diff --git a/boost/README.md b/boost/README.md new file mode 100644 index 0000000000000000000000000000000000000000..fe541329bec923d6831a266dbb1ef985c8ffad84 --- /dev/null +++ b/boost/README.md @@ -0,0 +1,47 @@ +# 使用方式 + +## 1、安装 vllm 和 vllm-ascend +```bash +# vllm +git clone -b v0.9.1 https://github.com/vllm-project/vllm.git +pip install -r vllm/requirements/build.txt -i https://mirrors.aliyun.com/pypi/simple/#将里面的torch==2.7删除 +pip install -r vllm/requirements/common.txt -i https://mirrors.aliyun.com/pypi/simple/ +cd vllm +VLLM_TARGET_DEVICE=empty python setup.py develop + + +# vllm-ascend +git clone -b v0.9.1-dev https://github.com/vllm-project/vllm-ascend.git + git checkout 4014ad2a46e01c79fd8d98d6283404d0bc414dce +cd vllm-ascend +pip install -v -e . +``` + +## 2、安装 Megatron +```bash +git clone https://github.com/NVIDIA/Megatron-LM.git # Megatron cong github 下载,请确保网络能访问 +cd Megatron-LM +git checkout core_v0.12.1 +pip install . +``` + +## 3、安装 verl +```bash +git clone https://github.com/volcengine/verl.git +cd verl +git checkout 503ea75f533f715cf1667551bd375ec9b1e0d5fc +pip install -e . +``` + +## 4、安装插件 +```bash +# 请确保 vllm 已正确安装并且之后不会做覆盖 +git clone https://gitee.com/ascend/MindSpeed-RL.git +cd MindSpeed-RL/boost +pip install -e . +``` + +**注意**:安装插件前需要保证verl源码安装,否则插件不能生效。如果无法源码安装verl,需要指定verl源码路径: + +```bash +VERL_PATH=path_to_verl pip install -e . +``` diff --git a/boost/setup.py b/boost/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..50f3cee757112503e2894a2ef4e2db63afcd84a0 --- /dev/null +++ b/boost/setup.py @@ -0,0 +1,337 @@ +import os +import sys +import sysconfig +import subprocess +from setuptools import setup, find_packages +from setuptools.command.develop import develop +from setuptools.command.install import install + + +def inject_verl_plugin(custom_path=None): + print("Starting verl plugin injection...") + + if 'VERL_PATH' in os.environ: + verl_path = os.path.join(os.environ['VERL_PATH'], "verl") + print(f"Using verl path from environment variable: {verl_path}") + elif custom_path: + verl_path = custom_path + print(f"Using custom verl path: {verl_path}") + else: + print("Searching for verl package automatically...") + paths_to_try = [ + sysconfig.get_paths()["purelib"], + sysconfig.get_paths()["platlib"], + ] + sys.path + + verl_path = None + for path in paths_to_try: + if not path: + continue + + candidate = os.path.join(path, "verl") + if os.path.exists(candidate) and os.path.isdir(candidate): + verl_path = candidate + break + + if not verl_path: + try: + result = subprocess.run( + [sys.executable, "-m", "pip", "show", "verl"], + capture_output=True, + text=True, + check=True + ) + for line in result.stdout.splitlines(): + if line.startswith("Location:"): + verl_path = os.path.join(line.split(": ")[1], "verl") + break + except (subprocess.CalledProcessError, FileNotFoundError) as e: + print(f"pip show failed: {e}") + + if not verl_path: + print("Error: verl package not found. Please specify with VERL_PATH environment variable.") + return False + + print(f"Found verl at: {verl_path}") + + init_file = os.path.join(verl_path, "__init__.py") + if not os.path.exists(init_file): + print(f"Error: verl initialization file not found: {init_file}") + return False + + import_content = """ +# NPU acceleration support added by mindspeed-rl plugin +from verl.utils.device import is_npu_available + +if is_npu_available: + import verl_npu + print("NPU acceleration enabled for verl") +""" + + try: + with open(init_file, "r") as f: + content = f.read() + except Exception as e: + print(f"Error reading {init_file}: {e}") + return False + + if import_content in content: + print(f"Info: {init_file} already contains NPU acceleration import") + else: + try: + with open(init_file, "a") as f: + f.write(import_content) + print(f"Successfully modified {init_file} to add NPU acceleration support") + except Exception as e: + print(f"Error writing to {init_file}: {e}") + return False + + linear_cross_entropy_file = os.path.join(verl_path, "utils", "kernel", "linear_cross_entropy.py") + if not os.path.exists(linear_cross_entropy_file): + print(f"Warning: linear_cross_entropy file not found: {linear_cross_entropy_file}") + return True + + line_to_comment = "from . import kernels" + + try: + with open(linear_cross_entropy_file, "r") as f: + lines = f.readlines() + + modified = False + new_lines = [] + for line in lines: + if line.strip() == line_to_comment: + new_lines.append(f"# {line}") + print(f"Commented out line in {linear_cross_entropy_file}: {line.strip()}") + modified = True + else: + new_lines.append(line) + + if modified: + with open(linear_cross_entropy_file, "w") as f: + f.writelines(new_lines) + print(f"Successfully modified {linear_cross_entropy_file}") + else: + already_commented = any(f"# {line_to_comment}" in line for line in lines) + if already_commented: + print(f"Info: line already commented in {linear_cross_entropy_file}") + else: + print(f"Warning: line to comment not found in {linear_cross_entropy_file}: {line_to_comment}") + + except Exception as e: + print(f"Error modifying {linear_cross_entropy_file}: {e}") + return False + + return True + + +def inject_vllm_plugin(): + print("Searching for vllm package automatically...") + paths_to_try = [ + sysconfig.get_paths()["purelib"], + sysconfig.get_paths()["platlib"], + ] + sys.path + + vllm_path = None + for path in paths_to_try: + if not path: + continue + + candidate = os.path.join(path, "vllm") + if os.path.exists(candidate) and os.path.isdir(candidate): + vllm_path = candidate + break + + if not vllm_path: + try: + result = subprocess.run( + [sys.executable, "-m", "pip", "show", "vllm"], + capture_output=True, + text=True, + check=True + ) + for line in result.stdout.splitlines(): + if line.startswith("Location:"): + vllm_path = os.path.join(line.split(": ")[1], "vllm") + break + except (subprocess.CalledProcessError, FileNotFoundError) as e: + print(f"pip show failed: {e}") + + if not vllm_path: + print("Error: vllm package not found. Please specify with VLLM_PATH environment variable.") + return False + + print(f"Found vllm at: {vllm_path}") + + fp8_utils_file = os.path.join(vllm_path, "model_executor", "layers", "quantization", "utils", "fp8_utils.py") + if not os.path.exists(fp8_utils_file): + print(f"Warning: linear_cross_entropy file not found: {fp8_utils_file}") + else: + line_to_change = "from typing import Any, Callable" + + try: + with open(fp8_utils_file, "r") as f: + lines = f.readlines() + + modified = False + new_lines = [] + for line in lines: + if line_to_change in line.strip() and 'List' not in line.strip(): + new_lines.append(f"{line[:-1]}, List\n") + print(f"Commented out line in {fp8_utils_file}: {line.strip()}") + modified = True + elif 'list' in line: + new_lines.append(line.replace('list','List')) + modified = True + else: + new_lines.append(line) + + if modified: + with open(fp8_utils_file, "w") as f: + f.writelines(new_lines) + print(f"Successfully modified {fp8_utils_file}") + else: + already_commented = any(f"List" in line for line in lines) + if already_commented: + print(f"Info: line already commented in {fp8_utils_file}") + else: + print(f"Warning: line to comment not found in {fp8_utils_file}: {line_to_change}") + except Exception as e: + print(f"Error modifying {fp8_utils_file}: {e}") + return False + + fused_moe_file = os.path.join(vllm_path, "model_executor", "layers", "fused_moe", "fused_moe.py") + if not os.path.exists(fused_moe_file): + print(f"Warning: linear_cross_entropy file not found: {fused_moe_file}") + else: + line_to_change = "from typing import Any, Callable" + + try: + with open(fused_moe_file, "r") as f: + lines = f.readlines() + + modified = False + new_lines = [] + for line in lines: + if line_to_change in line.strip() and 'List' not in line.strip(): + new_lines.append(f"{line[:-1]}, List\n") + print(f"Commented out line in {fused_moe_file}: {line.strip()}") + modified = True + elif 'list' in line: + new_lines.append(line.replace('list','List')) + modified = True + else: + new_lines.append(line) + + if modified: + with open(fused_moe_file, "w") as f: + f.writelines(new_lines) + print(f"Successfully modified {fused_moe_file}") + else: + already_commented = any(f"List" in line for line in lines) + if already_commented: + print(f"Info: line already commented in {fused_moe_file}") + else: + print(f"Warning: line to comment not found in {fused_moe_file}: {line_to_change}") + except Exception as e: + print(f"Error modifying {fused_moe_file}: {e}") + return False + + +def clone_and_install_mindspeed(): + repo_dir = os.path.join(os.getcwd(), "mindspeed_repo") + os.makedirs(repo_dir, exist_ok=True) + + if os.path.exists(repo_dir): + import shutil + shutil.rmtree(repo_dir) + + clone_cmd = ["git", "clone", "https://gitee.com/humphrey007/MindSpeed.git", repo_dir] + try: + subprocess.run(clone_cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + print(f"Successfully cloned MindSpeed repository to {repo_dir}") + except subprocess.CalledProcessError as e: + print(f"Error cloning MindSpeed repository: {e}") + print(f"Installing MindSpeed core from {repo_dir}...") + try: + subprocess.run(["python", "setup.py", "install"], cwd=repo_dir, check=True) + print("MindSpeed core installed successfully.") + except subprocess.CalledProcessError as e: + print(f"Error installing MindSpeed core: {e}") + + +class CustomInstallCommand(install): + def run(self): + super().run() + print("Running verl injection after standard install...") + custom_path = os.environ.get('VERL_PATH', None) + inject_verl_plugin(custom_path) + inject_vllm_plugin() + clone_and_install_mindspeed() + + +class CustomDevelopCommand(develop): + def run(self): + super().run() + print("Running verl injection after develop install...") + custom_path = os.environ.get('VERL_PATH', None) + inject_verl_plugin(custom_path) + inject_vllm_plugin() + + +def main(): + print("Setting up verl_npu...") + + custom_path = None + i = 0 + while i < len(sys.argv): + arg = sys.argv[i] + if arg.startswith('--verl-path='): + custom_path = arg.split('=', 1)[1] + sys.argv.pop(i) + break + elif arg == '--verl-path': + if i + 1 < len(sys.argv): + custom_path = sys.argv[i+1] + sys.argv.pop(i) + sys.argv.pop(i) + break + else: + print("Error: --verl-path requires a path argument") + sys.exit(1) + i += 1 + + setup( + name="verl_npu", + version="0.0.1", + author="verl-npu team", + license="Apache 2.0", + description="verl Ascend backend plugin", + long_description=open("README.md", encoding="utf-8").read(), + long_description_content_type="text/markdown", + packages=find_packages(), + classifiers=[ + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "License :: OSI Approved :: Apache Software License", + "Intended Audience :: Developers", + "Intended Audience :: Information Technology", + "Intended Audience :: Science/Research", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Scientific/Engineering :: Information Analysis", + ], + python_requires=">=3.9", + cmdclass={ + 'install': CustomInstallCommand, + 'develop': CustomDevelopCommand, + }, + ) + + if custom_path: + print("Running direct injection from command line argument...") + inject_verl_plugin(custom_path) + + +if __name__ == '__main__': + main() diff --git a/boost/verl_npu/__init__.py b/boost/verl_npu/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..099e9f22c1c5e57d0cd5d08e5895b550e2e50513 --- /dev/null +++ b/boost/verl_npu/__init__.py @@ -0,0 +1,11 @@ +from verl_npu.workers.rollout.vllm_rollout.vllm_rollout_spmd import patch_vllm_rollout_spmd +from verl_npu.models.mcore.registry import patch_mcore_registry + + +def adapt_verl_to_ascend(): + from mindspeed import megatron_adaptor + + patch_mcore_registry() + patch_vllm_rollout_spmd() + +adapt_verl_to_ascend() diff --git a/boost/verl_npu/models/__init__.py b/boost/verl_npu/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/boost/verl_npu/models/mcore/__init__.py b/boost/verl_npu/models/mcore/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/boost/verl_npu/models/mcore/model_forward.py b/boost/verl_npu/models/mcore/model_forward.py new file mode 100644 index 0000000000000000000000000000000000000000..1d73018c7ed84db62310f2fd274612020df31a13 --- /dev/null +++ b/boost/verl_npu/models/mcore/model_forward.py @@ -0,0 +1,65 @@ +# Copyright (c) 2025, HUAWEI CORPORATION. All rights reserved. + +# need this patch to patch megatron +from mindspeed import megatron_adaptor # noqa: F401 + +from verl.utils.megatron_utils import unwrap_model + +from verl.models.mcore.util import postprocess_packed_seqs, recover_left_padding, remove_left_padding +from .util import preprocess_packed_seqs + + +def gptmodel_forward( + model, + input_ids, + attention_mask, + position_ids, + sequence_parallel, + value_model=False, + pack_seqs=True, + logits_processor=None, + logits_processor_args: dict = None, + **kwargs, +): + """Default forward pass for GPT models with optional sequence packing.""" + pre_process = unwrap_model(model).pre_process + post_process = unwrap_model(model).post_process + if pack_seqs: + batch_size, seq_len = attention_mask.shape[:2] + input_ids_rmpad, packed_seq_params = preprocess_packed_seqs(input_ids, attention_mask, pre_process=pre_process) + input_ids_rmpad = input_ids_rmpad.contiguous() + output_orig = model( + input_ids=input_ids_rmpad, + attention_mask=None, + position_ids=position_ids, + packed_seq_params=packed_seq_params, + ) + if post_process and logits_processor is not None: + args = { + k: preprocess_packed_seqs(v, attention_mask, pre_process=True)[0] + for k, v in logits_processor_args.items() + } + output_dict = logits_processor(output_orig, **args) + output = { + k: postprocess_packed_seqs( + v, packed_seq_params, attention_mask, batch_size, seq_len, post_process=post_process + ) + for k, v in output_dict.items() + } + else: + output = postprocess_packed_seqs( + output_orig, packed_seq_params, attention_mask, batch_size, seq_len, post_process=post_process + ) + else: + assert logits_processor is None, "logits_processor is not supported for non-packed sequence" + batch_size, sequence_length = attention_mask.shape + new_input_ids, new_attention_mask, new_position_ids = remove_left_padding( + input_ids, attention_mask, position_ids, sequence_parallel, pre_process=pre_process + ) + output = model(input_ids=new_input_ids, attention_mask=new_attention_mask, position_ids=new_position_ids) + output = recover_left_padding( + output, new_attention_mask, attention_mask, sequence_length, post_process=post_process + ) + if value_model and post_process: + output = output[..., 0] + return output diff --git a/boost/verl_npu/models/mcore/registry.py b/boost/verl_npu/models/mcore/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..3903929e26743d18186aa9b7ed968eaa064bf3de --- /dev/null +++ b/boost/verl_npu/models/mcore/registry.py @@ -0,0 +1,33 @@ +# Copyright (c) 2025, HUAWEI CORPORATION. All rights reserved. + +from typing import Callable + +from .model_forward import gptmodel_forward +from verl.models.mcore.registry import SupportedModel, gptmodel_forward_qwen2_5_vl + + +# Registry for model forward functions +MODEL_FORWARD_REGISTRY: dict[SupportedModel, Callable] = { + SupportedModel.LLAMA: gptmodel_forward, + SupportedModel.QWEN2: gptmodel_forward, + SupportedModel.QWEN2_MOE: gptmodel_forward, + SupportedModel.MIXTRAL: gptmodel_forward, + SupportedModel.DEEPSEEK_V3: gptmodel_forward, + SupportedModel.QWEN2_5_VL: gptmodel_forward, + SupportedModel.LLAMA4: gptmodel_forward, + SupportedModel.QWEN3: gptmodel_forward, + SupportedModel.QWEN3_MOE: gptmodel_forward, + SupportedModel.QWEN2_5_VL: gptmodel_forward_qwen2_5_vl, + SupportedModel.DEEPSEEK_V3: gptmodel_forward, +} + + +def patch_mcore_registry(): + from verl.models import mcore + from verl_npu.patch_utils import apply_patches + + patch_list = [ + ("registry.MODEL_FORWARD_REGISTRY", MODEL_FORWARD_REGISTRY), + ] + + apply_patches(patch_list, mcore) diff --git a/boost/verl_npu/models/mcore/util.py b/boost/verl_npu/models/mcore/util.py new file mode 100644 index 0000000000000000000000000000000000000000..162e4b995cfd95e1803995be5976df1fc5844adb --- /dev/null +++ b/boost/verl_npu/models/mcore/util.py @@ -0,0 +1,97 @@ +# Copyright (c) 2025, HUAWEI CORPORATION. All rights reserved. + +import torch +from megatron.core import parallel_state as mpu +from megatron.core.packed_seq_params import PackedSeqParams + + +def compute_qkv_index(seq_lens): + full_indices = list(range(seq_lens[-1])) + prev_eod_pos = 0 + kv_indices = [] + q_indices = [] + for eod_pos in seq_lens: + mid = (eod_pos + prev_eod_pos) // 2 + kv_indices.extend(full_indices[prev_eod_pos:mid]) + q_indices.extend(full_indices[mid:eod_pos]) + prev_eod_pos = eod_pos + + kv_index = torch.tensor(kv_indices).cuda(non_blocking=True) + q_index = torch.tensor(q_indices).cuda(non_blocking=True) + + return q_index, kv_index + + +def preprocess_packed_seqs( + input_ids: torch.Tensor, attention_mask: torch.Tensor, pre_process: bool = True +) -> tuple[torch.Tensor, PackedSeqParams]: + """ + Preprocess packed sequences + CP splits sequence into CP*2 chunks, and each GPU gets 2 chunks (GPU0 gets first and last chunks, GPU1 + gets second and second last chunks, and so on), this is for load balancing with causal masking. + See https://github.com/NVIDIA/TransformerEngine/issues/1368 + """ + batch_size = input_ids.shape[0] + + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + tp_size = mpu.get_tensor_model_parallel_world_size() + cp_size = mpu.get_context_parallel_world_size() + cp_rank = mpu.get_context_parallel_rank() + align_size = tp_size * cp_size * 2 if cp_size > 1 else tp_size + + pad_size = (align_size - seqlens_in_batch % align_size) % align_size + seqlens_in_batch_padded = seqlens_in_batch + pad_size + cu_seqlens = torch.zeros(batch_size + 1, dtype=torch.int32, device=input_ids.device) + cu_seqlens[1:] = torch.cumsum(seqlens_in_batch, dim=0) + cu_seqlens_padded = torch.zeros(batch_size + 1, dtype=torch.int32, device=input_ids.device) + cu_seqlens_padded[1:] = torch.cumsum(seqlens_in_batch_padded, dim=0) + max_seqlen_in_batch = seqlens_in_batch_padded.max().item() + + shape = list(input_ids.shape[1:]) + shape[0] = seqlens_in_batch_padded.sum().item() // cp_size + if pre_process: + input_ids_rmpad = torch.zeros(shape, dtype=input_ids.dtype, device=input_ids.device) + for i in range(batch_size): + if cp_size <= 1: + seqlen = seqlens_in_batch[i] + input_ids_rmpad[cu_seqlens_padded[i] : cu_seqlens_padded[i] + seqlen] = input_ids[i, attention_mask[i]] + continue + seqlen = seqlens_in_batch_padded[i] // cp_size + half_seqlen = seqlen // 2 + start_idx = cu_seqlens_padded[i] // cp_size + # split to 2 chunks + d = input_ids[i, attention_mask[i]] + input_ids_rmpad[start_idx : start_idx + half_seqlen] = d[ + half_seqlen * cp_rank : half_seqlen * (cp_rank + 1) + ] + + remain_start = seqlens_in_batch_padded[i] - half_seqlen * (cp_rank + 1) + remain_end = seqlens_in_batch_padded[i] - half_seqlen * cp_rank + remain_end = min(remain_end, d.shape[0]) + remain_len = remain_end - remain_start + if remain_len > 0: + input_ids_rmpad[start_idx + half_seqlen : start_idx + half_seqlen + remain_len] = d[ + remain_start:remain_end + ] + + packed_seq_params = PackedSeqParams( + qkv_format="thd", + cu_seqlens_q=cu_seqlens_padded, + max_seqlen_q=max_seqlen_in_batch, + cu_seqlens_kv=cu_seqlens_padded, + max_seqlen_kv=max_seqlen_in_batch, + cu_seqlens_q_padded=cu_seqlens_padded, + cu_seqlens_kv_padded=cu_seqlens_padded, + ) + + # patched for npu cp + cu_seqlens_padded_div_cp = cu_seqlens_padded // cp_size + q_index, kv_index = compute_qkv_index(cu_seqlens_padded_div_cp.clone().tolist()) + packed_seq_params.q_index = q_index + packed_seq_params.kv_index = kv_index + packed_seq_params.cu_seqlens_padded_div_cp = cu_seqlens_padded_div_cp + + if pre_process: + return input_ids_rmpad.unsqueeze(0), packed_seq_params + else: + return input_ids, packed_seq_params diff --git a/boost/verl_npu/patch_utils.py b/boost/verl_npu/patch_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7faff366ab41f80ef918846cd6501c9adbdf36da --- /dev/null +++ b/boost/verl_npu/patch_utils.py @@ -0,0 +1,37 @@ +# Copyright (c) 2025, HUAWEI CORPORATION. All rights reserved. + +import sys +import types + + +def apply_patches(monkey_patches, root_module): + def _getattr(module_list, root_module): + if len(module_list) <= 1: + return root_module + + if hasattr(root_module, module_list[0]): + return _getattr(module_list[1:], getattr(root_module, module_list[0])) + else: + empty_module_name = f"{root_module.__name__}.{module_list[0]}" + sys.modules[empty_module_name] = types.ModuleType(empty_module_name) + setattr(root_module, module_list[0], sys.modules.get(empty_module_name)) + return _getattr(module_list[1:], getattr(root_module, module_list[0])) + + for patch_pair in monkey_patches: + dest, patch = patch_pair + dest_module = _getattr(dest.split("."), root_module) + last_module_level = dest.split(".")[-1] + if not isinstance(patch, types.ModuleType): + setattr(dest_module, last_module_level, patch) + continue + + if not hasattr(dest_module, last_module_level) or not hasattr(patch, "__all__"): + setattr(dest_module, last_module_level, patch) + sys.modules[f"{dest_module.__name__}.{last_module_level}"] = patch + continue + + if not hasattr(patch, "__all__"): + raise NotImplementedError("Patch module must have __all__ definition.") + dest_module = getattr(dest_module, last_module_level) + for attr in patch.__all__: + setattr(dest_module, attr, getattr(patch, attr)) diff --git a/boost/verl_npu/workers/__init__.py b/boost/verl_npu/workers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/boost/verl_npu/workers/rollout/__init__.py b/boost/verl_npu/workers/rollout/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/boost/verl_npu/workers/rollout/vllm_rollout/__init__.py b/boost/verl_npu/workers/rollout/vllm_rollout/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/boost/verl_npu/workers/rollout/vllm_rollout/vllm_rollout_spmd.py b/boost/verl_npu/workers/rollout/vllm_rollout/vllm_rollout_spmd.py new file mode 100644 index 0000000000000000000000000000000000000000..eafb18169cab23ee668eb566c81ca1be94b7ada3 --- /dev/null +++ b/boost/verl_npu/workers/rollout/vllm_rollout/vllm_rollout_spmd.py @@ -0,0 +1,226 @@ +# Copyright (c) 2025, HUAWEI CORPORATION. All rights reserved. + +import os +import torch +import socket +from copy import deepcopy +from omegaconf import DictConfig, OmegaConf + +import torch.distributed as dist +from vllm import LLM, SamplingParams +import vllm.envs as envs +from vllm.distributed import parallel_state as vllm_ps +from verl.workers.rollout.base import BaseRollout + + +def get_cluster_info(): + # 确保分布式环境已初始化 + if not dist.is_initialized(): + raise RuntimeError("Distributed environment not initialized") + + world_size = dist.get_world_size() + + # 获取当前节点的IP地址 + ip_address = _get_current_node_ip() + + # 收集所有rank的IP地址 + ip_list = [None] * world_size + dist.all_gather_object(ip_list, ip_address) + + return ip_list + + +def _get_current_node_ip() -> str: + try: + # 创建一个 UDP 套接字(仅用于获取接口信息) + with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s: + # 连接到一个外部地址(无需真实通信) + s.connect(("8.8.8.8", 80)) # Google DNS 服务器 + local_ip = s.getsockname()[0] + except Exception: + local_ip = _get_ip_by_ifname() + if not local_ip: + # 如果失败,回退到遍历接口 + local_ip = "127.0.0.1" + hostname = socket.gethostname() + for addr in socket.getaddrinfo(hostname, None): + ip = addr[4][0] + if not ip.startswith("::"): + local_ip = ip + break + return local_ip + +def _init_dp_envs(config): + rank = torch.distributed.get_rank() + world_size = int(config.get("rollout_world_size", 1)) + # world_size = int(os.getenv("WORLD_SIZE", "-1")) + tp_size = int(config.get("tensor_model_parallel_size", 1)) + dp_size = int(config.get("dp_model_parallel_size", 1)) + + all_ranks = torch.arange(world_size).reshape(-1, dp_size, 1, tp_size) # noqa + group_ranks = all_ranks.transpose(1, 3).reshape(-1, dp_size).unbind(0) + group_ranks = [x.tolist() for x in group_ranks] + ip_list = get_cluster_info() + for index, group_rank in enumerate(group_ranks): + if torch.distributed.get_rank() in group_rank: + os.environ["VLLM_DP_MASTER_PORT"] = str(int(os.environ.get("MASTER_PORT")) + 1 + index) + os.environ["VLLM_DP_MASTER_IP"] = ip_list[group_rank[0]] + local_dp_rank = rank // tp_size % dp_size + os.environ["VLLM_DP_RANK"] = str(local_dp_rank) + os.environ["VLLM_DP_SIZE"] = str(dp_size) + os.environ["VLLM_PORT"] = os.environ["VLLM_DP_MASTER_PORT"] + envs.VLLM_DP_RANK = int(os.environ["VLLM_DP_RANK"]) + envs.VLLM_DP_MASTER_IP = os.environ["VLLM_DP_MASTER_IP"] + envs.VLLM_DP_MASTER_PORT = int(os.environ["VLLM_DP_MASTER_PORT"]) + + print(f"[VLLM] using TP={tp_size}, DP={dp_size}", flush=True) + + +def __init__(self, model_path: str, config: DictConfig, tokenizer, model_hf_config, **kwargs): + """A vLLM rollout. It requires the module is supported by the vllm. + + Args: + module: module here follows huggingface APIs + config: DictConfig + tokenizer: the task/model tokenizer + model_hf_config: the huggingface config to initiallize the generating model in vllm + **kwargs: train_tp, for Megatron Backend to initialize hybrid engine (zero redundancy) process group + """ + super(BaseRollout, self).__init__() + self.config = config + + tensor_parallel_size = self.config.get("tensor_model_parallel_size", 1) + assert tensor_parallel_size <= torch.distributed.get_world_size(), ( + "tensor parallel size should be less than or equal to the world size" + ) + max_num_batched_tokens = self.config.get("max_num_batched_tokens", 8192) + + if kwargs.get("train_tp") is not None: + # deployed with megatron + import os + + os.environ["CUDA_TIMER_STREAM_KAFKA_ENABLE"] = "0" + os.environ["MEGATRON_IMPORT_TIMERS"] = "0" + vllm_ps.initialize_model_parallel(tensor_model_parallel_size=tensor_parallel_size) + + rope_scaling_config = getattr(model_hf_config, "rope_scaling", None) + if not rope_scaling_config: + max_position_embeddings = None + if hasattr(model_hf_config, "max_position_embeddings"): + max_position_embeddings = model_hf_config.max_position_embeddings + elif hasattr(model_hf_config, "llm_config") and hasattr( + model_hf_config.llm_config, "max_position_embeddings" + ): + max_position_embeddings = model_hf_config.llm_config.max_position_embeddings + elif hasattr(model_hf_config, "text_config") and hasattr( + model_hf_config.text_config, "max_position_embeddings" + ): + max_position_embeddings = model_hf_config.text_config.max_position_embeddings + if max_position_embeddings is None: + raise ValueError("max_position_embeddings not found in model_hf_config") + assert max_position_embeddings >= config.prompt_length + config.response_length, ( + "model context length should be greater than total sequence length" + ) + else: + # handle type where there's a length extend factor + # see https://qwen.readthedocs.io/en/latest/deployment/vllm.html#extended-context-support + # for using yarn as an example + rope_scaling_factor = rope_scaling_config.get("factor", 1.0) + + assert ( + model_hf_config.max_position_embeddings * rope_scaling_factor + >= config.prompt_length + config.response_length + ), ( + "model context length should be greater than total sequence length, " + + f"got rope_scaling_factor={rope_scaling_factor} and " + + f"max_position_embeddings={model_hf_config.max_position_embeddings}" + ) + + max_model_len = int(config.max_model_len or config.prompt_length + config.response_length) + + if max_num_batched_tokens < max_model_len and self.config.enable_chunked_prefill: + raise ValueError( + "Enable chunked prefill, max_num_batched_tokens is smaller than max_model_len, \ + please increase max_num_batched_tokens or disable chunked prefill" + ) + + trust_remote_code = kwargs.get("trust_remote_code", False) + load_format = "dummy" if config.load_format.startswith("dummy") else config.load_format + + lora_kwargs = kwargs.pop("lora_kwargs", {}) + self.lora_kwargs = lora_kwargs + # copy it to avoid secretly modifying the engine config + engine_kwargs = ( + {} + if "engine_kwargs" not in config or "vllm" not in config.engine_kwargs + else OmegaConf.to_container(deepcopy(config.engine_kwargs.vllm)) + ) + # For each vLLM engine parameter, + # - `None` means not setting it, so we pop it, and leave it to vLLM default value + # (which can vary across different vLLM versions); + # - Otherwise it's the desired value we want to explicitly set. + engine_kwargs = {key: val for key, val in engine_kwargs.items() if val is not None} + if config.get("limit_images", None): # support for multi-image data + engine_kwargs["limit_mm_per_prompt"] = {"image": config.get("limit_images")} + + # patch this for npu + enable_infer_ep = False + if hasattr(config, "dp_model_parallel_size") and config.dp_model_parallel_size > 1: + _init_dp_envs(config) + enable_infer_ep = True + + self.inference_engine = LLM( + model=model_path, + enable_sleep_mode=True, + tensor_parallel_size=tensor_parallel_size, + distributed_executor_backend="external_launcher", + dtype=config.dtype, + enforce_eager=config.enforce_eager, + gpu_memory_utilization=config.gpu_memory_utilization, + disable_custom_all_reduce=False, + skip_tokenizer_init=False, + max_model_len=max_model_len, + load_format=load_format, + disable_log_stats=config.disable_log_stats, + max_num_batched_tokens=max_num_batched_tokens, + enable_chunked_prefill=config.enable_chunked_prefill, + enable_prefix_caching=True, + trust_remote_code=trust_remote_code, + enable_expert_parallel=enable_infer_ep, + seed=config.get("seed", 0), + **lora_kwargs, + **engine_kwargs, + ) + + # Offload vllm model to reduce peak memory usage + if config.free_cache_engine: + self.inference_engine.sleep(level=1) + + kwargs = dict( + n=1, + logprobs=0, # can be set to 0 and let actor to recompute + max_tokens=config.response_length, + ) + + kwargs["detokenize"] = False + + # supporting adding any sampling params from the config file + for k in config.keys(): + if hasattr(SamplingParams(), str(k)): + kwargs[k] = config.get(k) + kwargs["n"] = 1 # already repeat in ray_trainer + print(f"kwargs: {kwargs}") + self.sampling_params = SamplingParams(**kwargs) + + self.pad_token_id = tokenizer.pad_token_id + + +def patch_vllm_rollout_spmd(): + from verl.workers.rollout.vllm_rollout import vLLMRollout + from verl_npu.patch_utils import apply_patches + + patch_list = [ + ("__init__", __init__), + ] + + apply_patches(patch_list, vLLMRollout)