diff --git a/pom.xml b/pom.xml index 16d9dd95b42985c7c83d5d0c72fb56e2340490cd..6fee4e62af8b02a3503a2657acb74f08812c3f67 100644 --- a/pom.xml +++ b/pom.xml @@ -15,6 +15,7 @@ smartjavaai-objectdetection smartjavaai-all smartjavaai-ocr + smartjavaai-translate smartjavaai-bom @@ -80,6 +81,13 @@ ai.djl.mxnet mxnet-model-zoo + + + + ai.djl.mxnet + mxnet-engine + + @@ -162,12 +170,13 @@ tensorflow-engine runtime - - + + + ai.djl.onnxruntime onnxruntime-engine @@ -189,6 +198,24 @@ runtime + + ai.djl.pytorch + pytorch-engine + ${djl.version} + + + + ai.djl.huggingface + tokenizers + ${djl.version} + + + + ai.djl.sentencepiece + sentencepiece + ${djl.version} + + cn.hutool hutool-system diff --git a/smartjavaai-translate/pom.xml b/smartjavaai-translate/pom.xml new file mode 100644 index 0000000000000000000000000000000000000000..9095fe3292105d2acc05ec8d74648e790217900c --- /dev/null +++ b/smartjavaai-translate/pom.xml @@ -0,0 +1,149 @@ + + + 4.0.0 + + cn.smartjavaai + smartjavaai-parent + 1.0.13 + + + smartjavaai-translate + + + + + UTF-8 + 1.5.8 + 5.1.2-1.5.8 + + + + + cn.smartjavaai + smartjavaai-common + ${project.version} + + + org.bytedeco + javacpp + ${javacv.version} + + + org.bytedeco + ffmpeg + ${javacv.ffmpeg.version} + + + ai.djl.opencv + opencv + + + + 1.0.13 + smartjavaai-ocr + SmartJavaAI + https://github.com/geekwenjie/SmartJavaAI + + + MIT License + https://opensource.org/licenses/MIT + + + + + + + org.sonatype.central + central-publishing-maven-plugin + 0.4.0 + true + + dengwenjie + true + ${project.groupId}:${project.artifactId}:${project.version} + + + + + org.apache.maven.plugins + maven-source-plugin + 3.1.0 + + + attach-sources + + jar-no-fork + + + + + + org.apache.maven.plugins + maven-javadoc-plugin + 3.1.0 + + + none + + -Xdoclint:none + + + + + attach-javadocs + + jar + + + + + + org.apache.maven.plugins + maven-gpg-plugin + 3.1.0 + + + sign-artifacts + verify + + sign + + + + + + + + + + scm:git:git://github.com/geekwenjie/SmartJavaAI.git + scm:git:ssh://github.com/geekwenjie/SmartJavaAI.git + http://github.com/geekwenjie/SmartJavaAI/tree/master + + + + + + dengwenjie + https://s01.oss.sonatype.org/content/repositories/snapshots + + + dengwenjie + https://s01.oss.sonatype.org/service/local/staging/deploy/maven2/ + + + + + + dengwenjie + 775747758@qq.com + + Project Manager + Architect + + + + + diff --git a/smartjavaai-translate/src/main/java/cn/smartjavaai/translate/TextTranslation.java b/smartjavaai-translate/src/main/java/cn/smartjavaai/translate/TextTranslation.java new file mode 100644 index 0000000000000000000000000000000000000000..c86080bca90af7e4937e0f6f884e30d88b17fc15 --- /dev/null +++ b/smartjavaai-translate/src/main/java/cn/smartjavaai/translate/TextTranslation.java @@ -0,0 +1,57 @@ +package cn.smartjavaai.translate; + +import ai.djl.Device; +import ai.djl.ModelException; +import ai.djl.translate.TranslateException; +import cn.smartjavaai.translate.generate.SearchConfig; +import cn.smartjavaai.translate.model.NllbModel; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; + +/** + * 文本翻译,支持202种语言互译 + * + * @author Calvin + * @mail 179209347@qq.com + * @website www.aias.top + */ +public final class TextTranslation { + + private static final Logger logger = LoggerFactory.getLogger(TextTranslation.class); + + private TextTranslation() { + } + + public static void main(String[] args) throws ModelException, IOException, + TranslateException { + + SearchConfig config = new SearchConfig(); + // 设置输出文字的最大长度 + config.setMaxSeqLength(128); + // 设置源语言:中文 "zho_Hans": 256200 + config.setSrcLangId(256200); + // 设置目标语言:英文 "eng_Latn": 256047 + config.setForcedBosTokenId(256047); + config.setForcedBosTokenId(256201); + + // 输入文字 + String input = "智利北部的丘基卡马塔矿是世界上最大的露天矿之一,长约4公里,宽3公里,深1公里。"; + + String modelPath = "E:\\ai\\models\\nlp\\"; + String cpuModelName = "traced_translation_cpu.pt"; + String gpuModelName = "traced_translation_gpu.pt"; + try (NllbModel nllbModel = new NllbModel(config, modelPath, cpuModelName, Device.cpu())) { + + System.setProperty("ai.djl.pytorch.graph_optimizer", "false"); + + // 运行模型,获取翻译结果 + String result = nllbModel.translate(input); + + logger.info("result========={}", result); + } finally { + System.clearProperty("ai.djl.pytorch.graph_optimizer"); + } + } +} \ No newline at end of file diff --git a/smartjavaai-translate/src/main/java/cn/smartjavaai/translate/generate/BatchTensorList.java b/smartjavaai-translate/src/main/java/cn/smartjavaai/translate/generate/BatchTensorList.java new file mode 100644 index 0000000000000000000000000000000000000000..51ebfc5718df23051df2791e2e94b356109f7309 --- /dev/null +++ b/smartjavaai-translate/src/main/java/cn/smartjavaai/translate/generate/BatchTensorList.java @@ -0,0 +1,174 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package cn.smartjavaai.translate.generate; + +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; + +/** + * BatchTensorList represents a search state, and the NDArrays inside are updated in each iteration + * of the autoregressive loop. + * + *

