From 976706c6d0748379eac0b817553ddb15716c27d7 Mon Sep 17 00:00:00 2001
From: 13848007649 <410039586@qq.com>
Date: Thu, 28 Aug 2025 09:58:10 +0800
Subject: [PATCH] =?UTF-8?q?=E6=96=B0=E5=A2=9Eopus=E7=BF=BB=E8=AF=91?=
=?UTF-8?q?=E6=A8=A1=E5=9E=8B?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
smartjavaai-translate/pom.xml | 5 +
.../translation/config/OpusSearchConfig.java | 63 +++
.../entity/BeamBatchTensorList.java | 61 +++
.../translation/entity/TranslateParam.java | 10 +
.../enums/TranslationModeEnum.java | 6 +-
.../factory/TranslationModelFactory.java | 20 +-
.../translation/model/BeamHypotheses.java | 121 +++++
.../translation/model/BeamSearchScorer.java | 118 +++++
.../translation/model/NllbModel.java | 4 +-
.../translation/model/OpusMtModel.java | 442 ++++++++++++++++++
.../translation/model/TranslationModel.java | 9 +
.../translator/NllbDecoder2Translator.java | 2 +-
.../translator/NllbDecoderTranslator.java | 2 +-
.../translation/utils/NDArrayUtils.java | 28 ++
.../translation/utils/TokenUtils.java | 28 ++
15 files changed, 906 insertions(+), 13 deletions(-)
create mode 100644 smartjavaai-translate/src/main/java/cn/smartjavaai/translation/config/OpusSearchConfig.java
create mode 100644 smartjavaai-translate/src/main/java/cn/smartjavaai/translation/entity/BeamBatchTensorList.java
create mode 100644 smartjavaai-translate/src/main/java/cn/smartjavaai/translation/model/BeamHypotheses.java
create mode 100644 smartjavaai-translate/src/main/java/cn/smartjavaai/translation/model/BeamSearchScorer.java
create mode 100644 smartjavaai-translate/src/main/java/cn/smartjavaai/translation/model/OpusMtModel.java
create mode 100644 smartjavaai-translate/src/main/java/cn/smartjavaai/translation/utils/NDArrayUtils.java
diff --git a/smartjavaai-translate/pom.xml b/smartjavaai-translate/pom.xml
index 5c70904..8c53b1d 100644
--- a/smartjavaai-translate/pom.xml
+++ b/smartjavaai-translate/pom.xml
@@ -18,6 +18,11 @@
smartjavaai-common
${project.version}
+
+
+ ai.djl.sentencepiece
+ sentencepiece
+
1.0.23
diff --git a/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/config/OpusSearchConfig.java b/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/config/OpusSearchConfig.java
new file mode 100644
index 0000000..dc2a881
--- /dev/null
+++ b/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/config/OpusSearchConfig.java
@@ -0,0 +1,63 @@
+package cn.smartjavaai.translation.config;
+/**
+ * 配置信息
+ *
+ * @author Calvin
+ * @mail 179209347@qq.com
+ * @website www.aias.top
+ */
+public class OpusSearchConfig {
+ private int maxSeqLength;
+ private long padTokenId;
+ private long eosTokenId;
+ private int beam;
+ private boolean suffixPadding;
+
+ public OpusSearchConfig() {
+ this.eosTokenId = 0;
+ this.padTokenId = 65000;
+ this.maxSeqLength = 512;
+ this.beam = 6;
+ }
+
+
+ public void setEosTokenId(long eosTokenId) {
+ this.eosTokenId = eosTokenId;
+ }
+
+ public int getMaxSeqLength() {
+ return maxSeqLength;
+ }
+
+ public void setMaxSeqLength(int maxSeqLength) {
+ this.maxSeqLength = maxSeqLength;
+ }
+
+ public long getPadTokenId() {
+ return padTokenId;
+ }
+
+ public void setPadTokenId(long padTokenId) {
+ this.padTokenId = padTokenId;
+ }
+
+ public long getEosTokenId() {
+ return eosTokenId;
+ }
+
+ public int getBeam() {
+ return beam;
+ }
+
+ public void setBeam(int beam) {
+ this.beam = beam;
+ }
+
+ public boolean isSuffixPadding() {
+ return suffixPadding;
+ }
+
+ public void setSuffixPadding(boolean suffixPadding) {
+ this.suffixPadding = suffixPadding;
+ }
+}
diff --git a/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/entity/BeamBatchTensorList.java b/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/entity/BeamBatchTensorList.java
new file mode 100644
index 0000000..2c40021
--- /dev/null
+++ b/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/entity/BeamBatchTensorList.java
@@ -0,0 +1,61 @@
+package cn.smartjavaai.translation.entity;
+
+import ai.djl.ndarray.NDArray;
+import ai.djl.ndarray.NDList;
+
+/**
+ * beam 搜索张量对象列表
+ *
+ * @author Calvin
+ * @mail 179209347@qq.com
+ * @website www.aias.top
+ */
+public class BeamBatchTensorList {
+ private NDArray nextInputIds;
+ private NDArray encoderHiddenStates;
+ private NDArray attentionMask;
+ private NDList pastKeyValues;
+
+
+ public BeamBatchTensorList() {
+ }
+
+ public BeamBatchTensorList(NDArray nextInputIds, NDArray attentionMask, NDArray encoderHiddenStates, NDList pastKeyValues) {
+ this.nextInputIds = nextInputIds;
+ this.attentionMask = attentionMask;
+ this.pastKeyValues = pastKeyValues;
+ this.encoderHiddenStates = encoderHiddenStates;
+ }
+
+ public NDArray getNextInputIds() {
+ return nextInputIds;
+ }
+
+ public void setNextInputIds(NDArray nextInputIds) {
+ this.nextInputIds = nextInputIds;
+ }
+
+ public NDArray getEncoderHiddenStates() {
+ return encoderHiddenStates;
+ }
+
+ public void setEncoderHiddenStates(NDArray encoderHiddenStates) {
+ this.encoderHiddenStates = encoderHiddenStates;
+ }
+
+ public NDArray getAttentionMask() {
+ return attentionMask;
+ }
+
+ public void setAttentionMask(NDArray attentionMask) {
+ this.attentionMask = attentionMask;
+ }
+
+ public NDList getPastKeyValues() {
+ return pastKeyValues;
+ }
+
+ public void setPastKeyValues(NDList pastKeyValues) {
+ this.pastKeyValues = pastKeyValues;
+ }
+}
diff --git a/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/entity/TranslateParam.java b/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/entity/TranslateParam.java
index 6d426c9..aac5b41 100644
--- a/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/entity/TranslateParam.java
+++ b/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/entity/TranslateParam.java
@@ -45,6 +45,16 @@ public class TranslateParam {
return R.ok(null);
}
+ public TranslateParam(String input, LanguageCode sourceLanguage, LanguageCode targetLanguage) {
+ this.input = input;
+ this.sourceLanguage = sourceLanguage;
+ this.targetLanguage = targetLanguage;
+ }
+ public TranslateParam(String input) {
+ this.input = input;
+ }
+ public TranslateParam() {
+ }
}
diff --git a/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/enums/TranslationModeEnum.java b/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/enums/TranslationModeEnum.java
index e558a48..c23b0b3 100644
--- a/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/enums/TranslationModeEnum.java
+++ b/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/enums/TranslationModeEnum.java
@@ -6,7 +6,11 @@ package cn.smartjavaai.translation.enums;
*/
public enum TranslationModeEnum {
- NLLB_MODEL;
+ NLLB_MODEL,
+
+ OPUS_MT_ZH_EN,
+
+ OPUS_MT_EN_ZH;
/**
* 根据名称获取枚举 (忽略大小写和下划线变体)
diff --git a/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/factory/TranslationModelFactory.java b/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/factory/TranslationModelFactory.java
index 5ca6797..e13ad9e 100644
--- a/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/factory/TranslationModelFactory.java
+++ b/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/factory/TranslationModelFactory.java
@@ -4,8 +4,10 @@ import cn.smartjavaai.common.config.Config;
import cn.smartjavaai.translation.config.TranslationModelConfig;
+import cn.smartjavaai.translation.enums.TranslationModeEnum;
import cn.smartjavaai.translation.exception.TranslationException;
import cn.smartjavaai.translation.model.NllbModel;
+import cn.smartjavaai.translation.model.OpusMtModel;
import cn.smartjavaai.translation.model.TranslationModel;
import lombok.extern.slf4j.Slf4j;
@@ -23,14 +25,14 @@ public class TranslationModelFactory {
// 使用 volatile 和双重检查锁定来确保线程安全的单例模式
private static volatile TranslationModelFactory instance;
- private static final ConcurrentHashMap modelMap = new ConcurrentHashMap<>();
+ private static final ConcurrentHashMap modelMap = new ConcurrentHashMap<>();
/**
* 检测模型注册表
*/
- private static final Map> modelRegistry =
+ private static final Map> modelRegistry =
new ConcurrentHashMap<>();
@@ -49,11 +51,11 @@ public class TranslationModelFactory {
/**
* 注册翻译模型
- * @param name
+ * @param translationModeEnum
* @param clazz
*/
- private static void registerCommonDetModel(String name, Class extends TranslationModel> clazz) {
- modelRegistry.put(name.toLowerCase(), clazz);
+ private static void registerCommonDetModel(TranslationModeEnum translationModeEnum, Class extends TranslationModel> clazz) {
+ modelRegistry.put(translationModeEnum, clazz);
}
/**
@@ -65,7 +67,7 @@ public class TranslationModelFactory {
if(Objects.isNull(config) || Objects.isNull(config.getModelEnum())){
throw new TranslationException("未配置OCR模型");
}
- return modelMap.computeIfAbsent(config.getModelEnum().name(), k -> {
+ return modelMap.computeIfAbsent(config.getModelEnum(), k -> {
return createModel(config);
});
}
@@ -77,7 +79,7 @@ public class TranslationModelFactory {
* @return
*/
private TranslationModel createModel(TranslationModelConfig config) {
- Class> clazz = modelRegistry.get(config.getModelEnum().name().toLowerCase());
+ Class> clazz = modelRegistry.get(config.getModelEnum());
if(clazz == null){
throw new TranslationException("Unsupported model");
}
@@ -94,7 +96,9 @@ public class TranslationModelFactory {
// 初始化默认算法
static {
- registerCommonDetModel("NLLB_MODEL", NllbModel.class);
+ registerCommonDetModel(TranslationModeEnum.NLLB_MODEL, NllbModel.class);
+ registerCommonDetModel(TranslationModeEnum.OPUS_MT_EN_ZH, OpusMtModel.class);
+ registerCommonDetModel(TranslationModeEnum.OPUS_MT_ZH_EN, OpusMtModel.class);
log.debug("缓存目录:{}", Config.getCachePath());
}
diff --git a/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/model/BeamHypotheses.java b/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/model/BeamHypotheses.java
new file mode 100644
index 0000000..6ae5ce4
--- /dev/null
+++ b/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/model/BeamHypotheses.java
@@ -0,0 +1,121 @@
+package cn.smartjavaai.translation.model;
+
+import ai.djl.util.Pair;
+
+import java.util.ArrayList;
+
+/**
+ * Beam hypothesis
+ *
+ * @author Calvin
+ * @mail 179209347@qq.com
+ * @website www.aias.top
+ */
+public class BeamHypotheses {
+ float length_penalty;
+ boolean early_stopping;
+ int num_beams;
+ ArrayList> beams;
+ float worst_score = 1e9f;
+
+ public BeamHypotheses(float length_penalty, boolean early_stopping, int num_beams) {
+ this.length_penalty = length_penalty;
+ this.early_stopping = early_stopping;
+ this.num_beams = num_beams;
+ beams = new ArrayList<>();
+ }
+
+ /**
+ * Get length
+ *
+ * @return
+ */
+ public int getLen() {
+ return beams.size();
+ }
+
+ /**
+ * Add a new hypothesis to the list.
+ *
+ * @param sum_logprobs
+ * @param hyp
+ */
+ public void add(float sum_logprobs, long[] hyp) {
+ float score = sum_logprobs / (float) (Math.pow(hyp.length, this.length_penalty));
+
+ if (getLen() < this.num_beams || score > this.worst_score) {
+ this.beams.add(new Pair<>(score, hyp));
+ if (getLen() > this.num_beams) {
+ int index = min();
+ this.beams.remove(index);
+ index = min();
+ this.worst_score = this.beams.get(index).getKey();
+ }else {
+ this.worst_score = Math.min(score, this.worst_score);
+ }
+ }
+ }
+
+ /**
+ * Get Pair
+ * @param index
+ * @return
+ */
+ public Pair getPair(int index) {
+ return beams.get(index);
+ }
+
+ /**
+ * Get index for minmum score value
+ *
+ * @return
+ */
+ public int min() {
+ float min = beams.get(0).getKey();
+ int index = 0;
+ for (int i = 1; i < beams.size(); ++i) {
+ if (beams.get(i).getKey() < min) {
+ min = beams.get(i).getKey();
+ index = i;
+ }
+ }
+ return index;
+ }
+
+ /**
+ * Get index for maximum score value
+ * @return
+ */
+ public int max() {
+ float max = beams.get(0).getKey();
+ int index = 0;
+ for (int i = 1; i < beams.size(); ++i) {
+ if (beams.get(i).getKey() > max) {
+ max = beams.get(i).getKey();
+ index = i;
+ }
+ }
+ return index;
+ }
+
+ /**
+ * If there are enough hypotheses and that none of the hypotheses being generated can become better than the worst
+ * one in the heap, then we are done with this sentence.
+ *
+ * @param best_sum_logprobs
+ * @param cur_len
+ * @return
+ */
+ public boolean isDone(float best_sum_logprobs, long cur_len) {
+ if (getLen() < this.num_beams)
+ return false;
+
+ if (this.early_stopping)
+ return true;
+ else {
+ float highest_attainable_score = best_sum_logprobs / (float) Math.pow(cur_len, this.length_penalty);
+ boolean ret = (this.worst_score >= highest_attainable_score);
+ return ret;
+ }
+ }
+}
diff --git a/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/model/BeamSearchScorer.java b/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/model/BeamSearchScorer.java
new file mode 100644
index 0000000..48ed946
--- /dev/null
+++ b/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/model/BeamSearchScorer.java
@@ -0,0 +1,118 @@
+package cn.smartjavaai.translation.model;
+
+import ai.djl.ndarray.NDArray;
+import ai.djl.ndarray.NDList;
+import ai.djl.ndarray.NDManager;
+import ai.djl.ndarray.index.NDIndex;
+import ai.djl.ndarray.types.Shape;
+import ai.djl.util.Pair;
+
+/**
+ * Implementing standard beam search decoding.
+ *
+ * @author Calvin
+ * @mail 179209347@qq.com
+ * @website www.aias.top
+ */
+public class BeamSearchScorer {
+ private int num_beams;
+ private float length_penalty;
+ private boolean do_early_stopping;
+ private int num_beam_hyps_to_keep;
+ private int num_beam_groups;
+ private BeamHypotheses beam_hyp;
+ private boolean _done;
+
+ public BeamSearchScorer(int num_beams, float length_penalty, boolean do_early_stopping, int num_beam_hyps_to_keep, int num_beam_groups) {
+ this.num_beams = num_beams;
+ this.length_penalty = length_penalty;
+ this.do_early_stopping = do_early_stopping;
+ this.num_beam_hyps_to_keep = num_beam_hyps_to_keep;
+ this.num_beam_groups = num_beam_groups;
+ beam_hyp = new BeamHypotheses(length_penalty, do_early_stopping, num_beams);
+ _done = false;
+ }
+
+ public boolean isDone() {
+ return _done;
+ }
+
+ public NDList process(NDManager manager, NDArray input_ids, NDArray next_scores, NDArray next_tokens, NDArray next_indices, long pad_token_id, long eos_token_id) {
+
+ float[] next_scores_arr = next_scores.toFloatArray();
+ long[] next_indices_arr = next_indices.toLongArray();
+ long[] next_tokens_arr = next_tokens.toLongArray();
+
+ NDArray next_beam_scores = manager.zeros(new Shape(1, this.num_beams), next_scores.getDataType());
+ NDArray next_beam_tokens = manager.zeros(new Shape(1, this.num_beams), next_tokens.getDataType());
+ NDArray next_beam_indices = manager.zeros(new Shape(1, this.num_beams), next_indices.getDataType());
+
+
+ // next tokens for this sentence
+ int beam_idx = 0;
+ float maxScore = Float.NEGATIVE_INFINITY;
+ for (int i = 0; i < next_scores_arr.length; ++i) {
+ int beam_token_rank = i;
+ long next_token = next_tokens_arr[i];
+ float next_score = next_scores_arr[i];
+ if (maxScore < next_score) {
+ maxScore = next_score;
+ }
+ long next_index = next_indices_arr[i];
+
+ long batch_beam_idx = next_index;
+
+ // add to generated hypotheses if end of sentence
+ if (next_token == eos_token_id) {
+ // if beam_token does not belong to top num_beams tokens, it should not be added
+ if (beam_token_rank >= this.num_beams)
+ continue;
+ long[] arr = input_ids.get(batch_beam_idx).toLongArray();
+ // Add a new hypothesis to the list.
+ beam_hyp.add(next_score, arr);
+ } else {
+ // add next predicted token since it is not eos_token
+ next_beam_scores.set(new NDIndex(0, beam_idx), next_score);
+ next_beam_tokens.set(new NDIndex(0, beam_idx), next_token);
+ next_beam_indices.set(new NDIndex(0, beam_idx), batch_beam_idx);
+ beam_idx += 1;
+ }
+
+ // once the beam for next step is full, don't add more tokens to it.
+ if (beam_idx == this.num_beams)
+ break;
+ }
+
+ long cur_len = input_ids.getShape().getLastDimension();
+ this._done = this._done || beam_hyp.isDone(maxScore, cur_len);
+
+ NDList list = new NDList();
+ list.add(next_beam_scores);
+ list.add(next_beam_tokens);
+ list.add(next_beam_indices);
+
+ return list;
+ }
+
+ public long[] finalize(int max_length, long eos_token_id) {
+
+ // best_hyp_tuple
+ Pair pair = beam_hyp.getPair(beam_hyp.max());
+ float best_score = pair.getKey();
+ long[] best_hyp = pair.getValue();
+ int sent_length = best_hyp.length;
+
+ // prepare for adding eos
+ int sent_max_len = Math.min(sent_length + 1, max_length);
+ long[] decodedArr = new long[sent_max_len];
+
+ for (int i = 0; i < sent_length; ++i) {
+ decodedArr[i] = best_hyp[i];
+ }
+ if (sent_length < max_length) {
+ decodedArr[sent_length] = eos_token_id;
+ }
+
+ return decodedArr;
+ }
+}
diff --git a/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/model/NllbModel.java b/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/model/NllbModel.java
index 71653f7..1b9f22b 100644
--- a/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/model/NllbModel.java
+++ b/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/model/NllbModel.java
@@ -6,6 +6,7 @@ import ai.djl.engine.Engine;
import ai.djl.huggingface.tokenizers.Encoding;
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
import ai.djl.inference.Predictor;
+import ai.djl.modality.nlp.generate.CausalLMOutput;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
@@ -20,7 +21,6 @@ import cn.smartjavaai.common.enums.DeviceEnum;
import cn.smartjavaai.common.pool.CommonPredictorFactory;
import cn.smartjavaai.translation.config.TranslationModelConfig;
import cn.smartjavaai.translation.config.NllbSearchConfig;
-import cn.smartjavaai.translation.entity.CausalLMOutput;
import cn.smartjavaai.translation.entity.GreedyBatchTensorList;
import cn.smartjavaai.translation.entity.TranslateParam;
import cn.smartjavaai.translation.exception.TranslationException;
@@ -40,7 +40,7 @@ import java.nio.file.Paths;
import java.util.Objects;
/**
- * 机器翻译通用检测模型
+ * Nllb机器翻译模型
*
* @author lwx
* @date 2025/6/05
diff --git a/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/model/OpusMtModel.java b/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/model/OpusMtModel.java
new file mode 100644
index 0000000..60f6608
--- /dev/null
+++ b/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/model/OpusMtModel.java
@@ -0,0 +1,442 @@
+package cn.smartjavaai.translation.model;
+
+import ai.djl.Device;
+import ai.djl.MalformedModelException;
+import ai.djl.engine.Engine;
+import ai.djl.huggingface.tokenizers.Encoding;
+import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
+import ai.djl.inference.Predictor;
+import ai.djl.modality.cv.Image;
+import ai.djl.modality.cv.output.DetectedObjects;
+import ai.djl.modality.nlp.generate.CausalLMOutput;
+import ai.djl.modality.nlp.generate.SearchConfig;
+import ai.djl.ndarray.NDArray;
+import ai.djl.ndarray.NDArrays;
+import ai.djl.ndarray.NDList;
+import ai.djl.ndarray.NDManager;
+import ai.djl.ndarray.index.NDIndex;
+import ai.djl.ndarray.types.DataType;
+import ai.djl.ndarray.types.Shape;
+import ai.djl.repository.zoo.Criteria;
+import ai.djl.repository.zoo.ModelNotFoundException;
+import ai.djl.repository.zoo.ModelZoo;
+import ai.djl.repository.zoo.ZooModel;
+import ai.djl.sentencepiece.SpTokenizer;
+import ai.djl.translate.NoopTranslator;
+import ai.djl.translate.TranslateException;
+import ai.djl.util.Utils;
+import cn.smartjavaai.common.entity.R;
+import cn.smartjavaai.common.enums.DeviceEnum;
+import cn.smartjavaai.common.pool.CommonPredictorFactory;
+import cn.smartjavaai.translation.config.NllbSearchConfig;
+import cn.smartjavaai.translation.config.OpusSearchConfig;
+import cn.smartjavaai.translation.config.TranslationModelConfig;
+import cn.smartjavaai.translation.entity.BeamBatchTensorList;
+import cn.smartjavaai.translation.entity.GreedyBatchTensorList;
+import cn.smartjavaai.translation.entity.TranslateParam;
+import cn.smartjavaai.translation.exception.TranslationException;
+import cn.smartjavaai.translation.model.translator.NllbDecoder2Translator;
+import cn.smartjavaai.translation.model.translator.NllbDecoderTranslator;
+import cn.smartjavaai.translation.model.translator.NllbEncoderTranslator;
+import cn.smartjavaai.translation.model.translator.opus.Decoder2Translator;
+import cn.smartjavaai.translation.model.translator.opus.DecoderTranslator;
+import cn.smartjavaai.translation.model.translator.opus.EncoderTranslator;
+import cn.smartjavaai.translation.utils.NDArrayUtils;
+import cn.smartjavaai.translation.utils.TokenUtils;
+import com.google.gson.Gson;
+import com.google.gson.reflect.TypeToken;
+import lombok.extern.slf4j.Slf4j;
+import org.apache.commons.lang3.StringUtils;
+import org.apache.commons.pool2.impl.GenericObjectPool;
+
+import java.io.IOException;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.*;
+import java.util.concurrent.ConcurrentHashMap;
+
+/**
+ * OpusMt机器翻译模型
+ *
+ * @author dwj
+ */
+@Slf4j
+public class OpusMtModel implements TranslationModel{
+
+ private GenericObjectPool> encodePredictorPool;
+
+ private GenericObjectPool> decodePredictorPool;
+
+ private GenericObjectPool> decode2PredictorPool;
+
+ private ZooModel model;
+ private SpTokenizer sourceTokenizer;
+
+ private OpusSearchConfig searchConfig;
+ private TranslationModelConfig config;
+
+ private ConcurrentHashMap map;
+
+ private ConcurrentHashMap reverseMap;
+
+ private float length_penalty = 1.0f;
+ private boolean do_early_stopping = false;
+ private int num_beam_hyps_to_keep = 1;
+ private int num_beam_groups = 1;
+
+
+
+ @Override
+ public void loadModel(TranslationModelConfig config) {
+ if (StringUtils.isBlank(config.getModelPath())) {
+ throw new TranslationException("modelPath is null");
+ }
+ Device device = null;
+ if (!Objects.isNull(config.getDevice())) {
+ device = config.getDevice() == DeviceEnum.CPU ? Device.cpu() : Device.gpu(config.getGpuId());
+ }
+ this.config = config;
+ Path modelPath = Paths.get(config.getModelPath());
+ Criteria criteria =
+ Criteria.builder()
+ .setTypes(NDList.class, NDList.class)
+ .optModelPath(modelPath)
+ .optEngine("PyTorch")
+ .optDevice(device)
+ .optTranslator(new NoopTranslator())
+ .build();
+ try {
+ model = ModelZoo.loadModel(criteria);
+ encodePredictorPool = new GenericObjectPool<>(new CommonPredictorFactory(model,new EncoderTranslator()));
+ decodePredictorPool = new GenericObjectPool<>(new CommonPredictorFactory(model,new DecoderTranslator()));
+ decode2PredictorPool = new GenericObjectPool<>(new CommonPredictorFactory(model,new Decoder2Translator()));
+
+ Path tokenizerPath = modelPath.getParent().resolve("source.spm");
+ sourceTokenizer = new SpTokenizer(tokenizerPath);
+ List words = Utils.readLines(modelPath.getParent().resolve("vocab.txt"));
+ String jsonStr = "";
+ for (String line : words) {
+ jsonStr = jsonStr + line;
+ }
+ map = new Gson().fromJson(jsonStr, new TypeToken>() {
+ }.getType());
+ reverseMap = new ConcurrentHashMap<>();
+ Iterator it = map.entrySet().iterator();
+ while (it.hasNext()) {
+ Map.Entry next = (Map.Entry) it.next();
+ reverseMap.put(next.getValue(), next.getKey());
+ }
+ //初始化searchConfig
+ this.searchConfig = new OpusSearchConfig();
+ int predictorPoolSize = config.getPredictorPoolSize();
+ if(config.getPredictorPoolSize() <= 0){
+ predictorPoolSize = Runtime.getRuntime().availableProcessors(); // 默认等于CPU核心数
+ }
+ encodePredictorPool.setMaxTotal(predictorPoolSize);
+ decodePredictorPool.setMaxTotal(predictorPoolSize);
+ decode2PredictorPool.setMaxTotal(predictorPoolSize);
+ log.debug("当前设备: " + model.getNDManager().getDevice());
+ log.debug("当前引擎: " + Engine.getInstance().getEngineName());
+ log.debug("模型推理器线程池最大数量: " + predictorPoolSize);
+ } catch (IOException | ModelNotFoundException | MalformedModelException e) {
+ throw new TranslationException("模型加载失败", e);
+ }
+ }
+
+ @Override
+ public R translate(TranslateParam translateParam) {
+ if(translateParam == null){
+ return R.fail(R.Status.PARAM_ERROR);
+ }
+ //验证
+ if (StringUtils.isBlank(translateParam.getInput())) {
+ return R.fail(R.Status.PARAM_ERROR.getCode(), "输入文本不能为空");
+ }
+ return R.ok(translateLanguage(translateParam));
+ }
+
+ @Override
+ public R translate(String input) {
+ //验证
+ if (StringUtils.isBlank(input)) {
+ return R.fail(R.Status.PARAM_ERROR.getCode(), "输入文本不能为空");
+ }
+ return R.ok(translateLanguage(new TranslateParam(input)));
+ }
+
+ private String translateLanguage(TranslateParam translateParam) {
+ try (NDManager manager = NDManager.newBaseManager()) {
+ long numBeam = searchConfig.getBeam();
+ BeamSearchScorer beamSearchScorer = new BeamSearchScorer((int) numBeam, length_penalty, do_early_stopping, num_beam_hyps_to_keep, num_beam_groups);
+ // 1. Encode
+ List tokens = sourceTokenizer.tokenize(translateParam.getInput());
+ String[] strs = tokens.toArray(new String[]{});
+ log.info("Tokens: " + Arrays.toString(strs));
+ int[] sourceIds = new int[tokens.size() + 1];
+ sourceIds[tokens.size()] = 0;
+ for (int i = 0; i < tokens.size(); i++) {
+ sourceIds[i] = map.get(tokens.get(i)).intValue();
+ }
+ NDArray encoder_hidden_states = encoder(sourceIds);
+ encoder_hidden_states = NDArrayUtils.expand(encoder_hidden_states, searchConfig.getBeam());
+
+ NDArray decoder_input_ids = manager.create(new long[]{65000}).reshape(1, 1);
+ decoder_input_ids = NDArrayUtils.expand(decoder_input_ids, numBeam);
+
+
+ long[] attentionMask = new long[sourceIds.length];
+ Arrays.fill(attentionMask, 1);
+ NDArray attentionMaskArray = manager.create(attentionMask).expandDims(0);
+ NDArray new_attention_mask = NDArrayUtils.expand(attentionMaskArray, searchConfig.getBeam());
+ NDList decoderInput = new NDList(decoder_input_ids, encoder_hidden_states, new_attention_mask);
+
+
+ // 2. Initial Decoder
+ CausalLMOutput modelOutput = decoder(decoderInput);
+ modelOutput.getLogits().attach(manager);
+ modelOutput.getPastKeyValuesList().attach(manager);
+
+ NDArray beam_scores = manager.zeros(new Shape(1, numBeam), DataType.FLOAT32);
+ beam_scores.set(new NDIndex(":, 1:"), -1e9);
+ beam_scores = beam_scores.reshape(numBeam, 1);
+
+ NDArray input_ids = decoder_input_ids;
+ BeamBatchTensorList searchState = new BeamBatchTensorList(null, new_attention_mask, encoder_hidden_states, modelOutput.getPastKeyValuesList());
+ NDArray next_tokens;
+ NDArray next_indices;
+ while (true) {
+ if (searchState.getNextInputIds() != null) {
+ decoder_input_ids = searchState.getNextInputIds().get(new NDIndex(":, -1:"));
+ decoderInput = new NDList(decoder_input_ids, searchState.getEncoderHiddenStates(), searchState.getAttentionMask());
+ decoderInput.addAll(searchState.getPastKeyValues());
+ // 3. Decoder loop
+ modelOutput = decoder2(decoderInput);
+ }
+
+ NDArray next_token_logits = modelOutput.getLogits().get(":, -1, :");
+
+ // hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
+ // cannot be generated both before and after the `nn.functional.log_softmax` operation.
+ NDArray new_next_token_logits = manager.create(next_token_logits.getShape(), next_token_logits.getDataType());
+ next_token_logits.copyTo(new_next_token_logits);
+ new_next_token_logits.set(new NDIndex(":," + searchConfig.getPadTokenId()), Float.NEGATIVE_INFINITY);
+
+ NDArray next_token_scores = new_next_token_logits.logSoftmax(1);
+
+ // next_token_scores = logits_processor(input_ids, next_token_scores)
+ // 1. NoBadWordsLogitsProcessor
+ next_token_scores.set(new NDIndex(":," + searchConfig.getPadTokenId()), Float.NEGATIVE_INFINITY);
+
+ // 2. MinLengthLogitsProcessor 没生效
+ // 3. ForcedEOSTokenLogitsProcessor
+ long cur_len = input_ids.getShape().getLastDimension();
+ if (cur_len == (searchConfig.getMaxSeqLength() - 1)) {
+ long num_tokens = next_token_scores.getShape().getLastDimension();
+ for (long i = 0; i < num_tokens; i++) {
+ if(i != searchConfig.getEosTokenId()){
+ next_token_scores.set(new NDIndex(":," + i), Float.NEGATIVE_INFINITY);
+ }
+ }
+ next_token_scores.set(new NDIndex(":," + searchConfig.getEosTokenId()), 0);
+ }
+
+ long vocab_size = next_token_scores.getShape().getLastDimension();
+ beam_scores = beam_scores.repeat(1, vocab_size);
+ next_token_scores = next_token_scores.add(beam_scores);
+
+ // reshape for beam search
+ next_token_scores = next_token_scores.reshape(1, numBeam * vocab_size);
+
+ // [batch, beam]
+ NDList topK = next_token_scores.topK(Math.toIntExact(numBeam) * 2, 1, true, true);
+
+ next_token_scores = topK.get(0);
+ next_tokens = topK.get(1);
+
+ // next_indices = next_tokens // vocab_size
+ next_indices = next_tokens.div(vocab_size).toType(DataType.INT64, true);
+
+ // next_tokens = next_tokens % vocab_size
+ next_tokens = next_tokens.mod(vocab_size);
+
+ // stateless
+ NDList beam_outputs = beamSearchScorer.process(manager, input_ids, next_token_scores, next_tokens, next_indices, searchConfig.getPadTokenId(), searchConfig.getEosTokenId());
+
+ beam_scores = beam_outputs.get(0).reshape(numBeam, 1);
+ NDArray beam_next_tokens = beam_outputs.get(1);
+ NDArray beam_idx = beam_outputs.get(2);
+
+ // input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
+ long[] beam_next_tokens_arr = beam_next_tokens.toLongArray();
+ long[] beam_idx_arr = beam_idx.toLongArray();
+ NDList inputList = new NDList();
+ for (int i = 0; i < numBeam; i++) {
+ long index = beam_idx_arr[i];
+ NDArray ndArray = input_ids.get(index).reshape(1, input_ids.getShape().getLastDimension());
+ ndArray = ndArray.concat(manager.create(beam_next_tokens_arr[i]).reshape(1, 1), 1);
+ inputList.add(ndArray);
+ }
+ input_ids = NDArrays.concat(inputList, 0);
+ searchState.setNextInputIds(input_ids);
+ searchState.setPastKeyValues(modelOutput.getPastKeyValuesList());
+
+ boolean maxLengthCriteria = (input_ids.getShape().getLastDimension() >= searchConfig.getMaxSeqLength());
+ if (beamSearchScorer.isDone() || maxLengthCriteria) {
+ break;
+ }
+
+ }
+
+ long[] sequences = beamSearchScorer.finalize(searchConfig.getMaxSeqLength(), searchConfig.getEosTokenId());
+ String result = TokenUtils.decode(reverseMap, sequences);
+ return result;
+ } catch (Exception e) {
+ throw new TranslationException("翻译错误", e);
+ }
+ }
+
+ public NDArray encoder(int[] ids) {
+ Predictor predictor = null;
+ try {
+ predictor = (Predictor)encodePredictorPool.borrowObject();
+ return predictor.predict(ids);
+ } catch (Exception e) {
+ throw new TranslationException("机器翻译编码错误", e);
+ }finally {
+ if (predictor != null) {
+ try {
+ encodePredictorPool.returnObject(predictor); //归还
+ } catch (Exception e) {
+ log.warn("归还Predictor失败", e);
+ try {
+ predictor.close(); // 归还失败才销毁
+ } catch (Exception ex) {
+ log.error("关闭Predictor失败", ex);
+ }
+ }
+ }
+ }
+ }
+
+ public CausalLMOutput decoder(NDList input) throws TranslateException {
+ Predictor predictor = null;
+ try {
+ predictor = (Predictor)decodePredictorPool.borrowObject();
+ return predictor.predict(input);
+ } catch (Exception e) {
+ throw new TranslationException("机器翻译编码错误", e);
+ }finally {
+ if (predictor != null) {
+ try {
+ decodePredictorPool.returnObject(predictor); //归还
+ } catch (Exception e) {
+ log.warn("归还Predictor失败", e);
+ try {
+ predictor.close(); // 归还失败才销毁
+ } catch (Exception ex) {
+ log.error("关闭Predictor失败", ex);
+ }
+ }
+ }
+ }
+ }
+
+ public CausalLMOutput decoder2(NDList input) throws TranslateException {
+ Predictor predictor = null;
+ try {
+ predictor = (Predictor)decode2PredictorPool.borrowObject();
+ return predictor.predict(input);
+ } catch (Exception e) {
+ throw new TranslationException("机器翻译编码错误", e);
+ }finally {
+ if (predictor != null) {
+ try {
+ decode2PredictorPool.returnObject(predictor); //归还
+ } catch (Exception e) {
+ log.warn("归还Predictor失败", e);
+ try {
+ predictor.close(); // 归还失败才销毁
+ } catch (Exception ex) {
+ log.error("关闭Predictor失败", ex);
+ }
+ }
+ }
+ }
+ }
+
+ public NDArray greedyStepGen(NllbSearchConfig config, NDArray pastOutputIds, NDArray next_token_scores, NDManager manager) {
+ next_token_scores = next_token_scores.get(":, -1, :");
+
+ NDArray new_next_token_scores = manager.create(next_token_scores.getShape(), next_token_scores.getDataType());
+ next_token_scores.copyTo(new_next_token_scores);
+
+ // LogitsProcessor 1. ForcedBOSTokenLogitsProcessor
+ // 设置目标语言
+ long cur_len = pastOutputIds.getShape().getLastDimension();
+ if (cur_len == 1) {
+ long num_tokens = new_next_token_scores.getShape().getLastDimension();
+ for (long i = 0; i < num_tokens; i++) {
+ if (i != config.getForcedBosTokenId()) {
+ new_next_token_scores.set(new NDIndex(":," + i), Float.NEGATIVE_INFINITY);
+ }
+ }
+ new_next_token_scores.set(new NDIndex(":," + config.getForcedBosTokenId()), 0);
+ }
+
+ NDArray probs = new_next_token_scores.softmax(-1);
+ NDArray next_tokens = probs.argMax(-1);
+
+ return next_tokens.expandDims(0);
+ }
+
+ public GenericObjectPool> getEncodePredictorPool() {
+ return encodePredictorPool;
+ }
+
+ public GenericObjectPool> getDecodePredictorPool() {
+ return decodePredictorPool;
+ }
+
+ public GenericObjectPool> getDecode2PredictorPool() {
+ return decode2PredictorPool;
+ }
+
+ @Override
+ public void close() throws Exception {
+ try {
+ if (model != null) {
+ model.close();
+ }
+ } catch (Exception e) {
+ log.warn("关闭 model 失败", e);
+ }
+ try {
+ if (sourceTokenizer != null) {
+ sourceTokenizer.close();
+ }
+ } catch (Exception e) {
+ log.warn("关闭 tokenizer 失败", e);
+ }
+ try {
+ if (encodePredictorPool != null) {
+ encodePredictorPool.close();
+ }
+ } catch (Exception e) {
+ log.warn("关闭 encodePredictorPool 失败", e);
+ }
+ try {
+ if (decodePredictorPool != null) {
+ decodePredictorPool.close();
+ }
+ } catch (Exception e) {
+ log.warn("关闭 decodePredictorPool 失败", e);
+ }
+ try {
+ if (decode2PredictorPool != null) {
+ decode2PredictorPool.close();
+ }
+ } catch (Exception e) {
+ log.warn("关闭 decode2PredictorPool 失败", e);
+ }
+ }
+}
diff --git a/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/model/TranslationModel.java b/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/model/TranslationModel.java
index f0b35d2..a9d55ec 100644
--- a/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/model/TranslationModel.java
+++ b/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/model/TranslationModel.java
@@ -28,6 +28,15 @@ public interface TranslationModel extends AutoCloseable{
}
+ /**
+ * 机器翻译
+ * @param input 输入文本
+ * @return
+ */
+ default R translate(String input) {
+ throw new UnsupportedOperationException("默认不支持该功能");
+ }
+
diff --git a/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/model/translator/NllbDecoder2Translator.java b/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/model/translator/NllbDecoder2Translator.java
index 6e18541..f98e4aa 100644
--- a/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/model/translator/NllbDecoder2Translator.java
+++ b/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/model/translator/NllbDecoder2Translator.java
@@ -1,10 +1,10 @@
package cn.smartjavaai.translation.model.translator;
+import ai.djl.modality.nlp.generate.CausalLMOutput;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.translate.NoBatchifyTranslator;
import ai.djl.translate.TranslatorContext;
-import cn.smartjavaai.translation.entity.CausalLMOutput;
/**
diff --git a/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/model/translator/NllbDecoderTranslator.java b/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/model/translator/NllbDecoderTranslator.java
index 7867c87..717b49c 100644
--- a/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/model/translator/NllbDecoderTranslator.java
+++ b/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/model/translator/NllbDecoderTranslator.java
@@ -1,10 +1,10 @@
package cn.smartjavaai.translation.model.translator;
+import ai.djl.modality.nlp.generate.CausalLMOutput;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.translate.NoBatchifyTranslator;
import ai.djl.translate.TranslatorContext;
-import cn.smartjavaai.translation.entity.CausalLMOutput;
/**
* 解碼器,參數沒有 pastKeyValues
diff --git a/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/utils/NDArrayUtils.java b/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/utils/NDArrayUtils.java
new file mode 100644
index 0000000..f159698
--- /dev/null
+++ b/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/utils/NDArrayUtils.java
@@ -0,0 +1,28 @@
+package cn.smartjavaai.translation.utils;
+
+import ai.djl.ndarray.NDArray;
+import ai.djl.ndarray.NDArrays;
+import ai.djl.ndarray.NDList;
+
+/**
+ * NDArray 工具类
+ *
+ * @author Calvin
+ * @mail 179209347@qq.com
+ * @website www.aias.top
+ */
+public final class NDArrayUtils {
+
+ private NDArrayUtils() {
+ }
+
+ public static NDArray expand(NDArray array, long beam) {
+ NDList list = new NDList();
+ for (long i = 0; i < beam; i++) {
+ list.add(array);
+ }
+ NDArray result = NDArrays.concat(list, 0);
+
+ return result;
+ }
+}
diff --git a/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/utils/TokenUtils.java b/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/utils/TokenUtils.java
index d613432..f19dc4b 100644
--- a/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/utils/TokenUtils.java
+++ b/smartjavaai-translate/src/main/java/cn/smartjavaai/translation/utils/TokenUtils.java
@@ -6,6 +6,8 @@ import cn.smartjavaai.translation.config.NllbSearchConfig;
import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Map;
/**
*
@@ -44,4 +46,30 @@ public final class TokenUtils {
String text = tokenizer.decode(ids);
return text;
}
+
+ /**
+ * Token 解码
+ * 根据语言的类型更新下面的方法
+ *
+ * @param reverseMap
+ * @param outputIds
+ * @return
+ */
+ public static String decode(Map reverseMap, long[] outputIds) {
+ int[] intArray = Arrays.stream(outputIds).mapToInt(l -> (int) l).toArray();
+
+ StringBuffer sb = new StringBuffer();
+ for (int value : intArray) {
+ // 65000
+ // 0
+ if (value == 65000 || value == 0 || value == 8)
+ continue;
+ String text = reverseMap.get(Long.valueOf(value));
+ sb.append(text);
+ }
+
+ String result = sb.toString();
+ result = result.replaceAll("▁"," ");
+ return result;
+ }
}
--
Gitee