From 57a2f03e76549b4c91babee013c2f1510b81ba1e Mon Sep 17 00:00:00 2001 From: somnus Date: Mon, 22 Dec 2025 11:55:06 +0800 Subject: [PATCH] add telechat3_36B model --- research/telechat3/README.md | 302 ++++++++ research/telechat3/convert_reversed.py | 85 +++ research/telechat3/convert_weight.py | 111 +++ research/telechat3/telechat.py | 567 +++++++++++++++ .../telechat3-36b/finetune_telechat3_36b.yaml | 194 +++++ .../telechat3-36b/predict_telechat3_36b.yaml | 133 ++++ .../predict_telechat3_36b_parallel.yaml | 91 +++ research/telechat3/telechat_config.py | 214 ++++++ research/telechat3/telechat_interleave.py | 683 ++++++++++++++++++ research/telechat3/telechat_layer.py | 292 ++++++++ research/telechat3/telechat_tokenizer.py | 278 +++++++ research/telechat3/telechat_transformer.py | 594 +++++++++++++++ 12 files changed, 3544 insertions(+) create mode 100644 research/telechat3/README.md create mode 100644 research/telechat3/convert_reversed.py create mode 100644 research/telechat3/convert_weight.py create mode 100644 research/telechat3/telechat.py create mode 100644 research/telechat3/telechat3-36b/finetune_telechat3_36b.yaml create mode 100644 research/telechat3/telechat3-36b/predict_telechat3_36b.yaml create mode 100644 research/telechat3/telechat3-36b/predict_telechat3_36b_parallel.yaml create mode 100644 research/telechat3/telechat_config.py create mode 100644 research/telechat3/telechat_interleave.py create mode 100644 research/telechat3/telechat_layer.py create mode 100644 research/telechat3/telechat_tokenizer.py create mode 100644 research/telechat3/telechat_transformer.py diff --git a/research/telechat3/README.md b/research/telechat3/README.md new file mode 100644 index 000000000..062fa1b02 --- /dev/null +++ b/research/telechat3/README.md @@ -0,0 +1,302 @@ +# 星辰语义大模型 TeleChat3 + +## 模型描述 + +- 星辰语义大模型**TeleChat3**是由中国电信人工智能研究院研发训练的大语言模型,包含36B, 105B两种规模,该系列模型**完全基于国产算力**训练。 +- 本次开源**TeleChat3-36B**模型采用10万亿Tokens中英文高质量语料进行训练,同步开源对话模型**TeleChat3-36B**的多格式、多平台权重文件。 +- **TeleChat3**在训练数据、训练方法等方面进行了改进,在通用问答和知识类、代码类、数学类榜单上相比**TeleChat2**均有大幅提升。 + - **TeleChat3**完全基于国产算力和国产深度学习框架进行训练,算力和算法框架更自主可控。优化MP、PP、SP实现方式提升模型性能,优化算子来提升训练速度。 + - 我们使用大量小模型实验来验证scaling law规律,在不同模型结构、不同数据配比和数据清洗方式中寻找最优设计。 + - 采用RingAttention及其他序列切分方式,实现长文训练性能提升;通过ntk-aware+attention-scaling的方式保证训练长度切换时的平稳过渡,以此来保证模型在不同长度数据下的训练效果。 +- 在微调数据方面,我们进行了指令复杂性提升与多样性扩充,通过数据合成和人工标注生成高质量数据,并使用拒绝采样生成多样的推理路径;通过研究一套基于base模型反向选择偏好对齐数据方案,基于适配数据最大限度提升模型效果。 + - 通用能力较TeleChat1系列模型提升超过29%,在逻辑推理、总结摘要、长文写作和数学计算上均有大幅提升。 + +基于GPU,Torch版本的TeleChat3链接: + +[TeleChat3](https://github.com/Tele-AI/TeleChat3) + +## 模型性能 + +以下模型性能均由Atlas 800T A2硬件环境下测试得出。 + +TeleChat3-36b: + +| config | task | Datasets | SeqLength | phase | performance | +|-----------------------------------------------------| --------------------- |------------|-----------|-----------------|--------------| +| [TeleChat3_36b](./run_telechat_36b_finetune.yaml) | text_generation | example_dataset | 8192 | [finetune](#微调) | 516 tokens/s/p | +| [TeleChat3_36b](./run_telechat_36b_predict.yaml) | text_generation | example_dataset | 8192 | [predict](#推理) | 27.7 tokens/s | + +## 模型文件 + +`TeleChat3` 基于 `mindformers` 实现,主要涉及的文件有: + +1. 模型具体实现:`mindformers/research/telechat3` + + ```bash + telechat + ├── convert_weight.py # torch->ms权重转换脚本 + ├── convert_reversed.py # ms->torch权重转换脚本 + ├── telechat.py # 模型实现 + ├── telechat_config.py # 模型配置项 + ├── telechat_layer.py # telechat网络层定义 + ├── telechat_interleave.py # telechat细粒度多副本 + ├── telechat_tokenizer.py # telechat tokenizer + └── telechat_transformer.py # transformer层实现 + ``` + +2. 模型配置:`mindformers/research/telechat3` + + ```bash + telechat + ├── finetune_telechat_36b.yaml # 36b全量微调启动配置 + └── predict_telechat_36b.yaml # 36b推理启动配置 + ``` + +3. 任务启动脚本:`mindformers/research/telechat3` + + ```text + telechat + ├── run_telechat_predict.py # 推理脚本 + └── run_telechat.py # telechat高阶接口使用脚本 + ``` + +## 环境及数据准备 + +### 安装环境 + +**MindFormers安装**以及**软硬件配套关系**参考[MindFormers安装](../../README_CN.md#二MindFormers安装)和[版本匹配关系](../../README_CN.md#三版本匹配关系)。 + +> 注:Atlas 800T A2芯片支持telechat_36B单机多卡推理,至少使用2卡,全参微调至少需要4机32卡。 + +### 数据及权重准备 + +#### 数据集下载 + +TeleChat3全系列模型中36b,105b所使用的微调数据集是由中电信人工智能科技有限公司所提供。 + +step 1. 构建数据集 + +可构建**单轮对话数据集**或**多轮对话数据集**,数据集为jsonl文件,每条数据内容需包含`input`和`output`字段,参考格式如下: + +- 单轮对话数据集 + + ```text + {"input": "<_user>描述一下中国古代文学的典型特点。", "output": "中国古代文学的典型特点包括:\n1. 以散文和诗歌为主要形式:中国古代文学中,散文和诗歌是最为常见和流行的文学形式。\n2. 强调意境:中国古代文学强调意境,而不仅仅关注情节的描述。作家致力于创造一种美感,通过文字表达出美好、深刻和含蓄的意境。\n3. 重视艺术性和审美价值:中国古代文学非常注重艺术性和审美价值,作家用精妙的文字表现出诗意和美感,增强作品的艺术性。\n4. 探究人性和社会问题:中国古代文学经常探究人性和社会问题,通过作品展示出人类在道德和伦理方面的深刻思考,及对社会现象和问题的探讨。\n5. 借物言志:中国古代文学中常借物言志,比如通过自然景观、动植物、传说和神话等形象,表达出人类的情感和内心世界。\n6. 注重传统文化的保护:中国古代文学注重传统文化的保护,作家在其作品中运用了许多古代文学的传统元素,生动展现出中国文化的独特魅力。"} + ``` + +- 多轮对话数据集 + + ```text + {"input": "<_user>你能告诉我什么是人工智能吗?<_bot>人工智能是指利用计算机技术和算法,让计算机模拟人类的思维和行为,实现智能化的程序设计。它主要包括机器学习、自然语言处理、计算机视觉等技术和应用。人工智能可以应用于语音识别、图像识别、自动驾驶、智能客服等领域,可以有效地提高工作效率,优化资源配置,带来更好的人机交互体验。<_user>未来的人工智能将会发展成怎样呢?", "output": "未来的人工智能发展方向众多,以下是一些可能的方向:\n\n1. 个性化AI:人们对于个性化服务的需求越来越高,未来的人工智能将更加注重个性化,为用户提供量身定制的体验。\n\n2. 更加智能化的自动化:未来的人工智能将致力于让机器更加智能地自动化完成各种任务,从而提升效率、降低生产成本。\n\n3. 完全智能化的知识处理:未来的人工智能将能够更深入地处理和理解知识,掌握更多的上下文信息,从而能够更智能地应用到真实的场景中。\n\n4. 更强大的计算能力:未来的人工智能将伴随着计算技术的不断发展而不断提高其处理和计算能力,为更加复杂的任务处理提供支持。\n\n5. 更好的人机交互:未来的人工智能将更好地与人类交互,可预见的未来将有更多的机器人出现在人们的生活中,应用到教育、医疗、娱乐等领域。\n\n总之,未来的人工智能将在不断的发展中迎来更多更广泛的应用场景和解决更多的实际问题,使人们的生活更加便捷、高效和智能化。"} + ``` + +step 2. 处理数据成mindrecord格式 + +```bash +# 使用mindformers/research/telechat2/telechat_preprocess.py进行数据预处理和Mindrecord数据生成 +python telechat_preprocess.py \ +--input_dataset_file /{path}/ \ +--vocab_file_path /{path}/tokenizer.model \ +--max_length 8192 \ +--output_path /{path}/ +``` + +```text +# 参数说明 +input_dataset_file: 预训练的数据集 +vocab_file_path: 词模型文件路径(如使用上述链接下载,指定到对应路径下即可) +max_length: 数据集长度 +output_path: 生成数据集的路径 +``` + + > 注:`bos`, `eos`, `pad`等特殊`ids`要和`yaml`配置文件中`model_config`部分保持一致,默认`bos_token_id=1`, `eos_token_id=2`, `pad_token_id=3`。 +如果有所修改,配置文件中对应设置也需要修改,通常预训练数据不包含`pad_token`,因此建议设置`pad_token_id=-1`。 + +#### 模型权重下载与转换 + +MindFormers提供已经转换完成的预训练权重、词表文件用于预训练、微调和推理,开发者可以下载获取官方权重后,通过下面提供的**权重转换脚本**,将官方权重转换为MindSpore权重;或直接使用MindFormers提供的**已转换权重** + +1.torch模型权重及词模型下载链接: + +- [TeleChat3-36b](https://modelscope.cn/models/TeleAI/TeleChat3-36B-Thinking/files) + +下载完成后,运行如下转换脚本,将全量微调的权重转换为完整的ckpt权重。 + +```shell +python mindformers/research/telechat3/convert_weight_torch_to_ms.py \ +--torch_path TORCH_CKPT_DIR \ +--mindspore_path {path} \ +``` + +```text +# 参数说明 +torch_path: torch版本权重保存目录路径 +mindspore_path: 权重保存文件名,可以指定自定义保存路径 +``` + +2.获取MindFormers提供的已转换权重,可直接从下面的链接获取。 + +- [TeleChat3-36b](https://telechat-docker.obs.cn-north-4.myhuaweicloud.com/model_weight/Telechat_36b/Telechat_36b.zip) + +### [分布式权重切分与合并](https://www.mindspore.cn/mindformers/docs/zh-CN/master/index.html#%E5%88%86%E5%B8%83%E5%BC%8F%E6%9D%83%E9%87%8D%E5%88%87%E5%88%86%E4%B8%8E%E5%90%88%E5%B9%B6) + +分布式训练/微调后所得到的权重文件为根据策略切分后的权重,需要手动将切分权重合一,以用于评估和推理。 + +涉及到ckpt的单卡,多卡转换,详细教程请参考特性文档[分布式权重切分与合并](https://www.mindspore.cn/mindformers/docs/zh-CN/master/index.html) + +- step 1. 获取模型切分策略文件: + +在执行微调脚本时,模型完成编译后,将会在`output/strategy`路径下生成各卡的切分策略文件,用于权重合并。 + +- step 2. 运行`mindformers/tools/transform_ckpt.py`脚本进行多卡权重合并: + +```shell +python transform_ckpt.py \ +--src_ckpt_strategy {path}/output/strategy/ \ +--src_ckpt_dir {path}/output/checkpoint/ \ +--dst_ckpt_dir {path}/target_checkpoint/ \ +--prefix telechat_36b +``` + +```text +# 参数说明 +src_ckpt_strategy: 步骤1中的切分策略文件路径 +src_ckpt_dir: 原切分权重文件夹 +dst_ckpt_dir: 目标路径 +prefix: ckpt文件前缀名 +``` + +> 注:`transform_checkpoints` 接口当前仅mindspore 2.0以上版本支持,如当前硬件环境只支持2.0以下版本,可以新建conda环境安装mindspore 2.0的cpu版本以执行该脚本 + +## 微调 + +MindFormers提供`TeleChat3-36b`的微调示例,过程中使用中电信人工智能科技有限公司提供的数据集对模型进行预训练,数据集可以参考[数据集下载](#数据集下载)获得。 + +### 全参微调 + +#### 多机训练 + +- step 1. 修改模型对应的配置文件。 + + 在模型对应的配置文件`research/telechat3/finetune_telechat_36b.yaml`中,用户可自行修改模型、训练相关参数(推荐开启flash_attention,可加速训练),并通过`train_dataset`的`dataset_dir`参数,指定训练数据集的路径。 + + 1. 增加脚本入参`--load_checkpoint /{path}/telechat_36b.ckpt`加载预训练权重 + 2. 设置启动脚本中的`--train_dataset_dir /{path}/dataset.mindrecord`加载微调数据集 + 3. 设置启动脚本中的`--run_mode finetune` + + 配置文件中各参数含义详见[Config配置说明文档](https://gitee.com/mindspore/mindformers/blob/master/configs/README.md)。auto_parallel说明详见[自动并行](../../docs/feature_cards/Auto_Parallel.md)。 + +- step 2. 根据服务器节点数等信息,修改相应的配置。 + + ```yaml + # 以telechat-36b模型8机64卡训练为例,默认配置机4096卡,如果节点数有变,需要修改相应的配置。 + # 配置文件路径:finetune_telechat_36b.yaml + parallel_config: + data_parallel: 1 + model_parallel: 8 + pipeline_stage: 8 + micro_batch_num: 8 + vocab_emb_dp: True + gradient_aggregation_group: 4 + ``` + +- step3. 设置环境变量,变量配置如下: + + ```bash + export ENABLE_CELL_REUSE=1 #编译加速 + export MS_DEV_SIDE_EFFECT_LOAD_ELIM=3 # 去除TensorMove + export MS_MEMORY_POOL_RECYCLE=1 # 内存优化 + export GE_NOT_CUT=1 # 内存优化 + ``` + +- step 4. 执行运行脚本。 + + 在多机上同时拉起任务,每台机器拉起方式参考单机多卡启动方式。 + + ```shell + cd mindformers/ + + # 节点0,节点ip为192.168.1.1,作为主节点,总共16卡且每个节点8卡 + bash scripts/msrun_launcher.sh "python run_mindformer.py \ + --config research/telechat3/finetune_telechat_36b.yaml \ + --train_dataset /{path}/dataset.mindrecord \ + --use_parallel True \ + --register_path ./research/telechat3" \ + 16 8 192.168.1.1 8118 0 output/msrun_log False 300 + + # 节点1,节点ip为192.168.1.2,节点0与节点1启动命令仅参数NODE_RANK不同 + bash scripts/msrun_launcher.sh "python run_mindformer.py \ + --config research/telechat3/finetune_telechat_36b.yaml \ + --train_dataset /{path}/dataset.mindrecord \ + --use_parallel True \ + --register_path ./research/telechat3" \ + 16 8 192.168.1.1 8118 1 output/msrun_log False 300 + ``` + + ```text + # 参数说明 + config: 配置文件路径 + train_dataset: 训练数据集文件夹路径 + use_parallel:开启并行训练 + register_path: 外部模型注册路径 + ``` + +## 推理 + +推理时所需的模型词表可在[模型权重下载与转换](#模型权重下载与转换)章节中下载得到,对应文件为`tokenizer.model`。 + +### 快速推理 + +运行`run_mindformer.py`启动快速推理。 + +#### 参数配置 + +在`predict_telechat_xxx.yaml`中填写`vocab_file`字段 + +```yaml +processor: + tokenizer: + vocab_file: 'path/to/tokenizer.model' +``` + +#### 启动推理 + +- 36b模型2卡推理 + + 默认使用完整权重,开启权重自动转换`auto_trans_ckpt=True`。 + + ```bash + cd mindformers/ + bash scripts/msrun_launcher.sh "python run_mindformer.py \ + --config ./research/telechat3/telechat3-36b/predict_telechat_36b.yaml \ + --load_checkpoint path/to/ckpt_path \ + --predict_data '<_start><_user>生抽与老抽的区别?<_bot>' \ + --auto_trans_ckpt True \ + --use_parallel True \ + --register_path ./research/telechat3" 2 + ``` + +- 参数说明 + + ```text + config: 模型的配置文件 + load_checkpoint: 权重路径 + predict_data: 输入的问题 + auto_tans_ckpt: 权重自动转换开关 + use_parallel: 并行模式开关 + register_path: 外部模型注册路径 + ``` + +#### 推理结果 + +36b 模型推理结果如下: + +```text +生抽与老抽的区别? + +生抽和老抽是两种不同的酱油,它们在风味、色泽和用途上都有所区别。 + +1.颜色:生抽的颜色比较淡,而老抽的颜色较深。生抽的颜色呈红褐色或棕红色,而老抽的颜色则呈棕黑色。 + +2.味道:生抽具有鲜美的咸味和微甜的味浅,而老抽浓郁,颜色较深。根据个人口味和烹饪需求选择不同的酱油类型可以获得更好的口感和菜肴效果。 +``` diff --git a/research/telechat3/convert_reversed.py b/research/telechat3/convert_reversed.py new file mode 100644 index 000000000..78afbe60e --- /dev/null +++ b/research/telechat3/convert_reversed.py @@ -0,0 +1,85 @@ +# Copyright 2025 TeleAI Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +""" +Convert Telechat weight. +Support mindformers format. +""" + +import argparse +import torch + +import mindspore as ms + +from mindformers.tools import logger +from mindformers.utils.convert_utils import ms2pt + +dtype_map = { + 'float32': torch.float32, + 'bfloat16': torch.bfloat16, + 'float16': torch.float16 +} + + +def name_replace(name: str): + """replace ms param name to hf.""" + name = name.replace("model.tok_embeddings.embedding_weight", "transformer.word_embeddings.weight") + name = name.replace('model.embedding_hidden_mapping_in.weight', 'model.embedding_hidden_mapping_in.weight') + name = name.replace('model.embedding_hidden_mapping_out.weight', 'model.embedding_hidden_mapping_out.weight') + name = name.replace("attention_norm.weight", "input_layernorm.weight") + name = name.replace("attention.wo.weight", "self_attn.o_proj.weight") + name = name.replace("attention.wq.weight", "self_attn.q_proj.weight") + name = name.replace("attention.wk.weight", "self_attn.k_proj.weight") + name = name.replace("attention.wv.weight", "self_attn.v_proj.weight") + name = name.replace("feed_forward.w1.weight", "mlp.gate_proj.weight") + name = name.replace("feed_forward.w2.weight", "mlp.down_proj.weight") + name = name.replace("feed_forward.w3.weight", "mlp.up_proj.weight") + name = name.replace("ffn_norm.weight", "post_attention_layernorm.weight") + name = name.replace('model.norm.', 'model.norm_out.') + name = name.replace("lm_head.weight", "lm_head.weight") + return name + + +# pylint: disable=W0613 +def convert_ms_to_pt(input_path, output_path, dtype=None, **kwargs): + """convert telechat ms weight to hf.""" + logger.info(f"Trying to convert mindspore checkpoint in '{input_path}'.") + model_ms = ms.load_checkpoint(input_path) + + state_dict = {} + for name, value in model_ms.items(): + value = ms2pt(value, dtype) + name = name_replace(name) + if name.startswith("model.layers."): + name = name.replace("model.layers.", "transformer.h.") + + state_dict[name] = value + logger.info(f'\rprocessing parameter: {name} {value.shape}') + + torch.save(state_dict, output_path) + logger.info(f"\rConvert telechat checkpoint finished, the huggingface checkpoint is saved in '{output_path}'.") + return True + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--mindspore_path', default='transform.ckpt') + parser.add_argument('--torch_path', default='torch.bin') + parser.add_argument("--dtype", default='float32', choices=['float16', 'float32', 'bfloat16'], + help="Data type for output checkpoint file. Default: float16") + args = parser.parse_args() + torch_dtype = dtype_map.get(args.dtype) + + convert_ms_to_pt(input_path=args.mindspore_path, output_path=args.torch_path, dtype=torch_dtype) diff --git a/research/telechat3/convert_weight.py b/research/telechat3/convert_weight.py new file mode 100644 index 000000000..ed3c4f07c --- /dev/null +++ b/research/telechat3/convert_weight.py @@ -0,0 +1,111 @@ +# Copyright 2025 TeleAI Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +""" +Convert Telechat weight. +Support huggingface format. +""" + +import os +import argparse +from glob import glob + +import torch +from safetensors.torch import load_file + +import mindspore as ms +from mindformers.tools.utils import str2bool +from mindformers.tools import logger +from mindformers.utils.convert_utils import pt2ms + +dtype_map = { + 'float32': ms.float32, + 'bfloat16': ms.bfloat16, + 'float16': ms.float16 +} + + +def name_replace(name: str): + """replace hf param name to ms.""" + name = name.replace('model.embed_tokens.weight', 'model.tok_embeddings.embedding_weight') + name = name.replace('model.embedding_hidden_mapping_in.weight', 'model.embedding_hidden_mapping_in.weight') + name = name.replace('model.embedding_hidden_mapping_out.weight', 'model.embedding_hidden_mapping_out.weight') + name = name.replace('.input_layernorm', '.attention_norm') + name = name.replace('.self_attn.o_proj.', '.attention.wo.') + name = name.replace('.self_attn.q_proj.', '.attention.wq.') + name = name.replace('.self_attn.k_proj.', '.attention.wk.') + name = name.replace('.self_attn.v_proj.', '.attention.wv.') + name = name.replace('.mlp.gate_proj.', '.feed_forward.w1.') + name = name.replace('.mlp.down_proj.', '.feed_forward.w2.') + name = name.replace('.mlp.up_proj.', '.feed_forward.w3.') + name = name.replace('.post_attention_layernorm.', '.ffn_norm.') + name = name.replace('lm_head.', 'lm_head.') + name = name.replace('model.norm.', 'model.norm_out.') + return name + + +# pylint: disable=W0613 +def convert_pt_to_ms(input_path, output_path, dtype=None, **kwargs): + """convert telechat hf weight to ms.""" + files = list(glob(os.path.join(input_path, "pytorch_model*.bin"))) + convert_safetensors = False + if not files: + files = list(glob(os.path.join(input_path, "model*.safetensors"))) + if not files: + raise FileNotFoundError(f"No bin or safetensors found in the model path: {input_path}.") + convert_safetensors = True + files.sort() + pt_states_list = [] + for per_file in files: + if convert_safetensors: + pt_states = load_file(per_file) + else: + pt_states = torch.load(per_file, map_location='cpu') + pt_states_list.append(pt_states) + + ckpt_list = [] + for pt_states in pt_states_list: + for name, value in pt_states.items(): + name = name_replace(name) + if name.startswith('transformer.h.'): + name = name.replace('transformer.h.', 'model.layers.') + logger.info(f'\rprocessing parameter: {name} {value.shape}') + ckpt_list.append({'name': name, 'data': pt2ms(value, dtype)}) + + ms.save_checkpoint(ckpt_list, output_path) + logger.info(f"\rConvert huggingface checkpoint finished, the mindspore checkpoint is saved in '{output_path}'.") + return True + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description="Telechat convert script") + parser.add_argument("--torch_path", + type=str, + default="", + help="The input torch checkpoint path.") + parser.add_argument("--mindspore_path", + type=str, + default="", + help="The output mindspore checkpoint path.") + parser.add_argument("--dtype", default='float32', choices=['float16', 'float32', 'bfloat16'], + help="Data type for output checkpoint file. Default: float16") + parser.add_argument('--qkv_concat', default=False, type=str2bool) + parser.add_argument('--mindspore_ckpt_path', default='transform.ckpt') + parser.add_argument('--pre_ckpt_path', default=None) + parser.add_argument('--model_name', default="telechat_7B", type=str) + args = parser.parse_args() + ms_dtype = dtype_map.get(args.dtype) + + convert_pt_to_ms(input_path=args.torch_path, output_path=args.mindspore_path, dtype=ms_dtype) diff --git a/research/telechat3/telechat.py b/research/telechat3/telechat.py new file mode 100644 index 000000000..c7e939ffa --- /dev/null +++ b/research/telechat3/telechat.py @@ -0,0 +1,567 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Telechat models' APIs.""" +import copy +import numpy as np + +import mindspore.common.dtype as mstype +import mindspore.ops.functional as F +from mindspore import Tensor, nn, mint +from mindspore.context import ParallelMode +from mindspore.ops import operations as P +from mindspore.parallel._utils import _get_parallel_mode, _is_sharding_propagation + +from research.telechat3.telechat_transformer import TelechatDecodeLayer +from research.telechat3.telechat_interleave import TelechatDecodeLayerInterleave +from research.telechat3.telechat_layer import TelechatEmbedding +from research.telechat3.telechat_config import TelechatConfig + +from mindformers.core.loss.loss import CrossEntropyLoss +from mindformers.models.modeling_utils import PreTrainedModel +from mindformers.models.utils import LayerSetting, lazy_inline, check_fine_grain_interleave_valid +from mindformers.models.llama.llama_layer import LlamaRMSNorm +from mindformers.modules.layers import Linear, FreqsMgr +from mindformers.modules.transformer import LowerTriangularMaskWithDynamic +from mindformers.modules.transformer.op_parallel_config import _check_config, default_dpmp_config +from mindformers.tools.logger import logger +from mindformers.tools.register.register import MindFormerModuleType, MindFormerRegister +from mindformers.tools.utils import get_predict_run_mode + + +class TelechatPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = TelechatConfig + base_model_prefix = "telechat" + + +class AttentionMaskWithEod(nn.Cell): + """ + This class generates a two-dimensional attention mask matrix based on the input token sequence and + the end-of-document marker. It not only implements standard causal/padding masking but also ensures + attention isolation between different documents. + """ + def __init__(self, seq_length, parallel_config=default_dpmp_config, compute_dtype=mstype.float16, + use_flash_attention=False, **kwargs): + super().__init__() + dp = parallel_config.data_parallel + self.seq_length = seq_length + self.compute_dtype = compute_dtype + self.use_flash_attention = use_flash_attention + self.one = Tensor([1.0], dtype=compute_dtype) + self.multiply_data = Tensor([-10000.0], dtype=compute_dtype) + #self.arange_tensor = Tensor(np.tile(np.arange(seq_length).reshape(1, 1, seq_length), (1, seq_length, 1)), + # mstype.int16) + self.less = P.LessEqual().shard(((dp, 1, 1), (1, 1, 1))) + self.tile = P.Tile().shard(((dp, 1, 1),)) + self.expand_dims = P.ExpandDims().shard(((dp, 1),)) + self.mul = P.Mul().shard(((1, 1, 1), (dp, 1, 1))) + self.sub = P.Sub().shard(((1,), (dp, 1, 1))) + self.cast = P.Cast() + self.mul_post = P.Mul().shard(((dp, 1, 1, 1), (1,))) + self.expand_dim_post = P.ExpandDims().shard(((dp, 1, 1),)) + self.range = P.Range().shard(((1,),)) + self.less_equal = P.LessEqual().shard(((1, 1, 1), (1, 1, 1))) + self.reshape = P.Reshape() + + def construct(self, sequence_start_ids): + """ + Forward of telechat model. + + Args: + sequence_start_ids: the start_ids of tokenized inputs with datatype int32 + Returns: + output: Tensor + """ + col_indices = self.range(0, self.seq_length, 1) + col_indices = self.reshape(col_indices, (1, 1, self.seq_length)) + + row_indices = self.reshape(col_indices, (1, self.seq_length, 1)) + + lower_triangle_mask = self.less_equal(col_indices, row_indices) + + start_pos = self.cast(self.expand_dims(sequence_start_ids, -1), mstype.int32) + doc_mask = self.less(start_pos, col_indices) + + combined_mask = self.mul(self.cast(lower_triangle_mask, self.compute_dtype), + self.cast(doc_mask, self.compute_dtype)) + return combined_mask + + def post_process(self, mask): + mask = self.sub(self.one, self.cast(mask, self.compute_dtype)) + if not self.use_flash_attention: + mask = self.expand_dim_post(mask, 1) + mask = self.mul_post(mask, self.multiply_data) + else: + mask = self.expand_dim_post(mask, 1) + mask = self.cast(mask, mstype.uint8) + return mask + + +class TelechatModel(TelechatPreTrainedModel): + r""" + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`TelechatDecoderLayer`] + Args: + config(TelechatConfig): the config of network + + Returns: + output: Tensor, the output of telechat decoderlayer + + Examples: + >>> from mindformers import TelechatModel + >>> network = TelechatModel.from_pretrained('telechat_36b') + >>> type(network) + + """ + def __init__(self, config: TelechatConfig=None): + super().__init__(config, auto_prefix=True) + _check_config(config.parallel_config) + self.model_name=config.net_name + self.dtype=config.compute_dtype + self.hidden_size=config.hidden_size + self.embedding_size=config.embedding_size + self.num_layers=config.num_layers + self.n_head=config.num_heads + self.head_dim=self.hidden_size // self.n_head + self.pad_token_id=config.pad_token_id + self.is_first_iteration=True + self.use_past=config.use_past + self.use_flash_attention=config.use_flash_attention + + self.concat = P.Concat(-1) + self.cast = P.Cast() + self.shape = P.Shape() + self.reshape = P.Reshape() + + self.freqs_mgr = FreqsMgr(head_dim=self.head_dim, + seq_length=config.seq_length, + max_position_embedding=config.max_position_embedding, + rotary_dtype=config.rotary_dtype, + theta=config.theta, + scaling_factor=config.scaling_factor, + extend_method=config.extend_method, + parallel_config=config.parallel_config, + is_dynamic=config.is_dynamic) + if config.eod_reset: + self.casual_mask = AttentionMaskWithEod(seq_length=config.seq_length, + parallel_config=config.parallel_config, + compute_dtype=config.compute_dtype, + use_flash_attention=config.use_flash_attention) + else: + self.casual_mask = LowerTriangularMaskWithDynamic(seq_length=config.seq_length, + compute_type=config.compute_dtype, + is_dynamic=config.is_dynamic, + pad_token_id=config.pad_token_id, + use_flash_attention=config.use_flash_attention, + use_attn_mask_compression=config.use_attn_mask_compression, + use_past=config.use_past) + self.tok_embeddings = TelechatEmbedding(vocab_table_size=config.vocab_size, + sigma=config.sigma, + mean=config.mean, + embedding_size=config.hidden_size, + param_init_type=config.embedding_init_type, + parallel_optimizer=config.parallel_optimizer) + self.norm_out = LlamaRMSNorm(config.hidden_size, config.rms_norm_eps, + compute_type=config.layernorm_compute_type) + self.fine_grain_interleave = check_fine_grain_interleave_valid(config.fine_grain_interleave, + config.parallel_config) + self.layers = nn.CellList() + self.layer_setting = LayerSetting(config.num_layers, + config.offset, + config.parallel_config, + config.pp_interleave_num) + for layer_id in range(config.num_layers): + if self.fine_grain_interleave: + layer = TelechatDecodeLayerInterleave(config.seq_length, + layer_id, + dim=config.hidden_size, + n_heads=config.num_heads, + num_layers=config.num_layers, + n_kv_heads=config.n_kv_heads, + intermediate_size=config.intermediate_size, + ffn_dim_multiplier=config.ffn_dim_multiplier, + norm_eps=config.rms_norm_eps, + qkv_has_bias=config.qkv_has_bias, + out_proj_has_bias=config.out_proj_has_bias, + compute_dtype=config.compute_dtype, + layernorm_compute_dtype=config.layernorm_compute_type, + softmax_compute_dtype=config.softmax_compute_type, + rotary_dtype=config.rotary_dtype, + param_init_type=config.param_init_type, + res_dtype=config.res_dtype, + use_flash_attention=config.use_flash_attention, + use_attn_mask_compression=config.use_attn_mask_compression, + use_rope_slice=config.use_rope_slice, + fine_grain_interleave=config.fine_grain_interleave, + parallel_config=config.parallel_config) + else: + layer = TelechatDecodeLayer(layer_id, + dim=config.hidden_size, + n_heads=config.num_heads, + n_kv_heads=config.n_kv_heads, + sigma=config.sigma, + mean=config.mean, + moe_config=config.moe_config, + intermediate_size=config.intermediate_size, + multiple_of=config.multiple_of, + ffn_dim_multiplier=config.ffn_dim_multiplier, + norm_eps=config.rms_norm_eps, + qkv_has_bias=config.qkv_has_bias, + out_proj_has_bias=config.out_proj_has_bias, + qkv_concat=config.qkv_concat, + compute_dtype=config.compute_dtype, + layernorm_compute_dtype=config.layernorm_compute_type, + softmax_compute_dtype=config.softmax_compute_type, + rotary_dtype=config.rotary_dtype, + param_init_type=config.param_init_type, + res_dtype=config.res_dtype, + use_past=config.use_past, + use_flash_attention=config.use_flash_attention, + use_attn_mask_compression=config.use_attn_mask_compression, + block_size=config.block_size, + num_blocks=config.num_blocks, + is_dynamic=config.is_dynamic, + use_rope_slice=config.use_rope_slice, + parallel_config=config.parallel_config) + self.layers.append(layer) + dp = config.parallel_config.data_parallel + sp = config.parallel_config.context_parallel + self.sp = sp + self.expert_num = 1 if config.moe_config is None else config.moe_config.expert_num + if not (_get_parallel_mode() in (ParallelMode.AUTO_PARALLEL,) and _is_sharding_propagation()): + self.tok_embeddings.pipeline_stage = 0 + if config.parallel_config.pipeline_stage > 1: + self.norm_out.pipeline_stage = config.parallel_config.pipeline_stage - 1 + self.tok_embeddings.set_comm_fusion(2) + self.norm_out.set_comm_fusion(2) + else: + self.tok_embeddings.set_comm_fusion(config.parallel_config.gradient_aggregation_group) + self.norm_out.set_comm_fusion(config.parallel_config.gradient_aggregation_group) + + self.tok_embeddings.shard(config.parallel_config) + #self.casual_mask.shard(config.parallel_config) + self.concat.shard(((dp, 1, 1, 1), (dp, 1, 1, 1))) + if self.fine_grain_interleave: + self.norm_out.shard((dp * sp, 1)) + else: + self.norm_out.shard((dp, sp, 1)) + + # pylint: disable=W0613 + def construct(self, tokens: Tensor, batch_valid_length=None, batch_index=None, zactivate_len=None, aux_loss=None, + block_tables=None, attention_mask=None, slot_mapping=None, prefix_keys_values=None): + """ + Forward of telechat model. + + Args: + tokens: the tokenized inputs with datatype int32 + batch_valid_length(Tensor): the past calculated the index with datatype int32, used for incremental + prediction. Tensor of shape :math:`(batch_size,)`. Default None. + block_tables (Tensor[int64]): Store mapping tables for each sequence. + slot_mapping (Tensor[int32]): Store token cache physical slot index. + Returns: + output: Tensor, the output of telechat decoderlayer + """ + # preprocess + bs, seq_len = self.shape(tokens) + mask = attention_mask + if self.use_past: + if self.is_first_iteration: + freqs_cis = self.freqs_mgr.prefill(bs, seq_len) + mask = self.casual_mask.prefill() + if prefix_keys_values is not None: + if mask is None: + mask = self.casual_mask(tokens) + prefix_length = prefix_keys_values[0].shape[2] + prefix_mask = Tensor(np.zeros((bs, 1, seq_len, prefix_length)), dtype=mask.dtype) + mask = self.concat((prefix_mask, mask)) + else: + freqs_cis = self.freqs_mgr.increment(batch_valid_length) + else: + if attention_mask is not None: + if not self.config.eod_reset: + mask = attention_mask + mask = self.cast(mask, mstype.uint8) + freqs_cis = self.freqs_mgr(seq_len) + else: + mask = self.casual_mask(mask) + mask = self.casual_mask.post_process(mask) + freqs_cis = self.freqs_mgr(seq_len) + else: + mask = self.casual_mask(tokens) + freqs_cis = self.freqs_mgr(seq_len) + if prefix_keys_values is not None: + prefix_length = prefix_keys_values[0].shape[2] + prefix_mask = Tensor(np.zeros((bs, 1, seq_len, prefix_length)), dtype=mask.dtype) + mask = self.concat((prefix_mask, mask)) + + # tokens: [bs, seq/1] + h = self.tok_embeddings(tokens) + h = self.reshape(h, (bs, seq_len, -1)) + # h: [bs, seq/1, hidden_dim] + for i in range(self.num_layers): + prefix_kv = prefix_keys_values[i] if prefix_keys_values is not None else None + if self.expert_num > 1: + h, aux_loss = self.layers[i](h, freqs_cis, mask, batch_valid_length=batch_valid_length, + block_tables=block_tables, aux_loss=aux_loss, slot_mapping=slot_mapping, + prefix_keys_values=prefix_kv) + else: + h = self.layers[i](h, freqs_cis, mask, batch_valid_length=batch_valid_length, block_tables=block_tables, + slot_mapping=slot_mapping, prefix_keys_values=prefix_kv) + output = self.norm_out(h) + return output + + +class TelechatHead(nn.Cell): + """Head for Telechat to get the logits of each token in the vocab.""" + def __init__(self, + in_channels, + out_channels, + compute_dtype="float16", + parallel_config=None): + super().__init__() + copied_parallel_config = copy.deepcopy(parallel_config) + self.in_channels = in_channels + self.out_channels = out_channels + self.dtype = compute_dtype + self.cast = P.Cast() + self.reshape = P.Reshape() + dp = copied_parallel_config.data_parallel + mp = copied_parallel_config.model_parallel + sp = copied_parallel_config.context_parallel + if parallel_config.vocab_emb_dp or (out_channels % mp != 0): + self.matmul = P.MatMul(transpose_b=True).shard(((dp * sp, 1), (1, 1))) + else: + self.matmul = P.MatMul(transpose_b=True).shard(((dp * sp, 1), (mp, 1))) + + def construct(self, x, embedding_weight=None): + out_shape = P.Shape()(x)[:-1] + (self.out_channels,) + x = self.reshape(x, (-1, self.in_channels)) + ori_dtype = F.dtype(x) + weight = self.cast(embedding_weight, self.dtype) + x = self.cast(x, self.dtype) + x = self.matmul(x, weight) + x = self.cast(x, ori_dtype) + output = self.reshape(x, out_shape) + return output + + +@MindFormerRegister.register(MindFormerModuleType.MODELS) +class TelechatForCausalLM(TelechatPreTrainedModel): + """ + Provide telechat training loss or logits through network. + + Args: + config (TelechatConfig): The config of telechat model. + + Returns: + output: Tensor, the output of telechat decoderlayer + + Examples: + >>> from mindformers.models.telechat import TelechatConfig, TelechatForCausalLM + >>> config = TelechatConfig(batch_size=2) + >>> network = TelechatForCausalLM(config=config) + >>> type(network) + + >>> from mindformers import TelechatForCausalLM + >>> network = TelechatForCausalLM.from_pretrained('telechat_115b') + >>> type(network) + + """ + + @lazy_inline + def __init__(self, config: TelechatConfig = None): + super().__init__(config, auto_prefix=True) + _check_config(config.parallel_config) + self.config = config + self.model_name = config.net_name + self.ignore_token_id = config.ignore_token_id + self.pad_token_id = config.pad_token_id + self.use_past = config.use_past + self.vocab_size = config.vocab_size + self.rl_config = config.rl_config + self.is_first_iteration = True + + self.dp = config.parallel_config.data_parallel + self.mp = config.parallel_config.model_parallel + self.expert_num = config.moe_config.expert_num + self.init_aux_loss = Tensor(np.zeros([self.dp * self.mp, self.expert_num]), mstype.float32) + self.shape = P.Shape() + self.reshape = P.Reshape() + self.cast = P.Cast() + self.slice = P.StridedSlice() + self.not_equal = P.NotEqual() + self.mul = P.Mul() + self.add = P.Add() + self.ones = P.Ones() + self.gather = P.Gather() + self.sub_batch_valid_len = P.Sub() + self.model = TelechatModel(config=config) + self.lm_head = Linear(in_channels=config.hidden_size, + out_channels=config.vocab_size, + has_bias=False, + compute_dtype=config.compute_dtype, + param_init_type=config.param_init_type, + weight_init="normal") # meta default: xavier_normal + + mp = config.parallel_config.model_parallel + sp = config.parallel_config.context_parallel + vocab_size = config.vocab_size + loss_parallel_config = copy.deepcopy(config.parallel_config) + if vocab_size % mp != 0: + logger.warning("The vocab size of Loss is: %s, it is not divide by model_parallel: %s", + vocab_size, mp) + logger.warning("Now, the model_parallel num of Loss will be changed: mp = 1") + loss_parallel_config.model_parallel = 1 + loss_parallel_config.data_parallel *= loss_parallel_config.context_parallel + check_for_nan_in_loss_and_grad = getattr(config, "check_for_nan_in_loss_and_grad", False) + self.loss = CrossEntropyLoss(parallel_config=loss_parallel_config, + check_for_nan_in_loss_and_grad=check_for_nan_in_loss_and_grad) + + dp = config.parallel_config.data_parallel + mp = config.parallel_config.model_parallel + self.aux_reduce_mean = P.ReduceMean(keep_dims=True).shard(((1, 1),)) + if not (_get_parallel_mode() in (ParallelMode.AUTO_PARALLEL,) and _is_sharding_propagation()): + self.slice.shard(((dp, 1),)) + self.not_equal.shard(((dp, 1), ())) + self.mul.shard(((dp, 1), (dp, 1))) + self.add.shard(((dp, 1), ())) + self.gather.shard(((dp, 1, 1), (dp,))) + self.sub_batch_valid_len.shard(((1,), ())) + if config.parallel_config.vocab_emb_dp or (vocab_size % mp != 0): + self.lm_head.shard(strategy_matmul=((dp * sp, 1), (1, 1))) + else: + self.lm_head.shard(strategy_matmul=((dp * sp, 1), (mp, 1))) + if config.parallel_config.pipeline_stage > 1: + self.lm_head.pipeline_stage = config.parallel_config.pipeline_stage - 1 + + self.predict_run_mode = get_predict_run_mode() + + logger.info(f"Predict run mode:{self.predict_run_mode}") + + # pylint: disable=W0613 + def prepare_inputs_for_predict_layout(self, input_ids, **kwargs): + """Get Telechat model input tuple for transform ckpt.""" + input_ids = Tensor(input_ids, mstype.int32) + labels = Tensor(kwargs["labels"]) if "labels" in kwargs else None + bs, seq = input_ids.shape[0], input_ids.shape[1] + slot_mapping = Tensor(np.ones(shape=tuple([bs * seq])), mstype.int32) + prefix_keys_values = Tensor(kwargs["prefix_keys_values"]) if "prefix_keys_values" in kwargs else None + return input_ids, labels, None, None, None, None, None, None, None, None, None, None, \ + slot_mapping, prefix_keys_values + + def set_dynamic_inputs(self, **kwargs): + """Set dynamic inputs""" + dynamic_input_ids = Tensor(shape=[None, None], dtype=mstype.int32) + dynamic_batch_valid_length = Tensor(shape=[None, None], dtype=mstype.int32) + dynamic_block_tables = Tensor(shape=[None, None], dtype=mstype.int32) + dynamic_slot_mapping = Tensor(shape=[None], dtype=mstype.int32) + have_prefix_keys_values = getattr(kwargs, "have_prefix_keys_values", False) + if have_prefix_keys_values: + dynamic_prefix_keys_values = Tensor(shape=[2, None, None, None, None], dtype=mstype.float16) + self.set_inputs(dynamic_input_ids, None, None, None, None, None, None, None, + dynamic_batch_valid_length, None, None, dynamic_block_tables, + dynamic_slot_mapping, dynamic_prefix_keys_values) + else: + self.set_inputs(dynamic_input_ids, None, None, None, None, None, None, None, + dynamic_batch_valid_length, None, None, dynamic_block_tables, + dynamic_slot_mapping, None) + logger.info("Set dynamic input for telechat.") + + def add_flags_custom(self, is_first_iteration): + """Add customized attributes for specific cells in the model.""" + self.add_flags(is_first_iteration=is_first_iteration) + self.model.add_flags(is_first_iteration=is_first_iteration) + for layer in self.model.layers: + layer.add_flags(is_first_iteration=is_first_iteration) + layer.attention.infer_attention.add_flags(is_first_iteration=is_first_iteration) + + # pylint: disable=W0613 + def construct(self, input_ids, labels=None, loss_mask=None, input_position=None, position_ids=None, + attention_mask=None, input_embeds=None, init_reset=None, batch_valid_length=None, + batch_index=None, zactivate_len=None, block_tables=None, slot_mapping=None, + prefix_keys_values=None): + r""" + TelechatForCausalLM forward. + + Args: + input_ids(Tensor): the tokenized inputs with datatype int32, Tensor of shape :math:`(batch, seq\_length)`. + labels(Tensor): the tokenized labels with datatype int32, Tensor of shape :math:`(batch, seq\_length)`. + input_position(Tensor): current position, used by model.predict. + position_ids(Tensor): Reserved param, not used. + attention_mask(Tensor): Reserved param, not used. + input_embeds(Tensor): Reserved param, not used. + init_reset(bool, optional): A bool tensor with shape [1], used to clear the past key parameter and + past value parameter used in the incremental prediction. Default True. + batch_valid_length(Tensor): the past calculated the index with datatype int32, used for incremental + prediction. Tensor of shape :math:`(batch_size,)`. Default None. + block_tables (Tensor[int64]): Store mapping tables for each sequence. + slot_mapping (Tensor[int32]): Store token cache physical slot index. + Returns: + Tensor: The loss or (logits, tokens, input_mask) of the network. + """ + bsz, seqlen = self.shape(input_ids) + aux_loss = None + if self.use_past: + if not isinstance(batch_valid_length, Tensor): + batch_valid_length = self.ones((bsz,), mstype.int32) + tokens = input_ids + if batch_valid_length is not None: + batch_valid_length = self.reshape(batch_valid_length, (-1,)) + if self.expert_num == 1: + output = self.model(tokens, batch_valid_length, batch_index, zactivate_len, block_tables=block_tables, \ + attention_mask=attention_mask, slot_mapping=slot_mapping, \ + prefix_keys_values=prefix_keys_values) + pre_gather = (not self.use_past or self.is_first_iteration) and batch_valid_length is not None + if pre_gather: + batch_valid_length = mint.cumsum(batch_valid_length, 0) + output = self.gather(output, self.sub_batch_valid_len(batch_valid_length, 1), 1) + logits = self.lm_head(output) + + if self.rl_config is not None: + return logits + + input_mask = self.cast(self.not_equal(tokens, self.pad_token_id), mstype.float32) + if labels is None: + labels = self.slice(input_ids, (0, 1), (bsz, seqlen), (1, 1)) + else: + if labels.ndim > 1: + label_mask = self.cast(self.not_equal(labels, self.ignore_token_id), mstype.float32) + input_mask = self.mul(input_mask, label_mask) + + if not self.training: + logits = self.cast(logits, mstype.float32) + if self.predict_run_mode: + logits = self.reshape(logits, (-1, logits.shape[-1])) + return logits + return logits, tokens, input_mask + + if logits.ndim > 2: + logits = self.reshape(logits, (-1, logits.shape[-1])) + logits = self.cast(logits, mstype.float32) + labels = self.reshape(labels, (-1,)) + input_mask = self.reshape(input_mask, (-1,)) + loss = self.loss(logits, labels, input_mask) + if self.expert_num > 1: + aux_loss = self.aux_reduce_mean(aux_loss).reshape(-1) + loss = loss + aux_loss + return loss + + def kvcache(self, layer_idx): + key_cache = self.model.layers[layer_idx].attention.infer_attention.paged_attention_mgr.key_cache + value_cache = self.model.layers[layer_idx].attention.infer_attention.paged_attention_mgr.value_cache + return key_cache, value_cache diff --git a/research/telechat3/telechat3-36b/finetune_telechat3_36b.yaml b/research/telechat3/telechat3-36b/finetune_telechat3_36b.yaml new file mode 100644 index 000000000..e8d7ad0de --- /dev/null +++ b/research/telechat3/telechat3-36b/finetune_telechat3_36b.yaml @@ -0,0 +1,194 @@ +seed: 0 +output_dir: './output' +load_checkpoint: '' +src_strategy_path_or_dir: '' +auto_trans_ckpt: False +only_save_strategy: False +resume_training: False +# transform_process_num: 256 +ignore_data_skip: False +run_mode: 'finetune' + +# trainer config +trainer: + type: CausalLanguageModelingTrainer + model_name: 'telechat3_36b' + +# runner config +runner_config: + epochs: 10 + batch_size: 1 + sink_mode: True + sink_size: 1 + gradient_accumulation_steps: 1 + +# optimizer +optimizer: + type: AdamW + betas: [0.9, 0.95] + eps: 1.e-8 + weight_decay: 0.1 + +# lr sechdule +lr_schedule: + type: CosineWithWarmUpLR + learning_rate: 2.e-4 + lr_end: 2.e-5 + warmup_steps: 2000 + total_steps: 1908000 # -1 means it will load the total steps of the dataset + +# dataset +train_dataset: &train_dataset + data_loader: + type: MindDataset + dataset_dir: "" + shuffle: False + input_columns: ["input_ids", "labels", "attention_mask"] + construct_args_key: ["input_ids", "labels", "attention_mask"] + num_parallel_workers: 8 + python_multiprocessing: False + drop_remainder: True + batch_size: 6 + repeat: 1 + numa_enable: False + prefetch_size: 1 +train_dataset_task: + type: CausalLanguageModelDataset + dataset_config: *train_dataset +# if True, do evaluate during the training process. if false, do nothing. +# note that the task trainer should support _evaluate_in_training function. + +use_parallel: True +# parallel context config +parallel: + parallel_mode: 1 # 0-data parallel, 1-semi-auto parallel, 2-auto parallel, 3-hybrid parallel + gradients_mean: False + enable_alltoall: False + full_batch: True + enable_parallel_optimizer: True + strategy_ckpt_save_file: "./ckpt_strategy.ckpt" + pipeline_config: + pipeline_interleave: True + pipeline_scheduler: 'seqpipe' + parallel_optimizer_config: + gradient_accumulation_shard: False + parallel_optimizer_threshold: 64 +parallel_config: + data_parallel: 128 + model_parallel: 8 + pipeline_stage: 2 + use_seq_parallel: True + micro_batch_num: 4 + vocab_emb_dp: False + gradient_aggregation_group: 4 +# when model parallel is greater than 1, we can set micro_batch_interleave_num=2, that may accelerate the train process. +micro_batch_interleave_num: 1 + +# recompute config +recompute_config: + recompute: False + select_comm_recompute: False + select_recompute: False + + parallel_optimizer_comm_recompute: False + mp_comm_recompute: False + recompute_slice_activation: False + +# callbacks +callbacks: + - type: MFLossMonitor + - type: CheckpointMonitor + prefix: "telechat3_36b" + save_checkpoint_steps: 1500 + keep_checkpoint_max: 300 + integrated_save: False + async_save: False + +# mindspore context init config +context: + mode: 0 #0--Graph Mode; 1--Pynative Mode + device_target: "Ascend" + max_call_depth: 10000 + max_device_memory: "56.5GB" + save_graphs: False + save_graphs_path: "./graph" + device_id: 0 + jit_config: {"jit_level":"O1"} + +model: + model_config: + type: TelechatConfig + auto_register: telechat_config.TelechatConfig + batch_size: 1 # add for increase predict + net_name: 'telechat3_36b' + seq_length: 8192 + hidden_size: 6144 + embedding_size: 1 + num_layers: 64 + layers_group: 1 + num_heads: 48 + n_kv_heads: 8 + sigma: 0.0048 + mean: 0.0 + vocab_size: 131072 + rms_norm_eps: 1.0e-5 + bos_token_id: 1 + eos_token_id: 2 + pad_token_id: 3 + eod_reset: True + pp_interleave_num: 2 + out_proj_has_bias: False + ignore_token_id: -100 + embed_dropout_prob: 0. + hidden_dropout_prob: 0. + attention_dropout_prob: 0. + intermediate_size: 24576 #8192 + res_dtype: "float32" + compute_dtype: "bfloat16" + layernorm_compute_type: "float32" + softmax_compute_type: "float32" + rotary_dtype: "float32" + param_init_type: "float32" + max_position_embeddings: 8192 + scaling_factor: + beta_fast: 32.0 + beta_slow: 1.0 + factor: 16.0 + mscale: 1.0 + mscale_all_dim: 1.0 + original_max_position_embeddings: 8192 #4096 + theta: 10000 + use_past: False + pretrain_seqlen: 8192 # seqlen of the pretrain checkpoint + extend_method: "None" # support "None", "PI", "NTK" + parallel_optimizer: True + use_flash_attention: True # FA can accelerate training or finetune + use_past_shard: False + repetition_penalty: 1 + max_decode_length: 512 + fine_grain_interleave: 1 + top_k: 3 + top_p: 1 + do_sample: False + arch: + type: TelechatForCausalLM + auto_register: telechat.TelechatForCausalLM + +# wrapper cell config +runner_wrapper: + type: MFTrainOneStepCell + scale_sense: 1.0 + use_clip_grad: True + +profile: False +profile_start_step: 1 +profile_stop_step: 10 +init_start_profile: False +profile_communication: False +profile_memory: True +layer_scale: False +layer_decay: 0.65 +lr_scale_factor: 256 + +# aicc +remote_save_url: "Please input obs url on AICC platform." diff --git a/research/telechat3/telechat3-36b/predict_telechat3_36b.yaml b/research/telechat3/telechat3-36b/predict_telechat3_36b.yaml new file mode 100644 index 000000000..82cc10a94 --- /dev/null +++ b/research/telechat3/telechat3-36b/predict_telechat3_36b.yaml @@ -0,0 +1,133 @@ +seed: 0 +output_dir: './output' +load_checkpoint: '' +src_strategy_path_or_dir: '' +auto_trans_ckpt: False +only_save_strategy: False +resume_training: False +run_mode: 'predict' + +# trainer config +trainer: + type: CausalLanguageModelingTrainer + model_name: 'telechat_35b' + +use_parallel: True +# parallel context config +parallel: + parallel_mode: 1 # 0-data parallel, 1-semi-auto parallel, 2-auto parallel, 3-hybrid parallel + gradients_mean: False + enable_alltoall: False + full_batch: True + search_mode: "sharding_propagation" + strategy_ckpt_save_file: "./ckpt_strategy.ckpt" + parallel_optimizer_config: + gradient_accumulation_shard: False + parallel_optimizer_threshold: 64 +parallel_config: + data_parallel: 1 + model_parallel: 2 + pipeline_stage: 1 + use_seq_parallel: False + micro_batch_num: 1 + vocab_emb_dp: False + gradient_aggregation_group: 4 + +# mindspore context init config +context: + mode: 0 #0--Graph Mode; 1--Pynative Mode + device_target: "Ascend" + max_call_depth: 10000 + max_device_memory: "58GB" + save_graphs: False + save_graphs_path: "./graph" + device_id: 0 + mempool_block_size: "58GB" + affinity_cpu_list: "None" + +# model config +model: + model_config: + type: TelechatConfig + auto_register: telechat_config.TelechatConfig + batch_size: 1 # add for increase predict + net_name: "telechat3_36b" + seq_length: 8192 + hidden_size: 6144 + num_layers: 64 + layers_group: 1 + num_heads: 48 + n_kv_heads: 8 + embedding_size: 512 + vocab_size: 131072 + eod_reset: False + out_proj_has_bias: False + rms_norm_eps: 1.0e-5 + bos_token_id: 1 + eos_token_id: 2 + pad_token_id: 3 + ignore_token_id: -100 + embed_dropout_prob: 0. + hidden_dropout_prob: 0. + attention_dropout_prob: 0. + intermediate_size: 24576 + res_dtype: "bfloat16" + compute_dtype: "bfloat16" + layernorm_compute_type: "float32" + softmax_compute_type: "float32" + rotary_dtype: "bfloat16" + param_init_type: "bfloat16" + use_past: True + pretrain_seqlen: 8192 # seqlen of the pretrain checkpoint + extend_method: "None" # support "None", "PI", "NTK" + use_flash_attention: True # FA can accelerate training or finetune + block_size: 16 + num_blocks: 512 + is_dynamic: True + use_past_shard: False + repetition_penalty: 1 + max_decode_length: 512 + top_k: 3 + top_p: 1 + do_sample: False + auto_map: + AutoModel: telechat.TelechatForCausalLM + AutoConfig: telechat_config.TelechatConfig + AutoTokenizer: [telechat_tokenizer.TelechatTokenizer, null] + arch: + type: TelechatForCausalLM + auto_register: telechat.TelechatForCausalLM + +processor: + return_tensors: ms + tokenizer: + vocab_file: "" + unk_token: '' + bos_token: '<_start>' + eos_token: '<_end>' + pad_token: '<_pad>' + type: TelechatTokenizer + auto_register: telechat_tokenizer.TelechatTokenizer + +# wrapper cell config +runner_wrapper: + type: MFTrainOneStepCell + scale_sense: 1.0 + use_clip_grad: True + +auto_tune: False +filepath_prefix: './autotune' +autotune_per_step: 10 + +profile: False +profile_start_step: 1 +profile_stop_step: 10 +init_start_profile: False +profile_communication: False +profile_memory: True +layer_scale: False +layer_decay: 0.65 +lr_scale_factor: 256 + +# aicc +remote_save_url: "Please input obs url on AICC platform." diff --git a/research/telechat3/telechat3-36b/predict_telechat3_36b_parallel.yaml b/research/telechat3/telechat3-36b/predict_telechat3_36b_parallel.yaml new file mode 100644 index 000000000..75ec21a87 --- /dev/null +++ b/research/telechat3/telechat3-36b/predict_telechat3_36b_parallel.yaml @@ -0,0 +1,91 @@ +seed: 0 +output_dir: './output' +load_checkpoint: '' +src_strategy_path_or_dir: '' +auto_trans_ckpt: False +run_mode: 'predict' + +# trainer config +trainer: + type: CausalLanguageModelingTrainer + model_name: 'telechat_35b' + +use_parallel: True +# parallel context config +parallel: + parallel_mode: "STAND_ALONE" # 0-data parallel, 1-semi-auto parallel, 2-auto parallel, 3-hybrid parallel + full_batch: False + strategy_ckpt_save_file: "./ckpt_strategy.ckpt" +parallel_config: + data_parallel: 1 + model_parallel: 2 + vocab_emb_dp: False + +# mindspore context init config +context: + mode: 0 #0--Graph Mode; 1--Pynative Mode + max_device_memory: "58GB" + device_id: 0 + save_graphs: False + save_graphs_path: "./graph" + mempool_block_size: "58GB" + affinity_cpu_list: "None" + +# model config +model: + model_config: + type: TelechatConfig + batch_size: 1 # add for increase predict + net_name: "telechat3_36b" + seq_length: 8192 + hidden_size: 6144 + num_layers: 64 + layers_group: 1 + num_heads: 48 + n_kv_heads: 8 + embedding_size: 512 + vocab_size: 131072 + eod_reset: False + out_proj_has_bias: False + rms_norm_eps: 1.0e-5 + bos_token_id: 1 + eos_token_id: 2 + pad_token_id: 3 + ignore_token_id: -100 + intermediate_size: 24576 + res_dtype: "bfloat16" + compute_dtype: "bfloat16" + layernorm_compute_type: "float32" + softmax_compute_type: "float32" + rotary_dtype: "float32" + param_init_type: "bfloat16" + use_past: True + pretrain_seqlen: 8192 # seqlen of the pretrain checkpoint + extend_method: "None" # support "None", "PI", "NTK" + use_flash_attention: True # FA can accelerate training or finetune + block_size: 16 + num_blocks: 512 + is_dynamic: True + qkv_concat: False + use_past_shard: False + repetition_penalty: 1.03 + max_decode_length: 512 + top_k: 3 + top_p: 1 + do_sample: False + auto_map: + AutoModel: telechat.ParallelTelechatForCausalLM + AutoConfig: telechat_config.TelechatConfig + AutoTokenizer: [telechat_tokenizer.TelechatTokenizer, null] + arch: + type: ParallelTelechatForCausalLM + +processor: + return_tensors: ms + tokenizer: + vocab_file: "" + unk_token: '' + bos_token: '<_start>' + eos_token: '<_end>' + pad_token: '<_pad>' + type: TelechatTokenizer diff --git a/research/telechat3/telechat_config.py b/research/telechat3/telechat_config.py new file mode 100644 index 000000000..6ad859603 --- /dev/null +++ b/research/telechat3/telechat_config.py @@ -0,0 +1,214 @@ +# Copyright 2025 TeleAI Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Telechat Config API.""" + +from typing import Optional, Union + +from mindspore._checkparam import args_type_check + +from mindformers.modules.transformer.transformer import default_transformer_config, \ + TransformerOpParallelConfig +from mindformers.tools.register import MindFormerRegister, MindFormerModuleType +from mindformers.models.configuration_utils import PretrainedConfig +from mindformers.models.utils import convert_mstype + + +@MindFormerRegister.register(MindFormerModuleType.CONFIG) +class TelechatConfig(PretrainedConfig): + """ + Telechat config class which defines the model size. + + Args: + batch_size (Optional[int]): batch size for input data, use in predict. + seq_length (Optional[int]): The sequence length of input_ids, default is 1024. + vocab_size (`int`, *optional*, defaults to 50257): + Vocabulary size of the BERT model. + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + num_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + multiple_of (Optional[int]): Define SwiGLU hidden layer size multiples, default 256. + n_kv_heads (Optional[int]): Define multi group head attention heads number, default None. + ffn_dim_multiplier (Optional[int]): Define ffn layer dim multiples, default None. + rms_norm_eps (Optional[float]): The epsilon value of the denominator. Default 1e-5. + bos_token_id (Optional[int]): The id of the *beginning-of-sequence* token. + eos_token_id (Optional[int]): The id of the *end-of-sequence* token. + pad_token_id (Optional[int]): The id of the *padding* token. + ignore_token_id (Optional[int]): The id of the *ignoring* token. + compute_dtype (Optional[str]): + Linear layer compute dtype, default is "float16". + layernorm_compute_type (Optional[str]): + layernorm compute dtype, default is "float32". + softmax_compute_type (Optional[str]): + softmax compute dtype, default is "float32". + rotary_dtype (Optional[str]): + rope compute dtype, default is "float32". + param_init_type (Optional[str]): + parameter initial dtype, default is "float16". + qkv_has_bias (Optional[bool]): + Whether the Query, Key, and Value projection has bias. + use_past (`bool`, *optional*, defaults to `False`): + Whether the model should use the past last key/values attentions + (if applicable to the model) to speed up decoding. + with default values. Please see `MoEConfig`. + parallel_config(TransformerOpParallelConfig): + The parallel configure. Default `default_transformer_config`, + an instance of `TransformerOpParallelConfig` with default args. + extend_method(str): The extend method of seq length of inferencem,default None. + use_flash_attention(bool): Whether enable flash attention ops, default False. + offset(int): Offset of transformer layer when set pipeline stage number. + checkpoint_name_or_path (Optional[str]): + checkpoint path or name used to load to the network. + repetition_penalty (`float`, *optional*, defaults to 1.0): + The parameter for repetition penalty. 1.0 means no penalty. See [this + paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. + max_decode_length (`int`, *optional*, defaults to 1024): + The maximum length the generated tokens can have. Corresponds to the length of the input prompt + + `max_new_tokens`. Its effect is overridden by `max_new_tokens`, if also set. + top_k (`int`, *optional*, defaults to 5): + The number of highest probability vocabulary tokens to keep for top-k-filtering. + top_p (`float`, *optional*, defaults to 1.0): + If set to float < 1, only the smallest set of most probable tokens with probabilities + that add up to `top_p` or higher are kept for generation. + do_sample (`bool`, *optional*, defaults to `False`): + Whether or not to use sampling ; use greedy decoding otherwise. + block_size (`int`, *optional*, defaults to 16): + The maximum number of tokens in one block can have when using paged attention. + num_blocks (`int`, *optional*, defaults to 512): + The maximum number of blocks when using paged attention. + Returns: + Class, TelechatConfig. + """ + + model_type = "telechat" + + @args_type_check(parallel_config=(dict, TransformerOpParallelConfig)) + def __init__(self, batch_size: int = 1, + seq_length: int = 2048, + hidden_size: int = 4096, + num_layers: int = 32, + num_heads: int = 32, + n_kv_heads: Optional[int] = None, + max_position_embedding: Optional[int] = None, + intermediate_size: Optional[int] = None, + vocab_size: int = 32000, # defined later by tokenizer + multiple_of: int = 256, # make SwiGLU hidden layer size multiple of large power of 2 + ffn_dim_multiplier: Optional[int] = None, + rms_norm_eps: float = 1e-5, + bos_token_id: int = 1, + eos_token_id: int = 2, + pad_token_id: int = 0, + embedding_size: int = 1, + ignore_token_id: int = -100, + theta: float = 10000.0, + compute_dtype: str = "float16", + layernorm_compute_type: str = "float32", + softmax_compute_type: str = "float32", + rotary_dtype: str = "float32", + param_init_type: str = "float16", + embedding_init_type=None, + res_dtype: str = "float32", + qkv_has_bias: bool = False, + out_proj_has_bias: bool = True, + qkv_concat: bool = False, + parallel_config: Union[dict, TransformerOpParallelConfig] = default_transformer_config, + use_past: bool = False, + extend_method: str = "None", + scaling_factor: float = 1.0, + is_dynamic: bool = False, + use_rope_slice: bool = False, + use_flash_attention: bool = False, + use_attn_mask_compression: bool = False, + parallel_optimizer: bool = False, + fine_grain_interleave: int = 1, + pp_interleave_num: int = 1, + offset: int = 0, + checkpoint_name_or_path: str = "", + repetition_penalty: float = 1.0, + max_decode_length: int = 1024, + block_size: int = 16, + num_blocks: int = 512, + top_k: int = 5, + top_p: float = 1.0, + do_sample: bool = True, + sigma: float = 0.0048, + mean: float = 0.0, + layers_group: int = 1, + eod_reset: bool = False, + tie_word_embeddings: bool = False, + return_hidden_states: bool = False, + **kwargs): + super().__init__(**kwargs) + if isinstance(parallel_config, dict): + parallel_config = TransformerOpParallelConfig(**parallel_config) + self.batch_size = batch_size + self.seq_length = seq_length + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_layers = num_layers + self.num_heads = num_heads + self.max_position_embedding = max_position_embedding if max_position_embedding else seq_length + self.intermediate_size = intermediate_size + self.multiple_of = multiple_of + self.n_kv_heads = n_kv_heads + self.ffn_dim_multiplier = ffn_dim_multiplier + self.rms_norm_eps = rms_norm_eps + self.out_proj_has_bias = out_proj_has_bias + self.param_init_type = convert_mstype(param_init_type) + if embedding_init_type is not None: + self.embedding_init_type = convert_mstype(embedding_init_type) + else: + self.embedding_init_type = self.param_init_type + self.qkv_has_bias = qkv_has_bias + self.layernorm_compute_type = convert_mstype(layernorm_compute_type) + self.softmax_compute_type = convert_mstype(softmax_compute_type) + self.rotary_dtype = convert_mstype(rotary_dtype) + self.compute_dtype = convert_mstype(compute_dtype) + self.res_dtype = convert_mstype(res_dtype) + self.parallel_config = parallel_config + self.checkpoint_name_or_path = checkpoint_name_or_path + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + self.pad_token_id = pad_token_id + self.ignore_token_id = ignore_token_id + self.use_past = use_past + self.extend_method = extend_method + self.scaling_factor = scaling_factor + self.is_dynamic = is_dynamic + self.use_rope_slice = use_rope_slice + self.use_flash_attention = use_flash_attention + self.use_attn_mask_compression = use_attn_mask_compression + self.parallel_optimizer = parallel_optimizer + self.fine_grain_interleave = fine_grain_interleave + self.offset = offset + self.repetition_penalty = repetition_penalty + self.max_decode_length = max_decode_length + self.pp_interleave_num = pp_interleave_num + self.top_k = top_k + self.top_p = top_p + self.do_sample = do_sample + self.sigma = sigma + self.mean = mean + self.theta = theta + self.block_size = block_size + self.num_blocks = num_blocks + self.qkv_concat = qkv_concat + self.layers_group = layers_group + self.eod_reset = eod_reset + self.embedding_size = embedding_size + self.tie_word_embeddings = tie_word_embeddings + self.return_hidden_states = return_hidden_states diff --git a/research/telechat3/telechat_interleave.py b/research/telechat3/telechat_interleave.py new file mode 100644 index 000000000..69d002d95 --- /dev/null +++ b/research/telechat3/telechat_interleave.py @@ -0,0 +1,683 @@ +# Copyright 2025 TeleAI Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Telechat fine grain interleave transformer Telechat's APIs.""" + +from typing import Optional +import math + +import mindspore as ms +from mindspore import nn, __version__ +import mindspore.common.dtype as mstype +from mindspore.common.tensor import Tensor +from mindspore.context import ParallelMode +from mindspore.ops import operations as P +from mindspore.parallel._utils import _get_parallel_mode, _is_sharding_propagation +from mindspore.parallel.shard import Layout + +from research.telechat3.telechat_layer import TelechatLinear, TelechatFeedForward + +from mindformers.models.llama.llama_layer import LlamaRMSNorm +from mindformers.modules.layers import _check_input_dtype, RotaryEmbedding +from mindformers.modules.transformer import TransformerOpParallelConfig +from mindformers.modules.flash_attention import FlashAttention + + +class _MicroBatch(nn.Cell): + """ + transform mini-batch to micro-batch in pipeline parallel. + + Args: + params (micro_size): The number of micro-batch. + """ + def __init__(self, micro_size, input_size, axis_list): + super().__init__() + self.shape = P.Shape() + self.micro_size = micro_size + self.strided_slice_list = [] + for _ in range(input_size): + self.strided_slice_list.append(P.StridedSlice()) + self.axis_list = axis_list + + def construct(self, i, *inputs): + """construct for _MicroBatch.""" + micro_inputs = () + k = 0 + for each_input in inputs: + input_shape = self.shape(each_input) + micro_batch_begin = i * input_shape[self.axis_list[k]] // self.micro_size + micro_batch_end = (i + 1) * input_shape[self.axis_list[k]] // self.micro_size + strided_slice_begin = () + strided_slice_strides = () + strided_slice_end = () + for j, _ in enumerate(input_shape): + strided_slice_strides += (1,) + if j == self.axis_list[k]: + strided_slice_begin += (micro_batch_begin,) + strided_slice_end += (micro_batch_end,) + else: + strided_slice_begin += (0,) + strided_slice_end += (input_shape[j],) + + micro_input = self.strided_slice_list[k](each_input, strided_slice_begin, \ + strided_slice_end, strided_slice_strides) + micro_inputs += (micro_input,) + k += 1 + return micro_inputs + + +class TelechatAttentionInterleave(nn.Cell): + r""" + This is an implementation of multihead attention in Telechat. + + Args: + - **batch_size** (int): The batch size of the input tensor when do increnmental prediction. Should be a + positive value. + When do training or prediction, the argument will not work and the user can just pass None to the + argument. + - **src_seq_length** (int): The sequence length of the query vector. + - **tgt_seq_length** (int): The sequence length of the key and value vector. + - **dim** (int): The hidden size of the input. + - **head_dim** (int): The dim of head. + - **n_heads** (int): The number of the heads. + - **compute_dtype** (dtype.Number): The computation type of dense. Default mstype.float16. + Should be mstype.float32 or mstype.float16. + - **softmax_compute_type** (dtype.Number): The type of softmax computation module. Default mstype.float32. + Should be mstype.float32 or mstype.float16. + - **param_init_type** (dtype.Number): The parameter initialization type of the module. Default mstype. + float32. Should be mstype.float32 or mstype.float16. + - **qkv_has_bias** (bool): Whether Q/K/V in attention has bias or not. + - **use_past** (bool): Use the past state to compute, used for incremental prediction. + For example, if we have two words and want to generate the ten more words. + We just need to compute the two words' state only once, and generate the next word one by one. + When use_past is True, there are two steps to run the prediction. + In the first step, set the is_first_iteration to be True by + `model.add_flags_recursive(is_first_iteration=True)`, and pass the full inputs. Then, set the + is_first_iteration to be False by `model.add_flags_recursive(is_first_iteration=False)`. At this moment, + pass the single step's input tensor, and loop it. Default False. + - **parallel_config** (OpParallelConfig): The parallel configure. Default `default_dpmp_config`, + an instance of `OpParallelConfig` with default args. + + Inputs: + - **x** (Tensor) - The input tokens with shape (batch_size, src_seq_length, hidden_size) or + (batch_size * src_seq_length, hidden_size), if the use_past is False or is_first_iteration=True. + Otherwise, must be (batch_size, 1, hidden_size) + - **freqs_cis** (Tuple) - The precompute freqs and mask for rotary position embedding used in attention. + - **attention_mask** (Tensor) - If the use_past is False or is_first_iteration=True, the attention mask + matrix should ba (batch_size, src_seq_length, tgt_seq_length), or None. None means there will be no mask + in softmax computation. Otherwise, the mask must be (batch_size, 1, tgt_seq_length) + - **key_past** (Tensor) - Float16 tensor with shape (batch_size, num_heads, head_dim, tgt_seq_length). + The past calculated key vector. Used for incremental prediction when the use_past is True. + Default None. + - **value_past** (Tensor) - Float16 tensor with shape (batch_size, num_heads, tgt_seq_length, + head_dim). + The past calculated value vector. Used for incremental prediction when the use_past is True. + Default None. + - **batch_valid_length** (Tensor) - Int32 tensor with shape (batch_size,) the past calculated the index. + Used for incremental prediction when the use_past is True. Default None. + + Outputs: + Tuple, a tuple contains(`output`, `layer_present`) + + - **output** (Tensor) - Tensor, the float tensor of the output of the layer with + shape (batch_size, src_seq_length, hidden_size) or (batch_size * src_seq_length, hidden_size), + if the use_past is False or is_first_iteration=True. Otherwise, it will be (batch_size, 1, hidden_size). + + - **layer_present** (Tuple) - A tuple of the Tensor of the projected key and value vector with + ((batch_size, num_heads, head_dim, tgt_seq_length), + (batch_size, num_heads, tgt_seq_length, head_dim)). + """ + def __init__(self, seq_length, + dim: int = 512, + n_heads: int = 8, + sigma: float = 0.0048, + mean: float = 0.0, + n_kv_heads: Optional[int] = None, + compute_dtype=mstype.float16, + softmax_compute_dtype=mstype.float32, + rotary_dtype=mstype.float32, + param_init_type=mstype.float32, + qkv_has_bias=False, + out_proj_has_bias=True, + use_rope_slice=False, + use_flash_attention=False, + use_attn_mask_compression=False, + parallel_config=TransformerOpParallelConfig()): + super().__init__() + self.seq_length = seq_length + self.hidden_size = dim + self.n_head = n_heads + self.head_dim = dim // n_heads + self.n_kv_head = n_heads if n_kv_heads is None else n_kv_heads + self.n_rep = self.n_head // self.n_kv_head + self.kv_dim = self.n_kv_head * self.head_dim + self.qkv_has_bias = qkv_has_bias + self.out_proj_has_bias = out_proj_has_bias + self.dtype = compute_dtype + self.softmax_dtype = softmax_compute_dtype + self.is_first_iteration = True + self.use_flash_attention = use_flash_attention + self.use_attn_mask_compression = use_attn_mask_compression + + if self.hidden_size % self.n_head != 0: + raise ValueError(f"For 'MultiHeadAttention', the class variable 'hidden_size' must be a multiple " + f"of 'n_head', but got the hidden_size is {self.hidden_size}" + f"and the n_head is {self.n_head}.") + if self.n_kv_head % parallel_config.model_parallel != 0: + raise ValueError(f"For 'MultiHeadAttention', the class variable 'n_kv_head' must be a multiple of " + f"'parallel_config.model_parallel', but got the n_kv_head is {self.n_kv_head} " + f"and the parallel_config.model_parallel is {parallel_config.model_parallel}.") + + self.inv_norm_factor = Tensor(1.0 / math.sqrt(self.head_dim), dtype=compute_dtype) + + self.shape = P.Shape() + self.reshape = P.Reshape() + self.transpose = P.Transpose() + self.merger_head_transpose = P.Transpose() + self.batch_matmul = P.BatchMatMul() + self.batch_matmul_q_k = P.BatchMatMul(transpose_b=True) + self.mul = P.Mul() + self.add = P.Add() + self.softmax = P.Softmax() + self.cast = P.Cast() + self.cast_attn = P.Cast() + self.tile_kv = P.Tile() + self.split_kv = ms.ops.auto_generate.SplitWithSize() + self.split_kv.add_prim_attr("skip_redistribution", True) + self.apply_rotary_emb = RotaryEmbedding(self.head_dim, rotary_dtype, use_rope_slice=use_rope_slice) + + self.wq = TelechatLinear(self.hidden_size, + self.hidden_size, + has_bias=qkv_has_bias, + sigma=sigma, + mean=mean, + compute_dtype=compute_dtype, + param_init_type=param_init_type) + self.wk = TelechatLinear(self.hidden_size, + self.n_kv_head * self.head_dim, + has_bias=qkv_has_bias, + sigma=sigma, + mean=mean, + compute_dtype=compute_dtype, + param_init_type=param_init_type) + self.wv = TelechatLinear(self.hidden_size, + self.n_kv_head * self.head_dim, + has_bias=qkv_has_bias, + sigma=sigma, + mean=mean, + compute_dtype=compute_dtype, + param_init_type=param_init_type) + self.wo = TelechatLinear(in_channels=self.hidden_size, + out_channels=self.hidden_size, + has_bias=out_proj_has_bias, + sigma=sigma, + mean=mean, + compute_dtype=compute_dtype, + param_init_type=param_init_type) + + dp = parallel_config.data_parallel + mp = parallel_config.model_parallel + sp = parallel_config.context_parallel + self.sp = sp + self.split_kv.shard(((dp * sp, mp, 1),)) + if not (_get_parallel_mode() in (ParallelMode.AUTO_PARALLEL,) and _is_sharding_propagation()): + self.transpose.shard(((dp, sp, mp, 1),)) + if sp > 1: + layout = Layout((dp, sp, mp), ("dp", "sp", "mp")) + layout_merge_head_transpose = (layout("dp", "mp", "sp", "None"),) + self.merger_head_transpose.shard(in_strategy=layout_merge_head_transpose) + else: + self.merger_head_transpose.shard(((dp, mp, 1, 1),)) + self.batch_matmul_q_k.shard(((dp, mp, sp, 1), (dp, mp, sp, 1))) + self.batch_matmul.shard(((dp, mp, sp, 1), (dp, mp, sp, 1))) + self.mul.shard(((dp, mp, sp, 1), ())) + self.add.shard(((dp, 1, 1, 1), (dp, mp, sp, 1))) + self.softmax.shard(((dp, mp, 1, 1),)) + self.tile_kv.shard(((dp, mp, 1, sp),)) + + self.apply_rotary_emb.shard(parallel_config) + + if self.qkv_has_bias: + self.wq.shard(((dp * sp, 1), (mp, 1)), ((dp * sp, mp), (mp,))) + self.wk.shard(((dp * sp, 1), (mp, 1)), ((dp * sp, mp), (mp,))) + self.wv.shard(((dp * sp, 1), (mp, 1)), ((dp * sp, mp), (mp,))) + else: + self.wq.shard(((dp * sp, 1), (mp, 1))) + self.wk.shard(((dp * sp, 1), (mp, 1))) + self.wv.shard(((dp * sp, 1), (mp, 1))) + if self.out_proj_has_bias: + self.wo.shard(((dp * sp, mp), (1, mp)), ((dp * sp, 1), (1,))) + else: + self.wo.shard(((dp * sp, mp), (1, mp))) + if parallel_config.use_seq_parallel and self.is_first_iteration and sp==1: + if self.out_proj_has_bias: + self.wo.shard(((dp, mp), (1, mp)), ((dp * mp, 1), (1,)), out_strategy_matmul=((dp * mp, 1),)) + else: + self.wo.shard(((dp, mp), (1, mp)), out_strategy_matmul=((dp * mp, 1),)) + if parallel_config.recompute.select_recompute and not self.use_flash_attention: + self.apply_rotary_emb.recompute() + self.tile_kv.recompute() + self.batch_matmul_q_k.recompute() + self.mul.recompute() + self.add.recompute() + self.cast_attn.recompute() + self.softmax.recompute() + self.batch_matmul.recompute() + + if self.use_flash_attention: + self.input_layout = "BSH" if sp > 1 else "BNSD" + self.sparse_mode = 2 if self.use_attn_mask_compression else 0 + self.flash_attention = FlashAttention(head_num=self.n_head, + pre_tokens=65536, + next_tokens=0, + input_layout=self.input_layout, + scale_value=1. / math.sqrt(self.head_dim), + sparse_mode=self.sparse_mode, + use_attention_mask=True) + self.flash_attention.shard(parallel_config) + + def compute_qkv(self, x): + """compute the qkv with interleave number""" + x = self.reshape(x, (-1, x.shape[-1])) + query = self.cast(self.wq(x), self.dtype) # dp, 1 -> dp, mp + key = self.cast(self.wk(x), self.dtype) # dp, 1 -> dp, mp + value = self.cast(self.wv(x), self.dtype) # dp, 1 -> dp, mp + key = self.reshape(key, (-1, self.n_kv_head * self.head_dim)) + value = self.reshape(value, (-1, self.n_kv_head * self.head_dim)) + return query, key, value + + def cal_attn(self, query, key, value, mask, freqs_cis): + """cal_attn""" + query = self.reshape(query, (-1, self.seq_length, self.n_head, self.head_dim)) + key = self.reshape(key, (-1, self.seq_length, self.n_kv_head, self.head_dim)) + if self.sp > 1: + value = self.reshape(value, (-1, self.seq_length, self.n_kv_head * self.head_dim)) + else: + value = self.reshape(value, (-1, self.seq_length, self.n_kv_head, self.head_dim)) + value = self.transpose(value, (0, 2, 1, 3)) + + # [bs, seq/1, n_head/n_kv_head, head_dim] + query = self.transpose(query, (0, 2, 1, 3)) + key = self.transpose(key, (0, 2, 1, 3)) + + # [bs, n_head/n_kv_head, seq/1, head_dim] + query, key = self.apply_rotary_emb(query, key, freqs_cis) # dp, mp, 1, 1 + # kv share: [bs, n_kv_head, seq, head_dim] -> [bs, n_head, seq, head_dim] + if self.sp > 1: + query = self._merge_heads(query) + key = self._merge_heads(key) + else: + bs, n_head, seq, head_dim = query.shape + n_kv_head = key.shape[1] + query = self.reshape(query, (bs, n_head, seq, head_dim)) + key = self.reshape(key, (bs, n_kv_head, seq, head_dim)) + value = self.reshape(value, (bs, n_kv_head, seq, head_dim)) + + # q, k, v: [bs, n_head, seq/1, head_dim], [bs, n_head, seq, head_dim], [bs, n_head, seq, head_dim] + if self.use_flash_attention: + attention = self.flash_attention(query, key, value, mask) + if self.sp > 1: + attention = self.reshape(attention, (-1, attention.shape[-1])) + else: + attention = self._merge_heads(attention) + else: + key = self._repeat_kv(key, self.n_rep) + value = self._repeat_kv(value, self.n_rep) + attention = self._attn(query, key, value, mask) + return attention + + def cal_output_proj(self, attention): + """cal_output_proj""" + output = self.wo(attention) # dp, mp -> dp, 1 / dp * mp, 1 + return output + + def _repeat_kv(self, x, rep): + """repeat_kv""" + if rep == 1: + return x + bs, n_kv_head, seqlen, head_dim = x.shape + x = self.reshape(x, (bs, n_kv_head, 1, seqlen * head_dim)) + x = self.tile_kv(x, (1, 1, rep, 1)) + x = self.reshape(x, (bs, n_kv_head * rep, seqlen, head_dim)) + return x + + def _merge_heads(self, x): + """ + convert a 4d input to a 2d or 3d output + + Inputs: + x: input tensor + + Output: + x_merge: the 2d output + """ + # [bs, n_head, seq/1, head_dim] + x = self.merger_head_transpose(x, (0, 2, 1, 3)) # dp,mp,1,1 -> dp,1,mp,1 + # [bs, seq/1, n_head, head_dim] + x_shape = x.shape + # [bs * seq/1, hidden_dim] + if self.sp > 1: + new_shape = (-1, x_shape[-3], x_shape[-2]*x_shape[-1]) + else: + new_shape = (-1, x_shape[-2] * x_shape[-1]) + x_merge = self.reshape(x, new_shape) + return x_merge + + def _attn(self, query, key, value, mask): + """ + Get the weighted score along the seq_length + + Inputs: + query: the query matrix + key: the key matrix + value: the value matrix + mask: the attention mask adder matrix with shape (batch_size, + 1, seq_length, seq_length) + Outputs: + weighted_values: Tensor, the weighted sum scores + """ + # q, k: [bs, n_head, seq/1, head_dim], [bs, n_head, seq, head_dim] + score = self.batch_matmul_q_k(query, key) + # score: [bs, n_head, seq/1, seq] + score = self.mul(score, self.inv_norm_factor) + score = self.add(mask, score) + + attention_probs = self.softmax(self.cast_attn(score, self.softmax_dtype)) + # score, v: [bs, n_head, seq/1, seq], [bs, n_head, seq, head_dim] + weighted_values = self.batch_matmul(self.cast(attention_probs, self.dtype), value) + # [bs, n_head, seq/1, head_dim] + attention_merge = self._merge_heads(weighted_values) + # [bs, seq/1, hidden_dim] or [bs * seq/1, hidden_dim] + return attention_merge + + +class TelechatDecodeLayerInterleave(nn.Cell): + r""" + Transformer Layer. This is an implementation of the single layer of the transformer + encoder layer, including multihead attention and feedward layer. + + Args: + seq_length(int): The input sequence length. + layer_id(int): The layer id of current transformer block layer. + dim(int): The hidden size of the input. + num_heads(int): The number of the heads. + multiple_of(int): The SwiGLU hidden layer size multiple of large power of 2. + norm_eps (float): The epsilon value of the denominator. Default 1e-5. + compute_dtype(dtype.Number): The computation type of the layer. + Should be mstype.float32 or mstype.float16. Default mstype.float32. + layernorm_compute_type(dtype.Number): The computation type of the norm. + Should be mstype.float32 or mstype.float16. Default mstype.float32. + softmax_compute_type(dtype.Number): The computation type of the softmax in the attention. + Should be mstype.float32 or mstype.float16. Default mstype.float32. + param_init_type(dtype.Number): The parameter initialization type of the module. + Should be mstype.float32 or mstype.float16. Default mstype.float32. + qkv_has_bias(bool): Whether Q/K/V in attention has bias or not. + use_past(bool): Use the past state to compute, used for incremental prediction. For example, if we have two + words and want to generate the ten more words. We just need to compute the two words' state only once, + and generate the next word one by one. When use_past is True, there are two steps to run the prediction. + In the first step, set the is_first_iteration to be True by + `model.add_flags_recursive(is_first_iteration=True)`, and pass the full inputs. Then, set the + is_first_iteration to be False by `model.add_flags_recursive(is_first_iteration=False)`. + At this moment, pass the single step's input tensor, and loop it. Default False. + parallel_config(OpParallelConfig, MoEParallelConfig): The parallel configure. When MoE is applied, + MoEParallelConfig is effective, otherwise OpParallelConfig is effective. Default `default_dpmp_config`, + an instance of `OpParallelConfig` with default args. + + Inputs: + - **x** (Tensor) - Float Tensor, shape should be [batch_size, seq_length, hidden_size] or + [batch_size * seq_length, hidden_size], if the use_past is False or is_first_iteration=True. Otherwise, + should be [batch_size, 1, hidden_size] + - **freqs_cis** (Tuple) - The precompute freqs and mask for rotary position embedding used in attention. + - **input_mask** (Tensor) - Float Tensor, If the use_past is False or is_first_iteration=True, + the attention mask matrix should ba [batch_size, seq_length, seq_length], or None. None means there will + be no mask in softmax computation. Otherwise, should be [batch_size, 1, hidden_size] + - **init_reset** (Tensor) - A bool tensor with shape [1], used to clear the past key parameter and + past value parameter used in the incremental prediction. Only valid when use_past is True. Default True. + - **batch_valid_length** (Tensor) - Int32 tensor with shape [batch_size] the past calculated the index. + Used for incremental prediction when the use_past is True. Default None. + + Outputs: + Tuple, a tuple contains(`output`, `layer_present`). + + - **output** (Tensor) - The float tensor of the output of the layer with + shape (batch_size, seq_length, hidden_size) or (batch_size * seq_length, hidden_size), if the use_past is + False or is_first_iteration=True. Otherwise, it will be (batch_size, 1, hidden_size) + + - **layer_present** (Tuple) - A tuple of the Tensor of the projected key and value vector with + ((batch_size, num_heads, head_dim, seq_length), + (batch_size, num_heads, seq_length, head_dim)). + + """ + def __init__(self, seq_length, + layer_id, + dim: int = 512, + n_heads: int = 8, + num_layers: int = 32, + sigma: float = 0.0048, + mean: float = 0.0, + n_kv_heads: Optional[int] = None, + intermediate_size: Optional[int] = None, + ffn_dim_multiplier: Optional[int] = None, + norm_eps: float = 1e-5, + compute_dtype=mstype.float16, + layernorm_compute_dtype=mstype.float32, + softmax_compute_dtype=mstype.float32, + rotary_dtype=mstype.float32, + param_init_type=mstype.float32, + res_dtype=mstype.float32, + qkv_has_bias=False, + out_proj_has_bias=True, + use_rope_slice=False, + use_flash_attention=False, + fine_grain_interleave=2, + use_attn_mask_compression=False, + parallel_config=TransformerOpParallelConfig()): + + super().__init__() + self.seq_length = seq_length + self.layer_id = layer_id + self.hidden_size = dim + self.n_head = n_heads + self.num_layers = num_layers + self.head_dim = self.hidden_size // self.n_head + self.n_kv_head = n_heads if n_kv_heads is None else n_kv_heads + + self.dtype = compute_dtype + self.res_dtype = res_dtype + self.is_first_iteration = True + self.interleave_num = fine_grain_interleave + self.key_past = None + self.value_past = None + + self.reshape = P.Reshape() + self.add = P.Add() + self.cast = P.Cast() + self.attention_norm = LlamaRMSNorm(self.hidden_size, norm_eps, compute_type=layernorm_compute_dtype) + self.ffn_norm = LlamaRMSNorm(self.hidden_size, norm_eps, compute_type=layernorm_compute_dtype) + self.attention = TelechatAttentionInterleave(seq_length=seq_length, + dim=dim, + n_heads=n_heads, + sigma=sigma, + mean=mean, + n_kv_heads=n_kv_heads, + compute_dtype=compute_dtype, + softmax_compute_dtype=softmax_compute_dtype, + rotary_dtype=rotary_dtype, + param_init_type=param_init_type, + qkv_has_bias=qkv_has_bias, + out_proj_has_bias=out_proj_has_bias, + use_rope_slice=use_rope_slice, + use_flash_attention=use_flash_attention, + use_attn_mask_compression=use_attn_mask_compression, + parallel_config=parallel_config) + self.feed_forward = TelechatFeedForward(dim=self.hidden_size, + intermediate_size=intermediate_size, + hidden_dim=4 * self.hidden_size, + sigma=sigma, + mean=mean, + ffn_dim_multiplier=ffn_dim_multiplier, + compute_dtype=compute_dtype, + param_init_type=param_init_type) + + dp = parallel_config.data_parallel + mp = parallel_config.model_parallel + sp = parallel_config.context_parallel + if not (_get_parallel_mode() in (ParallelMode.AUTO_PARALLEL,) and _is_sharding_propagation()): + self.feed_forward.shard(parallel_config) + self.feed_forward.mul.shard(((dp * sp, mp), (dp * sp, mp))) + self.add.shard(((dp * sp, 1), (dp * sp, 1))) + if sp > 1: + self.attention_norm.shard((dp * sp * mp, 1)) + self.ffn_norm.shard((dp * sp * mp, 1)) + else: + self.attention_norm.shard((dp, 1)) + self.ffn_norm.shard((dp, 1)) + + if parallel_config.use_seq_parallel and self.is_first_iteration: + self.add.shard(((dp * sp * mp, 1), (dp * sp * mp, 1))) + self.attention_norm.shard((dp * sp * mp, 1)) + self.ffn_norm.shard((dp * sp * mp, 1)) + self.feed_forward.w2.shard( + strategy_matmul=((dp * sp, mp), (1, mp)), + strategy_bias=((dp * sp * mp, 1), (1, )), + out_strategy_matmul=((dp * sp * mp, 1),)) + + concat_stra1 = [] + concat_stra2 = [] + self.interleave1_inputs = nn.CellList() + self.interleave1_inputs_ = nn.CellList() + self.interleave2_inputs = nn.CellList() + self.interleaved_concat1 = P.Concat(axis=0) + self.interleaved_concat1.add_prim_attr("fine_grained_interleaved_index", self.layer_id) + self.interleaved_concat_1 = P.Concat(axis=0) + self.interleaved_concat2 = P.Concat(axis=0) + if self.layer_id != self.num_layers - 2: + self.interleaved_concat2.add_prim_attr("fine_grained_interleaved_index", 1000) + + for _ in range(self.interleave_num): + concat_stra1.append((dp * sp, mp)) + interleave_data1 = _MicroBatch(self.interleave_num, 1, [0]) + interleave_data1.strided_slice_list[0].add_prim_attr("skip_redistribution", True) + interleave_data1_ = _MicroBatch(self.interleave_num, 1, [0]) + interleave_data1_.strided_slice_list[0].add_prim_attr("skip_redistribution", True) + interleave_data2 = _MicroBatch(self.interleave_num, 2, [0, 0]) + if parallel_config.use_seq_parallel: + if self.layer_id == self.num_layers - 2: + concat_stra2.append((dp * sp, 1)) + else: + concat_stra2.append((dp * mp * sp, 1)) + if self.layer_id == self.num_layers - 1: + interleave_data1.strided_slice_list[0].shard(((dp * sp, 1),)) + else: + interleave_data1.strided_slice_list[0].shard(((dp * mp * sp, 1),)) + interleave_data1_.strided_slice_list[0].shard(((1, 1),)) + interleave_data2.strided_slice_list[0].shard(((dp * mp * sp, 1),)) + else: + concat_stra2.append((dp * sp, 1)) + interleave_data1.strided_slice_list[0].shard(((dp * sp, 1),)) + interleave_data1_.strided_slice_list[0].shard(((1, 1),)) + interleave_data2.strided_slice_list[0].shard(((dp * sp, 1),)) + if self.layer_id == 0 and parallel_config.use_seq_parallel: + interleave_data2.strided_slice_list[0].shard(((dp * sp, 1),)) + interleave_data2.strided_slice_list[0].add_prim_attr("skip_redistribution", True) + else: + interleave_data2.strided_slice_list[0].add_prim_attr("skip_redistribution", True) + + interleave_data2.strided_slice_list[0].add_prim_attr("fine_grained_interleaved_index", self.layer_id) + interleave_data2.strided_slice_list[1].shard(((dp * sp, mp),)) + interleave_data2.strided_slice_list[1].add_prim_attr("fine_grained_interleaved_index", self.layer_id) + interleave_data2.strided_slice_list[1].add_prim_attr("skip_redistribution", True) + self.interleave1_inputs.append(interleave_data1) + self.interleave1_inputs_.append(interleave_data1_) + self.interleave2_inputs.append(interleave_data2) + concat_stra3 = tuple(concat_stra1) + concat_stra4 = tuple(concat_stra2) + self.interleaved_concat1.shard(concat_stra3) + self.interleaved_concat1.add_prim_attr("skip_redistribution", True) + self.interleaved_concat_1.shard(concat_stra3) + self.interleaved_concat_1.add_prim_attr("skip_redistribution", True) + self.interleaved_concat2.shard(concat_stra4) + self.interleaved_concat2.add_prim_attr("skip_redistribution", True) + + def linear_layer1(self, x): + """layer part 1""" + input_x = self.attention_norm(x) + query, key, value = self.attention.compute_qkv(input_x) + return query, key, value + + def linear_layer2(self, x, attention): + """layer part 2""" + attention_output = self.attention.cal_output_proj(attention) + ori_dtype = attention_output.dtype + # For post-layernorm the inputs for residual path are output of self-attention and output of layernorm + x = self.add(self.cast(x, self.res_dtype), self.cast(attention_output, self.res_dtype)) + output_x = self.ffn_norm(x) + mlp_logit = self.feed_forward(output_x) + output = self.add(self.cast(x, self.res_dtype), self.cast(mlp_logit, self.res_dtype)) + output = self.cast(output, ori_dtype) + return output + + # pylint: disable=W0613 + def construct(self, x, freqs_cis, mask=None, batch_valid_length=None, block_tables=None, + slot_mapping=None, prefix_keys_values=None, q_seq_lens=None): + """ Forward of transformer block. """ + self._check_input(x, freqs_cis, mask) + x = self.reshape(x, (-1, x.shape[-1])) + # ============linear-layer1================ + if self.layer_id == 0: + query, key, value = self.linear_layer1(x) + else: + query_tuple = () + key_tuple = () + value_tuple = () + for i in range(self.interleave_num): + x_part, = self.interleave1_inputs[i](i, x) + query_part, key_part, value_part = self.linear_layer1(x_part) + query_tuple += (query_part,) + key_tuple += (key_part,) + value_tuple += (value_part,) + query = self.interleaved_concat1(query_tuple) + key = self.interleaved_concat_1(key_tuple) + value = self.interleaved_concat_1(value_tuple) + # ===========linear-layer1 end============= + attention = self.attention.cal_attn(query, key, value, mask, freqs_cis) + # ============linear-layer2================ + if self.layer_id == self.num_layers - 1: + output = self.linear_layer2(x, attention) + else: + output_tuple = () + for i in range(self.interleave_num): + x_part, attention_part = self.interleave2_inputs[i](i, x, attention) + output_part = self.linear_layer2(x_part, attention_part) + output_tuple += (output_part,) + output = self.interleaved_concat2(output_tuple) + # ============linear-layer2 end=========== + return output + + def _check_input(self, x, freqs_cis, mask): + r"""Check inputs""" + _check_input_dtype( + x.dtype, "x", [mstype.float32, mstype.float16, mstype.bfloat16], self.cls_name) + freqs_cos, freqs_sin, swap_mask = freqs_cis + _check_input_dtype(freqs_cos.dtype, "freqs_cos", + [mstype.float32, mstype.float16, mstype.bfloat16], self.cls_name) + _check_input_dtype(freqs_sin.dtype, "freqs_sin", + [mstype.float32, mstype.float16, mstype.bfloat16], self.cls_name) + if swap_mask is not None: + _check_input_dtype(swap_mask.dtype, "swap_mask", + [mstype.float32, mstype.float16, mstype.bfloat16], self.cls_name) + if mask is not None: + _check_input_dtype(mask.dtype, "input_mask", + [mstype.float32, mstype.float16, mstype.uint8, mstype.bfloat16], self.cls_name) + return True diff --git a/research/telechat3/telechat_layer.py b/research/telechat3/telechat_layer.py new file mode 100644 index 000000000..f5c8d08de --- /dev/null +++ b/research/telechat3/telechat_layer.py @@ -0,0 +1,292 @@ +# Copyright 2025 TeleAI Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Telechat Model Layers' APIs.""" + +import mindspore as ms +from mindspore.common.parameter import Parameter +from mindspore import nn +import mindspore.common.dtype as mstype +from mindspore.ops import operations as P +from mindspore.ops import functional as F +from mindspore.nn.cell import Cell + +try: + from mindspore._checkparam import Validator +except ImportError: + import mindspore._checkparam as Validator +from mindspore import log as logger +from mindspore.common.initializer import initializer, Normal +from mindspore.parallel._utils import _get_parallel_mode +from mindspore.context import ParallelMode +from mindformers.modules.transformer.op_parallel_config import default_dpmp_config +from mindformers.models.llama.llama_layer import LlamaSiLU +from mindformers.modules.layers import Linear, _check_input_dtype, _args_type_validator_check, \ + _valid_value_checks +from mindformers.tools.logger import _LogActionOnce + + +class TelechatEmbedding(Cell): + """ + Embedding Layer. + + Args: + - **vocab_size** (int): Size of the dictionary of embeddings. + - **embedding_size** (int): The size of each embedding vector. + - **param_init_type** (mstype): The param init type, default mstype.float32. + - **parallel_config** (TransformerOpParallelConfig): The parallel config of network. Default + `default_embedding_parallel_config`, an instance of `EmbeddingOpParallelConfig` with default args. + - **param_init** (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the embedding_table. + Refer to class `initializer` for the values of string when a string + is specified. Default: 'normal'. + Inputs: + - **input_ids** (Tensor) - The tokenized inputs with datatype int32 with shape (batch_size, seq_length) + + Outputs: + - **output** (Tensor) - The embedding vector for the input with shape (batch_size, + seq_length, embedding_size). + """ + + @_LogActionOnce(m_logger=logger, key='Embedding', + no_warning=_get_parallel_mode() in (ParallelMode.STAND_ALONE,)) + @_args_type_validator_check(vocab_table_size=Validator.check_positive_int, + embedding_size=Validator.check_positive_int) + def __init__(self, vocab_table_size, embedding_size, sigma=0.0048, mean=0.0, param_init_type=mstype.float32, + parallel_optimizer=True): + super().__init__() + self.vocab_table_size = vocab_table_size + self.embedding_size = embedding_size + self.embedding_weight = Parameter( + initializer(Normal(sigma=sigma, mean=mean), [self.vocab_table_size, self.embedding_size], + dtype=param_init_type), name='embedding_weight', parallel_optimizer=parallel_optimizer) + self.gather = P.Gather() + + def construct(self, input_ids): + """Forward of vocab embedding.""" + _check_input_dtype(F.dtype(input_ids), "input_ids", [mstype.int32, mstype.int64], self.cls_name) + shape = F.shape(input_ids) + input_ids = F.reshape(input_ids, (-1, 1)) + output = self.gather(self.embedding_weight, input_ids, 0) + output = F.reshape(output, (shape[0], shape[1], -1)) + return output + + def shard(self, parallel_config): + """sharding for embedding""" + dp = parallel_config.data_parallel + mp = parallel_config.model_parallel + sp = parallel_config.context_parallel + if parallel_config.vocab_emb_dp: + self.gather.shard(((1, 1), (dp * sp, 1))) + logger.info(f"Using {dp * sp} data parallel for the embedding lookup.") + else: + if self.vocab_table_size % mp != 0: + logger.warning("The vocab size of Loss is: %s, it is not divide by model_parallel: %s", + self.vocab_table_size, mp) + logger.warning("Now, the model_parallel num of Loss will be changed: mp = 1") + self.gather.shard(((1, 1), (dp * sp, 1))) + else: + self.gather.shard(((mp, 1), (dp * sp, 1))) + logger.info(f"Using {dp * sp} data parallel X sequence parallel and {mp} " + f"model parallel for the embedding lookup.") + + +class TelechatLinear(Linear): + # pylint: disable=W0212 + """ + Linear function for Telechat. + """ + + def __init__(self, + in_channels, + out_channels, + sigma=0.0048, + mean=0.0, + bias_init='zeros', + has_bias=True, + activation=None, + transpose_b=True, + param_init_type=mstype.float32, + compute_dtype=mstype.float16): + super().__init__(in_channels, + out_channels, + weight_init=Normal(sigma=sigma, mean=mean), + bias_init=bias_init, + has_bias=has_bias, + activation=activation, + transpose_b=transpose_b, + outer_batch=outer_batch, + param_init_type=param_init_type, + compute_dtype=compute_dtype) + weight_shape = [out_channels, in_channels] if transpose_b else [in_channels, out_channels] + self.weight = Parameter(initializer(Normal(sigma=sigma, mean=mean), weight_shape, param_init_type), + name="weight") + + def construct(self, x): + """construct of linear.""" + out_shape = self.shape(x)[:-1] + (self.out_channels,) + x = self.reshape(x, (-1, self.in_channels)) + ori_dtype = F.dtype(x) + weight = self.cast(self.weight, self.dtype) + x = self.cast(x, self.dtype) + x = self.matmul(x, weight) + if self.has_bias: + x = self.bias_add(x, self.cast(self.bias, self.dtype)) + if self.activation_flag: + x = self.activation(x) + x = F.cast(x, ori_dtype) + output = self.reshape(x, out_shape) + return output + + +class TelechatFeedForward(Cell): + r""" + Telechat FeedForward. + + .. math:: + (xW_1 * xW_3)W_2 + + Inputs: + - **x** (Tensor) - should be `[batch, seq_length, hidden_size] or [batch * seq_length, hidden_size]`. + Float tensor. + + Outputs: + Tensor, the output of this layer after mapping. The shape is `[batch, seq_length, hidden_size] or + [batch * seq_length, hidden_size]`. + + Raises: + ValueError: `hidden_dim` is not a multiple of the model parallel way. + ValueError: `dim` is not a multiple of the model parallel way. + """ + + @_LogActionOnce(m_logger=logger, key='TelechatFeedForward', + no_warning=_get_parallel_mode() in (ParallelMode.STAND_ALONE,)) + @_args_type_validator_check(dim=Validator.check_positive_int, + hidden_dim=Validator.check_positive_int, + multiple_of=Validator.check_positive_int, + compute_dtype=_valid_value_checks([mstype.float32, mstype.float16, mstype.bfloat16], + "TelechatFeedForward"), + param_init_type=_valid_value_checks([mstype.float32, mstype.float16, mstype.bfloat16], + "TelechatFeedForward")) + def __init__(self, dim, + intermediate_size=None, + hidden_dim=None, + sigma=0.0048, + mean=0.0, + multiple_of=256, + hidden_act=LlamaSiLU, + ffn_dim_multiplier=None, + compute_dtype=mstype.float16, + param_init_type=mstype.float32, + ffn_concat=False, + parallel_config=default_dpmp_config): + super().__init__() + + if hidden_act is None or not (isinstance(hidden_act, str) or issubclass(hidden_act, nn.Cell)): + raise TypeError(f"For FeedForward cell, the hidden_act should str type or nn.Cell type, " + f"but got {hidden_act}.") + + if intermediate_size is not None: + hidden_dim = intermediate_size + else: + if ffn_dim_multiplier is not None: + hidden_dim = int((ffn_dim_multiplier + 0.01) * hidden_dim) + hidden_dim = int(2 * hidden_dim / 3) + hidden_dim = multiple_of * \ + ((hidden_dim + multiple_of - 1) // multiple_of) + + self.dtype = compute_dtype + self.hidden_act = hidden_act + self.dim = dim + self.hidden_dim = hidden_dim + self.mul = P.Mul() + self.cast = P.Cast() + self.ffn_concat = ffn_concat + if self.ffn_concat: + self.w_gate_hidden = TelechatLinear(in_channels=dim, + out_channels=hidden_dim * 2, + has_bias=False, + sigma=sigma, + mean=mean, + compute_dtype=compute_dtype, + param_init_type=param_init_type) + self.activate = self.hidden_act() + self.split = ms.ops.auto_generate.SplitWithSize() + else: + self.w1 = TelechatLinear(in_channels=dim, + out_channels=hidden_dim, + activation=hidden_act, + has_bias=False, + sigma=sigma, + mean=mean, + compute_dtype=compute_dtype, + param_init_type=param_init_type) + + self.w3 = TelechatLinear(in_channels=dim, + out_channels=hidden_dim, + has_bias=False, + sigma=sigma, + mean=mean, + compute_dtype=compute_dtype, + param_init_type=param_init_type) + + self.w2 = TelechatLinear(in_channels=hidden_dim, + out_channels=dim, + has_bias=False, + sigma=sigma, + mean=mean, + compute_dtype=compute_dtype, + param_init_type=param_init_type) + + def construct(self, x): + """Forward process of the FeedForward""" + _check_input_dtype(F.dtype(x), "x", [mstype.float32, mstype.float16, mstype.bfloat16], self.cls_name) + x = self.cast(x, self.dtype) + + if self.ffn_concat: + gate_hidden_out = self.w_gate_hidden(x) # dp,1 -> dp, mp + gate, hidden = self.split(gate_hidden_out, (self.hidden_dim, self.hidden_dim), 2) + gate = self.activate(gate) + else: + # [bs, seq, hidden_dim] or [bs * seq, hidden_dim] + gate = self.w1(x) # dp,1 -> dp, mp + hidden = self.w3(x) # dp,1 -> dp, mp + hidden = self.mul(hidden, gate) # dp,mp -> dp, mp + output = self.w2(hidden) # dp,mp -> dp, 1 + return output + + def shard(self, parallel_config): + """sharding for feedforward""" + dp = parallel_config.data_parallel + mp = parallel_config.model_parallel + sp = parallel_config.context_parallel + if self.hidden_dim % mp != 0: + raise ValueError(f"For 'FeedForward', the class variable 'hidden_dim' must be a multiple of the" + f"num of model parallel, but got the hidden_dim is {self.hidden_dim} and the num of model " + f"parallel is {mp}.") + if self.dim % mp != 0: + raise ValueError(f"For 'FeedForward', the class variable 'dim' must be a multiple of the num of " + f"model parallel, but got the dim is {self.dim} and the num of model parallel is {mp}.") + if self.ffn_concat: + self.w_gate_hidden.shard(((dp * sp, 1), (mp, 1))) + self.activate.shard(((dp * sp, 1, mp),)) + self.w2.shard(((dp * sp, mp), (1, mp))) + self.split.add_prim_attr("skip_redistribution", True) + self.split.shard(((dp * sp, 1, mp),)) + self.mul.shard(((dp * sp, mp), (dp * sp, mp))) + else: + self.w1.shard(((dp * sp, 1), (mp, 1))) + self.w1.activation.shard(((dp * sp, mp),)) + self.w2.shard(((dp * sp, mp), (1, mp)), ((dp * sp, 1), (1,))) + self.w3.shard(((dp * sp, 1), (mp, 1))) + self.mul.shard(((dp * sp, mp), (dp * sp, mp))) diff --git a/research/telechat3/telechat_tokenizer.py b/research/telechat3/telechat_tokenizer.py new file mode 100644 index 000000000..cd9f95d4d --- /dev/null +++ b/research/telechat3/telechat_tokenizer.py @@ -0,0 +1,278 @@ +# Copyright 2025 TeleAI Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Telechat tokenizer APIs.""" +# coding: utf-8 +import os +from shutil import copyfile +from typing import Any, Dict, List, Optional + +import sentencepiece as spm + +from mindformers.tools import logger +from mindformers.models.tokenization_utils import PreTrainedTokenizer, AddedToken +from mindformers.tools.register import MindFormerRegister, MindFormerModuleType + + +VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"} + + +@MindFormerRegister.register(MindFormerModuleType.TOKENIZER) +class TelechatTokenizer(PreTrainedTokenizer): + r""" + Tokenize the input string and convert them into the ids. The tokenizer use the sentence piece internally. + + Args: + model_path(str): The spiece.model file path. + add_bos(bool): The flag defines whether add bos token, Default True. + eos_token(str): The token that represents the end-of-sentence. Default "". + unk_token(str): The token that represents the unknown. Default "". + pad_token(str): The token that represents the pad. Default "". + sp_model_kwargs(str): Other kwargs for sp_model`. + add_bos_token(bool): Whether or not to add the bos_token_id to the left of the input. Default "True" + add_eos_token(bool): Whether or not to add the eos_token_id to the right of the input. Default "True" + clean_up_tokenization_spaces (bool): Whether or not the model should cleanup the spaces that were added when + splitting the input text during the tokenization process. Default "False" + **kwargs: Other kwargs that will be passed into the base class of the `Tokenizer`. + + Outputs: + A dict contains the processed ids, attention_mask that specific by the member `MODEL_INPUT_NAME` + of the subclass. + """ + + vocab_files_names = VOCAB_FILES_NAMES + model_input_names = ["input_ids", "attention_mask"] + FILE_LIST = ['tokenizer_config.json'] + + def __init__( + self, + vocab_file, + unk_token="", + bos_token="<_start>", + eos_token="<_end>", + pad_token="<_pad>", + usr_token="<_user>", + bot_token="<_bot>", + sys_token="<_system>", + call_start_token="", + call_end_token="", + repo_start_token="", + repo_end_token="", + chat_template=None, + sp_model_kwargs: Optional[Dict[str, Any]] = None, + add_bos_token=False, + add_eos_token=False, + clean_up_tokenization_spaces=False, + **kwargs, + ): + self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs + unk_token = AddedToken(unk_token, lstrip=False, rstrip=False, single_word=False, normalized=False) \ + if isinstance(unk_token, str) else unk_token + bos_token = AddedToken(bos_token, \ + lstrip=False, rstrip=False, single_word=False, normalized=False, special=True) + eos_token = AddedToken(eos_token, \ + lstrip=False, rstrip=False, single_word=False, normalized=False, special=True) + pad_token = AddedToken(pad_token, \ + lstrip=False, rstrip=False, single_word=False, normalized=False, special=True) + usr_token = AddedToken(usr_token, \ + lstrip=False, rstrip=False, single_word=False, normalized=False, special=True) + bot_token = AddedToken(bot_token, \ + lstrip=False, rstrip=False, single_word=False, normalized=False, special=True) + sys_token = AddedToken(sys_token, \ + lstrip=False, rstrip=False, single_word=False, normalized=False, special=True) + call_start_token = AddedToken(call_start_token, \ + lstrip=False, rstrip=False, single_word=False, normalized=False, special=True) + call_end_token = AddedToken(call_end_token, \ + lstrip=False, rstrip=False, single_word=False, normalized=False, special=True) + repo_start_token = AddedToken(repo_start_token, \ + lstrip=False, rstrip=False, single_word=False, normalized=False, special=True) + repo_end_token = AddedToken(repo_end_token, \ + lstrip=False, rstrip=False, single_word=False, normalized=False, special=True) + self.vocab_file = vocab_file + self.add_bos_token = add_bos_token + self.add_eos_token = add_eos_token + self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) + self.sp_model.Load(vocab_file) + + self.chat_template = chat_template + + super().__init__( + bos_token=bos_token, + eos_token=eos_token, + unk_token=unk_token, + pad_token=pad_token, + add_bos_token=add_bos_token, + add_eos_token=add_eos_token, + sp_model_kwargs=self.sp_model_kwargs, + clean_up_tokenization_spaces=clean_up_tokenization_spaces, + chat_template=self.chat_template, + **kwargs, + ) + self.add_tokens([bos_token, eos_token, pad_token, usr_token, bot_token, sys_token, \ + call_start_token, call_end_token, repo_start_token, repo_end_token]) + + def __getstate__(self): + state = self.__dict__.copy() + state["sp_model"] = None + return state + + def __setstate__(self, d): + self.__dict__ = d + self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) + self.sp_model.Load(self.vocab_file) + + @property + def vocab_size(self): + """Returns vocab size""" + return self.sp_model.get_piece_size() + + def get_vocab(self): + """Returns vocab as a dict""" + vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)} + vocab.update(self.added_tokens_encoder) + return vocab + + def _tokenize(self, text): + """Returns a tokenized string.""" + return self.sp_model.encode(text, out_type=str) + + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + return self.sp_model.piece_to_id(token) + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + token = self.sp_model.IdToPiece(index) + return token + + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (string) in a single string.""" + current_sub_tokens = [] + out_string = "" + for token in tokens: + # make sure that special tokens are not decoded using sentencepiece model + if token in self.all_special_tokens: + out_string += self.sp_model.decode(current_sub_tokens) + token + current_sub_tokens = [] + else: + current_sub_tokens.append(token) + out_string += self.sp_model.decode(current_sub_tokens) + return out_string + + # pylint: disable=R1710 + def save_vocabulary(self, save_directory, filename_prefix=None): + """ + Save the vocabulary and special tokens file to a directory. + + Args: + save_directory (`str`): + The directory in which to save the vocabulary. + + Returns: + `Tuple(str)`: Paths to the files saved. + """ + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory") + return None + out_vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + + if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file): + copyfile(self.vocab_file, out_vocab_file) + elif not os.path.isfile(self.vocab_file): + flags_ = os.O_WRONLY | os.O_CREAT | os.O_TRUNC + with os.fdopen(os.open(out_vocab_file, flags_, 0o750), 'wb') as fi: + content_spiece_model = self.sp_model.serialized_model_proto() + fi.write(content_spiece_model) + + return out_vocab_file + + def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): + bos_token_id = [self.bos_token_id] if self.add_bos_token else [] + eos_token_id = [self.eos_token_id] if self.add_eos_token else [] + + output = bos_token_id + token_ids_0 + eos_token_id + + if token_ids_1 is not None: + output = output + bos_token_id + token_ids_1 + eos_token_id + + return output + + def get_special_tokens_mask(self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, + already_has_special_tokens: bool = False): + """ + Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer `prepare_for_model` method. + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the token list is already formatted with special tokens for the model. + + Returns: + `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + if already_has_special_tokens: + return super().get_special_tokens_mask( + token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True + ) + + bos_token_id = [1] if self.add_bos_token else [] + eos_token_id = [1] if self.add_eos_token else [] + + if token_ids_1 is None: + return bos_token_id + ([0] * len(token_ids_0)) + eos_token_id + return ( + bos_token_id + + ([0] * len(token_ids_0)) + + eos_token_id + + bos_token_id + + ([0] * len(token_ids_1)) + + eos_token_id + ) + + def create_token_type_ids_from_sequences(self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None): + """ + Creates a mask from the two sequences passed to be used in a sequence-pair classification task. An ALBERT + sequence pair mask has the following format: + + ``` + 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 + | first sequence | second sequence | + ``` + + if token_ids_1 is None, only returns the first portion of the mask (0s). + + Args: + token_ids_0 (`List[int]`): + List of ids. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s). + """ + bos_token_id = [self.bos_token_id] if self.add_bos_token else [] + eos_token_id = [self.eos_token_id] if self.add_eos_token else [] + + output = [0] * len(bos_token_id + token_ids_0 + eos_token_id) + + if token_ids_1 is not None: + output += [1] * len(bos_token_id + token_ids_1 + eos_token_id) + + return output diff --git a/research/telechat3/telechat_transformer.py b/research/telechat3/telechat_transformer.py new file mode 100644 index 000000000..13193842a --- /dev/null +++ b/research/telechat3/telechat_transformer.py @@ -0,0 +1,594 @@ +# Copyright 2025 TeleAI Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Telechat transformer Layer's APIs.""" +import math +from typing import Tuple, Optional + +import mindspore as ms +from mindspore import nn +import mindspore.common.dtype as mstype +from mindspore.common.tensor import Tensor +from mindspore.context import ParallelMode +from mindspore.ops import operations as P +from mindspore.parallel._utils import _get_parallel_mode, _is_sharding_propagation +from mindspore.parallel.shard import Layout + +from research.telechat3.telechat_layer import TelechatLinear, TelechatFeedForward +from mindformers.models.llama.llama_layer import LlamaRMSNorm +from mindformers.models.utils import predict_lazy_inline +from mindformers.modules.layers import _check_input_dtype, RotaryEmbedding +from mindformers.modules.transformer import TransformerOpParallelConfig +from mindformers.modules.flash_attention import FlashAttention +from mindformers.modules.infer_attention import InferAttention +from mindformers.tools.logger import logger +from mindformers.tools.utils import get_predict_run_mode + + +class TelechatAttention(nn.Cell): + r""" + This is an implementation of multihead attention in Telechat. + + Args: + - **dim** (int): The hidden size of the input. + - **head_dim** (int): The dim of head. + - **n_heads** (int): The number of the heads. + - **compute_dtype** (dtype.Number): The computation type of dense. Default mstype.float16. + Should be mstype.float32 or mstype.float16. + - **softmax_compute_type** (dtype.Number): The type of softmax computation module. Default mstype.float32. + Should be mstype.float32 or mstype.float16. + - **param_init_type** (dtype.Number): The parameter initialization type of the module. Default mstype. + float32. Should be mstype.float32 or mstype.float16. + - **qkv_has_bias** (bool): Whether Q/K/V in attention has bias or not. + - **use_past** (bool): Use the past state to compute, used for incremental prediction. + For example, if we have two words and want to generate the ten more words. + We just need to compute the two words' state only once, and generate the next word one by one. + When use_past is True, there are two steps to run the prediction. + In the first step, set the is_first_iteration to be True by + `model.add_flags_recursive(is_first_iteration=True)`, and pass the full inputs. Then, set the + is_first_iteration to be False by `model.add_flags_recursive(is_first_iteration=False)`. At this moment, + pass the single step's input tensor, and loop it. Default False. + - **parallel_config** (OpParallelConfig): The parallel configure. Default `default_dpmp_config`, + an instance of `OpParallelConfig` with default args. + + Inputs: + - **x** (Tensor) - The input tokens with shape (batch_size, src_seq_length, hidden_size) or + (batch_size * src_seq_length, hidden_size), if the use_past is False or is_first_iteration=True. + Otherwise, must be (batch_size, 1, hidden_size) + - **freqs_cis** (Tuple) - The precompute freqs and mask for rotary position embedding used in attention. + - **attention_mask** (Tensor) - If the use_past is False or is_first_iteration=True, the attention mask + matrix should ba (batch_size, src_seq_length, tgt_seq_length), or None. None means there will be no mask + in softmax computation. Otherwise, the mask must be (batch_size, 1, tgt_seq_length) + - **batch_valid_length** (Tensor) - Int32 tensor with shape (batch_size,) the past calculated the index. + Used for incremental prediction when the use_past is True. Default None. + - **block_tables** (Tensor[int64]) - Store mapping tables for each sequence. + - **slot_mapping** (Tensor[int32]) - Store token cache physical slot index. + Outputs: + Tuple, a tuple contains(`output`, `layer_present`) + + - **output** (Tensor) - Tensor, the float tensor of the output of the layer with + shape (batch_size, src_seq_length, hidden_size) or (batch_size * src_seq_length, hidden_size), + if the use_past is False or is_first_iteration=True. Otherwise, it will be (batch_size, 1, hidden_size). + + - **layer_present** (Tuple) - A tuple of the Tensor of the projected key and value vector with + ((batch_size, num_heads, head_dim, tgt_seq_length), + (batch_size, num_heads, tgt_seq_length, head_dim)). + """ + + def __init__(self, dim: int = 512, + n_heads: int = 8, + n_kv_heads: Optional[int] = None, + sigma: float = 0.0048, + mean: float = 0.0, + compute_dtype=mstype.float16, + softmax_compute_dtype=mstype.float32, + rotary_dtype=mstype.float32, + param_init_type=mstype.float32, + qkv_has_bias=False, + out_proj_has_bias=True, + qkv_concat=False, + use_past=False, + is_dynamic=False, + use_rope_slice=False, + use_flash_attention=False, + use_attn_mask_compression=False, + block_size: Optional[int] = None, + num_blocks: Optional[int] = None, + parallel_config=TransformerOpParallelConfig()): + super().__init__() + self.hidden_size = dim + self.n_head = n_heads + self.head_dim = dim // n_heads + self.n_kv_head = n_heads if n_kv_heads is None else n_kv_heads + self.n_rep = self.n_head // self.n_kv_head + self.kv_dim = self.n_kv_head * self.head_dim + self.block_size = block_size + self.num_blocks = num_blocks + self.sigma = sigma + self.mean = mean + self.dtype = compute_dtype + self.softmax_dtype = softmax_compute_dtype + self.is_first_iteration = True + self.use_past = use_past + self.use_flash_attention = use_flash_attention + self.use_attn_mask_compression = use_attn_mask_compression + self.qkv_concat = qkv_concat + self.qkv_has_bias = qkv_has_bias + + if self.hidden_size % self.n_head != 0: + raise ValueError( + f"For 'MultiHeadAttention', the class variable 'n_kv_head' must be a multiple of " + f"'parallel_config.model_parallel', but got the n_kv_head is {self.n_kv_head} " + f"and the parallel_config.model_parallel is {parallel_config.model_parallel}." + ) + if self.n_kv_head % parallel_config.model_parallel != 0: + raise ValueError( + f"MultiHeadAttention error: n_kv_head ({self.n_kv_head}) must be a multiple of " + f"parallel_config.model_parallel ({parallel_config.model_parallel})" + ) + dp = parallel_config.data_parallel + mp = parallel_config.model_parallel + self.sp = parallel_config.context_parallel + self.shape = P.Shape() + self.cast = P.Cast() + self.reshape = P.Reshape() + + if self.qkv_concat: + self.w_qkv = TelechatLinear(self.hidden_size, + self.hidden_size + self.n_kv_head * self.head_dim * 2, + has_bias=qkv_has_bias, + sigma=sigma, + mean=mean, + compute_dtype=compute_dtype, + param_init_type=param_init_type) + if self.qkv_has_bias: + self.w_qkv.shard(((dp * self.sp, 1), (mp, 1)), ((dp * self.sp, mp), (mp,))) + else: + self.w_qkv.shard(((dp * self.sp, 1), (mp, 1))) + self.split_qkv = ms.ops.auto_generate.SplitWithSize() + self.split_qkv.add_prim_attr("skip_redistribution", True) + self.split_qkv.shard(((dp * self.sp, 1, mp),)) + else: + self.wq = TelechatLinear(self.hidden_size, + self.hidden_size, + sigma=self.sigma, + mean=self.mean, + has_bias=qkv_has_bias, + compute_dtype=compute_dtype, + param_init_type=param_init_type) + self.wk = TelechatLinear(self.hidden_size, + self.n_kv_head * self.head_dim, + has_bias=qkv_has_bias, + sigma=self.sigma, + mean=self.mean, + compute_dtype=compute_dtype, + param_init_type=param_init_type) + self.wv = TelechatLinear(self.hidden_size, + self.n_kv_head * self.head_dim, + has_bias=qkv_has_bias, + sigma=self.sigma, + mean=self.mean, + compute_dtype=compute_dtype, + param_init_type=param_init_type) + + if qkv_has_bias: + self.wq.shard(((dp * self.sp, 1), (mp, 1)), ((dp * self.sp, mp), (mp,))) + self.wk.shard(((dp * self.sp, 1), (mp, 1)), ((dp * self.sp, mp), (mp,))) + self.wv.shard(((dp * self.sp, 1), (mp, 1)), ((dp * self.sp, mp), (mp,))) + else: + self.wq.shard(((dp * self.sp, 1), (mp, 1))) + self.wk.shard(((dp * self.sp, 1), (mp, 1))) + self.wv.shard(((dp * self.sp, 1), (mp, 1))) + + self.wo = TelechatLinear(in_channels=self.hidden_size, + out_channels=self.hidden_size, + sigma=self.sigma, + mean=self.mean, + has_bias=out_proj_has_bias, + compute_dtype=compute_dtype, + param_init_type=param_init_type) + if out_proj_has_bias: + self.wo.shard(((dp * self.sp, mp), (1, mp)), ((dp * self.sp, 1), (1,)), + out_strategy_matmul=((dp * self.sp, 1),)) + else: + self.wo.shard(((dp * self.sp, mp), (1, mp)), out_strategy_matmul=((dp * self.sp, 1),)) + + if self.use_past: + self.infer_attention = InferAttention(self.n_head, + self.head_dim, + self.n_kv_head, + pa_n_head_split=self.n_head // mp, + pa_n_kv_head_split=self.n_kv_head // mp, + scale_value=1. / math.sqrt(self.head_dim), + pre_tokens=2147483647, + next_tokens=0, + block_size=self.block_size, + num_blocks=self.num_blocks, + is_dynamic=is_dynamic, + use_flash_attention=self.use_flash_attention, + rotary_cos_format=2, + compute_dtype=compute_dtype) + self.infer_attention.shard(parallel_config) + else: + self.inv_norm_factor = Tensor(1.0 / math.sqrt(self.head_dim), dtype=compute_dtype) + + self.transpose = P.Transpose() + self.merger_head_transpose = P.Transpose() + self.batch_matmul = P.BatchMatMul() + self.batch_matmul_q_k = P.BatchMatMul(transpose_b=True) + self.mul = P.Mul() + self.add = P.Add() + self.softmax = P.Softmax() + self.cast_attn = P.Cast() + self.tile_kv = P.Tile() + + self.apply_rotary_emb = RotaryEmbedding(self.head_dim, rotary_dtype, use_rope_slice=use_rope_slice) + + if not (_get_parallel_mode() in (ParallelMode.AUTO_PARALLEL,) and _is_sharding_propagation()): + self.transpose.shard(((dp, self.sp, mp, 1),)) + layout = Layout((dp, self.sp, mp), ("dp", "sp", "mp")) + layout_merger_transpose = (layout("dp", "mp", "sp", "None"),) + self.merger_head_transpose.shard(in_strategy=layout_merger_transpose) + self.batch_matmul_q_k.shard(((dp, mp, 1, 1), (dp, mp, 1, 1))) + self.batch_matmul.shard(((dp, mp, 1, 1), (dp, mp, 1, 1))) + self.mul.shard(((dp, mp, 1, 1), ())) + self.add.shard(((dp, 1, 1, 1), (dp, mp, 1, 1))) + self.softmax.shard(((dp, mp, 1, 1),)) + self.tile_kv.shard(((dp, mp, 1, 1),)) + + self.apply_rotary_emb.shard(parallel_config) + if parallel_config.use_seq_parallel and self.is_first_iteration: + self.wo.shard(((dp * self.sp, mp), (1, mp)), out_strategy_matmul=((dp * mp * self.sp, 1),)) + if parallel_config.recompute.select_recompute and not self.use_flash_attention: + self.apply_rotary_emb.recompute() + self.tile_kv.recompute() + self.batch_matmul_q_k.recompute() + self.mul.recompute() + self.add.recompute() + self.cast_attn.recompute() + self.softmax.recompute() + self.batch_matmul.recompute() + + if self.use_flash_attention: + self.input_layout = "BSH" if self.sp > 1 else "BNSD" + self.sparse_mode = 2 if self.use_attn_mask_compression else 0 + self.flash_attention = FlashAttention(head_num=self.n_head, + pre_tokens=65536, + next_tokens=0, + input_layout=self.input_layout, + scale_value=1. / math.sqrt(self.head_dim), + sparse_mode=self.sparse_mode, + use_attention_mask=True) + self.flash_attention.shard(parallel_config) + + def construct(self, x: Tensor, freqs_cis: Tuple[Tensor, Tensor], mask=None, batch_valid_length=None, + block_tables=None, slot_mapping=None, prefix_keys_values=None): + """Forward process of the MultiHeadAttention""" + ori_dtype = x.dtype + # [bs, seq/1, hidden_dim] + bs, seq_len, _ = self.shape(x) + + if self.qkv_concat: + qkv = self.cast(self.w_qkv(x), self.dtype) + query, key, value = self.split_qkv(qkv, (self.hidden_size, self.kv_dim, self.kv_dim), 2) + else: + query = self.cast(self.wq(x), self.dtype) # dp, 1 -> dp, mp + key = self.cast(self.wk(x), self.dtype) # dp, 1 -> dp, mp + value = self.cast(self.wv(x), self.dtype) # dp, 1 -> dp, mp + + # key and value for current token(s) + if self.use_past: + if not self.qkv_concat: + key = self.reshape(key, (bs, seq_len, self.n_kv_head * self.head_dim)) + value = self.reshape(value, (bs, seq_len, self.n_kv_head * self.head_dim)) + context_layer = self.infer_attention(query, key, value, batch_valid_length, block_tables, slot_mapping, + freqs_cis, mask, prefix_keys_values=prefix_keys_values) + else: + query = self.transpose(self.reshape(query, (bs, seq_len, self.n_head, self.head_dim)), (0, 2, 1, 3)) + key = self.transpose(self.reshape(key, (bs, seq_len, self.n_kv_head, self.head_dim)), (0, 2, 1, 3)) + query, key = self.apply_rotary_emb(query, key, freqs_cis) # dp, mp, 1, 1 + + if self.sp > 1: + value = self.reshape(value, (bs, seq_len, self.n_kv_head * self.head_dim)) + else: + value = self.transpose(self.reshape(value, (bs, seq_len, self.n_kv_head, self.head_dim)), (0, 2, 1, 3)) + key, value = self._cat_prefix(key, value, prefix_keys_values) + + if self.sp > 1: + query = self._merge_heads(query) + key = self._merge_heads(key) + + if self.use_flash_attention: + context_layer = self.flash_attention(query, key, value, mask) + if self.sp == 1: + context_layer = self._merge_heads(context_layer) + else: + key = self._repeat_kv(key, self.n_rep) + value = self._repeat_kv(value, self.n_rep) + context_layer = self._attn(query, key, value, mask) + + # [bs, seq/1, hidden_dim] or [bs * seq/1, hidden_dim] + output = self.wo(context_layer) # dp, mp -> dp, 1 / dp * mp, 1 + output = self.cast(output, ori_dtype) + return output + + def _cat_prefix(self, key, value, prefix_keys_values): + r''' + concat prefix_keys_values to key and value + prefix_keys_values: shape(2, bs, pre_len, num_heads * kv_channels) + ''' + if prefix_keys_values is not None: + bs, n_kv_head, _, head_dim = key.shape + past_key = prefix_keys_values[0] + past_value = prefix_keys_values[1] + past_key = self.transpose(self.reshape(past_key, (bs, -1, n_kv_head, head_dim)), (0, 2, 1, 3)) + past_value = self.transpose(self.reshape(past_value, (bs, -1, n_kv_head, head_dim)), (0, 2, 1, 3)) + past_key = self.cast(past_key, self.dtype) + past_value = self.cast(past_value, self.dtype) + cat = P.Concat(2) + key = cat((past_key, key)) + value = cat((past_value, value)) + return key, value + + def _repeat_kv(self, x, rep): + if rep == 1: + return x + bs, n_kv_head, seqlen, head_dim = self.shape(x) + x = self.reshape(x, (bs, n_kv_head, 1, seqlen * head_dim)) + x = self.tile_kv(x, (1, 1, rep, 1)) + x = self.reshape(x, (bs, n_kv_head * rep, seqlen, head_dim)) + return x + + def _merge_heads(self, x): + """ + convert a 4d input to a 3d output + + Inputs: + x: input tensor + + Output: + x_merge: the 2d output + """ + # [bs, n_head, seq/1, head_dim] + x = self.merger_head_transpose(x, (0, 2, 1, 3)) # dp,mp,1,1 -> dp,1,mp,1 + # [bs, seq/1, n_head, head_dim] + bs, seq_len, n_head, head_dim = self.shape(x) + # [bs, seq/1, hidden_dim] + new_shape = (bs, seq_len, n_head * head_dim) + x_merge = self.reshape(x, new_shape) + return x_merge + + def _attn(self, query, key, value, mask): + """ + Get the weighted score along the seq_length + + Inputs: + query: the query matrix + key: the key matrix + value: the value matrix + mask: the attention mask adder matrix with shape (batch_size, + 1, seq_length, seq_length) + Outputs: + weighted_values: Tensor, the weighted sum scores + """ + # q, k: [bs, n_head, seq/1, head_dim], [bs, n_head, seq, head_dim] + score = self.batch_matmul_q_k(query, key) + # score: [bs, n_head, seq/1, seq] + score = self.mul(score, self.inv_norm_factor) + score = self.add(mask, score) + + attention_probs = self.softmax(self.cast_attn(score, self.softmax_dtype)) + # score, v: [bs, n_head, seq/1, seq], [bs, n_head, seq, head_dim] + weighted_values = self.batch_matmul(self.cast(attention_probs, self.dtype), value) + # [bs, n_head, seq/1, head_dim] + attention_merge = self._merge_heads(weighted_values) + # [bs, seq/1, hidden_dim] or [bs * seq/1, hidden_dim] + return attention_merge + + +# pylint: disable=C0326 +class TelechatDecodeLayer(nn.Cell): + r""" + Transformer Layer. This is an implementation of the single layer of the transformer + encoder layer, including multihead attention and feedward layer. + + Args: + layer_id(int): The layer id of current transformer block layer. + dim(int): The hidden size of the input. + num_heads(int): The number of the heads. + multiple_of(int): The SwiGLU hidden layer size multiple of large power of 2. + norm_eps (float): The epsilon value of the denominator. Default 1e-5. + compute_dtype(dtype.Number): The computation type of the layer. + Should be mstype.float32 or mstype.float16. Default mstype.float32. + layernorm_compute_type(dtype.Number): The computation type of the norm. + Should be mstype.float32 or mstype.float16. Default mstype.float32. + softmax_compute_type(dtype.Number): The computation type of the softmax in the attention. + Should be mstype.float32 or mstype.float16. Default mstype.float32. + param_init_type(dtype.Number): The parameter initialization type of the module. + Should be mstype.float32 or mstype.float16. Default mstype.float32. + qkv_has_bias(bool): Whether Q/K/V in attention has bias or not. + use_past(bool): Use the past state to compute, used for incremental prediction. For example, if we have two + words and want to generate the ten more words. We just need to compute the two words' state only once, + and generate the next word one by one. When use_past is True, there are two steps to run the prediction. + In the first step, set the is_first_iteration to be True by + `model.add_flags_recursive(is_first_iteration=True)`, and pass the full inputs. Then, set the + is_first_iteration to be False by `model.add_flags_recursive(is_first_iteration=False)`. + At this moment, pass the single step's input tensor, and loop it. Default False. + parallel_config(OpParallelConfig, MoEParallelConfig): The parallel configure. When MoE is applied, + MoEParallelConfig is effective, otherwise OpParallelConfig is effective. Default `default_dpmp_config`, + an instance of `OpParallelConfig` with default args. + + Inputs: + - **x** (Tensor) - Float Tensor, shape should be [batch_size, seq_length, hidden_size] or + [batch_size * seq_length, hidden_size], if the use_past is False or is_first_iteration=True. Otherwise, + should be [batch_size, 1, hidden_size] + - **freqs_cis** (Tuple) - The precompute freqs and mask for rotary position embedding used in attention. + - **input_mask** (Tensor) - Float Tensor, If the use_past is False or is_first_iteration=True, + the attention mask matrix should ba [batch_size, seq_length, seq_length], or None. None means there will + be no mask in softmax computation. Otherwise, should be [batch_size, 1, hidden_size] + - **init_reset** (Tensor) - A bool tensor with shape [1], used to clear the past key parameter and + past value parameter used in the incremental prediction. Only valid when use_past is True. Default True. + - **batch_valid_length** (Tensor) - Int32 tensor with shape [batch_size] the past calculated the index. + Used for incremental prediction when the use_past is True. Default None. + - **block_tables** (Tensor[int64]) - Store mapping tables for each sequence. + - **slot_mapping** (Tensor[int32]) - Store token cache physical slot index. + Outputs: + Tuple, a tuple contains(`output`, `layer_present`). + + - **output** (Tensor) - The float tensor of the output of the layer with + shape (batch_size, seq_length, hidden_size) or (batch_size * seq_length, hidden_size), if the use_past is + False or is_first_iteration=True. Otherwise, it will be (batch_size, 1, hidden_size) + + - **layer_present** (Tuple) - A tuple of the Tensor of the projected key and value vector with + ((batch_size, num_heads, head_dim, seq_length), + (batch_size, num_heads, seq_length, head_dim)). + + """ + + @predict_lazy_inline + def __init__(self, layer_id, + dim: int = 512, + n_heads: int = 8, + sigma: float = 0.0048, + mean: float = 0.0, + n_kv_heads: Optional[int] = None, + intermediate_size: Optional[int] = None, + multiple_of: int = 256, + ffn_dim_multiplier: Optional[int] = None, + norm_eps: float = 1e-5, + compute_dtype=mstype.float16, + layernorm_compute_dtype=mstype.float32, + softmax_compute_dtype=mstype.float32, + rotary_dtype=mstype.float32, + param_init_type=mstype.float32, + res_dtype=mstype.float32, + qkv_has_bias=False, + out_proj_has_bias=True, + qkv_concat=False, + use_past=False, + is_dynamic=False, + use_rope_slice=False, + use_flash_attention=False, + use_attn_mask_compression=False, + block_size: Optional[int] = None, + num_blocks: Optional[int] = None, + parallel_config=TransformerOpParallelConfig()): + super().__init__() + self.layer_id = layer_id + self.hidden_size = dim + self.n_head = n_heads + self.head_dim = self.hidden_size // self.n_head + self.n_kv_head = n_heads if n_kv_heads is None else n_kv_heads + self.dtype = compute_dtype + self.res_dtype = res_dtype + self.is_first_iteration = True + self.use_past = use_past + + self.sigma = sigma + self.mean = mean + self.qkv_concat = qkv_concat + self.shape = P.Shape() + self.reshape = P.Reshape() + self.cast = P.Cast() + self.add = P.Add() + self.ffn_norm = LlamaRMSNorm(self.hidden_size, norm_eps, compute_type=layernorm_compute_dtype) + self.attention_norm = LlamaRMSNorm(self.hidden_size, norm_eps, compute_type=layernorm_compute_dtype) + self.attention = TelechatAttention(dim=dim, + n_heads=n_heads, + n_kv_heads=n_kv_heads, + sigma=self.sigma, + mean=self.mean, + compute_dtype=compute_dtype, + softmax_compute_dtype=softmax_compute_dtype, + rotary_dtype=rotary_dtype, + param_init_type=param_init_type, + qkv_has_bias=qkv_has_bias, + out_proj_has_bias=out_proj_has_bias, + qkv_concat=qkv_concat, + use_past=use_past, + is_dynamic=is_dynamic, + use_rope_slice=use_rope_slice, + use_flash_attention=use_flash_attention, + use_attn_mask_compression=use_attn_mask_compression, + block_size=block_size, + num_blocks=num_blocks, + parallel_config=parallel_config) + self.feed_forward = TelechatFeedForward(dim=self.hidden_size, + intermediate_size=intermediate_size, + hidden_dim=4 * self.hidden_size, + sigma=self.sigma, + mean=self.mean, + multiple_of=multiple_of, + ffn_dim_multiplier=ffn_dim_multiplier, + ffn_concat=self.qkv_concat, + compute_dtype=compute_dtype, + param_init_type=param_init_type, + parallel_config=parallel_config_new) + dp = parallel_config.data_parallel + mp = parallel_config.model_parallel + sp = parallel_config.context_parallel + if not (_get_parallel_mode() in (ParallelMode.AUTO_PARALLEL,) and _is_sharding_propagation()): + self.feed_forward.shard(parallel_config) + self.add.shard(((dp, sp, 1), (dp, sp, 1))) + self.attention_norm.shard((dp, sp, 1)) + self.ffn_norm.shard((dp, sp, 1)) + + if parallel_config.use_seq_parallel and self.is_first_iteration: + self.add.shard(((dp, mp * sp, 1), (dp, mp * sp, 1))) + self.attention_norm.shard((dp, mp * sp, 1)) + self.ffn_norm.shard((dp, mp * sp, 1)) + self.feed_forward.w2.shard(((dp * sp, mp), (1, mp)), ((dp * mp * sp, 1), (1, )), + out_strategy_matmul=((dp * mp * sp, 1),)) + + self.predict_run_mode = get_predict_run_mode() + logger.info(f"Predict run mode:{self.predict_run_mode}") + + if self.predict_run_mode: + self.no_inline = False + + def construct(self, x, freqs_cis, mask=None, batch_valid_length=None, block_tables=None, + slot_mapping=None, aux_loss=None, prefix_keys_values=None): + """ Forward of transformer block. """ + if not self.use_past: + self._check_input(x, freqs_cis, mask) + # [bs, seq/1, hidden_dim] + input_x = self.attention_norm(x) + # [bs, seq/1, hidden_dim] + h = self.attention(input_x, freqs_cis, mask, batch_valid_length, block_tables, + slot_mapping, prefix_keys_values) + h = self.add(self.cast(x, self.res_dtype), self.cast(h, self.res_dtype)) + ffn_norm = self.ffn_norm(h) + # [bs, seq/1, hidden_dim] + ffn_out = self.feed_forward(ffn_norm) + # [bs, seq/1, hidden_dim] or [bs * seq/1, hidden_dim] + out = self.add(self.cast(h, self.res_dtype), self.cast(ffn_out, self.res_dtype)) + return out + + def _check_input(self, x, freqs_cis, mask): + r"""Check inputs""" + _check_input_dtype( + x.dtype, "x", [mstype.float32, mstype.float16, mstype.bfloat16], self.cls_name) + freqs_cos, freqs_sin, swap_mask = freqs_cis + _check_input_dtype(freqs_cos.dtype, "freqs_cos", + [mstype.float32, mstype.float16, mstype.bfloat16], self.cls_name) + _check_input_dtype(freqs_sin.dtype, "freqs_sin", + [mstype.float32, mstype.float16, mstype.bfloat16], self.cls_name) + if swap_mask is not None: + _check_input_dtype(swap_mask.dtype, "swap_mask", + [mstype.float32, mstype.float16, mstype.bfloat16], self.cls_name) + if mask is not None: + _check_input_dtype(mask.dtype, "input_mask", + [mstype.float32, mstype.float16, mstype.bfloat16, mstype.uint8, mstype.bool_], + self.cls_name) + return True -- Gitee