diff --git a/.gitignore b/.gitignore index d4e2d42f6c46e3389d10e27ce984006001a07cb8..ec806f11f430202d9c00e71b4b1d336a0ea03d8f 100644 --- a/.gitignore +++ b/.gitignore @@ -146,6 +146,9 @@ cython_debug/ # megatron core /megatron/ +# mindspeed +/mindspeed/ + # User stuff /kernel*/ /logs/ diff --git a/docs/USER_GUIDE.md b/docs/USER_GUIDE.md index 60c1e23b22c96372334f0ad83e696235a0b44cc0..b63660263f604c34628ffd69bd1b1549075d3594 100644 --- a/docs/USER_GUIDE.md +++ b/docs/USER_GUIDE.md @@ -98,7 +98,7 @@ source /usr/local/Ascend/nnal/atb/set_env.sh git clone https://gitee.com/ascend/MindSpeed.git cd MindSpeed # checkout commit from MindSpeed core_r0.7.0 in 2024.11.04 -git checkout f3332571 +git checkout c9d20b5 pip install -r requirements.txt pip3 install -e . cd .. diff --git a/examples/mcore/deepseek2_lite/convert_ckpt_deepseek2_lite_hf2mcore.sh b/examples/mcore/deepseek2_lite/convert_ckpt_deepseek2_lite_hf2mcore.sh index 631ae66c6d42998740c1cad3a85c095d7c956463..4f9515818e28dbae770494b05114a455866fb03a 100644 --- a/examples/mcore/deepseek2_lite/convert_ckpt_deepseek2_lite_hf2mcore.sh +++ b/examples/mcore/deepseek2_lite/convert_ckpt_deepseek2_lite_hf2mcore.sh @@ -1,9 +1,10 @@ # 请按照您的真实环境修改 set_env.sh 路径 # 按照您的实际需要修改目录信息并完成对应的TP、PP、EP的参数配置 -source /usr/local/Ascend/ascend-toolkit/set_up.sh +source /usr/local/Ascend/ascend-toolkit/set_env.sh python convert_ckpt.py \ + --moe-grouped-gemm \ --use-mcore-models \ --model-type-hf deepseek2-lite \ --model-type GPT \ @@ -13,6 +14,7 @@ python convert_ckpt.py \ --target-tensor-parallel-size 1 \ --target-pipeline-parallel-size 1 \ --target-expert-parallel-size 8 \ + --spec modellink.tasks.models.spec.deepseek_spec layer_spec \ --load-dir ./model_from_hf/deepseek_v2_lite/ \ --save-dir ./model_weights/deepseek2_lite_mcore/ \ --tokenizer-model ./model_from_hf/deepseek_v2_lite/ diff --git a/examples/mcore/deepseek2_lite/convert_ckpt_deepseek2_lite_mcore2hf.sh b/examples/mcore/deepseek2_lite/convert_ckpt_deepseek2_lite_mcore2hf.sh index 3af937b75d0b0fe35002b1862e22551191ea29fc..f09f648120a3a16b5580f9b8962d1a058501305c 100644 --- a/examples/mcore/deepseek2_lite/convert_ckpt_deepseek2_lite_mcore2hf.sh +++ b/examples/mcore/deepseek2_lite/convert_ckpt_deepseek2_lite_mcore2hf.sh @@ -2,6 +2,7 @@ source /usr/local/Ascend/ascend-toolkit/set_env.sh python convert_ckpt.py \ + --moe-grouped-gemm \ --use-mcore-models \ --model-type-hf deepseek2-lite \ --model-type GPT \ @@ -11,5 +12,6 @@ python convert_ckpt.py \ --target-tensor-parallel-size 1 \ --target-pipeline-parallel-size 1 \ --target-expert-parallel-size 1 \ + --spec modellink.tasks.models.spec.deepseek_spec layer_spec \ --load-dir ./model_weights/deepseek2_lite_mcore/ \ --save-dir ./model/deepseek2_lite/ diff --git a/examples/mcore/deepseek2_lite/evaluate_deepseek2_lite_16b_ptd.sh b/examples/mcore/deepseek2_lite/evaluate_deepseek2_lite_16b_ptd.sh index b0d99f65dc37c7cdc2d98ea73eaeb3da66ccf243..1f240b62c7e08a3fe9f1b26f085c219657be70df 100644 --- a/examples/mcore/deepseek2_lite/evaluate_deepseek2_lite_16b_ptd.sh +++ b/examples/mcore/deepseek2_lite/evaluate_deepseek2_lite_16b_ptd.sh @@ -33,6 +33,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS evaluation.py \ --no-chat-template \ --max-new-tokens 1 \ --use-mcore-models \ + --moe-grouped-gemm \ --tensor-model-parallel-size ${TP} \ --pipeline-model-parallel-size ${PP} \ --num-layers 27 \ @@ -89,4 +90,4 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS evaluation.py \ --rope-scaling-original-max-position-embeddings 4096 \ --rope-scaling-type yarn \ --distributed-backend nccl \ - | tee ./logs/evaluation_mcore_deepseek2_lite_16b_${TASK}.log + | tee ./logs/evaluation_deepseek2_lite_ptd_8p_${TASK}.log diff --git a/examples/mcore/deepseek2_lite/generate_deepseek2_lite_16b_ptd.sh b/examples/mcore/deepseek2_lite/generate_deepseek2_lite_16b_ptd.sh index ec6754b67e41b27382c4c02134d3c21eda5ce613..e79f9f78fde71b8eb73b88ecc9a0b49eac26b0cf 100644 --- a/examples/mcore/deepseek2_lite/generate_deepseek2_lite_16b_ptd.sh +++ b/examples/mcore/deepseek2_lite/generate_deepseek2_lite_16b_ptd.sh @@ -36,6 +36,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS inference.py \ --task chat \ --max-new-tokens 256 \ --use-mcore-models \ + --moe-grouped-gemm \ --tensor-model-parallel-size ${TP} \ --pipeline-model-parallel-size ${PP} \ --num-layers 27 \ @@ -92,5 +93,5 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS inference.py \ --rope-scaling-original-max-position-embeddings 4096 \ --rope-scaling-type yarn \ --distributed-backend nccl \ - | tee logs/generate_mcore_deepseek2_lite.log + | tee logs/generate_deepseek2_lite.log diff --git a/examples/mcore/deepseek2_lite/pretrain_deepseek2_lite_16b_ptd_16p.sh b/examples/mcore/deepseek2_lite/pretrain_deepseek2_lite_16b_ptd_16p.sh new file mode 100644 index 0000000000000000000000000000000000000000..6a6ce8a36ca895abf381aba9a55afac27ece77d8 --- /dev/null +++ b/examples/mcore/deepseek2_lite/pretrain_deepseek2_lite_16b_ptd_16p.sh @@ -0,0 +1,150 @@ +#!/bin/bash +export CUDA_DEVICE_MAX_CONNECTIONS=1 +export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True + +GPUS_PER_NODE=16 +MASTER_ADDR=localhost +MASTER_PORT=6000 +NNODES=1 +NODE_RANK=0 +WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES)) + +CKPT_SAVE_DIR="your model save ckpt path" +DATA_PATH="your data path" +TOKENIZER_MODEL="your tokenizer path" +CKPT_LOAD_DIR="your model ckpt path" + +TP=1 +PP=1 +EP=8 + +DISTRIBUTED_ARGS=" + --nproc_per_node $GPUS_PER_NODE \ + --nnodes $NNODES \ + --node_rank $NODE_RANK \ + --master_addr $MASTER_ADDR \ + --master_port $MASTER_PORT +" + +MLA_ARGS=" + --spec modellink.tasks.models.spec.deepseek_spec layer_spec \ + --multi-head-latent-attention \ + --qk-rope-head-dim 64 \ + --qk-nope-head-dim 128 \ + --kv-lora-rank 512 \ + --v-head-dim 128 \ + --qk-layernorm \ +" + +MOE_ARGS=" + --moe-grouped-gemm \ + --moe-alltoall-overlap-comm \ + --moe-permutation-async-comm \ + --moe-token-dispatcher-type alltoall \ + --use-fused-moe-token-permute-and-unpermute \ + --first-k-dense-replace 1 \ + --moe-layer-freq 1 \ + --n-shared-experts 2 \ + --num-experts 64 \ + --moe-router-topk 6 \ + --moe-intermediate-size 1408 \ + --moe-router-load-balancing-type pai_megatron_aux_loss \ + --topk-group 1 \ + --moe-aux-loss-coeff 0.01 \ + --routed-scaling-factor 1.0 \ + --seq-aux +" + +ROPE_ARGS=" + --rope-scaling-beta-fast 32 \ + --rope-scaling-beta-slow 1 \ + --rope-scaling-factor 40 \ + --rope-scaling-mscale 0.707 \ + --rope-scaling-mscale-all-dim 0.707 \ + --rope-scaling-original-max-position-embeddings 4096 \ + --rope-scaling-type yarn +" + +GPT_ARGS=" + --shape-order BNSD \ + --reuse-fp32-param \ + --load $CKPT_LOAD_DIR \ + --use-distributed-optimizer \ + --use-flash-attn \ + --use-mcore-models \ + --tensor-model-parallel-size ${TP} \ + --pipeline-model-parallel-size ${PP} \ + --expert-model-parallel-size ${EP} \ + --sequence-parallel \ + --num-layers 27 \ + --hidden-size 2048 \ + --ffn-hidden-size 10944 \ + --num-attention-heads 16 \ + --tokenizer-type PretrainedFromHF \ + --tokenizer-name-or-path ${TOKENIZER_MODEL} \ + --finetune \ + --num-workers 8 \ + --seq-length 4096 \ + --max-position-embeddings 163840 \ + --micro-batch-size 1 \ + --global-batch-size 768 \ + --make-vocab-size-divisible-by 1 \ + --lr 2e-5 \ + --train-iters 2000 \ + --lr-decay-style cosine \ + --lr-decay-iters 2000 \ + --untie-embeddings-and-output-weights \ + --disable-bias-linear \ + --attention-dropout 0.0 \ + --init-method-std 0.02 \ + --hidden-dropout 0.0 \ + --position-embedding-type rope \ + --normalization RMSNorm \ + --use-fused-rotary-pos-emb \ + --use-rotary-position-embeddings \ + --use-fused-swiglu \ + --use-fused-rmsnorm \ + --swiglu \ + --no-masked-softmax-fusion \ + --attention-softmax-in-fp32 \ + --min-lr 1.0e-8 \ + --weight-decay 1e-1 \ + --lr-warmup-iters 100 \ + --clip-grad 1.0 \ + --adam-beta1 0.9 \ + --adam-beta2 0.95 \ + --initial-loss-scale 65536 \ + --vocab-size 102400 \ + --padded-vocab-size 102400 \ + --rotary-base 10000 \ + --no-gradient-accumulation-fusion \ + --norm-epsilon 1e-6 \ + --no-load-optim \ + --no-load-rng \ + --bf16 +" + +DATA_ARGS=" + --data-path $DATA_PATH \ + --split 99,1,0 +" + +OUTPUT_ARGS=" + --log-interval 1 \ + --save-interval 1000 \ + --eval-interval 10000 \ + --eval-iters 10 \ + --no-save-optim \ + --no-save-rng +" + +python -m torch.distributed.launch $DISTRIBUTED_ARGS pretrain_gpt.py \ + $GPT_ARGS \ + $DATA_ARGS \ + $OUTPUT_ARGS \ + $MLA_ARGS \ + $ROPE_ARGS \ + $MOE_ARGS \ + --distributed-backend nccl \ + --save $CKPT_SAVE_DIR \ + | tee logs/pretrain_deepseek2_lite_ptd_16p.log \ No newline at end of file diff --git a/examples/mcore/deepseek2_lite/pretrain_deepseek2_lite_16b_ptd_8p.sh b/examples/mcore/deepseek2_lite/pretrain_deepseek2_lite_16b_ptd_8p.sh index fd6fa3dc3e9976ac60d16766c6d7d07e07e79afe..766ef830b8c42633ddf75525244a49d3bc6309fc 100644 --- a/examples/mcore/deepseek2_lite/pretrain_deepseek2_lite_16b_ptd_8p.sh +++ b/examples/mcore/deepseek2_lite/pretrain_deepseek2_lite_16b_ptd_8p.sh @@ -37,17 +37,20 @@ MLA_ARGS=" " MOE_ARGS=" + --moe-grouped-gemm \ + --moe-alltoall-overlap-comm \ --moe-permutation-async-comm \ - --moe-token-dispatcher-type allgather \ + --moe-token-dispatcher-type alltoall \ + --use-fused-moe-token-permute-and-unpermute \ --first-k-dense-replace 1 \ --moe-layer-freq 1 \ --n-shared-experts 2 \ --num-experts 64 \ --moe-router-topk 6 \ --moe-intermediate-size 1408 \ - --moe-router-load-balancing-type softmax_topk \ + --moe-router-load-balancing-type pai_megatron_aux_loss \ --topk-group 1 \ - --moe-aux-loss-coeff 0.001 \ + --moe-aux-loss-coeff 0.01 \ --routed-scaling-factor 1.0 \ --seq-aux " @@ -63,6 +66,8 @@ ROPE_ARGS=" " GPT_ARGS=" + --shape-order BNSD \ + --reuse-fp32-param \ --load $CKPT_LOAD_DIR \ --use-distributed-optimizer \ --use-flash-attn \ @@ -80,13 +85,16 @@ GPT_ARGS=" --num-attention-heads 16 \ --tokenizer-type PretrainedFromHF \ --tokenizer-name-or-path ${TOKENIZER_MODEL} \ - --seq-length 8192 \ + --finetune \ + --num-workers 8 \ + --seq-length 4096 \ --max-position-embeddings 163840 \ --micro-batch-size 1 \ --global-batch-size 8 \ --make-vocab-size-divisible-by 1 \ - --lr 1.0e-6 \ - --train-iters 2000 \ + --lr 2e-5 \ + --train-iters 462240 \ + --lr-decay-iters 462240 \ --lr-decay-style cosine \ --untie-embeddings-and-output-weights \ --disable-bias-linear \ @@ -103,11 +111,11 @@ GPT_ARGS=" --no-masked-softmax-fusion \ --attention-softmax-in-fp32 \ --min-lr 1.0e-8 \ - --weight-decay 1e-2 \ - --lr-warmup-iters 500 \ + --weight-decay 1e-1 \ + --lr-warmup-iters 1920 \ --clip-grad 1.0 \ --adam-beta1 0.9 \ - --adam-beta2 0.999 \ + --adam-beta2 0.95 \ --initial-loss-scale 65536 \ --vocab-size 102400 \ --padded-vocab-size 102400 \ @@ -121,14 +129,14 @@ GPT_ARGS=" DATA_ARGS=" --data-path $DATA_PATH \ - --split 100,0,0 + --split 99,1,0 " OUTPUT_ARGS=" --log-interval 1 \ - --save-interval 20000 \ - --eval-interval 20000 \ - --eval-iters 0 \ + --save-interval 1000 \ + --eval-interval 10000 \ + --eval-iters 10 \ --no-save-optim \ --no-save-rng " @@ -142,4 +150,4 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS pretrain_gpt.py \ $MOE_ARGS \ --distributed-backend nccl \ --save $CKPT_SAVE_DIR \ - | tee logs/npu_pretrain_mcore_deepseek2_lite_ptd_8p.log + | tee logs/pretrain_deepseek2_lite_ptd_8p.log diff --git a/examples/mcore/deepseek2_lite/tune_deepseek2_lite_16b_full_ptd.sh b/examples/mcore/deepseek2_lite/tune_deepseek2_lite_16b_full_ptd.sh index 7ee2231e2407b6faee4354ae31c3e2fdd6c56be5..bd4a57ab708c6f8128800aa8906ea93f1688eec7 100644 --- a/examples/mcore/deepseek2_lite/tune_deepseek2_lite_16b_full_ptd.sh +++ b/examples/mcore/deepseek2_lite/tune_deepseek2_lite_16b_full_ptd.sh @@ -37,17 +37,20 @@ MLA_ARGS=" " MOE_ARGS=" + --moe-grouped-gemm \ + --moe-alltoall-overlap-comm \ --moe-permutation-async-comm \ --moe-token-dispatcher-type alltoall \ + --use-fused-moe-token-permute-and-unpermute \ --first-k-dense-replace 1 \ --moe-layer-freq 1 \ --n-shared-experts 2 \ --num-experts 64 \ --moe-router-topk 6 \ --moe-intermediate-size 1408 \ - --moe-router-load-balancing-type softmax_topk \ + --moe-router-load-balancing-type aux_loss \ --topk-group 1 \ - --moe-aux-loss-coeff 0.001 \ + --moe-aux-loss-coeff 0.01 \ --routed-scaling-factor 1.0 \ --seq-aux " @@ -72,6 +75,7 @@ FITUNE_ARGS=" " GPT_ARGS=" + --shape-order BNSD \ --load $CKPT_LOAD_DIR \ --use-distributed-optimizer \ --use-flash-attn \ @@ -89,18 +93,20 @@ GPT_ARGS=" --num-attention-heads 16 \ --tokenizer-type PretrainedFromHF \ --tokenizer-name-or-path ${TOKENIZER_MODEL} \ - --seq-length 8192 \ + --num-workers 8 \ + --seq-length 4096 \ --max-position-embeddings 163840 \ --micro-batch-size 1 \ --global-batch-size 8 \ --make-vocab-size-divisible-by 1 \ - --lr 5e-5 \ + --lr 9e-6 \ --train-iters 2000 \ - --lr-decay-style constant \ + --lr-decay-style cosine \ + --lr-decay-iters 2000 \ --untie-embeddings-and-output-weights \ --disable-bias-linear \ --attention-dropout 0.0 \ - --init-method-std 0.02 \ + --init-method-std 0.008 \ --hidden-dropout 0.0 \ --position-embedding-type rope \ --normalization RMSNorm \ @@ -109,12 +115,13 @@ GPT_ARGS=" --use-fused-swiglu \ --use-fused-rmsnorm \ --swiglu \ + --dataloader-type cyclic \ --no-masked-softmax-fusion \ --attention-softmax-in-fp32 \ - --weight-decay 0e0 \ + --weight-decay 0.1 \ --clip-grad 1.0 \ --adam-beta1 0.9 \ - --adam-beta2 0.999 \ + --adam-beta2 0.95 \ --initial-loss-scale 1 \ --vocab-size 102400 \ --padded-vocab-size 102400 \ @@ -124,7 +131,7 @@ GPT_ARGS=" --no-load-optim \ --no-load-rng \ --bf16 \ - --reuse-fp32-param + --reuse-fp32-param \ " DATA_ARGS=" @@ -151,4 +158,4 @@ torchrun $DISTRIBUTED_ARGS posttrain_gpt.py \ $FITUNE_ARGS \ --distributed-backend nccl \ --save $CKPT_SAVE_DIR \ - | tee ./logs/npu_tune_mcore_deepseek2_lite.log + | tee ./logs/tune_deepseek2_lite_ptd_8p.log diff --git a/examples/mcore/deepseek2_lite/tune_deepseek2_lite_16b_full_ptd_16p.sh b/examples/mcore/deepseek2_lite/tune_deepseek2_lite_16b_full_ptd_16p.sh new file mode 100644 index 0000000000000000000000000000000000000000..26c871495499e090f62d195b42018febfcbd8fe5 --- /dev/null +++ b/examples/mcore/deepseek2_lite/tune_deepseek2_lite_16b_full_ptd_16p.sh @@ -0,0 +1,162 @@ +#!/bin/bash +export CUDA_DEVICE_MAX_CONNECTIONS=1 +export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True + +GPUS_PER_NODE=16 +MASTER_ADDR=localhost +MASTER_PORT=6000 +NNODES=1 +NODE_RANK=0 +WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES)) + +CKPT_SAVE_DIR="your checkpoint save path" +DATA_PATH="your finetune dataset path" +TOKENIZER_MODEL="your tokenizer model path" +CKPT_LOAD_DIR="your checkpoint load path" + +TP=1 +PP=1 +EP=8 + +DISTRIBUTED_ARGS=" + --nproc_per_node $GPUS_PER_NODE \ + --nnodes $NNODES \ + --node_rank $NODE_RANK \ + --master_addr $MASTER_ADDR \ + --master_port $MASTER_PORT +" + +MLA_ARGS=" + --spec modellink.tasks.models.spec.deepseek_spec layer_spec \ + --multi-head-latent-attention \ + --qk-rope-head-dim 64 \ + --qk-nope-head-dim 128 \ + --kv-lora-rank 512 \ + --v-head-dim 128 \ + --qk-layernorm \ +" + +MOE_ARGS=" + --moe-grouped-gemm \ + --moe-alltoall-overlap-comm \ + --moe-permutation-async-comm \ + --moe-token-dispatcher-type alltoall \ + --use-fused-moe-token-permute-and-unpermute \ + --first-k-dense-replace 1 \ + --moe-layer-freq 1 \ + --n-shared-experts 2 \ + --num-experts 64 \ + --moe-router-topk 6 \ + --moe-intermediate-size 1408 \ + --moe-router-load-balancing-type aux_loss \ + --topk-group 1 \ + --moe-aux-loss-coeff 0.01 \ + --routed-scaling-factor 1.0 \ + --seq-aux +" + +ROPE_ARGS=" + --rope-scaling-beta-fast 32 \ + --rope-scaling-beta-slow 1 \ + --rope-scaling-factor 40 \ + --rope-scaling-mscale 0.707 \ + --rope-scaling-mscale-all-dim 0.707 \ + --rope-scaling-original-max-position-embeddings 4096 \ + --rope-scaling-type yarn +" + +FITUNE_ARGS=" + --stage sft \ + --finetune \ + --is-instruction-dataset \ + --variable-seq-lengths \ + --prompt-type deepseek2-lite \ + --tokenizer-not-use-fast \ + " + + +GPT_ARGS=" + --shape-order BNSD \ + --load $CKPT_LOAD_DIR \ + --use-distributed-optimizer \ + --use-flash-attn \ + --use-mcore-models \ + --reuse-fp32-param \ + --tensor-model-parallel-size ${TP} \ + --pipeline-model-parallel-size ${PP} \ + --expert-model-parallel-size ${EP} \ + --sequence-parallel \ + --num-layers 27 \ + --recompute-granularity full \ + --recompute-method uniform \ + --recompute-num-layers 1 \ + --hidden-size 2048 \ + --ffn-hidden-size 10944 \ + --num-attention-heads 16 \ + --tokenizer-type PretrainedFromHF \ + --tokenizer-name-or-path ${TOKENIZER_MODEL} \ + --num-workers 8 \ + --seq-length 4096 \ + --max-position-embeddings 163840 \ + --micro-batch-size 1 \ + --global-batch-size 768 \ + --make-vocab-size-divisible-by 1 \ + --lr 9e-6 \ + --train-iters 462240 \ + --lr-decay-style cosine \ + --lr-decay-iters 462240 \ + --untie-embeddings-and-output-weights \ + --disable-bias-linear \ + --attention-dropout 0.0 \ + --init-method-std 0.008 \ + --hidden-dropout 0.0 \ + --position-embedding-type rope \ + --normalization RMSNorm \ + --use-fused-rotary-pos-emb \ + --use-rotary-position-embeddings \ + --use-fused-swiglu \ + --use-fused-rmsnorm \ + --swiglu \ + --dataloader-type cyclic \ + --no-masked-softmax-fusion \ + --attention-softmax-in-fp32 \ + --weight-decay 0.1 \ + --clip-grad 1.0 \ + --adam-beta1 0.9 \ + --adam-beta2 0.95 \ + --initial-loss-scale 1 \ + --vocab-size 102400 \ + --padded-vocab-size 102400 \ + --rotary-base 10000 \ + --no-gradient-accumulation-fusion \ + --norm-epsilon 1e-6 \ + --no-load-optim \ + --no-load-rng \ + --bf16 \ +" + +DATA_ARGS=" + --data-path $DATA_PATH \ + --split 100,0,0 +" + +OUTPUT_ARGS=" + --log-interval 1 \ + --save-interval 2000 \ + --eval-interval 1000 \ + --eval-iters 0 \ + --no-save-optim \ + --no-save-rng +" + +torchrun $DISTRIBUTED_ARGS posttrain_gpt.py \ + $GPT_ARGS \ + $DATA_ARGS \ + $OUTPUT_ARGS \ + $MLA_ARGS \ + $ROPE_ARGS \ + $MOE_ARGS \ + $FITUNE_ARGS \ + --distributed-backend nccl \ + --save $CKPT_SAVE_DIR \ + | tee ./logs/tune_deepseek2_lite_ptd_16p.log \ No newline at end of file diff --git a/modellink/core/transformer/mlp.py b/modellink/core/transformer/mlp.py index 28e6df291390446061f910258f40433bdf08b8fe..0a1a8580a3fbaf8f4b46a720b2ebc24442229b95 100644 --- a/modellink/core/transformer/mlp.py +++ b/modellink/core/transformer/mlp.py @@ -69,7 +69,7 @@ def should_recompute_activation(self): return recompute_priority < activation_recompute_layers -def core_mlp_init(self, config, submodules, is_expert=False, input_size=None): +def core_mlp_init(self, config, submodules, is_expert=False, input_size=None, shared_expert=False): super(MLP, self).__init__(config=config) self.config: TransformerConfig = config @@ -94,30 +94,62 @@ def core_mlp_init(self, config, submodules, is_expert=False, input_size=None): if self.config.gated_linear_unit: ffn_hidden_size *= 2 - self.linear_fc1 = build_module( - submodules.linear_fc1, - self.input_size, - ffn_hidden_size, - config=self.config, - init_method=self.config.init_method, - gather_output=False, - bias=self.config.add_bias_linear, - skip_bias_add=True, - is_expert=is_expert, - tp_comm_buffer_name='fc1', - ) + if shared_expert: + self.linear_fc1 = build_module( + submodules.linear_fc1, + self.input_size, + ffn_hidden_size, + config=self.config, + init_method=self.config.init_method, + gather_output=False, + bias=self.config.add_bias_linear, + skip_bias_add=True, + is_expert=is_expert, + tp_comm_buffer_name='fc1', + shared_expert=shared_expert + ) + else: + self.linear_fc1 = build_module( + submodules.linear_fc1, + self.input_size, + ffn_hidden_size, + config=self.config, + init_method=self.config.init_method, + gather_output=False, + bias=self.config.add_bias_linear, + skip_bias_add=True, + is_expert=is_expert, + tp_comm_buffer_name='fc1' + ) self.activation_func = self.config.activation_func - self.linear_fc2 = build_module( - submodules.linear_fc2, - self.config.ffn_hidden_size, - self.config.hidden_size, - config=self.config, - init_method=self.config.output_layer_init_method, - bias=self.config.add_bias_linear, - input_is_parallel=True, - skip_bias_add=True, - is_expert=is_expert, - tp_comm_buffer_name='fc2', - ) + if shared_expert: + self.linear_fc2 = build_module( + submodules.linear_fc2, + self.config.ffn_hidden_size, + self.config.hidden_size, + config=self.config, + init_method=self.config.output_layer_init_method, + bias=self.config.add_bias_linear, + input_is_parallel=True, + skip_bias_add=True, + is_expert=is_expert, + tp_comm_buffer_name='fc2', + shared_expert=shared_expert + ) + else: + self.linear_fc2 = build_module( + submodules.linear_fc2, + self.config.ffn_hidden_size, + self.config.hidden_size, + config=self.config, + init_method=self.config.output_layer_init_method, + bias=self.config.add_bias_linear, + input_is_parallel=True, + skip_bias_add=True, + is_expert=is_expert, + tp_comm_buffer_name='fc2' + ) + + self.shared_expert = shared_expert diff --git a/modellink/core/transformer/moe/moe_layer.py b/modellink/core/transformer/moe/moe_layer.py index d9fca0793dc1f67137fc95a097449fe5bfa416da..bf84dd0a9c41e888c14cb472a4c9f90d7e07e3c8 100644 --- a/modellink/core/transformer/moe/moe_layer.py +++ b/modellink/core/transformer/moe/moe_layer.py @@ -13,6 +13,7 @@ from megatron.core.transformer.mlp import MLPSubmodules, MLP from megatron.core.transformer.moe.experts import GroupedMLP, SequentialMLP from megatron.core.transformer.moe.moe_utils import save_to_aux_losses_tracker from megatron.training import get_args +from mindspeed.core.transformer.moe.moe_layer_overlap_all2all import MoELayerOverlapAll2All def moe_layer_init_wrapper(init_func): @@ -34,8 +35,14 @@ def moe_layer_init_wrapper(init_func): if global_args.n_shared_experts: config = deepcopy(self.config) config.ffn_hidden_size = global_args.n_shared_experts * self.config.ffn_hidden_size - self.shared_experts = MLP(config, MLPSubmodules(linear_fc1=ColumnParallelLinear, - linear_fc2=RowParallelLinear,)) + + if global_args.moe_allgather_overlap_comm or global_args.moe_alltoall_overlap_comm: + from mindspeed.core.transformer.moe.layers import ColumnParallelLinear, RowParallelLinear + self.shared_experts = MLP(config, MLPSubmodules(linear_fc1=ColumnParallelLinear,linear_fc2 = RowParallelLinear), shared_expert=True) + else: + from megatron.core.tensor_parallel import ColumnParallelLinear, RowParallelLinear + self.shared_experts = MLP(config, MLPSubmodules(linear_fc1=ColumnParallelLinear,linear_fc2 = RowParallelLinear)) + # For using layer_number when recompute activation function is enabled. self.shared_experts.layer_number = self.layer_number if global_args.shared_expert_gate: @@ -53,6 +60,10 @@ def moe_layer_init_wrapper(init_func): def moe_layer_forward(self, hidden_states: torch.Tensor): + global_args = get_args() + if global_args.moe_token_dispatcher_type == 'alltoall' and global_args.moe_alltoall_overlap_comm: + return MoELayerOverlapAll2All.apply(hidden_states, self) + # process MoE scores, indices = self.router(hidden_states) diff --git a/modellink/core/transformer/moe/router.py b/modellink/core/transformer/moe/router.py index a5badc22a25f1b2969538cc47fe014cec2919a8a..4c3e9966346b5ec47c2ef309ffb73e15492a2ccb 100644 --- a/modellink/core/transformer/moe/router.py +++ b/modellink/core/transformer/moe/router.py @@ -21,9 +21,8 @@ from megatron.core.tensor_parallel import gather_from_sequence_parallel_region from megatron.training import get_args from megatron.core.transformer.moe.moe_utils import MoEAuxLossAutoScaler, save_to_aux_losses_tracker from megatron.core import parallel_state - from .moe_utils import topk_softmax_with_capacity, switch_load_balancing_loss_func - +from modellink.tasks.models.common.pai_megatron import pai_megatron_aux_loss def group_limited_greedy_topKgating(self, logits: torch.Tensor): args = get_args() @@ -234,6 +233,8 @@ def topk_router_routing(self, logits: torch.Tensor): scores, indices = torch.topk(logits_, k=self.topk, dim=1) elif self.routing_type == "group_limited_greedy": scores, indices = group_limited_greedy_topKgating(self, logits) + elif self.routing_type == "pai_megatron_aux_loss": + scores, indices = pai_megatron_aux_loss(self, logits) elif self.routing_type == "none": # A naive top-k routing without load balancing # top_logits, indices = torch.topk(logits, k=self.topk, dim=1) diff --git a/modellink/tasks/megatron_adaptor.py b/modellink/tasks/megatron_adaptor.py index 6b882591f5d58c7f49a1b86c091bd1879c1eeb06..360b43fa241fb922a330e1198fd8f122d52bff8d 100644 --- a/modellink/tasks/megatron_adaptor.py +++ b/modellink/tasks/megatron_adaptor.py @@ -235,14 +235,13 @@ class CoreAdaptation(MegatronAdaptationABC): transformer_block_checkpointed_forward_wrapper) def patch_core_transformers(self): - from mindspeed.core.transformer.moe.router import aux_loss_load_balancing from mindspeed.core.transformer.moe.token_dispatcher import allgather_token_permutation, \ allgather_token_unpermutation from mindspeed.core.transformer.moe.grouped_gemm_util import Ops, grouped_gemm_is_available, \ get_device_capability, assert_grouped_gemm_is_available from mindspeed.core.transformer.transformer import core_mlp_forward_wrapper from mindspeed.core.transformer.moe.moe_utils import permute, unpermute - + from mindspeed.core.transformer.moe.experts import group_mlp_forward from ..core.transformer.moe.moe_layer import moe_layer_init_wrapper, moe_layer_forward from ..core.transformer.transformer_block import _transformer_block_build_layers from ..core.transformer.transformer_layer import transformer_layer_init_wrapper @@ -286,6 +285,7 @@ class CoreAdaptation(MegatronAdaptationABC): args = MegatronAdaptation.get_args() if args.moe_permutation_async_comm: if args.moe_token_dispatcher_type == 'allgather': + from mindspeed.core.transformer.moe.router import aux_loss_load_balancing MegatronAdaptation.register( 'megatron.core.transformer.moe.token_dispatcher.MoEAllGatherTokenDispatcher.token_permutation', allgather_token_permutation) @@ -300,13 +300,26 @@ class CoreAdaptation(MegatronAdaptationABC): MegatronAdaptation.register( 'megatron.core.transformer.moe.token_dispatcher.MoEAlltoAllTokenDispatcher.preprocess', preprocess) - MegatronAdaptation.register( - 'megatron.core.transformer.moe.token_dispatcher.MoEAlltoAllTokenDispatcher.token_permutation', - alltoall_token_permutation) MegatronAdaptation.register('megatron.core.transformer.moe.experts.SequentialMLP.forward', sequential_mlp_forward) MegatronAdaptation.register('megatron.core.transformer.moe.moe_utils.permute', permute) MegatronAdaptation.register('megatron.core.transformer.moe.moe_utils.unpermute', unpermute) + if args.moe_alltoall_overlap_comm: + from mindspeed.core.transformer.moe.token_dispatcher import alltoall_token_permutation_new, \ + alltoall_token_unpermutation_new + MegatronAdaptation.register('megatron.core.transformer.moe.experts.GroupedMLP.forward', + group_mlp_forward) + MegatronAdaptation.register( + 'megatron.core.transformer.moe.token_dispatcher.MoEAlltoAllTokenDispatcher.token_permutation', + alltoall_token_permutation_new) + MegatronAdaptation.register( + 'megatron.core.transformer.moe.token_dispatcher.MoEAlltoAllTokenDispatcher.token_unpermutation', + alltoall_token_unpermutation_new) + else: + MegatronAdaptation.register( + 'megatron.core.transformer.moe.token_dispatcher.MoEAlltoAllTokenDispatcher.token_permutation', + alltoall_token_permutation) + if hasattr(args, 'use_fused_moe_token_permute_and_unpermute') and args.use_fused_moe_token_permute_and_unpermute and not args.moe_expert_capacity_factor: from mindspeed.core.fusions.npu_moe_token_permute import permute_wrapper from mindspeed.core.fusions.npu_moe_token_unpermute import unpermute_wrapper diff --git a/modellink/tasks/models/common/pai_megatron.py b/modellink/tasks/models/common/pai_megatron.py new file mode 100644 index 0000000000000000000000000000000000000000..3299db58415a3f7ae98675d7e26a783c4f5636f6 --- /dev/null +++ b/modellink/tasks/models/common/pai_megatron.py @@ -0,0 +1,40 @@ +# coding=utf-8 +# Copyright (c) 2024, HUAWEI CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from megatron.training import get_args + +def pai_megatron_aux_loss(self, logits: torch.Tensor): + routing_weights = torch.softmax(logits, dim=1, dtype=torch.float32).type_as(logits) + scores, indices = torch.topk(routing_weights, k=self.topk, dim=-1) + + # TopK without capacity + num_experts = logits.shape[1] + tokens_per_expert = torch.histc(indices, bins=num_experts, min=0, max=num_experts) + + # Apply load balancing loss + probs = torch.softmax(logits, dim=-1, dtype=torch.float32) + scores = self.apply_load_balancing_loss(probs, tokens_per_expert, activation=scores) + + args = get_args() + global_indices = indices + if args.moe_token_dispatcher_type == "allgather": + if args.moe_permutation_async_comm and ( + self.config.sequence_parallel or (self.config.expert_model_parallel_size > 1)): + from mindspeed.core.transformer.moe.router import gather_from_sequence_parallel_region_to_moe_async + with torch.no_grad(): + global_indices = gather_from_sequence_parallel_region_to_moe_async(indices) + return scores, global_indices + diff --git a/modellink/tasks/preprocess/decoder_packed_mtf_dataset.py b/modellink/tasks/preprocess/decoder_packed_mtf_dataset.py index 0630bb43bb315de91fe9607ebd87a37ec3aa1548..4b978c803fe0cec5c333741ffd4ce3c9e214fa64 100644 --- a/modellink/tasks/preprocess/decoder_packed_mtf_dataset.py +++ b/modellink/tasks/preprocess/decoder_packed_mtf_dataset.py @@ -185,7 +185,7 @@ class DecoderPackedMTFDataset(torch.utils.data.Dataset): def _cut_token(self, token, dtype): token_length = len(token) - if token_length >= self.seq_length: + if not self.args.no_cut_token and token_length >= self.seq_length: token = token[:self.seq_length] return token.astype(dtype) diff --git a/modellink/training/arguments.py b/modellink/training/arguments.py index e8cc4ad26d55cd3b5b94b3e67a60db2372867a4c..b926d0011691ab1a13d9051b39d5b357a742e431 100644 --- a/modellink/training/arguments.py +++ b/modellink/training/arguments.py @@ -288,13 +288,14 @@ def _add_moe_args(parser): group.add_argument('--moe-router-topk', type=int, default=2, help='Number of experts to route to for each token. The default is 2.') group.add_argument('--moe-router-load-balancing-type', type=str, - choices=['aux_loss', "group_limited_greedy", "softmax_topk"], + choices=['aux_loss', "group_limited_greedy", "softmax_topk", "pai_megatron_aux_loss"], default='aux_loss', help='Determines the load balancing strategy for the router. "aux_loss" corresponds ' 'to the load balancing loss used in GShard and SwitchTransformer, "sinkhorn" corresponds ' 'to the balancing algorithm used in S-BASE, "softmax_topk" implies no load balancing and ' 'softmax before topk , "None" implies no load balancing, and "group_limited_greedy" corresponds ' - 'to the Device-Limited Routing method in DeepSeekV2.' + 'to the Device-Limited Routing method in DeepSeekV2. and "pai_megatron_aux_loss" corresponds ' + ' to the load balancing loss used in pai-megatron loss' 'The default is "aux_loss".') group.add_argument('--expert-interval', type=int, default=1, help='Use experts in every "expert-interval" layers') @@ -334,6 +335,8 @@ def _add_moe_args(parser): help="moe model shared expert gate output dimension for qwen2 moe, this parameter can only configured with" "1 or hidden_state") group.add_argument("--fix-router", action='store_true', help="fix router for load balancing.") + group.add_argument('--moe-alltoall-overlap-comm', action='store_true', default=False, + help='moe_alltoall_overlap_comm') return parser @@ -587,6 +590,8 @@ def _add_training_args(parser): help='scale embed tokens') group.add_argument('--dim-model-base', type=float, default=None, help='dim-model-base') + group.add_argument('--no-cut-token', action='store_true', default=False, + help='Used for not cut token in finetune.') group.add_argument('--scale-depth', type=float, default=None, help='scale-depth') group.add_argument('--swap-attention', action='store_true', default=False, @@ -769,6 +774,14 @@ def _validate_moe_args(args): raise AssertionError('shared expert gate output dimension can only be configured with 1 or hidden_size') if hasattr(args, 'use_fused_moe_token_permute_and_unpermute') and args.use_fused_moe_token_permute_and_unpermute: raise AssertionError('moe_expert_capacity_factor mode does not support use_fused_moe_token_permute_and_unpermute') + if args.moe_alltoall_overlap_comm: + if not args.moe_permutation_async_comm or not args.moe_grouped_gemm: + raise AssertionError( + '`--moe-alltoall-overlap-comm` or `--moe-allgather-overlap-comm` only support with `--moe-permutation-async-comm` and `--moe-grouped-gemm`.') + if args.moe_alltoall_overlap_comm and not args.moe_token_dispatcher_type == 'alltoall': + raise AssertionError('`--moe-alltoall-overlap-comm` only support with `--moe-token-dispatcher-type alltoall`.') + if args.moe_alltoall_overlap_comm and args.tensor_model_parallel_size > 1: + raise AssertionError('`--moe-alltoall-overlap-comm` do not support tp for now.') def _validate_mla(args): @@ -960,11 +973,12 @@ def _add_dummy_args(args): args.tp_x = 1 args.tp_y = 1 args.use_nd_matmul = False - args.moe_alltoall_overlap_comm = False args.moe_allgather_overlap_comm = False args.moe_without_activation = False args.disable_gloo_group = None args.ampipe_degree = 0 + args.moe_zero_memory = 'disable' + args.moe_zero_memory_num_layers = None def _validate_noop_layer(args):