From 22bab3a0f296372bf9b1bc0cc91e811dfcaa3cea Mon Sep 17 00:00:00 2001 From: Yule100 Date: Fri, 19 Dec 2025 10:02:22 +0800 Subject: [PATCH] telechat2 infer st --- .../test_model/test_telechat2/__init__.py | 15 ++++ .../test_telechat2_infer/__init__.py | 15 ++++ .../run_telechat2_infer.py | 73 +++++++++++++++++++ .../test_telechat2_infer/telechat2_infer.yaml | 39 ++++++++++ .../test_telechat2_infer.py | 58 +++++++++++++++ 5 files changed, 200 insertions(+) create mode 100644 tests/st/test_multi_cards_cases/test_model/test_telechat2/__init__.py create mode 100644 tests/st/test_multi_cards_cases/test_model/test_telechat2/test_telechat2_infer/__init__.py create mode 100644 tests/st/test_multi_cards_cases/test_model/test_telechat2/test_telechat2_infer/run_telechat2_infer.py create mode 100644 tests/st/test_multi_cards_cases/test_model/test_telechat2/test_telechat2_infer/telechat2_infer.yaml create mode 100644 tests/st/test_multi_cards_cases/test_model/test_telechat2/test_telechat2_infer/test_telechat2_infer.py diff --git a/tests/st/test_multi_cards_cases/test_model/test_telechat2/__init__.py b/tests/st/test_multi_cards_cases/test_model/test_telechat2/__init__.py new file mode 100644 index 000000000..1d1a3b364 --- /dev/null +++ b/tests/st/test_multi_cards_cases/test_model/test_telechat2/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""test mcore telechat2.""" diff --git a/tests/st/test_multi_cards_cases/test_model/test_telechat2/test_telechat2_infer/__init__.py b/tests/st/test_multi_cards_cases/test_model/test_telechat2/test_telechat2_infer/__init__.py new file mode 100644 index 000000000..1d1a3b364 --- /dev/null +++ b/tests/st/test_multi_cards_cases/test_model/test_telechat2/test_telechat2_infer/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""test mcore telechat2.""" diff --git a/tests/st/test_multi_cards_cases/test_model/test_telechat2/test_telechat2_infer/run_telechat2_infer.py b/tests/st/test_multi_cards_cases/test_model/test_telechat2/test_telechat2_infer/run_telechat2_infer.py new file mode 100644 index 000000000..8219fd679 --- /dev/null +++ b/tests/st/test_multi_cards_cases/test_model/test_telechat2/test_telechat2_infer/run_telechat2_infer.py @@ -0,0 +1,73 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""mcore telechat2 model ST of inference""" +import argparse +import os +from transformers import AutoTokenizer + +from mindspore.nn.utils import no_init_parameters + +from mindformers import AutoModel, build_context, MindFormerConfig +from mindformers.core.parallel_config import build_parallel_config +from mindformers.tools.logger import logger + + +def test_telechat2_predict_mcore(device_num: int = 1): + """ + Feature: Mcore TeleChat2 predict task + Description: Two-card tp parallel + Expectation: Success or assert precision failed + """ + max_decode_length = 32 + config_path = os.path.join(os.path.dirname(__file__), "telechat2_infer.yaml") + config = MindFormerConfig(config_path) + config.use_parallel = device_num > 1 + config.parallel_config.model_parallel = device_num + config.pretrained_model_dir = "/home/workspace/mindspore_dataset/weight/telechat2_7b" + # Reduced layer network + config.model.model_config.num_hidden_layers = 2 + build_context(config) + build_parallel_config(config) + # Auto tokenizer + tokenizer = AutoTokenizer.from_pretrained(config.pretrained_model_dir, trust_remote_code=True) + # init network + with no_init_parameters(): + network = AutoModel.from_config(config) + network.load_weights(config.pretrained_model_dir) + # Build prompt + question = "Please introduce some scenic spots in Beijing." + + input_ids = tokenizer.encode(question) + + output = network.generate(input_ids, + max_length=max_decode_length, + do_sample=False, + return_dict_in_generate=False) + + output_text = tokenizer.decode(output[0]) + logger.info("test_telechat2_predict, output_text: %s", str(output_text)) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description="Run TeleChat2 ST") + parser.add_argument("--device_num", type=int, default=2) + + args = parser.parse_args() + os.environ['MS_ENABLE_LCCL'] = "off" + os.environ['HCCL_DETERMINICTIC'] = "true" + os.environ['LCCL_DETERMINICTIC'] = "1" + os.environ['ASCEND_LAUNCH_BLOCKING'] = "1" + os.environ['CUSTOM_MATMUL_SHUFFLE'] = "off" + test_telechat2_predict_mcore(args.device_num) diff --git a/tests/st/test_multi_cards_cases/test_model/test_telechat2/test_telechat2_infer/telechat2_infer.yaml b/tests/st/test_multi_cards_cases/test_model/test_telechat2/test_telechat2_infer/telechat2_infer.yaml new file mode 100644 index 000000000..f2341d2fd --- /dev/null +++ b/tests/st/test_multi_cards_cases/test_model/test_telechat2/test_telechat2_infer/telechat2_infer.yaml @@ -0,0 +1,39 @@ +seed: 0 +output_dir: './output' # path to save checkpoint/strategy +load_checkpoint: '' +use_parallel: True +run_mode: 'predict' +use_legacy: False +load_ckpt_format: 'safetensors' +infer_precision_sync: True + +trainer: + type: CausalLanguageModelingTrainer + model_name: 'telechat2' + +# default parallel of device num = 8 for Atlas 800T A2 +parallel_config: + data_parallel: 1 + model_parallel: 2 +# HuggingFace file directory +pretrained_model_dir: '/path/hf_dir' +model: + model_config: + compute_dtype: "bfloat16" + layernorm_compute_dtype: "float32" + rotary_dtype: "bfloat16" + params_dtype: "bfloat16" + +# mindspore context init config +context: + mode: 0 #0--Graph Mode; 1--Pynative Mode + max_device_memory: "28GB" + device_id: 0 + device_target: "Ascend" + affinity_cpu_list: None + deterministic: "ON" + +# parallel context config +parallel: + parallel_mode: "MANUAL_PARALLEL" + full_batch: False diff --git a/tests/st/test_multi_cards_cases/test_model/test_telechat2/test_telechat2_infer/test_telechat2_infer.py b/tests/st/test_multi_cards_cases/test_model/test_telechat2/test_telechat2_infer/test_telechat2_infer.py new file mode 100644 index 000000000..7720dfee1 --- /dev/null +++ b/tests/st/test_multi_cards_cases/test_model/test_telechat2/test_telechat2_infer/test_telechat2_infer.py @@ -0,0 +1,58 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Test Mcore TeleChat2 inference""" +import os +import random +from pathlib import Path + +import pytest + +from tests.st.test_multi_cards_cases.utils import TaskType +from mindformers.tools.logger import logger + + +_LEVEL_0_TASK_TIME = 80 +_LEVEL_1_TASK_TIME = 0 +_TASK_TYPE = TaskType.TWO_CARDS_TASK + + +class TestMcoreTeleChat2ParallelInference: + """Test class for TeleChat2 in inference""" + + def setup_method(self): + """Setup method to prepare test environment""" + self.sh_path = Path(__file__).parent.resolve() + self.run_script_path = self.sh_path / "run_telechat2_infer.py" + assert self.run_script_path.exists(), f"Run script not found: {self.run_script_path}" + + @pytest.mark.level1 + def test_two_cards_cases(self): + """Test two cards for TeleChat2.""" + port_id = int(os.environ.get("ASCEND_PORT_ID", random.randint(50000, 65535))) + cmd_list = [ + "msrun", + "--worker_num=2", + "--local_worker_num=2", # Should match NPU cards available + f"--master_port={port_id}", # Ensure port is unique per test run if parallelized at pytest level + "--log_dir=./msrun_log_telechat2", + "--join=True"] + cmd_list += [ + str(self.run_script_path), + "--device_num=2" + ] + cmd = " ".join(cmd_list) + logger.info(f"Running command: {cmd}") + return_code = os.system(cmd) + assert return_code == 0, "TeleChat2 inference st failed." -- Gitee