From 5b5b5a4477ac5ca518a82e8022f3183521f17e58 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E4=BC=9F=E6=A0=B9?= <1101204667@qq.com> Date: Mon, 11 Apr 2022 01:59:33 +0000 Subject: [PATCH] update main.py. --- .../EfficientNet_for_PyTorch/examples/imagenet/main.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/PyTorch/built-in/cv/classification/EfficientNet_for_PyTorch/examples/imagenet/main.py b/PyTorch/built-in/cv/classification/EfficientNet_for_PyTorch/examples/imagenet/main.py index fb92fc1ac2..d8343ad748 100644 --- a/PyTorch/built-in/cv/classification/EfficientNet_for_PyTorch/examples/imagenet/main.py +++ b/PyTorch/built-in/cv/classification/EfficientNet_for_PyTorch/examples/imagenet/main.py @@ -32,7 +32,7 @@ import torch.utils.data.distributed import torchvision.transforms as transforms import torchvision.datasets as datasets import torchvision.models as models - +import apex from apex import amp sys.path.append(os.path.join(os.path.abspath(os.path.dirname(__file__)),'../../')) @@ -195,7 +195,7 @@ def main_worker(npu, nnpus_per_node, args): criterion = nn.CrossEntropyLoss().to('npu:' + str(args.npu)) - optimizer = torch.optim.SGD(model.parameters(), args.lr, + optimizer = apex.optimizers.NpuFusedSGD(model.parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay) model = model.to('npu:' + str(args.npu)) @@ -205,7 +205,7 @@ def main_worker(npu, nnpus_per_node, args): print('=>unsupported precision mode!') exit() opt_level = args.pm - model, optimizer = amp.initialize(model, optimizer, opt_level=opt_level, loss_scale=args.loss_scale) + model, optimizer = amp.initialize(model, optimizer, opt_level=opt_level, loss_scale=args.loss_scale,combine_grad=True) global total_batch_size total_batch_size = args.batch_size -- Gitee