diff --git a/omni/adaptors/vllm/worker/npu_model_runner.py b/omni/adaptors/vllm/worker/npu_model_runner.py index 1bed68c7f0e1edbface6b6b2a5b63e06c198a71f..85b9b668c186482a479f59b207df03c1e9119acb 100644 --- a/omni/adaptors/vllm/worker/npu_model_runner.py +++ b/omni/adaptors/vllm/worker/npu_model_runner.py @@ -166,7 +166,6 @@ class NPUModelRunner(GPUModelRunner): device="cpu", pin_memory=is_pin_memory_available()) self.seq_lens_np = self.seq_lens_cpu.numpy() - # TODO: support arbitrary spec tokens self.graph_block_tables = np.zeros( (self.max_num_reqs if not self.use_spec_decode else self.max_num_reqs * (1 + self.speculative_config.num_speculative_tokens), (self.model_config.max_model_len + self.block_size - 1) // @@ -820,7 +819,35 @@ class NPUModelRunner(GPUModelRunner): finished_recving=finished_recving, ) return model_runner_output - + + def _simple_advance_step(self, + attn_metadata, + block_size: int, + positions: torch.Tensor, + ) -> None: + token_each_reqs = 1 + self.speculative_config.num_speculative_tokens + num_reqs = self.input_batch.num_reqs + num_tokens = token_each_reqs * num_reqs + positions[:num_tokens] += 1 + + req_indices = torch.repeat_interleave(torch.arange(num_reqs, device=self.device), token_each_reqs, dim=0) + block_table: BlockTable = self.input_batch.block_table[-1] + block_table_indices = ( + req_indices * block_table.max_num_blocks_per_req + + positions[:num_tokens] // block_size) + block_table_cpu = block_table.get_device_tensor() + block_numbers = block_table_cpu.flatten()[block_table_indices] + block_offsets = positions[:num_tokens] % block_size + block_table.slot_mapping[:num_tokens] = block_numbers * block_size + block_offsets + + attn_metadata.slot_mapping[:num_tokens] = block_table.slot_mapping[:num_tokens] + input_positions = positions[:num_tokens] + attn_metadata.decode.input_positions[:num_tokens] = input_positions + attn_metadata.decode.seq_lens[:num_tokens] = (input_positions + 1).to(self.seq_lens.dtype) + cos, sin = self.model.model.layers[0].self_attn.rotary_emb.get_cos_sin(attn_metadata.decode.input_positions) + attn_metadata.decode.cos = cos + attn_metadata.decode.sin = sin + @torch.inference_mode() def run_mtp(self, attn_metadata, scheduler_output, input_ids, raw_hidden_states, mtp_input_tokens, positions, sample_indices, last_accepted_index): attn_state = next(iter(attn_metadata.values())).attn_state @@ -833,10 +860,11 @@ class NPUModelRunner(GPUModelRunner): if not self.drafter_mark_static: torch._dynamo.mark_static(mtp_input_tokens) torch._dynamo.mark_static(raw_hidden_states) - mtp_logits, mtp_hidden_states = self.compile_drafter_list[layer_idx]( + mtp_layer_idx = 0 if self.num_mtp_layers == 1 else layer_idx + mtp_logits, mtp_hidden_states = self.compile_drafter_list[mtp_layer_idx]( input_ids=mtp_input_tokens.to(torch.long), positions=positions, - kv_caches=self.kv_caches[-self.speculative_config.num_speculative_tokens + layer_idx:], + kv_caches=self.kv_caches[-self.num_mtp_layers + mtp_layer_idx:], attn_metadata=attn_metadata, previous_hidden_states=raw_hidden_states, intermediate_tensors=None, @@ -850,6 +878,13 @@ class NPUModelRunner(GPUModelRunner): mtp_input_tokens[:-1] = mtp_input_tokens.clone()[1:] mtp_input_tokens[last_accepted_index] = mtp_forward_tokens raw_hidden_states = mtp_hidden_states + + if self.num_mtp_layers == 1: + mtp_layer_name = list(attn_metadata.keys())[-1] + mtp_attn_metadata = attn_metadata[mtp_layer_name] + self._simple_advance_step(mtp_attn_metadata, + self.vllm_config.cache_config.block_size, + positions) self.drafter_mark_static = True else: @@ -858,23 +893,27 @@ class NPUModelRunner(GPUModelRunner): self.vllm_config, num_tokens=scheduler_output.total_num_scheduled_tokens): for layer_idx in range(self.speculative_config.num_speculative_tokens): - mtp_logits, mtp_hidden_states = self.drafter_list[layer_idx]( + mtp_layer_idx = 0 if self.num_mtp_layers == 1 else layer_idx + mtp_logits, mtp_hidden_states = self.drafter_list[mtp_layer_idx]( input_ids=mtp_input_tokens.to(torch.long), positions=positions, - kv_caches=self.kv_caches[-self.speculative_config.num_speculative_tokens + layer_idx:], + kv_caches=self.kv_caches[-self.num_mtp_layers + mtp_layer_idx:], attn_metadata=attn_metadata, previous_hidden_states=raw_hidden_states, - prefill_padding_or_selected_indices=sample_indices, + prefill_padding_or_selected_indices=last_accepted_index, intermediate_tensors=None, inputs_embeds=None, require_hidden_states=True, ) mtp_forward_tokens = mtp_logits[last_accepted_index].argmax(dim=-1) + if self.num_mtp_layers == 1: + mtp_forward_token_list.extend([mtp_forward_tokens] * self.speculative_config.num_speculative_tokens) + break mtp_forward_token_list.append(mtp_forward_tokens) if layer_idx == self.speculative_config.num_speculative_tokens - 1: continue mtp_input_tokens[:-1] = mtp_input_tokens.clone()[1:] - mtp_input_tokens[sample_indices] = mtp_forward_tokens + mtp_input_tokens[last_accepted_index] = mtp_forward_tokens raw_hidden_states = mtp_hidden_states return torch.stack(mtp_forward_token_list, dim=1) @@ -921,7 +960,8 @@ class NPUModelRunner(GPUModelRunner): raw_hidden_states, hidden_states = forward_results if self.use_spec_decode and self.speculative_config.method in ('mtp'): for layer_idx in range(self.speculative_config.num_speculative_tokens): - self.drafter_list[layer_idx]( + mtp_layer_idx = 0 if self.num_mtp_layers == 1 else layer_idx + self.drafter_list[mtp_layer_idx]( input_ids=input_ids, positions=positions, kv_caches=None, @@ -977,10 +1017,11 @@ class NPUModelRunner(GPUModelRunner): torch._dynamo.mark_static(raw_hidden_states) self.dummy_drafter_mark_static = True for layer_idx in range(self.speculative_config.num_speculative_tokens): - self.compile_drafter_list[layer_idx]( + mtp_layer_idx = 0 if self.num_mtp_layers == 1 else layer_idx + self.compile_drafter_list[mtp_layer_idx]( input_ids=input_ids, positions=positions, - kv_caches=self.kv_caches[-self.speculative_config.num_speculative_tokens + layer_idx:] if self.kv_caches else None, + kv_caches=self.kv_caches[-self.num_mtp_layers + mtp_layer_idx:] if self.kv_caches else None, attn_metadata=attn_metadata, previous_hidden_states=raw_hidden_states, intermediate_tensors=None, @@ -1003,10 +1044,11 @@ class NPUModelRunner(GPUModelRunner): attn_metadata=attn_metadata) if self.use_spec_decode and self.speculative_config.method in ('mtp'): for layer_idx in range(self.speculative_config.num_speculative_tokens): - self.drafter_list[layer_idx]( + mtp_layer_idx = 0 if self.num_mtp_layers == 1 else layer_idx + self.drafter_list[mtp_layer_idx]( input_ids=input_ids, positions=positions, - kv_caches=self.kv_caches[-self.speculative_config.num_speculative_tokens + layer_idx:] if self.kv_caches else None, + kv_caches=self.kv_caches[-self.num_mtp_layers + mtp_layer_idx:] if self.kv_caches else None, attn_metadata=attn_metadata, previous_hidden_states=raw_hidden_states, intermediate_tensors=None, @@ -1038,7 +1080,7 @@ class NPUModelRunner(GPUModelRunner): logger.info("Loading mtp model...") original_arch = self.model_config.hf_config.architectures # ['DeepseekV3ForCausalLM'] original_type = self.model_config.hf_config.model_type # 'deepseek_v3' - + self.num_mtp_layers = min(self.speculative_config.num_speculative_tokens, self.model_config.hf_config.num_nextn_predict_layers) self.drafter_list = [] architecture_list = ["DeepSeekMTPModel", "DeepSeekMTPModelDuo", "DeepSeekMTPModelTres"] for mtp_layer_idx in range(self.model_config.hf_config.num_nextn_predict_layers): @@ -1051,7 +1093,6 @@ class NPUModelRunner(GPUModelRunner): self.drafter_list.append(drafter) self.model_config.hf_config.architectures = original_arch self.model_config.hf_config.model_type = original_type - # zxp TODO: check if fusion_spec.py from line 90 needed? if not int(os.getenv("NO_NPU_MOCK", "0")): logger.info("Loading model weights took %.4f GB", m.consumed_memory / float(2**30)) @@ -1078,9 +1119,9 @@ class NPUModelRunner(GPUModelRunner): backend=npu_backend) if hasattr(self, "drafter"): self.compile_drafter_list = [] - for layer_idx in range(self.speculative_config.num_speculative_tokens): + for mtp_layer_idx in range(self.num_mtp_layers): self.compile_drafter_list.append(torch.compile( - self.drafter_list[layer_idx], + self.drafter_list[mtp_layer_idx], dynamic=True, fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE, backend=npu_backend)) @@ -1090,8 +1131,8 @@ class NPUModelRunner(GPUModelRunner): if hasattr(self, "drafter"): wrap_list = [WrapDrafter, WrapDrafterDuo, WrapDrafterTres] self.compile_drafter_list = [] - for idx in range(self.speculative_config.num_speculative_tokens): - self.compile_drafter_list.append(wrap_list[idx](self.drafter_list[idx], self.decode_gear_list)) + for mtp_layer_idx in range(self.num_mtp_layers): + self.compile_drafter_list.append(wrap_list[mtp_layer_idx](self.drafter_list[mtp_layer_idx], self.decode_gear_list)) if model_extra_config.operator_opt_config.use_omni_placement: param_dict = dict(self.model.named_parameters()) self.planner = OmniPlanner(config_file= model_extra_config.operator_opt_config.omni_placement_config_path) diff --git a/tools/scripts/pd_run.sh b/tools/scripts/pd_run.sh index 5b087c8be7ac459d01c21118b815cff944056bff..cc8396fa6300a5da5db442c431500b71e74bed03 100644 --- a/tools/scripts/pd_run.sh +++ b/tools/scripts/pd_run.sh @@ -53,6 +53,7 @@ ADDITIONAL_CONFIG="" VLLM_ENABLE_MC2=0 HCCL_BUFFSIZE=0 HCCL_OP_EXPANSION_MODE="" +NUM_SPECULATIVE_TOKENS=1 # Help information print_help() { @@ -100,6 +101,7 @@ print_help() { echo " --vllm-enable-mc2 vLLM framework: GRAPH parameter (default: $VLLM_ENABLE_MC2)" echo " --hccl-op-expansion-mode vLLM framework: HCCL_OP_EXPANSION_MODE" echo " --hccl-buffsize vLLM framework: HCCL_BUFFSIZE" + echo " --num-speculative-tokens vLLM framework: MTP parameter, number of speculative tokens per step (default: $NUM_SPECULATIVE_TOKENS)" exit 0 } @@ -236,6 +238,9 @@ parse_long_option() { --hccl-op-expansion-mode) HCCL_OP_EXPANSION_MODE="$2" ;; + --num-speculative-tokens) + NUM_SPECULATIVE_TOKENS="$2" + ;; --help) print_help ;; @@ -393,6 +398,7 @@ common_operations() { --gpu-util "$GPU_UTIL" \ --additional-config "$ADDITIONAL_CONFIG" \ --enable-mtp \ + --num-speculative-tokens "$NUM_SPECULATIVE_TOKENS" \ --extra-args "$EXTRA_ARGS" }