From 976706c6d0748379eac0b817553ddb15716c27d7 Mon Sep 17 00:00:00 2001 From: 13848007649 <410039586@qq.com> Date: Thu, 28 Aug 2025 09:58:10 +0800 Subject: [PATCH] =?UTF-8?q?=E6=96=B0=E5=A2=9Eopus=E7=BF=BB=E8=AF=91?= =?UTF-8?q?=E6=A8=A1=E5=9E=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- smartjavaai-translate/pom.xml | 5 + .../translation/config/OpusSearchConfig.java | 63 +++ .../entity/BeamBatchTensorList.java | 61 +++ .../translation/entity/TranslateParam.java | 10 + .../enums/TranslationModeEnum.java | 6 +- .../factory/TranslationModelFactory.java | 20 +- .../translation/model/BeamHypotheses.java | 121 +++++ .../translation/model/BeamSearchScorer.java | 118 +++++ .../translation/model/NllbModel.java | 4 +- .../translation/model/OpusMtModel.java | 442 ++++++++++++++++++ .../translation/model/TranslationModel.java | 9 + .../translator/NllbDecoder2Translator.java | 2 +- .../translator/NllbDecoderTranslator.java | 2 +- .../translation/utils/NDArrayUtils.java | 28 ++ .../translation/utils/TokenUtils.java | 28 ++ 15 files changed, 906 insertions(+), 13 deletions(-) create mode 100644 smartjavaai-translate/src/main/java/cn/smartjavaai/translation/config/OpusSearchConfig.java create mode 100644 smartjavaai-translate/src/main/java/cn/smartjavaai/translation/entity/BeamBatchTensorList.java create mode 100644 smartjavaai-translate/src/main/java/cn/smartjavaai/translation/model/BeamHypotheses.java create mode 100644 smartjavaai-translate/src/main/java/cn/smartjavaai/translation/model/BeamSearchScorer.java create mode 100644 smartjavaai-translate/src/main/java/cn/smartjavaai/translation/model/OpusMtModel.java create mode 100644 smartjavaai-translate/src/main/java/cn/smartjavaai/translation/utils/NDArrayUtils.java diff --git a/smartjavaai-translate/pom.xml b/smartjavaai-translate/pom.xml index 5c70904..8c53b1d 100644 --- a/smartjavaai-translate/pom.xml +++ b/smartjavaai-translate/pom.xml @@ -18,6 +18,11 @@ smartjavaai-common ${project.version} + + + ai.djl.sentencepiece + sentencepiece + 1.0.23 diff --git a/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/config/OpusSearchConfig.java b/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/config/OpusSearchConfig.java new file mode 100644 index 0000000..dc2a881 --- /dev/null +++ b/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/config/OpusSearchConfig.java @@ -0,0 +1,63 @@ +package cn.smartjavaai.translation.config; +/** + * 配置信息 + * + * @author Calvin + * @mail 179209347@qq.com + * @website www.aias.top + */ +public class OpusSearchConfig { + private int maxSeqLength; + private long padTokenId; + private long eosTokenId; + private int beam; + private boolean suffixPadding; + + public OpusSearchConfig() { + this.eosTokenId = 0; + this.padTokenId = 65000; + this.maxSeqLength = 512; + this.beam = 6; + } + + + public void setEosTokenId(long eosTokenId) { + this.eosTokenId = eosTokenId; + } + + public int getMaxSeqLength() { + return maxSeqLength; + } + + public void setMaxSeqLength(int maxSeqLength) { + this.maxSeqLength = maxSeqLength; + } + + public long getPadTokenId() { + return padTokenId; + } + + public void setPadTokenId(long padTokenId) { + this.padTokenId = padTokenId; + } + + public long getEosTokenId() { + return eosTokenId; + } + + public int getBeam() { + return beam; + } + + public void setBeam(int beam) { + this.beam = beam; + } + + public boolean isSuffixPadding() { + return suffixPadding; + } + + public void setSuffixPadding(boolean suffixPadding) { + this.suffixPadding = suffixPadding; + } +} diff --git a/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/entity/BeamBatchTensorList.java b/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/entity/BeamBatchTensorList.java new file mode 100644 index 0000000..2c40021 --- /dev/null +++ b/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/entity/BeamBatchTensorList.java @@ -0,0 +1,61 @@ +package cn.smartjavaai.translation.entity; + +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; + +/** + * beam 搜索张量对象列表 + * + * @author Calvin + * @mail 179209347@qq.com + * @website www.aias.top + */ +public class BeamBatchTensorList { + private NDArray nextInputIds; + private NDArray encoderHiddenStates; + private NDArray attentionMask; + private NDList pastKeyValues; + + + public BeamBatchTensorList() { + } + + public BeamBatchTensorList(NDArray nextInputIds, NDArray attentionMask, NDArray encoderHiddenStates, NDList pastKeyValues) { + this.nextInputIds = nextInputIds; + this.attentionMask = attentionMask; + this.pastKeyValues = pastKeyValues; + this.encoderHiddenStates = encoderHiddenStates; + } + + public NDArray getNextInputIds() { + return nextInputIds; + } + + public void setNextInputIds(NDArray nextInputIds) { + this.nextInputIds = nextInputIds; + } + + public NDArray getEncoderHiddenStates() { + return encoderHiddenStates; + } + + public void setEncoderHiddenStates(NDArray encoderHiddenStates) { + this.encoderHiddenStates = encoderHiddenStates; + } + + public NDArray getAttentionMask() { + return attentionMask; + } + + public void setAttentionMask(NDArray attentionMask) { + this.attentionMask = attentionMask; + } + + public NDList getPastKeyValues() { + return pastKeyValues; + } + + public void setPastKeyValues(NDList pastKeyValues) { + this.pastKeyValues = pastKeyValues; + } +} diff --git a/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/entity/TranslateParam.java b/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/entity/TranslateParam.java index 6d426c9..aac5b41 100644 --- a/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/entity/TranslateParam.java +++ b/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/entity/TranslateParam.java @@ -45,6 +45,16 @@ public class TranslateParam { return R.ok(null); } + public TranslateParam(String input, LanguageCode sourceLanguage, LanguageCode targetLanguage) { + this.input = input; + this.sourceLanguage = sourceLanguage; + this.targetLanguage = targetLanguage; + } + public TranslateParam(String input) { + this.input = input; + } + public TranslateParam() { + } } diff --git a/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/enums/TranslationModeEnum.java b/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/enums/TranslationModeEnum.java index e558a48..c23b0b3 100644 --- a/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/enums/TranslationModeEnum.java +++ b/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/enums/TranslationModeEnum.java @@ -6,7 +6,11 @@ package cn.smartjavaai.translation.enums; */ public enum TranslationModeEnum { - NLLB_MODEL; + NLLB_MODEL, + + OPUS_MT_ZH_EN, + + OPUS_MT_EN_ZH; /** * 根据名称获取枚举 (忽略大小写和下划线变体) diff --git a/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/factory/TranslationModelFactory.java b/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/factory/TranslationModelFactory.java index 5ca6797..e13ad9e 100644 --- a/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/factory/TranslationModelFactory.java +++ b/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/factory/TranslationModelFactory.java @@ -4,8 +4,10 @@ import cn.smartjavaai.common.config.Config; import cn.smartjavaai.translation.config.TranslationModelConfig; +import cn.smartjavaai.translation.enums.TranslationModeEnum; import cn.smartjavaai.translation.exception.TranslationException; import cn.smartjavaai.translation.model.NllbModel; +import cn.smartjavaai.translation.model.OpusMtModel; import cn.smartjavaai.translation.model.TranslationModel; import lombok.extern.slf4j.Slf4j; @@ -23,14 +25,14 @@ public class TranslationModelFactory { // 使用 volatile 和双重检查锁定来确保线程安全的单例模式 private static volatile TranslationModelFactory instance; - private static final ConcurrentHashMap modelMap = new ConcurrentHashMap<>(); + private static final ConcurrentHashMap modelMap = new ConcurrentHashMap<>(); /** * 检测模型注册表 */ - private static final Map> modelRegistry = + private static final Map> modelRegistry = new ConcurrentHashMap<>(); @@ -49,11 +51,11 @@ public class TranslationModelFactory { /** * 注册翻译模型 - * @param name + * @param translationModeEnum * @param clazz */ - private static void registerCommonDetModel(String name, Class clazz) { - modelRegistry.put(name.toLowerCase(), clazz); + private static void registerCommonDetModel(TranslationModeEnum translationModeEnum, Class clazz) { + modelRegistry.put(translationModeEnum, clazz); } /** @@ -65,7 +67,7 @@ public class TranslationModelFactory { if(Objects.isNull(config) || Objects.isNull(config.getModelEnum())){ throw new TranslationException("未配置OCR模型"); } - return modelMap.computeIfAbsent(config.getModelEnum().name(), k -> { + return modelMap.computeIfAbsent(config.getModelEnum(), k -> { return createModel(config); }); } @@ -77,7 +79,7 @@ public class TranslationModelFactory { * @return */ private TranslationModel createModel(TranslationModelConfig config) { - Class clazz = modelRegistry.get(config.getModelEnum().name().toLowerCase()); + Class clazz = modelRegistry.get(config.getModelEnum()); if(clazz == null){ throw new TranslationException("Unsupported model"); } @@ -94,7 +96,9 @@ public class TranslationModelFactory { // 初始化默认算法 static { - registerCommonDetModel("NLLB_MODEL", NllbModel.class); + registerCommonDetModel(TranslationModeEnum.NLLB_MODEL, NllbModel.class); + registerCommonDetModel(TranslationModeEnum.OPUS_MT_EN_ZH, OpusMtModel.class); + registerCommonDetModel(TranslationModeEnum.OPUS_MT_ZH_EN, OpusMtModel.class); log.debug("缓存目录:{}", Config.getCachePath()); } diff --git a/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/model/BeamHypotheses.java b/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/model/BeamHypotheses.java new file mode 100644 index 0000000..6ae5ce4 --- /dev/null +++ b/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/model/BeamHypotheses.java @@ -0,0 +1,121 @@ +package cn.smartjavaai.translation.model; + +import ai.djl.util.Pair; + +import java.util.ArrayList; + +/** + * Beam hypothesis + * + * @author Calvin + * @mail 179209347@qq.com + * @website www.aias.top + */ +public class BeamHypotheses { + float length_penalty; + boolean early_stopping; + int num_beams; + ArrayList> beams; + float worst_score = 1e9f; + + public BeamHypotheses(float length_penalty, boolean early_stopping, int num_beams) { + this.length_penalty = length_penalty; + this.early_stopping = early_stopping; + this.num_beams = num_beams; + beams = new ArrayList<>(); + } + + /** + * Get length + * + * @return + */ + public int getLen() { + return beams.size(); + } + + /** + * Add a new hypothesis to the list. + * + * @param sum_logprobs + * @param hyp + */ + public void add(float sum_logprobs, long[] hyp) { + float score = sum_logprobs / (float) (Math.pow(hyp.length, this.length_penalty)); + + if (getLen() < this.num_beams || score > this.worst_score) { + this.beams.add(new Pair<>(score, hyp)); + if (getLen() > this.num_beams) { + int index = min(); + this.beams.remove(index); + index = min(); + this.worst_score = this.beams.get(index).getKey(); + }else { + this.worst_score = Math.min(score, this.worst_score); + } + } + } + + /** + * Get Pair + * @param index + * @return + */ + public Pair getPair(int index) { + return beams.get(index); + } + + /** + * Get index for minmum score value + * + * @return + */ + public int min() { + float min = beams.get(0).getKey(); + int index = 0; + for (int i = 1; i < beams.size(); ++i) { + if (beams.get(i).getKey() < min) { + min = beams.get(i).getKey(); + index = i; + } + } + return index; + } + + /** + * Get index for maximum score value + * @return + */ + public int max() { + float max = beams.get(0).getKey(); + int index = 0; + for (int i = 1; i < beams.size(); ++i) { + if (beams.get(i).getKey() > max) { + max = beams.get(i).getKey(); + index = i; + } + } + return index; + } + + /** + * If there are enough hypotheses and that none of the hypotheses being generated can become better than the worst + * one in the heap, then we are done with this sentence. + * + * @param best_sum_logprobs + * @param cur_len + * @return + */ + public boolean isDone(float best_sum_logprobs, long cur_len) { + if (getLen() < this.num_beams) + return false; + + if (this.early_stopping) + return true; + else { + float highest_attainable_score = best_sum_logprobs / (float) Math.pow(cur_len, this.length_penalty); + boolean ret = (this.worst_score >= highest_attainable_score); + return ret; + } + } +} diff --git a/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/model/BeamSearchScorer.java b/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/model/BeamSearchScorer.java new file mode 100644 index 0000000..48ed946 --- /dev/null +++ b/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/model/BeamSearchScorer.java @@ -0,0 +1,118 @@ +package cn.smartjavaai.translation.model; + +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; +import ai.djl.ndarray.NDManager; +import ai.djl.ndarray.index.NDIndex; +import ai.djl.ndarray.types.Shape; +import ai.djl.util.Pair; + +/** + * Implementing standard beam search decoding. + * + * @author Calvin + * @mail 179209347@qq.com + * @website www.aias.top + */ +public class BeamSearchScorer { + private int num_beams; + private float length_penalty; + private boolean do_early_stopping; + private int num_beam_hyps_to_keep; + private int num_beam_groups; + private BeamHypotheses beam_hyp; + private boolean _done; + + public BeamSearchScorer(int num_beams, float length_penalty, boolean do_early_stopping, int num_beam_hyps_to_keep, int num_beam_groups) { + this.num_beams = num_beams; + this.length_penalty = length_penalty; + this.do_early_stopping = do_early_stopping; + this.num_beam_hyps_to_keep = num_beam_hyps_to_keep; + this.num_beam_groups = num_beam_groups; + beam_hyp = new BeamHypotheses(length_penalty, do_early_stopping, num_beams); + _done = false; + } + + public boolean isDone() { + return _done; + } + + public NDList process(NDManager manager, NDArray input_ids, NDArray next_scores, NDArray next_tokens, NDArray next_indices, long pad_token_id, long eos_token_id) { + + float[] next_scores_arr = next_scores.toFloatArray(); + long[] next_indices_arr = next_indices.toLongArray(); + long[] next_tokens_arr = next_tokens.toLongArray(); + + NDArray next_beam_scores = manager.zeros(new Shape(1, this.num_beams), next_scores.getDataType()); + NDArray next_beam_tokens = manager.zeros(new Shape(1, this.num_beams), next_tokens.getDataType()); + NDArray next_beam_indices = manager.zeros(new Shape(1, this.num_beams), next_indices.getDataType()); + + + // next tokens for this sentence + int beam_idx = 0; + float maxScore = Float.NEGATIVE_INFINITY; + for (int i = 0; i < next_scores_arr.length; ++i) { + int beam_token_rank = i; + long next_token = next_tokens_arr[i]; + float next_score = next_scores_arr[i]; + if (maxScore < next_score) { + maxScore = next_score; + } + long next_index = next_indices_arr[i]; + + long batch_beam_idx = next_index; + + // add to generated hypotheses if end of sentence + if (next_token == eos_token_id) { + // if beam_token does not belong to top num_beams tokens, it should not be added + if (beam_token_rank >= this.num_beams) + continue; + long[] arr = input_ids.get(batch_beam_idx).toLongArray(); + // Add a new hypothesis to the list. + beam_hyp.add(next_score, arr); + } else { + // add next predicted token since it is not eos_token + next_beam_scores.set(new NDIndex(0, beam_idx), next_score); + next_beam_tokens.set(new NDIndex(0, beam_idx), next_token); + next_beam_indices.set(new NDIndex(0, beam_idx), batch_beam_idx); + beam_idx += 1; + } + + // once the beam for next step is full, don't add more tokens to it. + if (beam_idx == this.num_beams) + break; + } + + long cur_len = input_ids.getShape().getLastDimension(); + this._done = this._done || beam_hyp.isDone(maxScore, cur_len); + + NDList list = new NDList(); + list.add(next_beam_scores); + list.add(next_beam_tokens); + list.add(next_beam_indices); + + return list; + } + + public long[] finalize(int max_length, long eos_token_id) { + + // best_hyp_tuple + Pair pair = beam_hyp.getPair(beam_hyp.max()); + float best_score = pair.getKey(); + long[] best_hyp = pair.getValue(); + int sent_length = best_hyp.length; + + // prepare for adding eos + int sent_max_len = Math.min(sent_length + 1, max_length); + long[] decodedArr = new long[sent_max_len]; + + for (int i = 0; i < sent_length; ++i) { + decodedArr[i] = best_hyp[i]; + } + if (sent_length < max_length) { + decodedArr[sent_length] = eos_token_id; + } + + return decodedArr; + } +} diff --git a/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/model/NllbModel.java b/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/model/NllbModel.java index 71653f7..1b9f22b 100644 --- a/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/model/NllbModel.java +++ b/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/model/NllbModel.java @@ -6,6 +6,7 @@ import ai.djl.engine.Engine; import ai.djl.huggingface.tokenizers.Encoding; import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer; import ai.djl.inference.Predictor; +import ai.djl.modality.nlp.generate.CausalLMOutput; import ai.djl.ndarray.NDArray; import ai.djl.ndarray.NDList; import ai.djl.ndarray.NDManager; @@ -20,7 +21,6 @@ import cn.smartjavaai.common.enums.DeviceEnum; import cn.smartjavaai.common.pool.CommonPredictorFactory; import cn.smartjavaai.translation.config.TranslationModelConfig; import cn.smartjavaai.translation.config.NllbSearchConfig; -import cn.smartjavaai.translation.entity.CausalLMOutput; import cn.smartjavaai.translation.entity.GreedyBatchTensorList; import cn.smartjavaai.translation.entity.TranslateParam; import cn.smartjavaai.translation.exception.TranslationException; @@ -40,7 +40,7 @@ import java.nio.file.Paths; import java.util.Objects; /** - * 机器翻译通用检测模型 + * Nllb机器翻译模型 * * @author lwx * @date 2025/6/05 diff --git a/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/model/OpusMtModel.java b/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/model/OpusMtModel.java new file mode 100644 index 0000000..60f6608 --- /dev/null +++ b/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/model/OpusMtModel.java @@ -0,0 +1,442 @@ +package cn.smartjavaai.translation.model; + +import ai.djl.Device; +import ai.djl.MalformedModelException; +import ai.djl.engine.Engine; +import ai.djl.huggingface.tokenizers.Encoding; +import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer; +import ai.djl.inference.Predictor; +import ai.djl.modality.cv.Image; +import ai.djl.modality.cv.output.DetectedObjects; +import ai.djl.modality.nlp.generate.CausalLMOutput; +import ai.djl.modality.nlp.generate.SearchConfig; +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDArrays; +import ai.djl.ndarray.NDList; +import ai.djl.ndarray.NDManager; +import ai.djl.ndarray.index.NDIndex; +import ai.djl.ndarray.types.DataType; +import ai.djl.ndarray.types.Shape; +import ai.djl.repository.zoo.Criteria; +import ai.djl.repository.zoo.ModelNotFoundException; +import ai.djl.repository.zoo.ModelZoo; +import ai.djl.repository.zoo.ZooModel; +import ai.djl.sentencepiece.SpTokenizer; +import ai.djl.translate.NoopTranslator; +import ai.djl.translate.TranslateException; +import ai.djl.util.Utils; +import cn.smartjavaai.common.entity.R; +import cn.smartjavaai.common.enums.DeviceEnum; +import cn.smartjavaai.common.pool.CommonPredictorFactory; +import cn.smartjavaai.translation.config.NllbSearchConfig; +import cn.smartjavaai.translation.config.OpusSearchConfig; +import cn.smartjavaai.translation.config.TranslationModelConfig; +import cn.smartjavaai.translation.entity.BeamBatchTensorList; +import cn.smartjavaai.translation.entity.GreedyBatchTensorList; +import cn.smartjavaai.translation.entity.TranslateParam; +import cn.smartjavaai.translation.exception.TranslationException; +import cn.smartjavaai.translation.model.translator.NllbDecoder2Translator; +import cn.smartjavaai.translation.model.translator.NllbDecoderTranslator; +import cn.smartjavaai.translation.model.translator.NllbEncoderTranslator; +import cn.smartjavaai.translation.model.translator.opus.Decoder2Translator; +import cn.smartjavaai.translation.model.translator.opus.DecoderTranslator; +import cn.smartjavaai.translation.model.translator.opus.EncoderTranslator; +import cn.smartjavaai.translation.utils.NDArrayUtils; +import cn.smartjavaai.translation.utils.TokenUtils; +import com.google.gson.Gson; +import com.google.gson.reflect.TypeToken; +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.lang3.StringUtils; +import org.apache.commons.pool2.impl.GenericObjectPool; + +import java.io.IOException; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.*; +import java.util.concurrent.ConcurrentHashMap; + +/** + * OpusMt机器翻译模型 + * + * @author dwj + */ +@Slf4j +public class OpusMtModel implements TranslationModel{ + + private GenericObjectPool> encodePredictorPool; + + private GenericObjectPool> decodePredictorPool; + + private GenericObjectPool> decode2PredictorPool; + + private ZooModel model; + private SpTokenizer sourceTokenizer; + + private OpusSearchConfig searchConfig; + private TranslationModelConfig config; + + private ConcurrentHashMap map; + + private ConcurrentHashMap reverseMap; + + private float length_penalty = 1.0f; + private boolean do_early_stopping = false; + private int num_beam_hyps_to_keep = 1; + private int num_beam_groups = 1; + + + + @Override + public void loadModel(TranslationModelConfig config) { + if (StringUtils.isBlank(config.getModelPath())) { + throw new TranslationException("modelPath is null"); + } + Device device = null; + if (!Objects.isNull(config.getDevice())) { + device = config.getDevice() == DeviceEnum.CPU ? Device.cpu() : Device.gpu(config.getGpuId()); + } + this.config = config; + Path modelPath = Paths.get(config.getModelPath()); + Criteria criteria = + Criteria.builder() + .setTypes(NDList.class, NDList.class) + .optModelPath(modelPath) + .optEngine("PyTorch") + .optDevice(device) + .optTranslator(new NoopTranslator()) + .build(); + try { + model = ModelZoo.loadModel(criteria); + encodePredictorPool = new GenericObjectPool<>(new CommonPredictorFactory(model,new EncoderTranslator())); + decodePredictorPool = new GenericObjectPool<>(new CommonPredictorFactory(model,new DecoderTranslator())); + decode2PredictorPool = new GenericObjectPool<>(new CommonPredictorFactory(model,new Decoder2Translator())); + + Path tokenizerPath = modelPath.getParent().resolve("source.spm"); + sourceTokenizer = new SpTokenizer(tokenizerPath); + List words = Utils.readLines(modelPath.getParent().resolve("vocab.txt")); + String jsonStr = ""; + for (String line : words) { + jsonStr = jsonStr + line; + } + map = new Gson().fromJson(jsonStr, new TypeToken>() { + }.getType()); + reverseMap = new ConcurrentHashMap<>(); + Iterator it = map.entrySet().iterator(); + while (it.hasNext()) { + Map.Entry next = (Map.Entry) it.next(); + reverseMap.put(next.getValue(), next.getKey()); + } + //初始化searchConfig + this.searchConfig = new OpusSearchConfig(); + int predictorPoolSize = config.getPredictorPoolSize(); + if(config.getPredictorPoolSize() <= 0){ + predictorPoolSize = Runtime.getRuntime().availableProcessors(); // 默认等于CPU核心数 + } + encodePredictorPool.setMaxTotal(predictorPoolSize); + decodePredictorPool.setMaxTotal(predictorPoolSize); + decode2PredictorPool.setMaxTotal(predictorPoolSize); + log.debug("当前设备: " + model.getNDManager().getDevice()); + log.debug("当前引擎: " + Engine.getInstance().getEngineName()); + log.debug("模型推理器线程池最大数量: " + predictorPoolSize); + } catch (IOException | ModelNotFoundException | MalformedModelException e) { + throw new TranslationException("模型加载失败", e); + } + } + + @Override + public R translate(TranslateParam translateParam) { + if(translateParam == null){ + return R.fail(R.Status.PARAM_ERROR); + } + //验证 + if (StringUtils.isBlank(translateParam.getInput())) { + return R.fail(R.Status.PARAM_ERROR.getCode(), "输入文本不能为空"); + } + return R.ok(translateLanguage(translateParam)); + } + + @Override + public R translate(String input) { + //验证 + if (StringUtils.isBlank(input)) { + return R.fail(R.Status.PARAM_ERROR.getCode(), "输入文本不能为空"); + } + return R.ok(translateLanguage(new TranslateParam(input))); + } + + private String translateLanguage(TranslateParam translateParam) { + try (NDManager manager = NDManager.newBaseManager()) { + long numBeam = searchConfig.getBeam(); + BeamSearchScorer beamSearchScorer = new BeamSearchScorer((int) numBeam, length_penalty, do_early_stopping, num_beam_hyps_to_keep, num_beam_groups); + // 1. Encode + List tokens = sourceTokenizer.tokenize(translateParam.getInput()); + String[] strs = tokens.toArray(new String[]{}); + log.info("Tokens: " + Arrays.toString(strs)); + int[] sourceIds = new int[tokens.size() + 1]; + sourceIds[tokens.size()] = 0; + for (int i = 0; i < tokens.size(); i++) { + sourceIds[i] = map.get(tokens.get(i)).intValue(); + } + NDArray encoder_hidden_states = encoder(sourceIds); + encoder_hidden_states = NDArrayUtils.expand(encoder_hidden_states, searchConfig.getBeam()); + + NDArray decoder_input_ids = manager.create(new long[]{65000}).reshape(1, 1); + decoder_input_ids = NDArrayUtils.expand(decoder_input_ids, numBeam); + + + long[] attentionMask = new long[sourceIds.length]; + Arrays.fill(attentionMask, 1); + NDArray attentionMaskArray = manager.create(attentionMask).expandDims(0); + NDArray new_attention_mask = NDArrayUtils.expand(attentionMaskArray, searchConfig.getBeam()); + NDList decoderInput = new NDList(decoder_input_ids, encoder_hidden_states, new_attention_mask); + + + // 2. Initial Decoder + CausalLMOutput modelOutput = decoder(decoderInput); + modelOutput.getLogits().attach(manager); + modelOutput.getPastKeyValuesList().attach(manager); + + NDArray beam_scores = manager.zeros(new Shape(1, numBeam), DataType.FLOAT32); + beam_scores.set(new NDIndex(":, 1:"), -1e9); + beam_scores = beam_scores.reshape(numBeam, 1); + + NDArray input_ids = decoder_input_ids; + BeamBatchTensorList searchState = new BeamBatchTensorList(null, new_attention_mask, encoder_hidden_states, modelOutput.getPastKeyValuesList()); + NDArray next_tokens; + NDArray next_indices; + while (true) { + if (searchState.getNextInputIds() != null) { + decoder_input_ids = searchState.getNextInputIds().get(new NDIndex(":, -1:")); + decoderInput = new NDList(decoder_input_ids, searchState.getEncoderHiddenStates(), searchState.getAttentionMask()); + decoderInput.addAll(searchState.getPastKeyValues()); + // 3. Decoder loop + modelOutput = decoder2(decoderInput); + } + + NDArray next_token_logits = modelOutput.getLogits().get(":, -1, :"); + + // hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id` + // cannot be generated both before and after the `nn.functional.log_softmax` operation. + NDArray new_next_token_logits = manager.create(next_token_logits.getShape(), next_token_logits.getDataType()); + next_token_logits.copyTo(new_next_token_logits); + new_next_token_logits.set(new NDIndex(":," + searchConfig.getPadTokenId()), Float.NEGATIVE_INFINITY); + + NDArray next_token_scores = new_next_token_logits.logSoftmax(1); + + // next_token_scores = logits_processor(input_ids, next_token_scores) + // 1. NoBadWordsLogitsProcessor + next_token_scores.set(new NDIndex(":," + searchConfig.getPadTokenId()), Float.NEGATIVE_INFINITY); + + // 2. MinLengthLogitsProcessor 没生效 + // 3. ForcedEOSTokenLogitsProcessor + long cur_len = input_ids.getShape().getLastDimension(); + if (cur_len == (searchConfig.getMaxSeqLength() - 1)) { + long num_tokens = next_token_scores.getShape().getLastDimension(); + for (long i = 0; i < num_tokens; i++) { + if(i != searchConfig.getEosTokenId()){ + next_token_scores.set(new NDIndex(":," + i), Float.NEGATIVE_INFINITY); + } + } + next_token_scores.set(new NDIndex(":," + searchConfig.getEosTokenId()), 0); + } + + long vocab_size = next_token_scores.getShape().getLastDimension(); + beam_scores = beam_scores.repeat(1, vocab_size); + next_token_scores = next_token_scores.add(beam_scores); + + // reshape for beam search + next_token_scores = next_token_scores.reshape(1, numBeam * vocab_size); + + // [batch, beam] + NDList topK = next_token_scores.topK(Math.toIntExact(numBeam) * 2, 1, true, true); + + next_token_scores = topK.get(0); + next_tokens = topK.get(1); + + // next_indices = next_tokens // vocab_size + next_indices = next_tokens.div(vocab_size).toType(DataType.INT64, true); + + // next_tokens = next_tokens % vocab_size + next_tokens = next_tokens.mod(vocab_size); + + // stateless + NDList beam_outputs = beamSearchScorer.process(manager, input_ids, next_token_scores, next_tokens, next_indices, searchConfig.getPadTokenId(), searchConfig.getEosTokenId()); + + beam_scores = beam_outputs.get(0).reshape(numBeam, 1); + NDArray beam_next_tokens = beam_outputs.get(1); + NDArray beam_idx = beam_outputs.get(2); + + // input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) + long[] beam_next_tokens_arr = beam_next_tokens.toLongArray(); + long[] beam_idx_arr = beam_idx.toLongArray(); + NDList inputList = new NDList(); + for (int i = 0; i < numBeam; i++) { + long index = beam_idx_arr[i]; + NDArray ndArray = input_ids.get(index).reshape(1, input_ids.getShape().getLastDimension()); + ndArray = ndArray.concat(manager.create(beam_next_tokens_arr[i]).reshape(1, 1), 1); + inputList.add(ndArray); + } + input_ids = NDArrays.concat(inputList, 0); + searchState.setNextInputIds(input_ids); + searchState.setPastKeyValues(modelOutput.getPastKeyValuesList()); + + boolean maxLengthCriteria = (input_ids.getShape().getLastDimension() >= searchConfig.getMaxSeqLength()); + if (beamSearchScorer.isDone() || maxLengthCriteria) { + break; + } + + } + + long[] sequences = beamSearchScorer.finalize(searchConfig.getMaxSeqLength(), searchConfig.getEosTokenId()); + String result = TokenUtils.decode(reverseMap, sequences); + return result; + } catch (Exception e) { + throw new TranslationException("翻译错误", e); + } + } + + public NDArray encoder(int[] ids) { + Predictor predictor = null; + try { + predictor = (Predictor)encodePredictorPool.borrowObject(); + return predictor.predict(ids); + } catch (Exception e) { + throw new TranslationException("机器翻译编码错误", e); + }finally { + if (predictor != null) { + try { + encodePredictorPool.returnObject(predictor); //归还 + } catch (Exception e) { + log.warn("归还Predictor失败", e); + try { + predictor.close(); // 归还失败才销毁 + } catch (Exception ex) { + log.error("关闭Predictor失败", ex); + } + } + } + } + } + + public CausalLMOutput decoder(NDList input) throws TranslateException { + Predictor predictor = null; + try { + predictor = (Predictor)decodePredictorPool.borrowObject(); + return predictor.predict(input); + } catch (Exception e) { + throw new TranslationException("机器翻译编码错误", e); + }finally { + if (predictor != null) { + try { + decodePredictorPool.returnObject(predictor); //归还 + } catch (Exception e) { + log.warn("归还Predictor失败", e); + try { + predictor.close(); // 归还失败才销毁 + } catch (Exception ex) { + log.error("关闭Predictor失败", ex); + } + } + } + } + } + + public CausalLMOutput decoder2(NDList input) throws TranslateException { + Predictor predictor = null; + try { + predictor = (Predictor)decode2PredictorPool.borrowObject(); + return predictor.predict(input); + } catch (Exception e) { + throw new TranslationException("机器翻译编码错误", e); + }finally { + if (predictor != null) { + try { + decode2PredictorPool.returnObject(predictor); //归还 + } catch (Exception e) { + log.warn("归还Predictor失败", e); + try { + predictor.close(); // 归还失败才销毁 + } catch (Exception ex) { + log.error("关闭Predictor失败", ex); + } + } + } + } + } + + public NDArray greedyStepGen(NllbSearchConfig config, NDArray pastOutputIds, NDArray next_token_scores, NDManager manager) { + next_token_scores = next_token_scores.get(":, -1, :"); + + NDArray new_next_token_scores = manager.create(next_token_scores.getShape(), next_token_scores.getDataType()); + next_token_scores.copyTo(new_next_token_scores); + + // LogitsProcessor 1. ForcedBOSTokenLogitsProcessor + // 设置目标语言 + long cur_len = pastOutputIds.getShape().getLastDimension(); + if (cur_len == 1) { + long num_tokens = new_next_token_scores.getShape().getLastDimension(); + for (long i = 0; i < num_tokens; i++) { + if (i != config.getForcedBosTokenId()) { + new_next_token_scores.set(new NDIndex(":," + i), Float.NEGATIVE_INFINITY); + } + } + new_next_token_scores.set(new NDIndex(":," + config.getForcedBosTokenId()), 0); + } + + NDArray probs = new_next_token_scores.softmax(-1); + NDArray next_tokens = probs.argMax(-1); + + return next_tokens.expandDims(0); + } + + public GenericObjectPool> getEncodePredictorPool() { + return encodePredictorPool; + } + + public GenericObjectPool> getDecodePredictorPool() { + return decodePredictorPool; + } + + public GenericObjectPool> getDecode2PredictorPool() { + return decode2PredictorPool; + } + + @Override + public void close() throws Exception { + try { + if (model != null) { + model.close(); + } + } catch (Exception e) { + log.warn("关闭 model 失败", e); + } + try { + if (sourceTokenizer != null) { + sourceTokenizer.close(); + } + } catch (Exception e) { + log.warn("关闭 tokenizer 失败", e); + } + try { + if (encodePredictorPool != null) { + encodePredictorPool.close(); + } + } catch (Exception e) { + log.warn("关闭 encodePredictorPool 失败", e); + } + try { + if (decodePredictorPool != null) { + decodePredictorPool.close(); + } + } catch (Exception e) { + log.warn("关闭 decodePredictorPool 失败", e); + } + try { + if (decode2PredictorPool != null) { + decode2PredictorPool.close(); + } + } catch (Exception e) { + log.warn("关闭 decode2PredictorPool 失败", e); + } + } +} diff --git a/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/model/TranslationModel.java b/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/model/TranslationModel.java index f0b35d2..a9d55ec 100644 --- a/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/model/TranslationModel.java +++ b/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/model/TranslationModel.java @@ -28,6 +28,15 @@ public interface TranslationModel extends AutoCloseable{ } + /** + * 机器翻译 + * @param input 输入文本 + * @return + */ + default R translate(String input) { + throw new UnsupportedOperationException("默认不支持该功能"); + } + diff --git a/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/model/translator/NllbDecoder2Translator.java b/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/model/translator/NllbDecoder2Translator.java index 6e18541..f98e4aa 100644 --- a/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/model/translator/NllbDecoder2Translator.java +++ b/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/model/translator/NllbDecoder2Translator.java @@ -1,10 +1,10 @@ package cn.smartjavaai.translation.model.translator; +import ai.djl.modality.nlp.generate.CausalLMOutput; import ai.djl.ndarray.NDArray; import ai.djl.ndarray.NDList; import ai.djl.translate.NoBatchifyTranslator; import ai.djl.translate.TranslatorContext; -import cn.smartjavaai.translation.entity.CausalLMOutput; /** diff --git a/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/model/translator/NllbDecoderTranslator.java b/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/model/translator/NllbDecoderTranslator.java index 7867c87..717b49c 100644 --- a/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/model/translator/NllbDecoderTranslator.java +++ b/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/model/translator/NllbDecoderTranslator.java @@ -1,10 +1,10 @@ package cn.smartjavaai.translation.model.translator; +import ai.djl.modality.nlp.generate.CausalLMOutput; import ai.djl.ndarray.NDArray; import ai.djl.ndarray.NDList; import ai.djl.translate.NoBatchifyTranslator; import ai.djl.translate.TranslatorContext; -import cn.smartjavaai.translation.entity.CausalLMOutput; /** * 解碼器,參數沒有 pastKeyValues diff --git a/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/utils/NDArrayUtils.java b/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/utils/NDArrayUtils.java new file mode 100644 index 0000000..f159698 --- /dev/null +++ b/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/utils/NDArrayUtils.java @@ -0,0 +1,28 @@ +package cn.smartjavaai.translation.utils; + +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDArrays; +import ai.djl.ndarray.NDList; + +/** + * NDArray 工具类 + * + * @author Calvin + * @mail 179209347@qq.com + * @website www.aias.top + */ +public final class NDArrayUtils { + + private NDArrayUtils() { + } + + public static NDArray expand(NDArray array, long beam) { + NDList list = new NDList(); + for (long i = 0; i < beam; i++) { + list.add(array); + } + NDArray result = NDArrays.concat(list, 0); + + return result; + } +} diff --git a/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/utils/TokenUtils.java b/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/utils/TokenUtils.java index d613432..f19dc4b 100644 --- a/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/utils/TokenUtils.java +++ b/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/utils/TokenUtils.java @@ -6,6 +6,8 @@ import cn.smartjavaai.translation.config.NllbSearchConfig; import java.util.ArrayList; +import java.util.Arrays; +import java.util.Map; /** * @@ -44,4 +46,30 @@ public final class TokenUtils { String text = tokenizer.decode(ids); return text; } + + /** + * Token 解码 + * 根据语言的类型更新下面的方法 + * + * @param reverseMap + * @param outputIds + * @return + */ + public static String decode(Map reverseMap, long[] outputIds) { + int[] intArray = Arrays.stream(outputIds).mapToInt(l -> (int) l).toArray(); + + StringBuffer sb = new StringBuffer(); + for (int value : intArray) { + // 65000 + // 0 + if (value == 65000 || value == 0 || value == 8) + continue; + String text = reverseMap.get(Long.valueOf(value)); + sb.append(text); + } + + String result = sb.toString(); + result = result.replaceAll("▁"," "); + return result; + } } -- Gitee