diff --git a/omni/layers/attention/deepseek_mla.py b/omni/layers/attention/deepseek_mla.py index acff0b16ebfca8f98b28368fed069a1cdf404870..525367e08ebaa23ff7710c2c6097cd0a424726b9 100644 --- a/omni/layers/attention/deepseek_mla.py +++ b/omni/layers/attention/deepseek_mla.py @@ -554,6 +554,7 @@ class DeepseekMLA(nn.Module): hidden_states: torch.Tensor, kv_cache: Optional[torch.Tensor], attn_metadata: AttentionMetadata, + stream1: torch.npu.Stream, comm_group: Optional[GroupCoordinator] = None, ) -> torch.Tensor: if not self.is_init: @@ -570,7 +571,7 @@ class DeepseekMLA(nn.Module): if model_extra_config.operator_opt_config.enable_dsa: output = self._forward_prefill_absorb(positions, hidden_states, kv_cache, attn_metadata, comm_group=comm_group) else: - output = self._forward_prefill(positions, hidden_states, kv_cache, attn_metadata, comm_group=comm_group) + output = self._forward_prefill(positions, hidden_states, kv_cache, attn_metadata, stream1, comm_group=comm_group) else: output = self._forward_decode(positions, hidden_states, kv_cache, attn_metadata) if model_extra_config.operator_opt_config.use_mlaprolog and not self.is_mla_prolog_init: @@ -706,6 +707,7 @@ class DeepseekMLA(nn.Module): hidden_states: torch.Tensor, kv_cache: Optional[torch.Tensor], attn_metadata: AttentionMetadata, + stream1: torch.npu.Stream, comm_group: Optional[GroupCoordinator] = None, ) -> torch.Tensor: main_stream = current_stream() @@ -890,36 +892,57 @@ class DeepseekMLA(nn.Module): prefix_meta=attn_metadata.prefill.prefix_meta, layer_idx=self.layer_idx + 1, ) - + attn_output_mb0, attn_output_mb1 = torch.split(attn_output, [attn_output.size(0) // 2, attn_output.size(0) - attn_output.size(0) // 2], dim=0) + mb0_size = attn_output_mb0.size(0) + mb1_size = attn_output_mb1.size(0) + event = torch_npu.npu.ExternalEvent() # if only set prefill_enable_mla_alltoall means prefill o_proj tp to dp # if also set o_proj_tp_size means prefill o_proj tp to dp + tp if model_extra_config.operator_opt_config.prefill_enable_mla_alltoall: if attn_metadata is not None: if model_extra_config.parall_config.o_proj_tp_size > 1: - attn_output = attn_output.view(get_o_proj_dp_group().world_size, -1, self.num_local_heads, self.v_head_dim) - attn_output = attn_output.reshape(-1) - all_to_all_attn_output = torch.empty( - [q.shape[0] * self.num_local_heads * self.v_head_dim], - dtype=attn_output.dtype, - device=current_platform.device_type - ) + attn_output_mb0 = attn_output_mb0.view(get_o_proj_dp_group().world_size, -1, self.num_local_heads, self.v_head_dim) + attn_output_mb1 = attn_output_mb1.view(get_o_proj_dp_group().world_size, -1, self.num_local_heads, self.v_head_dim) + attn_output_mb0 = attn_output_mb0.reshape(-1) + attn_output_mb1 = attn_output_mb1.reshape(-1) device_group = get_o_proj_dp_group().device_group \ if model_extra_config.parall_config.o_proj_tp_size > 1 else get_tp_group().device_group - dist.all_to_all_single(all_to_all_attn_output, attn_output, group=device_group) - if model_extra_config.parall_config.o_proj_tp_size > 1: - attn_output = all_to_all_attn_output.view( - get_tensor_model_parallel_world_size() // get_o_proj_tp_group().world_size, - q.shape[0] // get_tensor_model_parallel_world_size() * get_o_proj_tp_group().world_size, + with torch.npu.stream(main_stream): + all_to_all_attn_output_mb0 = torch.empty( + [mb0_size * self.num_local_heads * self.v_head_dim], + dtype=attn_output_mb0.dtype, + device=current_platform.device_type + ) + dist.all_to_all_single(all_to_all_attn_output_mb0, attn_output_mb0, group=device_group) + event.record(main_stream) + attn_output_mb0 = all_to_all_attn_output_mb0.view( + get_tensor_model_parallel_world_size(), + mb0_size // get_tensor_model_parallel_world_size(), self.num_local_heads * self.v_head_dim ).transpose(0, 1).contiguous() - else: - attn_output = all_to_all_attn_output.view( + output_mb0, _ = self.o_proj.forward( + attn_output_mb0.reshape(-1, o_proj_tp_size * self.num_local_heads * self.v_head_dim)) + # print("oproj type: ", type(self.o_proj)) + with torch.npu.stream(stream1): + all_to_all_attn_output_mb1 = torch.empty( + [mb1_size * self.num_local_heads * self.v_head_dim], + dtype=attn_output_mb1.dtype, + device=current_platform.device_type + ) + event.wait(stream1) + dist.all_to_all_single(all_to_all_attn_output_mb1, attn_output_mb1, group=device_group) + event.reset() + attn_output_mb1 = all_to_all_attn_output_mb1.view( get_tensor_model_parallel_world_size(), - q.shape[0] // get_tensor_model_parallel_world_size(), + mb1_size // get_tensor_model_parallel_world_size(), self.num_local_heads * self.v_head_dim ).transpose(0, 1).contiguous() - output, _ = self.o_proj.forward( - attn_output.reshape(-1, o_proj_tp_size * self.num_local_heads * self.v_head_dim)) + output_mb1, _ = self.o_proj.forward( + attn_output_mb1.reshape(-1, o_proj_tp_size * self.num_local_heads * self.v_head_dim)) + if stream1 is not None: + main_stream.wait_stream(stream1) + # logger.info(f"output_mb0.shape = {output_mb0.shape}, output_mb1.shape = {output_mb1.shape}") + output = torch.cat((output_mb0, output_mb1), 0) else: attn_output = attn_output.view(-1, self.num_local_heads * self.v_head_dim) if model_extra_config.parall_config.o_proj_tp_size > 1: diff --git a/omni/models/pangu/pangu_ultra_moe.py b/omni/models/pangu/pangu_ultra_moe.py index f958d040ff5d9f4a142bb8c044bfed323c6ddfb7..8e61d53c8f5b90e0e96b0f4699fc6823f6c1847a 100644 --- a/omni/models/pangu/pangu_ultra_moe.py +++ b/omni/models/pangu/pangu_ultra_moe.py @@ -207,7 +207,8 @@ class PanguUltraMoEDecoderLayer(nn.Module): residual: Optional[torch.Tensor], layer_id: Optional[int] = None, next_attention_weights: Optional[dict] = None, - next_input_layernorm: Optional[nn.Module] = None + next_input_layernorm: Optional[nn.Module] = None, + stream1: torch.npu.Stream = None ) -> torch.Tensor: if isinstance(attn_metadata, dict): attn_metadata = attn_metadata[self.layer_name] @@ -230,6 +231,7 @@ class PanguUltraMoEDecoderLayer(nn.Module): hidden_states=hidden_states, kv_cache=kv_cache, attn_metadata=attn_metadata, + stream1=stream1 ) if enable_superkernel: @@ -433,6 +435,7 @@ class PanguUltraMoEModel(nn.Module): attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], ) -> Union[torch.Tensor, IntermediateTensors]: + stream1 = torch.npu.Stream() if get_pp_group().is_first_rank: hidden_states = self.get_input_embeddings(input_ids) residual = None @@ -470,7 +473,8 @@ class PanguUltraMoEModel(nn.Module): residual, layer_id, next_attention_weights, - next_input_layernorm) + next_input_layernorm, + stream1) if not get_pp_group().is_last_rank: return IntermediateTensors({ diff --git a/tests/test_config/test_config_prefill_pangu_ultra_moe.json b/tests/test_config/test_config_prefill_pangu_ultra_moe.json index 1002faca90b9897f2819da47ee777b9ae4c17df3..25aa281185b6b3b23d5eccce68dc2defe4c658dc 100644 --- a/tests/test_config/test_config_prefill_pangu_ultra_moe.json +++ b/tests/test_config/test_config_prefill_pangu_ultra_moe.json @@ -7,6 +7,7 @@ "operator_optimization_config": { "enable_kv_rmsnorm_rope_cache": true, "prefill_moe_all_to_all": true, + "prefill_mla_all_to_all": true, "best_ep": false, "merge_qkv": false, "gmm_nz": true,