From 265d7365bb1e1b874751eba77cc0ce9ca3edf30d Mon Sep 17 00:00:00 2001 From: anzhengqi Date: Sat, 15 Apr 2023 15:34:35 +0800 Subject: [PATCH] modify seq2seq --- research/nlp/seq2seq/src/seq2seq_model/__init__.py | 4 +++- .../nlp/seq2seq/src/seq2seq_model/seq2seq_for_infer.py | 8 +------- .../nlp/seq2seq/src/seq2seq_model/seq2seq_for_train.py | 1 + 3 files changed, 5 insertions(+), 8 deletions(-) diff --git a/research/nlp/seq2seq/src/seq2seq_model/__init__.py b/research/nlp/seq2seq/src/seq2seq_model/__init__.py index 4292350c1..8a5a61845 100644 --- a/research/nlp/seq2seq/src/seq2seq_model/__init__.py +++ b/research/nlp/seq2seq/src/seq2seq_model/__init__.py @@ -19,6 +19,7 @@ from .seq2seq_for_train import Seq2seqTraining, LabelSmoothedCrossEntropyCriteri Seq2seqNetworkWithLoss, Seq2seqTrainOneStepWithLossScaleCell from .bleu_calculate import bleu_calculate from .seq2seq_for_infer_onnx import infer_onnx +from .seq2seq_for_infer import infer __all__ = [ "Seq2seqTraining", @@ -28,5 +29,6 @@ __all__ = [ "Seq2seqModel", "Seq2seqConfig", "bleu_calculate", - "infer_onnx" + "infer_onnx", + "infer" ] diff --git a/research/nlp/seq2seq/src/seq2seq_model/seq2seq_for_infer.py b/research/nlp/seq2seq/src/seq2seq_model/seq2seq_for_infer.py index 76fa2536c..3c8bdb38d 100644 --- a/research/nlp/seq2seq/src/seq2seq_model/seq2seq_for_infer.py +++ b/research/nlp/seq2seq/src/seq2seq_model/seq2seq_for_infer.py @@ -21,7 +21,7 @@ import mindspore.nn as nn import mindspore.common.dtype as mstype from mindspore.common.tensor import Tensor from mindspore.ops import operations as P -from mindspore import context, Parameter +from mindspore import Parameter from mindspore.train.model import Model from src.dataset import load_dataset @@ -29,12 +29,6 @@ from .seq2seq import Seq2seqModel from ..utils import zero_weight from ..utils.load_weights import load_infer_weights -context.set_context( - mode=context.GRAPH_MODE, - save_graphs=False, - device_target="Ascend", - reserve_class_name_in_scope=False) - class Seq2seqInferCell(nn.Cell): """ diff --git a/research/nlp/seq2seq/src/seq2seq_model/seq2seq_for_train.py b/research/nlp/seq2seq/src/seq2seq_model/seq2seq_for_train.py index 6330926e1..e5ac97131 100644 --- a/research/nlp/seq2seq/src/seq2seq_model/seq2seq_for_train.py +++ b/research/nlp/seq2seq/src/seq2seq_model/seq2seq_for_train.py @@ -132,6 +132,7 @@ class LabelSmoothedCrossEntropyCriterion(nn.Cell): """ prediction_scores = self.reshape(prediction_scores, (-1, self.vocab_size)) label_ids = self.reshape(label_ids, (-1, 1)) + label_ids = F.select(label_ids < self.vocab_size, label_ids, Tensor(0, mstype.int32)) label_weights = self.reshape(label_weights, (-1,)) tmp_gather_indices = self.concat((self.index_ids, label_ids)) nll_loss = self.neg(self.gather_nd(prediction_scores, tmp_gather_indices)) -- Gitee