From 838ec32ec901a5c26aba37aa50d175239e14c69f Mon Sep 17 00:00:00 2001 From: niujunhao Date: Wed, 16 Apr 2025 17:57:00 +0800 Subject: [PATCH] fix cp with use_seq_parallel in llama_interleave --- mindformers/models/llama/llama_interleave.py | 49 +++++++++++--------- 1 file changed, 27 insertions(+), 22 deletions(-) diff --git a/mindformers/models/llama/llama_interleave.py b/mindformers/models/llama/llama_interleave.py index ac1712c47..676cfb620 100644 --- a/mindformers/models/llama/llama_interleave.py +++ b/mindformers/models/llama/llama_interleave.py @@ -309,7 +309,6 @@ class LLamaAttentionInterleave(nn.Cell): self.merger_head_transpose.shard(in_strategy=layout_merger_head_transpose) else: self.merger_head_transpose.shard(((dp, mp, 1, 1),)) - self.merger_head_transpose.shard(((dp, mp, 1, 1),)) self.batch_matmul_q_k.shard(((dp, mp, 1, 1), (dp, mp, 1, 1))) self.batch_matmul.shard(((dp, mp, 1, 1), (dp, mp, 1, 1))) self.slice_qkv.shard(((dp, mp),)) @@ -332,11 +331,10 @@ class LLamaAttentionInterleave(nn.Cell): self.merger_head_transpose.shard(in_strategy=layout_merger_head_transpose) else: self.merger_head_transpose.shard(((dp, mp, 1, 1),)) - self.merger_head_transpose.shard(((dp, mp, 1, 1),)) - self.batch_matmul_q_k.shard(((dp, mp, 1, 1), (dp, mp, 1, 1))) - self.batch_matmul.shard(((dp, mp, 1, 1), (dp, mp, 1, 1))) - self.mul.shard(((dp, mp, 1, 1), ())) - self.add.shard(((dp, 1, 1, 1), (dp, mp, 1, 1))) + self.batch_matmul_q_k.shard(((dp, mp, cp, 1), (dp, mp, cp, 1))) + self.batch_matmul.shard(((dp, mp, cp, 1), (dp, mp, cp, 1))) + self.mul.shard(((dp, mp, cp, 1), ())) + self.add.shard(((dp, 1, 1, 1), (dp, mp, cp, 1))) self.softmax.softmax.shard(((dp, mp, 1, 1),)) self.tile_kv.shard(((dp, mp, 1, 1),)) self.slice_qkv.shard(((dp, mp),)) @@ -354,7 +352,10 @@ class LLamaAttentionInterleave(nn.Cell): self.wv.shard(((dp * cp, 1), (mp, 1))) self.wo.shard(((dp * cp, mp), (1, mp)), ((dp * cp, 1), (1,)), out_strategy_matmul=((dp * cp, 1),)) if parallel_config.use_seq_parallel and self.is_first_iteration and cp == 1: - self.wo.shard(((dp, mp), (1, mp)), ((dp * cp, 1), (1,)), out_strategy_matmul=((dp * mp, 1),)) + self.wo.shard( + strategy_matmul=((dp * cp, mp), (1, mp)), + strategy_bias=((dp * cp * mp, 1), (1,)), + out_strategy_matmul=((dp * cp * mp, 1),)) if parallel_config.recompute.select_recompute and not self.use_flash_attention: self.apply_rotary_emb.recompute() self.tile_kv.recompute() @@ -694,6 +695,7 @@ class LLamaDecodeLayerInterleave(nn.Cell): self.feed_forward.shard(parallel_config) self.feed_forward.mul.shard(((dp * cp, mp), (dp * cp, mp))) self.add.shard(((dp * cp, 1), (dp * cp, 1))) + # equal to rmsnorm_compute_2d in llama_transformer.py if cp > 1: self.attention_norm.shard((dp * cp * mp, 1)) self.ffn_norm.shard((dp * cp * mp, 1)) @@ -702,10 +704,13 @@ class LLamaDecodeLayerInterleave(nn.Cell): self.ffn_norm.shard((dp, 1)) if parallel_config.use_seq_parallel and self.is_first_iteration: - self.add.shard(((dp * mp, 1), (dp * mp, 1))) - self.attention_norm.shard((dp * mp, 1)) - self.ffn_norm.shard((dp * mp, 1)) - self.feed_forward.w2.shard(((dp, mp), (1, mp)), out_strategy_matmul=((dp * mp, 1),)) + self.add.shard(((dp * mp * cp, 1), (dp * mp * cp, 1))) + self.attention_norm.shard((dp * mp * cp, 1)) + self.ffn_norm.shard((dp * mp * cp, 1)) + self.feed_forward.w2.shard( + strategy_matmul=((dp * cp, mp), (1, mp)), + strategy_bias=((dp * cp * mp, 1), (1,)), + out_strategy_matmul=((dp * cp * mp, 1),)) concat_stra1 = [] concat_stra2 = [] @@ -720,7 +725,7 @@ class LLamaDecodeLayerInterleave(nn.Cell): self.interleaved_concat2.add_prim_attr("fine_grained_interleaved_index", 1000) for _ in range(self.interleave_num): - concat_stra1.append((dp, mp)) + concat_stra1.append((dp * cp, mp)) interleave_data1 = _MicroBatch(self.interleave_num, 1, [0]) interleave_data1.strided_slice_list[0].add_prim_attr("skip_redistribution", True) interleave_data1_ = _MicroBatch(self.interleave_num, 1, [0]) @@ -728,28 +733,28 @@ class LLamaDecodeLayerInterleave(nn.Cell): interleave_data2 = _MicroBatch(self.interleave_num, 2, [0, 0]) if parallel_config.use_seq_parallel: if self.layer_id == self.num_layers - 2: - concat_stra2.append((dp, 1)) + concat_stra2.append((dp * cp, 1)) else: - concat_stra2.append((dp * mp, 1)) + concat_stra2.append((dp * mp * cp, 1)) if self.layer_id == self.num_layers - 1: - interleave_data1.strided_slice_list[0].shard(((dp, 1),)) + interleave_data1.strided_slice_list[0].shard(((dp * cp, 1),)) else: - interleave_data1.strided_slice_list[0].shard(((dp * mp, 1),)) + interleave_data1.strided_slice_list[0].shard(((dp * mp * cp, 1),)) interleave_data1_.strided_slice_list[0].shard(((1, 1),)) - interleave_data2.strided_slice_list[0].shard(((dp * mp, 1),)) + interleave_data2.strided_slice_list[0].shard(((dp * mp * cp, 1),)) else: - concat_stra2.append((dp, 1)) - interleave_data1.strided_slice_list[0].shard(((dp, 1),)) + concat_stra2.append((dp * cp, 1)) + interleave_data1.strided_slice_list[0].shard(((dp * cp, 1),)) interleave_data1_.strided_slice_list[0].shard(((1, 1),)) - interleave_data2.strided_slice_list[0].shard(((dp, 1),)) + interleave_data2.strided_slice_list[0].shard(((dp * cp, 1),)) if self.layer_id == 0 and parallel_config.use_seq_parallel: - interleave_data2.strided_slice_list[0].shard(((dp, 1),)) + interleave_data2.strided_slice_list[0].shard(((dp * cp, 1),)) interleave_data2.strided_slice_list[0].add_prim_attr("skip_redistribution", True) else: interleave_data2.strided_slice_list[0].add_prim_attr("skip_redistribution", True) interleave_data2.strided_slice_list[0].add_prim_attr("fine_grained_interleaved_index", self.layer_id) - interleave_data2.strided_slice_list[1].shard(((dp, mp),)) + interleave_data2.strided_slice_list[1].shard(((dp * cp, mp),)) interleave_data2.strided_slice_list[1].add_prim_attr("fine_grained_interleaved_index", self.layer_id) interleave_data2.strided_slice_list[1].add_prim_attr("skip_redistribution", True) self.interleave1_inputs.append(interleave_data1) -- Gitee