diff --git a/vllm_mindspore/model_executor/layers/rotary_embedding.py b/vllm_mindspore/model_executor/layers/rotary_embedding.py index ade162460bddf1e3e9de4e5d6809716bc65e4004..51962dbc97f1c45d2806c7fdd6dd00e4ec52fbbf 100644 --- a/vllm_mindspore/model_executor/layers/rotary_embedding.py +++ b/vllm_mindspore/model_executor/layers/rotary_embedding.py @@ -49,16 +49,6 @@ def _get_feat_extract_output_lengths(input_lengths: ms.Tensor): ) return feat_lengths, output_lengths -def apply_interleaved_rope(x: Tensor, mrope_section: list[int]) -> Tensor: - """Apply interleaved MRoPE to 3D rotary embeddings. - Reorganizes frequency layout from chunked [TTT...HHH...WWW] to - interleaved [THTHWHTHW...TT], preserving frequency continuity. - """ - x_t = x[0].clone() - x_t[..., 1 : mrope_section[1] * 3 : 3] = x[1, ..., 1 : mrope_section[1] * 3 : 3] - x_t[..., 2 : mrope_section[2] * 3 : 3] = x[2, ..., 2 : mrope_section[2] * 3 : 3] - return x_t - def _apply_rotary_emb( x: Tensor, cos: Tensor, @@ -299,40 +289,35 @@ class MRotaryEmbedding(RotaryEmbedding): self.cache_max_position_num = max_position_embeddings * 4 super().__init__(head_size, rotary_dim, self.cache_max_position_num, base, is_neox_style, dtype) - self.inv_freq = 1.0 / (base ** (mint.arange( - 0, rotary_dim, 2, dtype=mstype.float32)[: (rotary_dim // 2)] / self.rotary_dim)) self.mrope_section = mrope_section if self.mrope_section: assert sum(self.mrope_section) == rotary_dim // 2 self.mrope_interleaved = mrope_interleaved - - def get_freqs(self, position_ids): - inv_freq_expanded = self.inv_freq[None, :, None].float().expand(3, -1, 1).asnumpy() - position_ids = position_ids.float().asnumpy() - position_ids_expanded = position_ids[:, None, :] - freqs = (inv_freq_expanded @ position_ids_expanded).transpose(0, 2, 1) - freqs = self.apply_interleaved_mrope(freqs, self.mrope_section) - freqs = Tensor.from_numpy(freqs.astype(np.float32)) - return freqs - - def apply_interleaved_mrope(self, freqs, mrope_section): + if self.mrope_interleaved: + assert len(self.mrope_section) == 3 + mrope_section_np = np.array(self.mrope_section, dtype=np.int64) + sec_total = mrope_section_np.sum() + h_sec = np.array(list(range(1, self.mrope_section[1] * 3, 3))) + sec_total + w_sec = np.array(list(range(2, self.mrope_section[2] * 3, 3))) + 2 * sec_total + select_index = np.arange(sec_total, dtype=np.int64) + select_index[1 : mrope_section[1] * 3 : 3] = h_sec + select_index[2 : mrope_section[2] * 3 : 3] = w_sec + self.rope_select_index = ms.from_numpy(select_index) + + if self.is_neox_style and self.rotary_dim == self.head_size: + self.rotary_embedding_op = ops.ApplyRotaryPosEmb(2) + + def apply_interleaved_rope(self, x: Tensor, mrope_section: list[int]) -> Tensor: """Apply interleaved MRoPE to 3D rotary embeddings. Reorganizes frequency layout from chunked [TTT...HHH...WWW] to interleaved [THTHWHTHW...TT], preserving frequency continuity. - args: - x: (3, bs, seq_len, head_dim // 2) - mrope_section: (3,) - returns: - x_t: (bs, seq_len, head_dim // 2) """ - freqs_t = freqs[0] # just overwrite the first dimension T - for dim, offset in enumerate((1, 2), start=1): # H, W - length = mrope_section[dim] * 3 - idx = slice(offset, length, 3) #难道说cos,sin也是不连续导致的 - freqs_t[..., idx] = freqs[dim, ..., idx] - return freqs_t + x = ops.transpose(x, (1, 0, 2)) + x = mint.flatten(x, start_dim=1) + x_t = mint.index_select(x, -1, self.rope_select_index) + return x_t def construct( self, @@ -340,7 +325,6 @@ class MRotaryEmbedding(RotaryEmbedding): query: mindspore.Tensor, key: mindspore.Tensor, batch_valid_length: Tensor = None, - freqs: mindspore.Tensor = None, ) -> tuple[mindspore.Tensor, mindspore.Tensor]: """ Args: @@ -357,10 +341,29 @@ class MRotaryEmbedding(RotaryEmbedding): # cos_sin: (3, 5120, rotary_dim) # noqa: ERA001 # cos/sin: cat[(1, 5120, mrope_sec),...] -> (1, 5120, rotary_dim//2) ###################################################################### - # emb = mint.cat((freqs, freqs), dim=-1).view(-1, self.rotary_dim) num_tokens = positions.shape[-1] - cos = freqs.cos() - sin = freqs.sin() + cos_sin = self.cos_sin_cache[positions] + cos, sin = ops.chunk(cos_sin, 2, axis=-1) + if positions.ndim == 2: + if self.mrope_interleaved: + cos = self.apply_interleaved_rope(cos, self.mrope_section) + sin = self.apply_interleaved_rope(sin, self.mrope_section) + else: + cos_l = mint.split(cos, self.mrope_section, dim=-1) + sin_l = mint.split(sin, self.mrope_section, dim=-1) + cos, sin = (), () + for i in range(len(self.mrope_section)): # type: ignore[arg-type] + cos += (cos_l[i][i], ) + sin += (sin_l[i][i], ) + cos = mint.cat(cos, dim=-1) + sin = mint.cat(sin, dim=-1) + + if self.is_neox_style and self.rotary_dim == self.head_size: + freqs_cos = mint.cat((cos, cos), dim=-1) + freqs_sin = mint.cat((sin, sin), dim=-1) + query, key = self.rotary_embedding_op(query, key, freqs_cos, freqs_sin, + batch_valid_length) + return query, key query_shape = query.shape query = query.view(num_tokens, -1, self.head_size) diff --git a/vllm_mindspore/model_executor/models/qwen3_moe.py b/vllm_mindspore/model_executor/models/qwen3_moe.py index 5dfec8ff3ca53419e3c98acf12cbf23bb69123c8..5c0260e40ca9fecdd5fb1c5b62aee95c972a3130 100644 --- a/vllm_mindspore/model_executor/models/qwen3_moe.py +++ b/vllm_mindspore/model_executor/models/qwen3_moe.py @@ -243,7 +243,6 @@ class Qwen3MoeAttention(nn.Cell): batch_valid_length: Tensor, q_seq_lens: Tensor, block_tables: Tensor, - freqs: Optional[Tensor] = None ) -> Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) @@ -257,7 +256,7 @@ class Qwen3MoeAttention(nn.Cell): self.head_dim) k_by_head = self.k_norm(k_by_head) k = k_by_head.view(k.shape) - q, k = self.rotary_emb(positions, q, k, batch_valid_length, freqs) + q, k = self.rotary_emb(positions, q, k, batch_valid_length) attn_output = self.attn(q, k, v, key_cache, value_cache, slot_mapping, attn_mask, batch_valid_length, q_seq_lens, block_tables) @@ -332,7 +331,6 @@ class Qwen3MoeDecoderLayer(nn.Cell): dp_unpad_index: Optional[Tensor] = None, dp_pad_index_with_offset: Optional[Tensor] = None, dp_unpad_index_total_with_offset: Optional[Tensor] = None, - freqs: Optional[Tensor] = None ) -> Tensor: # Self Attention if residual is None: @@ -344,7 +342,7 @@ class Qwen3MoeDecoderLayer(nn.Cell): hidden_states = self.self_attn(positions, hidden_states, key_cache, value_cache, slot_mapping, attn_mask, batch_valid_length, q_seq_lens, - block_tables, freqs) + block_tables) # Fully Connected hidden_states, residual = self.post_attention_layernorm( hidden_states, residual) diff --git a/vllm_mindspore/model_executor/models/qwen3_omni_moe_thinker.py b/vllm_mindspore/model_executor/models/qwen3_omni_moe_thinker.py index 44b1a8f9cd3b62e4b346f4d8728b5690f392272a..43ed3ef78065a0082c89a158f59d891b04069c37 100644 --- a/vllm_mindspore/model_executor/models/qwen3_omni_moe_thinker.py +++ b/vllm_mindspore/model_executor/models/qwen3_omni_moe_thinker.py @@ -584,7 +584,6 @@ class Qwen3MoeLLMModel(Qwen3MoeModel): intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[Tensor] = None, deepstack_input_embeds: Optional[Mapping[str, Tensor]] = None, - freqs: Optional[Tensor] = None, ) -> ms.Tensor | IntermediateTensors: if get_pp_group().is_first_rank: if inputs_embeds is not None: @@ -612,7 +611,7 @@ class Qwen3MoeLLMModel(Qwen3MoeModel): positions, hidden_states, key_caches[i - self.start_layer], value_caches[i - self.start_layer], slot_mapping, attn_mask, batch_valid_length, q_seq_lens, block_tables, residual, - None, None, None, None, freqs) + None, None, None, None) if deepstack_input_embeds is not None and i in range(self.deepstack_layers): hidden_states = mint.add(hidden_states, deepstack_input_embeds[i]) @@ -1254,6 +1253,13 @@ class Qwen3OmniMoeThinkerForConditionalGeneration( self.lm_head = self.language_model.lm_head self.common_preprocess(vllm_config, prefix) + self.model.embed_tokens._set_jit_graph_name("prefill") + self.model.embed_tokens.phase = "prefill" + dyn_input_ids = ms.Tensor(shape=[None], dtype=ms.int32) + self.model.embed_tokens.set_inputs(dyn_input_ids) + self.model.embed_tokens.construct = ms.jit( + function=self.model.embed_tokens, jit_level='O0') + self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors ) @@ -1502,8 +1508,6 @@ class Qwen3OmniMoeThinkerForConditionalGeneration( else: deepstack_input_embeds = None - - freqs = self.language_model.model.layers[0].self_attn.rotary_emb.get_freqs(positions) hidden_states = self.exec_model( input_ids=input_ids, positions=positions, @@ -1511,7 +1515,6 @@ class Qwen3OmniMoeThinkerForConditionalGeneration( inputs_embeds=inputs_embeds, # args for deepstack deepstack_input_embeds=deepstack_input_embeds, - freqs=freqs ) if inputs_embeds is not None and get_pp_group().is_first_rank: @@ -2006,13 +2009,12 @@ class Qwen3OmniMoeThinkerForConditionalGeneration( dyn_block_tables = Tensor(shape=[None, None], dtype=mstype.int32) dyn_deepstack_input_embeds = Tensor(shape=[None, None, None], dtype=self.model_config.dtype) - dyn_freqs = Tensor(shape=[None, None], dtype=mstype.float32) self.ready_model.set_inputs( dyn_input_ids, dyn_position_ids, dyn_key_caches, dyn_value_caches, dyn_slot_mapping, dynamic_attention_mask, dyn_batch_valid_length, dyn_q_seq_lens, dyn_block_tables, - dyn_intermediate_tensors, dyn_inputs_embeds, dyn_deepstack_input_embeds, dyn_freqs) + dyn_intermediate_tensors, dyn_inputs_embeds, dyn_deepstack_input_embeds) dynamic_hidden_states = Tensor(shape=[None, None], dtype=self.model_config.dtype)