diff --git a/PyTorch/built-in/cv/classification/EfficientNet_for_PyTorch/examples/imagenet/main.py b/PyTorch/built-in/cv/classification/EfficientNet_for_PyTorch/examples/imagenet/main.py index fb92fc1ac25b88ec285a38f9e1eda6bd85b1a19d..d8343ad748952ac5faefb4dfde02761b7f5a66cc 100644 --- a/PyTorch/built-in/cv/classification/EfficientNet_for_PyTorch/examples/imagenet/main.py +++ b/PyTorch/built-in/cv/classification/EfficientNet_for_PyTorch/examples/imagenet/main.py @@ -32,7 +32,7 @@ import torch.utils.data.distributed import torchvision.transforms as transforms import torchvision.datasets as datasets import torchvision.models as models - +import apex from apex import amp sys.path.append(os.path.join(os.path.abspath(os.path.dirname(__file__)),'../../')) @@ -195,7 +195,7 @@ def main_worker(npu, nnpus_per_node, args): criterion = nn.CrossEntropyLoss().to('npu:' + str(args.npu)) - optimizer = torch.optim.SGD(model.parameters(), args.lr, + optimizer = apex.optimizers.NpuFusedSGD(model.parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay) model = model.to('npu:' + str(args.npu)) @@ -205,7 +205,7 @@ def main_worker(npu, nnpus_per_node, args): print('=>unsupported precision mode!') exit() opt_level = args.pm - model, optimizer = amp.initialize(model, optimizer, opt_level=opt_level, loss_scale=args.loss_scale) + model, optimizer = amp.initialize(model, optimizer, opt_level=opt_level, loss_scale=args.loss_scale,combine_grad=True) global total_batch_size total_batch_size = args.batch_size