diff --git a/PyTorch/built-in/nlp/Bert-Squad_ID0470_for_PyTorch/run_squad.py b/PyTorch/built-in/nlp/Bert-Squad_ID0470_for_PyTorch/run_squad.py index c46721df1c6d68d9543aefced7e6967d9a2cef63..6ca21e29d2d6c0960a91dfd336877d4d03e612d6 100644 --- a/PyTorch/built-in/nlp/Bert-Squad_ID0470_for_PyTorch/run_squad.py +++ b/PyTorch/built-in/nlp/Bert-Squad_ID0470_for_PyTorch/run_squad.py @@ -31,7 +31,10 @@ import numpy as np import torch if torch.__version__ >= "1.8": import torch_npu - torch.npu.config.allow_internal_format = False + if torch_npu.npu.utils.get_soc_version() == 103: + torch.npu.config.allow_internal_format = True + else: + torch.npu.config.allow_internal_format = False from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler, TensorDataset) from torch.utils.data.distributed import DistributedSampler