It is a struct consisting of NDArrays, whose first dimension is batch, and also contains + * sequence dimension (whose position in tensor's shape is specified by seqDimOrder). The SeqBatcher + * batch operations will operate on these two dimensions. + */ +public abstract class BatchTensorList { + // [batch, seq_past]. seq-dim-size == |past_seq| + |inputIds|. Will grow. + private NDArray pastOutputIds; + + // [batch, seq_past] + // The cache of past attentionMask. seq-dim-size == |past_seq| + |inputIds|. Will grow. + private NDArray pastAttentionMask; + + // (k, v) * numLayer, + // kv: [batch, heads, seq_past, kvfeature] + // The cache of past sequence. seq-dim-size == |past_seq| + |inputIds|. Will grow. + private NDList pastKeyValues; + + // Sequence dimension order among all dimensions for each element in the batch list. + private long[] seqDimOrder; + + BatchTensorList() {} + + /** + * Constructs a new {@code BatchTensorList} instance. + * + * @param list the NDList that contains the serialized version of the batch tensors + * @param seqDimOrder the sequence dimension order that specifies where the sequence dimension + * is in a tensor's shape + */ + BatchTensorList(NDList list, long[] seqDimOrder) { + this.seqDimOrder = seqDimOrder; + pastOutputIds = list.get(0); + pastAttentionMask = list.get(1); + pastKeyValues = list.subNDList(2); + } + + /** + * Constructs a new {@code BatchTensorList} instance. + * + * @param pastOutputIds past output token ids + * @param pastAttentionMask past attention mask + * @param pastKeyValues past kv cache + * @param seqDimOrder the sequence dimension order that specifies where the sequence dimension + * is in a tensor's shape + */ + BatchTensorList( + NDArray pastOutputIds, + NDArray pastAttentionMask, + NDList pastKeyValues, + long[] seqDimOrder) { + this.pastKeyValues = pastKeyValues; + this.pastOutputIds = pastOutputIds; + this.pastAttentionMask = pastAttentionMask; + this.seqDimOrder = seqDimOrder; + } + + /** + * Constructs a new {@code BatchTensorList} instance from the serialized version of the batch + * tensors. + * + *

The pastOutputIds has to be the first in the output list. + * + * @param inputList the serialized version of the batch tensors + * @param seqDimOrder the sequence dimension order that specifies where the sequence dimension + * is in a tensor's shape + * @return BatchTensorList + */ + public abstract BatchTensorList fromList(NDList inputList, long[] seqDimOrder); + + /** + * Returns the serialized version of the BatchTensorList. The pastOutputIds has to be the first + * in the output list. + * + * @return the NDList that contains the serialized BatchTensorList + */ + public abstract NDList getList(); + + /** + * Returns the sequence dimension order which specifies where the sequence dimension is in a + * tensor's shape. + * + * @return the sequence dimension order which specifies where the sequence dimension is in a + * tensor's shape + */ + public long[] getSeqDimOrder() { + return seqDimOrder; + } + + /** + * Returns the value of the pastOutputIds. + * + * @return the value of pastOutputIds + */ + public NDArray getPastOutputIds() { + return pastOutputIds; + } + + /** + * Sets the past output token ids. + * + * @param pastOutputIds the past output token ids + */ + public void setPastOutputIds(NDArray pastOutputIds) { + this.pastOutputIds = pastOutputIds; + } + + /** + * Returns the value of the pastAttentionMask. + * + * @return the value of pastAttentionMask + */ + public NDArray getPastAttentionMask() { + return pastAttentionMask; + } + + /** + * Sets the attention mask. + * + * @param pastAttentionMask the attention mask + */ + public void setPastAttentionMask(NDArray pastAttentionMask) { + this.pastAttentionMask = pastAttentionMask; + } + + /** + * Returns the value of the pastKeyValues. + * + * @return the value of pastKeyValues + */ + public NDList getPastKeyValues() { + return pastKeyValues; + } + + /** + * Sets the kv cache. + * + * @param pastKeyValues the kv cache + */ + public void setPastKeyValues(NDList pastKeyValues) { + this.pastKeyValues = pastKeyValues; + } + + /** + * Sets the sequence dimension order which specifies where the sequence dimension is in a + * tensor's shape. + * + * @param seqDimOrder the sequence dimension order which specifies where the sequence dimension + * is in a tensor's shape + */ + public void setSeqDimOrder(long[] seqDimOrder) { + this.seqDimOrder = seqDimOrder; + } +} \ No newline at end of file diff --git a/smartjavaai-translate/src/main/java/cn/smartjavaai/translate/generate/CausalLMOutput.java b/smartjavaai-translate/src/main/java/cn/smartjavaai/translate/generate/CausalLMOutput.java new file mode 100644 index 0000000000000000000000000000000000000000..62ce2c9e6a0c2377e21af9b0221f766564b67e03 --- /dev/null +++ b/smartjavaai-translate/src/main/java/cn/smartjavaai/translate/generate/CausalLMOutput.java @@ -0,0 +1,32 @@ +package cn.smartjavaai.translate.generate; + +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; +/** + * 解码输出对象 + * + * @author Calvin + * @mail 179209347@qq.com + * @website www.aias.top + */ +public class CausalLMOutput { + private NDArray logits; + private NDList pastKeyValuesList; + + public CausalLMOutput(NDArray logits, NDList pastKeyValues) { + this.logits = logits; + this.pastKeyValuesList = pastKeyValues; + } + + public NDArray getLogits() { + return logits; + } + + public void setLogits(NDArray logits) { + this.logits = logits; + } + + public NDList getPastKeyValuesList() { + return pastKeyValuesList; + } +} \ No newline at end of file diff --git a/smartjavaai-translate/src/main/java/cn/smartjavaai/translate/generate/GreedyBatchTensorList.java b/smartjavaai-translate/src/main/java/cn/smartjavaai/translate/generate/GreedyBatchTensorList.java new file mode 100644 index 0000000000000000000000000000000000000000..5b0747688766f76d0b7dd9d7cb9c3920605ac02b --- /dev/null +++ b/smartjavaai-translate/src/main/java/cn/smartjavaai/translate/generate/GreedyBatchTensorList.java @@ -0,0 +1,84 @@ +package cn.smartjavaai.translate.generate; + +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; + +/** + * 贪婪搜索张量对象列表 + * + * @author Calvin + * @mail 179209347@qq.com + * @website www.aias.top + */ +public class GreedyBatchTensorList extends BatchTensorList { + // [batch, 1] + private NDArray nextInputIds; + + private NDArray pastOutputIds; + + private NDArray encoderHiddenStates; + private NDArray attentionMask; + private NDList pastKeyValues; + + public GreedyBatchTensorList( + NDArray nextInputIds, + NDArray pastOutputIds, + NDList pastKeyValues, + NDArray encoderHiddenStates, + NDArray attentionMask) { + this.nextInputIds = nextInputIds; + this.pastKeyValues = pastKeyValues; + this.pastOutputIds = pastOutputIds; + this.attentionMask = attentionMask; + this.encoderHiddenStates = encoderHiddenStates; + } + + public GreedyBatchTensorList() {} + + public BatchTensorList fromList(NDList inputList, long[] seqDimOrder) { + return new GreedyBatchTensorList(); + } + + public NDList getList() { + return new NDList(); + } + + public NDArray getNextInputIds() { + return nextInputIds; + } + + public void setNextInputIds(NDArray nextInputIds) { + this.nextInputIds = nextInputIds; + } + public NDArray getPastOutputIds() { + return pastOutputIds; + } + + public void setPastOutputIds(NDArray pastOutputIds) { + this.pastOutputIds = pastOutputIds; + } + + public NDList getPastKeyValues() { + return pastKeyValues; + } + + public void setPastKeyValues(NDList pastKeyValues) { + this.pastKeyValues = pastKeyValues; + } + + public NDArray getEncoderHiddenStates() { + return encoderHiddenStates; + } + + public void setEncoderHiddenStates(NDArray encoderHiddenStates) { + this.encoderHiddenStates = encoderHiddenStates; + } + + public NDArray getAttentionMask() { + return attentionMask; + } + + public void setAttentionMask(NDArray attentionMask) { + this.attentionMask = attentionMask; + } +} \ No newline at end of file diff --git a/smartjavaai-translate/src/main/java/cn/smartjavaai/translate/generate/SearchConfig.java b/smartjavaai-translate/src/main/java/cn/smartjavaai/translate/generate/SearchConfig.java new file mode 100644 index 0000000000000000000000000000000000000000..50d873fc14d794ba9357e3e987cc6b5f26aa79aa --- /dev/null +++ b/smartjavaai-translate/src/main/java/cn/smartjavaai/translate/generate/SearchConfig.java @@ -0,0 +1,104 @@ +package cn.smartjavaai.translate.generate; +/** + * 配置信息 + * + * @author Calvin + * @mail 179209347@qq.com + * @website www.aias.top + */ +public class SearchConfig { + + private int maxSeqLength; + private long padTokenId; + private long eosTokenId; + private long bosTokenId; + private long decoderStartTokenId; + private float encoderRepetitionPenalty; + private long forcedBosTokenId; + private long srcLangId; + private float lengthPenalty; + public SearchConfig() { + this.maxSeqLength = 512; + this.eosTokenId = 2; + this.bosTokenId = 0; + this.padTokenId = 1; + this.decoderStartTokenId = 2; + this.encoderRepetitionPenalty = 1.0f; + this.srcLangId = 0; + this.forcedBosTokenId = 0; + this.lengthPenalty = 1.0f; + + } + + public long getSrcLangId() { + return srcLangId; + } + + public void setSrcLangId(long srcLangId) { + this.srcLangId = srcLangId; + } + + public void setEosTokenId(long eosTokenId) { + this.eosTokenId = eosTokenId; + } + + public int getMaxSeqLength() { + return maxSeqLength; + } + + public void setMaxSeqLength(int maxSeqLength) { + this.maxSeqLength = maxSeqLength; + } + + public long getPadTokenId() { + return padTokenId; + } + + public void setPadTokenId(long padTokenId) { + this.padTokenId = padTokenId; + } + + public long getEosTokenId() { + return eosTokenId; + } + + public long getDecoderStartTokenId() { + return decoderStartTokenId; + } + + public void setDecoderStartTokenId(long decoderStartTokenId) { + this.decoderStartTokenId = decoderStartTokenId; + } + + public float getEncoderRepetitionPenalty() { + return encoderRepetitionPenalty; + } + + public void setEncoderRepetitionPenalty(float encoderRepetitionPenalty) { + this.encoderRepetitionPenalty = encoderRepetitionPenalty; + } + + public long getForcedBosTokenId() { + return forcedBosTokenId; + } + + public void setForcedBosTokenId(long forcedBosTokenId) { + this.forcedBosTokenId = forcedBosTokenId; + } + + public float getLengthPenalty() { + return lengthPenalty; + } + + public void setLengthPenalty(float lengthPenalty) { + this.lengthPenalty = lengthPenalty; + } + + public long getBosTokenId() { + return bosTokenId; + } + + public void setBosTokenId(long bosTokenId) { + this.bosTokenId = bosTokenId; + } +} \ No newline at end of file diff --git a/smartjavaai-translate/src/main/java/cn/smartjavaai/translate/model/Decoder2Translator.java b/smartjavaai-translate/src/main/java/cn/smartjavaai/translate/model/Decoder2Translator.java new file mode 100644 index 0000000000000000000000000000000000000000..f3a4f9cf80678fe894e0377c66e8b52ad454107e --- /dev/null +++ b/smartjavaai-translate/src/main/java/cn/smartjavaai/translate/model/Decoder2Translator.java @@ -0,0 +1,45 @@ +package cn.smartjavaai.translate.model; + +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; +import ai.djl.translate.NoBatchifyTranslator; +import ai.djl.translate.TranslatorContext; +import cn.smartjavaai.translate.generate.CausalLMOutput; + +/** + * 解碼器,參數支持 pastKeyValues + * + * @author Calvin + * @mail 179209347@qq.com + * @website www.aias.top + */ +public class Decoder2Translator implements NoBatchifyTranslator { + private String tupleName; + + public Decoder2Translator() { + tupleName = "past_key_values(" + 12 + ',' + 4 + ')'; + } + + @Override + public NDList processInput(TranslatorContext ctx, NDList input) { + + NDArray placeholder = ctx.getNDManager().create(0); + placeholder.setName("module_method:decoder2"); + + input.add(placeholder); + + return input; + } + + @Override + public CausalLMOutput processOutput(TranslatorContext ctx, NDList output) { + NDArray logitsOutput = output.get(0); + NDList pastKeyValuesOutput = output.subNDList(1, 12 * 4 + 1); + + for (NDArray array : pastKeyValuesOutput) { + array.setName(tupleName); + } + + return new CausalLMOutput(logitsOutput, pastKeyValuesOutput); + } +} \ No newline at end of file diff --git a/smartjavaai-translate/src/main/java/cn/smartjavaai/translate/model/DecoderTranslator.java b/smartjavaai-translate/src/main/java/cn/smartjavaai/translate/model/DecoderTranslator.java new file mode 100644 index 0000000000000000000000000000000000000000..34bdf94afd900e3160e0803716550e932e437514 --- /dev/null +++ b/smartjavaai-translate/src/main/java/cn/smartjavaai/translate/model/DecoderTranslator.java @@ -0,0 +1,44 @@ +package cn.smartjavaai.translate.model; + +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; +import ai.djl.translate.NoBatchifyTranslator; +import ai.djl.translate.TranslatorContext; +import cn.smartjavaai.translate.generate.CausalLMOutput; +/** + * 解碼器,參數沒有 pastKeyValues + * + * @author Calvin + * @mail 179209347@qq.com + * @website www.aias.top + */ +public class DecoderTranslator implements NoBatchifyTranslator { + private String tupleName; + + public DecoderTranslator() { + tupleName = "past_key_values(" + 12 + ',' + 4 + ')'; + } + + @Override + public NDList processInput(TranslatorContext ctx, NDList input) { + + NDArray placeholder = ctx.getNDManager().create(0); + placeholder.setName("module_method:decoder"); + + input.add(placeholder); + + return input; + } + + @Override + public CausalLMOutput processOutput(TranslatorContext ctx, NDList output) { + NDArray logitsOutput = output.get(0); + NDList pastKeyValuesOutput = output.subNDList(1, 12 * 4 + 1); + + for (NDArray array : pastKeyValuesOutput) { + array.setName(tupleName); + } + + return new CausalLMOutput(logitsOutput, pastKeyValuesOutput); + } +} \ No newline at end of file diff --git a/smartjavaai-translate/src/main/java/cn/smartjavaai/translate/model/EncoderTranslator.java b/smartjavaai-translate/src/main/java/cn/smartjavaai/translate/model/EncoderTranslator.java new file mode 100644 index 0000000000000000000000000000000000000000..53ef90d2c09ff42eb7831828a44cbc53c73cfc20 --- /dev/null +++ b/smartjavaai-translate/src/main/java/cn/smartjavaai/translate/model/EncoderTranslator.java @@ -0,0 +1,49 @@ +package cn.smartjavaai.translate.model; + +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; +import ai.djl.ndarray.NDManager; +import ai.djl.translate.NoBatchifyTranslator; +import ai.djl.translate.TranslatorContext; + +import java.util.Arrays; + +/** + * 编码器前后处理 + * + * @author Calvin + * @mail 179209347@qq.com + * @website www.aias.top + */ +public class EncoderTranslator implements NoBatchifyTranslator { + + + public EncoderTranslator() { + } + + @Override + public NDList processInput(TranslatorContext ctx, long[] input) throws Exception { + NDManager manager = ctx.getNDManager(); + + NDArray inputIdArray = manager.create(input).expandDims(0); + inputIdArray.setName("input_ids"); + + long[] attentionMask = new long[input.length]; + Arrays.fill(attentionMask, 1); + NDArray attentionMaskArray = manager.create(attentionMask).expandDims(0); + attentionMaskArray.setName("attention_mask"); + + NDArray placeholder = ctx.getNDManager().create(0); + placeholder.setName("module_method:encoder"); + + return new NDList(inputIdArray, attentionMaskArray, placeholder); + } + + @Override + public NDArray processOutput(TranslatorContext ctx, NDList list) { + NDArray encoderHiddenStates = list.get(0); + encoderHiddenStates.detach(); + return encoderHiddenStates; + } + +} \ No newline at end of file diff --git a/smartjavaai-translate/src/main/java/cn/smartjavaai/translate/model/NllbModel.java b/smartjavaai-translate/src/main/java/cn/smartjavaai/translate/model/NllbModel.java new file mode 100644 index 0000000000000000000000000000000000000000..9b6b339ccea7c18b57f09657bc8e1ab5313ba5de --- /dev/null +++ b/smartjavaai-translate/src/main/java/cn/smartjavaai/translate/model/NllbModel.java @@ -0,0 +1,183 @@ +package cn.smartjavaai.translate.model; + +import ai.djl.Device; +import ai.djl.ModelException; +import ai.djl.huggingface.tokenizers.Encoding; +import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer; +import ai.djl.inference.Predictor; +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; +import ai.djl.ndarray.NDManager; +import ai.djl.ndarray.index.NDIndex; +import ai.djl.repository.zoo.Criteria; +import ai.djl.repository.zoo.ZooModel; +import ai.djl.translate.NoopTranslator; +import ai.djl.translate.TranslateException; +import cn.smartjavaai.translate.generate.CausalLMOutput; +import cn.smartjavaai.translate.generate.GreedyBatchTensorList; +import cn.smartjavaai.translate.generate.SearchConfig; +import cn.smartjavaai.translate.tokenizer.TokenUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.nio.file.Paths; +import java.util.Arrays; +/** + * 模型载入及推理 + * + * @author Calvin + * @mail 179209347@qq.com + * @website www.aias.top + */ +public class NllbModel implements AutoCloseable { + private static final Logger logger = LoggerFactory.getLogger(NllbModel.class); + private SearchConfig config; + private ZooModel nllbModel; + private HuggingFaceTokenizer tokenizer; + private Predictor encoderPredictor; + private Predictor decoderPredictor; + private Predictor decoder2Predictor; + private NDManager manager; + + public NllbModel(SearchConfig config, String modelPath, String modelName, Device device) throws ModelException, IOException { + this.config = config; + Criteria criteria = + Criteria.builder() + .setTypes(NDList.class, NDList.class) + .optModelPath(Paths.get(modelPath + modelName)) + .optEngine("PyTorch") + .optDevice(device) + .optTranslator(new NoopTranslator()) + .build(); + + manager = NDManager.newBaseManager(device); + nllbModel = criteria.loadModel(); + tokenizer = HuggingFaceTokenizer.newInstance(Paths.get(modelPath + "tokenizer.json")); + encoderPredictor = nllbModel.newPredictor(new EncoderTranslator()); + decoderPredictor = nllbModel.newPredictor(new DecoderTranslator()); + decoder2Predictor = nllbModel.newPredictor(new Decoder2Translator()); + } + + public NDArray encoder(long[] ids) throws TranslateException { + return encoderPredictor.predict(ids); + } + + public CausalLMOutput decoder(NDList input) throws TranslateException { + return decoderPredictor.predict(input); + } + + public CausalLMOutput decoder2(NDList input) throws TranslateException { + return decoder2Predictor.predict(input); + } + + @Override + public void close() { + encoderPredictor.close(); + decoderPredictor.close(); + decoder2Predictor.close(); + nllbModel.close(); + manager.close(); + tokenizer.close(); + } + + public String translate(String input) throws TranslateException { + + Encoding encoding = tokenizer.encode(input); + long[] ids = encoding.getIds(); + // 1. Encoder + long[] inputIds = new long[ids.length]; + // 设置源语言编码 + inputIds[0] = config.getSrcLangId(); + for (int i = 0; i < ids.length - 1; i++) { + inputIds[i + 1] = ids[i]; + } + logger.info("inputIds: " + Arrays.toString(inputIds)); + long[] attentionMask = encoding.getAttentionMask(); + NDArray attentionMaskArray = manager.create(attentionMask).expandDims(0); + + NDArray encoderHiddenStates = encoder(inputIds); + + NDArray decoder_input_ids = manager.create(new long[]{config.getDecoderStartTokenId()}).reshape(1, 1); + NDList decoderInput = new NDList(decoder_input_ids, encoderHiddenStates, attentionMaskArray); + + // 2. Initial Decoder + CausalLMOutput modelOutput = decoder(decoderInput); + modelOutput.getLogits().attach(manager); + modelOutput.getPastKeyValuesList().attach(manager); + + GreedyBatchTensorList searchState = + new GreedyBatchTensorList(null, decoder_input_ids, modelOutput.getPastKeyValuesList(), encoderHiddenStates, attentionMaskArray); + + while (true) { +// try (NDScope ignore = new NDScope()) { + NDArray pastOutputIds = searchState.getPastOutputIds(); + + if (searchState.getNextInputIds() != null) { + decoderInput = new NDList(searchState.getNextInputIds(), searchState.getEncoderHiddenStates(), searchState.getAttentionMask()); + decoderInput.addAll(searchState.getPastKeyValues()); + // 3. Decoder loop + modelOutput = decoder2(decoderInput); + } + + NDArray outputIds = greedyStepGen(config, pastOutputIds, modelOutput.getLogits()); + + searchState.setNextInputIds(outputIds); + pastOutputIds = pastOutputIds.concat(outputIds, 1); + searchState.setPastOutputIds(pastOutputIds); + + searchState.setPastKeyValues(modelOutput.getPastKeyValuesList()); + + // memory management +// NDScope.unregister(outputIds, pastOutputIds); +// } + + // Termination Criteria + long id = searchState.getNextInputIds().toLongArray()[0]; + if (config.getEosTokenId() == id) { + searchState.setNextInputIds(null); + break; + } + if (searchState.getPastOutputIds() != null && searchState.getPastOutputIds().getShape().get(1) + 1 >= config.getMaxSeqLength()) { + break; + } + } + + if (searchState.getNextInputIds() == null) { + NDArray resultIds = searchState.getPastOutputIds(); + String result = TokenUtils.decode(config, tokenizer, resultIds); + return result; + } else { + NDArray resultIds = searchState.getPastOutputIds(); // .concat(searchState.getNextInputIds(), 1) + String result = TokenUtils.decode(config, tokenizer, resultIds); + return result; + } + + } + + public NDArray greedyStepGen(SearchConfig config, NDArray pastOutputIds, NDArray next_token_scores) { + next_token_scores = next_token_scores.get(":, -1, :"); + + NDArray new_next_token_scores = manager.create(next_token_scores.getShape(), next_token_scores.getDataType()); + next_token_scores.copyTo(new_next_token_scores); + + // LogitsProcessor 1. ForcedBOSTokenLogitsProcessor + // 设置目标语言 + long cur_len = pastOutputIds.getShape().getLastDimension(); + if (cur_len == 1) { + long num_tokens = new_next_token_scores.getShape().getLastDimension(); + for (long i = 0; i < num_tokens; i++) { + if (i != config.getForcedBosTokenId()) { + new_next_token_scores.set(new NDIndex(":," + i), Float.NEGATIVE_INFINITY); + } + } + new_next_token_scores.set(new NDIndex(":," + config.getForcedBosTokenId()), 0); + } + + NDArray probs = new_next_token_scores.softmax(-1); + NDArray next_tokens = probs.argMax(-1); + + return next_tokens.expandDims(0); + } + +} diff --git a/smartjavaai-translate/src/main/java/cn/smartjavaai/translate/tokenizer/TokenUtils.java b/smartjavaai-translate/src/main/java/cn/smartjavaai/translate/tokenizer/TokenUtils.java new file mode 100644 index 0000000000000000000000000000000000000000..f4c0ab9f68ddb120f2ccac0da61d5f3cc45a5db6 --- /dev/null +++ b/smartjavaai-translate/src/main/java/cn/smartjavaai/translate/tokenizer/TokenUtils.java @@ -0,0 +1,48 @@ +package cn.smartjavaai.translate.tokenizer; + +import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer; +import ai.djl.ndarray.NDArray; +import cn.smartjavaai.translate.generate.SearchConfig; + +import java.util.ArrayList; + +/** + * Token工具类 + * + * @author Calvin + * @mail 179209347@qq.com + * @website www.aias.top + */ +public final class TokenUtils { + + private TokenUtils() { + } + + /** + * 语言解码 + * + * @param tokenizer + * @param output + * @return + */ + public static String decode(SearchConfig config, HuggingFaceTokenizer tokenizer, NDArray output) { + long[] outputIds = output.toLongArray(); + ArrayList outputIdsList = new ArrayList<>(); + + for (long id : outputIds) { + if (id == config.getEosTokenId() || id==config.getSrcLangId() || id==config.getForcedBosTokenId()) { + continue; + } + outputIdsList.add(id); + } + + Long[] objArr = outputIdsList.toArray(new Long[0]); + long[] ids = new long[objArr.length]; + for (int i = 0; i < objArr.length; i++) { + ids[i] = objArr[i]; + } + String text = tokenizer.decode(ids); + + return text; + } +} \ No newline at end of file diff --git "a/smartjavaai-translate/\350\257\255\350\250\200\347\274\226\347\240\201.xlsx" "b/smartjavaai-translate/\350\257\255\350\250\200\347\274\226\347\240\201.xlsx" new file mode 100644 index 0000000000000000000000000000000000000000..d3a879f88f3d2df0f467b769dbb7a12d0720c684 Binary files /dev/null and "b/smartjavaai-translate/\350\257\255\350\250\200\347\274\226\347\240\201.xlsx" differ