From 9c9a7f16ee479855d21ae7aa92e36e02a102309f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B5=B5=E7=84=B1?= Date: Sat, 9 Apr 2022 01:48:09 +0000 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E8=AE=AD=E7=BB=83=E7=BB=93?= =?UTF-8?q?=E6=9D=9F=E6=8E=A8=E5=87=BA=E5=BC=82=E5=B8=B8=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../nlp/XLM_ID0740_for_PyTorch/xlm/trainer.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/PyTorch/built-in/nlp/XLM_ID0740_for_PyTorch/xlm/trainer.py b/PyTorch/built-in/nlp/XLM_ID0740_for_PyTorch/xlm/trainer.py index 718cb7ed14..f3cdb893cd 100644 --- a/PyTorch/built-in/nlp/XLM_ID0740_for_PyTorch/xlm/trainer.py +++ b/PyTorch/built-in/nlp/XLM_ID0740_for_PyTorch/xlm/trainer.py @@ -39,6 +39,7 @@ from .utils import to_cuda, concat_batches, find_modules from .utils import parse_lambda_config, update_lambdas from .model.memory import HashingMemory from .model.transformer import TransformerFFN +import torch.distributed as dist logger = getLogger() @@ -624,7 +625,16 @@ class Trainer(object): logger.info("Not a better validation score (%i / %i)." % (self.decrease_counts, self.decrease_counts_max)) self.decrease_counts += 1 - if self.decrease_counts > self.decrease_counts_max: + + if self.params.multi_npu: + decrease_counts = torch.from_numpy(np.array(self.decrease_counts)).npu().half() + dist.all_reduce(decrease_counts, op=dist.ReduceOp.MAX) + dist.barrier() + decrease_counts = decrease_counts.int().item() + else: + decrease_counts = self.decrease_counts + + if decrease_counts > self.decrease_counts_max: logger.info("Stopping criterion has been below its best value for more " "than %i epochs. Ending the experiment..." % self.decrease_counts_max) if self.params.multi_npu and 'SLURM_JOB_ID' in os.environ: -- Gitee