diff --git a/amp_C.py b/amp_C.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/examples/gpt_task/pretrain_gpt_1p.sh b/examples/gpt_task/pretrain_gpt_1p.sh new file mode 100644 index 0000000000000000000000000000000000000000..5b77e4b1b932ea2e23ca3a21b5da958f44e61626 --- /dev/null +++ b/examples/gpt_task/pretrain_gpt_1p.sh @@ -0,0 +1,43 @@ +export LD_LIBRARY_PATH=/root/miniconda3/lib:$LD_LIBRARY_PATH +export HCCL_CONNECT_TIMEOUT=1200 +source /usr/local/Ascend/ascend-toolkit/set_env.sh +RANK=0 +WORLD_SIZE=1 + +DATA_PATH=./dataset/enwiki-gpt/gpt_text_sentence +CHECKPOINT_PATH=./ckpt + +export LOCAL_RANK=0 + +python pretrain_gpt.py \ + --DDP-impl local \ + --use-contiguous-buffers-in-ddp \ + --num-layers 1 \ + --hidden-size 4096 \ + --num-attention-heads 32 \ + --micro-batch-size 4 \ + --global-batch-size 8 \ + --seq-length 2048 \ + --max-position-embeddings 2048 \ + --train-iters 500000 \ + --lr-decay-iters 320000 \ + --save $CHECKPOINT_PATH \ + --load $CHECKPOINT_PATH \ + --data-path $DATA_PATH \ + --vocab-file ./dataset/gpt2-vocab.json \ + --merge-file ./dataset/gpt2-merges.txt \ + --data-impl mmap \ + --split 949,50,1 \ + --distributed-backend nccl \ + --lr 0.00015 \ + --min-lr 1.0e-5 \ + --lr-decay-style cosine \ + --weight-decay 1e-2 \ + --clip-grad 1.0 \ + --lr-warmup-fraction .01 \ + --checkpoint-activations \ + --log-interval 10 \ + --save-interval 10000 \ + --eval-interval 1000 \ + --eval-iters 10 \ + --fp16 | tee logs/train.log diff --git a/examples/gpt_task/pretrain_gpt_ptd_8p.sh b/examples/gpt_task/pretrain_gpt_ptd_8p.sh new file mode 100644 index 0000000000000000000000000000000000000000..4ffa16900b2599b3c2a8bd72f3c823b66eef647f --- /dev/null +++ b/examples/gpt_task/pretrain_gpt_ptd_8p.sh @@ -0,0 +1,51 @@ +export LD_LIBRARY_PATH=/root/miniconda3/lib:$LD_LIBRARY_PATH +export HCCL_CONNECT_TIMEOUT=1200 +source /usr/local/Ascend/ascend-toolkit/set_env.sh +GPUS_PER_NODE=8 +# Change for multinode config +MASTER_ADDR=localhost +MASTER_PORT=6000 +NNODES=1 +NODE_RANK=0 +WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES)) + +DATA_PATH=./dataset/enwiki-gpt/gpt_text_sentence +CHECKPOINT_PATH=./ckpt + +DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE --nnodes $NNODES --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT" + +python -m torch.distributed.launch $DISTRIBUTED_ARGS \ + pretrain_gpt.py \ + --DDP-impl local \ + --use-contiguous-buffers-in-ddp \ + --tensor-model-parallel-size 2 \ + --pipeline-model-parallel-size 2 \ + --num-layers 8 \ + --hidden-size 4096 \ + --num-attention-heads 32 \ + --micro-batch-size 2 \ + --global-batch-size 16 \ + --seq-length 2048 \ + --max-position-embeddings 2048 \ + --train-iters 500000 \ + --lr-decay-iters 320000 \ + --save $CHECKPOINT_PATH \ + --load $CHECKPOINT_PATH \ + --data-path $DATA_PATH \ + --vocab-file ./dataset/gpt2-vocab.json \ + --merge-file ./dataset/gpt2-merges.txt \ + --data-impl mmap \ + --split 949,50,1 \ + --distributed-backend nccl \ + --lr 0.00015 \ + --lr-decay-style cosine \ + --min-lr 1.0e-5 \ + --weight-decay 1e-2 \ + --clip-grad 1.0 \ + --lr-warmup-fraction .01 \ + --checkpoint-activations \ + --log-interval 10 \ + --save-interval 10000 \ + --eval-interval 1000 \ + --eval-iters 10 \ + --fp16 | tee logs/train.log diff --git a/examples/gpt_task/pretrain_gpt_td_8p.sh b/examples/gpt_task/pretrain_gpt_td_8p.sh new file mode 100644 index 0000000000000000000000000000000000000000..522f1069dff97438cf8bd5cdb2ad76125ed1e8c4 --- /dev/null +++ b/examples/gpt_task/pretrain_gpt_td_8p.sh @@ -0,0 +1,50 @@ +export LD_LIBRARY_PATH=/root/miniconda3/lib:$LD_LIBRARY_PATH +export HCCL_CONNECT_TIMEOUT=1200 +source /usr/local/Ascend/ascend-toolkit/set_env.sh +GPUS_PER_NODE=8 +# Change for multinode config +MASTER_ADDR=localhost +MASTER_PORT=6000 +NNODES=1 +NODE_RANK=0 +WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES)) + +DATA_PATH=./dataset/enwiki-gpt/gpt_text_sentence +CHECKPOINT_PATH=./ckpt + +DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE --nnodes $NNODES --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT" + +python -m torch.distributed.launch $DISTRIBUTED_ARGS \ + pretrain_gpt.py \ + --DDP-impl local \ + --use-contiguous-buffers-in-ddp \ + --tensor-model-parallel-size 4 \ + --num-layers 8 \ + --hidden-size 4096 \ + --num-attention-heads 32 \ + --micro-batch-size 2 \ + --global-batch-size 16 \ + --seq-length 2048 \ + --max-position-embeddings 2048 \ + --train-iters 500000 \ + --lr-decay-iters 320000 \ + --save $CHECKPOINT_PATH \ + --load $CHECKPOINT_PATH \ + --data-path $DATA_PATH \ + --vocab-file ./dataset/gpt2-vocab.json \ + --merge-file ./dataset/gpt2-merges.txt \ + --data-impl mmap \ + --split 949,50,1 \ + --distributed-backend nccl \ + --lr 0.00015 \ + --lr-decay-style cosine \ + --min-lr 1.0e-5 \ + --weight-decay 1e-2 \ + --clip-grad 1.0 \ + --lr-warmup-fraction .01 \ + --checkpoint-activations \ + --log-interval 10 \ + --save-interval 10000 \ + --eval-interval 1000 \ + --eval-iters 10 \ + --fp16 | tee logs/train.log diff --git a/examples/llama_task/pretrain_llama_1p.sh b/examples/llama_task/pretrain_llama_1p.sh new file mode 100644 index 0000000000000000000000000000000000000000..7bfc8a6afa0655a39509e49c25245a14e3fe2c69 --- /dev/null +++ b/examples/llama_task/pretrain_llama_1p.sh @@ -0,0 +1,43 @@ +export LD_LIBRARY_PATH=/root/miniconda3/lib:$LD_LIBRARY_PATH +export HCCL_CONNECT_TIMEOUT=1200 +source /usr/local/Ascend/ascend-toolkit/set_env.sh +RANK=0 +WORLD_SIZE=1 + +DATA_PATH=./dataset/enwiki-gpt/gpt_text_sentence +CHECKPOINT_PATH=./ckpt + +export LOCAL_RANK=0 + +python pretrain_llama.py \ + --DDP-impl local \ + --use-contiguous-buffers-in-ddp \ + --num-layers 1 \ + --hidden-size 4096 \ + --num-attention-heads 32 \ + --micro-batch-size 4 \ + --global-batch-size 8 \ + --seq-length 2048 \ + --max-position-embeddings 2048 \ + --train-iters 500000 \ + --lr-decay-iters 320000 \ + --save $CHECKPOINT_PATH \ + --load $CHECKPOINT_PATH \ + --data-path $DATA_PATH \ + --vocab-file ./dataset/gpt2-vocab.json \ + --merge-file ./dataset/gpt2-merges.txt \ + --data-impl mmap \ + --split 949,50,1 \ + --distributed-backend nccl \ + --lr 0.00015 \ + --min-lr 1.0e-5 \ + --lr-decay-style cosine \ + --weight-decay 1e-2 \ + --clip-grad 1.0 \ + --lr-warmup-fraction .01 \ + --checkpoint-activations \ + --log-interval 10 \ + --save-interval 10000 \ + --eval-interval 1000 \ + --eval-iters 10 \ + --fp16 | tee logs/train.log diff --git a/examples/llama_task/pretrain_llama_ptd_8p.sh b/examples/llama_task/pretrain_llama_ptd_8p.sh new file mode 100644 index 0000000000000000000000000000000000000000..a7b691abcaa9f342e16140bc3de1583b61499b7e --- /dev/null +++ b/examples/llama_task/pretrain_llama_ptd_8p.sh @@ -0,0 +1,51 @@ +export LD_LIBRARY_PATH=/root/miniconda3/lib:$LD_LIBRARY_PATH +export HCCL_CONNECT_TIMEOUT=1200 +source /usr/local/Ascend/ascend-toolkit/set_env.sh +GPUS_PER_NODE=8 +# Change for multinode config +MASTER_ADDR=localhost +MASTER_PORT=6000 +NNODES=1 +NODE_RANK=0 +WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES)) + +DATA_PATH=./dataset/enwiki-gpt/gpt_text_sentence +CHECKPOINT_PATH=./ckpt + +DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE --nnodes $NNODES --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT" + +python -m torch.distributed.launch $DISTRIBUTED_ARGS \ + pretrain_llama.py \ + --DDP-impl local \ + --use-contiguous-buffers-in-ddp \ + --tensor-model-parallel-size 2 \ + --pipeline-model-parallel-size 2 \ + --num-layers 8 \ + --hidden-size 4096 \ + --num-attention-heads 32 \ + --micro-batch-size 2 \ + --global-batch-size 16 \ + --seq-length 2048 \ + --max-position-embeddings 2048 \ + --train-iters 500000 \ + --lr-decay-iters 320000 \ + --save $CHECKPOINT_PATH \ + --load $CHECKPOINT_PATH \ + --data-path $DATA_PATH \ + --vocab-file ./dataset/gpt2-vocab.json \ + --merge-file ./dataset/gpt2-merges.txt \ + --data-impl mmap \ + --split 949,50,1 \ + --distributed-backend nccl \ + --lr 0.00015 \ + --lr-decay-style cosine \ + --min-lr 1.0e-5 \ + --weight-decay 1e-2 \ + --clip-grad 1.0 \ + --lr-warmup-fraction .01 \ + --checkpoint-activations \ + --log-interval 10 \ + --save-interval 10000 \ + --eval-interval 1000 \ + --eval-iters 10 \ + --fp16 | tee logs/train.log diff --git a/examples/llama_task/pretrain_llama_td_8p.sh b/examples/llama_task/pretrain_llama_td_8p.sh new file mode 100644 index 0000000000000000000000000000000000000000..861d501fab8675cdb8c5d453e654c048c83dabc8 --- /dev/null +++ b/examples/llama_task/pretrain_llama_td_8p.sh @@ -0,0 +1,50 @@ +export LD_LIBRARY_PATH=/root/miniconda3/lib:$LD_LIBRARY_PATH +export HCCL_CONNECT_TIMEOUT=1200 +source /usr/local/Ascend/ascend-toolkit/set_env.sh +GPUS_PER_NODE=8 +# Change for multinode config +MASTER_ADDR=localhost +MASTER_PORT=6000 +NNODES=1 +NODE_RANK=0 +WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES)) + +DATA_PATH=./dataset/enwiki-gpt/gpt_text_sentence +CHECKPOINT_PATH=./ckpt + +DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE --nnodes $NNODES --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT" + +python -m torch.distributed.launch $DISTRIBUTED_ARGS \ + pretrain_llama.py \ + --DDP-impl local \ + --use-contiguous-buffers-in-ddp \ + --tensor-model-parallel-size 4 \ + --num-layers 8 \ + --hidden-size 4096 \ + --num-attention-heads 32 \ + --micro-batch-size 2 \ + --global-batch-size 16 \ + --seq-length 2048 \ + --max-position-embeddings 2048 \ + --train-iters 500000 \ + --lr-decay-iters 320000 \ + --save $CHECKPOINT_PATH \ + --load $CHECKPOINT_PATH \ + --data-path $DATA_PATH \ + --vocab-file ./dataset/gpt2-vocab.json \ + --merge-file ./dataset/gpt2-merges.txt \ + --data-impl mmap \ + --split 949,50,1 \ + --distributed-backend nccl \ + --lr 0.00015 \ + --lr-decay-style cosine \ + --min-lr 1.0e-5 \ + --weight-decay 1e-2 \ + --clip-grad 1.0 \ + --lr-warmup-fraction .01 \ + --checkpoint-activations \ + --log-interval 10 \ + --save-interval 10000 \ + --eval-interval 1000 \ + --eval-iters 10 \ + --fp16 | tee logs/train.log diff --git a/megatron/__init__.py b/megatron/__init__.py index 93894cd3f3364724e830e52fbf1f7d58fbdcbc72..a7160413128449a23f6bbf818bc196a76948bca0 100644 --- a/megatron/__init__.py +++ b/megatron/__init__.py @@ -13,7 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +import math +import copy +import apex +from functools import wraps import torch +from torch import distributed as dist from deepspeed.accelerator import get_accelerator from .package_info import ( __description__, @@ -36,6 +41,154 @@ from .global_vars import get_adlr_autoresume from .global_vars import get_timers from .initialize import initialize_megatron +def wrapper_type(fn): + @wraps(fn) + def decorated(*args, **kwargs): + output = fn(*args, **kwargs) + if isinstance(output, str): + if output == 'torch.npu.FloatTensor': + output = 'torch.cuda.FloatTensor' + elif output == 'torch.npu.HalfTensor': + output = 'torch.cuda.HalfTensor' + return output + + return decorated + +torch.Tensor.type = wrapper_type(torch.Tensor.type) # fix for torch.Tensor.type() + +def wrapper_dist_long2int(fn): + @wraps(fn) + def wrapper(*args, **kwargs): + if args[0].dtype == torch.long and not kwargs.get('async_op', False): + new_args = list(copy.deepcopy(args)) + new_args[0] = new_args[0].int() + fn(*new_args, **kwargs) + args[0].copy_(new_args[0].long()) + return + return fn(*args, **kwargs) + + return wrapper + +dist.all_reduce = wrapper_dist_long2int(dist.all_reduce) +dist.broadcast = wrapper_dist_long2int(dist.broadcast) +dist.send = wrapper_dist_long2int(dist.send) +dist.recv = wrapper_dist_long2int(dist.recv) # fix for torch.distributed + +class AdamW(torch.optim.Optimizer): + r"""Implements AdamW algorithm. + + The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_. + The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_. + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay coefficient (default: 1e-2) + amsgrad (boolean, optional): whether to use the AMSGrad variant of this + algorithm from the paper `On the Convergence of Adam and Beyond`_ + (default: False) + + .. _Adam\: A Method for Stochastic Optimization: + https://arxiv.org/abs/1412.6980 + .. _Decoupled Weight Decay Regularization: + https://arxiv.org/abs/1711.05101 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ + """ + + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, + weight_decay=1e-2, amsgrad=False): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + defaults = dict(lr=lr, betas=betas, eps=eps, + weight_decay=weight_decay, amsgrad=amsgrad) + super(AdamW, self).__init__(params, defaults) + + def __setstate__(self, state): + super(AdamW, self).__setstate__(state) + for group in self.param_groups: + group.setdefault('amsgrad', False) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + + # Perform stepweight decay + p.data.mul_(1 - group['lr'] * group['weight_decay']) + + # Perform optimization step + grad = p.grad + if grad.is_sparse: + raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') + amsgrad = group['amsgrad'] + + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = 0 + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like(p) + if amsgrad: + # Maintains max of all exp. moving avg. of sq. grad. values + state['max_exp_avg_sq'] = torch.zeros_like(p) + + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + if amsgrad: + max_exp_avg_sq = state['max_exp_avg_sq'] + beta1, beta2 = group['betas'] + + state['step'] += 1 + bias_correction1 = 1 - beta1 ** state['step'] + bias_correction2 = 1 - beta2 ** state['step'] + + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + if amsgrad: + # Maintains the maximum of all 2nd moment running avg. till now + torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) + # Use the max. for normalizing running avg. of gradient + denom = (max_exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) + else: + denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) + + step_size = group['lr'] / bias_correction1 + + p.addcdiv_(exp_avg, denom, value=-step_size) + + return loss + +apex.optimizers.FusedAdam = AdamW +apex.optimizers.FusedSGD = torch.optim.SGD + def print_rank_0(message): """If distributed is initialized, print only on rank 0.""" if torch.distributed.is_initialized(): diff --git a/megatron/initialize.py b/megatron/initialize.py index 4b7e11a6c35a815053954a1c048647a75eac91f4..baf3669a26a4d7e2e3616baaa906918d7aae4792 100644 --- a/megatron/initialize.py +++ b/megatron/initialize.py @@ -21,7 +21,7 @@ import time import numpy as np import torch -from megatron import fused_kernels + from megatron import get_adlr_autoresume from megatron import get_args from megatron import get_tensorboard_writer @@ -31,7 +31,6 @@ from megatron.mpu import (set_tensor_model_parallel_rank, set_tensor_model_parallel_world_size) from deepspeed.accelerator import get_accelerator import deepspeed -import deepspeed.utils.groups as groups def initialize_megatron(extra_args_provider=None, args_defaults={}, ignore_unknown_args=False, allow_no_cuda=False): @@ -92,66 +91,13 @@ def initialize_megatron(extra_args_provider=None, args_defaults={}, def _compile_dependencies(): - - args = get_args() - - # ========================= - # Compile dataset C++ code. - # ========================= - # TODO: move this to ninja - if _is_rank_0(): + if torch.distributed.get_rank() == 0: start_time = time.time() print('> compiling dataset index builder ...') from megatron.data.dataset_utils import compile_helper compile_helper() print('>>> done with dataset index builder. Compilation time: {:.3f} ' 'seconds'.format(time.time() - start_time), flush=True) - - if not get_accelerator().device_name() == 'cuda': - print(">fused kernel is only supported in cuda, skip loading fused kernel") - return - # ================== - # Load fused kernels - # ================== - - # Custom kernel constraints check. - seq_len = args.seq_length - attn_batch_size = \ - (args.num_attention_heads / args.tensor_model_parallel_size) * \ - args.micro_batch_size - # Constraints on sequence length and attn_batch_size to enable warp based - # optimization and upper triangular optimization (for causal mask) - custom_kernel_constraint = seq_len > 16 and seq_len <=2048 and \ - seq_len % 4 == 0 and attn_batch_size % 4 == 0 - # Print a warning. - if not ((args.fp16 or args.bf16) and - custom_kernel_constraint and - args.masked_softmax_fusion): - if args.rank == 0: - print('WARNING: constraints for invoking optimized' - ' fused softmax kernel are not met. We default' - ' back to unfused kernel invocations.', flush=True) - - # Always build on rank zero first. - if _is_rank_0(): - start_time = time.time() - print('> compiling and loading fused kernels ...', flush=True) - if get_accelerator().device_count() > 0: # Skip when CPU-only - fused_kernels.load(args) - torch.distributed.barrier() - else: - torch.distributed.barrier() - fused_kernels.load(args) - # Simple barrier to make sure all ranks have passed the - # compilation phase successfully before moving on to the - # rest of the program. We think this might ensure that - # the lock is released. - torch.distributed.barrier() - if _is_rank_0(): - print('>>> done with compiling and loading fused kernels. ' - 'Compilation time: {:.3f} seconds'.format( - time.time() - start_time), flush=True) - def setup_deepspeed_random_and_activation_checkpointing(args): '''Optional DeepSpeed Activation Checkpointing features. diff --git a/megatron/model/__init__.py b/megatron/model/__init__.py index 917b0f34127eec3f0c49f0cbdc290fa648877ed1..40967f41717592406cbf5e7fd0b659f2e08319f5 100644 --- a/megatron/model/__init__.py +++ b/megatron/model/__init__.py @@ -23,4 +23,4 @@ from .gpt_model import GPTModel, GPTModelPipe from .llama_model import LlamaModel, LlamaModelPipe from .t5_model import T5Model from .language_model import get_language_model -from .module import Float16Module \ No newline at end of file +from .module import Float16Module diff --git a/megatron/model/distributed.py b/megatron/model/distributed.py index ad86345c7ea0c4a9f5126ca8594a4b4fe894136f..320261fee5defba1c1f8aa1447fbcf96d91d8f47 100644 --- a/megatron/model/distributed.py +++ b/megatron/model/distributed.py @@ -188,6 +188,7 @@ class DistributedDataParallel(DistributedDataParallelBase): def allreduce_gradients(self): """Reduce gradients across data parallel ranks.""" # If we have buffers, simply reduce the data in the buffer. + if self._grad_buffers is not None: for _, buffer_ in self._grad_buffers.items(): buffer_.data /= mpu.get_data_parallel_world_size() diff --git a/megatron/model/fused_layer_norm.py b/megatron/model/fused_layer_norm.py index f34a05ffebca6cb5b01c60043ae4377c8e2edc44..1583f97ea7cb0531e91dfdc28feaa58f148af2c2 100644 --- a/megatron/model/fused_layer_norm.py +++ b/megatron/model/fused_layer_norm.py @@ -65,10 +65,6 @@ class MixedFusedLayerNorm(torch.nn.Module): def __init__(self, normalized_shape, eps=1e-5): super(MixedFusedLayerNorm, self).__init__() - global fused_mix_prec_layer_norm_cuda - fused_mix_prec_layer_norm_cuda = importlib.import_module( - "fused_mix_prec_layer_norm_cuda") - if isinstance(normalized_shape, numbers.Integral): normalized_shape = (normalized_shape,) self.normalized_shape = torch.Size(normalized_shape) @@ -79,17 +75,9 @@ class MixedFusedLayerNorm(torch.nn.Module): def reset_parameters(self): - init.ones_(self.weight) init.zeros_(self.bias) - def forward(self, input): - # CPU path is here for unittest sake. - if not input.is_cuda: - print("WARNING! The input of FusedLayerNorm should be on the GPU." - "This warning should only be triggered in the FusedLayerNorm unit tests.") - return F.layer_norm(input, self.normalized_shape, self.weight, self.bias, self.eps) - return FusedLayerNormAffineFunction.apply( - input, self.weight, self.bias, self.normalized_shape,self.eps) + return torch.nn.functional.layer_norm(input, self.normalized_shape, self.weight, self.bias, self.eps) diff --git a/megatron/model/fused_softmax.py b/megatron/model/fused_softmax.py index 8d52967d95be2605a29f8ea208844b5fce1fe000..8a4ddf08e211c6facbe953ecf605625ba13cd99b 100644 --- a/megatron/model/fused_softmax.py +++ b/megatron/model/fused_softmax.py @@ -14,150 +14,86 @@ # limitations under the License. import torch +import torch_npu from megatron.model.enums import AttnMaskType -class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function): - """ - Fused operation which performs following three operations in sequence - 1. Scale the tensor. - 2. Apply upper triangular mask (typically used in gpt models). - 3. Perform softmax. - """ - - @staticmethod - def forward(ctx, inputs, scale): - import scaled_upper_triang_masked_softmax_cuda - - scale_t = torch.tensor([scale]) - - softmax_results = scaled_upper_triang_masked_softmax_cuda.forward( - inputs, scale_t[0] - ) - ctx.save_for_backward(softmax_results, scale_t) - return softmax_results - - @staticmethod - def backward(ctx, output_grads): - import scaled_upper_triang_masked_softmax_cuda - - softmax_results, scale_t = ctx.saved_tensors - - input_grads = scaled_upper_triang_masked_softmax_cuda.backward( - output_grads, softmax_results, scale_t[0] - ) - return input_grads, None - - -class ScaledMaskedSoftmax(torch.autograd.Function): - """ - Fused operation which performs following three operations in sequence - 1. Scale the tensor. - 2. Apply the mask. - 3. Perform softmax. - """ - - @staticmethod - def forward(ctx, inputs, mask, scale): - import scaled_masked_softmax_cuda - - scale_t = torch.tensor([scale]) - - softmax_results = scaled_masked_softmax_cuda.forward( - inputs, mask, scale_t[0] - ) - ctx.save_for_backward(softmax_results, scale_t) - return softmax_results - - @staticmethod - def backward(ctx, output_grads): - import scaled_masked_softmax_cuda - - softmax_results, scale_t = ctx.saved_tensors - - input_grads = scaled_masked_softmax_cuda.backward( - output_grads, softmax_results, scale_t[0] - ) - return input_grads, None, None - - class FusedScaleMaskSoftmax(torch.nn.Module): - """ - fused operation: scaling + mask + softmax - Arguments: - input_in_fp16: flag to indicate if input in fp16 data format. - attn_mask_type: attention mask type (pad or causal) - mask_func: mask function to be applied. - softmax_in_fp32: if true, softmax in performed at fp32 precision. - scale: scaling factor used in input tensor scaling. - - """ - def __init__( - self, - input_in_fp16, - input_in_bf16, - attn_mask_type, - scaled_masked_softmax_fusion, - mask_func, - softmax_in_fp32, - scale, + self, + input_in_fp16, + input_in_bf16, + attn_mask_type, + scaled_masked_softmax_fusion, + mask_func, + softmax_in_fp32, + scale, ): super(FusedScaleMaskSoftmax, self).__init__() self.input_in_fp16 = input_in_fp16 self.input_in_bf16 = input_in_bf16 - assert not (self.input_in_fp16 and self.input_in_bf16),\ - 'both fp16 and bf16 flags cannot be active at the same time.' + assert not ( + self.input_in_fp16 and self.input_in_bf16 + ), "both fp16 and bf16 flags cannot be active at the same time." self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16 self.attn_mask_type = attn_mask_type self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion self.mask_func = mask_func self.softmax_in_fp32 = softmax_in_fp32 self.scale = scale + self.mask_tri = None + p = torch.npu.get_device_properties(0) if torch.npu.is_available() else None + self.fused = p.name in ['Ascend910A', 'Ascend910ProB'] if p is not None else False assert ( - self.scale is None or softmax_in_fp32 + self.scale is None or softmax_in_fp32 ), "softmax should be in fp32 when scaled" - + def forward(self, input, mask): # [b, np, sq, sk] assert input.dim() == 4 - data_size = input.size() - query_seq_len = data_size[-2] - key_seq_len = data_size[-1] - attn_batch_size = data_size[0] * data_size[1] - - # constraints on various tensor dimensions to enable warp based - # optimization and upper triangular optimization (for causal mask) - custom_kernel_constraint = key_seq_len > 16 and key_seq_len <= 2048 and \ - query_seq_len % 4 == 0 and attn_batch_size % 4 == 0 - - # invoke custom kernel - if self.input_in_float16 and mask is not None and \ - custom_kernel_constraint and self.scaled_masked_softmax_fusion: - scale = self.scale if self.scale is not None else 1.0 - - if self.attn_mask_type == AttnMaskType.causal: - assert query_seq_len == key_seq_len, \ - "causal mask is only for self attention" - input = input.view(-1, query_seq_len, key_seq_len) - probs = ScaledUpperTriangMaskedSoftmax.apply(input, scale) - probs = probs.view(*data_size) - else: - assert self.attn_mask_type == AttnMaskType.padding - probs = ScaledMaskedSoftmax.apply(input, mask, scale) + + if torch.npu.is_available() and self.fused: + return self.forward_fused_softmax(input, mask) + + return self.forward_torch_softmax(input, mask) + + def forward_fused_softmax(self, input, mask): + if self.softmax_in_fp32: + input = input.float() + + if self.scale is None: + self.scale = 1.0 + + if self.attn_mask_type == AttnMaskType.causal: + if self.mask_tri is None: + self.mask_tri = torch.triu(torch.ones(input.shape, device=input.device), diagonal=1).bool() + probs = torch_npu.npu_scaled_masked_softmax(input, self.mask_tri, self.scale, False) else: - if self.input_in_float16 and self.softmax_in_fp32: - input = input.float() + probs = torch_npu.npu_scaled_masked_softmax(input, mask, self.scale, False) - if self.scale is not None: - input = input * self.scale + probs = probs.half() + + return probs + + def forward_torch_softmax(self, input, mask): + if self.input_in_float16 and self.softmax_in_fp32: + input = input.float() + + if self.scale is not None: + input = input * self.scale + + if self.attn_mask_type == AttnMaskType.causal: + mask_tri = torch.triu(torch.ones(input.shape, device=input.device), diagonal=1).bool() + mask_output = self.mask_func(input, mask_tri) + else: mask_output = self.mask_func(input, mask) if mask is not None else input - probs = torch.nn.Softmax(dim=-1)(mask_output) - if self.input_in_float16 and self.softmax_in_fp32: - if self.input_in_fp16: - probs = probs.half() - else: - probs = probs.bfloat16() + probs = torch.nn.Softmax(dim=-1)(mask_output) + + if self.input_in_float16 and self.softmax_in_fp32: + if self.input_in_fp16: + probs = probs.half() + else: + probs = probs.bfloat16() return probs diff --git a/megatron/model/llama_model.py b/megatron/model/llama_model.py index b523c79cbeb614f06c86db679c50399bc17edfa2..213a77dd4ab59e601fc31d17957c554c8712336b 100644 --- a/megatron/model/llama_model.py +++ b/megatron/model/llama_model.py @@ -41,7 +41,8 @@ from deepspeed.pipe import PipelineModule, LayerSpec class RotaryEmbedding(torch.nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) + # inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) + inv_freq = 1.0 / (torch.tensor(base).double() ** (torch.arange(0, dim, 2).float().to(device) / dim).double()) self.register_buffer("inv_freq", inv_freq) # Build here to make `torch.jit.trace` work. @@ -271,9 +272,7 @@ class LlamaParallelMLP(MegatronModule): enable_expert_tensor_parallelism=enable_expert_tensor_parallelism) def forward(self, hidden_states): - intermediate_parallel = self.gate_proj(hidden_states)[0] * self.up_proj(hidden_states)[0] - - intermediate_parallel = self.activation_func(intermediate_parallel) + intermediate_parallel = self.activation_func(self.gate_proj(hidden_states)[0]) * self.up_proj(hidden_states)[0] output, _ = self.down_proj(intermediate_parallel) return output @@ -854,7 +853,7 @@ class LlamaModelPipe(PipelineModule, MegatronModule): self.specs.append(LayerSpec(RMSNorm, args.hidden_size, eps=args.layernorm_epsilon)) self.specs.append( - LayerSpec(LlamaLMHeadPipe, hidden_size=args.hidden_size, vocab_size=padded_vocab_size, + LayerSpec(LlamaLMHeadPipe, hidden_size=args.hidden_size, vocab_size=args.padded_vocab_size, init_method=self.init_method, parallel_output=self.parallel_output) ) @@ -883,7 +882,7 @@ class LlamaModel(MegatronModule): """llama Language model.""" def __init__(self, pre_process, post_process, parallel_output=True, add_pooler=False): - super(LlamaModel, self).__init__() + super(LlamaModel, self).__init__(share_word_embeddings=False) args = get_args() self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy self.hidden_size = args.hidden_size @@ -902,7 +901,7 @@ class LlamaModel(MegatronModule): vocab_size=self.padded_vocab_size) # Transformer. - self.transformer = LlamaParallelTransformer( + self.language_model = LlamaParallelTransformer( self.init_method, self.output_layer_init_method, self_attn_mask_type=self.self_attn_mask_type, @@ -922,7 +921,7 @@ class LlamaModel(MegatronModule): def set_input_tensor(self, input_tensor): """See megatron.model.transformer.set_input_tensor()""" - self.transformer.set_input_tensor(input_tensor) + self.language_model.set_input_tensor(input_tensor) def forward(self, input_ids, attention_mask, labels=None, layer_past=None, get_key_value=False): args = get_args() @@ -933,7 +932,7 @@ class LlamaModel(MegatronModule): hidden_states = input_ids # decoder - hidden_states = self.transformer(hidden_states, attention_mask, layer_past=layer_past, + hidden_states = self.language_model(hidden_states, attention_mask, layer_past=layer_past, get_key_value=get_key_value) if self.post_process: @@ -958,4 +957,4 @@ class LlamaModel(MegatronModule): loss = mpu.vocab_parallel_cross_entropy(hidden_states.float(), labels) return loss - return hidden_states \ No newline at end of file + return hidden_states diff --git a/megatron/model/module.py b/megatron/model/module.py index 9f91c8bd1d949a672e8486a0430d7aa2b74136a2..3d5b783ffce00a99dc32dad8e2c53aedd01dec39 100644 --- a/megatron/model/module.py +++ b/megatron/model/module.py @@ -121,30 +121,28 @@ def conversion_helper(val, conversion): rtn = tuple(rtn) return rtn - def fp32_to_float16(val, float16_convertor): - """Convert fp32 `val` to fp16/bf16""" def half_conversion(val): val_typecheck = val - if isinstance(val_typecheck, (Parameter, Variable)): + if isinstance(val_typecheck, (torch.nn.parameter.Parameter, torch.autograd.Variable)): val_typecheck = val.data - if isinstance(val_typecheck, _FLOAT_TYPES): + if val_typecheck.dtype == torch.float32: val = float16_convertor(val) return val + return conversion_helper(val, half_conversion) def float16_to_fp32(val): - """Convert fp16/bf16 `val` to fp32""" def float_conversion(val): val_typecheck = val - if isinstance(val_typecheck, (Parameter, Variable)): + if isinstance(val_typecheck, (torch.nn.parameter.Parameter, torch.autograd.Variable)): val_typecheck = val.data - if isinstance(val_typecheck, (_BF16_TYPES, _HALF_TYPES)): + if val_typecheck.dtype in [torch.float16, torch.bfloat16]: val = val.float() return val - return conversion_helper(val, float_conversion) + return conversion_helper(val, float_conversion) class Float16Module(MegatronModule): diff --git a/megatron/mpu/layers.py b/megatron/mpu/layers.py index 5d168c67282ab46676ec4f91725a34236e053844..c08168340bb6a63ebdd20bdb7b4991438d52974b 100644 --- a/megatron/mpu/layers.py +++ b/megatron/mpu/layers.py @@ -43,7 +43,6 @@ _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = {'tensor_model_parallel': False, 'partition_dim': -1, 'partition_stride': 1} - def param_is_not_tensor_parallel_duplicate(param): return (hasattr(param, 'tensor_model_parallel') and param.tensor_model_parallel) or ( diff --git a/megatron/mpu/random.py b/megatron/mpu/random.py index beea791b56882a863f162ccec21f770740d75df3..7de26f467534a6e105c8a63272e44f51972d2d3b 100644 --- a/megatron/mpu/random.py +++ b/megatron/mpu/random.py @@ -168,33 +168,33 @@ class CudaRNGStatesTracker: if name in self.states_: raise Exception('cuda rng state {} already exists'.format(name)) # Get the current rng state. - orig_rng_state = get_accelerator().get_rng_state() + # orig_rng_state = torch.cuda.get_rng_state() # Set the new state and store it. - get_accelerator().manual_seed(seed) - self.states_[name] = get_accelerator().get_rng_state() + torch.cuda.manual_seed(seed) + # self.states_[name] = torch.cuda.get_rng_state() # Reset rng state to what it was. - _set_cuda_rng_state(orig_rng_state) + # _set_cuda_rng_state(orig_rng_state) @contextlib.contextmanager def fork(self, name=_MODEL_PARALLEL_RNG_TRACKER_NAME): - """Fork the cuda rng state, perform operations, and exit with - the original state.""" - # Check if we have added the state - if name not in self.states_: - print(name, self.states_) - raise Exception('cuda rng state {} is not added'.format(name)) + yield + # """Fork the cuda rng state, perform operations, and exit with + # the original state.""" + # # Check if we have added the state + # if name not in self.states_: + # raise Exception('cuda rng state {} is not added'.format(name)) # Store current rng state. - orig_cuda_rng_state = get_accelerator().get_rng_state() + # orig_cuda_rng_state = torch.cuda.get_rng_state() # Set rng state to the desired one - _set_cuda_rng_state(self.states_[name]) + # _set_cuda_rng_state(self.states_[name]) # Do the stuff we wanted to do. - try: - yield - finally: - # Update the current rng state for later use. - self.states_[name] = get_accelerator().get_rng_state() - # And set the state to the original state we started with. - _set_cuda_rng_state(orig_cuda_rng_state) + # try: + # yield + # finally: + # # Update the current rng state for later use. + # self.states_[name] = torch.cuda.get_rng_state() + # # And set the state to the original state we started with. + # _set_cuda_rng_state(orig_cuda_rng_state) # RNG tracker object. @@ -310,7 +310,7 @@ class CheckpointFunction(torch.autograd.Function): if isinstance(outputs, torch.Tensor): outputs = (outputs,) elif len(outputs) == 2 and isinstance(outputs[1], torch.Tensor) and \ - torch.equal(outputs[1], torch.tensor(0).to(get_accelerator().device_name())): + torch.equal(outputs[1], torch.tensor(0, dtype=outputs[1].dtype).to(get_accelerator().device_name())): # a hacky solution to overcome issue when running old script examples/pretrain_gpt_distributed.sh outputs = (outputs[0],) torch.autograd.backward(outputs, args) diff --git a/megatron/optimizer/clip_grads.py b/megatron/optimizer/clip_grads.py index fdec6ff806bea943ef676dde46d6b733ccbc1049..1563966ed7487b3d5841b4551693e7da0472c3ef 100644 --- a/megatron/optimizer/clip_grads.py +++ b/megatron/optimizer/clip_grads.py @@ -45,9 +45,13 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2): Arguments: parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a single Tensor that will have gradients normalized + grads_for_norm (Iterable[Tensor]): an iterable of Tensors or a single + Tensor that will be used for calculating the grad norm. max_norm (float or int): max norm of the gradients norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for infinity norm. + model_parallel_group (group): given the nature of the distributed + optimizer, this is passed as an argument. Returns: Total norm of the parameters (viewed as a single vector). @@ -55,11 +59,10 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2): if isinstance(parameters, torch.Tensor): parameters = [parameters] + # if isinstance(grads_for_norm, torch.Tensor): + # grads_for_norm = [grads_for_norm] - # Filter parameters based on: - # - grad should not be none - # - parameter should not be shared - # - should not be a replica due to tensor model parallelism + # Grads. grads = [] grads_for_norm = [] for param in parameters: @@ -69,11 +72,12 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2): grad = param.grad.detach() if grad_not_none: # Make sure the grads are in fp32 - assert param.grad.type() == 'torch.{}.FloatTensor'.format(get_accelerator().device_name()) + # assert param.grad.type() == 'torch.{}.FloatTensor'.format(get_accelerator().device_name()) grads.append(grad) if grad_not_none and is_not_shared and is_not_tp_duplicate: grads_for_norm.append(grad) + # Norm parameters. max_norm = float(max_norm) norm_type = float(norm_type) @@ -82,35 +86,16 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2): # Calculate norm. if norm_type == inf: total_norm = max(grad.abs().max() for grad in grads_for_norm) - total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)]) + total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) # Take max across all model-parallel GPUs. torch.distributed.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=mpu.get_model_parallel_group()) total_norm = total_norm_cuda[0].item() - else: - if norm_type == 2.0: - if get_accelerator().device_name() == 'cuda': - dummy_overflow_buf = get_accelerator().IntTensor([0]) - # Use apex's multi-tensor applier for efficiency reasons. - # Multi-tensor applier takes a function and a list of list - # and performs the operation on that list all in one kernel. - grad_norm, _ = multi_tensor_applier( - amp_C.multi_tensor_l2norm, - dummy_overflow_buf, - [grads_for_norm], - False # no per-parameter norm - ) - else: - grad_norm = torch.norm(grads_for_norm,p=2.0) - # Since we will be summing across data parallel groups, - # we need the pow(norm-type). - total_norm = grad_norm ** norm_type - else: - for grad in grads_for_norm: - grad_norm = torch.norm(grad, norm_type) - total_norm += grad_norm ** norm_type + for grad in grads_for_norm: + grad_norm = torch.norm(grad, norm_type) + total_norm += grad_norm ** norm_type # Sum across all model-parallel GPUs. torch.distributed.all_reduce(total_norm, @@ -121,15 +106,8 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2): # Scale. clip_coeff = max_norm / (total_norm + 1.0e-6) if clip_coeff < 1.0: - if get_accelerator().device_name() == 'cuda': - dummy_overflow_buf = get_accelerator().IntTensor([0]) - multi_tensor_applier(amp_C.multi_tensor_scale, - dummy_overflow_buf, - [grads, grads], - clip_coeff) - else: - for g in grads: - g.detach().mul_(clip_coeff.to(g.device)) + for p in parameters: + p.grad.detach().mul_(clip_coeff) return total_norm diff --git a/megatron/optimizer/optimizer.py b/megatron/optimizer/optimizer.py index 6568bf10a7fa8b6d2846d4fff29bc13de0d185d0..8c5f7a7db744e03336a9816b41bbd368706d974d 100644 --- a/megatron/optimizer/optimizer.py +++ b/megatron/optimizer/optimizer.py @@ -239,11 +239,15 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer): # For all the parameters in this group: for i, param in enumerate(param_group['params']): if param.requires_grad: + if param.type() == "torch.cuda.HalfTensor": + param_type = "torch.npu.HalfTensor" + elif param.type() == "torch.cuda.BFloat16Tensor": + param_type = "torch.npu.BFloat16Tensor" + elif param.type() == "torch.cuda.FloatTensor": + param_type = "torch.npu.FloatTensor" # float16 params: - - - if param.type() in ['torch.{}.HalfTensor'.format(get_accelerator().device_name()), + if param_type in ['torch.{}.HalfTensor'.format(get_accelerator().device_name()), 'torch.{}.BFloat16Tensor'.format(get_accelerator().device_name())]: float16_params_this_group.append(param) # Create a copy @@ -262,7 +266,7 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer): = self.optimizer.state.pop(param) # fp32 params. - elif param.type() == 'torch.{}.FloatTensor'.format(format(get_accelerator().device_name())): + elif param_type == 'torch.{}.FloatTensor'.format(format(get_accelerator().device_name())): fp32_params_this_group.append(param) param_group['params'][i] = param @@ -304,6 +308,7 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer): for model_group, main_group in zip(self.float16_groups, self.fp32_from_float16_groups): for model_param, main_param in zip(model_group, main_group): + # if self.params_have_main_grad: if self.params_have_main_grad: main_param.grad = model_param.main_grad.float() else: diff --git a/megatron/p2p_communication.py b/megatron/p2p_communication.py index 21df8b2b3f0a95aa7d967dd6cea5004522154a84..07e66c8019d2355e27f5afec03b483887aa219b0 100644 --- a/megatron/p2p_communication.py +++ b/megatron/p2p_communication.py @@ -94,16 +94,16 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next, torch.distributed.irecv, tensor_recv_prev, mpu.get_pipeline_model_parallel_prev_rank()) ops.append(recv_prev_op) - if tensor_send_next is not None: - send_next_op = torch.distributed.P2POp( - torch.distributed.isend, tensor_send_next, - mpu.get_pipeline_model_parallel_next_rank()) - ops.append(send_next_op) if tensor_recv_next is not None: recv_next_op = torch.distributed.P2POp( torch.distributed.irecv, tensor_recv_next, mpu.get_pipeline_model_parallel_next_rank()) ops.append(recv_next_op) + if tensor_send_next is not None: + send_next_op = torch.distributed.P2POp( + torch.distributed.isend, tensor_send_next, + mpu.get_pipeline_model_parallel_next_rank()) + ops.append(send_next_op) if len(ops) > 0: reqs = torch.distributed.batch_isend_irecv(ops) for req in reqs: diff --git a/megatron/schedules.py b/megatron/schedules.py index a072ac82b44cc3b5c6f08439f8a5b2885bbc692c..57d73c05ca6753a6fd5b5dc1466308a4ad53b302 100644 --- a/megatron/schedules.py +++ b/megatron/schedules.py @@ -15,6 +15,7 @@ from contextlib import contextmanager import torch +import torch_npu from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP from megatron import get_args @@ -27,6 +28,25 @@ from megatron.utils import unwrap_model from megatron.model import DistributedDataParallel as LocalDDP from megatron.model import Float16Module + +def clear_npu_overflow_flag(): + float_status = torch.zeros(8).npu() + result = torch_npu.npu_clear_float_status(float_status) + + +def get_npu_overflow_flag(): + float_status = torch.zeros(8).npu() + result = torch_npu.npu_get_float_status(float_status) + if float_status.cpu()[0] != 0: + return True + else: + return False + + +def set_npu_overflow_flag(): + torch.tensor([65504]).half().npu() + 100 + + def get_forward_backward_func(): args = get_args() if mpu.get_pipeline_model_parallel_world_size() > 1: @@ -96,6 +116,7 @@ def backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad, mo if input_tensor is not None: input_tensor.retain_grad() + clear_npu_overflow_flag() if args.deepspeed: model.backward(output_tensor) else: @@ -113,7 +134,6 @@ def backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad, mo return input_tensor_grad - @contextmanager def dummy_handler(): try: @@ -131,35 +151,33 @@ def forward_backward_no_pipelining(forward_step_func, data_iterator, model, assert len(model) == 1 model = model[0] - args = get_args() - context_handler = dummy_handler if isinstance(model, torchDDP): context_handler = model.no_sync - if args.deepspeed: - model.set_gradient_accumulation_boundary(False) - losses_reduced = [] input_tensor, output_tensor_grad = None, None + overflow_flag_all = False with context_handler(): for i in range(get_num_microbatches() - 1): output_tensor = forward_step(forward_step_func, data_iterator, model, input_tensor, losses_reduced) if not forward_only: backward_step(optimizer, input_tensor, output_tensor, - output_tensor_grad, model) - - if args.deepspeed: - model.set_gradient_accumulation_boundary(True) + output_tensor_grad) - # Run computation for last microbatch out of context handler (want to - # synchronize gradients). + overflow_flag = get_npu_overflow_flag() + overflow_flag_all = overflow_flag or overflow_flag_all output_tensor = forward_step(forward_step_func, data_iterator, model, input_tensor, losses_reduced) if not forward_only: - backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad, model) + backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad) + + overflow_flag = get_npu_overflow_flag() + overflow_flag_all = overflow_flag or overflow_flag_all + if overflow_flag_all: + set_npu_overflow_flag() return losses_reduced diff --git a/megatron/training.py b/megatron/training.py index 94133e7aeddbc64880773ff27ecbe60a7326bd89..7281362565abb9a778d87ae6ea617d66bbfc50ed 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -136,6 +136,7 @@ def pretrain(train_valid_test_dataset_provider, model, optimizer, lr_scheduler = setup_model_and_optimizer( model_provider, teacher=False, data_post_process=data_post_process, build_train_valid_test_datasets_provider=train_valid_test_dataset_provider) + timers('model-and-optimizer-setup').stop() print_datetime('after model, optimizer, and learning rate ' 'scheduler are built') @@ -488,7 +489,6 @@ def setup_model_and_optimizer(model_provider_func, teacher=False, # Number of train/valid/test samples. if args.train_samples: train_samples = args.train_samples - update_train_iters(args) else: train_samples = args.train_iters * args.global_batch_size # eval_iters and test_iters here are not actually used, only for @@ -1273,7 +1273,6 @@ def build_train_valid_test_data_iterators( # Number of train/valid/test samples. if args.train_samples: train_samples = args.train_samples - update_train_iters(args) else: train_samples = args.train_iters * args.global_batch_size eval_iters = (args.train_iters // args.eval_interval + 1) * \ diff --git a/pretrain_gpt.py b/pretrain_gpt.py index 355c41450673e3ddb3fa823a1f126b42bcfaac1c..099afbd870e4baa4849512c4a29ad8aec0d3f738 100644 --- a/pretrain_gpt.py +++ b/pretrain_gpt.py @@ -16,6 +16,10 @@ """Pretrain GPT""" import torch +import torch_npu +from torch_npu.contrib import transfer_to_npu +import deepspeed_npu + import math from functools import partial from megatron import get_args @@ -322,7 +326,6 @@ def git_ds_info(): if __name__ == "__main__": - git_ds_info() pretrain(train_valid_test_datasets_provider, model_provider, forward_step, args_defaults={'tokenizer_type': 'GPT2BPETokenizer'}, data_post_process=data_post_process) diff --git a/pretrain_llama.py b/pretrain_llama.py index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..3d19a5dba0011d5faa5da9d7612b6a548ae852c5 100644 --- a/pretrain_llama.py +++ b/pretrain_llama.py @@ -0,0 +1,250 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Pretrain Llama""" +import torch +import torch_npu +from torch_npu.contrib import transfer_to_npu +import deepspeed_npu + +import math +from functools import partial +from megatron import get_args +from megatron import print_rank_0 +from megatron import get_timers +from megatron import get_tokenizer +from megatron import mpu +from megatron.data.gpt_dataset import build_train_valid_test_datasets +from megatron.model import LlamaModel, LlamaModelPipe +from megatron.training import pretrain +from megatron.utils import get_ltor_masks_and_position_ids +from megatron.utils import average_losses_across_data_parallel_group + +import deepspeed +from deepspeed.runtime.utils import see_memory_usage +from deepspeed.accelerator.real_accelerator import get_accelerator + +import os +import subprocess + +from torch import nn +import torch.nn.functional as F + + +def model_provider(pre_process=True, post_process=True): + """Build the model.""" + + print_rank_0('building llama model ...') + see_memory_usage(f"Before Building Model", force=True) + + args = get_args() + with deepspeed.zero.Init(data_parallel_group=mpu.get_data_parallel_group(), + remote_device=None if args.remote_device == 'none' else args.remote_device, + config_dict_or_path=args.deepspeed_config, + enabled=args.zero_stage == 3, + mpu=mpu): + if args.deepspeed and not args.no_pipeline_parallel: + model = LlamaModelPipe(parallel_output=True) + # This is a hack to give us a reference to get_batch_pipe from within training.py + # We need to call model.set_batch_fn after deepspeed.initialize + model._megatron_batch_fn = get_batch_pipe + + # Predompute the attention mask and store it in args. This avoids having to + # pipeline it as an activation during training. The mask is constant, and thus + # we can reuse it. + attention_mask = torch.tril(torch.ones( + (1, args.seq_length, args.seq_length), device=get_accelerator().current_device_name())).view( + 1, 1, args.seq_length, args.seq_length) + + # Convert attention mask to binary: + attention_mask = (attention_mask < 0.5) + if args.fp16: + attention_mask = attention_mask.half() + elif args.bf16: + attention_mask = attention_mask.bfloat16() + + # Attention mask must be bool. + args.attn_mask = attention_mask.to(torch.bool) + + else: + model = LlamaModel( + parallel_output=True, + add_pooler=False, + pre_process=pre_process, + post_process=post_process + ) + see_memory_usage(f"After Building Model", force=True) + return model + + +def get_batch(data_iterator): + """Generate a batch""" + args = get_args() + tokenizer = get_tokenizer() + + # Items and their type. + keys = ['text'] + datatype = torch.int64 + + # Broadcast data. + if data_iterator is not None: + data = next(data_iterator) + else: + data = None + data_b = mpu.broadcast_data(keys, data, datatype) + + # Unpack. + tokens_ = data_b['text'].long() + labels = tokens_[:, 1:].contiguous() + tokens = tokens_[:, :-1].contiguous() + + # Get the masks and postition ids. + attention_mask, loss_mask, _ = get_ltor_masks_and_position_ids( + tokens, + tokenizer.eod, + args.reset_position_ids, + args.reset_attention_mask, + args.eod_mask_loss) + + return tokens, labels, loss_mask, attention_mask + + +def data_post_process(data, data_sampler_state_dict): + args = get_args() + if args.data_efficiency_curriculum_learning: + if 'seqlen_truncate' in data_sampler_state_dict['current_difficulties']: + args.data_efficiency_curriculum_learning_seqlen_type = 'seqlen_truncate' + current_seqlen = data_sampler_state_dict['current_difficulties']['seqlen_truncate'] + if current_seqlen < args.seq_length: + data['text'] = data['text'][:, :(current_seqlen + 1)].contiguous() + elif 'seqlen_reshape' in data_sampler_state_dict['current_difficulties']: + args.data_efficiency_curriculum_learning_seqlen_type = 'seqlen_reshape' + current_seqlen = data_sampler_state_dict['current_difficulties']['seqlen_reshape'] + if current_seqlen < args.seq_length: + orig_num_token = torch.numel(data['text']) + reshape_len = (data['text'].size()[1] // (current_seqlen + 1)) * (current_seqlen + 1) + data['text'] = torch.cat((data['text'][:, :reshape_len].contiguous().view(-1, current_seqlen + 1), + data['text'][:, -(current_seqlen + 1):]), 0).contiguous() + num_row = math.ceil(orig_num_token / (current_seqlen + 1)) + num_row = min(num_row, data['text'].size()[0]) + if num_row > 1 and num_row % 2 != 0: + num_row -= 1 + data['text'] = data['text'][:num_row, :].contiguous() + else: + args.data_efficiency_curriculum_learning_seqlen_type = None + return data + + +def get_batch_pipe(data): + """Modification of `get_batch` to work on `next(data_iterator)` instead of `data_iterator`""" + args = get_args() + tokenizer = get_tokenizer() + + # Items and their type. + keys = ['text'] + datatype = torch.int64 + + # Broadcast data. + data_b = mpu.broadcast_data(keys, data, datatype) + + # Unpack. + tokens_ = data_b['text'].long() + labels = tokens_[:, 1:].contiguous() + tokens = tokens_[:, :-1].contiguous() + + # Get the masks and postition ids. + attention_mask, loss_mask, _ = get_ltor_masks_and_position_ids( + tokens, + tokenizer.eod, + args.reset_position_ids, + args.reset_attention_mask, + args.eod_mask_loss) + return (tokens, attention_mask), (labels, loss_mask) + +def loss_func(loss_mask, output_tensor): + args = get_args() + losses = output_tensor.float() + loss_mask = loss_mask.view(-1).float() + loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum() + + # Reduce loss for logging. + averaged_loss = average_losses_across_data_parallel_group([loss]) + return loss, {'lm loss': averaged_loss[0]} + + +def forward_step(data_iterator, model): + """Forward step.""" + args = get_args() + timers = get_timers() + # Get the batch. + timers('batch-generator').start() + tokens, labels, loss_mask, attention_mask = get_batch(data_iterator) + timers('batch-generator').stop() + + output_tensor = model(tokens, attention_mask, labels=labels) + # Output_tensor stores the standard loss, loos_func calculates the total loss. + return output_tensor, partial(loss_func, loss_mask) + + +def train_valid_test_datasets_provider(train_val_test_num_samples): + """Build train, valid, and test datasets.""" + args = get_args() + + print_rank_0('> building train, validation, and test datasets ' + 'for llama ...') + train_ds, valid_ds, test_ds = build_train_valid_test_datasets( + data_prefix=args.data_path, + data_impl=args.data_impl, + splits_string=args.split, + train_valid_test_num_samples=train_val_test_num_samples, + seq_length=args.seq_length, + seed=args.seed, + skip_warmup=(not args.mmap_warmup)) + print_rank_0("> finished creating llama datasets ...") + + return train_ds, valid_ds, test_ds + + +def command_exists(cmd): + result = subprocess.Popen(f'type {cmd}', stdout=subprocess.PIPE, shell=True) + return result.wait() == 0 + + +def git_ds_info(): + from deepspeed.env_report import main as ds_report + ds_report() + + # Write out version/git info + git_hash_cmd = "git rev-parse --short HEAD" + git_branch_cmd = "git rev-parse --abbrev-ref HEAD" + if command_exists('git'): + try: + result = subprocess.check_output(git_hash_cmd, shell=True) + git_hash = result.decode('utf-8').strip() + result = subprocess.check_output(git_branch_cmd, shell=True) + git_branch = result.decode('utf-8').strip() + except subprocess.CalledProcessError: + git_hash = "unknown" + git_branch = "unknown" + else: + git_hash = "unknown" + git_branch = "unknown" + print(f'**** Git info for Megatron: git_hash={git_hash} git_branch={git_branch} ****') + + +if __name__ == "__main__": + pretrain(train_valid_test_datasets_provider, model_provider, forward_step, + args_defaults={'tokenizer_type': 'GPT2BPETokenizer'}, + data_post_process=data_post_process) \ No newline at end of file