From 76b51e71e9e0f4f8fabc23c9d884c5050668d663 Mon Sep 17 00:00:00 2001 From: zxq <342239412@qq.com> Date: Thu, 20 Nov 2025 17:51:16 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90master=E3=80=91=E5=B0=86mint.split?= =?UTF-8?q?=E7=AE=97=E5=AD=90=E6=9B=BF=E6=8D=A2=E4=B8=BAops.function.array?= =?UTF-8?q?=5Ffunc.split=5Fext?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../inference/transformer/attention.py | 9 ++- .../inference/transformer/mlp.py | 4 +- .../inference/transformer/moe/experts.py | 4 +- .../transformer/multi_latent_attention.py | 19 +++--- research/deepseek3/deepseek3_model_infer.py | 60 ++++++++++--------- research/deepseek3/infer/transformer.py | 6 +- research/deepseek3/moe.py | 26 ++++---- research/llama3_1/infer/transformer.py | 7 ++- research/qwen2_5/infer/transformer.py | 24 ++++---- 9 files changed, 83 insertions(+), 76 deletions(-) diff --git a/mindformers/parallel_core/inference/transformer/attention.py b/mindformers/parallel_core/inference/transformer/attention.py index a38416a9b..896974d9a 100644 --- a/mindformers/parallel_core/inference/transformer/attention.py +++ b/mindformers/parallel_core/inference/transformer/attention.py @@ -24,7 +24,7 @@ from dataclasses import dataclass import math from typing import Union, Optional -from mindspore import mint, nn, ops +from mindspore import nn, ops from mindformers.parallel_core.inference.quantization import QuantizationConfig from mindformers.parallel_core.inference.transformer.identity_op import IdentityOp @@ -146,13 +146,12 @@ class Attention(nn.Cell): self.tp_group_size = self.tp.size self.num_attention_heads_per_partition = divide(self.num_heads, self.tp_group_size) - self.use_gqa = (self.num_heads != self.num_query_groups) + self.use_gqa = self.num_heads != self.num_query_groups if self.use_gqa: self._check_gqa_valid() # Note: Special handling when kv heads is less than tp size - if self.num_query_groups < self.tp_group_size: - self.num_query_groups = self.tp_group_size + self.num_query_groups = max(self.num_query_groups, self.tp_group_size) self.num_query_groups_per_partition = divide(self.num_query_groups, self.tp_group_size) self.repeat_num = divide(self.num_heads, self.num_query_groups) else: @@ -370,7 +369,7 @@ class SelfAttention(Attention): def get_query_key_value_tensors(self, hidden_states): qkv = self.cast(self.linear_qkv(hidden_states), self.compute_dtype) - query, key, value = mint.split(qkv, + query, key, value = ops.function.array_func.split_ext(qkv, (self.hidden_size_per_partition, self.kv_hidden_size_per_partition, self.kv_hidden_size_per_partition), -1) diff --git a/mindformers/parallel_core/inference/transformer/mlp.py b/mindformers/parallel_core/inference/transformer/mlp.py index 7c9b08c46..1f63370e2 100644 --- a/mindformers/parallel_core/inference/transformer/mlp.py +++ b/mindformers/parallel_core/inference/transformer/mlp.py @@ -21,7 +21,7 @@ __all__ = [ from dataclasses import dataclass from typing import Union, Optional -from mindspore import nn, mint +from mindspore import nn, mint, ops from mindformers.parallel_core.inference.quantization import QuantizationConfig from mindformers.parallel_core.transformer_config import TransformerConfig @@ -157,7 +157,7 @@ class MLP(nn.Cell): intermediate_parallel = self.linear_fc1(hidden_states) if self.config.gated_linear_unit: - gate, hidden = mint.split(intermediate_parallel, + gate, hidden = ops.function.array_func.split_ext(intermediate_parallel, (self.ffn_hidden_size_per_partition, self.ffn_hidden_size_per_partition), -1) gate = self.activation_func(gate) if self.activation_type else gate diff --git a/mindformers/parallel_core/inference/transformer/moe/experts.py b/mindformers/parallel_core/inference/transformer/moe/experts.py index 6e77b3d47..e03976761 100644 --- a/mindformers/parallel_core/inference/transformer/moe/experts.py +++ b/mindformers/parallel_core/inference/transformer/moe/experts.py @@ -18,7 +18,7 @@ __all__ = ["GroupedMLP"] from typing import Optional -from mindspore import mint, nn +from mindspore import mint, nn, ops from mindformers.parallel_core.transformer_config import TransformerConfig from mindformers.parallel_core.utils.spec_utils import build_module @@ -125,7 +125,7 @@ class GroupedMLP(nn.Cell): intermediate_parallel = self.linear_fc1(hidden_states, group_list=group_list) if self.config.gated_linear_unit: - gate, hidden = mint.split(intermediate_parallel, + gate, hidden = ops.function.array_func.split_ext(intermediate_parallel, (self.ffn_hidden_size_per_partition, self.ffn_hidden_size_per_partition), -1) gate = self.activation_func(gate) if self.activation_type else gate diff --git a/mindformers/parallel_core/inference/transformer/multi_latent_attention.py b/mindformers/parallel_core/inference/transformer/multi_latent_attention.py index df402ff6e..bb704a603 100644 --- a/mindformers/parallel_core/inference/transformer/multi_latent_attention.py +++ b/mindformers/parallel_core/inference/transformer/multi_latent_attention.py @@ -363,7 +363,7 @@ class MLASelfAttention(MultiLatentAttention): Process the weight after loading. This can be used for example, to transpose weights for computation. """ - q_absorb, out_absorb = mint.split(self.linear_kv_up_proj.weight, + q_absorb, out_absorb = ops.function.array_func.split_ext(self.linear_kv_up_proj.weight, [self.num_attention_heads_per_partition * self.config.qk_head_dim, self.num_attention_heads_per_partition * self.config.v_head_dim], -2) self.q_absorb = q_absorb.reshape(self.num_attention_heads_per_partition, @@ -384,7 +384,7 @@ class MLASelfAttention(MultiLatentAttention): hidden_states = self.input_layernorm(hidden_states) if self.config.q_lora_rank is not None: qkv = self.linear_qkv_down_proj(hidden_states) - kv_compressed, k_pos_emb, q_compressed = mint.split(qkv, + kv_compressed, k_pos_emb, q_compressed = ops.function.array_func.split_ext(qkv, [self.config.kv_lora_rank, self.config.qk_pos_emb_head_dim, self.config.q_lora_rank], @@ -404,7 +404,7 @@ class MLASelfAttention(MultiLatentAttention): if kv_combined.shape[-1] != self.config.kv_lora_rank + self.config.qk_pos_emb_head_dim: # the shape of kv_combined is [s, b, (kv_lora_rank + qk_pos_emb_head_dim)] kv_combined = gather_from_model_parallel_region(q_compressed, self.tp) - kv_compressed, k_pos_emb = mint.split( + kv_compressed, k_pos_emb = ops.function.array_func.split_ext( kv_combined, [self.config.kv_lora_rank, self.config.qk_pos_emb_head_dim], dim=-1 ) # the shape of q is [num_tokens, n * (qk_head_dim + qk_pos_emb_head_dim)] @@ -413,7 +413,8 @@ class MLASelfAttention(MultiLatentAttention): # the shape of q is [num_tokens, n, q_head_dim] q = q.reshape(*q.shape[:-1], self.num_attention_heads_per_partition, self.q_head_dim) # the shape of q_no_pe is [num_tokens, n, qk_head_dim], q_pos_emb: [num_tokens, n, qk_pos_emb_head_dim] - q_no_pe, q_pos_emb = mint.split(q, [self.config.qk_head_dim, self.config.qk_pos_emb_head_dim], dim=-1) + q_no_pe, q_pos_emb = ops.function.array_func.split_ext( + q, [self.config.qk_head_dim, self.config.qk_pos_emb_head_dim], dim=-1) # the shape of kv_compressed is [num_tokens, kv_lora_rank] kv_compressed = self.kv_layernorm(kv_compressed) @@ -443,8 +444,9 @@ class MLASelfAttention(MultiLatentAttention): # the shape of k_no_pe is [num_tokens, qk_head_dim * self.kv_num_heads_per_partition], # the shape of value is [num_tokens, v_head_dim * self.kv_num_heads_per_partition] - k_no_pe, value = mint.split(kv, [self.config.qk_head_dim * self.kv_num_heads_per_partition, - self.config.v_head_dim * self.kv_num_heads_per_partition], dim=-1) + k_no_pe, value = ops.function.array_func.split_ext( + kv, [self.config.qk_head_dim * self.kv_num_heads_per_partition, + self.config.v_head_dim * self.kv_num_heads_per_partition], dim=-1) k_no_pe = k_no_pe.reshape(-1, self.kv_num_heads_per_partition, self.config.qk_head_dim) # the shape of value_states is [num_tokens, n, v_head_dim] @@ -531,7 +533,7 @@ class FusedMLASelfAttention(MLASelfAttention): self.is_modelslim = quant_config.is_modelslim self.fa3_quant = quant_config.fa3_quant self.fa3_quant_layer = quant_config.fa3_quant_layer - self.is_fa3_quant_layer = (layer_number - 1) in self.fa3_quant_layer # layer_number start from 1 + self.is_fa3_quant_layer = layer_number - 1 in self.fa3_quant_layer # layer_number start from 1 self.input_layernorm_weight = None self.qkv_down_proj_input_scale = None self.q_layernorm_weight = None @@ -542,6 +544,7 @@ class FusedMLASelfAttention(MLASelfAttention): self.q_up_proj_input_offset = None self.input_format = 1 if self.fa3_quant else 0 self.use_ringmla = use_ms_custom_ops() and get_tensor_model_parallel_world_size() < 16 + # pylint: disable=C0415 import ms_custom_ops self.ms_custom_ops = ms_custom_ops self.scale_value = 1 / math.sqrt(self.config.kv_lora_rank + self.config.qk_head_dim) \ @@ -793,7 +796,7 @@ class FusedMLASelfAttention(MLASelfAttention): k_cache = self.transpose(key_cache.reshape(-1, self.config.kv_lora_rank // 32, \ self.config.block_size, 32), (0, 2, 1, 3)).reshape( \ -1, self.config.block_size, self.config.kv_lora_rank) - k_cache = (self.cast(k_cache, dtype.bfloat16) / self.quant_ctkv_scale) + k_cache = self.cast(k_cache, dtype.bfloat16) / self.quant_ctkv_scale else: k_cache = self.ms_custom_ops.trans_data(key_cache, transdata_type=0) v_cache = self.ms_custom_ops.trans_data(value_cache, transdata_type=0) diff --git a/research/deepseek3/deepseek3_model_infer.py b/research/deepseek3/deepseek3_model_infer.py index a491d5c08..f05297b17 100644 --- a/research/deepseek3/deepseek3_model_infer.py +++ b/research/deepseek3/deepseek3_model_infer.py @@ -37,6 +37,13 @@ try: from mindspore._checkparam import Validator except ImportError: import mindspore._checkparam as Validator + +from research.deepseek3.deepseek3_config import DeepseekV3Config +from research.deepseek3.moe import ExpertParallelMoE, ParallelMoEV2, RoutedParallelMLP, SharedMLP, SharedParallelMLP +from research.deepseek3.utils import convert_model_config +from research.deepseek3.infer.norm import RMSNorm +from research.deepseek3.infer.transformer import ParallelMLP, VocabEmbedding +from research.deepseek3.infer.layers import ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding from mindformers.models.modeling_utils import PreTrainedModel from mindformers.models.utils import lazy_inline, check_fine_grain_interleave_valid, predict_lazy_inline,\ jit @@ -59,13 +66,8 @@ from mindformers.parallel_core.inference.tensor_parallel.mappings import (gather reduce_scatter_to_model_parallel_region, scatter_to_model_parallel_region) from mindformers.version_control import is_910b +from mindformers.parallel_core.inference.parallel_state import get_data_parallel_group -from research.deepseek3.deepseek3_config import DeepseekV3Config -from research.deepseek3.moe import ExpertParallelMoE, ParallelMoEV2, RoutedParallelMLP, SharedMLP, SharedParallelMLP -from research.deepseek3.utils import convert_model_config -from research.deepseek3.infer.norm import RMSNorm -from research.deepseek3.infer.transformer import ParallelMLP, VocabEmbedding -from research.deepseek3.infer.layers import ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding __all__ = ['InferenceDeepseekV3ForCausalLM', 'DeepseekV3Model'] @@ -249,7 +251,7 @@ class MLAInferAttention(nn.Cell): prefill_head_dim=None, config: DeepseekV3Config = None ): - super(MLAInferAttention, self).__init__() + super().__init__() self.n_head = n_head self.head_dim = head_dim self.n_kv_head = n_kv_head @@ -438,14 +440,13 @@ class DeepseekV3Attention(nn.Cell): raise ValueError("For 'DeepseekV3Attention', the use_flash_attention must be enabled.") if self.hidden_size % self.n_head != 0: - raise ValueError("For 'MultiHeadAttention', the class variable 'hidden_size' must be a multiple " - "of 'n_head', but got the hidden_size is {} and the n_head is {}." - .format(self.hidden_size, self.n_head)) + raise ValueError(f"For 'MultiHeadAttention', the class variable 'hidden_size' must be a multiple " + f"of 'n_head', but got the hidden_size is {self.hidden_size} and " + f"the n_head is {self.n_head}.") if self.n_kv_head % parallel_config.model_parallel != 0: - raise ValueError("For 'MultiHeadAttention', the class variable 'n_kv_head' must be a multiple of " - "'parallel_config.model_parallel', but got the n_kv_head is {} " - "and the parallel_config.model_parallel is {}." - .format(self.n_kv_head, parallel_config.model_parallel)) + raise ValueError(f"For 'MultiHeadAttention', the class variable 'n_kv_head' must be a multiple of " + f"'parallel_config.model_parallel', but got the n_kv_head is {self.n_kv_head} " + f"and the parallel_config.model_parallel is {parallel_config.model_parallel}.") self.shape = P.Shape() self.cast = P.Cast() if self.q_lora_rank == 0: @@ -572,12 +573,13 @@ class DeepseekV3Attention(nn.Cell): if self.q_lora_rank == 0: q = self.q_proj(x) latent_kv_all = self.kv2l(x) - latent_kv, k_pe = mint.split(latent_kv_all, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + latent_kv, k_pe = ops.function.array_func.split_ext( + latent_kv_all, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) else: if self.qkv_concat: qkv2l = self.qkv2l(x) - q, latent_kv, k_pe = mint.split(qkv2l, [self.q_lora_rank, self.kv_lora_rank, self.qk_rope_head_dim], - dim=-1) + q, latent_kv, k_pe = ops.function.array_func.split_ext( + qkv2l, [self.q_lora_rank, self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) norm_q = self.lq_norm(q) q = self.l2q_proj(norm_q) else: @@ -585,10 +587,11 @@ class DeepseekV3Attention(nn.Cell): norm_q = self.lq_norm(q) q = self.l2q_proj(norm_q) latent_kv_all = self.kv2l(x) - latent_kv, k_pe = mint.split(latent_kv_all, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + latent_kv, k_pe = ops.function.array_func.split_ext( + latent_kv_all, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) q = self.reshape(q, (-1, self.n_local_heads, self.q_head_dim)) - q_nope, q_pe = mint.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + q_nope, q_pe = ops.function.array_func.split_ext(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) # (T, kv_lora_rank) i_kv = self.lkv_norm(latent_kv) q_pe = self.reshape(q_pe, (-1, self.n_local_heads * self.qk_rope_head_dim)) @@ -663,7 +666,7 @@ class DeepseekV3ParallelMLP(ParallelMLP): # [B, S, H] -> [B, S, ffn_H] if self.ffn_concat: gate_hidden_out = self.w_gate_hidden(x) # dp,1 -> dp, mp # dp,1 -> dp, mp - gate, hidden = mint.split(gate_hidden_out, + gate, hidden = ops.function.array_func.split_ext(gate_hidden_out, (self.ffn_hidden_size_per_partition, self.ffn_hidden_size_per_partition), -1) else: gate = self.w1(x) # dp,1 -> dp, mp @@ -692,7 +695,7 @@ class DeepseekV3MoE(Cell): """ def __init__(self, config): - super(DeepseekV3MoE, self).__init__() + super().__init__() self.config = config self.parallel_config = config.parallel_config self.moe_config = config.moe_config @@ -766,7 +769,7 @@ class DeepseekV3MoEWithMicroBatch(DeepseekV3MoE): """ def __init__(self, config): - super(DeepseekV3MoEWithMicroBatch, self).__init__(config=config) + super().__init__(config=config) self.moe_tp_size = get_moe_tp_world_size() self.moe_ep_size = get_moe_ep_world_size() self.ep_rank_id = get_rank() // self.moe_tp_size @@ -846,7 +849,7 @@ class AttentionReduceScatter(Cell): """ def __init__(self, config): - super(AttentionReduceScatter, self).__init__() + super().__init__() self.config = config self.compute_dtype = config.compute_dtype self.hidden_size = config.hidden_size @@ -1439,7 +1442,7 @@ class InferenceDeepseekV3ForCausalLM(DeepseekV3PreTrainedModel): @lazy_inline def __init__(self, config: DeepseekV3Config = None): - super(InferenceDeepseekV3ForCausalLM, self).__init__(config, auto_prefix=True) + super().__init__(config, auto_prefix=True) _check_config(config.parallel_config) self.config = convert_model_config(config) @@ -1499,7 +1502,7 @@ class InferenceDeepseekV3ForCausalLM(DeepseekV3PreTrainedModel): self.load_checkpoint(config) self.predict_run_mode = get_predict_run_mode() - logger.info("Predict run mode:{}".format(self.predict_run_mode)) + logger.info(f"Predict run mode:{self.predict_run_mode}") self.return_hidden_states = config.return_hidden_states # pylint: disable=W0613 @@ -1602,7 +1605,6 @@ class InferenceDeepseekV3ForCausalLM(DeepseekV3PreTrainedModel): if dp_size == 1 or q_seq_len is None: return model_inputs - from mindformers.parallel_core.inference.parallel_state import get_data_parallel_group tokens_len_per_dp = q_seq_len.sum().reshape(-1) tokens_len_per_dp = ops.AllGather(group=get_data_parallel_group().group)(tokens_len_per_dp) tokens_len_per_dp = tokens_len_per_dp.asnumpy() @@ -1749,7 +1751,7 @@ class DeepseekV3MTPLayer(nn.Cell): """ def __init__(self, config: DeepseekV3Config = None): - super(DeepseekV3MTPLayer, self).__init__() + super().__init__() self.enorm = RMSNorm(config.hidden_size, config.rms_norm_eps, compute_type=config.layernorm_compute_type) self.hnorm = RMSNorm(config.hidden_size, config.rms_norm_eps, @@ -1826,7 +1828,7 @@ class DeepseekV3MTPModel(DeepseekV3PreTrainedModel): """ def __init__(self, config: DeepseekV3Config = None): - super(DeepseekV3MTPModel, self).__init__(config, auto_prefix=True) + super().__init__(config, auto_prefix=True) self.dtype = config.compute_dtype self.use_past = config.use_past self.is_first_iteration = True @@ -1925,7 +1927,7 @@ class InferenceDeepseekV3MTPForCausalLM(DeepseekV3PreTrainedModel): """ def __init__(self, config: DeepseekV3Config = None): - super(InferenceDeepseekV3MTPForCausalLM, self).__init__(config, auto_prefix=True) + super().__init__(config, auto_prefix=True) self.dtype = config.compute_dtype self.config = convert_model_config(config) self.parallel_config = self.config.parallel_config diff --git a/research/deepseek3/infer/transformer.py b/research/deepseek3/infer/transformer.py index 0626cc6e9..56be892af 100644 --- a/research/deepseek3/infer/transformer.py +++ b/research/deepseek3/infer/transformer.py @@ -17,11 +17,11 @@ import mindspore.common.dtype as mstype from mindspore import Parameter, mint, nn, ops from mindspore.common.initializer import initializer +from research.deepseek3.infer.activation import SiLU +from research.deepseek3.infer.layers import ColumnParallelLinear, RowParallelLinear from mindformers.parallel_core.inference.utils import get_tp_world_size from mindformers.parallel_core.inference.parallel_state import get_tensor_model_parallel_group from mindformers.tools.utils import divide -from research.deepseek3.infer.activation import SiLU -from research.deepseek3.infer.layers import ColumnParallelLinear, RowParallelLinear class VocabEmbedding(nn.Cell): @@ -174,7 +174,7 @@ class ParallelMLP(nn.Cell): gate_hidden_out_shape = gate_hidden_out.shape reshape_out = self.reshape(gate_hidden_out, (*gate_hidden_out_shape[:-1], self.ffn_hidden_size_per_partition, 2)) - gate, hidden = mint.split(reshape_out, + gate, hidden = ops.function.array_func.split_ext(reshape_out, (1, 1), -1) gate = self.reshape(gate, (*gate_hidden_out_shape[:-1], self.ffn_hidden_size_per_partition)) hidden = self.reshape(hidden, (*gate_hidden_out_shape[:-1], self.ffn_hidden_size_per_partition)) diff --git a/research/deepseek3/moe.py b/research/deepseek3/moe.py index 18aa33d5b..d59d63b9d 100644 --- a/research/deepseek3/moe.py +++ b/research/deepseek3/moe.py @@ -37,6 +37,8 @@ try: except ImportError: MOE_FUSED_OP_VALID = False +from research.deepseek3.infer.activation import SiLU +from research.deepseek3.infer.layers import ColumnParallelLinear, RowParallelLinear from mindformers.modules.layers import Linear from mindformers.parallel_core.inference.parallel_state import (default_pgs, get_moe_expert_parallel_group, get_moe_expert_parallel_world_size, @@ -45,8 +47,6 @@ from mindformers.parallel_core.inference.parallel_state import (default_pgs, get from mindformers.version_control import is_910b from mindformers.tools.utils import divide -from research.deepseek3.infer.activation import SiLU -from research.deepseek3.infer.layers import ColumnParallelLinear, RowParallelLinear dtype_map = { 'float16': mstype.float32, @@ -60,7 +60,7 @@ class TopkRouter(nn.Cell): A router implementation which maps each tokens to the topk expert. """ def __init__(self, expert_num): - super(TopkRouter, self).__init__() + super().__init__() self.topk_bias = Parameter(initializer('zeros', (expert_num), mstype.float32), requires_grad=False, parallel_optimizer=False) @@ -73,7 +73,7 @@ class Router(nn.Cell): def __init__(self, hidden_size, moe_config): - super(Router, self).__init__() + super().__init__() self.expert_num = moe_config.expert_num self.dense = nn.Dense(in_channels=hidden_size, out_channels=self.expert_num, has_bias=False, dtype=dtype_map.get(moe_config.router_dense_type)) @@ -103,7 +103,7 @@ class ParallelMoE(nn.Cell): hidden_size, moe_config, use_fused_op=True): - super(ParallelMoE, self).__init__() + super().__init__() self.hidden_size = hidden_size self.moe_config = moe_config self.expert_dim = moe_config.expert_num @@ -290,7 +290,7 @@ class SharedMLP(nn.Cell): """ Construct function of mlp block. """ if self.ffn_concat: gate_hidden_out = self.w_gate_hidden(x) # dp,1 -> dp, mp # dp,1 -> dp, mp - gate, hidden = mint.split(gate_hidden_out, + gate, hidden = ops.function.array_func.split_ext(gate_hidden_out, (self.ffn_hidden_size, self.ffn_hidden_size), -1) else: gate = self.w1(x) @@ -387,7 +387,7 @@ class SharedParallelMLP(nn.Cell): """ Construct function of mlp block. """ if self.ffn_concat: gate_hidden_out = self.w_gate_hidden(x) # dp,1 -> dp, mp # dp,1 -> dp, mp - gate, hidden = mint.split(gate_hidden_out, + gate, hidden = ops.function.array_func.split_ext(gate_hidden_out, (self.ffn_hidden_size_per_partition, self.ffn_hidden_size_per_partition), -1) else: gate = self.w1(x) @@ -455,7 +455,7 @@ class ColumnParallelGroupLinear(ColumnParallelLinear): tp_group=default_pgs, **kwargs ): - super(ColumnParallelGroupLinear, self).__init__( + super().__init__( input_size=input_size, output_size=output_size, config=config, @@ -541,7 +541,7 @@ class RowParallelGroupLinear(RowParallelLinear): tp_group=default_pgs, **kwargs ): - super(RowParallelGroupLinear, self).__init__( + super().__init__( input_size=input_size, output_size=output_size, config=config, @@ -661,7 +661,7 @@ class RoutedParallelMLP(nn.Cell): """Forward process of the FeedForward""" if self.ffn_concat: gate_hidden_out = self.w_gate_hidden(x, group_list=group_list) # dp,1 -> dp, mp # dp,1 -> dp, mp - gate, hidden = mint.split(gate_hidden_out, + gate, hidden = ops.function.array_func.split_ext(gate_hidden_out, (self.ffn_hidden_size_per_partition, self.ffn_hidden_size_per_partition), -1) else: gate = self.w1(x, group_list=group_list) @@ -711,7 +711,7 @@ class ParallelMoEV2(nn.Cell): hidden_size, moe_config, is_reduce_moe_output=True): - super(ParallelMoEV2, self).__init__() + super().__init__() self.hidden_size = hidden_size self.moe_config = moe_config self.is_reduce_moe_output = is_reduce_moe_output @@ -821,7 +821,7 @@ class ExpertParallelMoE(nn.Cell): moe_config, use_alltoall, compute_dtype): - super(ExpertParallelMoE, self).__init__() + super().__init__() self.compute_dtype = compute_dtype self.hidden_size = hidden_size self.moe_config = moe_config @@ -860,7 +860,7 @@ class ExpertParallelMoE(nn.Cell): self.group_list_index = Tensor([0,], mstype.int32) if self.moe_ep_size > 1 and not self.use_alltoall: - bias_idx = [idx for idx in range(self.expert_num)] + bias_idx = list(range(self.expert_num)) self.bias_idx = bias_idx[self.in_start_expert_idx:] + bias_idx[:self.in_start_expert_idx] self.router.e_score_correction_bias.init_data() self.router.e_score_correction_bias = self.router.e_score_correction_bias[self.bias_idx] diff --git a/research/llama3_1/infer/transformer.py b/research/llama3_1/infer/transformer.py index 2d8fe4200..1d1b6f857 100644 --- a/research/llama3_1/infer/transformer.py +++ b/research/llama3_1/infer/transformer.py @@ -193,7 +193,7 @@ class ParallelMLP(nn.Cell): gate_hidden_out_shape = gate_hidden_out.shape reshape_out = self.reshape(gate_hidden_out, (*gate_hidden_out_shape[:-1], self.ffn_hidden_size_per_partition, 2)) - gate, hidden = mint.split(reshape_out, + gate, hidden = ops.function.array_func.split_ext(reshape_out, (1, 1), -1) gate = self.reshape(gate, (*gate_hidden_out_shape[:-1], self.ffn_hidden_size_per_partition)) hidden = self.reshape(hidden, (*gate_hidden_out_shape[:-1], self.ffn_hidden_size_per_partition)) @@ -422,7 +422,7 @@ class ParallelAttention(nn.Cell): (-1, self.kv_num_heads_per_partition, (self.n_rep + 2) * self.head_dim)) - query, key, value = mint.split(reshape_qkv, + query, key, value = ops.function.array_func.split_ext(reshape_qkv, (self.head_dim * self.n_rep, self.head_dim, self.head_dim), -1) @@ -444,7 +444,8 @@ class ParallelAttention(nn.Cell): query = self.cast(self.wq(x), self.compute_dtype) if self.qkv_concat: kv = self.cast(self.w_kv(encoder_output), self.compute_dtype) - key, value = mint.split(kv, (self.kv_hidden_size_per_partition, self.kv_hidden_size_per_partition), -1) + key, value = ops.function.array_func.split_ext( + kv, (self.kv_hidden_size_per_partition, self.kv_hidden_size_per_partition), -1) else: key = self.cast(self.wk(encoder_output), self.compute_dtype) value = self.cast(self.wv(encoder_output), self.compute_dtype) diff --git a/research/qwen2_5/infer/transformer.py b/research/qwen2_5/infer/transformer.py index bb4874711..18cc8adaa 100644 --- a/research/qwen2_5/infer/transformer.py +++ b/research/qwen2_5/infer/transformer.py @@ -21,6 +21,10 @@ import mindspore.common.dtype as mstype from mindspore import Parameter, Tensor, mint, nn, ops from mindspore.common.initializer import initializer +from research.qwen2_5.infer.layers import ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding +from research.qwen2_5.infer.norm import get_norm +from research.qwen2_5.infer.parallel_paged_attention_mgr import ParallelPagedAttentionMgr +from research.qwen2_5.infer.scale_mask_softmax import ScaleMaskSoftmax from mindformers.modules.flash_attention import FlashAttention from mindformers.modules.infer_attention import InferRotaryEmbedding from mindformers.modules.layers import FreqsMgr, RotaryEmbedding @@ -30,10 +34,7 @@ from mindformers.parallel_core.inference.utils import divide from mindformers.parallel_core.inference.utils import get_attn_mask_func from mindformers.parallel_core.process_group_config import default_model_comm_pgs from mindformers.version_control import need_nz -from research.qwen2_5.infer.layers import ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding -from research.qwen2_5.infer.norm import get_norm -from research.qwen2_5.infer.parallel_paged_attention_mgr import ParallelPagedAttentionMgr -from research.qwen2_5.infer.scale_mask_softmax import ScaleMaskSoftmax + __all__ = [ "ParallelMLP", @@ -195,7 +196,7 @@ class ParallelMLP(nn.Cell): gate_hidden_out_shape = gate_hidden_out.shape reshape_out = self.reshape(gate_hidden_out, (*gate_hidden_out_shape[:-1], self.ffn_hidden_size_per_partition, 2)) - gate, hidden = mint.split(reshape_out, + gate, hidden = ops.function.array_func.split_ext(reshape_out, (1, 1), -1) gate = self.reshape(gate, (*gate_hidden_out_shape[:-1], self.ffn_hidden_size_per_partition)) hidden = self.reshape(hidden, (*gate_hidden_out_shape[:-1], self.ffn_hidden_size_per_partition)) @@ -237,7 +238,7 @@ class CoreAttention(nn.Cell): """ def __init__(self, layer_number, config, attn_mask_type=None): - super(CoreAttention, self).__init__() + super().__init__() if attn_mask_type: raise NotImplementedError("For CoreAttention, `attn_mask_type` is not supported for now.") self.config = config @@ -353,7 +354,7 @@ class ParallelAttention(nn.Cell): self.tp_group_size = self.tp.size self.num_heads_per_partition = divide(self.num_heads, self.tp_group_size) - self.use_gqa = (self.num_heads != self.kv_num_heads) + self.use_gqa = self.num_heads != self.kv_num_heads if self.use_gqa: self._check_gqa_valid() @@ -424,7 +425,7 @@ class ParallelAttention(nn.Cell): (-1, self.kv_num_heads_per_partition, (self.n_rep + 2) * self.head_dim)) - query, key, value = mint.split(reshape_qkv, + query, key, value = ops.function.array_func.split_ext(reshape_qkv, (self.head_dim * self.n_rep, self.head_dim, self.head_dim), -1) @@ -446,7 +447,8 @@ class ParallelAttention(nn.Cell): query = self.cast(self.wq(x), self.compute_dtype) if self.qkv_concat: kv = self.cast(self.w_kv(encoder_output), self.compute_dtype) - key, value = mint.split(kv, (self.kv_hidden_size_per_partition, self.kv_hidden_size_per_partition), -1) + key, value = ops.function.array_func.split_ext( + kv, (self.kv_hidden_size_per_partition, self.kv_hidden_size_per_partition), -1) else: key = self.cast(self.wk(encoder_output), self.compute_dtype) value = self.cast(self.wv(encoder_output), self.compute_dtype) @@ -688,8 +690,8 @@ class ParallelTransformerLayer(nn.Cell): raise NotImplementedError("For ParallelTransformerLayer, `self_attn_mask_type` is not supported for now.") if drop_path_rate > 0.0: raise NotImplementedError( - "For ParallelTransformerLayer, `drop_path_rate > 0` is not supported for now, " - "but got `drop_path_rate={}`".format(drop_path_rate) + f"For ParallelTransformerLayer, `drop_path_rate > 0` is not supported for now, " + f"but got `drop_path_rate={drop_path_rate}`" ) self.config = config self.apply_residual_connection_post_norm = self.config.apply_residual_connection_post_norm -- Gitee