diff --git a/vllm_mindspore/model_executor/models/qwen3_next.py b/vllm_mindspore/model_executor/models/qwen3_next.py new file mode 100644 index 0000000000000000000000000000000000000000..23a3de6ac4d7514d89d61134b7dce5099655ed66 --- /dev/null +++ b/vllm_mindspore/model_executor/models/qwen3_next.py @@ -0,0 +1,94 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/qwen2/modeling_qwen2.py +# Copyright 2025 Huawei Technologies Co., Ltd. +# Copyright 2024 The Qwen team. +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# type: ignore +# isort:skip_file + +import mindspore as ms +import mindspore.ops as ops +from mindspore import Tensor + +def recurrent_gated_delta_rule( + query, + key, + value, + g, + beta, + initial_state, + output_final_state, + use_qk_l2norm_in_kernel=False, +): + initial_dtype = query.dtype + if use_qk_l2norm_in_kernel: + query = ops.L2Normalize(-1)(query) + key = ops.L2Normalize(-1)(key) + + transpose = ops.Transpose() + query = transpose(query, (0, 2, 1, 3)).astype(ms.float32) + key = transpose(key, (0, 2, 1, 3)).astype(ms.float32) + value = transpose(value, (0, 2, 1, 3)).astype(ms.float32) + beta = transpose(beta, (0, 2, 1, 3)).astype(ms.float32) + g = transpose(g, (0, 2, 1, 3)).astype(ms.float32) + + batch_size, seq_len, num_heads, k_head_dim = key.shape + v_head_dim = value.shape[-1] + scale = 1 / (query.shape[-1] ** 0.5) + query = query * scale + + zeros = ops.Zeros() + core_attn_out = zeros((batch_size, seq_len, num_heads, v_head_dim), value.dtype) + + if initial_state is None: + last_recurrent_state = zeros( + (batch_size, seq_len, k_head_dim, v_head_dim), value.dtypes + ) + else: + last_recurrent_state = initial_state.astype(value.dtype) + + expand_dims = ops.ExpandDims() + for i in range(num_heads): + q_t = query[:, :, i, :] + k_t = key[:, :, i, :] + v_t = value[:, :, i, :] + g_t = expand_dims(expand_dims(ops.exp(g[:, :, i]), -1), -1) + beta_t = expand_dims(beta[:, :, i], -1) + + // update recurrent_state + last_recurrent_state = last_recurrent_state * g_t + kv_mem = ops.reduce_sum(last_recurrent_state* expand_dims(k_t, -1), dim=-2) + delta = (v_t - kv_mem) * beta_t + last_recurrent_state = last_recurrent_state + expand_dims(k_t, -1) * expand_dims(delta, -2) + core_attn_out[:, :, i] = ops.reduce_sum(last_recurrent_state * expand_dims(q_t, -1), dim=-2) + + if not output_final_state: + last_recurrent_state = None + else: + last_recurrent_state = last_recurrent_state.astype(initial_dtype) + + core_attn_out = transpose(core_attn_out, (0, 2, 1, 3)).astype(initial_dtype) + return core_attn_out, last_recurrent_state + +