diff --git a/README.md b/README.md index 178164eabd6728ebb32d4fd85393b1f68150d238..c95dd38912c85500220cb717fa82d2f8693402bd 100644 --- a/README.md +++ b/README.md @@ -1,172 +1,152 @@ -# vllm_mindspore - -## 项目介绍 - -功能介绍 - -API - ---- - -## 前置依赖 - -### 运行环境 - -OS:linux aarch64 - -python:python3.9-3.11 - -device:Ascend A2/A3卡 - -软件: - -1. `CANN>=8.0.0` -2. `mindspore>=2.5.0` - -### 环境验证 - -`python -c "import mindspore;mindspore.set_context(device_target='Ascend');mindspore.run_check()"` - -> **若报错显示CANN不匹配,则需重新安装CANN包** -> -> 则可通过如下方法找到对应的CANN: -> -> 1. 找到包安装路径下的 `.commit_id` (如: `/your_path/site-packages/mindspore/.commit_id`), 可获取其对应的代码 commit记录,如: -> -> ``` -> __commit_id__ = '[sha1]:94cf8828,[branch]:(HEAD,origin/master,origin/HEAD,master)' -> ``` -> -> 2. 通过该 commit id 查找源码中对应的 `./.jenkins/task/config/cann_version.txt` 文件即可知配套 CANN 的归档日期,重新安装对应CANN。 -> -> -> -> **mindspore,CANN包推荐** -> -> 依赖的 CANN 和 MindSpore 仍处于开发态未正式发布,可通过 -> https://repo.mindspore.cn/mindspore/mindspore/version 获取每日构建版本,并安装对应的 CANN 配套环境。 -> -> **推荐获取: Milan_C20_20241231 的CANN,和 20250125 的每日 mindspore 包。** -> mindspore包地址: -> https://repo.mindspore.cn/mindspore/mindspore/version/202501/20250125/master_20250125160017_3f1def978242de1dda3ef0544e282b6ef369d165_newest/unified/aarch64/ -> -> CANN包地址: -> https://mindspore-repo.csi.rnd.huawei.com/productrepo - ---- - -## 安装 - -### 源码安装 - -```shell -git clone https://gitee.com/mindspore/vllm_mindspore.git -cd vllm_mindspore -# 1. 安装vllm(可选) -bash install_vllm.sh -# bash install_vllm.sh develop # 开发者模式 - -# 2. 安装vllm_mindspore -pip3 install . -# pip3 install -e . # 开发者模式 - -# 3. 卸载torch相关包 -pip3 uninstall torch torch-npu torchvision # 卸载 torch 相关包,当前msadapter带来的限制,后续清除 -``` - -> msadapter 需要申请仓权限: -> -> https://gitee.com/mindspore/msadapter - -### 通过镜像使用 - -```` -bash build_image.sh $DEVICE_TYPE $VERSION -# bash build_image.sh 800I 2.0.RC1.B020 -```` - -> DEVICE_TYPE可以取值`300I`、`800I`、`A3` -> -> VERSION取值为MindIE版本号,如2.0.RC1.B020 - ---- - -## 部署 - -1. 离线批量推理 - - ```python - import vllm_mindspore # Add this line on the top of script. - from vllm import LLM, SamplingParams - - # Sample prompts. - prompts = [ - "I am", - "Today is", - "Llama is" - ] - - # Create a sampling params object. - sampling_params = SamplingParams(temperature=0.0, top_p=0.95) - - # Create an LLM. - llm = LLM(model="meta-llama/Llama-2-7b-hf") - # Generate texts from the prompts. The output is a list of RequestOutput objects - # that contain the prompt, generated text, and other information. - outputs = llm.generate(prompts, sampling_params) - # Print the outputs. - for output in outputs: - prompt = output.prompt - generated_text = output.outputs[0].text - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") - ``` - - > **关于权重设置** - > - > 1. 在线权重需要设置HF_TOKEN - > - > `export HF_TOKEN=Your_token` - > - > 2. 本地权重设置 - > - > 如果已经有下载好的模型配置、权重等,将 `meta-llama/Llama-2-7b-hf` 替换为本地的路径即可。 - > - > - > - > - > **https请求失败时额外设置** - > - > 由于一些限制,在线下载在特定的服务器上需要通过安装较低版本的 requests 包 `requests-2.27.1`,且需要在脚本最上方添加如下代码: - > - > ```python - > import urllib3 - > import os - > # disable SSL certificate verification - > os.environ['CURL_CA_BUNDLE'] = '' - > # disable_warning - > urllib3.disable_warnings() - > ``` - -2. 服务化(兼容openai) - - **拉起服务** - - `python3 -m vllm_mindspore.entrypoints vllm.entrypoints.openai.api_server --model "meta-llama/Llama-2-7b-hf"` - - **发起请求** - - ```shell - curl http://localhost:8000/v1/completions \ - -H "Content-Type: application/json" \ - -d '{ - "model": "meta-llama/Llama-2-7b-hf", - "prompt": "Llama is", - "max_tokens": 120, - "temperature": 0 - }' - ``` - - - - - +# vllm_mindspore + +## 项目介绍 + +功能介绍 + +API + +--- + +## 前置依赖 + +### 运行环境 + +OS:linux aarch64 + +python:python3.9-3.11 + +device:Ascend A2/A3卡 + +软件: + +1. `CANN>=8.0.0` +2. `mindspore>=2.5.0` + +### 环境验证 + +`python -c "import mindspore;mindspore.set_context(device_target='Ascend');mindspore.run_check()"` + +> **若报错显示CANN不匹配,则需重新安装CANN包** +> +> 则可通过如下方法找到对应的CANN: +> +> 1. 找到包安装路径下的 `.commit_id` (如: `/your_path/site-packages/mindspore/.commit_id`), 可获取其对应的代码 commit记录,如: +> +> ```bash +> __commit_id__ = '[sha1]:94cf8828,[branch]:(HEAD,origin/master,origin/HEAD,master)' +> ``` +> +> 2. 通过该 commit id 查找源码中对应的 `./.jenkins/task/config/cann_version.txt` 文件即可知配套 CANN 的归档日期,重新安装对应CANN。 +> **mindspore,CANN包推荐** +> 依赖的 CANN 和 MindSpore 仍处于开发态未正式发布,可通过 +> https://repo.mindspore.cn/mindspore/mindspore/version 获取每日构建版本,并安装对应的 CANN 配套环境。 +> **推荐获取: Milan_C20_20241231 的CANN,和 20250125 的每日 mindspore 包。** +> mindspore包地址: +> https://repo.mindspore.cn/mindspore/mindspore/version/202501/20250125/master_20250125160017_3f1def978242de1dda3ef0544e282b6ef369d165_newest/unified/aarch64/ +> CANN包地址: +> https://mindspore-repo.csi.rnd.huawei.com/productrepo + +--- + +## 安装 + +### 源码安装 + +```shell +git clone https://gitee.com/mindspore/vllm_mindspore.git +cd vllm_mindspore +# 1. 安装vllm(可选) +bash install_vllm.sh +# bash install_vllm.sh develop # 开发者模式 + +# 2. 安装vllm_mindspore +pip3 install . +# pip3 install -e . # 开发者模式 + +# 3. 卸载torch相关包 +pip3 uninstall torch torch-npu torchvision # 卸载 torch 相关包,当前msadapter带来的限制,后续清除 +``` + +> msadapter 需要申请仓权限: +> +> https://gitee.com/mindspore/msadapter + +### 通过镜像使用 + +````bash +bash build_image.sh $DEVICE_TYPE $VERSION +# bash build_image.sh 800I 2.0.RC1.B020 +```` + +> DEVICE_TYPE可以取值`300I`、`800I`、`A3` +> +> VERSION取值为MindIE版本号,如2.0.RC1.B020 +--- + +## 部署 + +1. 离线批量推理 + + ```python + import vllm_mindspore # Add this line on the top of script. + from vllm import LLM, SamplingParams + + # Sample prompts. + prompts = [ + "I am", + "Today is", + "Llama is" + ] + + # Create a sampling params object. + sampling_params = SamplingParams(temperature=0.0, top_p=0.95) + + # Create an LLM. + + llm = LLM(model="meta-llama/Llama-2-7b-hf") + + # Generate texts from the prompts. The output is a list of RequestOutput objects + # that contain the prompt, generated text, and other information. + outputs = llm.generate(prompts, sampling_params) + # Print the outputs. + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + ``` + + > **关于权重设置** + > 1. 在线权重需要设置HF_TOKEN + > `export HF_TOKEN=Your_token` + > 2. 本地权重设置 + > 如果已经有下载好的模型配置、权重等,将 `meta-llama/Llama-2-7b-hf` 替换为本地的路径即可。 + > **https请求失败时额外设置** + > 由于一些限制,在线下载在特定的服务器上需要通过安装较低版本的 requests 包 `requests-2.27.1`,且需要在脚本最上方添加如下代码: + > ```python + > import urllib3 + > import os + > # disable SSL certificate verification + > os.environ['CURL_CA_BUNDLE'] = '' + > # disable_warning + > urllib3.disable_warnings() + > ``` + +2. 服务化(兼容openai) + + **拉起服务** + + `python3 -m vllm_mindspore.entrypoints vllm.entrypoints.openai.api_server --model "meta-llama/Llama-2-7b-hf"` + + **发起请求** + + ```shell + curl http://localhost:8000/v1/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "meta-llama/Llama-2-7b-hf", + "prompt": "Llama is", + "max_tokens": 120, + "temperature": 0 + }' + ``` diff --git a/setup.py b/setup.py index a3a865fa45f9e8f5b5c4bcef71ac5bf8ab664350..1a3be6a4781f3795fc06447cb1a587b73b7f09f4 100644 --- a/setup.py +++ b/setup.py @@ -37,7 +37,6 @@ def load_module_from_path(module_name, path): ROOT_DIR = os.path.dirname(__file__) logger = logging.getLogger(__name__) - if not sys.platform.startswith("linux"): logger.warning( "vllm_mindspore only supports Linux platform." @@ -84,19 +83,23 @@ def get_requirements() -> List[str]: def prepare_submodules() -> None: + """prepare submodules for vllm_mindspore.""" + def _run_cmd(args: str, check: bool = True) -> None: cmds = args.split(" ") returned = subprocess.run(cmds, stderr=subprocess.STDOUT) if check: returned.check_returncode() elif returned.returncode != 0: - logger.warning("Run %s error, please check!" % args) + # 将参数直接传递给 logger.warning + logger.warning("Run %s error, please check!", args) old_dir = os.getcwd() os.chdir(ROOT_DIR) _run_cmd("rm -rf vllm_mindspore/msadapter") - _run_cmd("git submodule update --init vllm_mindspore/msadapter", check=False) + _run_cmd("git submodule update --init vllm_mindspore/msadapter", + check=False) os.chdir(get_path("vllm_mindspore", "msadapter")) # Add __init__.py for packing. @@ -108,17 +111,14 @@ def prepare_submodules() -> None: prepare_submodules() - setup( name="vllm-mindspore", use_scm_version=True, setup_requires=["setuptools_scm"], author="MindSpore Team", license="Apache 2.0", - description=( - "A high-throughput and memory-efficient inference and " - "serving engine for LLMs" - ), + description=("A high-throughput and memory-efficient inference and " + "serving engine for LLMs"), long_description=read_readme(), long_description_content_type="text/markdown", url="https://gitee.com/mindspore/vllm_mindspore", diff --git a/vllm_mindspore/__init__.py b/vllm_mindspore/__init__.py index a972fd3968a356bfbc7ffa734596d4ec4226014f..1b3da6e4defa5463d2abd99eb0b54c6be732dc03 100644 --- a/vllm_mindspore/__init__.py +++ b/vllm_mindspore/__init__.py @@ -19,49 +19,64 @@ import os import sys import warnings +import vllm.attention +import vllm.config +import vllm.engine.async_llm_engine +import vllm.engine.llm_engine +import vllm.executor +import vllm.executor.ray_gpu_executor +import vllm.model_executor +import vllm.platforms +import vllm.utils +import vllm.worker.cache_engine +import vllm.worker.model_runner +import vllm.worker.worker +import vllm.distributed.parallel_state + +from vllm.worker.worker import Worker +from vllm.executor.ray_gpu_executor import RayGPUExecutor +from vllm_mindspore.attention.selector import get_ms_attn_backend +from vllm_mindspore.executor.multiproc_worker_utils import get_mp_context as ms_get_mp_context +from vllm_mindspore.executor.ray_gpu_executor import initialize_ray_cluster, ms_init_workers_ray +from vllm_mindspore.model_executor.model_loader.utils import get_ms_model_architecture +from vllm_mindspore.model_executor.model_loader.weight_utils import safetensors_weights_iterator +from vllm_mindspore.model_executor.models.registry import MindSporeModelRegistry, _run_in_subprocess +from vllm_mindspore.model_executor.sampling_metadata import SamplingMetadata, SamplingMetadataCache, SequenceGroupToSample +from vllm_mindspore.platforms.ascend import AscendPlatform +from vllm_mindspore.scripts import env_setup +from vllm_mindspore.utils import ( + ascend_device_count_stateless, + ascend_is_initialized, + async_tensor_h2d, + direct_register_custom_op, + get_dtype_size, + make_tensor_with_pad, + memory_profiling) +from vllm_mindspore.worker.cache_engine import cache_engine_init, ms_allocate_kv_cache, ms_swap_in, ms_swap_out +from vllm_mindspore.worker.model_runner import _get_cuda_graph_pad_size, profile_run +from vllm_mindspore.worker.worker import _warm_up_model, determine_num_available_blocks +from vllm_mindspore.distributed.parallel_state import all_reduce_for_GroupCoordinator, init_model_parallel_group +from .config import _verify_quantization, get_head_size +from .utils import check_ready + msadapter_path = os.path.abspath( os.path.join(os.path.dirname(__file__), "msadapter/mindtorch") ) sys.path.insert(0, msadapter_path) -from .version import __version__ - if "vllm" in sys.modules: - # Check models variable in sub process, cannot raise here. warnings.warn( - "vllm import before vllm_mindspore, vllm_mindspore cannot worker right!" + "vllm imported before vllm_mindspore; vllm_mindspore may not work correctly!" ) -from vllm_mindspore.scripts import env_setup - env_setup() -from vllm_mindspore.platforms.ascend import AscendPlatform - ascend_platform = AscendPlatform() -import vllm.config - vllm.config.current_platform = ascend_platform - -import vllm.platforms - vllm.platforms.current_platform = ascend_platform - -import vllm.utils - vllm.utils.current_platform = ascend_platform -from vllm_mindspore.utils import ( - direct_register_custom_op, - memory_profiling, - make_tensor_with_pad, - async_tensor_h2d, - get_dtype_size, - ascend_device_count_stateless, - ascend_is_initialized, -) - vllm.utils.direct_register_custom_op = direct_register_custom_op vllm.utils.memory_profiling = memory_profiling vllm.utils.make_tensor_with_pad = make_tensor_with_pad @@ -71,127 +86,49 @@ vllm.utils.cuda_device_count_stateless = ascend_device_count_stateless vllm.utils.cuda_is_initialized = ascend_is_initialized vllm.config.cuda_device_count_stateless = ascend_device_count_stateless -import vllm.executor - vllm.executor.cuda_device_count_stateless = ascend_device_count_stateless -from vllm_mindspore.model_executor.models.registry import ( - MindSporeModelRegistry, - _run_in_subprocess, -) - -import vllm.model_executor - vllm.model_executor.models.ModelRegistry = MindSporeModelRegistry vllm.config.ModelRegistry = MindSporeModelRegistry -from vllm_mindspore.model_executor.model_loader.utils import get_ms_model_architecture -from vllm.model_executor.model_loader import get_model_architecture - vllm.model_executor.model_loader.get_model_architecture = get_ms_model_architecture -vllm.model_executor.model_loader.utils.get_model_architecture = ( - get_ms_model_architecture -) -vllm.model_executor.model_loader.loader.get_model_architecture = ( - get_ms_model_architecture -) +vllm.model_executor.model_loader.utils.get_model_architecture = get_ms_model_architecture +vllm.model_executor.model_loader.loader.get_model_architecture = get_ms_model_architecture vllm.model_executor.models.registry._run_in_subprocess = _run_in_subprocess -from vllm_mindspore.model_executor.sampling_metadata import ( - SequenceGroupToSample, - SamplingMetadataCache, - SamplingMetadata, -) - vllm.model_executor.SamplingMetadataCache = SamplingMetadataCache vllm.model_executor.SamplingMetadata = SamplingMetadata vllm.model_executor.sampling_metadata.SequenceGroupToSample = SequenceGroupToSample vllm.model_executor.sampling_metadata.SamplingMetadataCache = SamplingMetadataCache vllm.model_executor.sampling_metadata.SamplingMetadata = SamplingMetadata -from vllm_mindspore.attention.selector import get_ms_attn_backend - -import vllm.attention - vllm.attention.get_attn_backend = get_ms_attn_backend -from vllm_mindspore.worker.cache_engine import ( - ms_allocate_kv_cache, - ms_swap_in, - ms_swap_out, - cache_engine_init -) - -import vllm.worker.cache_engine - vllm.worker.cache_engine.CacheEngine._allocate_kv_cache = ms_allocate_kv_cache vllm.worker.cache_engine.CacheEngine.__init__ = cache_engine_init vllm.worker.cache_engine.CacheEngine.swap_in = ms_swap_in vllm.worker.cache_engine.CacheEngine.swap_out = ms_swap_out -from vllm_mindspore.model_executor.model_loader.weight_utils import ( - safetensors_weights_iterator, -) - -vllm.model_executor.model_loader.loader.safetensors_weights_iterator = ( - safetensors_weights_iterator -) - -from vllm_mindspore.worker.worker import ( - _warm_up_model, - determine_num_available_blocks, -) -from vllm.worker.worker import Worker +vllm.model_executor.model_loader.loader.safetensors_weights_iterator = safetensors_weights_iterator Worker._warm_up_model = _warm_up_model Worker.determine_num_available_blocks = determine_num_available_blocks -from vllm_mindspore.worker.model_runner import _get_cuda_graph_pad_size, profile_run - -vllm.worker.model_runner.ModelInputForGPUBuilder._get_cuda_graph_pad_size = ( - _get_cuda_graph_pad_size -) +vllm.worker.model_runner.ModelInputForGPUBuilder._get_cuda_graph_pad_size = _get_cuda_graph_pad_size vllm.worker.model_runner.GPUModelRunnerBase.profile_run = profile_run -from vllm_mindspore.distributed.parallel_state import ( - all_reduce_for_GroupCoordinator, - init_model_parallel_group, -) - -vllm.distributed.parallel_state.GroupCoordinator.all_reduce = ( - all_reduce_for_GroupCoordinator -) +vllm.distributed.parallel_state.GroupCoordinator.all_reduce = all_reduce_for_GroupCoordinator vllm.distributed.parallel_state.init_model_parallel_group = init_model_parallel_group -from vllm_mindspore.executor.multiproc_worker_utils import ( - get_mp_context as ms_get_mp_context, -) -from vllm.executor.multiproc_worker_utils import get_mp_context - vllm.executor.multiproc_worker_utils.get_mp_context = ms_get_mp_context -from vllm_mindspore.executor.ray_gpu_executor import ( - ms_init_workers_ray, - initialize_ray_cluster, -) - -from vllm.executor.ray_gpu_executor import RayGPUExecutor - RayGPUExecutor._init_workers_ray = ms_init_workers_ray - vllm.executor.ray_utils.initialize_ray_cluster = initialize_ray_cluster -import vllm.engine.llm_engine -import vllm.engine.async_llm_engine - vllm.engine.llm_engine.initialize_ray_cluster = initialize_ray_cluster vllm.engine.async_llm_engine.initialize_ray_cluster = initialize_ray_cluster - -from .config import get_head_size, _verify_quantization vllm.config.ModelConfig.get_head_size = get_head_size vllm.config.ModelConfig._verify_quantization = _verify_quantization -from .utils import check_ready - check_ready() diff --git a/vllm_mindspore/attention/__init__.py b/vllm_mindspore/attention/__init__.py index a523461741c88a476b964f76d43f73c19ed201db..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 100644 --- a/vllm_mindspore/attention/__init__.py +++ b/vllm_mindspore/attention/__init__.py @@ -1 +0,0 @@ -from .layer import Attention diff --git a/vllm_mindspore/attention/backends/ms_attn.py b/vllm_mindspore/attention/backends/ms_attn.py index f01c1517f9385f93ddb95880344e2dca192e8784..76f706f12da373ac856421c5bdad439dbf0f1a98 100644 --- a/vllm_mindspore/attention/backends/ms_attn.py +++ b/vllm_mindspore/attention/backends/ms_attn.py @@ -21,35 +21,24 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type import torch - -from vllm.attention.backends.abstract import ( - AttentionBackend, - AttentionImpl, - AttentionMetadata, - AttentionMetadataBuilder, - AttentionType, - AttentionState, -) - -if TYPE_CHECKING: - from vllm.worker.model_runner import ModelInputForGPUBuilder - -from vllm.utils import make_tensor_with_pad -from vllm.attention.backends.utils import ( - compute_slot_mapping, - compute_slot_mapping_start_idx, - is_block_tables_empty, -) +import mindspore as ms +from mindspore import mutable +from mindspore._c_expression import swap_cache +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionMetadata, + AttentionMetadataBuilder, + AttentionState, AttentionType) +from vllm.attention.backends.utils import (compute_slot_mapping, + compute_slot_mapping_start_idx, + is_block_tables_empty) from vllm.multimodal import MultiModalPlaceholderMap - +from vllm.utils import make_tensor_with_pad from vllm_mindspore.attention.backends.utils import MsAttentionState from vllm_mindspore.attention.ops.paged_attn import PagedAttentionMetadata - from vllm_mindspore.utils import MsKVCache -import mindspore as ms -from mindspore import mutable -from mindspore._c_expression import swap_cache +if TYPE_CHECKING: + from vllm.worker.model_runner import ModelInputForGPUBuilder @dataclass @@ -100,8 +89,8 @@ class MSAttentionMetadata(AttentionMetadata, PagedAttentionMetadata): return self def get_seq_lens( - self, - attn_type: str, + self, + attn_type: str, ): """ Extract appropriate sequence lengths from attention metadata @@ -119,8 +108,8 @@ class MSAttentionMetadata(AttentionMetadata, PagedAttentionMetadata): """ if ( - attn_type == AttentionType.DECODER - or attn_type == AttentionType.ENCODER_ONLY + attn_type == AttentionType.DECODER + or attn_type == AttentionType.ENCODER_ONLY ): seq_lens_q = self.seq_lens seq_lens_kv = self.seq_lens @@ -135,16 +124,19 @@ class MSAttentionMetadata(AttentionMetadata, PagedAttentionMetadata): return seq_lens_q, seq_lens_kv def get_seq_len_block_table_args( - self, - attn_type: str, + self, + attn_type: str, ) -> tuple: if ( - attn_type == AttentionType.DECODER - or attn_type == AttentionType.ENCODER_ONLY + attn_type == AttentionType.DECODER + or attn_type == AttentionType.ENCODER_ONLY ): # Decoder self-attention # Choose max_seq_len based on whether we are in prompt_run - return (self.seq_lens_tensor, self.max_decode_seq_len, self.block_tables) + return ( + self.seq_lens_tensor, + self.max_decode_seq_len, + self.block_tables) elif attn_type == AttentionType.ENCODER_DECODER: # Enc/dec cross-attention KVs match encoder sequence length; # cross-attention utilizes special "cross" block tables @@ -155,12 +147,21 @@ class MSAttentionMetadata(AttentionMetadata, PagedAttentionMetadata): ) elif attn_type == AttentionType.ENCODER: # No block tables associated with encoder attention - return (self.encoder_seq_lens_tensor, self.max_encoder_seq_len, None) + return ( + self.encoder_seq_lens_tensor, + self.max_encoder_seq_len, + None) else: raise AttributeError(f"Invalid attention type {str(attn_type)}") def keys(self): - return ["num_prefill_tokens", "num_decode_tokens", "slot_mapping", "batch_valid_length", "context_lens", "block_tables"] + return [ + "num_prefill_tokens", + "num_decode_tokens", + "slot_mapping", + "batch_valid_length", + "context_lens", + "block_tables"] def __getitem__(self, key): if key == "context_lens": @@ -173,7 +174,9 @@ class MSAttentionMetadata(AttentionMetadata, PagedAttentionMetadata): return mutable(getattr(self, key)) return mutable(getattr(self, key)) -class MsAttentionMetadataBuilder(AttentionMetadataBuilder[MSAttentionMetadata]): + +class MsAttentionMetadataBuilder( + AttentionMetadataBuilder[MSAttentionMetadata]): def __init__(self, input_builder: "ModelInputForGPUBuilder"): self.slot_mapping: List[int] = [] @@ -182,8 +185,7 @@ class MsAttentionMetadataBuilder(AttentionMetadataBuilder[MSAttentionMetadata]): self.block_tables: List[List[int]] = [] self.curr_seq_lens: List[int] = [] self.multimodal_placeholder_maps: Dict[str, MultiModalPlaceholderMap] = ( - defaultdict(MultiModalPlaceholderMap) - ) + defaultdict(MultiModalPlaceholderMap)) self.num_prefills = 0 self.num_prefill_tokens = 0 self.num_decode_tokens = 0 @@ -195,10 +197,10 @@ class MsAttentionMetadataBuilder(AttentionMetadataBuilder[MSAttentionMetadata]): self.block_size = input_builder.block_size def _add_seq_group( - self, - inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup", - chunked_prefill_enabled: bool, - prefix_cache_hit: bool, + self, + inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup", + chunked_prefill_enabled: bool, + prefix_cache_hit: bool, ): """Add a sequence group to the metadata. Specifically update/append 1. context length. @@ -209,13 +211,13 @@ class MsAttentionMetadataBuilder(AttentionMetadataBuilder[MSAttentionMetadata]): block_tables = inter_data.block_tables for ( - seq_id, - token_len, - seq_len, - curr_seq_len, - query_len, - context_len, - curr_sliding_window_block, + seq_id, + token_len, + seq_len, + curr_seq_len, + query_len, + context_len, + curr_sliding_window_block, ) in zip( inter_data.seq_ids, [len(t) for t in inter_data.input_tokens], @@ -231,7 +233,8 @@ class MsAttentionMetadataBuilder(AttentionMetadataBuilder[MSAttentionMetadata]): mm_maps = inter_data.multi_modal_placeholder_maps if mm_maps: for modality, placeholders in mm_maps.items(): - self.multimodal_placeholder_maps[modality].extend(placeholders) + self.multimodal_placeholder_maps[modality].extend( + placeholders) self.num_prefills += 1 self.num_prefill_tokens += token_len @@ -250,7 +253,7 @@ class MsAttentionMetadataBuilder(AttentionMetadataBuilder[MSAttentionMetadata]): # include the entries for the incoming prefill tokens. block_table = block_tables[seq_id] elif ( - chunked_prefill_enabled or not is_prompt + chunked_prefill_enabled or not is_prompt ) and block_tables is not None: if curr_sliding_window_block == 0: block_table = block_tables[seq_id] @@ -275,20 +278,25 @@ class MsAttentionMetadataBuilder(AttentionMetadataBuilder[MSAttentionMetadata]): ) def build( - self, - seq_lens: List[int], - query_lens: List[int], - cuda_graph_pad_size: int, - batch_size: int, + self, + seq_lens: List[int], + query_lens: List[int], + cuda_graph_pad_size: int, + _batch_size: int, ): - """Build attention metadata with on-device tensors. + """ + Build attention metadata with on-device tensors. Args: + seq_lens: The maybe padded sequence lengths of the input sequences. + query_lens: The query lengths of the input sequences. + cuda_graph_pad_size: The padding size for cuda graph. -1 if cuda graph is not used. - batch_size: The maybe padded batch size. + + _batch_size: The maybe padded batch size. """ prefix_cache_hit = any( [ @@ -298,8 +306,9 @@ class MsAttentionMetadataBuilder(AttentionMetadataBuilder[MSAttentionMetadata]): ) for inter_data in self.input_builder.inter_data_list: self._add_seq_group( - inter_data, self.input_builder.chunked_prefill_enabled, prefix_cache_hit - ) + inter_data, + self.input_builder.chunked_prefill_enabled, + prefix_cache_hit) device = self.runner.device use_captured_graph = cuda_graph_pad_size != -1 @@ -309,8 +318,9 @@ class MsAttentionMetadataBuilder(AttentionMetadataBuilder[MSAttentionMetadata]): num_decode_tokens = self.num_decode_tokens if use_captured_graph: - # TODO(tronzhang): Maybe here only turn graph mode on , and go with then same condition branch logic? - raise RuntimeError("Doesnot support captured graph now!") + # TODO(tronzhang): Maybe here only turn graph mode on , and go with + # then same condition branch logic? + raise RuntimeError("Doesn't support captured graph now!") else: block_tables = make_tensor_with_pad( self.block_tables, @@ -361,10 +371,10 @@ class MsAttentionBackend(AttentionBackend): @staticmethod def get_kv_cache_shape( - num_blocks: int, - block_size: int, - num_kv_heads: int, - head_size: int, + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, ) -> Tuple[int, ...]: if block_size % 16 != 0: raise ValueError("Block size must be a multiple of 16.") @@ -372,10 +382,10 @@ class MsAttentionBackend(AttentionBackend): @staticmethod def swap_blocks( - src_kv_cache: MsKVCache, - dst_kv_cache: MsKVCache, - src_to_dst: torch.Tensor, - swap_type: bool, + src_kv_cache: MsKVCache, + dst_kv_cache: MsKVCache, + src_to_dst: torch.Tensor, + swap_type: bool, ) -> None: """ Swap key/value cache between host and device, to support multi-batch and long-sequence inference. @@ -395,10 +405,11 @@ class MsAttentionBackend(AttentionBackend): @staticmethod def copy_blocks( - kv_caches: List[MsKVCache], - src_to_dists: torch.Tensor, + kv_caches: List[MsKVCache], + src_to_dists: torch.Tensor, ) -> None: - # TODO(tronzhang): this may be slow, a faster interface should be implemented by custom op! + # TODO(tronzhang): this may be slow, a faster interface should be + # implemented by custom op! blocks_to_copy = src_to_dists.asnumpy().tolist() for kv_cache in kv_caches: npu_key_block, npu_value_block = kv_cache @@ -434,30 +445,30 @@ class MsAttentionImpl(AttentionImpl): """ def __init__( - self, - num_heads: int, - head_size: int, - scale: float, - num_kv_heads: int, - alibi_slopes: Optional[List[float]], - sliding_window: Optional[int], - kv_cache_dtype: str, - blocksparse_params: Optional[Dict[str, Any]] = None, - logits_soft_cap: Optional[float] = None, + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, ) -> None: pass def forward( - self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: MSAttentionMetadata, - k_scale: float = 1.0, - v_scale: float = 1.0, - attn_type: str = AttentionType.DECODER, - output: Optional[torch.Tensor] = None, + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: MSAttentionMetadata, + k_scale: float = 1.0, + v_scale: float = 1.0, + attn_type: str = AttentionType.DECODER, + output: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with FlashAttention. @@ -472,7 +483,6 @@ class MsAttentionImpl(AttentionImpl): attn_metadata: Metadata for attention. NOTE: It in-place updates the output tensor. """ - pass class MLABackend(AttentionBackend): @@ -498,28 +508,27 @@ class MLABackend(AttentionBackend): @staticmethod def get_kv_cache_shape( - num_blocks: int, - block_size: int, - num_kv_heads: int, # assumed to be 1 for MLA - head_size: int, + num_blocks: int, + block_size: int, + num_kv_heads: int, # assumed to be 1 for MLA + head_size: int, ) -> Tuple[int, ...]: return (1, num_blocks, block_size, 1, head_size) @staticmethod def swap_blocks( - src_kv_cache: torch.Tensor, - dst_kv_cache: torch.Tensor, - src_to_dst: torch.Tensor, + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: torch.Tensor, ) -> None: src_key_cache = src_kv_cache[0] dst_key_cache = dst_kv_cache[0] swap_cache(src_key_cache, dst_key_cache, src_to_dst) - @staticmethod def copy_blocks( - kv_caches: List[torch.Tensor], - src_to_dists: torch.Tensor, + kv_caches: List[torch.Tensor], + src_to_dists: torch.Tensor, ) -> None: blocks_to_copy = src_to_dists.asnumpy().tolist() for kv_cache in kv_caches: diff --git a/vllm_mindspore/attention/backends/utils.py b/vllm_mindspore/attention/backends/utils.py index 00970f8f5682c68e4da8cecdfcbb496a8e989a9f..b8af257cb6276f14ecd795133d447022d94003de 100644 --- a/vllm_mindspore/attention/backends/utils.py +++ b/vllm_mindspore/attention/backends/utils.py @@ -35,7 +35,7 @@ class MsAttentionState(AttentionState): return def get_graph_input_buffers( - self, attn_metadata, is_encoder_decoder_model: bool = False + self, attn_metadata, is_encoder_decoder_model: bool = False ): """Get attention-specific input buffers for CUDA graph capture.""" ... @@ -46,13 +46,15 @@ class MsAttentionState(AttentionState): yield def graph_capture_get_metadata_for_batch( - self, batch_size: int, is_encoder_decoder_model: bool = False + self, batch_size: int, is_encoder_decoder_model: bool = False ): ... def graph_clone(self, batch_size: int): ... def prepare_graph_input_buffers( - self, input_buffers, attn_metadata, is_encoder_decoder_model: bool = False - ) -> None: + self, + input_buffers, + attn_metadata, + is_encoder_decoder_model: bool = False) -> None: """In-place modify input buffers dict for CUDA graph replay.""" ... diff --git a/vllm_mindspore/attention/layer.py b/vllm_mindspore/attention/layer.py index 01eacca3a2a62986c639796db8b4cc7e417a2892..c50473aa2710f5aa1f68bdc01feac7ad937793bf 100644 --- a/vllm_mindspore/attention/layer.py +++ b/vllm_mindspore/attention/layer.py @@ -17,37 +17,31 @@ """Common layer for LLM.""" from typing import Any, Dict, List, Optional, Tuple -from mindspore import Tensor, mint, nn, ops, jit +from mindspore import Tensor, jit, mint, nn, ops from mindspore.common import dtype as mstype from mindspore.ops.auto_generate import PagedAttention, ReshapeAndCache from mindspore.ops.operations.nn_ops import FlashAttentionScore - -from vllm.config import CacheConfig from vllm.attention.backends.abstract import AttentionType -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.config import CacheConfig +from vllm.model_executor.layers.quantization.base_config import \ + QuantizationConfig def _pad_to_max_tensor( - input_: Tensor, - max_len: int, - dim: int = 0, - pad_value: int = -1 + input_: Tensor, max_len: int, dim: int = 0, pad_value: int = -1 ) -> Tensor: """Temporary function, will be deprecated in the future.""" if input_.shape[dim] == max_len: return input_ - pad_shape = (input_.shape[0], max_len - input_.shape[dim], *input_.shape[dim + 1:]) + pad_shape = (input_.shape[0], + max_len - input_.shape[dim], + *input_.shape[dim + 1:]) pad_tensor = mint.ones(size=pad_shape, dtype=input_.dtype) * pad_value output = mint.cat([input_, pad_tensor], dim=dim) return output -def _generate_attn_mask( - query: Tensor, - value: Tensor, - flatten: bool -) -> Tensor: +def _generate_attn_mask(query: Tensor, value: Tensor, flatten: bool) -> Tensor: """Temporary function, will be deprecated in the future.""" if flatten: return mint.triu(mint.ones(size=(128, 128), dtype=query.dtype), 1) @@ -59,9 +53,8 @@ def _generate_attn_mask( def _hidden_states_th2bsh( - input_: Tensor, - batch_valid_length: Tensor -) -> Tensor: + input_: Tensor, + batch_valid_length: Tensor) -> Tensor: """Temporary function, will be deprecated in the future.""" max_seq_len = batch_valid_length.max().item() start_pos = 0 @@ -76,13 +69,12 @@ def _hidden_states_th2bsh( def _hidden_states_bsh2th( - input_: Tensor, - batch_valid_length: Tensor -) -> Tensor: + input_: Tensor, + batch_valid_length: Tensor) -> Tensor: """Temporary function, will be deprecated in the future.""" unpadded_input_list = [] for batch_index, valid_length in enumerate(batch_valid_length): - padded_input = input_[batch_index:batch_index + 1] + padded_input = input_[batch_index: batch_index + 1] unpadded_input = padded_input[:, :valid_length, ...] unpadded_input_list.append(unpadded_input) th_output = mint.cat(unpadded_input_list, dim=1) @@ -102,21 +94,21 @@ class Attention(nn.Cell): """ def __init__( - self, - num_heads: int, - head_size: int, - scale: float, - num_kv_heads: Optional[int] = None, - alibi_slopes: Optional[List[float]] = None, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - blocksparse_params: Optional[Dict[str, Any]] = None, - logits_soft_cap: Optional[float] = None, - per_layer_sliding_window: Optional[int] = None, - use_mla: bool = False, - prefix: str = "", - attn_type: str = AttentionType.DECODER, - **extra_impl_args, + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: Optional[int] = None, + _alibi_slopes: Optional[List[float]] = None, + _cache_config: Optional[CacheConfig] = None, + _quant_config: Optional[QuantizationConfig] = None, + _blocksparse_params: Optional[Dict[str, Any]] = None, + _logits_soft_cap: Optional[float] = None, + _per_layer_sliding_window: Optional[int] = None, + _use_mla: bool = False, + _prefix: str = "", + attn_type: str = AttentionType.DECODER, + **_extra_impl_args, ) -> None: super().__init__() if attn_type != AttentionType.DECODER: @@ -127,8 +119,8 @@ class Attention(nn.Cell): self.num_heads = num_heads self.num_kv_heads = num_kv_heads self.head_size = head_size - self.hidden_size_per_partition = num_heads*head_size - self.kv_hidden_size_per_partition = num_kv_heads*head_size + self.hidden_size_per_partition = num_heads * head_size + self.kv_hidden_size_per_partition = num_kv_heads * head_size self.flatten = True input_layout = "TH" if self.flatten else "BSH" # pynative 下不支持拉平操作。 @@ -148,21 +140,21 @@ class Attention(nn.Cell): @jit def construct( - self, - query: Tensor, - key: Tensor, - value: Tensor, - kv_cache: Tuple[Tensor, Tensor], - # attn_metadata: MSMetadata, - num_prefill_tokens: int, - num_decode_tokens: int, - slot_mapping: Tensor, - batch_valid_length: Tuple[int], - context_lens: Tensor, - block_tables: Tensor, - attn_mask: Tensor, + self, + query: Tensor, + key: Tensor, + value: Tensor, + kv_cache: Tuple[Tensor, Tensor], + # attn_metadata: MSMetadata, + num_prefill_tokens: int, + num_decode_tokens: int, + slot_mapping: Tensor, + batch_valid_length: Tuple[int], + context_lens: Tensor, + block_tables: Tensor, + attn_mask: Tensor, ) -> Tensor: - """Attention foward, support MHA and GQA. + """Attention forward, support MHA and GQA. Args: query: shape = [1, num_tokens, hidden_size] @@ -175,22 +167,25 @@ class Attention(nn.Cell): """ output = query key_cache, value_cache = kv_cache - cache_out = self.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping) + cache_out = self.reshape_and_cache( + key, value, key_cache, value_cache, slot_mapping) query = ops.depend(query, cache_out) if num_prefill_tokens > 0: - output = self._run_prefill_forward(query, key, value, attn_mask, batch_valid_length, batch_valid_length) + output = self._run_prefill_forward( + query, key, value, attn_mask, batch_valid_length, batch_valid_length) if num_decode_tokens > 0: - output = self._run_decode_forward(query, key_cache, value_cache, block_tables, context_lens) + output = self._run_decode_forward( + query, key_cache, value_cache, block_tables, context_lens) return output def _run_prefill_forward( - self, - query: Tensor, - key: Tensor, - value: Tensor, - attn_mask: Tensor, - actual_seq_qlen: Tuple[int], - actual_seq_kvlen: Tuple[int], + self, + query: Tensor, + key: Tensor, + value: Tensor, + attn_mask: Tensor, + actual_seq_qlen: Tuple[int], + actual_seq_kvlen: Tuple[int], ) -> Tensor: """Prefill with FlashAttention. @@ -219,12 +214,12 @@ class Attention(nn.Cell): return output def _run_decode_forward( - self, - query: Tensor, - key_cache: Tensor, - value_cache: Tensor, - block_tables: Tensor, - context_lens: Tensor, + self, + query: Tensor, + key_cache: Tensor, + value_cache: Tensor, + block_tables: Tensor, + context_lens: Tensor, ) -> Tensor: """Decode with PagedAttention. @@ -235,5 +230,10 @@ class Attention(nn.Cell): block_tables: shape = [block_size, num_block] context_lens: shape = [batch_size, ] """ - output = self.paged_attention(query, key_cache, value_cache, block_tables, context_lens) + output = self.paged_attention( + query, + key_cache, + value_cache, + block_tables, + context_lens) return output diff --git a/vllm_mindspore/attention/ops/paged_attn.py b/vllm_mindspore/attention/ops/paged_attn.py index 57f58db6f9430a9d02894641786166381f9aa237..87704aaaef25082b6236bb8e9d608ab99bc24caf 100644 --- a/vllm_mindspore/attention/ops/paged_attn.py +++ b/vllm_mindspore/attention/ops/paged_attn.py @@ -20,7 +20,6 @@ from dataclasses import dataclass from typing import List, Optional, Tuple import torch - from vllm import _custom_ops as ops from vllm.triton_utils import HAS_TRITON @@ -30,6 +29,7 @@ if HAS_TRITON: # Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`. _PARTITION_SIZE = 512 + # TODO(tronzhang): delete all not work codes. @@ -59,38 +59,39 @@ class PagedAttention: @staticmethod def get_kv_cache_shape( - num_blocks: int, - block_size: int, - num_kv_heads: int, - head_size: int, + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, ) -> Tuple[int, ...]: return (2, num_blocks, block_size * num_kv_heads * head_size) @staticmethod def split_kv_cache( - kv_cache: torch.Tensor, - num_kv_heads: int, - head_size: int, + kv_cache: torch.Tensor, + num_kv_heads: int, + head_size: int, ) -> Tuple[torch.Tensor, torch.Tensor]: x = 16 // kv_cache.element_size() num_blocks = kv_cache.shape[1] key_cache = kv_cache[0] - key_cache = key_cache.view(num_blocks, num_kv_heads, head_size // x, -1, x) + key_cache = key_cache.view( + num_blocks, num_kv_heads, head_size // x, -1, x) value_cache = kv_cache[1] value_cache = value_cache.view(num_blocks, num_kv_heads, head_size, -1) return key_cache, value_cache @staticmethod def write_to_paged_cache( - key: torch.Tensor, - value: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - slot_mapping: torch.Tensor, - kv_cache_dtype: str, - k_scale: float, - v_scale: float, + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: float, + v_scale: float, ) -> None: ops.reshape_and_cache( key, @@ -105,47 +106,48 @@ class PagedAttention: @staticmethod def forward_decode( - query: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - block_tables: torch.Tensor, - seq_lens: torch.Tensor, - max_seq_len: int, - kv_cache_dtype: str, - num_kv_heads: int, - scale: float, - alibi_slopes: Optional[torch.Tensor], - k_scale: float, - v_scale: float, - tp_rank: int = 0, - blocksparse_local_blocks: int = 0, - blocksparse_vert_stride: int = 0, - blocksparse_block_size: int = 64, - blocksparse_head_sliding_step: int = 0, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + max_seq_len: int, + kv_cache_dtype: str, + num_kv_heads: int, + scale: float, + alibi_slopes: Optional[torch.Tensor], + k_scale: float, + v_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, ) -> torch.Tensor: if blocksparse_vert_stride is not None and blocksparse_vert_stride > 1: # use blocksparse paged attention block_size = value_cache.size(-1) assert ( - blocksparse_block_size > 0 and blocksparse_block_size % block_size == 0 - ), ( - f"{blocksparse_block_size=} needs to be a multiple of" - f"{block_size=} used in block_tables." - ) + blocksparse_block_size > 0 and blocksparse_block_size % + block_size == 0), (f"{ + blocksparse_block_size=} needs to be a multiple of" f"{ + block_size=} used in block_tables.") output = torch.empty_like(query) block_size = value_cache.shape[3] num_seqs, num_heads, head_size = query.shape - max_num_partitions = (max_seq_len + _PARTITION_SIZE - 1) // _PARTITION_SIZE + max_num_partitions = ( + max_seq_len + _PARTITION_SIZE - 1) // _PARTITION_SIZE # NOTE(woosuk): We use a simple heuristic to decide whether to use # PagedAttention V1 or V2. If the number of partitions is 1, we use # V1 to avoid the overhead of reduction. Also, if the number of # sequences or heads is large, we use V1 since there is enough work # to parallelize. # TODO(woosuk): Tune this heuristic. - # For context len > 8192, use V2 kernel to avoid shared memory shortage. + # For context len > 8192, use V2 kernel to avoid shared memory + # shortage. use_v1 = max_seq_len <= 8192 and ( - max_num_partitions == 1 or num_seqs * num_heads > 512 + max_num_partitions == 1 or num_seqs * num_heads > 512 ) if use_v1: @@ -213,21 +215,21 @@ class PagedAttention: @staticmethod def forward_prefix( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - kv_cache_dtype: str, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - block_tables: torch.Tensor, - query_start_loc: torch.Tensor, - seq_lens_tensor: torch.Tensor, - context_lens: torch.Tensor, - max_query_len: int, - alibi_slopes: Optional[torch.Tensor], - sliding_window: Optional[int], - k_scale: float, - v_scale: float, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache_dtype: str, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + block_tables: torch.Tensor, + query_start_loc: torch.Tensor, + seq_lens_tensor: torch.Tensor, + context_lens: torch.Tensor, + max_query_len: int, + alibi_slopes: Optional[torch.Tensor], + sliding_window: Optional[int], + k_scale: float, + v_scale: float, ) -> torch.Tensor: output = torch.empty_like(query) context_attention_fwd( @@ -253,9 +255,9 @@ class PagedAttention: @staticmethod def swap_blocks( - src_kv_cache: torch.Tensor, - dst_kv_cache: torch.Tensor, - src_to_dst: torch.Tensor, + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: torch.Tensor, ) -> None: src_key_cache = src_kv_cache[0] dst_key_cache = dst_kv_cache[0] @@ -267,8 +269,8 @@ class PagedAttention: @staticmethod def copy_blocks( - kv_caches: List[torch.Tensor], - src_to_dists: torch.Tensor, + kv_caches: List[torch.Tensor], + src_to_dists: torch.Tensor, ) -> None: key_caches = [kv_cache[0] for kv_cache in kv_caches] value_caches = [kv_cache[1] for kv_cache in kv_caches] diff --git a/vllm_mindspore/attention/selector.py b/vllm_mindspore/attention/selector.py index 1dd046661ba137f92ed1fd2e02ef79afb698411e..761e28f6177be4ed2fd57da1d29a26444c868dde 100644 --- a/vllm_mindspore/attention/selector.py +++ b/vllm_mindspore/attention/selector.py @@ -15,16 +15,10 @@ # limitations under the License. # ============================================================================ -from typing import Optional, Type - -import torch -from vllm.attention.backends.abstract import AttentionBackend - from functools import lru_cache from typing import Optional, Type import torch - import vllm.envs as envs from vllm.attention.backends.abstract import AttentionBackend from vllm.logger import init_logger @@ -34,13 +28,13 @@ logger = init_logger(__name__) def which_attn_to_use( - head_size: int, - dtype: torch.dtype, - kv_cache_dtype: Optional[str], - block_size: int, - is_attention_free: bool, - use_v1: bool = False, - use_mla: bool = False, + _head_size: int, + _dtype: torch.dtype, + _kv_cache_dtype: Optional[str], + _block_size: int, + is_attention_free: bool, + _use_v1: bool = False, + use_mla: bool = False, ): """Returns which flash attention backend to use.""" if use_mla: @@ -59,7 +53,8 @@ def which_attn_to_use( ) # get device-specific default attn_backend - default_backend = current_platform.get_default_attn_backend(selected_backend) + default_backend = current_platform.get_default_attn_backend( + selected_backend) if default_backend is not None: return default_backend @@ -68,28 +63,31 @@ def which_attn_to_use( @lru_cache(maxsize=None) def _cached_get_attn_backend( - head_size: int, - dtype: torch.dtype, - kv_cache_dtype: Optional[str], - block_size: int, - is_attention_free: bool, - is_blocksparse: bool = False, - use_v1: bool = False, - use_mla: bool = False, + head_size: int, + dtype: torch.dtype, + kv_cache_dtype: Optional[str], + block_size: int, + is_attention_free: bool, + is_blocksparse: bool = False, + use_v1: bool = False, + use_mla: bool = False, ) -> Type[AttentionBackend]: if is_blocksparse: logger.warning( - "MindSpore doesnot support BlocksparseFlashAttention backend now." + "MindSpore doesn't support BlocksparseFlashAttention backend now." ) backend = which_attn_to_use( - head_size, dtype, kv_cache_dtype, block_size, is_attention_free, use_v1, use_mla - ) + head_size, + dtype, + kv_cache_dtype, + block_size, + is_attention_free, + use_v1, + use_mla) if backend == _Backend.FLASH_ATTN: logger.info("Using Flash Attention backend.") - from vllm_mindspore.attention.backends.ms_attn import ( # noqa: F401 - MsAttentionBackend, - ) + from vllm_mindspore.attention.backends.ms_attn import MsAttentionBackend return MsAttentionBackend elif backend == "MLA_ATTN": @@ -102,13 +100,13 @@ def _cached_get_attn_backend( def get_ms_attn_backend( - head_size: int, - dtype: torch.dtype, - kv_cache_dtype: Optional[str], - block_size: int, - is_attention_free: bool, - is_blocksparse: bool = False, - use_mla: bool = False, + head_size: int, + dtype: torch.dtype, + kv_cache_dtype: Optional[str], + block_size: int, + is_attention_free: bool, + is_blocksparse: bool = False, + use_mla: bool = False, ) -> Type[AttentionBackend]: """Selects which attention backend to use and lazily imports it.""" # Accessing envs.* behind an @lru_cache decorator can cause the wrong diff --git a/vllm_mindspore/config.py b/vllm_mindspore/config.py index 18fed6fc54a7bee1dbebc9ea2311cff048bf5cec..596c843a08e16f205bd136f1dff1a453bbd35e9c 100644 --- a/vllm_mindspore/config.py +++ b/vllm_mindspore/config.py @@ -20,11 +20,12 @@ from vllm_mindspore.utils import is_mindformers_model_backend def get_head_size(self) -> int: if hasattr(self.hf_text_config, "model_type") and ( - self.hf_text_config.model_type in ("deepseek_v2", "deepseek_v3") + self.hf_text_config.model_type in ("deepseek_v2", "deepseek_v3") ): if is_mindformers_model_backend(): - qk_rope_head_dim = getattr(self.hf_text_config, "qk_rope_head_dim", 0) + qk_rope_head_dim = getattr( + self.hf_text_config, "qk_rope_head_dim", 0) return self.hf_text_config.kv_lora_rank + qk_rope_head_dim # FlashAttention supports only head_size 32, 64, 128, 256, @@ -39,6 +40,7 @@ def get_head_size(self) -> int: # FIXME(woosuk): This may not be true for all models. return self.hf_text_config.hidden_size // self.hf_text_config.num_attention_heads + def _verify_quantization(self) -> None: - # Donnot verify now. - return \ No newline at end of file + # Do not verify now. + return diff --git a/vllm_mindspore/distributed/communication_op.py b/vllm_mindspore/distributed/communication_op.py index 58c8c1e8e5f81306dc000cfea0819b6b3666fff6..d9d30ad1e878f22b991f6e664bb14d574b0d8c9b 100644 --- a/vllm_mindspore/distributed/communication_op.py +++ b/vllm_mindspore/distributed/communication_op.py @@ -1,11 +1,8 @@ - - # 该文件实现底层通信接口, 要求动静统一, 最后才可以在网络中入图。 # 不要去照搬mindspeed的, 因为训练当中包含太多的特性, 推理只需要非常简单的通信,可以提升性能。 from typing import Any, Dict, Optional, Union -import mindspore as ms from mindspore import Tensor, nn, ops from mindspore.communication.comm_func import (all_gather_into_tensor, all_reduce, broadcast, @@ -23,8 +20,7 @@ def tensor_model_parallel_all_reduce(input_: Tensor) -> Tensor: return output -def tensor_model_parallel_all_gather(input_: Tensor, - dim: int = -1) -> Tensor: +def tensor_model_parallel_all_gather(input_: Tensor, dim: int = -1) -> Tensor: if get_tensor_model_parallel_world_size() == 1: return input_ """All-gather the input tensor across model parallel group.""" @@ -34,18 +30,17 @@ def tensor_model_parallel_all_gather(input_: Tensor, # Convert negative dim to positive. dim += len(input_size) # Reshape - output_tensor = output_tensor.reshape((world_size, ) + input_size) + output_tensor = output_tensor.reshape((world_size,) + input_size) output_tensor = output_tensor.movedim(0, dim) - output_tensor = output_tensor.reshape(input_size[:dim] + - (world_size * - input_size[dim], ) + - input_size[dim + 1:]) + output_tensor = output_tensor.reshape( + input_size[:dim] + (world_size * input_size[dim],) + input_size[dim + 1:] + ) return output -def tensor_model_parallel_gather(input_: Tensor, - dst: int = 0, - dim: int = -1) -> Optional[Tensor]: +def tensor_model_parallel_gather( + input_: Tensor, dst: int = 0, dim: int = -1 +) -> Optional[Tensor]: if get_tensor_model_parallel_world_size() == 1: return input_ """Gather the input tensor across model parallel group.""" @@ -69,9 +64,8 @@ def broadcast_tensor(tensor, src: int = 0): return broadcast(tensor, src, group=get_world_group()) -def broadcast_tensor_dict(tensor_dict: Optional[Dict[Any, Union[Tensor, - Any]]] = None, - src: int = 0): +def broadcast_tensor_dict( + tensor_dict: Optional[Dict[Any, Union[Tensor, Any]]] = None, src: int = 0): ... # if not torch.distributed.is_initialized(): # return tensor_dict @@ -105,7 +99,7 @@ class ReduceFromModelParallelRegion(nn.Cell): class GatherFromModelParallelRegion(nn.Cell): - "Gather the input from model parallel region and concatinate." + "Gather the input from model parallel region and concatenate." def __init__(self): super().__init__() @@ -114,14 +108,15 @@ class GatherFromModelParallelRegion(nn.Cell): if self.world_size > 1: self.tp_group = get_tp_group().device_group._name - def construct(self, - input_: Tensor, - dst: int = 0, - dim: int = -1) -> Optional[Tensor]: + def construct( + self, input_: Tensor, dst: int = 0, dim: int = -1 + ) -> Optional[Tensor]: # Size and dimension. if self.world_size == 1: return input_ - output = ops.CollectiveGather(dest_rank=dst, group=self.tp_group)(input_.transpose(2, 1, 0)) + output = ops.CollectiveGather(dest_rank=dst, group=self.tp_group)( + input_.transpose(2, 1, 0) + ) if self.tp_rank != dst: return ops.depend(ops.zeros_like(input_), output) return output.transpose(2, 1, 0) diff --git a/vllm_mindspore/distributed/parallel_state.py b/vllm_mindspore/distributed/parallel_state.py index b669f82f90bbd0ee825ca81c7c924b1cc5124b7c..15949ee9479ac6089581bfc3d5ece51d481f601b 100644 --- a/vllm_mindspore/distributed/parallel_state.py +++ b/vllm_mindspore/distributed/parallel_state.py @@ -15,31 +15,28 @@ # limitations under the License. # ============================================================================ -import pickle -from typing import List, Optional, Any +from typing import List, Optional -import numpy as np import torch import torch.distributed def init_model_parallel_group( - group_ranks: List[List[int]], - local_rank: int, - backend: str, - use_custom_allreduce: Optional[bool] = None, - use_message_queue_broadcaster: bool = False, - group_name: Optional[str] = None, + group_ranks: List[List[int]], + local_rank: int, + backend: str, + use_custom_allreduce: Optional[bool] = None, + _use_message_queue_broadcaster: bool = False, + group_name: Optional[str] = None, ) -> "GroupCoordinator": - from vllm.distributed.parallel_state import ( - GroupCoordinator, - _ENABLE_CUSTOM_ALL_REDUCE, - ) + from vllm.distributed.parallel_state import (_ENABLE_CUSTOM_ALL_REDUCE, + GroupCoordinator) if use_custom_allreduce is None: use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE - # TODO(tronzhang): mindspore doesnot support enough communicate cpu ops, set use_message_queue_broadcaster to False now. + # TODO(tronzhang): mindspore doesn't support enough communicate cpu ops, + # set use_message_queue_broadcaster to False now. return GroupCoordinator( group_ranks=group_ranks, local_rank=local_rank, @@ -54,7 +51,8 @@ def init_model_parallel_group( ) -def all_reduce_for_GroupCoordinator(self, input_: torch.Tensor) -> torch.Tensor: +def all_reduce_for_GroupCoordinator( + self, input_: torch.Tensor) -> torch.Tensor: """ User-facing all-reduce function before we actually call the all-reduce operation. diff --git a/vllm_mindspore/entrypoints.py b/vllm_mindspore/entrypoints.py index 208acb72441c8f8bdc5e5fd0b50f1d5cbac38494..afa279d21fe3a6965dfc7c2d884667e14208fced 100644 --- a/vllm_mindspore/entrypoints.py +++ b/vllm_mindspore/entrypoints.py @@ -31,8 +31,8 @@ if __name__ == "__main__": module = importlib.import_module(module_name) except Exception as e: raise ValueError( - "Invalid entrypoint(%s) for vllm, error: %s!" % (module_name, str(e)) - ) + "Invalid entrypoint(%s) for vllm, error: %s!" % + (module_name, str(e))) module_code = inspect.getsource(module) vllm_mindspore_enable_line = "import vllm_mindspore\n" diff --git a/vllm_mindspore/executor/ray_gpu_executor.py b/vllm_mindspore/executor/ray_gpu_executor.py index 8b9cd11abba67a7fd4d5dcf391f580fbd32ba694..a4798641f6a9fb576e506b205df7f04ac0d7537f 100644 --- a/vllm_mindspore/executor/ray_gpu_executor.py +++ b/vllm_mindspore/executor/ray_gpu_executor.py @@ -15,17 +15,17 @@ # limitations under the License. # ============================================================================ -from typing import Dict, List, Optional from collections import defaultdict +from typing import Dict, List, Optional import vllm.envs as envs -from vllm.utils import (get_distributed_init_method, get_ip, get_open_port) -from vllm.logger import init_logger from vllm.config import ParallelConfig -from vllm.platforms import current_platform -from vllm.executor.ray_utils import RayWorkerWrapper, ray, available_resources_per_node from vllm.executor.ray_gpu_executor import PlacementGroupSchedulingStrategy - +from vllm.executor.ray_utils import (RayWorkerWrapper, + available_resources_per_node, ray) +from vllm.logger import init_logger +from vllm.platforms import current_platform +from vllm.utils import get_distributed_init_method, get_ip, get_open_port logger = init_logger(__name__) @@ -36,7 +36,7 @@ class MsRayWorkerWrapper(RayWorkerWrapper): def ms_init_workers_ray(self, placement_group: "PlacementGroup", - **ray_remote_kwargs): + **ray_remote_kwargs): if (self.parallel_config.tensor_parallel_size == 1 and self.parallel_config.pipeline_parallel_size == 1): # For single GPU case, we use a ray worker with constrained memory. @@ -145,8 +145,8 @@ def ms_init_workers_ray(self, placement_group: "PlacementGroup", # driver_dummy_worker can be None when using ray spmd worker. continue worker_node_and_gpu_ids.append( - ray.get(worker.get_node_and_gpu_ids.remote()) \ - ) # type: ignore + ray.get(worker.get_node_and_gpu_ids.remote()) + ) # type: ignore node_workers = defaultdict(list) # node id -> list of worker ranks node_gpus = defaultdict(list) # node id -> list of gpu ids @@ -178,14 +178,14 @@ def ms_init_workers_ray(self, placement_group: "PlacementGroup", # Set environment variables for the driver and workers. all_args_to_update_environment_variables = [({ - "CUDA_VISIBLE_DEVICES": - ",".join(map(str, node_gpus[node_id])), - "VLLM_TRACE_FUNCTION": - str(envs.VLLM_TRACE_FUNCTION), - **({ - "VLLM_ATTENTION_BACKEND": envs.VLLM_ATTENTION_BACKEND - } if envs.VLLM_ATTENTION_BACKEND is not None else {}) - }, ) for (node_id, _) in worker_node_and_gpu_ids] + "CUDA_VISIBLE_DEVICES": + ",".join(map(str, node_gpus[node_id])), + "VLLM_TRACE_FUNCTION": + str(envs.VLLM_TRACE_FUNCTION), + **({ + "VLLM_ATTENTION_BACKEND": envs.VLLM_ATTENTION_BACKEND + } if envs.VLLM_ATTENTION_BACKEND is not None else {}) + },) for (node_id, _) in worker_node_and_gpu_ids] self._env_vars_for_all_workers = ( all_args_to_update_environment_variables) @@ -254,8 +254,8 @@ def ms_init_workers_ray(self, placement_group: "PlacementGroup", def initialize_ray_cluster( - parallel_config: ParallelConfig, - ray_address: Optional[str] = None, + parallel_config: ParallelConfig, + ray_address: Optional[str] = None, ): """Initialize the distributed cluster with Ray. @@ -274,7 +274,8 @@ def initialize_ray_cluster( # Connect to a ray cluster. if current_platform.is_rocm() or current_platform.is_xpu(): - # Try to connect existing ray instance and create a new one if not found + # Try to connect existing ray instance and create a new one if not + # found try: ray.init("auto", ignore_reinit_error=True) except ConnectionError: @@ -285,8 +286,8 @@ def initialize_ray_cluster( ignore_reinit_error=True, num_gpus=parallel_config.world_size) else: - ray.init(address=ray_address, ignore_reinit_error=True, - runtime_env={"env_vars":{"RAY_EXPERIMENTAL_NOSET_ASCEND_RT_VISIBLE_DEVICES": "1"}}) + ray.init(address=ray_address, ignore_reinit_error=True, runtime_env={ + "env_vars": {"RAY_EXPERIMENTAL_NOSET_ASCEND_RT_VISIBLE_DEVICES": "1"}}) if parallel_config.placement_group: # Placement group is already set. diff --git a/vllm_mindspore/model_executor/custom_op.py b/vllm_mindspore/model_executor/custom_op.py index 7b913ef8510bb294993e35c8d2345f24673484e7..75bdc2cd8657edb92429ae7f78ef3553dbf2f5d8 100644 --- a/vllm_mindspore/model_executor/custom_op.py +++ b/vllm_mindspore/model_executor/custom_op.py @@ -17,6 +17,7 @@ from mindspore import nn + class CustomOp(nn.Cell): """ Base class for custom ops. @@ -35,6 +36,6 @@ class CustomOp(nn.Cell): def forward_cuda(self, *args, **kwargs): raise NotImplementedError - + def dispatch_forward(self): - return self.forward_native \ No newline at end of file + return self.forward_native diff --git a/vllm_mindspore/model_executor/layers/layernorm.py b/vllm_mindspore/model_executor/layers/layernorm.py index dd497e08eafd32b624977d355fe95f80d46f2276..ad1ade3561a85aca30912394f16d8290ce4f89da 100644 --- a/vllm_mindspore/model_executor/layers/layernorm.py +++ b/vllm_mindspore/model_executor/layers/layernorm.py @@ -15,31 +15,30 @@ # limitations under the License. # ============================================================================ -from typing import Optional, Tuple, Union, Any +from typing import Any, Optional, Tuple, Union from mindspore import Parameter, Tensor, mint, ops from mindspore.common import dtype as mstype -from mindspore.common.dtype import typing from vllm_mindspore.model_executor.custom_op import CustomOp class RMSNorm(CustomOp): def __init__( - self, - hidden_size: int, - eps: float = 1e-6, - var_hidden_size: Optional[int] = None, - params_dtype: Optional[Any] = mstype.float16, + self, + hidden_size: int, + eps: float = 1e-6, + _var_hidden_size: Optional[int] = None, + params_dtype: Optional[Any] = mstype.float16, ) -> None: super().__init__() self.weight = Parameter(mint.ones(hidden_size, dtype=params_dtype)) self.rms_norm = ops.RmsNorm(eps) def forward_native( - self, - x: Tensor, - residual: Optional[Tensor] = None + self, + x: Tensor, + residual: Optional[Tensor] = None ) -> Union[Tensor, Tuple[Tensor, Tensor]]: if residual is not None: x = x + residual diff --git a/vllm_mindspore/model_executor/layers/linear.py b/vllm_mindspore/model_executor/layers/linear.py index 62142ac8c44316a3bef0f195b1b325ea9b5d77c0..161c72a9fb9d5817e0128d26b55c85bdb472eda3 100644 --- a/vllm_mindspore/model_executor/layers/linear.py +++ b/vllm_mindspore/model_executor/layers/linear.py @@ -15,29 +15,21 @@ # limitations under the License. # ============================================================================ -from typing import List, Optional from abc import abstractmethod +from typing import List, Optional -import numpy as np import mindspore as ms -from mindspore import mint, ops, Tensor -from mindspore import Parameter - -from vllm.distributed import ( - divide, - get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, - split_tensor_along_last_dim, - tensor_model_parallel_all_gather, - tensor_model_parallel_all_reduce, -) +from mindspore import Parameter, Tensor, mint, ops +from vllm.distributed import (divide, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + split_tensor_along_last_dim, + tensor_model_parallel_all_gather) + +from vllm_mindspore.distributed.communication_op import \ + ReduceFromModelParallelRegion from vllm_mindspore.model_executor.layers.quantization.base_config import ( - QuantizationConfig, - QuantizeMethodBase, -) + QuantizationConfig, QuantizeMethodBase) from vllm_mindspore.model_executor.utils import set_weight_attrs -from vllm_mindspore.distributed.communication_op import ReduceFromModelParallelRegion - WEIGHT_LOADER_V2_SUPPORTED = [ "CompressedTensorsLinearMethod", @@ -63,14 +55,14 @@ class LinearMethodBase(QuantizeMethodBase): @abstractmethod def create_weights( - self, - layer: ms.nn.Cell, - input_size_per_partition: int, - output_partition_sizes: List[int], - input_size: int, - output_size: int, - params_dtype, - **extra_weight_attrs + self, + layer: ms.nn.Cell, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype, + **extra_weight_attrs ): """Create weights for a linear layer. The weights will be set as attributes of the layer. @@ -88,9 +80,8 @@ class LinearMethodBase(QuantizeMethodBase): raise NotImplementedError @abstractmethod - def apply( - self, layer: ms.nn.Cell, x: ms.Tensor, bias: Optional[ms.Tensor] = None - ) -> ms.Tensor: + def apply(self, layer: ms.nn.Cell, x: ms.Tensor, + bias: Optional[ms.Tensor] = None) -> ms.Tensor: """Apply the weights in layer to the input tensor. Expects create_weights to have been called before on the layer.""" raise NotImplementedError @@ -100,14 +91,14 @@ class UnquantizedLinearMethod(LinearMethodBase): """Linear method without quantization.""" def create_weights( - self, - layer: ms.nn.Cell, - input_size_per_partition: int, - output_partition_sizes: List[int], - input_size: int, - output_size: int, - params_dtype, - **extra_weight_attrs + self, + layer: ms.nn.Cell, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype, + **extra_weight_attrs ): weight = Parameter( mint.zeros( @@ -151,13 +142,13 @@ class LinearBase(ms.nn.Cell): """ def __init__( - self, - input_size: int, - output_size: int, - skip_bias_add: bool = False, - params_dtype=None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", + self, + input_size: int, + output_size: int, + skip_bias_add: bool = False, + params_dtype=None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() @@ -170,9 +161,11 @@ class LinearBase(ms.nn.Cell): params_dtype = ms.float16 self.params_dtype = params_dtype if quant_config is None: - self.quant_method: Optional[QuantizeMethodBase] = UnquantizedLinearMethod() + self.quant_method: Optional[QuantizeMethodBase] = UnquantizedLinearMethod( + ) else: - self.quant_method = quant_config.get_quant_method(self, prefix=prefix) + self.quant_method = quant_config.get_quant_method( + self, prefix=prefix) def construct(self, x: ms.Tensor) -> ms.Tensor: raise NotImplementedError @@ -183,20 +176,24 @@ class LinearBase(ms.nn.Cell): class ColumnParallelLinear(LinearBase): def __init__( - self, - input_size: int, - output_size: int, - bias: bool = True, - gather_output: bool = False, - skip_bias_add: bool = False, - params_dtype=None, - quant_config: Optional[QuantizationConfig] = None, - output_sizes: Optional[List[int]] = None, - prefix: str = "", + self, + input_size: int, + output_size: int, + bias: bool = True, + gather_output: bool = False, + skip_bias_add: bool = False, + params_dtype=None, + quant_config: Optional[QuantizationConfig] = None, + output_sizes: Optional[List[int]] = None, + prefix: str = "", ): super().__init__( - input_size, output_size, skip_bias_add, params_dtype, quant_config, prefix - ) + input_size, + output_size, + skip_bias_add, + params_dtype, + quant_config, + prefix) self.gather_output = gather_output @@ -208,12 +205,14 @@ class ColumnParallelLinear(LinearBase): # If QKV or MergedColumn, use output size of each partition. if hasattr(self, "output_sizes"): self.output_partition_sizes = [ - divide(output_size, tp_size) for output_size in self.output_sizes - ] + divide( + output_size, + tp_size) for output_size in self.output_sizes] if output_sizes is None: output_sizes = [output_size] - # vllm 中变量名称与megatron相似度高, 然后mindspeed 与 megatron的变量命名相似度高, 所以以vllm为基准, 可以最大可能保证变量命名一致性, 方便load。 + # vllm 中变量名称与megatron相似度高, 然后mindspeed 与 megatron的变量命名相似度高, 所以以vllm为基准, + # 可以最大可能保证变量命名一致性, 方便load。 self.quant_method.create_weights( layer=self, input_size_per_partition=self.input_size, @@ -278,7 +277,8 @@ class ColumnParallelLinear(LinearBase): if output_dim is not None and not use_bitsandbytes_4bit: shard_size = param_data.shape[output_dim] start_idx = tp_rank * shard_size - loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) + loaded_weight = loaded_weight.narrow( + output_dim, start_idx, shard_size) # Special case for loading scales off disk, which often do not # have a shape (such as in the case of AutoFP8). @@ -316,15 +316,15 @@ class MergedColumnParallelLinear(ColumnParallelLinear): """ def __init__( - self, - input_size: int, - output_sizes: List[int], - bias: bool = True, - gather_output: bool = False, - skip_bias_add: bool = False, - params_dtype=None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", + self, + input_size: int, + output_sizes: List[int], + bias: bool = True, + gather_output: bool = False, + skip_bias_add: bool = False, + params_dtype=None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): self.output_sizes = output_sizes tp_size = get_tensor_model_parallel_world_size() @@ -341,7 +341,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ) def weight_loader( - self, param, loaded_weight, loaded_shard_id: Optional[int] = None + self, param, loaded_weight, loaded_shard_id: Optional[int] = None ): use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False) param_data = param.data @@ -371,12 +371,14 @@ class MergedColumnParallelLinear(ColumnParallelLinear): # loaded_shard_id param_data = param.data - param_data = param_data.narrow(output_dim, shard_offset, shard_size) + param_data = param_data.narrow( + output_dim, shard_offset, shard_size) start_idx = tp_rank * shard_size # bitsandbytes loads the weights of the specific portion # no need to narrow here if not use_bitsandbytes_4bit: - loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) + loaded_weight = loaded_weight.narrow( + output_dim, start_idx, shard_size) assert param_data.shape == loaded_weight.shape # param_data.copy_(loaded_weight) # param_data.set_data(loaded_weight) @@ -385,16 +387,16 @@ class MergedColumnParallelLinear(ColumnParallelLinear): class QKVParallelLinear(ColumnParallelLinear): def __init__( - self, - hidden_size: int, - head_size: int, - total_num_heads: int, - total_num_kv_heads: Optional[int] = None, - bias: bool = True, - skip_bias_add: bool = False, - params_dtype=None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", + self, + hidden_size: int, + head_size: int, + total_num_heads: int, + total_num_kv_heads: Optional[int] = None, + bias: bool = True, + skip_bias_add: bool = False, + params_dtype=None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): self.hidden_size = hidden_size self.head_size = head_size @@ -407,13 +409,14 @@ class QKVParallelLinear(ColumnParallelLinear): self.num_heads = divide(self.total_num_heads, tp_size) if tp_size >= self.total_num_kv_heads: self.num_kv_heads = 1 - self.num_kv_head_replicas = divide(tp_size, self.total_num_kv_heads) + self.num_kv_head_replicas = divide( + tp_size, self.total_num_kv_heads) else: self.num_kv_heads = divide(self.total_num_kv_heads, tp_size) self.num_kv_head_replicas = 1 input_size = self.hidden_size output_size = ( - (self.num_heads + 2 * self.num_kv_heads) * tp_size * self.head_size + (self.num_heads + 2 * self.num_kv_heads) * tp_size * self.head_size ) self.output_sizes = [ self.num_heads * self.head_size * tp_size, # q_proj @@ -448,10 +451,12 @@ class QKVParallelLinear(ColumnParallelLinear): shard_offset = self.num_heads * self.head_size shard_size = self.num_kv_heads * self.head_size elif loaded_shard_id == "v": - shard_offset = (self.num_heads + self.num_kv_heads) * self.head_size + shard_offset = ( + self.num_heads + self.num_kv_heads) * self.head_size shard_size = self.num_kv_heads * self.head_size - param_data = param_data.narrow(output_dim, shard_offset, shard_size) + param_data = param_data.narrow( + output_dim, shard_offset, shard_size) if loaded_shard_id == "q": shard_id = tp_rank else: @@ -459,12 +464,15 @@ class QKVParallelLinear(ColumnParallelLinear): start_idx = shard_id * shard_size if not use_bitsandbytes_4bit: - loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) + loaded_weight = loaded_weight.narrow( + output_dim, start_idx, shard_size) assert param_data.shape == loaded_weight.shape if param.name.endswith("weight"): - self.weight[shard_offset: shard_offset + shard_size, :] = loaded_weight + self.weight[shard_offset: shard_offset + + shard_size, :] = loaded_weight if param.name.endswith("bias"): - self.bias[shard_offset: shard_offset + shard_size] = loaded_weight + self.bias[shard_offset: shard_offset + + shard_size] = loaded_weight # tp_rank = get_tensor_model_parallel_rank() # if shard_id is "q": # start_index = self.num_heads * tp_rank * self.head_size @@ -483,20 +491,24 @@ class QKVParallelLinear(ColumnParallelLinear): class RowParallelLinear(LinearBase): def __init__( - self, - input_size: int, - output_size: int, - bias: bool = True, - input_is_parallel: bool = True, - skip_bias_add: bool = False, - params_dtype=None, - reduce_results: bool = True, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", + self, + input_size: int, + output_size: int, + bias: bool = True, + input_is_parallel: bool = True, + skip_bias_add: bool = False, + params_dtype=None, + reduce_results: bool = True, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__( - input_size, output_size, skip_bias_add, params_dtype, quant_config, prefix - ) + input_size, + output_size, + skip_bias_add, + params_dtype, + quant_config, + prefix) self.input_is_parallel = input_is_parallel self.reduce_results = reduce_results @@ -527,7 +539,10 @@ class RowParallelLinear(LinearBase): ) if bias: - self.bias = Parameter(mint.zeros(self.output_size), dtype=params_dtype) + self.bias = Parameter( + mint.zeros( + self.output_size), + dtype=params_dtype) set_weight_attrs( self.bias, { @@ -556,7 +571,8 @@ class RowParallelLinear(LinearBase): # Only fuse bias add into GEMM for rank 0 (this ensures that # bias will not get added more than once in TP>1 case) bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias - output_parallel = self.quant_method.apply(self, input_parallel, bias=bias_) + output_parallel = self.quant_method.apply( + self, input_parallel, bias=bias_) if self.reduce_results and self.tp_size > 1: output = self.tensor_model_parallel_all_reduce(output_parallel) else: @@ -578,7 +594,8 @@ class RowParallelLinear(LinearBase): if input_dim is not None and not use_bitsandbytes_4bit: shard_size = param_data.shape[input_dim] start_idx = tp_rank * shard_size - loaded_weight = loaded_weight.narrow(input_dim, start_idx, shard_size) + loaded_weight = loaded_weight.narrow( + input_dim, start_idx, shard_size) # Special case for loading scales off disk, which often do not # have a shape (such as in the case of AutoFP8). diff --git a/vllm_mindspore/model_executor/layers/logits_processor.py b/vllm_mindspore/model_executor/layers/logits_processor.py index 9399e518a38f57177857a865e48961881b786976..e96c996eb03de474d6ac0a85502ef76cdc6b2ca9 100644 --- a/vllm_mindspore/model_executor/layers/logits_processor.py +++ b/vllm_mindspore/model_executor/layers/logits_processor.py @@ -19,19 +19,16 @@ import inspect from typing import Optional import mindspore.nn as nn -from mindspore import Tensor -from mindspore import mint - -from vllm.distributed import ( - tensor_model_parallel_all_gather, - tensor_model_parallel_gather, -) -from vllm_mindspore.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding, -) -from vllm_mindspore.model_executor.sampling_metadata import SamplingMetadata +from mindspore import Tensor, mint +from vllm.distributed import (tensor_model_parallel_all_gather, + tensor_model_parallel_gather) from vllm.platforms import current_platform +from vllm_mindspore.model_executor.layers.vocab_parallel_embedding import \ + VocabParallelEmbedding +from vllm_mindspore.model_executor.sampling_metadata import SamplingMetadata + + # TODO(tronzhang): Use vllm's logits_processor.py latter... @@ -45,12 +42,12 @@ class LogitsProcessor(nn.Cell): """ def __init__( - self, - vocab_size: int, - org_vocab_size: Optional[int] = None, - scale: float = 1.0, - logits_as_input: bool = False, - soft_cap: Optional[float] = None, + self, + vocab_size: int, + org_vocab_size: Optional[int] = None, + scale: float = 1.0, + logits_as_input: bool = False, + soft_cap: Optional[float] = None, ) -> None: """ Args: @@ -69,17 +66,18 @@ class LogitsProcessor(nn.Cell): self.use_gather = not current_platform.is_tpu() def construct( - self, - lm_head: VocabParallelEmbedding, - hidden_states: Tensor, - sampling_metadata: Optional[SamplingMetadata] = None, - embedding_bias: Optional[Tensor] = None, + self, + lm_head: VocabParallelEmbedding, + hidden_states: Tensor, + sampling_metadata: Optional[SamplingMetadata] = None, + embedding_bias: Optional[Tensor] = None, ) -> Optional[Tensor]: if self.logits_as_input: logits = hidden_states else: if sampling_metadata is not None: - hidden_states = _prune_hidden_states(hidden_states, sampling_metadata) + hidden_states = _prune_hidden_states( + hidden_states, sampling_metadata) # Get the logits for the next tokens. logits = self._get_logits(hidden_states, lm_head, embedding_bias) @@ -99,10 +97,10 @@ class LogitsProcessor(nn.Cell): return logits def _get_logits( - self, - hidden_states: Tensor, - lm_head: VocabParallelEmbedding, - embedding_bias: Optional[Tensor], + self, + hidden_states: Tensor, + lm_head: VocabParallelEmbedding, + embedding_bias: Optional[Tensor], ) -> Optional[Tensor]: # Get the logits for the next tokens. logits = lm_head.linear_method.apply( @@ -131,21 +129,22 @@ class LogitsProcessor(nn.Cell): def _prune_hidden_states( - hidden_states: Tensor, - sampling_metadata: SamplingMetadata, + hidden_states: Tensor, + sampling_metadata: SamplingMetadata, ) -> Tensor: # NOTE(kzawora): The if guard is needed for Gaudi - in some scenarios # (warmup, profile_run) we might not have selected_token_indices, # so we skip pruning. if sampling_metadata.selected_token_indices is not None: - return hidden_states.index_select(0, sampling_metadata.selected_token_indices) + return hidden_states.index_select( + 0, sampling_metadata.selected_token_indices) else: return hidden_states def _apply_logits_processors( - logits: Tensor, - sampling_metadata: SamplingMetadata, + logits: Tensor, + sampling_metadata: SamplingMetadata, ) -> Tensor: found_logits_processors = False logits_processed = 0 @@ -156,7 +155,8 @@ def _apply_logits_processors( if logits_processors: found_logits_processors = True - for seq_id, logits_row_idx in zip(seq_ids, seq_group.sample_indices): + for seq_id, logits_row_idx in zip( + seq_ids, seq_group.sample_indices): logits_row = logits[logits_row_idx] past_tokens_ids = seq_group.seq_data[seq_id].output_token_ids prompt_tokens_ids = seq_group.seq_data[seq_id].prompt_token_ids @@ -168,7 +168,8 @@ def _apply_logits_processors( prompt_tokens_ids, past_tokens_ids, logits_row ) else: - logits_row = logits_processor(past_tokens_ids, logits_row) + logits_row = logits_processor( + past_tokens_ids, logits_row) logits[logits_row_idx] = logits_row diff --git a/vllm_mindspore/model_executor/layers/quantization/base_config.py b/vllm_mindspore/model_executor/layers/quantization/base_config.py index ea259ee65742a5089a36a0966520144b6a8973bc..9185962cfe7ac863ae295f9d50badb8afdc19e09 100644 --- a/vllm_mindspore/model_executor/layers/quantization/base_config.py +++ b/vllm_mindspore/model_executor/layers/quantization/base_config.py @@ -18,15 +18,22 @@ import inspect from abc import ABC, abstractmethod from typing import Any, Dict, List, Optional, Type + import mindspore as ms + # TODO(tronzhang): Use vllm's quantization base_config.py latter. + class QuantizeMethodBase(ABC): """Base class for different quantized methods.""" @abstractmethod - def create_weights(self, layer: ms.nn.Cell, *weight_args, **extra_weight_attrs): + def create_weights( + self, + layer: ms.nn.Cell, + *weight_args, + **extra_weight_attrs): """Create weights for a layer. The weights will be set as attributes of the layer.""" @@ -91,7 +98,8 @@ class QuantizationConfig(ABC): raise NotImplementedError @classmethod - def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]: + def override_quantization_method( + cls, hf_quant_cfg, user_quant) -> Optional[str]: """ Detects if this quantization method can support a given checkpoint format by overriding the user specified quantization method -- @@ -106,10 +114,12 @@ class QuantizationConfig(ABC): for key in keys: if key in config: return config[key] - raise ValueError(f"Cannot find any of {keys} in the model's " "quantization config.") + raise ValueError( + f"Cannot find any of {keys} in the model's " "quantization config.") @staticmethod - def get_from_keys_or(config: Dict[str, Any], keys: List[str], default: Any) -> Any: + def get_from_keys_or( + config: Dict[str, Any], keys: List[str], default: Any) -> Any: """Get a optional value from the model's quantization config.""" try: return QuantizationConfig.get_from_keys(config, keys) @@ -117,7 +127,10 @@ class QuantizationConfig(ABC): return default @abstractmethod - def get_quant_method(self, layer: ms.nn.Cell, prefix: str) -> Optional[QuantizeMethodBase]: + def get_quant_method( + self, + layer: ms.nn.Cell, + prefix: str) -> Optional[QuantizeMethodBase]: """Get the quantize method to use for the quantized layer. Args: @@ -130,13 +143,15 @@ class QuantizationConfig(ABC): raise NotImplementedError -def method_has_implemented_embedding(method_class: Type[QuantizeMethodBase]) -> bool: +def method_has_implemented_embedding( + method_class: Type[QuantizeMethodBase]) -> bool: """ Not all quant methods have embedding implemented, so we need to check that it exists for our given method. We check this by making sure the function has been changed from the base implementation. """ - base_embedding = inspect.getattr_static(QuantizeMethodBase, "embedding", None) + base_embedding = inspect.getattr_static( + QuantizeMethodBase, "embedding", None) class_embedding = inspect.getattr_static(method_class, "embedding", None) return class_embedding is not None and class_embedding is not base_embedding diff --git a/vllm_mindspore/model_executor/layers/rotary_embedding.py b/vllm_mindspore/model_executor/layers/rotary_embedding.py index 77827002f512d77570e3438f47a2a6549b56ac8b..4dec26fd3965291e67dd8048adcfa3531cef9272 100644 --- a/vllm_mindspore/model_executor/layers/rotary_embedding.py +++ b/vllm_mindspore/model_executor/layers/rotary_embedding.py @@ -25,10 +25,10 @@ from vllm_mindspore.model_executor.custom_op import CustomOp def _apply_rotary_emb( - x: Tensor, - cos: Tensor, - sin: Tensor, - is_neox_style: bool, + x: Tensor, + cos: Tensor, + sin: Tensor, + is_neox_style: bool, ) -> Tensor: """ Args: @@ -55,13 +55,13 @@ def _apply_rotary_emb( class RotaryEmbedding(CustomOp): def __init__( - self, - head_size: int, - rotary_dim: int, - max_position_embeddings: int, - base: int, - is_neox_style: bool, - dtype, + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + dtype, ) -> None: super().__init__() self.head_size = head_size @@ -82,10 +82,8 @@ class RotaryEmbedding(CustomOp): # use CPU to compute the cache and then move it to GPU. However, we # create the cache on GPU for faster initialization. This may cause # a slight numerical difference between the HF implementation and ours. - inv_freq = 1.0 / ( - base - ** (mint.arange(0, self.rotary_dim, 2, dtype=mstype.float32) / self.rotary_dim) - ) + inv_freq = 1.0 / (base ** (mint.arange(0, self.rotary_dim, + 2, dtype=mstype.float32) / self.rotary_dim)) return inv_freq def _compute_cos_sin_cache(self) -> Tensor: @@ -101,11 +99,11 @@ class RotaryEmbedding(CustomOp): return cache def forward_native( - self, - positions: Tensor, - query: Tensor, - key: Tensor, - offsets: Optional[Tensor] = None, + self, + positions: Tensor, + query: Tensor, + key: Tensor, + offsets: Optional[Tensor] = None, ) -> Tuple[Tensor, Tensor]: """A PyTorch-native implementation of forward().""" if offsets is not None: @@ -133,16 +131,17 @@ class RotaryEmbedding(CustomOp): class InferRotaryEmbedding(CustomOp): def __init__( - self, - head_size: int, - rotary_dim: int, - max_position_embeddings: int, - base: int, - is_neox_style: bool, - dtype, + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + dtype, ) -> None: super().__init__() - freqs_base = np.arange(0, rotary_dim, 2)[: (rotary_dim // 2)].astype(np.float32) # (head_dim // 2, ) + freqs_base = np.arange(0, rotary_dim, 2)[: ( + rotary_dim // 2)].astype(np.float32) # (head_dim // 2, ) freqs = 1.0 / (base ** (freqs_base / rotary_dim)) # (head_dim // 2, ) mscale = 1.0 t = np.arange(0, max_position_embeddings, 1).astype(np.float32) @@ -157,34 +156,36 @@ class InferRotaryEmbedding(CustomOp): self.rotary_embedding_op = ops.ApplyRotaryPosEmb(2) def forward_native( - self, - positions: Tensor, - query: Tensor, - key: Tensor, - batch_valid_length: Tensor, - num_prefill_tokens: int, - offsets: Optional[Tensor] = None, + self, + positions: Tensor, + query: Tensor, + key: Tensor, + batch_valid_length: Tensor, + num_prefill_tokens: int, + offsets: Optional[Tensor] = None, ) -> Tuple[Tensor, Tensor]: if num_prefill_tokens > 0: - return self.rotary_embedding_op(query, key, self.freqs_cos, self.freqs_sin, batch_valid_length) + return self.rotary_embedding_op( + query, key, self.freqs_cos, self.freqs_sin, batch_valid_length) freqs_cos = self.freqs_cos.index_select(0, positions) freqs_sin = self.freqs_sin.index_select(0, positions) - return self.rotary_embedding_op(query, key, freqs_cos, freqs_sin, batch_valid_length) + return self.rotary_embedding_op( + query, key, freqs_cos, freqs_sin, batch_valid_length) _ROPE_DICT: Dict[Tuple, InferRotaryEmbedding] = {} def get_rope( - head_size: int, - rotary_dim: int, - max_position: int, - base: int, - is_neox_style: bool = True, - rope_scaling: Optional[Dict[str, Any]] = None, - dtype: Optional[Any] = mstype.float16, - partial_rotary_factor: float = 1.0, + head_size: int, + rotary_dim: int, + max_position: int, + base: int, + is_neox_style: bool = True, + rope_scaling: Optional[Dict[str, Any]] = None, + dtype: Optional[Any] = mstype.float16, + partial_rotary_factor: float = 1.0, ) -> InferRotaryEmbedding: if rope_scaling is not None: # Transforms every value that is a list into a tuple for caching calls diff --git a/vllm_mindspore/model_executor/layers/sampler.py b/vllm_mindspore/model_executor/layers/sampler.py index c51b526addb2e813b4da6f6a597a413e9f523f37..ad1981b1d27816699cf629439f419a16727bd337 100644 --- a/vllm_mindspore/model_executor/layers/sampler.py +++ b/vllm_mindspore/model_executor/layers/sampler.py @@ -23,35 +23,27 @@ from importlib.util import find_spec from math import inf from typing import Dict, Iterator, List, Optional, Tuple, Union -# TODO(tronzhang): for some ops, msadaptor cannnot support, latter use vllm's... - import msgspec import torch import torch.nn as nn - import vllm.envs as envs -from vllm_mindspore.model_executor.layers.utils import apply_penalties from vllm.sampling_params import SamplingType -from vllm.sequence import ( - VLLM_INVALID_TOKEN_ID, - CompletionSequenceGroupOutput, - Logprob, - PromptLogprobs, - SampleLogprobs, - SequenceOutput, -) +from vllm.sequence import (VLLM_INVALID_TOKEN_ID, + CompletionSequenceGroupOutput, Logprob, + PromptLogprobs, SampleLogprobs, SequenceOutput) from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics +from vllm_mindspore.model_executor.layers.utils import apply_penalties from vllm_mindspore.model_executor.sampling_metadata import ( - SamplingMetadata, - SamplingTensors, - SequenceGroupToSample, -) + SamplingMetadata, SamplingTensors, SequenceGroupToSample) + +# TODO(tronzhang): for some ops, msadaptor cannot support, latter use vllm's... + if envs.VLLM_USE_FLASHINFER_SAMPLER and find_spec("flashinfer"): raise RuntimeError("Donot support for mindspore now.") -else: - flashinfer_top_k_top_p_sampling = None + +flashinfer_top_k_top_p_sampling = None def get_sampler() -> torch.nn.Module: @@ -63,7 +55,8 @@ SampleResultType = List[Tuple[List[int], List[int]]] # Types of temporary data structures used for # computing sample_result -SampleMetadataType = Dict[SamplingType, Tuple[List[int], List[SequenceGroupToSample]]] +SampleMetadataType = Dict[SamplingType, +Tuple[List[int], List[SequenceGroupToSample]]] MultinomialSamplesType = Dict[SamplingType, torch.Tensor] SampleResultsDictType = Dict[int, Tuple[List[int], List[int]]] @@ -97,7 +90,8 @@ SampleReturnType = Tuple[MaybeDeferredSampleResultType, Optional[torch.Tensor]] class SamplerOutput( - msgspec.Struct, omit_defaults=True, array_like=True # type: ignore[call-arg] + # type: ignore[call-arg] + msgspec.Struct, omit_defaults=True, array_like=True ): # type: ignore[call-arg] """For each sequence group, we generate a list of SequenceOutput object, each of which contains one possible candidate for the next token. @@ -156,7 +150,8 @@ class SamplerOutput( return len(self.outputs) def __eq__(self, other: object): - return isinstance(other, self.__class__) and self.outputs == other.outputs + return isinstance( + other, self.__class__) and self.outputs == other.outputs def __repr__(self) -> str: """Show the shape of a tensor instead of its values to reduce noise.""" @@ -166,8 +161,7 @@ class SamplerOutput( else self.sampled_token_probs.shape ) sampled_token_ids_repr = ( - "None" if self.sampled_token_ids is None else self.sampled_token_ids.shape - ) + "None" if self.sampled_token_ids is None else self.sampled_token_ids.shape) return ( f"SamplerOutput(outputs={self.outputs}, " f"sampled_token_probs={sampled_token_probs_repr}, " @@ -207,9 +201,9 @@ class Sampler(nn.Module): self.should_modify_greedy_probs_inplace = False def _init_sampling_tensors( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, ): """The goal here is to reuse sampling tensors between similar decode runs. This is possible because sampling logic does not change between @@ -235,9 +229,9 @@ class Sampler(nn.Module): self._do_min_p = do_min_p def forward( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: """ Single-step scheduling: @@ -258,7 +252,6 @@ class Sampler(nn.Module): sampling_metadata: Metadata for sampling. """ assert logits is not None - _, vocab_size = logits.shape # Prepare sampling tensors with pinned memory to avoid blocking. if not sampling_metadata.reuse_sampling_tensors: @@ -336,7 +329,9 @@ class Sampler(nn.Module): sample_logprobs = None if not sampling_metadata.skip_sampler_cpu_output: # Pythonize logprobs now (GPU -> CPU); do not defer. - assert not isinstance(maybe_deferred_sample_results, SampleResultArgsType) + assert not isinstance( + maybe_deferred_sample_results, + SampleResultArgsType) prompt_logprobs, sample_logprobs = get_logprobs( logprobs, sampling_metadata, maybe_deferred_sample_results ) @@ -366,8 +361,8 @@ class Sampler(nn.Module): def _apply_min_tokens_penalty( - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, ) -> torch.Tensor: """Apply min_tokens penalty which sets stop tokens to -inf if min_tokens have not been generated yet @@ -380,7 +375,8 @@ def _apply_min_tokens_penalty( sampling_params = seq_group.sampling_params sample_indices = seq_group.sample_indices - logits_applied += len(sample_indices) + len(seq_group.prompt_logprob_indices) + logits_applied += len(sample_indices) + \ + len(seq_group.prompt_logprob_indices) if not seq_group.do_sample: continue @@ -413,10 +409,20 @@ def _apply_min_tokens_penalty( def _apply_top_k_top_p( - logits: torch.Tensor, - p: torch.Tensor, - k: torch.Tensor, + logits: torch.Tensor, + p: torch.Tensor, + k: torch.Tensor, ) -> torch.Tensor: + """ + + Args: + logits: + p: + k: + + Returns: + + """ logits_sort, logits_idx = logits.sort(axis=-1, descending=False) # logits_sort, logits_idx = logits.sort(dim=-1, descending=False) @@ -444,8 +450,8 @@ def _apply_top_k_top_p( def _apply_min_p( - logits: torch.Tensor, - min_p: torch.Tensor, + logits: torch.Tensor, + min_p: torch.Tensor, ) -> torch.Tensor: """ Adapted from @@ -461,8 +467,8 @@ def _apply_min_p( def _greedy_sample( - selected_seq_groups: List[SequenceGroupToSample], - samples: torch.Tensor, + selected_seq_groups: List[SequenceGroupToSample], + samples: torch.Tensor, ) -> SampleResultType: """Run greedy sampling on a given samples. @@ -495,8 +501,8 @@ def _greedy_sample( def _random_sample( - selected_seq_groups: List[SequenceGroupToSample], - random_samples: torch.Tensor, + selected_seq_groups: List[SequenceGroupToSample], + random_samples: torch.Tensor, ) -> SampleResultType: """Run random sampling on a given samples. @@ -525,21 +531,22 @@ def _random_sample( if is_prompt: # Prompt phase. parent_ids = [0] * sampling_params.n - next_token_ids = random_samples[sample_idx, : sampling_params.n].tolist() + next_token_ids = random_samples[sample_idx, + : sampling_params.n].tolist() else: # Generation phase. parent_ids = list(range(num_parent_seqs)) next_token_ids = random_samples[ - sample_idx : sample_idx + num_parent_seqs, 0 - ].tolist() + sample_idx: sample_idx + num_parent_seqs, 0 + ].tolist() results.append((next_token_ids, parent_ids)) sample_idx += num_parent_seqs return results def _beam_search_sample( - selected_seq_groups: List[SequenceGroupToSample], - logprobs: torch.Tensor, + selected_seq_groups: List[SequenceGroupToSample], + logprobs: torch.Tensor, ) -> SampleResultType: """Run beam sampling on a given samples. @@ -572,25 +579,25 @@ def _beam_search_sample( seq_ids, sampling_params = seq_group.seq_ids, seq_group.sampling_params num_parent_seqs = len(seq_ids) beam_width = sampling_params.n - seq_group_logprobs = logprobs[sample_idx : sample_idx + num_parent_seqs] + seq_group_logprobs = logprobs[sample_idx: sample_idx + num_parent_seqs] if is_prompt: # Prompt phase. assert num_parent_seqs == 1, "Prompt input should have only one seq." parent_ids = [0] * (2 * beam_width) - _, next_token_ids = torch.topk(seq_group_logprobs[0], 2 * beam_width) + _, next_token_ids = torch.topk( + seq_group_logprobs[0], 2 * beam_width) next_token_ids = next_token_ids.tolist() else: # Generation phase. cumulative_logprobs: List[float] = [ - seq_group.seq_data[seq_id].cumulative_logprob for seq_id in seq_ids - ] + seq_group.seq_data[seq_id].cumulative_logprob for seq_id in seq_ids] cumulative_logprobs_tensor = torch.tensor( - cumulative_logprobs, dtype=torch.float, device=seq_group_logprobs.device - ) + cumulative_logprobs, dtype=torch.float, device=seq_group_logprobs.device) seq_group_logprobs = ( - seq_group_logprobs + cumulative_logprobs_tensor.unsqueeze(1) + seq_group_logprobs + cumulative_logprobs_tensor.unsqueeze(1) ) - _, topk_ids = torch.topk(seq_group_logprobs.flatten(), 2 * beam_width) + _, topk_ids = torch.topk( + seq_group_logprobs.flatten(), 2 * beam_width) topk_ids = topk_ids.tolist() vocab_size = seq_group_logprobs.size(-1) parent_ids = [i // vocab_size for i in topk_ids] @@ -604,8 +611,8 @@ def _beam_search_sample( def exponential(x, lambd=1.0, *, generator=None): if generator is not None: raise ValueError("`generator` can not be supported.") - import numpy as np import mindspore as ms + import numpy as np output = np.random.exponential(scale=lambd, size=x.shape) return ms.Tensor(output).astype(x.dtype) @@ -617,10 +624,20 @@ def exponential(x, lambd=1.0, *, generator=None): # probs will be modified in place, but this is fine, as we pass # in a copy already. def _multinomial( - probs: torch.Tensor, - num_samples: int, - seq_groups: Optional[List[SequenceGroupToSample]] = None, + probs: torch.Tensor, + num_samples: int, + seq_groups: Optional[List[SequenceGroupToSample]] = None, ) -> torch.Tensor: + """ + + Args: + probs: + num_samples: + seq_groups: + + Returns: + + """ if num_samples > 1: probs = probs.repeat_interleave(num_samples, dim=0) q = torch.empty_like(probs) @@ -633,8 +650,8 @@ def _multinomial( seq_ids = seq_group.seq_ids stride = len(seq_ids) * num_samples assert seq_group.generator is not None - q[sample_idx : sample_idx + stride] = exponential( - q[sample_idx : sample_idx + stride] + q[sample_idx: sample_idx + stride] = exponential( + q[sample_idx: sample_idx + stride] ) # q[sample_idx:sample_idx + # stride].exponential_(generator=seq_group.generator) @@ -643,19 +660,32 @@ def _multinomial( def _top_k_top_p_multinomial_with_flashinfer( - probs: torch.Tensor, - top_ks: torch.Tensor, - top_ps: torch.Tensor, - num_samples: int, - seq_groups: Optional[List[SequenceGroupToSample]], + probs: torch.Tensor, + top_ks: torch.Tensor, + top_ps: torch.Tensor, + num_samples: int, + seq_groups: Optional[List[SequenceGroupToSample]], ): + """ + + Args: + probs: + top_ks: + top_ps: + num_samples: + seq_groups: + + Returns: + + """ max_top_k_round = 32 if num_samples > 1: probs = probs.repeat_interleave(num_samples, dim=0) top_ks = top_ks.repeat_interleave(num_samples) top_ps = top_ps.repeat_interleave(num_samples) batch_size = probs.shape[0] - uniform_samples = torch.empty((max_top_k_round, batch_size), device=probs.device) + uniform_samples = torch.empty( + (max_top_k_round, batch_size), device=probs.device) if seq_groups is None: uniform_samples.uniform_() else: @@ -664,7 +694,7 @@ def _top_k_top_p_multinomial_with_flashinfer( seq_ids = seq_group.seq_ids stride = len(seq_ids) * num_samples assert seq_group.generator is not None - uniform_samples[:, sample_idx : sample_idx + stride].uniform_( + uniform_samples[:, sample_idx: sample_idx + stride].uniform_( generator=seq_group.generator ) sample_idx += stride @@ -675,7 +705,9 @@ def _top_k_top_p_multinomial_with_flashinfer( top_ps, ) if not success.all(): - warnings.warn("FlashInfer rejection sampling failed, fallback.", stacklevel=1) + warnings.warn( + "FlashInfer rejection sampling failed, fallback.", + stacklevel=1) probs = flashinfer.sampling.top_k_renorm_prob(probs, top_ks) probs = flashinfer.sampling.top_p_renorm_prob(probs, top_ps) batch_next_token_ids = flashinfer.sampling.sampling_from_probs( @@ -685,7 +717,7 @@ def _top_k_top_p_multinomial_with_flashinfer( def get_pythonized_sample_results( - sample_result_args: SampleResultArgsType, + sample_result_args: SampleResultArgsType, ) -> SampleResultType: """This function consumes GPU-side sampler results and computes Pythonized CPU-side sampler results (GPU -> CPU sync.) @@ -730,7 +762,8 @@ def get_pythonized_sample_results( seq_groups, multinomial_samples[sampling_type] ) elif sampling_type == SamplingType.BEAM: - sample_results = _beam_search_sample(seq_groups, beam_search_logprobs) + sample_results = _beam_search_sample( + seq_groups, beam_search_logprobs) sample_results_dict.update(zip(seq_group_id, sample_results)) return [ @@ -740,12 +773,12 @@ def get_pythonized_sample_results( def _sample_with_torch( - probs: torch.Tensor, - logprobs: torch.Tensor, - sampling_metadata: SamplingMetadata, - sampling_tensors: SamplingTensors, - include_gpu_probs_tensor: bool, - modify_greedy_probs: bool, + probs: torch.Tensor, + logprobs: torch.Tensor, + sampling_metadata: SamplingMetadata, + sampling_tensors: SamplingTensors, + include_gpu_probs_tensor: bool, + modify_greedy_probs: bool, ) -> SampleReturnType: """Torch-oriented _sample() implementation. @@ -798,7 +831,8 @@ def _sample_with_torch( sample_metadata[sampling_type] = (seq_group_id, seq_groups) long_sample_indices = sample_indices.long() if sampling_type == SamplingType.GREEDY: - greedy_samples = torch.argmax(logprobs[long_sample_indices], dim=-1) + greedy_samples = torch.argmax( + logprobs[long_sample_indices], dim=-1) if sampled_token_ids_tensor is not None: # Store sampled tokens in output tensor. @@ -871,22 +905,22 @@ def _sample_with_torch( get_pythonized_sample_results(maybe_deferred_args), sampled_token_ids_tensor, ) - else: - # Defer sampler result Pythonization; return deferred - # Pythonization args & sampled token ids - return ( - maybe_deferred_args, - sampled_token_ids_tensor, - ) + + # Defer sampler result Pythonization; return deferred + # Pythonization args & sampled token ids + return ( + maybe_deferred_args, + sampled_token_ids_tensor, + ) def _sample( - probs: torch.Tensor, - logprobs: torch.Tensor, - sampling_metadata: SamplingMetadata, - sampling_tensors: SamplingTensors, - include_gpu_probs_tensor: bool, - modify_greedy_probs: bool, + probs: torch.Tensor, + logprobs: torch.Tensor, + sampling_metadata: SamplingMetadata, + sampling_tensors: SamplingTensors, + include_gpu_probs_tensor: bool, + modify_greedy_probs: bool, ) -> SampleReturnType: """ Args: @@ -924,16 +958,17 @@ def _get_ranks(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor: Each element in the returned tensor represents the rank of the chosen token in the input logprob tensor. """ - vals = x[torch.arange(0, len(x), device=x.device, dtype=indices.dtype), indices] + vals = x[torch.arange(0, len(x), device=x.device, + dtype=indices.dtype), indices] result = x > vals[:, None] del vals return result.sum(1).add_(1) def get_logprobs( - logprobs: torch.Tensor, - sampling_metadata: SamplingMetadata, - sample_results: SampleResultType, + logprobs: torch.Tensor, + sampling_metadata: SamplingMetadata, + sample_results: SampleResultType, ) -> Tuple[List[Optional[PromptLogprobs]], List[SampleLogprobs]]: """Return sample logprobs and prompt logprobs. @@ -973,7 +1008,8 @@ def get_logprobs( # Select indices to compute logprob from, ranks of token ids, and the top # k token ids from logprobs. - for seq_group, sample_result in zip(sampling_metadata.seq_groups, sample_results): + for seq_group, sample_result in zip( + sampling_metadata.seq_groups, sample_results): sampling_params = seq_group.sampling_params # Update indices and tokens for prompt logprobs. @@ -1005,7 +1041,7 @@ def get_logprobs( assert len(next_token_ids) == len(query_indices) - if len(query_indices) == 0: + if not query_indices: empty_sampled_logprob: SampleLogprobs = [] empty_prompt_logprob: Optional[PromptLogprobs] = None return [empty_prompt_logprob], [empty_sampled_logprob] @@ -1017,7 +1053,8 @@ def get_logprobs( # skip the whole logprob calculation. if largest_num_logprobs >= 0: query_indices_gpu = torch.tensor(query_indices, device=logprobs.device) - next_token_ids_gpu = torch.tensor(next_token_ids, device=logprobs.device) + next_token_ids_gpu = torch.tensor( + next_token_ids, device=logprobs.device) # (num_selected_query_tokens, num_logprobs). Note that query_indices can # contain duplicates if beam search is enabled. @@ -1052,7 +1089,8 @@ def get_logprobs( top_logprob_idx = 0 selected_logprobs_idx = 0 - for seq_group, sample_result in zip(sampling_metadata.seq_groups, sample_results): + for seq_group, sample_result in zip( + sampling_metadata.seq_groups, sample_results): (prompt_logprobs, top_logprob_idx, selected_logprobs_idx) = ( _get_prompt_logprob_if_needed( seq_group, @@ -1084,13 +1122,13 @@ def get_logprobs( def _get_prompt_logprob_if_needed( - seq_group: SequenceGroupToSample, - selected_logprobs: torch.Tensor, - ranks: torch.Tensor, - top_token_ids: torch.Tensor, - top_logprobs: torch.Tensor, - selected_logprobs_idx: int, - top_logprob_idx: int, + seq_group: SequenceGroupToSample, + selected_logprobs: torch.Tensor, + ranks: torch.Tensor, + top_token_ids: torch.Tensor, + top_logprobs: torch.Tensor, + selected_logprobs_idx: int, + top_logprob_idx: int, ): """Compute the prompt logprob from a sequence group if needed.""" sampling_params = seq_group.sampling_params @@ -1105,11 +1143,10 @@ def _get_prompt_logprob_if_needed( # Pre-select indexes and create a list. It is faster than calling .item # repetitively. selected_logprob_items = selected_logprobs[ - selected_logprobs_idx : selected_logprobs_idx + len(next_prompt_tokens) - ].tolist() - rank_items = ranks[ - selected_logprobs_idx : selected_logprobs_idx + len(next_prompt_tokens) - ].tolist() + selected_logprobs_idx: selected_logprobs_idx + len(next_prompt_tokens) + ].tolist() + rank_items = ranks[selected_logprobs_idx: selected_logprobs_idx + + len(next_prompt_tokens)].tolist() for idx, token_id in enumerate(next_prompt_tokens): # Calculate the prompt logprob of the real prompt tokens. @@ -1120,8 +1157,10 @@ def _get_prompt_logprob_if_needed( # Add top K prompt logprobs along with its rank. if num_logprobs > 0: - top_ids = top_token_ids[top_logprob_idx, :num_logprobs].tolist() - top_probs = top_logprobs[top_logprob_idx, :num_logprobs].tolist() + top_ids = top_token_ids[top_logprob_idx, + :num_logprobs].tolist() + top_probs = top_logprobs[top_logprob_idx, + :num_logprobs].tolist() # Top K is already sorted by rank, so we can use 1 ~ # num_logprobs + 1 for rank. top_ranks = range(1, num_logprobs + 1) @@ -1146,14 +1185,14 @@ def _get_prompt_logprob_if_needed( def _get_sampled_logprob_if_needed( - seq_group: SequenceGroupToSample, - sample_result: Tuple[List[int], List[int]], - selected_logprobs: torch.Tensor, - ranks: torch.Tensor, - top_token_ids: torch.Tensor, - top_logprobs: torch.Tensor, - selected_logprobs_idx: int, - top_logprob_idx: int, + seq_group: SequenceGroupToSample, + sample_result: Tuple[List[int], List[int]], + selected_logprobs: torch.Tensor, + ranks: torch.Tensor, + top_token_ids: torch.Tensor, + top_logprobs: torch.Tensor, + selected_logprobs_idx: int, + top_logprob_idx: int, ): """Compute the sample logprob if needed.""" seq_ids = seq_group.seq_ids @@ -1162,7 +1201,7 @@ def _get_sampled_logprob_if_needed( next_token_ids, parent_seq_ids = sample_result if seq_group.do_sample: - assert len(next_token_ids) > 0 + assert next_token_ids if num_logprobs is None: for next_token_id in next_token_ids: # Use a dummy logprob @@ -1171,26 +1210,26 @@ def _get_sampled_logprob_if_needed( # Pre-select items from tensor. tolist() is faster than repetitive # `.item()` calls. selected_logprob_items = selected_logprobs[ - selected_logprobs_idx : selected_logprobs_idx + len(next_token_ids) - ].tolist() - rank_items = ranks[ - selected_logprobs_idx : selected_logprobs_idx + len(next_token_ids) - ].tolist() + selected_logprobs_idx: selected_logprobs_idx + len(next_token_ids) + ].tolist() + rank_items = ranks[selected_logprobs_idx: selected_logprobs_idx + + len(next_token_ids)].tolist() for idx, (next_token_id, parent_id) in enumerate( - zip(next_token_ids, parent_seq_ids) + zip(next_token_ids, parent_seq_ids) ): # Get the logprob of a sampled token. sampled_logprobs_dict = { - next_token_id: (selected_logprob_items[idx], rank_items[idx]) - } + next_token_id: ( + selected_logprob_items[idx], + rank_items[idx])} if num_logprobs is not None and num_logprobs > 0: # Get top K logprobs. top_ids = top_token_ids[ - top_logprob_idx + parent_id, :num_logprobs - ].tolist() + top_logprob_idx + parent_id, :num_logprobs + ].tolist() top_probs = top_logprobs[ - top_logprob_idx + parent_id, :num_logprobs - ].tolist() + top_logprob_idx + parent_id, :num_logprobs + ].tolist() # Top K is already sorted by rank, so we can use 1 ~ # num_logprobs + 1 for rank. top_ranks = range(1, num_logprobs + 1) @@ -1198,8 +1237,8 @@ def _get_sampled_logprob_if_needed( { top_id: (top_prob, rank) for top_id, top_prob, rank in zip( - top_ids, top_probs, top_ranks - ) + top_ids, top_probs, top_ranks + ) } ) @@ -1223,10 +1262,10 @@ def _get_sampled_logprob_if_needed( def _modify_greedy_probs_inplace( - logprobs: torch.Tensor, - probs: torch.Tensor, - sample_indices: torch.Tensor, - greedy_samples: torch.Tensor, + _logprobs: torch.Tensor, + probs: torch.Tensor, + sample_indices: torch.Tensor, + greedy_samples: torch.Tensor, ) -> None: """Modify the probability distributions of the greedily-sampled tokens such that each sampled token has a "probability" of 1.0. This is required by @@ -1276,12 +1315,12 @@ def _modify_greedy_probs_inplace( def _build_sampler_output( - maybe_deferred_sample_results: MaybeDeferredSampleResultType, - sampling_metadata: SamplingMetadata, - prompt_logprobs: Optional[List[Optional[PromptLogprobs]]], - sample_logprobs: Optional[List[SampleLogprobs]], - on_device_tensors: Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]], - skip_sampler_cpu_output: bool = False, + maybe_deferred_sample_results: MaybeDeferredSampleResultType, + sampling_metadata: SamplingMetadata, + prompt_logprobs: Optional[List[Optional[PromptLogprobs]]], + sample_logprobs: Optional[List[SampleLogprobs]], + on_device_tensors: Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]], + skip_sampler_cpu_output: bool = False, ) -> SamplerOutput: """Construct Python objects with the output of sampling. @@ -1299,14 +1338,16 @@ def _build_sampler_output( else: assert prompt_logprobs is not None assert sample_logprobs is not None - assert not isinstance(maybe_deferred_sample_results, SampleResultArgsType) + assert not isinstance( + maybe_deferred_sample_results, + SampleResultArgsType) deferred_sample_results_args = None for ( - seq_group, - sample_result, - group_prompt_logprobs, - group_sample_logprobs, + seq_group, + sample_result, + group_prompt_logprobs, + group_sample_logprobs, ) in zip( sampling_metadata.seq_groups, maybe_deferred_sample_results, @@ -1317,20 +1358,22 @@ def _build_sampler_output( next_token_ids, parent_ids = sample_result seq_outputs: List[SequenceOutput] = [] for parent_id, next_token_id, logprobs in zip( - parent_ids, next_token_ids, group_sample_logprobs + parent_ids, next_token_ids, group_sample_logprobs ): seq_outputs.append( SequenceOutput(seq_ids[parent_id], next_token_id, logprobs) ) sampler_output.append( - CompletionSequenceGroupOutput(seq_outputs, group_prompt_logprobs) - ) + CompletionSequenceGroupOutput( + seq_outputs, group_prompt_logprobs)) # If not specified, store None values in SamplerOutput. if on_device_tensors is not None: - (sampled_token_probs, logprobs_tensor, sampled_token_ids) = on_device_tensors + (sampled_token_probs, logprobs_tensor, + sampled_token_ids) = on_device_tensors else: - sampled_token_probs, logprobs_tensor, sampled_token_ids = (None, None, None) + sampled_token_probs, logprobs_tensor, sampled_token_ids = ( + None, None, None) return SamplerOutput( outputs=sampler_output, @@ -1368,6 +1411,8 @@ def _get_next_prompt_tokens(seq_group: SequenceGroupToSample) -> List[int]: prompt_tokens = seq_data.prompt_token_ids # +1 because we are looking for a next prompt token. next_token_index_start = computed_len + 1 - next_token_index_end = min(computed_len + query_len + 1, len(prompt_tokens)) + next_token_index_end = min( + computed_len + query_len + 1, + len(prompt_tokens)) next_prompt_tokens = prompt_tokens[next_token_index_start:next_token_index_end] return next_prompt_tokens diff --git a/vllm_mindspore/model_executor/layers/utils.py b/vllm_mindspore/model_executor/layers/utils.py index eedfaa12b0e84176c7e0b445291cadc82a6f2dfd..f7b5065e595b3a2ff96bbd7b44ad0bb1b2e51df3 100644 --- a/vllm_mindspore/model_executor/layers/utils.py +++ b/vllm_mindspore/model_executor/layers/utils.py @@ -1,35 +1,49 @@ """Utility methods for model layers.""" + from typing import Tuple + import torch + def get_token_bin_counts_and_mask( - tokens: torch.Tensor, - vocab_size: int, - num_seqs: int, + tokens: torch.Tensor, + vocab_size: int, + num_seqs: int, ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + + :param tokens: + :param vocab_size: + :param num_seqs: + :return: + """ # Compute the bin counts for the tokens. # vocab_size + 1 for padding. - bin_counts = torch.zeros((num_seqs, vocab_size + 1), - dtype=torch.long, - device=tokens.device) + bin_counts = torch.zeros( + (num_seqs, vocab_size + 1), dtype=torch.long, device=tokens.device + ) bin_counts.scatter_add_(1, tokens, torch.ones_like(tokens)) bin_counts = bin_counts[:, :vocab_size] mask = bin_counts > 0 return bin_counts, mask -def apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor, - output_tokens_tensor: torch.Tensor, - presence_penalties: torch.Tensor, - frequency_penalties: torch.Tensor, - repetition_penalties: torch.Tensor) -> torch.Tensor: + +def apply_penalties( + logits: torch.Tensor, + prompt_tokens_tensor: torch.Tensor, + output_tokens_tensor: torch.Tensor, + presence_penalties: torch.Tensor, + frequency_penalties: torch.Tensor, + repetition_penalties: torch.Tensor, +) -> torch.Tensor: """ Applies penalties in place to the logits tensor logits : The input logits tensor of shape [num_seqs, vocab_size] - prompt_tokens_tensor: A tensor containing the prompt tokens. The prompts - are padded to the maximum prompt length within the batch using - `vocab_size` as the padding value. The value `vocab_size` is used - for padding because it does not correspond to any valid token ID + prompt_tokens_tensor: A tensor containing the prompt tokens. The prompts + are padded to the maximum prompt length within the batch using + `vocab_size` as the padding value. The value `vocab_size` is used + for padding because it does not correspond to any valid token ID in the vocabulary. output_tokens_tensor: The output tokens tensor. presence_penalties: The presence penalties of shape (num_seqs, ) @@ -37,20 +51,25 @@ def apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor, repetition_penalties: The repetition penalties of shape (num_seqs, ) """ num_seqs, vocab_size = logits.shape - _, prompt_mask = get_token_bin_counts_and_mask(prompt_tokens_tensor, - vocab_size, num_seqs) + _, prompt_mask = get_token_bin_counts_and_mask( + prompt_tokens_tensor, vocab_size, num_seqs + ) output_bin_counts, output_mask = get_token_bin_counts_and_mask( - output_tokens_tensor, vocab_size, num_seqs) + output_tokens_tensor, vocab_size, num_seqs + ) # repetition_penalties = repetition_penalties.unsqueeze_(dim=1).repeat( # 1, vocab_size) - repetition_penalties = repetition_penalties.unsqueeze(dim=1).repeat(1, vocab_size) - logits[logits > 0] /= torch.where(prompt_mask | output_mask, - repetition_penalties, 1.0)[logits > 0] - logits[logits <= 0] *= torch.where(prompt_mask | output_mask, - repetition_penalties, 1.0)[logits <= 0] + repetition_penalties = repetition_penalties.unsqueeze( + dim=1).repeat( + 1, vocab_size) + logits[logits > 0] /= torch.where( + prompt_mask | output_mask, repetition_penalties, 1.0 + )[logits > 0] + logits[logits <= 0] *= torch.where( + prompt_mask | output_mask, repetition_penalties, 1.0 + )[logits <= 0] # We follow the definition in OpenAI API. # Refer to https://platform.openai.com/docs/api-reference/parameter-details logits -= frequency_penalties.unsqueeze(dim=1) * output_bin_counts logits -= presence_penalties.unsqueeze(dim=1) * output_mask return logits - diff --git a/vllm_mindspore/model_executor/layers/vocab_parallel_embedding.py b/vllm_mindspore/model_executor/layers/vocab_parallel_embedding.py index f45064e194c34e7252cc070e680f7fe57e746560..536e3c176ec5f14a02d4b4e64cfa9161fbafc457 100644 --- a/vllm_mindspore/model_executor/layers/vocab_parallel_embedding.py +++ b/vllm_mindspore/model_executor/layers/vocab_parallel_embedding.py @@ -18,23 +18,23 @@ from dataclasses import dataclass from typing import List, Optional, Sequence, Tuple -from mindspore import Parameter, Tensor, mint, nn, ops +from mindspore import Parameter, Tensor, jit, mint, nn, ops from mindspore.common import dtype as mstype from mindspore.common.dtype import typing from vllm.distributed import (divide, get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, - tensor_model_parallel_all_reduce,) -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) + get_tensor_model_parallel_world_size) +from vllm.model_executor.layers.quantization.base_config import \ + QuantizationConfig +from vllm_mindspore.distributed.communication_op import \ + ReduceFromModelParallelRegion from vllm_mindspore.model_executor.layers.quantization.base_config import ( QuantizeMethodBase, method_has_implemented_embedding) from vllm_mindspore.model_executor.utils import set_weight_attrs -from vllm_mindspore.distributed.communication_op import ReduceFromModelParallelRegion -from mindspore import jit DEFAULT_VOCAB_PADDING_SIZE = 64 + # TODO(tronzhang): Most same as vllm's one, check latter... @@ -79,12 +79,12 @@ class UnquantizedEmbeddingMethod(QuantizeMethodBase): def get_masked_input_and_mask( - input_: Tensor, - org_vocab_start_index: int, - org_vocab_end_index: int, - num_org_vocab_padding: int, - added_vocab_start_index: int, - added_vocab_end_index: int, + input_: Tensor, + org_vocab_start_index: int, + org_vocab_end_index: int, + num_org_vocab_padding: int, + added_vocab_start_index: int, + added_vocab_end_index: int, ) -> Tuple[Tensor, Tensor]: displaced_x = mint.sub(input_, org_vocab_start_index) down_truncated_x = mint.nn.functional.relu(displaced_x) @@ -96,7 +96,7 @@ def get_masked_input_and_mask( truncated_x = mint.minimum(down_truncated_x, added_vocab_end_index) added_vocab_mask = mint.eq(displaced_x, truncated_x) added_offset = added_vocab_start_index - ( - org_vocab_end_index - org_vocab_start_index) - num_org_vocab_padding + org_vocab_end_index - org_vocab_start_index) - num_org_vocab_padding valid_offset = (org_vocab_start_index * org_vocab_mask) + (added_offset * added_vocab_mask) vocab_mask = mint.logical_or(org_vocab_mask, added_vocab_mask) @@ -104,13 +104,15 @@ def get_masked_input_and_mask( return input_, vocab_mask.expand_dims(-1) -def pad_vocab_size(vocab_size: int, pad_to: int = DEFAULT_VOCAB_PADDING_SIZE) -> int: +def pad_vocab_size( + vocab_size: int, + pad_to: int = DEFAULT_VOCAB_PADDING_SIZE) -> int: """Pad the vocab size to the given value.""" return ((vocab_size + pad_to - 1) // pad_to) * pad_to def vocab_range_from_per_partition_vocab_size( - per_partition_vocab_size: int, rank: int, offset: int = 0 + per_partition_vocab_size: int, rank: int, offset: int = 0 ) -> Sequence[int]: index_f = rank * per_partition_vocab_size index_l = index_f + per_partition_vocab_size @@ -118,7 +120,7 @@ def vocab_range_from_per_partition_vocab_size( def vocab_range_from_global_vocab_size( - global_vocab_size: int, rank: int, world_size: int, offset: int = 0 + global_vocab_size: int, rank: int, world_size: int, offset: int = 0 ) -> Sequence[int]: per_partition_vocab_size = divide(global_vocab_size, world_size) return vocab_range_from_per_partition_vocab_size( @@ -187,14 +189,14 @@ class VocabParallelEmbeddingShardIndices: class VocabParallelEmbedding(nn.Cell): def __init__( - self, - num_embeddings: int, - embedding_dim: int, - params_dtype: Optional[typing.Type] = None, - org_num_embeddings: Optional[int] = None, - padding_size: int = DEFAULT_VOCAB_PADDING_SIZE, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", + self, + num_embeddings: int, + embedding_dim: int, + params_dtype: Optional[typing.Type] = None, + org_num_embeddings: Optional[int] = None, + padding_size: int = DEFAULT_VOCAB_PADDING_SIZE, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() # Keep the input dimensions. @@ -208,8 +210,7 @@ class VocabParallelEmbedding(nn.Cell): self.org_vocab_size, self.padding_size ) self.num_embeddings_padded = pad_vocab_size( - self.org_vocab_size_padded + num_added_embeddings, self.padding_size - ) + self.org_vocab_size_padded + num_added_embeddings, self.padding_size) assert self.org_vocab_size_padded <= self.num_embeddings_padded self.shard_indices = self._get_indices( @@ -251,16 +252,15 @@ class VocabParallelEmbedding(nn.Cell): self.num_embeddings_per_partition = divide( self.num_embeddings_padded, self.tp_size ) - assert ( - self.shard_indices.num_elements_padded == self.num_embeddings_per_partition - ) + assert (self.shard_indices.num_elements_padded == + self.num_embeddings_per_partition) self.num_org_embeddings_per_partition = ( - self.shard_indices.org_vocab_end_index - - self.shard_indices.org_vocab_start_index + self.shard_indices.org_vocab_end_index + - self.shard_indices.org_vocab_start_index ) self.num_added_embeddings_per_partition = ( - self.shard_indices.added_vocab_end_index - - self.shard_indices.added_vocab_start_index + self.shard_indices.added_vocab_end_index + - self.shard_indices.added_vocab_start_index ) self.linear_method.create_weights( @@ -276,13 +276,13 @@ class VocabParallelEmbedding(nn.Cell): @classmethod def _get_indices( - cls, - vocab_size_padded: int, - org_vocab_size_padded: int, - vocab_size: int, - org_vocab_size: int, - tp_rank: int, - tp_size: int, + cls, + vocab_size_padded: int, + org_vocab_size_padded: int, + vocab_size: int, + org_vocab_size: int, + tp_rank: int, + tp_size: int, ) -> VocabParallelEmbeddingShardIndices: """Get start and end indices for vocab parallel embedding, following the layout outlined in the class docstring, based on the given tp_rank and @@ -293,13 +293,13 @@ class VocabParallelEmbedding(nn.Cell): ) padded_added_vocab_start_index, padded_added_vocab_end_index = ( vocab_range_from_global_vocab_size( - num_added_embeddings_padded, tp_rank, tp_size, offset=org_vocab_size - ) - ) + num_added_embeddings_padded, tp_rank, tp_size, offset=org_vocab_size)) # remove padding - org_vocab_start_index = min(padded_org_vocab_start_index, org_vocab_size) + org_vocab_start_index = min( + padded_org_vocab_start_index, org_vocab_size) org_vocab_end_index = min(padded_org_vocab_end_index, org_vocab_size) - added_vocab_start_index = min(padded_added_vocab_start_index, vocab_size) + added_vocab_start_index = min( + padded_added_vocab_start_index, vocab_size) added_vocab_end_index = min(padded_added_vocab_end_index, vocab_size) return VocabParallelEmbeddingShardIndices( padded_org_vocab_start_index, @@ -344,8 +344,9 @@ class VocabParallelEmbedding(nn.Cell): assert param.data.shape == loaded_weight.shape if param.data.shape != loaded_weight.shape: raise ValueError( - f"'param.data.shape' should be equal to 'loaded_weight.shape'," - f" but got {param.data.shape} and {loaded_weight.shape}") + f"'param.data.shape' should be equal to 'loaded_weight.shape'," f" but got { + param.data.shape} and { + loaded_weight.shape}") param.set_data(loaded_weight) return @@ -380,15 +381,15 @@ class ParallelLMHead(VocabParallelEmbedding): """ def __init__( - self, - num_embeddings: int, - embedding_dim: int, - bias: bool = False, - params_dtype=None, - org_num_embeddings: Optional[int] = None, - padding_size: int = DEFAULT_VOCAB_PADDING_SIZE, - quant_config=None, - prefix: str = "", + self, + num_embeddings: int, + embedding_dim: int, + bias: bool = False, + params_dtype=None, + org_num_embeddings: Optional[int] = None, + padding_size: int = DEFAULT_VOCAB_PADDING_SIZE, + quant_config=None, + prefix: str = "", ): super().__init__( num_embeddings, @@ -402,8 +403,9 @@ class ParallelLMHead(VocabParallelEmbedding): self.quant_config = quant_config if bias: self.bias = Parameter( - mint.zeros(self.num_embeddings_per_partition, dtype=params_dtype) - ) + mint.zeros( + self.num_embeddings_per_partition, + dtype=params_dtype)) set_weight_attrs( self.bias, { diff --git a/vllm_mindspore/model_executor/model_loader/utils.py b/vllm_mindspore/model_executor/model_loader/utils.py index c94e3150b96842bd9b598879d05fc338a0afee9b..ff9eb8f629caa5b059191caf2e899e67d7f6a185 100644 --- a/vllm_mindspore/model_executor/model_loader/utils.py +++ b/vllm_mindspore/model_executor/model_loader/utils.py @@ -18,21 +18,22 @@ from typing import Tuple, Type from torch import nn - from vllm.config import ModelConfig -from vllm_mindspore.model_executor.models.registry import MindSporeModelRegistry +from vllm_mindspore.model_executor.models.registry import \ + MindSporeModelRegistry -def get_ms_model_architecture(model_config: ModelConfig) -> Tuple[Type[nn.Module], str]: +def get_ms_model_architecture( + model_config: ModelConfig) -> Tuple[Type[nn.Module], str]: architectures = getattr(model_config.hf_config, "architectures", []) model_cls, arch = MindSporeModelRegistry.resolve_model_cls(architectures) if model_config.task == "embed": - raise RecursionError("MindSpore unsupport embed model task now!") + raise RecursionError("MindSpore unsupported embed model task now!") elif model_config.task == "classify": - raise RecursionError("MindSpore unsupport classify model task now!") + raise RecursionError("MindSpore unsupported classify model task now!") elif model_config.task == "reward": - raise RecursionError("MindSpore unsupport reward model task now!") + raise RecursionError("MindSpore unsupported reward model task now!") return model_cls, arch diff --git a/vllm_mindspore/model_executor/model_loader/weight_utils.py b/vllm_mindspore/model_executor/model_loader/weight_utils.py index ead186957a59adf97789be023e2af5e5a19925bc..e124918875c179daf22e654325e576e06aaf0735 100644 --- a/vllm_mindspore/model_executor/model_loader/weight_utils.py +++ b/vllm_mindspore/model_executor/model_loader/weight_utils.py @@ -15,30 +15,28 @@ # limitations under the License. # ============================================================================ -from tqdm.auto import tqdm from typing import Generator, List, Tuple -import torch - import mindspore as ms +import torch from mindspore import Parameter, Tensor +from tqdm.auto import tqdm def safetensors_weights_iterator( - hf_weights_files: List[str], + hf_weights_files: List[str], ) -> Generator[Tuple[str, torch.Tensor], None, None]: """Iterate over the weights in the model safetensor files.""" from safetensors import safe_open from vllm.model_executor.model_loader.weight_utils import _BAR_FORMAT - enable_tqdm = ( - not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0 - ) + enable_tqdm = (not torch.distributed.is_initialized() + or torch.distributed.get_rank() == 0) for st_file in tqdm( - hf_weights_files, - desc="Loading safetensors checkpoint shards", - disable=not enable_tqdm, - bar_format=_BAR_FORMAT, + hf_weights_files, + desc="Loading safetensors checkpoint shards", + disable=not enable_tqdm, + bar_format=_BAR_FORMAT, ): with safe_open(st_file, framework="np") as f: for name in f.keys(): diff --git a/vllm_mindspore/model_executor/models/interfaces.py b/vllm_mindspore/model_executor/models/interfaces.py index 0b1510d973dc911b5fe85d0d4980648d9e605c18..132be356c707afe05792dada53ecec8d6e2a0ef4 100644 --- a/vllm_mindspore/model_executor/models/interfaces.py +++ b/vllm_mindspore/model_executor/models/interfaces.py @@ -1,5 +1,9 @@ -from typing import (TYPE_CHECKING, ClassVar, Dict, List, Literal, Optional, - Protocol, Type, Union, overload, runtime_checkable) +""" +interface +""" + +from typing import ClassVar, Dict, List, Literal, Protocol, runtime_checkable + @runtime_checkable class SupportsLoRA(Protocol): @@ -19,6 +23,7 @@ class SupportsLoRA(Protocol): embedding_modules: ClassVar[Dict[str, str]] embedding_padding_modules: ClassVar[List[str]] + @runtime_checkable class _SupportsLoRAType(Protocol): supports_lora: Literal[True] @@ -26,4 +31,4 @@ class _SupportsLoRAType(Protocol): packed_modules_mapping: Dict[str, List[str]] supported_lora_modules: List[str] embedding_modules: Dict[str, str] - embedding_padding_modules: List[str] \ No newline at end of file + embedding_padding_modules: List[str] diff --git a/vllm_mindspore/model_executor/models/llama.py b/vllm_mindspore/model_executor/models/llama.py index c20f54fe8571b1879846f393a438fb047f67cd52..5f83e7458d0e33a5e4d246726254ad79fad2e254 100644 --- a/vllm_mindspore/model_executor/models/llama.py +++ b/vllm_mindspore/model_executor/models/llama.py @@ -15,48 +15,38 @@ # limitations under the License. # ============================================================================ -from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple, Type, Union +from typing import (TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, + Tuple, Type, Union) if TYPE_CHECKING: from transformers import LlamaConfig else: LlamaConfig = None +from mindspore import Tensor +from mindspore import dtype as mstype +from mindspore import jit, mint, nn from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.model_executor.models.interfaces import SupportsPP +from vllm.sequence import IntermediateTensors -from vllm_mindspore.model_executor.layers.linear import ( - MergedColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear, -) -from vllm_mindspore.model_executor.layers.logits_processor import LogitsProcessor from vllm_mindspore.attention import Attention from vllm_mindspore.model_executor.layers.activation import SiluAndMul -from vllm_mindspore.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, - ParallelLMHead, - VocabParallelEmbedding, -) -from vllm_mindspore.model_executor.models.utils import ( - PPMissingLayer, - extract_layer_index, - make_layers, - maybe_prefix, - make_empty_intermediate_tensors_factory, -) -from vllm_mindspore.model_executor.layers.sampler import get_sampler, SamplerOutput from vllm_mindspore.model_executor.layers.layernorm import RMSNorm +from vllm_mindspore.model_executor.layers.linear import ( + MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) +from vllm_mindspore.model_executor.layers.logits_processor import \ + LogitsProcessor from vllm_mindspore.model_executor.layers.rotary_embedding import get_rope -from vllm_mindspore.model_executor.sampling_metadata import SamplingMetadata - +from vllm_mindspore.model_executor.layers.sampler import (SamplerOutput, + get_sampler) +from vllm_mindspore.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm_mindspore.model_executor.models.model_base import MsModelBase - -from vllm.sequence import IntermediateTensors -from vllm.attention import AttentionMetadata -from vllm.model_executor.models.interfaces import SupportsPP - -from mindspore import Tensor, mint, jit, nn -from mindspore import dtype as mstype +from vllm_mindspore.model_executor.models.utils import ( + PPMissingLayer, extract_layer_index, + make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) +from vllm_mindspore.model_executor.sampling_metadata import SamplingMetadata def default_weight_loader(param, loaded_weight) -> None: @@ -65,13 +55,13 @@ def default_weight_loader(param, loaded_weight) -> None: class LlamaMLP(nn.Cell): def __init__( - self, - hidden_size: int, - intermediate_size: int, - hidden_act: str, - quant_config=None, - bias: bool = False, - prefix: str = "", + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config=None, + bias: bool = False, + prefix: str = "", ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( @@ -105,18 +95,18 @@ class LlamaMLP(nn.Cell): class LlamaAttention(nn.Cell): def __init__( - self, - config: LlamaConfig, - hidden_size: int, - num_heads: int, - num_kv_heads: int, - rope_theta: float = 10000, - rope_scaling: Optional[Dict[str, Any]] = None, - max_position_embeddings: int = 8192, - quant_config=None, - bias: bool = False, - cache_config=None, - prefix: str = "", + self, + config: LlamaConfig, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + max_position_embeddings: int = 8192, + quant_config=None, + bias: bool = False, + cache_config=None, + prefix: str = "", ) -> None: super().__init__() layer_idx = extract_layer_index(prefix) @@ -141,7 +131,7 @@ class LlamaAttention(nn.Cell): ) self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim - self.scaling = self.head_dim**-0.5 + self.scaling = self.head_dim ** -0.5 self.rope_theta = rope_theta self.max_position_embeddings = max_position_embeddings @@ -196,50 +186,69 @@ class LlamaAttention(nn.Cell): per_layer_sliding_window=sliding_window, prefix=f"{prefix}.attn", ) - self.attn_mask = mint.triu(mint.ones(size=(128, 128), dtype=mstype.float16), 1) * -10000.0 + self.attn_mask = mint.triu( + mint.ones( + size=( + 128, + 128), + dtype=mstype.float16), + 1) * -10000.0 @jit def construct( - self, - positions: Tensor, - hidden_states: Tensor, - kv_cache: Tuple[Tensor, Tensor], - # attn_metadata: AttentionMetadata, - num_prefill_tokens: int, - num_decode_tokens: int, - slot_mapping: Tensor, - batch_valid_length: Tuple[int], - context_lens: Tensor, - block_tables: Tensor, + self, + positions: Tensor, + hidden_states: Tensor, + kv_cache: Tuple[Tensor, Tensor], + # attn_metadata: AttentionMetadata, + num_prefill_tokens: int, + num_decode_tokens: int, + slot_mapping: Tensor, + batch_valid_length: Tuple[int], + context_lens: Tensor, + block_tables: Tensor, ) -> Tensor: qkv, _ = self.qkv_proj(hidden_states) - q, k, v = mint.split(qkv, (self.q_size, self.kv_size, self.kv_size), -1) - q, k = self.rotary_emb(positions, q, k, context_lens, num_prefill_tokens) - attn_output = self.attn(q, k, v, kv_cache, num_prefill_tokens, num_decode_tokens, - slot_mapping, batch_valid_length, context_lens, block_tables, self.attn_mask) + q, k, v = mint.split( + qkv, (self.q_size, self.kv_size, self.kv_size), -1) + q, k = self.rotary_emb( + positions, q, k, context_lens, num_prefill_tokens) + attn_output = self.attn( + q, + k, + v, + kv_cache, + num_prefill_tokens, + num_decode_tokens, + slot_mapping, + batch_valid_length, + context_lens, + block_tables, + self.attn_mask) output, _ = self.o_proj(attn_output) return output class LlamaDecoderLayer(nn.Cell): def __init__( - self, - config: LlamaConfig, - cache_config=None, - quant_config=None, - prefix: str = "", + self, + config: LlamaConfig, + cache_config=None, + quant_config=None, + prefix: str = "", ): super().__init__() self.hidden_size = config.hidden_size rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) if rope_scaling is not None and getattr( - config, "original_max_position_embeddings", None + config, "original_max_position_embeddings", None ): rope_scaling["original_max_position_embeddings"] = ( config.original_max_position_embeddings ) - max_position_embeddings = getattr(config, "max_position_embeddings", 8192) + max_position_embeddings = getattr( + config, "max_position_embeddings", 8192) # Support abacusai/Smaug-72B-v0.1 with attention_bias # Support internlm/internlm-7b with bias attention_bias = getattr(config, "attention_bias", False) or getattr( @@ -268,32 +277,34 @@ class LlamaDecoderLayer(nn.Cell): bias=getattr(config, "mlp_bias", False), prefix=f"{prefix}.mlp", ) - self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.input_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm( config.hidden_size, eps=config.rms_norm_eps ) @jit def construct( - self, - positions: Tensor, - hidden_states: Tensor, - kv_cache: Tuple[Tensor, Tensor], - # attn_metadata: AttentionMetadata, - num_prefill_tokens: int, - num_decode_tokens: int, - slot_mapping: Tensor, - batch_valid_length: Tuple[int], - context_lens: Tensor, - block_tables: Tensor, - residual: Optional[Tensor], + self, + positions: Tensor, + hidden_states: Tensor, + kv_cache: Tuple[Tensor, Tensor], + # attn_metadata: AttentionMetadata, + num_prefill_tokens: int, + num_decode_tokens: int, + slot_mapping: Tensor, + batch_valid_length: Tuple[int], + context_lens: Tensor, + block_tables: Tensor, + residual: Optional[Tensor], ) -> Tuple[Tensor, Tensor]: # Self Attention if residual is None: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) else: - hidden_states, residual = self.input_layernorm(hidden_states, residual) + hidden_states, residual = self.input_layernorm( + hidden_states, residual) hidden_states = self.self_attn( positions, @@ -308,7 +319,8 @@ class LlamaDecoderLayer(nn.Cell): ) # Fully Connected - hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) hidden_states = self.mlp(hidden_states) return hidden_states, residual @@ -318,11 +330,11 @@ class LlamaModel(nn.Cell): SUPPORT_PP = False def __init__( - self, - *, - vllm_config, - prefix: str = "", - layer_type: Type[LlamaDecoderLayer] = LlamaDecoderLayer, + self, + *, + vllm_config, + prefix: str = "", + layer_type: Type[LlamaDecoderLayer] = LlamaDecoderLayer, ): super().__init__() config = vllm_config @@ -340,7 +352,7 @@ class LlamaModel(nn.Cell): cache_config = None if get_pp_group().is_first_rank or ( - config.tie_word_embeddings and get_pp_group().is_last_rank + config.tie_word_embeddings and get_pp_group().is_last_rank ): self.embed_tokens = VocabParallelEmbedding( self.vocab_size, @@ -368,27 +380,26 @@ class LlamaModel(nn.Cell): self.norm = PPMissingLayer() self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size - ) + ["hidden_states", "residual"], config.hidden_size) def get_input_embeddings(self, input_ids: Tensor) -> Tensor: return self.embed_tokens(input_ids) @jit def construct( - self, - input_ids: Optional[Tensor], - positions: Tensor, - kv_caches: List[Tuple[Tensor, Tensor]], - # attn_metadata: AttentionMetadata, - num_prefill_tokens: int, - num_decode_tokens: int, - slot_mapping: Tensor, - batch_valid_length: Tuple[int], - context_lens: Tensor, - block_tables: Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[Tensor] = None, + self, + input_ids: Optional[Tensor], + positions: Tensor, + kv_caches: List[Tuple[Tensor, Tensor]], + # attn_metadata: AttentionMetadata, + num_prefill_tokens: int, + num_decode_tokens: int, + slot_mapping: Tensor, + batch_valid_length: Tuple[int], + context_lens: Tensor, + block_tables: Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[Tensor] = None, ) -> Union[Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: if inputs_embeds is not None: @@ -516,13 +527,13 @@ class LlamaForCausalLM(MsModelBase, SupportsPP): self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens) def forward( - self, - input_ids, - positions, - kv_caches, - attn_metadata, - intermediate_tensors=None, - inputs_embeds=None, + self, + input_ids, + positions, + kv_caches, + attn_metadata, + intermediate_tensors=None, + inputs_embeds=None, ): if attn_metadata.num_prefill_tokens > 0: input_ids = input_ids.expand_dims(0) @@ -545,15 +556,16 @@ class LlamaForCausalLM(MsModelBase, SupportsPP): self.model.load_weights(weights, params_dict) def sample( - self, logits: Tensor, sampling_metadata: SamplingMetadata + self, logits: Tensor, sampling_metadata: SamplingMetadata ) -> Optional[SamplerOutput]: next_tokens = self.sampler(logits, sampling_metadata) return next_tokens def compute_logits( - self, - hidden_states: Tensor, - sampling_metadata: SamplingMetadata, + self, + hidden_states: Tensor, + sampling_metadata: SamplingMetadata, ) -> Optional[Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) + logits = self.logits_processor( + self.lm_head, hidden_states, sampling_metadata) return logits diff --git a/vllm_mindspore/model_executor/models/mf_models/deepseek_v3.py b/vllm_mindspore/model_executor/models/mf_models/deepseek_v3.py index e4b41792e6f5ba049a6f1106428c634704486c07..7f29c991b411cf9978afa86649e6544a13a8480d 100644 --- a/vllm_mindspore/model_executor/models/mf_models/deepseek_v3.py +++ b/vllm_mindspore/model_executor/models/mf_models/deepseek_v3.py @@ -17,39 +17,30 @@ import os from typing import Iterable, List, Optional, Set, Tuple, Union -from pathlib import Path +import mindspore as ms import numpy as np - +from mindformers.core.context import build_context +from mindformers.core.parallel_config import build_parallel_config +from mindformers.tools.register.config import MindFormerConfig +from mindformers.trainer.utils import transform_and_load_checkpoint +from mindspore import JitConfig, Model, Tensor +from research.deepseek3.deepseek3 import \ + DeepseekV3ForCausalLM as DeepseekV3ForCausalLM_MF +from research.deepseek3.deepseek3_config import \ + DeepseekV3Config as DeepseekV3Config_MF from vllm.attention import AttentionMetadata from vllm.config import VllmConfig +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.logger import init_logger from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.logger import init_logger - - -from mindformers.tools.register.config import MindFormerConfig - -from mindformers.core.context import build_context -from mindformers.core.parallel_config import build_parallel_config -from mindformers.trainer.utils import transform_and_load_checkpoint -from research.deepseek3.deepseek3_config import ( - DeepseekV3Config as DeepseekV3Config_MF, -) -from research.deepseek3.deepseek3 import ( - DeepseekV3ForCausalLM as DeepseekV3ForCausalLM_MF, -) from vllm_mindspore.model_executor.layers.sampler import get_sampler from vllm_mindspore.model_executor.models.model_base import MsModelBase from vllm_mindspore.utils import cal_block_num -import mindspore as ms -from mindspore import Tensor, JitConfig, Model - - logger = init_logger(__name__) @@ -87,7 +78,8 @@ class DeepseekV3ForCausalLM(MsModelBase): vllm_config=vllm_config, prefix=prefix ) - self.mf_config = MindFormerConfig(os.getenv("MINDFORMERS_MODEL_CONFIG")) + self.mf_config = MindFormerConfig( + os.getenv("MINDFORMERS_MODEL_CONFIG")) build_context(self.mf_config, is_set_ms_ctx=False, is_init_ms=False) build_parallel_config(self.mf_config) self.mf_config.model.model_config.parallel_config = ( @@ -98,29 +90,38 @@ class DeepseekV3ForCausalLM(MsModelBase): ) self.mf_config.model.model_config.parallel_config.pipeline_stage = 1 - self.mf_model_config = DeepseekV3Config_MF(**self.mf_config.model.model_config) - self.mf_model_config.num_blocks = cal_block_num(self.cache_config, self.model_config, self.parallel_config) + self.mf_model_config = DeepseekV3Config_MF( + **self.mf_config.model.model_config) + self.mf_model_config.num_blocks = cal_block_num( + self.cache_config, self.model_config, self.parallel_config + ) self.mf_model_config.block_size = self.cache_config.block_size if self.mf_config.moe_config: self.mf_model_config.moe_config = self.mf_config.moe_config - # Initital network + # Initial network self.network = DeepseekV3ForCausalLM_MF(self.mf_model_config) # quant if self.mf_model_config.quantization_config: - from mindspore_gs.ptq.ptq import PTQ - from mindspore_gs.ptq.ptq_config import PTQMode, PTQConfig, OutliersSuppressionType, PrecisionRecovery, QuantGranularity - from mindspore_gs.common import BackendTarget from mindspore.common import dtype as msdtype - cfg = PTQConfig(mode=PTQMode.DEPLOY, - backend=BackendTarget.ASCEND, - weight_quant_dtype=msdtype.int8, - act_quant_dtype=msdtype.int8, - outliers_suppression=OutliersSuppressionType.NONE, - opname_blacklist=['lkv2kv', 'lm_head', '61'], - act_quant_granularity=QuantGranularity.PER_TENSOR, - weight_quant_granularity=QuantGranularity.PER_CHANNEL) + from mindspore_gs.common import BackendTarget + from mindspore_gs.ptq.ptq import PTQ + from mindspore_gs.ptq.ptq_config import (OutliersSuppressionType, + PrecisionRecovery, + PTQConfig, PTQMode, + QuantGranularity) + + cfg = PTQConfig( + mode=PTQMode.DEPLOY, + backend=BackendTarget.ASCEND, + weight_quant_dtype=msdtype.int8, + act_quant_dtype=msdtype.int8, + outliers_suppression=OutliersSuppressionType.NONE, + opname_blacklist=["lkv2kv", "lm_head", "61"], + act_quant_granularity=QuantGranularity.PER_TENSOR, + weight_quant_granularity=QuantGranularity.PER_CHANNEL, + ) ptq = PTQ(config=cfg) ptq.apply(self.network) ptq.convert(self.network) @@ -148,13 +149,13 @@ class DeepseekV3ForCausalLM(MsModelBase): self.mf_kvcaches_init = True def forward( - self, - input_ids: Tensor, - positions: Tensor, - kv_caches: List[Tensor], - attn_metadata: AttentionMetadata, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[Tensor] = None, + self, + input_ids: Tensor, + positions: Tensor, + kv_caches: List[Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[Tensor] = None, ) -> Union[Tensor, IntermediateTensors]: self.update_mf_kvcaches(kv_caches) @@ -186,16 +187,17 @@ class DeepseekV3ForCausalLM(MsModelBase): return None def compute_logits( - self, - hidden_states: Tensor, - sampling_metadata: SamplingMetadata, + self, + hidden_states: Tensor, + sampling_metadata: SamplingMetadata, ) -> Optional[Tensor]: + return self.logits def sample( - self, - logits: Tensor, - sampling_metadata: SamplingMetadata, + self, + logits: Tensor, + sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: next_tokens = self.sampler(logits, sampling_metadata) return next_tokens diff --git a/vllm_mindspore/model_executor/models/mf_models/qwen2.py b/vllm_mindspore/model_executor/models/mf_models/qwen2.py index c7a963cae2dce496623871cfcb1809b1f8909ff0..2ca72ec4b74a21afe467122964f4ee8be5c18725 100644 --- a/vllm_mindspore/model_executor/models/mf_models/qwen2.py +++ b/vllm_mindspore/model_executor/models/mf_models/qwen2.py @@ -18,36 +18,30 @@ import os from typing import Iterable, List, Optional, Set, Tuple, Union +import mindspore as ms import numpy as np - +from mindformers.core.context import build_context +from mindformers.core.parallel_config import build_parallel_config +from mindformers.models.llama import LlamaConfig as LlamaConfig_MF +from mindformers.tools.register.config import MindFormerConfig +from mindformers.tools.utils import set_output_path +from mindformers.trainer import BaseTrainer +from mindformers.trainer.utils import transform_and_load_checkpoint +from mindspore import JitConfig, Model, Tensor +from research.qwen2_5.infer.qwen2_5 import \ + ParallelQwenForCausalLM as ParallelQwenForCausalLM_MF from vllm.attention import AttentionMetadata from vllm.config import VllmConfig +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.logger import init_logger from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.logger import init_logger - - -from mindformers.tools.register.config import MindFormerConfig - -from mindformers.core.context import build_context -from mindformers.core.parallel_config import build_parallel_config - -from mindformers.models.llama import LlamaConfig as LlamaConfig_MF -from mindformers.trainer import BaseTrainer -from mindformers.tools.utils import set_output_path, set_strategy_save_path -from research.qwen2_5.infer.qwen2_5 import ParallelQwenForCausalLM as ParallelQwenForCausalLM_MF from vllm_mindspore.model_executor.layers.sampler import get_sampler from vllm_mindspore.model_executor.models.model_base import MsModelBase from vllm_mindspore.utils import cal_block_num -import mindspore as ms -from mindspore import Tensor, JitConfig, Model -from mindformers.trainer.utils import transform_and_load_checkpoint - - logger = init_logger(__name__) @@ -81,9 +75,14 @@ def _batch_seq(input_tokens, prefill): class Qwen2ForCausalLM(MsModelBase): def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: - super(Qwen2ForCausalLM, self).__init__(vllm_config=vllm_config, prefix=prefix) - - self.mf_config = MindFormerConfig(os.getenv("MINDFORMERS_MODEL_CONFIG")) + super( + Qwen2ForCausalLM, + self).__init__( + vllm_config=vllm_config, + prefix=prefix) + + self.mf_config = MindFormerConfig( + os.getenv("MINDFORMERS_MODEL_CONFIG")) build_context(self.mf_config, is_set_ms_ctx=False, is_init_ms=False) build_parallel_config(self.mf_config) self.mf_config.model.model_config.parallel_config = ( @@ -94,9 +93,11 @@ class Qwen2ForCausalLM(MsModelBase): ) self.mf_config.model.model_config.parallel_config.pipeline_stage = 1 - self.mf_model_config = LlamaConfig_MF(**self.mf_config.model.model_config) + self.mf_model_config = LlamaConfig_MF( + **self.mf_config.model.model_config) # Cannot get num_gpu_blocks from cache config now, calculate one first. - self.mf_model_config.num_blocks = cal_block_num(self.cache_config, self.model_config, self.parallel_config) + self.mf_model_config.num_blocks = cal_block_num( + self.cache_config, self.model_config, self.parallel_config) self.mf_model_config.block_size = self.cache_config.block_size if self.mf_config.moe_config: self.mf_model_config.moe_config = self.mf_config.moe_config @@ -137,13 +138,13 @@ class Qwen2ForCausalLM(MsModelBase): self.mf_kvcaches_init = True def forward( - self, - input_ids: Tensor, - positions: Tensor, - kv_caches: List[Tensor], - attn_metadata: AttentionMetadata, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[Tensor] = None, + self, + input_ids: Tensor, + positions: Tensor, + kv_caches: List[Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[Tensor] = None, ) -> Union[Tensor, IntermediateTensors]: self.update_mf_kvcaches(kv_caches) @@ -175,16 +176,16 @@ class Qwen2ForCausalLM(MsModelBase): return None def compute_logits( - self, - hidden_states: Tensor, - sampling_metadata: SamplingMetadata, + self, + hidden_states: Tensor, + sampling_metadata: SamplingMetadata, ) -> Optional[Tensor]: return self.logits def sample( - self, - logits: Tensor, - sampling_metadata: SamplingMetadata, + self, + logits: Tensor, + sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: next_tokens = self.sampler(logits, sampling_metadata) return next_tokens diff --git a/vllm_mindspore/model_executor/models/model_base.py b/vllm_mindspore/model_executor/models/model_base.py index 86c12e252240a64689755ef8c9df90619fdf9046..26f1faea88c183f9278cf82a2cf43ca862bbb8bb 100644 --- a/vllm_mindspore/model_executor/models/model_base.py +++ b/vllm_mindspore/model_executor/models/model_base.py @@ -15,23 +15,31 @@ # limitations under the License. # ============================================================================ +""" +model_base +""" + from abc import abstractmethod -from typing import Iterable, List, Optional, Set, Tuple, Union, Dict +from typing import Dict, Iterable, List, Optional, Set, Tuple, Union +from mindspore import Tensor +from mindspore import dtype as mstype +from mindspore import mutable, nn from vllm.attention import AttentionMetadata from vllm.config import VllmConfig from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from mindspore import Tensor, nn, mutable -from mindspore import dtype as mstype - from vllm_mindspore.utils import STR_DTYPE_TO_MS_DTYPE class MsModelBase(): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: + """ + MsModelBase + """ + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: # pylint: disable=unused-argument super(MsModelBase, self).__init__() config = vllm_config.model_config.hf_config lora_config = vllm_config.lora_config @@ -62,6 +70,11 @@ class MsModelBase(): yield par_name, par def get_params_dict(self): + """ + + Returns: + + """ self._check_modules_valid() params_dict = dict() @@ -76,10 +89,8 @@ class MsModelBase(): return params_dict - def named_modules(self, remove_duplicate: bool = True): + def named_modules(self, _remove_duplicate: bool = True): self._check_modules_valid() - - res_modules = set() for name, module in self.modules_dict.items(): for module_name, sub_module in module.cells_and_names(): if name != "self": @@ -98,13 +109,13 @@ class MsModelBase(): return self def __call__( - self, - input_ids: Tensor, - positions: Tensor, - kv_caches: List[Tensor], - attn_metadata: AttentionMetadata, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[Tensor] = None, + self, + input_ids: Tensor, + positions: Tensor, + kv_caches: List[Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[Tensor] = None, ) -> Union[Tensor, IntermediateTensors]: return self.forward( input_ids, @@ -116,17 +127,20 @@ class MsModelBase(): ) def forward( - self, - input_ids: Tensor, - positions: Tensor, - kv_caches: List[Tensor], - attn_metadata: AttentionMetadata, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[Tensor] = None, + self, + input_ids: Tensor, + positions: Tensor, + kv_caches: List[Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[Tensor] = None, ) -> Union[Tensor, IntermediateTensors]: raise NotImplementedError def set_model_inputs(self): + """ + set_model_inputs + """ dyn_input_ids = Tensor(shape=[None, None], dtype=mstype.int64) dyn_position_ids = Tensor(shape=[None], dtype=mstype.int64) @@ -141,8 +155,14 @@ class MsModelBase(): num_layers = self.model_config.get_num_layers(self.parallel_config) - dyn_key_cache = mutable(Tensor(shape=kv_cache_shape, dtype=kv_cache_dtype)) - dyn_value_cache = mutable(Tensor(shape=kv_cache_shape, dtype=kv_cache_dtype)) + dyn_key_cache = mutable( + Tensor( + shape=kv_cache_shape, + dtype=kv_cache_dtype)) + dyn_value_cache = mutable( + Tensor( + shape=kv_cache_shape, + dtype=kv_cache_dtype)) dyn_kv_cache = mutable((dyn_key_cache, dyn_value_cache)) dyn_kv_caches = mutable([dyn_kv_cache for _ in range(num_layers)]) @@ -171,20 +191,22 @@ class MsModelBase(): @abstractmethod def compute_logits( - self, - hidden_states: Tensor, - sampling_metadata: SamplingMetadata, + self, + hidden_states: Tensor, + sampling_metadata: SamplingMetadata, ) -> Optional[Tensor]: - raise NotImplementedError("Function compute_logits should be Implemented!") + raise NotImplementedError( + "Function compute_logits should be Implemented!") @abstractmethod def sample( - self, - logits: Tensor, - sampling_metadata: SamplingMetadata, + self, + logits: Tensor, + sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: raise NotImplementedError("Function sample should be Implemented!") @abstractmethod def load_weights(self, weights: Iterable[Tuple[str, Tensor]]) -> Set[str]: - raise NotImplementedError("Function load_weights should be Implemented!") + raise NotImplementedError( + "Function load_weights should be Implemented!") diff --git a/vllm_mindspore/model_executor/models/qwen2.py b/vllm_mindspore/model_executor/models/qwen2.py index 7a46c83a0500f16708978b163590575dec6e9350..1c3363c736c28b57e445711d4d0e0978d7ddfbe0 100644 --- a/vllm_mindspore/model_executor/models/qwen2.py +++ b/vllm_mindspore/model_executor/models/qwen2.py @@ -1,23 +1,25 @@ -from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Union, Iterable +from typing import (TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple, + Union) if TYPE_CHECKING: from transformers import Qwen2Config else: Qwen2Config = None -from mindspore import Parameter, Tensor, mint, nn, jit, mutable +from mindspore import Parameter, Tensor, jit, mint, nn from mindspore.common import dtype as mstype - +from vllm.attention.backends.abstract import AttentionMetadata, AttentionType +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.sequence import IntermediateTensors from vllm_mindspore.attention import Attention - from vllm_mindspore.model_executor.layers.activation import SwiGLU from vllm_mindspore.model_executor.layers.layernorm import RMSNorm from vllm_mindspore.model_executor.layers.linear import ( MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm_mindspore.model_executor.layers.logits_processor import \ LogitsProcessor -from vllm.model_executor.layers.quantization import \ - QuantizationConfig from vllm_mindspore.model_executor.layers.rotary_embedding import get_rope from vllm_mindspore.model_executor.layers.sampler import (SamplerOutput, get_sampler) @@ -25,29 +27,22 @@ from vllm_mindspore.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm_mindspore.model_executor.model_loader.weight_utils import \ default_weight_loader +from vllm_mindspore.model_executor.models.model_base import MsModelBase from vllm_mindspore.model_executor.models.utils import ( PPMissingLayer, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) from vllm_mindspore.model_executor.sampling_metadata import SamplingMetadata -from vllm_mindspore.model_executor.models.model_base import MsModelBase - - -from vllm.config import CacheConfig, VllmConfig -from vllm.sequence import IntermediateTensors -from vllm.attention.backends.abstract import AttentionType -from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size -from vllm.attention.backends.abstract import AttentionMetadata class Qwen2MLP(nn.Cell): def __init__( - self, - hidden_size: int, - intermediate_size: int, - hidden_act: str, - quant_config=None, - bias: bool = False, - prefix: str = "", + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config=None, + bias: bool = False, + prefix: str = "", ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( @@ -112,7 +107,7 @@ class Qwen2Attention(nn.Cell): self.head_dim = hidden_size // self.total_num_heads self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim - self.scaling = self.head_dim**-0.5 + self.scaling = self.head_dim ** -0.5 self.rope_theta = rope_theta self.qkv_proj = QKVParallelLinear( @@ -152,27 +147,45 @@ class Qwen2Attention(nn.Cell): prefix=f"{prefix}.attn", attn_type=attn_type ) - self.attn_mask = mint.triu(mint.ones(size=(128, 128), dtype=mstype.bfloat16), 1) + self.attn_mask = mint.triu( + mint.ones( + size=( + 128, + 128), + dtype=mstype.bfloat16), + 1) @jit def construct( - self, - positions: Tensor, - hidden_states: Tensor, - kv_cache: Tuple[Tensor, Tensor], - # attn_metadata: AttentionMetadata, - num_prefill_tokens: int, - num_decode_tokens: int, - slot_mapping: Tensor, - batch_valid_length: Tuple[int], - context_lens: Tensor, - block_tables: Tensor, + self, + positions: Tensor, + hidden_states: Tensor, + kv_cache: Tuple[Tensor, Tensor], + # attn_metadata: AttentionMetadata, + num_prefill_tokens: int, + num_decode_tokens: int, + slot_mapping: Tensor, + batch_valid_length: Tuple[int], + context_lens: Tensor, + block_tables: Tensor, ) -> Tensor: qkv, _ = self.qkv_proj(hidden_states) - q, k, v = mint.split(qkv, (self.q_size, self.kv_size, self.kv_size), -1) - q, k = self.rotary_emb(positions, q, k, context_lens, num_prefill_tokens) - attn_output = self.attn(q, k, v, kv_cache, num_prefill_tokens, num_decode_tokens, - slot_mapping, batch_valid_length, context_lens, block_tables, self.attn_mask) + q, k, v = mint.split( + qkv, (self.q_size, self.kv_size, self.kv_size), -1) + q, k = self.rotary_emb( + positions, q, k, context_lens, num_prefill_tokens) + attn_output = self.attn( + q, + k, + v, + kv_cache, + num_prefill_tokens, + num_decode_tokens, + slot_mapping, + batch_valid_length, + context_lens, + block_tables, + self.attn_mask) output, _ = self.o_proj(attn_output) return output @@ -180,11 +193,11 @@ class Qwen2Attention(nn.Cell): class Qwen2DecoderLayer(nn.Cell): def __init__( - self, - config: Qwen2Config, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", + self, + config: Qwen2Config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.hidden_size = config.hidden_size @@ -222,32 +235,33 @@ class Qwen2DecoderLayer(nn.Cell): ) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, - params_dtype=mstype.bfloat16,) + params_dtype=mstype.bfloat16, ) self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, - params_dtype=mstype.bfloat16,) + params_dtype=mstype.bfloat16, ) @jit def construct( - self, - positions: Tensor, - hidden_states: Tensor, - kv_cache: Tuple[Tensor, Tensor], - # attn_metadata: AttentionMetadata, - num_prefill_tokens: int, - num_decode_tokens: int, - slot_mapping: Tensor, - batch_valid_length: Tuple[int], - context_lens: Tensor, - block_tables: Tensor, - residual: Optional[Tensor], + self, + positions: Tensor, + hidden_states: Tensor, + kv_cache: Tuple[Tensor, Tensor], + # attn_metadata: AttentionMetadata, + num_prefill_tokens: int, + num_decode_tokens: int, + slot_mapping: Tensor, + batch_valid_length: Tuple[int], + context_lens: Tensor, + block_tables: Tensor, + residual: Optional[Tensor], ) -> Tuple[Tensor, Tensor]: # Self Attention if residual is None: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) else: - hidden_states, residual = self.input_layernorm(hidden_states, residual) + hidden_states, residual = self.input_layernorm( + hidden_states, residual) hidden_states = self.self_attn( positions, hidden_states, @@ -261,7 +275,8 @@ class Qwen2DecoderLayer(nn.Cell): ) # Fully Connected - hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) hidden_states = self.mlp(hidden_states) return hidden_states, residual @@ -306,7 +321,7 @@ class Qwen2Model(nn.Cell): ["hidden_states", "residual"], config.hidden_size)) if get_pp_group().is_last_rank: self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, - params_dtype=mstype.bfloat16,) + params_dtype=mstype.bfloat16, ) else: self.norm = PPMissingLayer() @@ -315,19 +330,19 @@ class Qwen2Model(nn.Cell): @jit def construct( - self, - input_ids: Optional[Tensor], - positions: Tensor, - kv_caches: List[Tuple[Tensor, Tensor]], - # attn_metadata: AttentionMetadata, - num_prefill_tokens: int, - num_decode_tokens: int, - slot_mapping: Tensor, - batch_valid_length: Tuple[int], - context_lens: Tensor, - block_tables: Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[Tensor] = None, + self, + input_ids: Optional[Tensor], + positions: Tensor, + kv_caches: List[Tuple[Tensor, Tensor]], + # attn_metadata: AttentionMetadata, + num_prefill_tokens: int, + num_decode_tokens: int, + slot_mapping: Tensor, + batch_valid_length: Tuple[int], + context_lens: Tensor, + block_tables: Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[Tensor] = None, ) -> Union[Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: if inputs_embeds is not None: @@ -361,7 +376,8 @@ class Qwen2Model(nn.Cell): hidden_states, _ = self.norm(hidden_states, residual) return hidden_states - def load_weights(self, weights: Iterable[Tuple[str, Tensor]], params_dict: Dict[str, Parameter]): + def load_weights( + self, weights: Iterable[Tuple[str, Tensor]], params_dict: Dict[str, Parameter]): loaded_params: Set[str] = set() stacked_params_mapping = [ # (param_name, shard_name, shard_id) @@ -442,11 +458,14 @@ class Qwen2ForCausalLM(MsModelBase): if config.tie_word_embeddings: self.lm_head = self.model.embed_tokens else: - self.lm_head = ParallelLMHead(config.vocab_size, - config.hidden_size, - params_dtype=mstype.bfloat16, - quant_config=quant_config, - prefix=maybe_prefix(prefix, "lm_head")) + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + params_dtype=mstype.bfloat16, + quant_config=quant_config, + prefix=maybe_prefix( + prefix, + "lm_head")) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = get_sampler() else: @@ -462,13 +481,13 @@ class Qwen2ForCausalLM(MsModelBase): return self.model.get_input_embeddings(input_ids) def forward( - self, - input_ids: Tensor, - positions: Tensor, - kv_caches: List[Tuple[Tensor, Tensor]], - attn_metadata: AttentionMetadata, - intermediate_tensors: IntermediateTensors = None, - inputs_embeds: Tensor = None + self, + input_ids: Tensor, + positions: Tensor, + kv_caches: List[Tuple[Tensor, Tensor]], + attn_metadata: AttentionMetadata, + intermediate_tensors: IntermediateTensors = None, + inputs_embeds: Tensor = None ) -> Union[Tensor, IntermediateTensors]: if attn_metadata.num_prefill_tokens > 0: input_ids = input_ids.expand_dims(0) @@ -491,15 +510,16 @@ class Qwen2ForCausalLM(MsModelBase): self.model.load_weights(weights, params_dict) def sample( - self, logits: Tensor, sampling_metadata: SamplingMetadata + self, logits: Tensor, sampling_metadata: SamplingMetadata ) -> Optional[SamplerOutput]: next_tokens = self.sampler(logits, sampling_metadata) return next_tokens def compute_logits( - self, - hidden_states: Tensor, - sampling_metadata: SamplingMetadata, + self, + hidden_states: Tensor, + sampling_metadata: SamplingMetadata, ) -> Optional[Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) + logits = self.logits_processor( + self.lm_head, hidden_states, sampling_metadata) return logits diff --git a/vllm_mindspore/model_executor/models/registry.py b/vllm_mindspore/model_executor/models/registry.py index ef38ee0b9946723215a89b633f745dddbdb3d7ee..7865bc81f6b23570a7040ed3de9d902a84cae9ae 100644 --- a/vllm_mindspore/model_executor/models/registry.py +++ b/vllm_mindspore/model_executor/models/registry.py @@ -23,8 +23,8 @@ import tempfile from typing import Callable, TypeVar import cloudpickle - -from vllm.model_executor.models.registry import _ModelRegistry, _LazyRegisteredModel +from vllm.model_executor.models.registry import (_LazyRegisteredModel, + _ModelRegistry) from vllm_mindspore.utils import is_mindformers_model_backend diff --git a/vllm_mindspore/model_executor/models/utils.py b/vllm_mindspore/model_executor/models/utils.py index c84b6dc315b190b8d2ebc1d01b71c5a9b3f7684e..aef4783236c121b73eeeeec9a7ec73e2a0516755 100644 --- a/vllm_mindspore/model_executor/models/utils.py +++ b/vllm_mindspore/model_executor/models/utils.py @@ -17,13 +17,12 @@ from typing import List, Tuple +import mindspore as ms +from mindspore import mint from vllm.sequence import IntermediateTensors from vllm_mindspore.utils import get_valid_dtype -import mindspore as ms -from mindspore import mint - class PPMissingLayer(ms.nn.Cell): """ @@ -78,9 +77,9 @@ def extract_layer_index(layer_name: str) -> int: def make_layers( - num_hidden_layers: int, - layer_fn, - prefix: str, + num_hidden_layers: int, + layer_fn, + prefix: str, ) -> Tuple[int, int, ms.nn.CellList]: """Make a list of layers with the given layer function, taking pipeline parallelism into account. @@ -89,8 +88,7 @@ def make_layers( from vllm.distributed.utils import get_pp_indices start_layer, end_layer = get_pp_indices( - num_hidden_layers, get_pp_group().rank_in_group, get_pp_group().world_size - ) + num_hidden_layers, get_pp_group().rank_in_group, get_pp_group().world_size) modules = ms.nn.CellList( [PPMissingLayer() for _ in range(start_layer)] + [ @@ -103,11 +101,10 @@ def make_layers( def make_empty_intermediate_tensors_factory(keys: List[str], hidden_size: int): - def make_empty_intermediate_tensors( - batch_size: int, - dtype, - device, + batch_size: int, + dtype, + device, ) -> IntermediateTensors: dtype = get_valid_dtype(dtype) return IntermediateTensors( diff --git a/vllm_mindspore/model_executor/sampling_metadata.py b/vllm_mindspore/model_executor/sampling_metadata.py index e6b60f579ba00ac507a4f83d1ad6cbf89327f035..6b41e55a906d8d623946ea4463fabcfdb0752bc0 100644 --- a/vllm_mindspore/model_executor/sampling_metadata.py +++ b/vllm_mindspore/model_executor/sampling_metadata.py @@ -15,36 +15,39 @@ # limitations under the License. # ============================================================================ +""" +sampling_metadata +""" + from array import array from dataclasses import dataclass from typing import Dict, List, Optional, Tuple - +import mindspore as ms +from mindspore import Tensor from vllm.sampling_params import SamplingParams, SamplingType -from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData, SequenceGroupMetadata -from vllm.utils import ( - PyObjectCache, - async_tensor_h2d, - is_pin_memory_available, - make_tensor_with_pad, -) +from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData, + SequenceGroupMetadata) +from vllm.utils import (PyObjectCache, async_tensor_h2d, + is_pin_memory_available, make_tensor_with_pad) _SAMPLING_EPS = 1e-5 -from mindspore import Tensor -import mindspore as ms -# TODO(tronzhang): use vllm's SequenceGroupToSample. (now for tensor create pin/device and tensor.to) +# TODO(tronzhang): use vllm's SequenceGroupToSample. (now for tensor +# create pin/device and tensor.to) @dataclass class SequenceGroupToSample: + """ # |---------- N-1 iteration --------| # |---------------- N iteration ---------------------| # |- tokenA -|......................|-- newTokens ---| # |---------- context_len ----------| # |-------------------- seq_len ----------------------| # |-- query_len ---| + """ # Sequence ids for the sequence group in a previous step. seq_ids: List[int] @@ -75,7 +78,7 @@ class SequenceGroupToSample: return len(self.sample_indices) > 0 def __post_init__(self): - if len(self.prompt_logprob_indices) > 0: + if not self.prompt_logprob_indices: assert self.sampling_params.prompt_logprobs is not None if self.is_prompt: assert self.seq_len is not None @@ -152,13 +155,13 @@ class SamplingMetadata: """ def __init__( - self, - seq_groups: List[SequenceGroupToSample], - selected_token_indices: Tensor, - categorized_sample_indices: Dict[SamplingType, Tensor], - num_prompts: int, - skip_sampler_cpu_output: bool = False, - reuse_sampling_tensors: bool = False, + self, + seq_groups: List[SequenceGroupToSample], + selected_token_indices: Tensor, + categorized_sample_indices: Dict[SamplingType, Tensor], + num_prompts: int, + skip_sampler_cpu_output: bool = False, + reuse_sampling_tensors: bool = False, ) -> None: self.seq_groups = seq_groups self.selected_token_indices = selected_token_indices @@ -169,22 +172,38 @@ class SamplingMetadata: @staticmethod def prepare( - seq_group_metadata_list: List[SequenceGroupMetadata], - seq_lens: List[int], - query_lens: List[int], - device: str, - pin_memory: bool, - generators: Optional[Dict[str, ms.Generator]] = None, - cache: Optional[SamplingMetadataCache] = None, + seq_group_metadata_list: List[SequenceGroupMetadata], + seq_lens: List[int], + query_lens: List[int], + device: str, + pin_memory: bool, + generators: Optional[Dict[str, ms.Generator]] = None, + cache: Optional[SamplingMetadataCache] = None, ) -> "SamplingMetadata": - ( - seq_groups, - selected_token_indices, - categorized_sample_indices, - num_prompts, - ) = _prepare_seq_groups( - seq_group_metadata_list, seq_lens, query_lens, device, generators, cache - ) + """ + + Args: + seq_group_metadata_list: + seq_lens: + query_lens: + device: + pin_memory: + generators: + cache: + + Returns: + + """ + (seq_groups, + selected_token_indices, + categorized_sample_indices, + num_prompts, + ) = _prepare_seq_groups(seq_group_metadata_list, + seq_lens, + query_lens, + device, + generators, + cache) selected_token_indices = async_tensor_h2d( selected_token_indices, dtype=ms.int64, @@ -219,12 +238,12 @@ class SamplingMetadata: def _prepare_seq_groups( - seq_group_metadata_list, #: List[SequenceGroupMetadata], - seq_lens: List[int], - query_lens: List[int], - device: str, - generators: Optional[Dict[str, ms.Generator]] = None, - cache: Optional[SamplingMetadataCache] = None, + seq_group_metadata_list, #: List[SequenceGroupMetadata], + seq_lens: List[int], + query_lens: List[int], + device: str, + generators: Optional[Dict[str, ms.Generator]] = None, + cache: Optional[SamplingMetadataCache] = None, ) -> Tuple[ List[SequenceGroupToSample], List[int], @@ -319,7 +338,7 @@ def _prepare_seq_groups( # Decode prompt_logprob_len = 0 query_len = ( - query_lens[i] if query_lens is not None and len(query_lens) > 0 else 1 + query_lens[i] if query_lens is not None else 1 ) sample_len = len(seq_ids) * query_len if do_sample else 0 @@ -396,7 +415,11 @@ def _prepare_seq_groups( if cache is not None: cache.reset() - return (seq_groups, selected_token_indices, categorized_sample_indices, num_prompts) + return ( + seq_groups, + selected_token_indices, + categorized_sample_indices, + num_prompts) @dataclass @@ -415,12 +438,23 @@ class SamplingTensors: @classmethod def from_sampling_metadata( - cls, - sampling_metadata: "SamplingMetadata", - vocab_size: int, - device, #: torch.device, - dtype, #: torch.dtype, + cls, + sampling_metadata: "SamplingMetadata", + vocab_size: int, + device, #: torch.device, + dtype, #: torch.dtype, ) -> Tuple["SamplingTensors", bool, bool, bool]: + """ + + Args: + sampling_metadata: + vocab_size: + device: + dtype: + + Returns: + + """ prompt_tokens: List[array] = [] output_tokens: List[array] = [] top_ks: List[int] = [] @@ -454,15 +488,15 @@ class SamplingTensors: # Set the temperature to 1 to avoid division by zero. temperature = 1.0 if not do_top_p_top_k and ( - top_p < 1.0 - _SAMPLING_EPS or top_k != vocab_size + top_p < 1.0 - _SAMPLING_EPS or top_k != vocab_size ): do_top_p_top_k = True if not do_min_p and min_p > _SAMPLING_EPS: do_min_p = True if not do_penalties and ( - abs(p) >= _SAMPLING_EPS - or abs(f) >= _SAMPLING_EPS - or abs(r - 1.0) >= _SAMPLING_EPS + abs(p) >= _SAMPLING_EPS + or abs(f) >= _SAMPLING_EPS + or abs(r - 1.0) >= _SAMPLING_EPS ): do_penalties = True @@ -499,11 +533,9 @@ class SamplingTensors: if seq_group.is_prompt and sampling_params.prompt_logprobs is not None: prefill_len = len(seq_group.prompt_logprob_indices) prompt_tokens.extend( - array(VLLM_TOKEN_ID_ARRAY_TYPE) for _ in range(prefill_len) - ) + array(VLLM_TOKEN_ID_ARRAY_TYPE) for _ in range(prefill_len)) output_tokens.extend( - array(VLLM_TOKEN_ID_ARRAY_TYPE) for _ in range(prefill_len) - ) + array(VLLM_TOKEN_ID_ARRAY_TYPE) for _ in range(prefill_len)) if seq_group.do_sample: for seq_id in seq_ids: seq_data = seq_group.seq_data[seq_id] @@ -528,20 +560,39 @@ class SamplingTensors: @classmethod def from_lists( - cls, - temperatures: List[float], - top_ps: List[float], - top_ks: List[int], - min_ps: List[float], - presence_penalties: List[float], - frequency_penalties: List[float], - repetition_penalties: List[float], - prompt_tokens: List[array], - output_tokens: List[array], - vocab_size: int, - device, - dtype, + cls, + temperatures: List[float], + top_ps: List[float], + top_ks: List[int], + min_ps: List[float], + presence_penalties: List[float], + frequency_penalties: List[float], + repetition_penalties: List[float], + prompt_tokens: List[array], + output_tokens: List[array], + vocab_size: int, + _device, + dtype, ) -> "SamplingTensors": + """ + + Args: + temperatures: + top_ps: + top_ks: + min_ps: + presence_penalties: + frequency_penalties: + repetition_penalties: + prompt_tokens: + output_tokens: + vocab_size: + _: + dtype: + + Returns: + + """ # Note that the performance will be very bad without # pinned memory. pin_memory = is_pin_memory_available() @@ -601,7 +652,8 @@ class SamplingTensors: # Because the memory is pinned, we can do non-blocking # transfer to device. - # TODO(tronzhang): mindspore tensor donot support tensor.to(device=xxx, non_blocking=xxx), but tensor.move_to(to, blocking=xxx). + # TODO(tronzhang): mindspore tensor donot support tensor.to(device=xxx, + # non_blocking=xxx), but tensor.move_to(to, blocking=xxx). return cls( temperatures=temperatures_t, top_ps=top_ps_t, diff --git a/vllm_mindspore/model_executor/utils.py b/vllm_mindspore/model_executor/utils.py index c6de292aafdd2e56052b24dfe0ba2eb940d06a65..4ede48d8c877182146acc5ea0aef1c46c1dbb452 100644 --- a/vllm_mindspore/model_executor/utils.py +++ b/vllm_mindspore/model_executor/utils.py @@ -16,14 +16,16 @@ # ============================================================================ from typing import Any, Dict, Optional + from mindspore import Tensor + # TODO(tronzhang): Use vllm's latter... def set_weight_attrs( - weight: Tensor, - weight_attrs: Optional[Dict[str, Any]], + weight: Tensor, + weight_attrs: Optional[Dict[str, Any]], ): if weight_attrs is None: return diff --git a/vllm_mindspore/msadapter b/vllm_mindspore/msadapter deleted file mode 160000 index 6417e50602fc83d2a9291ef652749c3935b37588..0000000000000000000000000000000000000000 --- a/vllm_mindspore/msadapter +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 6417e50602fc83d2a9291ef652749c3935b37588 diff --git a/vllm_mindspore/platforms/ascend.py b/vllm_mindspore/platforms/ascend.py index de56d2680faac68893d0b035e88dc34408ea6507..e5844982160be6adb5b64a57210d1a5c49c8e72e 100644 --- a/vllm_mindspore/platforms/ascend.py +++ b/vllm_mindspore/platforms/ascend.py @@ -19,9 +19,9 @@ from typing import TYPE_CHECKING, Optional import torch - -from vllm.platforms.interface import DeviceCapability, Platform, PlatformEnum, _Backend from vllm.logger import init_logger +from vllm.platforms.interface import (DeviceCapability, Platform, PlatformEnum, + _Backend) if TYPE_CHECKING: from vllm.config import VllmConfig @@ -44,8 +44,8 @@ class AscendPlatform(Platform): @classmethod def get_device_capability( - cls, - device_id: int = 0, + cls, + device_id: int = 0, ) -> Optional[DeviceCapability]: major, minor = torch.cuda.get_device_capability(device_id) return DeviceCapability(major=major, minor=minor) diff --git a/vllm_mindspore/sequence.py b/vllm_mindspore/sequence.py index 339f2d89ee7976d1597b784463a5926aa9441686..60d2adfacb185d1796d94b511098fd3950fec813 100644 --- a/vllm_mindspore/sequence.py +++ b/vllm_mindspore/sequence.py @@ -1,25 +1,23 @@ - """Sequence and its related classes.""" + import copy import enum from abc import ABC, abstractmethod from array import array -from collections import defaultdict -from dataclasses import dataclass, field +from dataclasses import dataclass from functools import reduce -from typing import Any, Callable, DefaultDict, Dict, List, Mapping, Optional +from typing import Any, Dict, List, Mapping, Optional from typing import Sequence as GenericSequence -from typing import Set, Tuple, Union +from typing import Tuple, Union import msgspec import torch - from vllm.inputs import SingletonInputs, SingletonInputsAdapter from vllm.lora.request import LoRARequest from vllm.multimodal import MultiModalDataDict, MultiModalPlaceholderDict from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest -from vllm.sampling_params import RequestOutputKind, SamplingParams +from vllm.sampling_params import SamplingParams VLLM_TOKEN_ID_ARRAY_TYPE = "l" @@ -43,6 +41,7 @@ class Logprob: rank: The vocab rank of chosen token (>=1) decoded_token: The decoded chosen token index """ + logprob: float rank: Optional[int] = None decoded_token: Optional[str] = None @@ -57,6 +56,7 @@ SampleLogprobs = List[Dict[int, Logprob]] class SequenceStatus(enum.IntEnum): """Status of a sequence.""" + WAITING = 0 RUNNING = 1 SWAPPED = 2 @@ -73,6 +73,11 @@ class SequenceStatus(enum.IntEnum): @staticmethod def get_finished_reason(status: "SequenceStatus") -> Union[str, None]: + """ + + :param status: + :return: + """ if status == SequenceStatus.FINISHED_STOPPED: finish_reason = "stop" elif status == SequenceStatus.FINISHED_LENGTH_CAPPED: @@ -112,6 +117,7 @@ class RequestMetrics: will include model forward, block/sync across workers, cpu-gpu sync time and sampling time. """ + arrival_time: float last_token_time: float first_scheduled_time: Optional[float] @@ -124,10 +130,11 @@ class RequestMetrics: class SequenceDataDelta( - msgspec.Struct, - array_like=True, # type: ignore[call-arg] - omit_defaults=True): # type: ignore[call-arg] + # type: ignore[call-arg] + msgspec.Struct, array_like=True, omit_defaults=True +): # type: ignore[call-arg] """Delta SequenceData to send to workers per step.""" + # A new token to be appended to existing SequenceData. new_output_token_ids: List[int] # Overwriting existing `cumulative_logprob` @@ -138,8 +145,9 @@ class SequenceDataDelta( new_stage: SequenceStage -class SequenceData(msgspec.Struct, - omit_defaults=True): # type: ignore[call-arg] +class SequenceData( + msgspec.Struct, + omit_defaults=True): # type: ignore[call-arg] """Data associated with a sequence. Args: @@ -152,16 +160,18 @@ class SequenceData(msgspec.Struct, output_token_ids: The token IDs of the output. cumulative_logprob: The cumulative log probability of the output. """ + # NOTE: we cannot use Union[List, array] because msgspec cannot support # union of 2 list types. _prompt_token_ids: array _output_token_ids: array = msgspec.field( - default_factory=lambda: array(VLLM_TOKEN_ID_ARRAY_TYPE, [])) + default_factory=lambda: array(VLLM_TOKEN_ID_ARRAY_TYPE, []) + ) ### The below fields should not be passed as an argument ### _cumulative_logprob: float = 0.0 - _prompt_token_ids_tuple: Tuple[int, - ...] = msgspec.field(default_factory=tuple) + _prompt_token_ids_tuple: Tuple[int, ...] = msgspec.field( + default_factory=tuple) # The number of tokens that are computed (that run against the model). _num_computed_tokens: int = 0 # The number of tokens with prefix cache hit. @@ -186,7 +196,7 @@ class SequenceData(msgspec.Struct, Each tuple represents one token sequence, expressed in the form :code:`(token_id, count)`. """ - if len(token_counts) == 0: + if not token_counts: return SequenceData.from_seqs([]) prompt_token_ids_arr = reduce( @@ -198,24 +208,25 @@ class SequenceData(msgspec.Struct, @staticmethod def from_seqs( - prompt_token_ids: GenericSequence[int], - output_token_ids: Optional[GenericSequence[int]] = None, + prompt_token_ids: GenericSequence[int], + output_token_ids: Optional[GenericSequence[int]] = None, ) -> "SequenceData": """ Construct a :class:`SequenceData` instance from prompt and output token sequences. """ - prompt_token_ids_arr = array(VLLM_TOKEN_ID_ARRAY_TYPE, - prompt_token_ids) + prompt_token_ids_arr = array( + VLLM_TOKEN_ID_ARRAY_TYPE, prompt_token_ids) if output_token_ids is None: return SequenceData(prompt_token_ids_arr) - output_token_ids_arr = array(VLLM_TOKEN_ID_ARRAY_TYPE, - output_token_ids) + output_token_ids_arr = array( + VLLM_TOKEN_ID_ARRAY_TYPE, output_token_ids) - return SequenceData(prompt_token_ids_arr, - _output_token_ids=output_token_ids_arr) + return SequenceData( + prompt_token_ids_arr, _output_token_ids=output_token_ids_arr + ) def __post_init__(self) -> None: assert self._prompt_token_ids.typecode == "l" @@ -227,8 +238,9 @@ class SequenceData(msgspec.Struct, def _update_cached_all_tokens(self): assert isinstance(self._prompt_token_ids, array) assert isinstance(self._output_token_ids, array) - self._cached_all_token_ids: List[int] = list(self._prompt_token_ids + - self._output_token_ids) + self._cached_all_token_ids: List[int] = list( + self._prompt_token_ids + self._output_token_ids + ) @property def cumulative_logprob(self) -> float: @@ -256,10 +268,11 @@ class SequenceData(msgspec.Struct, return tuple(self._output_token_ids) @output_token_ids.setter - def output_token_ids(self, - new_output_token_ids: GenericSequence[int]) -> None: - self._output_token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, - new_output_token_ids) + def output_token_ids( + self, + new_output_token_ids: GenericSequence[int]) -> None: + self._output_token_ids = array( + VLLM_TOKEN_ID_ARRAY_TYPE, new_output_token_ids) self._update_cached_all_tokens() @property @@ -304,20 +317,26 @@ class SequenceData(msgspec.Struct, """Get prefix tokens, and make the return value hashable""" prompt_length = self.get_prompt_len() if num_tokens > prompt_length: - return (self._prompt_token_ids_tuple, - tuple(self._output_token_ids[:num_tokens - prompt_length])) - else: - return (self._prompt_token_ids_tuple[:num_tokens], None) + return ( + self._prompt_token_ids_tuple, + tuple(self._output_token_ids[: num_tokens - prompt_length]), + ) + return (self._prompt_token_ids_tuple[:num_tokens], None) def get_num_computed_tokens(self) -> int: """Return the number of prefill tokens that are already computed.""" return self._num_computed_tokens + def get_cached_all_token_ids(self) -> List[int]: + return self._cached_all_token_ids + def update_num_computed_tokens(self, num_new_computed_tokens: int): """Update number of tokens computed so far.""" self._num_computed_tokens += num_new_computed_tokens assert self._num_computed_tokens <= self.get_len(), ( - self._num_computed_tokens, self.get_len()) + self._num_computed_tokens, + self.get_len(), + ) # If all tokens are computed, it means it is in decoding phase. if self.get_num_uncomputed_tokens() == 0: self._stage = SequenceStage.DECODE @@ -358,9 +377,12 @@ class SequenceData(msgspec.Struct, return self.output_token_ids def get_delta_and_reset(self) -> SequenceDataDelta: - delta = SequenceDataDelta(self._new_appended_tokens, - self._cumulative_logprob, - self.get_num_computed_tokens(), self.stage) + delta = SequenceDataDelta( + self._new_appended_tokens, + self._cumulative_logprob, + self.get_num_computed_tokens(), + self.stage, + ) # Reset delta state. self._new_appended_tokens = [] return delta @@ -377,11 +399,13 @@ class SequenceData(msgspec.Struct, return self._stage def __repr__(self) -> str: - return (f"SequenceData(" - f"prompt_token_ids={self._prompt_token_ids}, " - f"output_token_ids={self.output_token_ids}, " - f"cumulative_logprob={self.cumulative_logprob}, " - f"get_num_computed_tokens={self.get_num_computed_tokens()}") + return ( + f"SequenceData(" + f"prompt_token_ids={self._prompt_token_ids}, " + f"output_token_ids={self.output_token_ids}, " + f"cumulative_logprob={self.cumulative_logprob}, " + f"get_num_computed_tokens={self.get_num_computed_tokens()}" + ) class Sequence: @@ -402,13 +426,13 @@ class Sequence: """ def __init__( - self, - seq_id: int, - inputs: SingletonInputs, - block_size: int, - eos_token_id: Optional[int] = None, - lora_request: Optional[LoRARequest] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, + self, + seq_id: int, + inputs: SingletonInputs, + block_size: int, + eos_token_id: Optional[int] = None, + lora_request: Optional[LoRARequest] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, ) -> None: self.seq_id = seq_id self.inputs = SingletonInputsAdapter(inputs) @@ -472,19 +496,24 @@ class Sequence: @property def prompt_adapter_id(self) -> int: - return self.prompt_adapter_request.prompt_adapter_id \ - if self.prompt_adapter_request else 0 + return ( + self.prompt_adapter_request.prompt_adapter_id + if self.prompt_adapter_request + else 0 + ) - def get_output_text_to_return(self, buffer_length: int, - delta: bool) -> str: + def get_output_text_to_return( + self, + buffer_length: int, + delta: bool) -> str: """If delta is True, only new text since the last call to this method is returned""" # We return the full output text if the sequence is finished. truncate = buffer_length and not self.is_finished() if not delta: - return self.output_text[:-buffer_length] if truncate else ( - self.output_text) + return self.output_text[:- + buffer_length] if truncate else (self.output_text) length = len(self.output_text) if truncate: length -= buffer_length @@ -495,7 +524,8 @@ class Sequence: return "" def get_output_token_ids_to_return( - self, delta: bool) -> Union[GenericSequence[int], int]: + self, delta: bool + ) -> Union[GenericSequence[int], int]: """If delta is True, only new tokens since the last call to this method are returned""" if not delta: @@ -511,12 +541,12 @@ class Sequence: if num_new_tokens == 1: # Optimization for single decode token case # (which is what we have most of the time) - return self.data._cached_all_token_ids[-1] + return self.data.get_cached_all_token_ids()[-1] if num_new_tokens == 0: return [] - return self.data._cached_all_token_ids[-num_new_tokens:] + return self.data.get_cached_all_token_ids[-num_new_tokens:] def hash_of_block(self, logical_idx: int) -> int: # TODO This can produce incorrect hash when block size > prompt size @@ -548,8 +578,8 @@ class Sequence: """Reset the sequence states for recomputation.""" self.data.reset_state_for_recompute() - def append_token_id(self, token_id: int, logprobs: Dict[int, - Logprob]) -> None: + def append_token_id(self, token_id: int, + logprobs: Dict[int, Logprob]) -> None: assert token_id in logprobs self.output_logprobs.append(logprobs) self.data.append_token_id(token_id, logprobs[token_id].logprob) @@ -604,13 +634,16 @@ class Sequence: return self.data.stage == SequenceStage.PREFILL def __repr__(self) -> str: - return (f"Sequence(seq_id={self.seq_id}, " - f"status={self.status.name}, " - f"num_blocks={self.n_blocks}, ") + return ( + f"Sequence(seq_id={self.seq_id}, " + f"status={self.status.name}, " + f"num_blocks={self.n_blocks}, " + ) -class SequenceGroupState(msgspec.Struct, - omit_defaults=True): # type: ignore[call-arg] +class SequenceGroupState( + msgspec.Struct, + omit_defaults=True): # type: ignore[call-arg] """Mutable state tied to a specific sequence group""" # for multi-step decoding @@ -642,18 +675,18 @@ class SequenceGroup: """ def __init__( - self, - request_id: str, - seqs: List[Sequence], - arrival_time: float, - sampling_params: Optional[SamplingParams] = None, - lora_request: Optional[LoRARequest] = None, - pooling_params: Optional[PoolingParams] = None, - pooled_data: Optional[torch.Tensor] = None, - encoder_seq: Optional[Sequence] = None, - trace_headers: Optional[Mapping[str, str]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, - priority: int = 0, + self, + request_id: str, + seqs: List[Sequence], + arrival_time: float, + sampling_params: Optional[SamplingParams] = None, + lora_request: Optional[LoRARequest] = None, + pooling_params: Optional[PoolingParams] = None, + pooled_data: Optional[torch.Tensor] = None, + encoder_seq: Optional[Sequence] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, ) -> None: self.request_id = request_id self.seqs = seqs @@ -663,11 +696,13 @@ class SequenceGroup: self.seqs_dict = {seq.seq_id: seq for seq in seqs} self.sampling_params = sampling_params - self.metrics = RequestMetrics(arrival_time=arrival_time, - last_token_time=arrival_time, - first_scheduled_time=None, - first_token_time=None, - time_in_queue=None) + self.metrics = RequestMetrics( + arrival_time=arrival_time, + last_token_time=arrival_time, + first_scheduled_time=None, + first_token_time=None, + time_in_queue=None, + ) self.lora_request = lora_request self.prompt_logprobs: Optional[PromptLogprobs] = None self.state = SequenceGroupState() @@ -693,16 +728,15 @@ class SequenceGroup: # There are either 0 or 1 encoder sequences # If one is present, its prompt is distinct # from the decoder's. - return (self.encoder_seq.prompt - if self.encoder_seq is not None else None) + return self.encoder_seq.prompt if self.encoder_seq is not None else None @property def encoder_prompt_token_ids(self) -> Optional[List[int]]: # There are either 0 or 1 encoder sequences # If one is present, its prompt token ids are # distinct from the decoder's. - return (self.encoder_seq.prompt_token_ids - if self.encoder_seq is not None else None) + return ( + self.encoder_seq.prompt_token_ids if self.encoder_seq is not None else None) @property def token_type_ids(self) -> Optional[List[int]]: @@ -726,23 +760,38 @@ class SequenceGroup: @property def prompt_adapter_id(self) -> int: - return self.prompt_adapter_request.prompt_adapter_id \ - if self.prompt_adapter_request else 0 + return ( + self.prompt_adapter_request.prompt_adapter_id + if self.prompt_adapter_request + else 0 + ) @property def prompt_adapter_num_virtual_tokens(self) -> int: - return self.prompt_adapter_request.prompt_adapter_num_virtual_tokens\ - if self.prompt_adapter_request else 0 + return ( + self.prompt_adapter_request.prompt_adapter_num_virtual_tokens + if self.prompt_adapter_request + else 0 + ) def init_multi_step(self, num_steps: int) -> None: self.state.num_steps = num_steps self.state.current_step = 0 - def init_multi_step_from_lookahead_slots(self, num_lookahead_slots: int, - num_scheduler_steps: int, - is_multi_step: bool, - enable_chunking: bool) -> None: - + def init_multi_step_from_lookahead_slots( + self, + num_lookahead_slots: int, + num_scheduler_steps: int, + is_multi_step: bool, + enable_chunking: bool, + ) -> None: + """ + :param num_lookahead_slots: + :param num_scheduler_steps: + :param is_multi_step: + :param enable_chunking: + :return: + """ if not is_multi_step: self.init_multi_step(num_steps=num_scheduler_steps) return @@ -769,7 +818,8 @@ class SequenceGroup: if self.is_prefill(): raise ValueError( "seq_group.get_last_latency() should not be called " - "if the seq_group is in prefill phase.") + "if the seq_group is in prefill phase." + ) # Otherwise return token latency. latency = now - self.metrics.last_token_time @@ -782,8 +832,10 @@ class SequenceGroup: # recomputed, the time between iterations is counted # in TPOT, rather than recalculating TTFT (since from the ) # POV of the user, there is simply a long generation delay. - if (self.metrics.first_token_time is None - and self.first_seq.get_output_len() == 1): + if ( + self.metrics.first_token_time is None + and self.first_seq.get_output_len() == 1 + ): self.metrics.first_token_time = time def maybe_set_first_scheduled_time(self, time: float) -> None: @@ -803,8 +855,8 @@ class SequenceGroup: return 0 if self.first_seq.is_finished() else 1 def get_seqs( - self, - status: Optional[SequenceStatus] = None, + self, + status: Optional[SequenceStatus] = None, ) -> List[Sequence]: if status is None: return self.seqs @@ -854,21 +906,25 @@ class SequenceGroup: return self.first_seq.is_prefill() def __repr__(self) -> str: - return (f"SequenceGroup(request_id={self.request_id}, " - f"sampling_params={self.sampling_params}, " - f"num_seqs={len(self.seqs)})") + return ( + f"SequenceGroup(request_id={self.request_id}, " + f"sampling_params={self.sampling_params}, " + f"num_seqs={len(self.seqs)})" + ) class SequenceGroupMetadataDelta( - msgspec.Struct, - tag=True, # type: ignore[call-arg] - array_like=True, # type: ignore[call-arg] - omit_defaults=True): # type: ignore[call-arg] + msgspec.Struct, + tag=True, # type: ignore[call-arg] + array_like=True, # type: ignore[call-arg] + omit_defaults=True, +): # type: ignore[call-arg] """Delta of SequenceGroupMetadata. After sending the first SequenceGroupMetadata, vLLM scheduler only sends delta to reduce the data payload size. """ + seq_data_delta: Dict[int, SequenceDataDelta] request_id: str block_tables: Dict[int, List[int]] @@ -877,14 +933,16 @@ class SequenceGroupMetadataDelta( token_chunk_size: Optional[int] = None computed_block_nums: Optional[List[int]] = None state: Optional[SequenceGroupState] = msgspec.field( - default_factory=lambda: SequenceGroupState()) + default_factory=SequenceGroupState() + ) class SequenceGroupMetadata( - msgspec.Struct, - tag=True, # type: ignore[call-arg] - array_like=True, # type: ignore[call-arg] - omit_defaults=True): # type: ignore[call-arg] + msgspec.Struct, + tag=True, # type: ignore[call-arg] + array_like=True, # type: ignore[call-arg] + omit_defaults=True, +): # type: ignore[call-arg] """Metadata for a sequence group. Used to create `AttentionMetadata`. Args: @@ -927,7 +985,8 @@ class SequenceGroupMetadata( lora_request: Optional[LoRARequest] = None computed_block_nums: Optional[List[int]] = None state: Optional[SequenceGroupState] = msgspec.field( - default_factory=lambda: SequenceGroupState()) + default_factory=SequenceGroupState() + ) # "MultiModalDataDict" types. We have to use Any due to msgspec # doesn't allow to have union of 2 different dicts. token_type_ids: Optional[List[int]] = None @@ -949,8 +1008,8 @@ class SequenceGroupMetadata( def __post_init__(self): if self.seq_data is not None and self.token_chunk_size is None: if self.is_prompt: - self.token_chunk_size = next(iter( - self.seq_data.values())).get_len() + self.token_chunk_size = next( + iter(self.seq_data.values())).get_len() else: self.token_chunk_size = 1 @@ -960,13 +1019,19 @@ class SequenceGroupMetadata( @property def prompt_adapter_id(self) -> int: - return self.prompt_adapter_request.prompt_adapter_id \ - if self.prompt_adapter_request else 0 + return ( + self.prompt_adapter_request.prompt_adapter_id + if self.prompt_adapter_request + else 0 + ) @property def prompt_adapter_num_virtual_tokens(self) -> int: - return self.prompt_adapter_request.prompt_adapter_num_virtual_tokens \ - if self.prompt_adapter_request else 0 + return ( + self.prompt_adapter_request.prompt_adapter_num_virtual_tokens + if self.prompt_adapter_request + else 0 + ) # Multi-Step Chunked-Prefill property @property @@ -982,10 +1047,11 @@ class SequenceGroupMetadata( # we know this SequenceGroup has only one sequence. return next(iter(self.seq_data)) - def apply_delta(self, - sequence_group_metadata_delta: SequenceGroupMetadataDelta): - for id, delta in sequence_group_metadata_delta.seq_data_delta.items(): - self.seq_data[id].apply_delta(delta) + def apply_delta( + self, + sequence_group_metadata_delta: SequenceGroupMetadataDelta): + for idx, delta in sequence_group_metadata_delta.seq_data_delta.items(): + self.seq_data[idx].apply_delta(delta) assert self.request_id == sequence_group_metadata_delta.request_id self.block_tables = sequence_group_metadata_delta.block_tables self.token_chunk_size = sequence_group_metadata_delta.token_chunk_size @@ -994,15 +1060,16 @@ class SequenceGroupMetadata( def finish_step(self) -> None: assert self.state is not None - assert self.state.current_step < self.state.num_steps, \ - f"current step {self.state.current_step}, num_steps {self.state.num_steps}" # noqa + assert ( + self.state.current_step < self.state.num_steps + ), f"current step {self.state.current_step}, num_steps {self.state.num_steps}" # noqa self.state.current_step += 1 class SequenceOutput( - msgspec.Struct, - omit_defaults=True, # type: ignore[call-arg] - array_like=True): # type: ignore[call-arg] + # type: ignore[call-arg] + msgspec.Struct, omit_defaults=True, array_like=True +): # type: ignore[call-arg] """The model output associated with a sequence. Args: @@ -1012,20 +1079,25 @@ class SequenceOutput( logprobs: The logprobs of the output token. (Token id -> logP(x_i+1 | x_0, ..., x_i)) """ + parent_seq_id: int output_token: int logprobs: Dict[int, Logprob] def __repr__(self) -> str: - return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, " - f"output_token={self.output_token}, " - f"logprobs={self.logprobs})") + return ( + f"SequenceOutput(parent_seq_id={self.parent_seq_id}, " + f"output_token={self.output_token}, " + f"logprobs={self.logprobs})" + ) def __eq__(self, other: object) -> bool: if not isinstance(other, SequenceOutput): raise NotImplementedError() - equal = (self.parent_seq_id == other.parent_seq_id - and self.output_token == other.output_token) + equal = ( + self.parent_seq_id == other.parent_seq_id + and self.output_token == other.output_token + ) log_probs_equal = other.logprobs == self.logprobs return equal and log_probs_equal @@ -1039,4 +1111,4 @@ class SequenceGroupOutput(ABC): @abstractmethod def __eq__(self, other: object) -> bool: - pass \ No newline at end of file + pass diff --git a/vllm_mindspore/tests/test_sampler.py b/vllm_mindspore/tests/test_sampler.py index e0d91147ce80960300b78abe55802c1fe5327653..b075e213987be9e0200196aae3ce7fef38240333 100644 --- a/vllm_mindspore/tests/test_sampler.py +++ b/vllm_mindspore/tests/test_sampler.py @@ -1,21 +1,18 @@ -import vllm_mindspore -import itertools import random -from dataclasses import dataclass -from typing import Dict, List, Optional, Tuple -from unittest.mock import Mock, patch +from typing import List, Tuple -import pytest import torch from vllm_mindspore.model_executor.layers.sampler import Sampler from vllm_mindspore.model_executor.sampling_metadata import SamplingMetadata # from vllm_mindspore.model_executor.utils import set_random_seed -from vllm_mindspore.sequence import SamplingParams, SequenceData, SequenceGroupMetadata +from vllm_mindspore.sequence import (SamplingParams, SequenceData, + SequenceGroupMetadata) VOCAB_SIZE = 32000 RANDOM_SEEDS = list(range(128)) + class MockLogitsSampler(Sampler): def __init__(self, fake_logits: torch.Tensor): @@ -25,6 +22,7 @@ class MockLogitsSampler(Sampler): def forward(self, *args, **kwargs): return super().forward(*args, **kwargs) + def _prepare_test( batch_size: int ) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsSampler]: @@ -35,12 +33,13 @@ def _prepare_test( sampler = MockLogitsSampler(fake_logits) return input_tensor, fake_logits, sampler + def _do_sample( - batch_size: int, - input_tensor: torch.Tensor, - sampler: MockLogitsSampler, - sampling_params: SamplingParams, - device: str, + batch_size: int, + input_tensor: torch.Tensor, + sampler: MockLogitsSampler, + sampling_params: SamplingParams, + device: str, ): seq_group_metadata_list: List[SequenceGroupMetadata] = [] seq_lens: List[int] = [] @@ -63,9 +62,10 @@ def _do_sample( pin_memory=False) return sampler(logits=input_tensor, sampling_metadata=sampling_metadata) + def test_sampler_all_greedy(): # set_random_seed(seed) - device='cuda' + device = 'cuda' # torch.set_default_device(device) batch_size = random.randint(1, 256) input_tensor, fake_logits, sampler = _prepare_test(batch_size) @@ -82,7 +82,7 @@ def test_sampler_all_greedy(): def test_sampler_all_random(): # set_random_seed(seed) # torch.set_default_device(device) - device='cuda' + device = 'cuda' batch_size = random.randint(1, 256) _, fake_logits, sampler = _prepare_test(batch_size) @@ -101,9 +101,8 @@ def test_sampler_all_random(): assert nth_output.output_token == i - def test_sampler_repetition_penalty_mixed(): - device='cuda' + device = 'cuda' vocab_size = 8 def test_sampling_params(sampling_params: List[SamplingParams]): @@ -136,7 +135,7 @@ def test_sampler_repetition_penalty_mixed(): fake_logits[:, 1] = 1.2e-2 sampler = MockLogitsSampler(fake_logits) - print(f'fake_logits is: {fake_logits}', flush = True) + print(f'fake_logits is: {fake_logits}', flush=True) sampler_output = sampler(logits=fake_logits, sampling_metadata=sampling_metadata) @@ -165,4 +164,4 @@ def test_sampler_repetition_penalty_mixed(): tokens2 = test_sampling_params( [sampling_params_sample, sampling_params_rep]) - assert tokens1[0] == tokens2[1] \ No newline at end of file + assert tokens1[0] == tokens2[1] diff --git a/vllm_mindspore/utils.py b/vllm_mindspore/utils.py index e35195143dd7b217b5e46342ac6b9b5f1af55daa..6d3680b765b09e92d9aa985dc995450862cf4edf 100644 --- a/vllm_mindspore/utils.py +++ b/vllm_mindspore/utils.py @@ -19,15 +19,8 @@ import contextlib import gc import logging import os -from typing import ( - TYPE_CHECKING, - Callable, - Generator, - List, - Optional, - Tuple, - Union, -) +from typing import (TYPE_CHECKING, Callable, Generator, List, Optional, Tuple, + Union) import torch @@ -36,17 +29,15 @@ if TYPE_CHECKING: else: Library = None -from vllm.utils import T, TORCH_DTYPE_TO_NUMPY_DTYPE, make_ndarray_with_pad - import mindspore as ms -from mindspore.common.initializer import Zero from mindspore import dtype as mstype +from mindspore.common.initializer import Zero +from vllm.utils import TORCH_DTYPE_TO_NUMPY_DTYPE, T, make_ndarray_with_pad MsKVCache = Tuple[ms.Tensor, ms.Tensor] logger = logging.getLogger(__name__) - STR_DTYPE_TO_MS_DTYPE = { "half": ms.float16, "float16": ms.float16, @@ -65,18 +56,18 @@ def get_valid_dtype(dtype): def direct_register_custom_op( - op_name: str, - op_func: Callable, - mutates_args: List[str], - fake_impl: Optional[Callable] = None, - target_lib: Optional[Library] = None, - dispatch_key: str = "CUDA", + op_name: str, + op_func: Callable, + mutates_args: List[str], + fake_impl: Optional[Callable] = None, + target_lib: Optional[Library] = None, + dispatch_key: str = "CUDA", ): ... @contextlib.contextmanager def memory_profiling( - baseline_memory_in_bytes: int, weights_memory_in_bytes: int + baseline_memory_in_bytes: int, weights_memory_in_bytes: int ) -> Generator["MemoryProfilingResult", None, None]: """Memory profiling context manager. baseline_memory_in_bytes: memory used by all the components other than @@ -151,19 +142,20 @@ def memory_profiling( diff = result.after_profile - result.before_profile result.torch_peak_increase_in_bytes = diff.torch_peak_in_bytes - # For mindspore, the memory is allocated and free in memory pool, so cannot read the current used memory by `torch.cuda.mem_get_info`. + # For mindspore, the memory is allocated and free in memory pool, so + # cannot read the current used memory by `torch.cuda.mem_get_info`. current_cuda_memory_bytes = result.after_profile.torch_memory_in_bytes result.non_torch_increase_in_bytes = ( - current_cuda_memory_bytes - - baseline_memory_in_bytes - - weights_memory_in_bytes - - diff.torch_memory_in_bytes + current_cuda_memory_bytes + - baseline_memory_in_bytes + - weights_memory_in_bytes + - diff.torch_memory_in_bytes ) # noqa result.profile_time = diff.timestamp result.non_kv_cache_memory_in_bytes = ( - result.non_torch_increase_in_bytes - + result.torch_peak_increase_in_bytes - + result.weights_memory_in_bytes + result.non_torch_increase_in_bytes + + result.torch_peak_increase_in_bytes + + result.weights_memory_in_bytes ) # noqa @@ -177,13 +169,13 @@ def _create_empty_tensor(ms_type): def make_tensor_with_pad( - x: List[List[T]], - pad: T, - dtype: torch.dtype, - *, - max_len: Optional[int] = None, - device: Optional[Union[str, torch.device]] = None, - pin_memory: bool = False, + x: List[List[T]], + pad: T, + dtype: torch.dtype, + *, + max_len: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + pin_memory: bool = False, ) -> torch.Tensor: """ Make a padded tensor from 2D inputs. @@ -205,16 +197,20 @@ def make_tensor_with_pad( def async_tensor_h2d( - data: list, - dtype: torch.dtype, - target_device: Union[str, torch.device], - pin_memory: bool, + data: list, + dtype: torch.dtype, + target_device: Union[str, torch.device], + pin_memory: bool, ) -> torch.Tensor: """Asynchronously create a tensor and copy it from host to device.""" if not data: t = _create_empty_tensor(dtype) else: - t = torch.tensor(data, dtype=dtype, pin_memory=pin_memory, device="CPU") + t = torch.tensor( + data, + dtype=dtype, + pin_memory=pin_memory, + device="CPU") return t @@ -252,7 +248,9 @@ def ascend_device_count_stateless() -> List[str]: try: res = visible_device_str.split(",") except Exception as e: - logger.error('Cannot parse "ASCEND_RT_VISIBLE_DEVICES" for: %s!' % str(e)) + logger.error( + 'Cannot parse "ASCEND_RT_VISIBLE_DEVICES" for: %s!' % + str(e)) raise ValueError( 'Error argument(%s) of environ "ASCEND_RT_VISIBLE_DEVICES"!' % visible_device_str @@ -277,7 +275,9 @@ def ascend_device_count_stateless() -> List[str]: avl_devices.append(str(i)) visible_device_str = ",".join(avl_devices) os.environ["ASCEND_RT_VISIBLE_DEVICES"] = visible_device_str - logger.info('Set environ "ASCEND_RT_VISIBLE_DEVICES" as %s' % visible_device_str) + logger.info( + 'Set environ "ASCEND_RT_VISIBLE_DEVICES" as %s' % + visible_device_str) return len(avl_devices) @@ -289,24 +289,27 @@ def ascend_is_initialized(): def is_mindformers_model_backend(): return ( - os.getenv("vLLM_MODEL_BACKEND") - and os.environ["vLLM_MODEL_BACKEND"] == "MindFormers" + os.getenv("vLLM_MODEL_BACKEND") + and os.environ["vLLM_MODEL_BACKEND"] == "MindFormers" ) def check_ready(): from mindspore import set_context - set_context(jit_config={"jit_level": "O0", "infer_boost": "on"}) # Common environment variables of predict. + # Common environment variables of predict. + set_context(jit_config={"jit_level": "O0", "infer_boost": "on"}) if is_mindformers_model_backend(): logger.info("Run with Mindformers backend!") - necessary_envs = ("vLLM_MODEL_MEMORY_USE_GB", "MINDFORMERS_MODEL_CONFIG") - lost_envs = [env_item for env_item in necessary_envs if not os.getenv(env_item)] + necessary_envs = ( + "vLLM_MODEL_MEMORY_USE_GB", + "MINDFORMERS_MODEL_CONFIG") + lost_envs = [ + env_item for env_item in necessary_envs if not os.getenv(env_item)] if lost_envs: raise RuntimeError( - 'For "MindFormers" model backend, environments %s should be set!' - % str(lost_envs) - ) + 'For "MindFormers" model backend, environments %s should be set!' % + str(lost_envs)) set_context(mode=0, device_target="Ascend", max_call_depth=10000) else: @@ -321,7 +324,8 @@ def cal_block_num(cache_config, model_config, parallel_config): _, total_gpu_memory = torch.cuda.mem_get_info() memory_can_use = total_gpu_memory * cache_config.gpu_memory_utilization - model_use_memory_b = int(os.getenv("vLLM_MODEL_MEMORY_USE_GB")) * 1024 * 1024 * 1024 + model_use_memory_b = int( + os.getenv("vLLM_MODEL_MEMORY_USE_GB")) * 1024 * 1024 * 1024 available_cache_memory = memory_can_use - model_use_memory_b cache_block_size = CacheEngine.get_cache_block_size( cache_config, model_config, parallel_config diff --git a/vllm_mindspore/version.py b/vllm_mindspore/version.py index c7b78b1dedcb24b87f61f448af111175edf7001a..344442e815dc092b4a6f2af69f5c68e9ce9f8cf1 100644 --- a/vllm_mindspore/version.py +++ b/vllm_mindspore/version.py @@ -18,12 +18,15 @@ from setuptools_scm import get_version - try: __version__ = get_version() except Exception as e: import warnings - warnings.warn("Failed to read version:\n%s" % str(e), RuntimeWarning, stacklevel=2) + warnings.warn( + "Failed to read version:\n%s" % + str(e), + RuntimeWarning, + stacklevel=2) __version__ = "dev" diff --git a/vllm_mindspore/worker/cache_engine.py b/vllm_mindspore/worker/cache_engine.py index 99fc7693c486fb51188092a3777e8c506e587b5a..a2024c88ab55b96266931f591112ee44d869cf1c 100644 --- a/vllm_mindspore/worker/cache_engine.py +++ b/vllm_mindspore/worker/cache_engine.py @@ -16,21 +16,16 @@ # ============================================================================ """CacheEngine class for managing the KV cache.""" +from vllm_mindspore.utils import (MsKVCache, get_valid_dtype, + is_mindformers_model_backend) +from mindspore import mutable +import mindspore as ms from typing import List from vllm.logger import init_logger logger = init_logger(__name__) -from vllm_mindspore.utils import ( - MsKVCache, - get_valid_dtype, - is_mindformers_model_backend, -) - -import mindspore as ms -from mindspore import mutable - def create_block(shape, dtype, name=None, device=None): from mindspore.ops.function.array_func import empty as empty_tensor @@ -40,9 +35,9 @@ def create_block(shape, dtype, name=None, device=None): def ms_allocate_kv_cache( - self, - num_blocks: int, - device: str, + self, + num_blocks: int, + device: str, ) -> List[MsKVCache]: """Allocates KV cache on the specified device.""" kv_cache_shape = self.attn_backend.get_kv_cache_shape( @@ -52,7 +47,8 @@ def ms_allocate_kv_cache( self.dtype = get_valid_dtype(self.dtype) - # TODO(tronzhang): A shape with (2, ...) for a kv tensor cannot support in mindspore's tensor and block operation, so split it to two tensor. + # TODO(tronzhang): A shape with (2, ...) for a kv tensor cannot support in + # mindspore's tensor and block operation, so split it to two tensor. for _ in range(self.num_attention_layers): device_type = "CPU" if device == "cpu" else "Ascend" current_cache = [] @@ -80,15 +76,14 @@ def ms_swap_out(self, src_to_dst: ms.Tensor) -> None: def cache_engine_init( - self, - cache_config, - model_config, - parallel_config, - device_config, + self, + cache_config, + model_config, + parallel_config, + device_config, ) -> None: - - from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, LayerBlockType from vllm.attention import get_attn_backend + from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, LayerBlockType self.cache_config = cache_config self.model_config = model_config @@ -116,9 +111,9 @@ def cache_engine_init( self.dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype] if ( - is_mindformers_model_backend() - and hasattr(model_config.hf_text_config, "model_type") - and (model_config.hf_text_config.model_type in ("deepseek_v3",)) + is_mindformers_model_backend() + and hasattr(model_config.hf_text_config, "model_type") + and (model_config.hf_text_config.model_type in ("deepseek_v3",)) ): is_mla = True else: diff --git a/vllm_mindspore/worker/model_runner.py b/vllm_mindspore/worker/model_runner.py index e8752d942da2e4c3ffa87f4889de1b2e82448cb3..379f2ad51022c8f3c122fc708c3130a90fd0339b 100644 --- a/vllm_mindspore/worker/model_runner.py +++ b/vllm_mindspore/worker/model_runner.py @@ -18,15 +18,14 @@ from typing import List import torch +from mindspore import mutable from vllm.distributed import get_pp_group from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.sampling_params import SamplingParams from vllm.sequence import SequenceGroupMetadata -from vllm_mindspore.utils import STR_DTYPE_TO_TENSOR_DTYPE -from mindspore.common import dtype as mstype -from mindspore import mutable, Tensor +from vllm_mindspore.utils import STR_DTYPE_TO_TENSOR_DTYPE logger = init_logger(__name__) @@ -34,8 +33,10 @@ LORA_WARMUP_RANK = 8 def _get_cuda_graph_pad_size( - self, num_seqs: int, max_decode_seq_len: int, max_encoder_seq_len: int = 0 -) -> int: + self, + num_seqs: int, + max_decode_seq_len: int, + max_encoder_seq_len: int = 0) -> int: # No need to use cuda graph for mindspore. return -1 @@ -127,7 +128,8 @@ def profile_run(self) -> None: # multiplying the list, to avoid Dynamo from treating them as # tensor aliasing. - # TODO(tronzhang): MindSpore's tensor view is limit now, delete this whole funtion patching latter. + # TODO(tronzhang): MindSpore's tensor view is limit now, delete this whole + # function patching latter. kv_cache_dtype = self.model_config.dtype if self.cache_config.cache_dtype == "auto" \ else self.cache_config.cache_dtype kv_cache_dtype = STR_DTYPE_TO_TENSOR_DTYPE[kv_cache_dtype] diff --git a/vllm_mindspore/worker/worker.py b/vllm_mindspore/worker/worker.py index d1e52a410a0ca439db5fafcf97c0481731e02523..3023504e17cf27517b3f25513ec8714703151b29 100644 --- a/vllm_mindspore/worker/worker.py +++ b/vllm_mindspore/worker/worker.py @@ -18,23 +18,11 @@ """Worker functions""" import gc import os -from typing import Tuple, Optional +from typing import Tuple import torch - -from vllm.config import VllmConfig -from vllm.distributed import ( - ensure_kv_transfer_initialized, - ensure_model_parallel_initialized, - init_distributed_environment, - set_custom_all_reduce, -) - from vllm.logger import init_logger -from vllm_mindspore.utils import is_mindformers_model_backend - - logger = init_logger(__name__) @@ -70,13 +58,15 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: _, total_gpu_memory = torch.cuda.mem_get_info() if os.getenv("vLLM_MODEL_MEMORY_USE_GB"): - memory_use_for_model_run = int(os.environ["vLLM_MODEL_MEMORY_USE_GB"]) * 1024 * 1024 * 1024 + memory_use_for_model_run = ( + int(os.environ["vLLM_MODEL_MEMORY_USE_GB"]) * 1024 * 1024 * 1024 + ) else: # Execute a forward pass with dummy inputs to profile the memory usage # of the model. with memory_profiling( - baseline_memory_in_bytes=total_gpu_memory - self.init_gpu_memory, - weights_memory_in_bytes=self.model_runner.model_memory_usage, + baseline_memory_in_bytes=total_gpu_memory - self.init_gpu_memory, + weights_memory_in_bytes=self.model_runner.model_memory_usage, ) as result: self.model_runner.profile_run() torch.cuda.synchronize() @@ -86,7 +76,7 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: memory_use_for_model_run = result.non_kv_cache_memory_in_bytes memory_for_current_instance = ( - total_gpu_memory * self.cache_config.gpu_memory_utilization + total_gpu_memory * self.cache_config.gpu_memory_utilization ) available_kv_cache_memory = memory_for_current_instance - memory_use_for_model_run @@ -98,7 +88,9 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: num_cpu_blocks = 0 else: num_gpu_blocks = int(available_kv_cache_memory // cache_block_size) - num_cpu_blocks = int(self.cache_config.swap_space_bytes // cache_block_size) + num_cpu_blocks = int( + self.cache_config.swap_space_bytes // + cache_block_size) num_gpu_blocks = max(num_gpu_blocks, 0) num_cpu_blocks = max(num_cpu_blocks, 0)