From a8dbc4e98013e40fdaa1b1dee27567fa8fb28aca Mon Sep 17 00:00:00 2001 From: zhanzhan1 Date: Tue, 9 Sep 2025 16:43:00 +0800 Subject: [PATCH] pp add testcase and fix v0 --- tests/st/python/test_mcore_mix_parallel.py | 104 +++++++++++++-------- vllm_mindspore/worker/worker.py | 19 +++- 2 files changed, 83 insertions(+), 40 deletions(-) diff --git a/tests/st/python/test_mcore_mix_parallel.py b/tests/st/python/test_mcore_mix_parallel.py index d7bc7a31..d7498687 100644 --- a/tests/st/python/test_mcore_mix_parallel.py +++ b/tests/st/python/test_mcore_mix_parallel.py @@ -63,7 +63,7 @@ common_qwen_expect_result = '\n好的' quant_type = 'ascend' -def dp_func(dp_size, local_dp_rank, global_dp_rank, tp_size, ep_size, +def dp_func(dp_size, local_dp_rank, global_dp_rank, tp_size, ep_size, pp_size, dp_master_port, prompts, expect_list, result_q, model_path, quantization): dp_master_ip = "127.0.0.1" @@ -93,6 +93,7 @@ def dp_func(dp_size, local_dp_rank, global_dp_rank, tp_size, ep_size, gpu_memory_utilization = 0.7 if model_path == ds_model_path else 0.9 llm = LLM(model=model_path, tensor_parallel_size=tp_size, + pipeline_parallel_size=pp_size, gpu_memory_utilization=gpu_memory_utilization, max_model_len=4096, max_num_batched_tokens=8, @@ -114,6 +115,7 @@ def dp_func(dp_size, local_dp_rank, global_dp_rank, tp_size, ep_size, def exec_model_with_dp(dp_size, tp_size, ep_size, + pp_size, prompts, expect_list, model_path, @@ -129,8 +131,8 @@ def exec_model_with_dp(dp_size, range(node_rank * dp_per_node, (node_rank + 1) * dp_per_node)): proc = Process(target=dp_func, args=(dp_size, local_dp_rank, global_dp_rank, tp_size, - ep_size, dp_master_port, prompts, expect_list, - result_q, model_path, quantization)) + ep_size, pp_size, dp_master_port, prompts, + expect_list, result_q, model_path, quantization)) proc.start() procs.append(proc) exit_code = 0 @@ -157,6 +159,7 @@ def exec_model_with_dp(dp_size, def exec_model_without_dp(tp_size, ep_size, + pp_size, prompts, expect_list, model_path, @@ -171,6 +174,7 @@ def exec_model_without_dp(tp_size, # Create an LLM. llm = LLM(model=model_path, tensor_parallel_size=tp_size, + piepeline_parallel_size=pp_size, trust_remote_code=True, gpu_memory_utilization=0.9, max_model_len=4096, @@ -195,151 +199,177 @@ def exec_model_without_dp(tp_size, @pytest.mark.level0 @pytest.mark.platform_arm_ascend910b_training @pytest.mark.allcards -def test_vllm_qwen3_moe_30b_dp4_tp2_ep4(): +def test_vllm_qwen3_moe_30b_dp4_tp2_ep4_pp1(): """ - test case qwen3_moe_30B with DP4TP2EP4 + test case qwen3_moe_30B with DP4TP2EP4PP1 """ dp_size = 4 tp_size = 2 ep_size = 4 + pp_size = 1 # Sample prompts. prompts = [common_qwen_prompt] * 4 expect_list = [common_qwen_expect_result] * 4 - exec_model_with_dp(dp_size, tp_size, ep_size, prompts, expect_list, - qwen_model_path) + exec_model_with_dp(dp_size, tp_size, ep_size, pp_size, prompts, + expect_list, qwen_model_path) @pytest.mark.level0 @pytest.mark.platform_arm_ascend910b_training @pytest.mark.allcards -def test_deepseek_r1_dp4_tp2_ep4(): +def test_vllm_qwen3_moe_30b_dp1_tp4_ep2_pp2(): """ - test case deepseek r1 w8a8 dp4 tp2 ep4 + test case qwen3_moe_30B with DP1TP4EP2PP2 + """ + tp_size = 4 + ep_size = 2 + pp_size = 2 + # Sample prompts. + prompts = [common_qwen_prompt] + expect_list = [common_qwen_expect_result] + exec_model_without_dp(tp_size, ep_size, pp_size, prompts, expect_list, + qwen_model_path) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.allcards +def test_deepseek_r1_dp4_tp2_ep4_pp1(): + """ + test case deepseek r1 w8a8 dp4 tp2 ep4 pp1 """ dp_size = 4 tp_size = 2 ep_size = 4 + pp_size = 1 # Sample prompts. prompts = [common_ds_prompt] * 4 expect_list = [common_ds_expect_result] * 4 - exec_model_with_dp(dp_size, tp_size, ep_size, prompts, expect_list, - ds_model_path, quant_type) + exec_model_with_dp(dp_size, tp_size, ep_size, pp_size, prompts, + expect_list, qwen_model_path) @pytest.mark.skip( reason= "Currently does not support relevant communication fusion operators in 910b" ) -def test_deepseek_r1_dp8_tp1_ep8(): +def test_deepseek_r1_dp8_tp1_ep8_pp1(): """ - test case deepseek r1 w8a8 Dp8 tp1 ep8 + test case deepseek r1 w8a8 Dp8 tp1 ep8 pp1 """ dp_size = 8 tp_size = 1 ep_size = 8 + pp_size = 1 # Sample prompts. prompts = [common_ds_prompt] * 8 expect_list = [common_ds_expect_result] * 8 - exec_model_with_dp(dp_size, tp_size, ep_size, prompts, expect_list, - ds_model_path, quant_type) + exec_model_with_dp(dp_size, tp_size, ep_size, pp_size, prompts, + expect_list, qwen_model_path) @pytest.mark.level4 @pytest.mark.platform_arm_ascend910b_training @pytest.mark.allcards -def test_deepseek_r1_dp2_tp4_ep1(): +def test_deepseek_r1_dp2_tp4_ep1_pp1(): """ - test case deepseek r1 w8a8 dp2 tp4 ep1 + test case deepseek r1 w8a8 dp2 tp4 ep1 pp1 """ dp_size = 2 tp_size = 4 ep_size = 1 + pp_size = 1 # Sample prompts. prompts = [common_ds_prompt] * 2 expect_list = [common_ds_expect_result] * 2 - exec_model_with_dp(dp_size, tp_size, ep_size, prompts, expect_list, - ds_model_path, quant_type) + exec_model_with_dp(dp_size, tp_size, ep_size, pp_size, prompts, + expect_list, qwen_model_path) @pytest.mark.skip( reason= "Currently does not support relevant communication fusion operators in 910b" ) -def test_deepseek_r1_dp4_tp2_ep8(): +def test_deepseek_r1_dp4_tp2_ep8_pp1(): """ - test case deepseek r1 w8a8 dp4 tp2 ep8 + test case deepseek r1 w8a8 dp4 tp2 ep8 pp1 """ dp_size = 4 tp_size = 2 ep_size = 8 + pp_size = 1 # Sample prompts. prompts = [common_ds_prompt] * 4 expect_list = [common_ds_expect_result] * 4 - exec_model_with_dp(dp_size, tp_size, ep_size, prompts, expect_list, - ds_model_path, quant_type) + exec_model_with_dp(dp_size, tp_size, ep_size, pp_size, prompts, + expect_list, qwen_model_path) @pytest.mark.level4 @pytest.mark.platform_arm_ascend910b_training @pytest.mark.allcards -def test_deepseek_r1_dp8_tp1_ep1(): +def test_deepseek_r1_dp8_tp1_ep1_pp1(): """ - test case deepseek r1 w8a8 dp8 tp1 ep1 + test case deepseek r1 w8a8 dp8 tp1 ep1 pp1 """ dp_size = 8 tp_size = 1 ep_size = 1 + pp_size = 1 # Sample prompts. prompts = [common_ds_prompt] * 8 expect_list = [common_ds_expect_result] * 8 - exec_model_with_dp(dp_size, tp_size, ep_size, prompts, expect_list, - ds_model_path, quant_type) + exec_model_with_dp(dp_size, tp_size, ep_size, pp_size, prompts, + expect_list, qwen_model_path) @pytest.mark.level4 @pytest.mark.platform_arm_ascend910b_training @pytest.mark.allcards -def test_deepseek_r1_dp8_tp1_ep4(): +def test_deepseek_r1_dp8_tp1_ep4_pp1(): """ - test case deepseek r1 w8a8 dp8 tp1 ep4 + test case deepseek r1 w8a8 dp8 tp1 ep4 pp1 """ dp_size = 8 tp_size = 1 ep_size = 4 + pp_size = 1 # Sample prompts. prompts = [common_ds_prompt] * 8 expect_list = [common_ds_expect_result] * 8 - exec_model_with_dp(dp_size, tp_size, ep_size, prompts, expect_list, - ds_model_path, quant_type) + exec_model_with_dp(dp_size, tp_size, ep_size, pp_size, prompts, + expect_list, qwen_model_path) @pytest.mark.level4 @pytest.mark.platform_arm_ascend910b_training @pytest.mark.allcards -def test_deepseek_r1_tp8_ep8(): +def test_deepseek_r1_tp8_ep8_pp1(): """ - test case deepseek r1 w8a8 tp8 ep8 + test case deepseek r1 w8a8 tp8 ep8 pp1 """ tp_size = 8 ep_size = 8 + pp_size = 1 # Sample prompts. prompts = [common_ds_prompt] expect_list = [common_ds_expect_result] - exec_model_without_dp(tp_size, ep_size, prompts, expect_list, + exec_model_without_dp(tp_size, ep_size, pp_size, prompts, expect_list, ds_model_path, quant_type) @pytest.mark.level4 @pytest.mark.platform_arm_ascend910b_training @pytest.mark.allcards -def test_deepseek_r1_tp8_ep4(): +def test_deepseek_r1_tp8_ep4_pp1(): """ - test case deepseek r1 w8a8 tp8 ep4 + test case deepseek r1 w8a8 tp8 ep4 pp1 """ tp_size = 8 ep_size = 4 + pp_size = 1 # Sample prompts. prompts = [common_ds_prompt] expect_list = [common_ds_expect_result] - exec_model_without_dp(tp_size, ep_size, prompts, expect_list, + exec_model_without_dp(tp_size, ep_size, pp_size, prompts, expect_list, ds_model_path, quant_type) diff --git a/vllm_mindspore/worker/worker.py b/vllm_mindspore/worker/worker.py index 66afa92c..945a4abb 100644 --- a/vllm_mindspore/worker/worker.py +++ b/vllm_mindspore/worker/worker.py @@ -21,6 +21,7 @@ import subprocess import psutil import torch +from vllm.distributed import get_pp_group from vllm.logger import init_logger from vllm.model_executor import set_random_seed from vllm.sampling_params import SamplingParams @@ -206,6 +207,18 @@ def _warm_up_model(self) -> None: kv_cache = self.cache_engine[0].gpu_cache is_mtp_model = self.speculative_config is not None and \ self.model_config.hf_config.model_type == "deepseek_mtp" + + def get_model(cls): + if cls.vllm_config.scheduler_config.is_multi_step: + return cls.model_runner._base_model_runner.model + return cls.model_runner.model + + intermediate_tensors = None + model = get_model(self) + if not get_pp_group().is_first_rank: + intermediate_tensors = model.make_empty_intermediate_tensors( + batch_size=1, dtype=self.model_config.dtype, device=self.device) + if is_mtp_model: # prefill mtp model model_input, previous_hidden_states = _prepare_input_for_warmup( @@ -214,7 +227,7 @@ def _warm_up_model(self) -> None: self.model_runner.execute_model( model_input, kv_cache, - None, + intermediate_tensors, previous_hidden_states=previous_hidden_states) # warmup for decode @@ -223,7 +236,7 @@ def _warm_up_model(self) -> None: self.model_config, self.model_runner._base_model_runner, self.cache_engine[0], False) self.model_runner._base_model_runner.execute_model( - model_input, kv_cache, None) + model_input, kv_cache, intermediate_tensors) else: model_input, previous_hidden_states = _prepare_input_for_warmup( self.model_config, self.model_runner, self.cache_engine[0], False, @@ -231,7 +244,7 @@ def _warm_up_model(self) -> None: self.model_runner.execute_model( model_input, kv_cache, - None, + intermediate_tensors, previous_hidden_states=previous_hidden_states) torch.cuda.synchronize() -- Gitee