diff --git a/PyTorch/built-in/nlp/XLM_ID0740_for_PyTorch/xlm/trainer.py b/PyTorch/built-in/nlp/XLM_ID0740_for_PyTorch/xlm/trainer.py index 718cb7ed1463fa7a0564570b2002c8050572818a..f3cdb893cd3fc2977a6f91ec47eee3ce36ee1490 100644 --- a/PyTorch/built-in/nlp/XLM_ID0740_for_PyTorch/xlm/trainer.py +++ b/PyTorch/built-in/nlp/XLM_ID0740_for_PyTorch/xlm/trainer.py @@ -39,6 +39,7 @@ from .utils import to_cuda, concat_batches, find_modules from .utils import parse_lambda_config, update_lambdas from .model.memory import HashingMemory from .model.transformer import TransformerFFN +import torch.distributed as dist logger = getLogger() @@ -624,7 +625,16 @@ class Trainer(object): logger.info("Not a better validation score (%i / %i)." % (self.decrease_counts, self.decrease_counts_max)) self.decrease_counts += 1 - if self.decrease_counts > self.decrease_counts_max: + + if self.params.multi_npu: + decrease_counts = torch.from_numpy(np.array(self.decrease_counts)).npu().half() + dist.all_reduce(decrease_counts, op=dist.ReduceOp.MAX) + dist.barrier() + decrease_counts = decrease_counts.int().item() + else: + decrease_counts = self.decrease_counts + + if decrease_counts > self.decrease_counts_max: logger.info("Stopping criterion has been below its best value for more " "than %i epochs. Ending the experiment..." % self.decrease_counts_max) if self.params.multi_npu and 'SLURM_JOB_ID' in os.environ: