diff --git a/examples/llama2/README.md b/examples/llama2/README.md index 352ddd4e962d6db24096b041b482e5d81a58a452..eba332772beeb9d8ce2c35dc4ef42582125a6403 100755 --- a/examples/llama2/README.md +++ b/examples/llama2/README.md @@ -216,6 +216,11 @@ python tools/checkpoint/util.py --model-type GPT \ --is-instruction-dataset \ ``` + 全参微调提供新的unpad方案,减少计算冗余加速训练 + (1) 不开启--use-flash-attn和--sequence-parallel + (2) 加入--use-unpad + (3) 仅支持fp16 + 6.3 Lora微调 Lora微调的脚本配置是在预训练脚本基础上加上lora参数,如下所示: @@ -548,6 +553,11 @@ python tools/checkpoint/util.py --model-type GPT \ --is-instruction-dataset \ ``` + 全参微调提供新的unpad方案,减少计算冗余加速训练 + (1) 去掉--use-flash-attn和--variable-seq-lengths + (2) 加入--use-unpad + (3) 仅支持fp16 + 6.3 Lora微调 Lora微调的脚本配置是在全参微调脚本基础上加上lora参数,如下所示: diff --git a/examples/llama2/README_en.md b/examples/llama2/README_en.md index bdfd50194888ba06a2aaa6a0012325311523bff4..dc3665ad51660dfd0de611adc169c44382ca8add 100644 --- a/examples/llama2/README_en.md +++ b/examples/llama2/README_en.md @@ -216,6 +216,11 @@ Here's a hardware summary of pre-training LLAMA2-7B: --finetune \ --is-instruction-dataset \ ``` + Full parameter fine-tuning provides a new unpad solution to reduce computational redundancy and accelerate training + (1) Do not turn on --use-flash-attn and --sequence-parallel + (2) Join --use-unpad + (3) Only supports fp16 + 6.3 Lora Fine-Tuning The Lora fine-tuning script is configured by adding the following lora parameters to the pretrain_llama2_7b_ptd.sh script: ```bash @@ -526,6 +531,12 @@ Here's a hardware summary of pre-training LLaMA2-13B: --finetune \ --is-instruction-dataset \ ``` + + Full parameter fine-tuning provides a new unpad solution to reduce computational redundancy and accelerate training + (1) Remove -- use flash attn and -- variable seq lengths + (2) Join -- use unpad + (3) Only supports fp16 + 6.3 Lora Fine-Tuning The Lora fine-tuning script is configured by adding the following lora parameters based on the full-parameter finetune script pretrain_llama2_7b_ptd.sh: ```bash diff --git a/megatron/arguments.py b/megatron/arguments.py index d4f1cd5a324b5d23928a690ea7cf7bda30f4432f..b9d8352fbc9db20e9dd28339cdf93de8048578a8 100644 --- a/megatron/arguments.py +++ b/megatron/arguments.py @@ -646,7 +646,10 @@ def _add_network_size_args(parser): group.add_argument('--num-experts', type=int, default=None, help='Number of Experts in Switch Transformer (None means no Switch)') group.add_argument('--untie-embeddings-and-output-weights', action='store_true', - help='Untie embeddings and output weights.'), + help='Untie embeddings and output weights.') + group.add_argument('--use-unpad', action='store_true', + help='use unpad.', + dest='use_unpad'), return parser diff --git a/megatron/core/tensor_parallel/layers.py b/megatron/core/tensor_parallel/layers.py index f31ee42df623f4433af90570c5bf0fd0914aa4e0..362ba26ce2f461a3808d506b551f252f70ee292c 100644 --- a/megatron/core/tensor_parallel/layers.py +++ b/megatron/core/tensor_parallel/layers.py @@ -370,12 +370,14 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function): # https://github.com/pytorch/pytorch/blob/c47cf9bc7f9e02f649ab4ed53fe4d35732c92ab6/torch/_refs/__init__.py#L2761 grad_output = grad_output.contiguous() # Convert the tensor shapes to 2D for execution compatibility - grad_output = grad_output.view( - grad_output.shape[0] * grad_output.shape[1], grad_output.shape[2] - ) - total_input = total_input.view( - total_input.shape[0] * total_input.shape[1], total_input.shape[2] - ) + if grad_output.dim() != 2: + grad_output = grad_output.view( + grad_output.shape[0] * grad_output.shape[1], grad_output.shape[2] + ) + if total_input.dim() != 2: + total_input = total_input.view( + total_input.shape[0] * total_input.shape[1], total_input.shape[2] + ) if ctx.async_grad_allreduce: # Asynchronous all-reduce diff --git a/megatron/model/language_model.py b/megatron/model/language_model.py index 69bfa2e8018ed30411764bf357111ba5f4ed4d95..40a34b90c8192033030dcf81b013f4542c4ee8cb 100644 --- a/megatron/model/language_model.py +++ b/megatron/model/language_model.py @@ -5,6 +5,7 @@ import torch import torch.nn.functional as F +from megatron import get_tokenizer from megatron import get_args from megatron.core import mpu, tensor_parallel from megatron.core.enums import ModelType @@ -424,6 +425,7 @@ class TransformerLanguageModel(MegatronModule): init_method=self.init_method, bias=False) # Setting bias to False always to keep it consistent with embedding tying that also does not have a bias. self._output_layer_key = 'output_layer' + self.tokenizer = get_tokenizer() def set_input_tensor(self, input_tensor): """ See megatron.model.transformer.set_input_tensor()""" @@ -463,6 +465,10 @@ class TransformerLanguageModel(MegatronModule): pooling_sequence_index=0, enc_hidden_states=None, output_enc_hidden=False): + seq_lengths = torch.sum(enc_input_ids.ne(self.tokenizer.eod), 1) + seq_lengths = seq_lengths.cpu().tolist() + seq_lengths = [(s+15)//16*16 for s in seq_lengths] + # Encoder embedding. if self.pre_process: encoder_input = self.embedding(enc_input_ids, enc_position_ids, @@ -496,7 +502,8 @@ class TransformerLanguageModel(MegatronModule): retriever_input=retriever_input, retriever_attn_mask=retriever_attn_mask, inference_params=inference_params, - rotary_pos_emb=rotary_pos_emb) + rotary_pos_emb=rotary_pos_emb, + seq_lengths=seq_lengths) else: encoder_output = self.encoder_hidden_state else: @@ -530,7 +537,8 @@ class TransformerLanguageModel(MegatronModule): encoder_output=encoder_output, enc_dec_attn_mask=enc_dec_attn_mask, inference_params=inference_params, - rotary_pos_emb=rotary_pos_emb) + rotary_pos_emb=rotary_pos_emb, + seq_lengths=seq_lengths) if self.add_pooler and self.post_process: return decoder_output, encoder_output, pooled_output diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index 9f1144c02bb875f67ec556f739855b40eed8ea6f..c83421d32dab3074821e0b5c0b1fc8003250a462 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -632,7 +632,8 @@ class ParallelAttention(MegatronModule): def forward(self, hidden_states, attention_mask, encoder_output=None, inference_params=None, - rotary_pos_emb=None): + rotary_pos_emb=None, + seq_lengths=None): # hidden_states: [sq, b, h] # ================================================= @@ -814,6 +815,108 @@ class ParallelAttention(MegatronModule): return output, bias +class UnpadParallelAttention(ParallelAttention): + """Parallel self-attention layer abstract class. + + Self-attention layer takes input with size [b, s, h] + and returns output of the same size. + """ + + def __init__(self, config, layer_number, + attention_type=AttnType.self_attn, + attn_mask_type=AttnMaskType.padding): + super(UnpadParallelAttention, self).__init__(config, layer_number, + attention_type=AttnType.self_attn, + attn_mask_type=AttnMaskType.padding) + + from ascendspeed.ops.unpad_gen_attention_mask import UnpadGenAttentionMask + from ascendspeed.ops.unpad_rope import UnpadRope + from ascendspeed.ops.unpad_softmax import UnpadSoftMax + from ascendspeed.ops.unpad_bmm1 import UnpadBmm1 + from ascendspeed.ops.unpad_bmm2 import UnpadBmm2 + from ascendspeed.ops.unpad_seqlen import UnpadSeqLen + from ascendspeed.ops.pad_seqlen import PadSeqLen + + args = get_args() + self.use_unpad = args.use_unpad + self.bmm1 = UnpadBmm1(self.num_attention_heads_per_partition) + self.bmm2 = UnpadBmm2(self.num_attention_heads_per_partition) + self.unpad_softmax = UnpadSoftMax() + self.rope = UnpadRope() + self.hidden_size_per_partition = self.hidden_size_per_attention_head * self.num_attention_heads_per_partition + self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) + + def apply_unpad_rotary_pos_emb(self, query_layer, key_layer, rotary_pos_emb, seq_lengths): + cos = torch.cos(rotary_pos_emb).to(query_layer.dtype) + sin = torch.sin(rotary_pos_emb).to(query_layer.dtype) + query_layer, key_layer = self.rope(query_layer, key_layer, cos, sin, seq_lengths, offset=0) + return query_layer, key_layer + + def forward(self, hidden_states, attention_mask, + encoder_output=None, inference_params=None, + rotary_pos_emb=None, seq_lengths=None): + # ===================== + # Query, Key, and Value + # ===================== + if self.attention_type == AttnType.self_attn: + # Attention heads [bsq, h] --> [bsq, (np * 3 * hn)] + mixed_x_layer, _ = self.query_key_value(hidden_states) + + # [bsq, hp] --> [bsq, ng, (np/ng + 2) * hn] + new_tensor_shape = mixed_x_layer.size()[:-1] + ( + self.num_query_groups_per_partition, + ( + (self.num_attention_heads_per_partition // self.num_query_groups_per_partition + 2) + * self.hidden_size_per_attention_head + ), + ) + mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) + + # [bsq, (np/ng + 2) * hn] --> [bsq, ng, np/ng * hn], [bsq, ng, hn], [bsq, ng, hn] + (query_layer, + key_layer, + value_layer) = torch.split( + mixed_x_layer, + [ + ( + self.num_attention_heads_per_partition // self.num_query_groups_per_partition + * self.hidden_size_per_attention_head + ), + self.hidden_size_per_attention_head, + self.hidden_size_per_attention_head + ], + dim=2) + + # [bsq, ng, np/ng * hn] -> [bsq, np * hn] + query_layer = query_layer.contiguous().view(query_layer.size(0), self.hidden_size_per_partition) + key_layer = key_layer.contiguous().view(key_layer.size(0), self.hidden_size_per_partition) + value_layer = value_layer.contiguous().view(value_layer.size(0), self.hidden_size_per_partition) + + query_layer, key_layer = self.apply_unpad_rotary_pos_emb(query_layer, key_layer, rotary_pos_emb[:, 0, 0, :], seq_lengths) + + # =================================== + # Raw attention scores. + # =================================== + attention_scores = self.bmm1(query_layer, key_layer, seq_lengths) + + # =================================== + # Attention probs and dropout + # =================================== + attention_scores.masked_fill_(attention_mask, -10000.0) + attention_scores = attention_scores * (1.0 / self.norm_factor) + attention_scores = self.unpad_softmax(attention_scores, seq_lengths, self.num_attention_heads_per_partition) + + # =================================== + # Context layer. [sq, b, hp] + # =================================== + context_layer = self.bmm2(attention_scores, value_layer, seq_lengths) + # ================= + # Output. [bsq, h] + # ================= + output, bias = self.dense(context_layer) + return output, bias + + def bias_dropout_add(x, bias, residual, prob, training): # type: (Tensor, Optional[Tensor], Tensor, float, bool) -> Tensor if bias is not None: @@ -872,11 +975,18 @@ class ParallelTransformerLayer(MegatronModule): self.input_norm = get_norm(config) # Self attention. - self.self_attention = ParallelAttention( - config, - layer_number, - attention_type=AttnType.self_attn, - attn_mask_type=self_attn_mask_type) + if args.use_unpad: + self.self_attention = UnpadParallelAttention( + config, + layer_number, + attention_type=AttnType.self_attn, + attn_mask_type=self_attn_mask_type) + else: + self.self_attention = ParallelAttention( + config, + layer_number, + attention_type=AttnType.self_attn, + attn_mask_type=self_attn_mask_type) self.hidden_dropout = config.hidden_dropout self.bias_dropout_fusion = config.bias_dropout_fusion self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else None @@ -1133,7 +1243,7 @@ class ParallelTransformerLayer(MegatronModule): return retriever_output, norm_input, norm_output - def forward(self, hidden_states, attention_mask, + def forward(self, hidden_states, seq_lengths=None, attention_mask=None, encoder_output=None, enc_dec_attn_mask=None, retriever_input=None, retriever_output=None, @@ -1151,7 +1261,8 @@ class ParallelTransformerLayer(MegatronModule): norm_output, attention_mask, inference_params=inference_params, - rotary_pos_emb=rotary_pos_emb) + rotary_pos_emb=rotary_pos_emb, + seq_lengths=seq_lengths) # Residual connection. if self.apply_residual_connection_post_norm: @@ -1466,6 +1577,12 @@ class ParallelTransformer(MegatronModule): "Full recompute not supported for Retro." assert args.transformer_impl == 'local', \ "Transformer engine does not support Retro layers." + + world_size = mpu.get_tensor_model_parallel_world_size() + self.num_attention_heads_per_partition = core.utils.divide( + config.num_attention_heads, world_size) + self.genAttentionMask = UnpadGenAttentionMask(self.num_attention_heads_per_partition) + def build_layer(layer_number): if args.transformer_impl == 'local': current_layer_type = _get_layer_type( @@ -1577,19 +1694,23 @@ class ParallelTransformer(MegatronModule): # Final layer norm before output. self.final_norm = get_norm(config) + self.use_unpad = args.use_unpad + self.pad = PadSeqLen(args.seq_length) + self.unpad = UnpadSeqLen(args.seq_length) + def _get_layer(self, layer_number): return self.layers[layer_number] def _checkpointed_forward(self, hidden_states, attention_mask, encoder_output, enc_dec_attn_mask, - rotary_pos_emb, is_first_microbatch): + rotary_pos_emb, is_first_microbatch, seq_lengths=None): """Forward method with activation checkpointing.""" def custom(start, end): def custom_forward(*args, **kwargs): x_, *args = args for index in range(start, end): layer = self._get_layer(index) - x_ = layer(x_, *args, **kwargs) + x_ = layer(x_, seq_lengths, *args, **kwargs) return x_ return custom_forward @@ -1675,7 +1796,8 @@ class ParallelTransformer(MegatronModule): retriever_output=None, retriever_attn_mask=None, inference_params=None, - rotary_pos_emb=None): + rotary_pos_emb=None, + seq_lengths=None): # hidden_states: [s, b, h] # Checks. @@ -1686,6 +1808,10 @@ class ParallelTransformer(MegatronModule): if not self.pre_process: # See set_input_tensor() hidden_states = self.input_tensor + if self.use_unpad: + hidden_states = hidden_states.transpose(0, 1).contiguous() + hidden_states = self.unpad(hidden_states, seq_lengths) + attention_mask = self.genAttentionMask(attention_mask, seq_lengths) # Viewless tensor. # - We only need to create a viewless tensor in the case of micro batch @@ -1736,7 +1862,8 @@ class ParallelTransformer(MegatronModule): encoder_output, enc_dec_attn_mask, rotary_pos_emb, - is_first_microbatch) + is_first_microbatch, + seq_lengths) else: forward_kwargs = { 'encoder_output': encoder_output, @@ -1760,6 +1887,7 @@ class ParallelTransformer(MegatronModule): hidden_states = layer( hidden_states, + seq_lengths, attention_mask, **forward_kwargs) @@ -1775,6 +1903,10 @@ class ParallelTransformer(MegatronModule): if torch.is_grad_enabled() and self.training: self.microbatch_count += 1 + if self.use_unpad: + hidden_states = self.pad(hidden_states, seq_lengths) + hidden_states = hidden_states.transpose(0, 1).contiguous() + # Final layer norm. if self.post_process and self.post_norm: hidden_states = self.final_norm(hidden_states) diff --git a/modellink/model/transformer.py b/modellink/model/transformer.py index f81e447a670672f58fe3e05dd06f2ec6153a095a..41d6a22e34084c02d0c37208380d2ff7d024071b 100644 --- a/modellink/model/transformer.py +++ b/modellink/model/transformer.py @@ -116,8 +116,23 @@ def parallel_transformer_init(self, config, pre_process=True, post_process=True, drop_path_rate=0.0): + super(ParallelTransformer, self).__init__() args = get_args() + self.use_unpad = args.use_unpad + + if self.use_unpad: + from ascendspeed.ops.unpad_gen_attention_mask import UnpadGenAttentionMask + from ascendspeed.ops.unpad_seqlen import UnpadSeqLen + from ascendspeed.ops.pad_seqlen import PadSeqLen + + self.pad = PadSeqLen(args.seq_length) + self.unpad = UnpadSeqLen(args.seq_length) + + world_size = mpu.get_tensor_model_parallel_world_size() + self.num_attention_heads_per_partition = core.utils.divide( + config.num_attention_heads, world_size) + self.genAttentionMask = UnpadGenAttentionMask(self.num_attention_heads_per_partition) self.layer_type = layer_type self.model_type = model_type diff --git a/tests/pipeline/llama2-13B/tune_llama2_13b_ptd.sh b/tests/pipeline/llama2-13B/tune_llama2_13b_ptd.sh new file mode 100644 index 0000000000000000000000000000000000000000..d96a405a28d8df42d0b0cc2dda59d72c9c1658af --- /dev/null +++ b/tests/pipeline/llama2-13B/tune_llama2_13b_ptd.sh @@ -0,0 +1,94 @@ +#!/bin/bash + +export HCCL_CONNECT_TIMEOUT=1200 +export COMBINED_ENABLE=1 +export AZUREML_EXPERIMENT_ID=0 + +export CUDA_DEVICE_MAX_CONNECTIONS=1 +export NPU_ASD_ENABLE=0 + +GPUS_PER_NODE=8 +MASTER_ADDR=localhost +MASTER_PORT=6020 +NNODES=1 +NODE_RANK=0 +WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES)) + +CKPT_LOAD_DIR=/home/dataset/llama2-13B-tp8-pp1 +DATA_PATH=/home/dataset/tune-dataset-llama2-13B/alpaca +TOKENIZER_MODEL=/home/dataset/llama2-13B +TP=8 +PP=1 + +DISTRIBUTED_ARGS=" + --nproc_per_node $GPUS_PER_NODE \ + --nnodes $NNODES \ + --node_rank $NODE_RANK \ + --master_addr $MASTER_ADDR \ + --master_port $MASTER_PORT +" + +GPT_ARGS=" + --tensor-model-parallel-size ${TP} \ + --pipeline-model-parallel-size ${PP} \ + --sequence-parallel \ + --num-layers 40 \ + --hidden-size 5120 \ + --ffn-hidden-size 13824 \ + --load ${CKPT_LOAD_DIR} \ + --num-attention-heads 40 \ + --tokenizer-type PretrainedFromHF \ + --tokenizer-name-or-path ${TOKENIZER_MODEL} \ + --tokenizer-not-use-fast \ + --seq-length 4096 \ + --max-position-embeddings 4096 \ + --micro-batch-size 2 \ + --global-batch-size 16 \ + --make-vocab-size-divisible-by 1 \ + --lr 1e-6 \ + --train-iters 1000 \ + --lr-decay-style cosine \ + --untie-embeddings-and-output-weights \ + --disable-bias-linear \ + --attention-dropout 0.0 \ + --init-method-std 0.01 \ + --hidden-dropout 0.0 \ + --position-embedding-type rope \ + --normalization RMSNorm \ + --use-fused-rmsnorm \ + --swiglu \ + --use-flash-attn \ + --no-masked-softmax-fusion \ + --attention-softmax-in-fp32 \ + --min-lr 1e-8 \ + --weight-decay 1e-1 \ + --lr-warmup-fraction 0.01 \ + --clip-grad 1.0 \ + --adam-beta1 0.9 \ + --initial-loss-scale 4096 \ + --adam-beta2 0.95 \ + --no-gradient-accumulation-fusion \ + --no-load-optim \ + --no-load-rng \ + --finetune \ + --is-instruction-dataset \ + --fp16 +" + +DATA_ARGS=" + --data-path $DATA_PATH \ + --split 100,0,0 +" + +OUTPUT_ARGS=" + --log-interval 1 \ + --save-interval 10000 \ + --eval-interval 1000 \ + --eval-iters 0 \ +" + +torchrun $DISTRIBUTED_ARGS pretrain_gpt.py \ + $GPT_ARGS \ + $DATA_ARGS \ + $OUTPUT_ARGS \ + --distributed-backend nccl 2>&1 | tee /home/dataset/tune_llama2_13b_ptd_baseline.log \ No newline at end of file diff --git a/tests/pipeline/llama2-13B/tune_llama2_unpad_13b_ptd.sh b/tests/pipeline/llama2-13B/tune_llama2_unpad_13b_ptd.sh new file mode 100644 index 0000000000000000000000000000000000000000..decaf6af2cca8868569051df15ccc589a79e857d --- /dev/null +++ b/tests/pipeline/llama2-13B/tune_llama2_unpad_13b_ptd.sh @@ -0,0 +1,99 @@ +#!/bin/bash + +export HCCL_CONNECT_TIMEOUT=1200 +export COMBINED_ENABLE=1 +export AZUREML_EXPERIMENT_ID=0 + +export CUDA_DEVICE_MAX_CONNECTIONS=1 +export NPU_ASD_ENABLE=0 + +cd /usr/local/Ascend/atb +source set_env.sh +cd - + +export WITHOUT_JIT_COMPILE=1 + +GPUS_PER_NODE=8 +MASTER_ADDR=localhost +MASTER_PORT=6021 +NNODES=1 +NODE_RANK=0 +WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES)) + +CKPT_LOAD_DIR=/home/dataset/llama2-13B-tp8-pp1 +DATA_PATH=/home/dataset/tune-dataset-llama2-13B/alpaca +TOKENIZER_MODEL=/home/dataset/llama2-13B +TP=8 +PP=1 + +DISTRIBUTED_ARGS=" + --nproc_per_node $GPUS_PER_NODE \ + --nnodes $NNODES \ + --node_rank $NODE_RANK \ + --master_addr $MASTER_ADDR \ + --master_port $MASTER_PORT +" + +GPT_ARGS=" + --tensor-model-parallel-size ${TP} \ + --pipeline-model-parallel-size ${PP} \ + --num-layers 40 \ + --hidden-size 5120 \ + --ffn-hidden-size 13824 \ + --load ${CKPT_LOAD_DIR} \ + --num-attention-heads 40 \ + --tokenizer-type PretrainedFromHF \ + --tokenizer-name-or-path ${TOKENIZER_MODEL} \ + --tokenizer-not-use-fast \ + --seq-length 4096 \ + --max-position-embeddings 4096 \ + --micro-batch-size 2 \ + --global-batch-size 16 \ + --make-vocab-size-divisible-by 1 \ + --lr 1e-6 \ + --train-iters 1000 \ + --lr-decay-style cosine \ + --untie-embeddings-and-output-weights \ + --disable-bias-linear \ + --attention-dropout 0.0 \ + --init-method-std 0.01 \ + --hidden-dropout 0.0 \ + --position-embedding-type rope \ + --normalization RMSNorm \ + --use-fused-rmsnorm \ + --swiglu \ + --no-masked-softmax-fusion \ + --attention-softmax-in-fp32 \ + --min-lr 1e-8 \ + --weight-decay 1e-1 \ + --lr-warmup-fraction 0.01 \ + --clip-grad 1.0 \ + --adam-beta1 0.9 \ + --initial-loss-scale 4096 \ + --adam-beta2 0.95 \ + --no-gradient-accumulation-fusion \ + --no-load-optim \ + --no-load-rng \ + --finetune \ + --is-instruction-dataset \ + --fp16 +" + +DATA_ARGS=" + --data-path $DATA_PATH \ + --split 100,0,0 +" + +OUTPUT_ARGS=" + --log-interval 1 \ + --save-interval 10000 \ + --eval-interval 1000 \ + --eval-iters 0 \ +" + +torchrun $DISTRIBUTED_ARGS pretrain_gpt.py \ + $GPT_ARGS \ + $DATA_ARGS \ + $OUTPUT_ARGS \ + --use-unpad \ + --distributed-backend nccl 2>&1 | tee /home/dataset/tune_llama2_13b_ptd_unpad.log \ No newline at end of file diff --git a/tests/pipeline/llama2-7B/tune_llama2_7b_ptd.sh b/tests/pipeline/llama2-7B/tune_llama2_7b_ptd.sh new file mode 100644 index 0000000000000000000000000000000000000000..3f99009d9aa760c8f86db90f14954e42254b050c --- /dev/null +++ b/tests/pipeline/llama2-7B/tune_llama2_7b_ptd.sh @@ -0,0 +1,90 @@ +#!/bin/bash + +export CUDA_DEVICE_MAX_CONNECTIONS=1 +export NPU_ASD_ENABLE=0 + +GPUS_PER_NODE=8 +MASTER_ADDR=localhost +MASTER_PORT=6018 +NNODES=1 +NODE_RANK=0 +WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES)) + +CKPT_LOAD_DIR=/home/dataset/llama2-7B-tp8-pp1 +DATA_PATH=/home/dataset/tune-dataset-llama2-7B/alpaca +TOKENIZER_MODEL=/home/dataset/llama2-7B +TP=8 +PP=1 + +DISTRIBUTED_ARGS=" + --nproc_per_node $GPUS_PER_NODE \ + --nnodes $NNODES \ + --node_rank $NODE_RANK \ + --master_addr $MASTER_ADDR \ + --master_port $MASTER_PORT +" + +GPT_ARGS=" + --tensor-model-parallel-size ${TP} \ + --pipeline-model-parallel-size ${PP} \ + --sequence-parallel \ + --num-layers 32 \ + --hidden-size 4096 \ + --ffn-hidden-size 11008 \ + --load ${CKPT_LOAD_DIR} \ + --num-attention-heads 32 \ + --tokenizer-type PretrainedFromHF \ + --tokenizer-name-or-path ${TOKENIZER_MODEL} \ + --tokenizer-not-use-fast \ + --seq-length 4096 \ + --max-position-embeddings 4096 \ + --micro-batch-size 4 \ + --global-batch-size 16 \ + --make-vocab-size-divisible-by 1 \ + --lr 1.25e-6 \ + --train-iters 1000 \ + --lr-decay-style cosine \ + --untie-embeddings-and-output-weights \ + --disable-bias-linear \ + --attention-dropout 0.0 \ + --init-method-std 0.01 \ + --hidden-dropout 0.0 \ + --position-embedding-type rope \ + --normalization RMSNorm \ + --use-fused-rmsnorm \ + --swiglu \ + --use-flash-attn \ + --no-masked-softmax-fusion \ + --attention-softmax-in-fp32 \ + --min-lr 1.25e-7 \ + --weight-decay 1e-1 \ + --lr-warmup-fraction 0.01 \ + --clip-grad 1.0 \ + --adam-beta1 0.9 \ + --initial-loss-scale 65536 \ + --adam-beta2 0.95 \ + --no-gradient-accumulation-fusion \ + --no-load-optim \ + --no-load-rng \ + --finetune \ + --is-instruction-dataset \ + --fp16 +" + +DATA_ARGS=" + --data-path $DATA_PATH \ + --split 100,0,0 +" + +OUTPUT_ARGS=" + --log-interval 1 \ + --save-interval 10000 \ + --eval-interval 1000 \ + --eval-iters 0 \ +" + +torchrun $DISTRIBUTED_ARGS pretrain_gpt.py \ + $GPT_ARGS \ + $DATA_ARGS \ + $OUTPUT_ARGS \ + --distributed-backend nccl 2>&1 | tee /home/dataset/tune_llama2_7b_ptd_baseline.log \ No newline at end of file diff --git a/tests/pipeline/llama2-7B/tune_llama2_unpad_7b_ptd.sh b/tests/pipeline/llama2-7B/tune_llama2_unpad_7b_ptd.sh new file mode 100644 index 0000000000000000000000000000000000000000..c83c159345d71c0c11032a5a06e301ba28d00830 --- /dev/null +++ b/tests/pipeline/llama2-7B/tune_llama2_unpad_7b_ptd.sh @@ -0,0 +1,95 @@ +#!/bin/bash + +export CUDA_DEVICE_MAX_CONNECTIONS=1 +export NPU_ASD_ENABLE=0 + +cd /usr/local/Ascend/atb +source set_env.sh +cd - + +export WITHOUT_JIT_COMPILE=1 + +GPUS_PER_NODE=8 +MASTER_ADDR=localhost +MASTER_PORT=6019 +NNODES=1 +NODE_RANK=0 +WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES)) + +CKPT_LOAD_DIR=/home/dataset/llama2-7B-tp8-pp1 +DATA_PATH=/home/dataset/tune-dataset-llama2-7B/alpaca +TOKENIZER_MODEL=/home/dataset/llama2-7B +TP=8 +PP=1 + +DISTRIBUTED_ARGS=" + --nproc_per_node $GPUS_PER_NODE \ + --nnodes $NNODES \ + --node_rank $NODE_RANK \ + --master_addr $MASTER_ADDR \ + --master_port $MASTER_PORT +" + +GPT_ARGS=" + --tensor-model-parallel-size ${TP} \ + --pipeline-model-parallel-size ${PP} \ + --num-layers 32 \ + --hidden-size 4096 \ + --ffn-hidden-size 11008 \ + --load ${CKPT_LOAD_DIR} \ + --num-attention-heads 32 \ + --tokenizer-type PretrainedFromHF \ + --tokenizer-name-or-path ${TOKENIZER_MODEL} \ + --tokenizer-not-use-fast \ + --seq-length 4096 \ + --max-position-embeddings 4096 \ + --micro-batch-size 4 \ + --global-batch-size 16 \ + --make-vocab-size-divisible-by 1 \ + --lr 1.25e-6 \ + --train-iters 1000 \ + --lr-decay-style cosine \ + --untie-embeddings-and-output-weights \ + --disable-bias-linear \ + --attention-dropout 0.0 \ + --init-method-std 0.01 \ + --hidden-dropout 0.0 \ + --position-embedding-type rope \ + --normalization RMSNorm \ + --use-fused-rmsnorm \ + --swiglu \ + --no-masked-softmax-fusion \ + --attention-softmax-in-fp32 \ + --min-lr 1.25e-7 \ + --weight-decay 1e-1 \ + --lr-warmup-fraction 0.01 \ + --clip-grad 1.0 \ + --adam-beta1 0.9 \ + --initial-loss-scale 65536 \ + --adam-beta2 0.95 \ + --no-gradient-accumulation-fusion \ + --no-load-optim \ + --no-load-rng \ + --finetune \ + --is-instruction-dataset \ + --fp16 +" + +DATA_ARGS=" + --data-path $DATA_PATH \ + --split 100,0,0 +" + +OUTPUT_ARGS=" + --log-interval 1 \ + --save-interval 10000 \ + --eval-interval 1000 \ + --eval-iters 0 \ +" + +torchrun $DISTRIBUTED_ARGS pretrain_gpt.py \ + $GPT_ARGS \ + $DATA_ARGS \ + $OUTPUT_ARGS \ + --use-unpad \ + --distributed-backend nccl 2>&1 | tee /home/dataset/tune_llama2_7b_ptd_unpad.log \ No newline at end of file