diff --git a/megatron/__init__.py b/megatron/__init__.py index 3d247c581ea228e7bb9fcf6c9f0f4a263636fdbc..75f6ce1a9acd67f4eeccdd98a55371dfb7f8609f 100644 --- a/megatron/__init__.py +++ b/megatron/__init__.py @@ -13,11 +13,13 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +import torch import copy +import torch_npu from functools import wraps - -import torch +from torch_npu.contrib import transfer_to_npu from torch import distributed as dist + from deepspeed.accelerator import get_accelerator from .package_info import ( __description__, @@ -53,17 +55,20 @@ def wrapper_type(fn): return decorated -torch.Tensor.type = wrapper_type(torch.Tensor.type) # fix for torch.Tensor.type() +torch.Tensor.type = wrapper_type(torch.Tensor.type) def wrapper_dist_long2int(fn): @wraps(fn) def wrapper(*args, **kwargs): + if not args: + args = (kwargs.pop('tensor'),) + if args[0].dtype == torch.long and not kwargs.get('async_op', False): new_args = list(copy.deepcopy(args)) new_args[0] = new_args[0].int() fn(*new_args, **kwargs) args[0].copy_(new_args[0].long()) - return None + return return fn(*args, **kwargs) return wrapper diff --git a/megatron/arguments.py b/megatron/arguments.py index bdd17450bfd0a34f584058bbb8f8daaea9aed907..d1224730285aaebcb2f6bae0deffe32ca72314c5 100644 --- a/megatron/arguments.py +++ b/megatron/arguments.py @@ -20,6 +20,7 @@ import os import torch import deepspeed +from megatron.enums import PositionEmbeddingType def parse_args(extra_args_provider=None, defaults={}, ignore_unknown_args=False): @@ -228,6 +229,15 @@ def parse_args(extra_args_provider=None, defaults={}, assert args.encoder_seq_length is not None args.seq_length = args.encoder_seq_length + if args.position_embedding_type == PositionEmbeddingType.absolute or args.position_embedding_type == PositionEmbeddingType.alibi: + assert args.max_position_embeddings is not None + if args.seq_length is not None: + assert args.max_position_embeddings >= args.seq_length + if args.decoder_seq_length is not None: + assert args.max_position_embeddings >= args.decoder_seq_length + else: + assert args.max_position_embeddings is None + if args.seq_length is not None: assert args.max_position_embeddings >= args.seq_length if args.decoder_seq_length is not None: @@ -306,12 +316,21 @@ def _add_network_size_args(parser): 'attention. This is set to ' ' args.hidden_size // args.num_attention_heads ' 'if not provided.') + group.add_argument('--embed-layernorm', action='store_true', + help='use layernorm for embedding') group.add_argument('--max-position-embeddings', type=int, default=None, help='Maximum number of position embeddings to use. ' 'This is the size of position embedding.') + group.add_argument('--position-embedding-type', type=lambda x: PositionEmbeddingType[x], + choices=list(PositionEmbeddingType), default=PositionEmbeddingType.absolute, + help='Define position embedding type ("absolute" | "rotary" | "alibi"). "absolute" by default.') group.add_argument('--make-vocab-size-divisible-by', type=int, default=128, help='Pad the vocab size to be divisible by this value.' 'This is added for computational efficieny reasons.') + group.add_argument('--pad-vocab-size-to', type=int, default=None, + help='Pad the vocab size to this value.' + 'This value must be greater than the initial size of the tokenizer,' + 'needs to be divisible by TP size and `make-vocab-size-divisible-by`.') group.add_argument('--layernorm-epsilon', type=float, default=1e-5, help='Layer norm epsilon.') group.add_argument('--apply-residual-connection-post-layernorm', @@ -339,6 +358,24 @@ def _add_logging_args(parser): help='If set, calculate and log parameters norm.') group.add_argument('--log-num-zeros-in-grad', action='store_true', help='If set, calculate and log the number of zeros in gradient.') + group.add_argument('--timing-log-level', type=int, + default=0, choices=range(0,3), + help='Granularity level to measure and report timing. ' + ' 0: report only iteration time and make sure timing ' + ' does not introduce extra overhead.' + ' 1: report timing for operations that are executed ' + ' very limited times (basically once) during ' + ' each iteration (such as gradient all-reduce) ' + ' 2: report timing for operations that migh be ' + ' executed numerous times during each iteration. ' + 'Note that setting the level to 1 or 2 might ' + 'cause increase in iteration time.') + group.add_argument('--timing-log-option', type=str, default='minmax', + choices=['max', 'minmax', 'all'], + help='Options for logging timing:' + ' max: report the max timing across all ranks' + ' minmax: report min and max timings across all ranks' + ' all: report timings of all ranks.') group.add_argument('--tensorboard-log-interval', type=int, default=1, help='Report to tensorboard interval.') group.add_argument('--tensorboard-queue-size', type=int, default=1000, @@ -739,8 +776,12 @@ def _add_data_args(parser): default=None, choices=['BertWordPieceLowerCase', 'BertWordPieceCase', - 'GPT2BPETokenizer'], + 'GPT2BPETokenizer', + 'PretrainedFromHF'], help='What type of tokenizer to use.') + group.add_argument("--tokenizer-name-or-path", type=str, default=None, + help="Name or path of the huggingface tokenizer.") + group.add_argument('--data-impl', type=str, default='infer', choices=['lazy', 'cached', 'mmap', 'infer'], help='Implementation of indexed datasets.') @@ -751,6 +792,8 @@ def _add_data_args(parser): 'end-of-document token.') group.add_argument('--eod-mask-loss', action='store_true', help='Mask loss for the end of document tokens.') + group.add_argument('--loss-on-targets-only', action='store_true', + help='Mask loss on input sequence.') group.add_argument('--train-data-exact-num-epochs', type=int, default=None, help='When building the train dataset, force it to be ' 'an exact number of epochs of the raw data') diff --git a/megatron/checkpointing.py b/megatron/checkpointing.py index fbef9da003be33c95fbca63e0a444e429c21e9ca..2154228099e18da7d657d6b3f38eadeaf999d5b2 100644 --- a/megatron/checkpointing.py +++ b/megatron/checkpointing.py @@ -21,6 +21,7 @@ import sys import numpy as np from deepspeed.accelerator import get_accelerator import torch +from megatron.enums import PositionEmbeddingType from megatron import (get_args, is_rank_0, @@ -62,7 +63,12 @@ def check_checkpoint_args(checkpoint_args): _compare('num_layers') _compare('hidden_size') _compare('num_attention_heads') - _compare('max_position_embeddings') + + _compare('position_embedding_type') + # with alibi we can change `max_position_embeddings` + if args.position_embedding_type != PositionEmbeddingType.alibi: + _compare('max_position_embeddings') + if args.vocab_file: _compare('make_vocab_size_divisible_by') _compare('padded_vocab_size') diff --git a/megatron/data/gpt_dataset.py b/megatron/data/gpt_dataset.py index a14179bb9fd6c9ef1d89e33288849732fcde95ba..a99904a9af2b5be05304ee1b69e2e942043d48d1 100644 --- a/megatron/data/gpt_dataset.py +++ b/megatron/data/gpt_dataset.py @@ -219,20 +219,11 @@ def _build_index_mappings(name, data_prefix, documents, sizes, sample_idx_filename = _filename + '_sample_idx.npy' shuffle_idx_filename = _filename + '_shuffle_idx.npy' - if name == 'train': - # force to use certain index files - if args.train_doc_idx_path is not None: - doc_idx_filename = args.train_doc_idx_path - if args.train_sample_idx_path is not None: - sample_idx_filename = args.train_sample_idx_path - if args.train_shuffle_idx_path is not None: - shuffle_idx_filename = args.train_shuffle_idx_path - # Build the indexed mapping if not exist. if is_rank_0(): if (not os.path.isfile(doc_idx_filename)) or \ - (not os.path.isfile(sample_idx_filename)) or \ - (not os.path.isfile(shuffle_idx_filename)): + (not os.path.isfile(sample_idx_filename)) or \ + (not os.path.isfile(shuffle_idx_filename)): print_rank_0(' > WARNING: could not find index map files, building ' 'the indices on rank 0 ...') diff --git a/megatron/data/indexed_dataset.py b/megatron/data/indexed_dataset.py index 380fe4e23546a229300dbddeb7fead4a3b411772..28ec14762b1cebf57f08227d70354ee9d1fb5798 100644 --- a/megatron/data/indexed_dataset.py +++ b/megatron/data/indexed_dataset.py @@ -11,6 +11,7 @@ # An empty sentence no longer separates documents. # Some of the fixes/improvements are adopted from +# https://github.com/bigscience-workshop/AscendSpeed/blob/main/megatron/data/indexed_dataset.py from functools import lru_cache import os diff --git a/megatron/global_vars.py b/megatron/global_vars.py index 21b647dac1c30f7d90142ba1b9d30a938b97a5d7..fb723532c5f15e86f8c2bb1369e241b6b30be72c 100644 --- a/megatron/global_vars.py +++ b/megatron/global_vars.py @@ -17,7 +17,6 @@ import os import sys -import time import torch @@ -25,6 +24,8 @@ from megatron.tokenizer import build_tokenizer from .arguments import parse_args from .microbatches import build_num_microbatches_calculator from deepspeed.accelerator import get_accelerator +from .timers import Timers + _GLOBAL_ARGS = None _GLOBAL_NUM_MICROBATCHES_CALCULATOR = None _GLOBAL_TOKENIZER = None @@ -83,11 +84,11 @@ def set_global_variables(extra_args_provider=None, args_defaults={}, defaults=args_defaults, ignore_unknown_args=ignore_unknown_args) _build_num_microbatches_calculator(args) - if args.vocab_file: + if args.vocab_file or args.tokenizer_name_or_path: _ = _build_tokenizer(args) _set_tensorboard_writer(args) _set_adlr_autoresume(args) - _set_timers() + _set_timers(args) def _parse_args(extra_args_provider=None, defaults={}, @@ -163,11 +164,11 @@ def _set_adlr_autoresume(args): _GLOBAL_ADLR_AUTORESUME = AutoResume -def _set_timers(): +def _set_timers(args): """Initialize timers.""" global _GLOBAL_TIMERS _ensure_var_is_not_initialized(_GLOBAL_TIMERS, 'timers') - _GLOBAL_TIMERS = Timers() + _GLOBAL_TIMERS = Timers(args.timing_log_level, args.timing_log_option) def _ensure_var_is_initialized(var, name): @@ -177,86 +178,4 @@ def _ensure_var_is_initialized(var, name): def _ensure_var_is_not_initialized(var, name): """Make sure the input variable is not None.""" - assert var is None, '{} is already initialized.'.format(name) - - -class _Timer: - """Timer.""" - - def __init__(self, name): - self.name_ = name - self.elapsed_ = 0.0 - self.started_ = False - self.start_time = time.time() - - def start(self): - """Start the timer.""" - assert not self.started_, 'timer has already been started' - get_accelerator().synchronize() - self.start_time = time.time() - self.started_ = True - - def stop(self): - """Stop the timer.""" - assert self.started_, 'timer is not started' - get_accelerator().synchronize() - self.elapsed_ += (time.time() - self.start_time) - self.started_ = False - - def reset(self): - """Reset timer.""" - self.elapsed_ = 0.0 - self.started_ = False - - def elapsed(self, reset=True): - """Calculate the elapsed time.""" - started_ = self.started_ - # If the timing in progress, end it first. - if self.started_: - self.stop() - # Get the elapsed time. - elapsed_ = self.elapsed_ - # Reset the elapsed time - if reset: - self.reset() - # If timing was in progress, set it back. - if started_: - self.start() - return elapsed_ - - -class Timers: - """Group of timers.""" - - def __init__(self): - self.timers = {} - - def __call__(self, name): - if name not in self.timers: - self.timers[name] = _Timer(name) - return self.timers[name] - - def write(self, names, writer, iteration, normalizer=1.0, reset=False): - """Write timers to a tensorboard writer""" - # currently when using add_scalars, - # torch.utils.add_scalars makes each timer its own run, which - # polutes the runs list, so we just add each as a scalar - assert normalizer > 0.0 - for name in names: - value = self.timers[name].elapsed(reset=reset) / normalizer - writer.add_scalar(name + '-time', value, iteration) - - def log(self, names, normalizer=1.0, reset=True): - """Log a group of timers.""" - assert normalizer > 0.0 - string = 'time (ms)' - for name in names: - elapsed_time = self.timers[name].elapsed( - reset=reset) * 1000.0 / normalizer - string += ' | {}: {:.2f}'.format(name, elapsed_time) - if torch.distributed.is_initialized(): - if torch.distributed.get_rank() == ( - torch.distributed.get_world_size() - 1): - print(string, flush=True) - else: - print(string, flush=True) + assert var is None, '{} is already initialized.'.format(name) \ No newline at end of file diff --git a/megatron/initialize.py b/megatron/initialize.py index e83717e941ecce0d77f55646d562e4711ca8de1c..0d97d921ecb32a49e87dd10fabbd1ee530b8d651 100644 --- a/megatron/initialize.py +++ b/megatron/initialize.py @@ -21,7 +21,6 @@ import time import numpy as np import torch - from megatron import get_adlr_autoresume from megatron import get_args from megatron import get_tensorboard_writer @@ -29,8 +28,11 @@ from megatron import mpu from megatron.global_vars import set_global_variables from megatron.mpu import (set_tensor_model_parallel_rank, set_tensor_model_parallel_world_size) + + from deepspeed.accelerator import get_accelerator import deepspeed +import deepspeed.utils.groups as groups def initialize_megatron(extra_args_provider=None, args_defaults=None, ignore_unknown_args=False, allow_no_cuda=False): diff --git a/megatron/model/__init__.py b/megatron/model/__init__.py index 3033f6a4865376f14a35fb91b3fd0bd381c513bc..22b822a7b0a51801634dfb756324bb4c84de6b4b 100644 --- a/megatron/model/__init__.py +++ b/megatron/model/__init__.py @@ -22,3 +22,4 @@ from .gpt_model import GPTModel, GPTModelPipe from .llama_model import LlamaModel, LlamaModelPipe from .language_model import get_language_model from .module import Float16Module +from .enums import ModelType diff --git a/megatron/model/enums.py b/megatron/model/enums.py index b6992fefafeda2fc15be0f61c08924385a7c0933..c4f4e27f09dc97c081b757e0facd6b2c63539db3 100644 --- a/megatron/model/enums.py +++ b/megatron/model/enums.py @@ -15,6 +15,10 @@ import enum +class ModelType(enum.Enum): + encoder_or_decoder = 1 + encoder_and_decoder = 2 + class LayerType(enum.Enum): encoder = 1 decoder = 2 @@ -25,4 +29,6 @@ class AttnType(enum.Enum): class AttnMaskType(enum.Enum): padding = 1 - causal = 2 + causal = 2 # Overrides `attention_mask` to be a lower triangular matrix + prefix = 3 + custom = 4 # Forces one to pass an `attention_mask` that's 1 if we need to mask. Tensor that can be broadcast to [micro_batch_size, n_head, seq_length, seq_length] diff --git a/megatron/model/fused_layer_norm.py b/megatron/model/fused_layer_norm.py index c7af4df47dd8918579a22fa9beb27c035770f826..c8c28c35cc2ae1ef8000bfaee82009685532e4ca 100644 --- a/megatron/model/fused_layer_norm.py +++ b/megatron/model/fused_layer_norm.py @@ -18,12 +18,19 @@ with some changes. """ import numbers +from megatron.mpu.utils import make_viewless_tensor import torch from torch.nn.parameter import Parameter from torch.nn import init import importlib from torch.nn import functional as F +try: + from apex.contrib.layer_norm.layer_norm import FastLayerNormFN + HAVE_PERSIST_LAYER_NORM = True +except: + HAVE_PERSIST_LAYER_NORM = False + global fused_mix_prec_layer_norm_cuda fused_mix_prec_layer_norm_cuda = None @@ -62,21 +69,30 @@ class FusedLayerNormAffineFunction(torch.autograd.Function): class MixedFusedLayerNorm(torch.nn.Module): - def __init__(self, normalized_shape, eps=1e-5): - super(MixedFusedLayerNorm, self).__init__() - - if isinstance(normalized_shape, numbers.Integral): - normalized_shape = (normalized_shape,) - self.normalized_shape = torch.Size(normalized_shape) - self.eps = eps - self.weight = Parameter(torch.Tensor(*normalized_shape)) - self.bias = Parameter(torch.Tensor(*normalized_shape)) - self.reset_parameters() - + def __init__(self, normalized_shape, eps=1e-5, no_persist_layer_norm=True, sequence_parallel=False): + super(MixedFusedLayerNorm, self).__init__() + if isinstance(normalized_shape, numbers.Integral): + normalized_shape = (normalized_shape,) + self.normalized_shape = torch.Size(normalized_shape) + self.eps = eps + self.weight = torch.nn.parameter.Parameter(torch.Tensor(*normalized_shape)) + self.bias = torch.nn.parameter.Parameter(torch.Tensor(*normalized_shape)) + self.reset_parameters() + self.no_persist_layer_norm = True + self.sequence_parallel = sequence_parallel + + # set sequence parallelism flag on weight and bias parameters + setattr(self.weight, 'sequence_parallel', self.sequence_parallel) + setattr(self.bias, 'sequence_parallel', self.sequence_parallel) def reset_parameters(self): init.ones_(self.weight) init.zeros_(self.bias) def forward(self, input): - return torch.nn.functional.layer_norm(input, self.normalized_shape, self.weight, self.bias, self.eps) + if self.no_persist_layer_norm: + return torch.nn.functional.layer_norm(input, self.normalized_shape, self.weight, self.bias, self.eps) + else: + output = FastLayerNormFN.apply(input, self.weight, self.bias, self.eps) + output = make_viewless_tensor(inp=output, requires_grad=input.requires_grad, keep_graph=True) + return output \ No newline at end of file diff --git a/megatron/model/fused_softmax.py b/megatron/model/fused_softmax.py index 3ef745ee3fd5e67758af7e366dd05fb7aa0058ad..7d7cfd13894af5eb189f4381727cce66a7258445 100644 --- a/megatron/model/fused_softmax.py +++ b/megatron/model/fused_softmax.py @@ -15,10 +15,88 @@ import torch import torch_npu +import torch.nn as nn from megatron.model.enums import AttnMaskType -class NPUFusedScaleMaskSoftmax(torch.nn.Module): +class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function): + """ + Fused operation which performs following three operations in sequence + 1. Scale the tensor. + 2. Apply upper triangular mask (typically used in gpt models). + 3. Perform softmax. + """ + + @staticmethod + def forward(ctx, inputs, scale): + import scaled_upper_triang_masked_softmax_cuda + + scale_t = torch.tensor([scale]) + + softmax_results = scaled_upper_triang_masked_softmax_cuda.forward( + inputs, scale_t[0] + ) + ctx.save_for_backward(softmax_results, scale_t) + return softmax_results + + @staticmethod + def backward(ctx, output_grads): + import scaled_upper_triang_masked_softmax_cuda + + softmax_results, scale_t = ctx.saved_tensors + + input_grads = scaled_upper_triang_masked_softmax_cuda.backward( + output_grads, softmax_results, scale_t[0] + ) + return input_grads, None + + +class ScaledMaskedSoftmax(torch.autograd.Function): + """ + Fused operation which performs following three operations in sequence + 1. Scale the tensor. + 2. Apply the mask. + 3. Perform softmax. + """ + + @staticmethod + def forward(ctx, inputs, mask, scale): + import scaled_masked_softmax_cuda + + scale_t = torch.tensor([scale]) + + softmax_results = scaled_masked_softmax_cuda.forward( + inputs, mask, scale_t[0] + ) + ctx.save_for_backward(softmax_results, scale_t) + return softmax_results + + @staticmethod + def backward(ctx, output_grads): + import scaled_masked_softmax_cuda + + softmax_results, scale_t = ctx.saved_tensors + + input_grads = scaled_masked_softmax_cuda.backward( + output_grads, softmax_results, scale_t[0] + ) + return input_grads, None, None + + +class NPUFusedScaleMaskSoftmax(nn.Module): + """ + fused operation: scaling + mask + softmax + + Arguments: + input_in_fp16: flag to indicate if input in fp16 data format. + input_in_bf16: flag to indicate if input in bf16 data format. + attn_mask_type: attention mask type (pad or causal) + scaled_masked_softmax_fusion: flag to indicate user want to use softmax fusion + mask_func: mask function to be applied. + softmax_in_fp32: if true, softmax in performed at fp32 precision. + scale: scaling factor used in input tensor scaling. + """ + def __init__( self, input_in_fp16, @@ -28,7 +106,7 @@ class NPUFusedScaleMaskSoftmax(torch.nn.Module): mask_func, softmax_in_fp32, scale, - ): + ): super(NPUFusedScaleMaskSoftmax, self).__init__() self.input_in_fp16 = input_in_fp16 self.input_in_bf16 = input_in_bf16 @@ -41,8 +119,6 @@ class NPUFusedScaleMaskSoftmax(torch.nn.Module): self.mask_func = mask_func self.softmax_in_fp32 = softmax_in_fp32 self.scale = scale - self.mask_tri = None - p = torch.npu.get_device_properties(0) if torch.npu.is_available() else None assert ( self.scale is None or softmax_in_fp32 @@ -52,28 +128,22 @@ class NPUFusedScaleMaskSoftmax(torch.nn.Module): # [b, np, sq, sk] assert input_.dim() == 4 - if torch.npu.is_available(): + if self.is_kernel_available(mask, *input_.size()): return self.npu_forward_fused_softmax(input_, mask) - - return self.npu_forward_torch_softmax(input_, mask) - - def npu_forward_fused_softmax(self, input_, mask): - if self.softmax_in_fp32: - input_ = input_.float() - - if self.scale is None: - self.scale = 1.0 - - if self.attn_mask_type == AttnMaskType.causal: - if self.mask_tri is None: - self.mask_tri = torch.triu(torch.ones(input_.shape, device=input_.device), diagonal=1).bool() - probs = torch_npu.npu_scaled_masked_softmax(input_, self.mask_tri, self.scale, False) else: - probs = torch_npu.npu_scaled_masked_softmax(input_, mask, self.scale, False) + return self.npu_forward_torch_softmax(input_, mask) - probs = probs.half() + def is_kernel_available(self, mask, b, np, sq, sk): + return ( + self.scaled_masked_softmax_fusion # user want to fuse + and self.input_in_float16 # input must be fp16 + and 32 < sk <= 2048 # sk must be 32 ~ 2048 + and sq % 16 == 0 # sq must be divisor of 16 + and sk % 16 == 0 # sk must be divisor of 16 + ) - return probs + def npu_forward_fused_softmax(self, input, mask): + return torch_npu.npu_scaled_masked_softmax(input, mask, self.scale, False) def npu_forward_torch_softmax(self, input_, mask): if self.input_in_float16 and self.softmax_in_fp32: @@ -81,14 +151,15 @@ class NPUFusedScaleMaskSoftmax(torch.nn.Module): if self.scale is not None: input_ = input_ * self.scale - - if self.attn_mask_type == AttnMaskType.causal: - mask_tri = torch.triu(torch.ones(input_.shape, device=input_.device), diagonal=1).bool() - mask_output = self.mask_func(input_, mask_tri) - else: - mask_output = self.mask_func(input_, mask) if mask is not None else input_ + mask_output = self.mask_func(input_, mask) if mask is not None else input probs = torch.nn.Softmax(dim=-1)(mask_output) if self.input_in_float16 and self.softmax_in_fp32: probs = probs.half() if self.input_in_fp16 else probs.bfloat16() return probs + + @staticmethod + def get_batch_per_block(sq, sk, b, np): + import scaled_masked_softmax_cuda + + return scaled_masked_softmax_cuda.get_batch_per_block(sq, sk, b, np) \ No newline at end of file diff --git a/megatron/model/gpt_model.py b/megatron/model/gpt_model.py index 62b951064e3de315df01b80c877aea24286728a8..4c2322f0d9b3ec0bcad5b5015d302bee69e2a45d 100644 --- a/megatron/model/gpt_model.py +++ b/megatron/model/gpt_model.py @@ -72,7 +72,8 @@ class GPTModel(MegatronModule): parallel_output=True, pre_process=True, post_process=True, - return_moe_loss=True): + prefix_lm=False, + return_moe_loss=False): super(GPTModel, self).__init__() args = get_args() @@ -84,10 +85,11 @@ class GPTModel(MegatronModule): self.language_model, self._language_model_key = get_language_model( num_tokentypes=num_tokentypes, add_pooler=False, - encoder_attn_mask_type=AttnMaskType.causal, + encoder_attn_mask_type=AttnMaskType.prefix if prefix_lm else AttnMaskType.causal, + # encoder_attn_mask_type=AttnMaskType.causal, init_method=init_method_normal(args.init_method_std), scaled_init_method=scaled_init_method_normal(args.init_method_std, args.num_layers), - num_experts=args.num_experts, + # num_experts=args.num_experts, pre_process=self.pre_process, post_process=self.post_process) @@ -177,23 +179,50 @@ class GPTModel(MegatronModule): self.language_model.load_state_dict(state_dict, strict=strict) -def CrossEntropy(output, labels): - labels, loss_mask = labels[0], labels[1] +def get_cross_entropy(is_prefix: bool): + def CrossEntropy(output, labels): + labels, loss_mask = labels[0], labels[1] - args = get_args() + args = get_args() - losses = mpu.vocab_parallel_cross_entropy(output.contiguous().float(), labels) - loss_mask = loss_mask.view(-1) - loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum() - return loss + losses = mpu.vocab_parallel_cross_entropy(output.contiguous().float(), labels) + + if is_prefix: + micro_batch_size, sequence_length = loss_mask.shape + average_tokens_per_sample: torch.Tensor + if args.loss_on_targets_only: + # HACK: This is useful when we obtain loss masks that are microbatch dependent. Consequently, if we want to + # preserve the notion that all tokens have the same impact on the loss, we can only normalise using a + # microbatch independent value. It should be expected weight over a microbatch. + # Here we still use `sequence_length`, that's batch size dependent, in order to be backwards compatible with + # current experiment on vanilla gpt. + if args.reweight_loss_based_on_position_frequency: + reweight = torch.arange( + sequence_length, 0, -1, dtype=torch.float, device=loss_mask.device + ) / (sequence_length + 1) * 2 + average_tokens_per_sample = reweight.flip(-1).cumsum(-1).mean() + else: + average_tokens_per_sample = (sequence_length + 1) / 2 + else: + average_tokens_per_sample = sequence_length + expected_number_of_tokens = average_tokens_per_sample * micro_batch_size + else: + expected_number_of_tokens = loss_mask.sum() + loss_mask = loss_mask.view(-1) + loss = torch.sum(losses.view(-1) * loss_mask) / expected_number_of_tokens + return loss + return CrossEntropy class GPTModelPipe(PipelineModule,MegatronModule): """GPT-2 Language model.""" - def __init__(self, - num_tokentypes=0, - parallel_output=True): + def __init__( + self, + num_tokentypes=0, + parallel_output=True, + attn_mask_type: AttnMaskType = AttnMaskType.causal + ): args = get_args() self.parallel_output = parallel_output @@ -221,11 +250,19 @@ class GPTModelPipe(PipelineModule,MegatronModule): init_method=init_method, num_tokentypes=num_tokentypes, tied_weight_attr='word_embeddings_weight')) - + if args.fp32_residual_connection: - self.specs.append(lambda x: x.transpose(0, 1).contiguous().float()) + if getattr(args, 'pretrain_causal_attention', False): + self.specs.append(lambda x: x.transpose(0, 1).contiguous().float()) + else: + # EmbeddingPipe returns attention mask as well + self.specs.append(lambda x: (x[0].transpose(0, 1).contiguous().float(), *x[1:])) else: - self.specs.append(lambda x: x.transpose(0, 1).contiguous()) + if getattr(args, 'pretrain_causal_attention', False): + self.specs.append(lambda x: x.transpose(0, 1).contiguous()) + else: + # EmbeddingPipe returns attention mask as well + self.specs.append(lambda x: (x[0].transpose(0, 1).contiguous(), *x[1:])) for layer_idx in range(args.num_layers): self.specs.append( @@ -234,11 +271,15 @@ class GPTModelPipe(PipelineModule,MegatronModule): output_layer_init_method=scaled_init_method_normal(args.init_method_std, args.num_layers), layer_number=layer_idx, - self_attn_mask_type=AttnMaskType.causal)) - - + # TODO: Change naming of class from GPT to something that encapsulate prefix lm. + self_attn_mask_type=attn_mask_type)) + # Undo data format change - self.specs.append(lambda x: x.transpose(0, 1).contiguous()) + def undo(x): + if not getattr(args, 'pretrain_causal_attention', False): + x = x[0] + return x.transpose(0, 1).contiguous() + self.specs.append(undo) # Final layernorm after transformer layers self.specs.append( @@ -258,7 +299,6 @@ class GPTModelPipe(PipelineModule,MegatronModule): EmbeddingPipe, args.hidden_size, args.padded_vocab_size, - args.max_position_embeddings, args.hidden_dropout, init_method=init_method, num_tokentypes=num_tokentypes, @@ -274,14 +314,26 @@ class GPTModelPipe(PipelineModule,MegatronModule): interval = args.checkpoint_num_layers else: interval = 0 - + from deepspeed.runtime.pipe.topology import PipeModelDataParallelTopology topo = PipeModelDataParallelTopology(num_pp=mpu.get_pipeline_model_parallel_world_size(), num_mp=mpu.get_tensor_model_parallel_world_size(), num_dp=mpu.get_data_parallel_world_size()) + # here one can extend the regex to include more layers to be counted towards partitioning, + # e.g. 'type:transformer|embedding' will add up all the transformer blocks and also the first + # and last embedding layers and then partition that transformers+2 layers - so to get a good + # balance you may want to use less transformer layers + # + # caveat emptor: the current implementation of PP fails unless each stage has at least one + # transformer layer + # if args.pp_partition_method is not None: + # partition_method = args.pp_partition_method + # else: + partition_method = 'type:transformer' + super().__init__(layers=self.specs, - loss_fn=CrossEntropy, + loss_fn=get_cross_entropy(is_prefix=attn_mask_type is AttnMaskType.prefix), topology=topo, activation_checkpoint_interval=interval, - partition_method='type:transformer') + partition_method=partition_method) diff --git a/megatron/model/language_model.py b/megatron/model/language_model.py index 2c6802a74afbd15332439c91688152b682fa7d95..fa098e5f483956036963594a01aa255a95d4322d 100644 --- a/megatron/model/language_model.py +++ b/megatron/model/language_model.py @@ -25,6 +25,8 @@ from megatron.model.enums import LayerType, AttnMaskType from megatron.model.transformer import ParallelTransformer from megatron.model.utils import get_linear_layer from megatron.model.utils import init_method_normal, scaled_init_method_normal +from ..enums import PositionEmbeddingType + def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, bias=None): @@ -137,11 +139,17 @@ class Embedding(MegatronModule): self._word_embeddings_key = 'word_embeddings' # Position embedding (serial). - self.position_embeddings = torch.nn.Embedding( - max_sequence_length, self.hidden_size) - self._position_embeddings_key = 'position_embeddings' - # Initialize the position embeddings. - self.init_method(self.position_embeddings.weight) + self.position_embedding_type = args.position_embedding_type + if self.position_embedding_type == PositionEmbeddingType.absolute: + max_position_embeddings = args.max_position_embeddings + assert max_position_embeddings is not None + self.position_embeddings = torch.nn.Embedding( + max_position_embeddings, self.hidden_size) + self._position_embeddings_key = 'position_embeddings' + # Initialize the position embeddings. + self.init_method(self.position_embeddings.weight) + else: + self.position_embeddings = None # Token type embedding. # Add this as an optional field that can be added through @@ -179,8 +187,14 @@ class Embedding(MegatronModule): def forward(self, input_ids, position_ids, tokentype_ids=None): # Embeddings. words_embeddings = self.word_embeddings(input_ids) - position_embeddings = self.position_embeddings(position_ids) - embeddings = words_embeddings + position_embeddings + embeddings = words_embeddings + + if self.position_embedding_type == PositionEmbeddingType.absolute: + assert self.position_embeddings is not None + embeddings = embeddings + self.position_embeddings(position_ids) + else: + assert self.position_embeddings is None + if tokentype_ids is not None: assert self.tokentype_embeddings is not None embeddings = embeddings + self.tokentype_embeddings(tokentype_ids) @@ -199,9 +213,10 @@ class Embedding(MegatronModule): state_dict_ = {} state_dict_[self._word_embeddings_key] \ = self.word_embeddings.state_dict(destination, prefix, keep_vars) - state_dict_[self._position_embeddings_key] \ - = self.position_embeddings.state_dict( - destination, prefix, keep_vars) + if self.position_embeddings == PositionEmbeddingType.absolute: + state_dict_[self._position_embeddings_key] \ + = self.position_embeddings.state_dict( + destination, prefix, keep_vars) if self.num_tokentypes > 0: state_dict_[self._tokentype_embeddings_key] \ = self.tokentype_embeddings.state_dict( @@ -225,16 +240,17 @@ class Embedding(MegatronModule): self.word_embeddings.load_state_dict(state_dict_, strict=strict) # Position embedding. - if self._position_embeddings_key in state_dict: - state_dict_ = state_dict[self._position_embeddings_key] - else: - # for backward compatibility. - state_dict_ = {} - for key in state_dict.keys(): - if 'position_embeddings' in key: - state_dict_[key.split('position_embeddings.')[1]] \ - = state_dict[key] - self.position_embeddings.load_state_dict(state_dict_, strict=strict) + if self.position_embedding_type == PositionEmbeddingType.absolute: + if self._position_embeddings_key in state_dict: + state_dict_ = state_dict[self._position_embeddings_key] + else: + # for backward compatibility. + state_dict_ = {} + for key in state_dict.keys(): + if 'position_embeddings' in key: + state_dict_[key.split('position_embeddings.')[1]] \ + = state_dict[key] + self.position_embeddings.load_state_dict(state_dict_, strict=strict) # Tokentype embedding. if self.num_tokentypes > 0: @@ -254,7 +270,6 @@ class Embedding(MegatronModule): print('***WARNING*** expected tokentype embeddings in the ' 'checkpoint but could not find it', flush=True) - class EmbeddingPipe(Embedding): def forward(self, inputs, **kwargs): @@ -272,7 +287,7 @@ class EmbeddingPipe(Embedding): tokentype_ids = inputs[3] else: tokentype_ids = None - + embeddings = super().forward(input_ids, position_ids, tokentype_ids=tokentype_ids) # If cmd args has attn_mask, we don't forward it as an activation. diff --git a/megatron/model/llama_model.py b/megatron/model/llama_model.py index b07954abea8215964af2aecb6c5f1bcbfbf3904e..f7cc7dd2743d31853f0cccdc36802cc31c2d0213 100644 --- a/megatron/model/llama_model.py +++ b/megatron/model/llama_model.py @@ -953,4 +953,4 @@ class LlamaModel(MegatronModule): loss = mpu.vocab_parallel_cross_entropy(hidden_states.float(), labels) return loss - return hidden_states + return hidden_states \ No newline at end of file diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index 97fa030438e210c050f99b81479a2beeeaaef137..7dd4ff940c5914cac96dbe1713b0c4d37ebd3613 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -26,10 +26,13 @@ from megatron.model.enums import AttnMaskType, LayerType, AttnType from megatron.model.fused_bias_gelu import bias_gelu_impl from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu from megatron.model.module import MegatronModule -from torch import distributed as dist +from megatron.enums import PositionEmbeddingType + import deepspeed -from deepspeed.moe.layer import MoE from deepspeed.accelerator import get_accelerator +from deepspeed.moe.layer import MoE + + # flags required to enable jit fusion kernels torch._C._jit_set_profiling_mode(False) torch._C._jit_set_profiling_executor(False) @@ -123,6 +126,7 @@ class ParallelAttention(MegatronModule): args = get_args() self.fp16 = args.fp16 self.bf16 = args.bf16 + self.position_embedding_type = args.position_embedding_type self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling self.attention_softmax_in_fp32 = args.attention_softmax_in_fp32 @@ -131,7 +135,7 @@ class ParallelAttention(MegatronModule): self.layer_number = max(1, layer_number) self.attention_type = attention_type self.attn_mask_type = attn_mask_type - self.num_attention_heads = args.num_attention_heads + projection_size = args.kv_channels * args.num_attention_heads # Per attention head and per partition values. @@ -196,8 +200,11 @@ class ParallelAttention(MegatronModule): get_cuda_rng_tracker = deepspeed.checkpointing.get_cuda_rng_tracker checkpoint = deepspeed.checkpointing.checkpoint + if self.position_embedding_type == PositionEmbeddingType.rotary: + self.rotary_emb = RotaryEmbedding(self.hidden_size_per_attention_head, precision=args.params_dtype) + def forward(self, hidden_states, attention_mask, layer_past=None, - get_key_value=False, encoder_output=None): + get_key_value=False, encoder_output=None, alibi=None): # hidden_states: [sq, b, h] # ===================== @@ -271,23 +278,55 @@ class ParallelAttention(MegatronModule): output_size[0] * output_size[1], -1) # preallocting result tensor: [b * np, sq, sk] - matmul_result = torch.empty( - output_size[0]*output_size[1], - output_size[2], - output_size[3], - dtype=query_layer.dtype, - device=get_accelerator().current_device_name()) + if alibi is None: + matmul_result = torch.empty( + output_size[0]*output_size[1], + output_size[2], + output_size[3], + dtype=query_layer.dtype, + device=get_accelerator().current_device_name()) + else: + matmul_result = alibi[:output_size[0]*output_size[1], :, :output_size[3]] + + # Rotary embeddings + if self.position_embedding_type == PositionEmbeddingType.rotary: + apply_rotary_fn = apply_rotary_pos_emb_torch if self.bf16 else apply_rotary_pos_emb + + seq_len = key_layer.shape[0] + offset = 0 + if layer_past is not None and layer_past.numel() > 0: + offset = layer_past[0].shape[0] + seq_len += offset + cos, sin = self.rotary_emb(value_layer, seq_len=seq_len) + query_layer, key_layer = apply_rotary_fn(query_layer, key_layer, cos, sin, offset=offset) # Raw attention scores. [b * np, sq, sk] - matmul_result = torch.baddbmm( - matmul_result, - query_layer.transpose(0, 1), # [b * np, sq, hn] - key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] - beta=0.0, alpha=(1.0/self.norm_factor)) + if alibi is None: + matmul_result = torch.baddbmm( + matmul_result, + query_layer.transpose(0, 1), # [b * np, sq, hn] + key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] + beta=0.0, alpha=(1.0/self.norm_factor)) + else: + if not hasattr(self, "logged_alibi"): + if torch.distributed.is_initialized(): + if torch.distributed.get_rank() == 0: + print("Using Alibi", flush=True) + self.logged_alibi = True + + if self.apply_query_key_layer_scaling: + beta = 1.0 / self.layer_number + else: + beta = 1.0 + + matmul_result = torch.baddbmm( + matmul_result, + query_layer.transpose(0, 1), # [b * np, sq, hn] + key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] + beta=beta, alpha=(1.0 / self.norm_factor)) # change view to [b, np, sq, sk] attention_scores = matmul_result.view(*output_size) - # ================================================== # Update attention mask for inference. [b, np, sq, sk] # ================================================== @@ -418,7 +457,7 @@ class ParallelTransformerLayer(MegatronModule): eps=args.layernorm_epsilon) # Self attention. - self.attention = ParallelAttention( + self.self_attention = ParallelAttention( init_method, output_layer_init_method, layer_number, @@ -463,7 +502,18 @@ class ParallelTransformerLayer(MegatronModule): eval_capacity_factor=args.moe_eval_capacity_factor, min_capacity=args.moe_min_capacity, drop_tokens=args.moe_token_dropping, use_tutel=args.use_tutel, - enable_expert_tensor_parallelism=enable_expert_tensor_parallelism) + enable_expert_tensor_parallelism=enable_expert_tensor_parallelism) + + # Alibi + if args.position_embedding_type == PositionEmbeddingType.alibi: + self.alibi = self._build_alibi_tensor(args.seq_length, args.num_attention_heads, + args.micro_batch_size).to(torch.cuda.current_device()) + if args.params_dtype == torch.float16: + self.alibi = self.alibi.to(torch.float16) + elif args.params_dtype == torch.bfloat16: + self.alibi = self.alibi.to(torch.bfloat16) + else: + self.alibi = None def forward(self, hidden_states, attention_mask=None, encoder_output=None, enc_dec_attn_mask=None, @@ -474,10 +524,11 @@ class ParallelTransformerLayer(MegatronModule): layernorm_output = self.input_layernorm(hidden_states) # Self attention. attention_output, attention_bias = \ - self.attention(layernorm_output, + self.self_attention(layernorm_output, attention_mask, layer_past=layer_past, - get_key_value=get_key_value) + get_key_value=get_key_value, + alibi=self.alibi) if get_key_value: attention_output, presents = attention_output @@ -564,6 +615,36 @@ class ParallelTransformerLayer(MegatronModule): return output, moe_loss + @staticmethod + def _build_alibi_tensor(max_seq_len, num_attention_heads, batch_size): + # Based on https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742 + """Returns tensor shaped (batch_size * num_attention_heads, 1, max_seq_len)""" + + def get_slopes(n): + def get_slopes_power_of_2(n): + start = (2 ** (-2 ** -(math.log2(n) - 3))) + ratio = start + return [start * ratio ** i for i in range(n)] + + if math.log2(n).is_integer(): + return get_slopes_power_of_2(n) + else: + closest_power_of_2 = 2 ** math.floor(math.log2(n)) + return get_slopes_power_of_2(closest_power_of_2) + get_slopes(2 * closest_power_of_2)[0::2][ + :n - closest_power_of_2] + + slopes = torch.Tensor(get_slopes(num_attention_heads)) + alibi = slopes.unsqueeze(1).unsqueeze(1) * torch.arange(max_seq_len).unsqueeze(0).unsqueeze(0).expand( + num_attention_heads, -1, -1) + + # Select the part of the tensor that corresponds to our tensor parallel index. + tp_world_size = mpu.get_tensor_model_parallel_world_size() + tp_index = mpu.get_tensor_model_parallel_rank() + alibi = alibi.reshape((tp_world_size, -1, *alibi.shape[1:]))[tp_index] + + alibi = alibi.repeat(batch_size, 1, 1) + return alibi + class ParallelTransformerLayerPipe(ParallelTransformerLayer): """Extends ParallelTransformerLayer to forward attention_mask through the pipeline. @@ -603,7 +684,6 @@ class ParallelTransformerLayerPipe(ParallelTransformerLayer): else: raise RuntimeError('Received more inputs than understood.') - class ParallelTransformer(MegatronModule): """Transformer class.""" diff --git a/megatron/mpu/__init__.py b/megatron/mpu/__init__.py index 32bb5fc3dd6ec916a65cd773bc577c71bf348583..e4d06bb155df647cbe46195d6719d06429976d55 100644 --- a/megatron/mpu/__init__.py +++ b/megatron/mpu/__init__.py @@ -42,6 +42,7 @@ from .initialize import get_virtual_pipeline_model_parallel_rank, set_virtual_pi from .initialize import initialize_model_parallel from .initialize import model_parallel_is_initialized from .initialize import get_model_parallel_world_size, get_model_parallel_rank +from .initialize import is_pipeline_stage_after_split from .layers import ColumnParallelLinear from .layers import RowParallelLinear diff --git a/megatron/mpu/cross_entropy.py b/megatron/mpu/cross_entropy.py index 8c790cd3e1a2e7fd6f108e1e719b01e53f651aff..36e8807290b63c73699a4b1acc7615d381dfaa02 100644 --- a/megatron/mpu/cross_entropy.py +++ b/megatron/mpu/cross_entropy.py @@ -55,7 +55,7 @@ class _VocabParallelCrossEntropy(torch.autograd.Function): masked_target_1d = masked_target.view(-1) arange_1d = torch.arange(start=0, end=logits_2d.size()[0], device=logits_2d.device) - predicted_logits_1d = logits_2d[arange_1d, masked_target_1d] + predicted_logits_1d = logits_2d[arange_1d, masked_target_1d.long()] predicted_logits_1d = predicted_logits_1d.clone().contiguous() predicted_logits = predicted_logits_1d.view_as(target) predicted_logits[target_mask] = 0.0 @@ -96,7 +96,7 @@ class _VocabParallelCrossEntropy(torch.autograd.Function): # Add the gradient from matching classes. arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=grad_2d.device) - grad_2d[arange_1d, masked_target_1d] -= ( + grad_2d[arange_1d, masked_target_1d.long()] -= ( 1.0 - target_mask.view(-1).float()) # Finally elementwise multiplication with the output gradients. diff --git a/megatron/mpu/initialize.py b/megatron/mpu/initialize.py index c24d1179ada7587f2d96ebdaabc5e3b2d86c8ddd..6bc4b8514fb6e9110447a514df56f88a020ef8f1 100644 --- a/megatron/mpu/initialize.py +++ b/megatron/mpu/initialize.py @@ -372,3 +372,19 @@ def destroy_model_parallel(): _PIPELINE_MODEL_PARALLEL_GROUP = None global _DATA_PARALLEL_GROUP _DATA_PARALLEL_GROUP = None + + +def is_pipeline_stage_after_split(rank=None): + """Return True if pipeline stage executes decoder block for a model + with both encoder and decoder.""" + if get_pipeline_model_parallel_world_size() == 1: + return True + if rank is None: + rank = get_pipeline_model_parallel_rank() + global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK + if _PIPELINE_MODEL_PARALLEL_SPLIT_RANK is None: + return True + if rank >= _PIPELINE_MODEL_PARALLEL_SPLIT_RANK: + return True + return False + diff --git a/megatron/mpu/layers.py b/megatron/mpu/layers.py index c08168340bb6a63ebdd20bdb7b4991438d52974b..aa03774569f92e6a6e91f459d6c19be94d55a784 100644 --- a/megatron/mpu/layers.py +++ b/megatron/mpu/layers.py @@ -35,7 +35,8 @@ from .random import get_cuda_rng_tracker from .utils import divide from .utils import split_tensor_along_last_dim from .utils import VocabUtility -from megatron import get_args +from ..model.fused_layer_norm import MixedFusedLayerNorm as LayerNorm +from megatron import get_args, mpu import deepspeed.runtime.activation_checkpointing.checkpointing as ds_checkpointing from deepspeed.accelerator import get_accelerator @@ -43,6 +44,7 @@ _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = {'tensor_model_parallel': False, 'partition_dim': -1, 'partition_stride': 1} + def param_is_not_tensor_parallel_duplicate(param): return (hasattr(param, 'tensor_model_parallel') and param.tensor_model_parallel) or ( @@ -165,6 +167,9 @@ class VocabParallelEmbedding(torch.nn.Module): # Allocate weights and initialize. args = get_args() + if mpu.is_pipeline_first_stage() and args.embed_layernorm: + self.norm = LayerNorm(embedding_dim) + if args.use_cpu_initialization: self.weight = Parameter(torch.empty( self.num_embeddings_per_partition, self.embedding_dim, @@ -199,6 +204,9 @@ class VocabParallelEmbedding(torch.nn.Module): output_parallel[input_mask, :] = 0.0 # Reduce across all the model parallel GPUs. output = reduce_from_tensor_model_parallel_region(output_parallel) + if hasattr(self, 'norm'): + output = self.norm(output) + return output diff --git a/megatron/mpu/random.py b/megatron/mpu/random.py index 8d54867f978f31fe95c9e43051873eb8426e65ea..beea791b56882a863f162ccec21f770740d75df3 100644 --- a/megatron/mpu/random.py +++ b/megatron/mpu/random.py @@ -168,13 +168,33 @@ class CudaRNGStatesTracker: if name in self.states_: raise Exception('cuda rng state {} already exists'.format(name)) # Get the current rng state. + orig_rng_state = get_accelerator().get_rng_state() # Set the new state and store it. - torch.cuda.manual_seed(seed) + get_accelerator().manual_seed(seed) + self.states_[name] = get_accelerator().get_rng_state() # Reset rng state to what it was. + _set_cuda_rng_state(orig_rng_state) @contextlib.contextmanager def fork(self, name=_MODEL_PARALLEL_RNG_TRACKER_NAME): - yield + """Fork the cuda rng state, perform operations, and exit with + the original state.""" + # Check if we have added the state + if name not in self.states_: + print(name, self.states_) + raise Exception('cuda rng state {} is not added'.format(name)) + # Store current rng state. + orig_cuda_rng_state = get_accelerator().get_rng_state() + # Set rng state to the desired one + _set_cuda_rng_state(self.states_[name]) + # Do the stuff we wanted to do. + try: + yield + finally: + # Update the current rng state for later use. + self.states_[name] = get_accelerator().get_rng_state() + # And set the state to the original state we started with. + _set_cuda_rng_state(orig_cuda_rng_state) # RNG tracker object. @@ -290,7 +310,7 @@ class CheckpointFunction(torch.autograd.Function): if isinstance(outputs, torch.Tensor): outputs = (outputs,) elif len(outputs) == 2 and isinstance(outputs[1], torch.Tensor) and \ - torch.equal(outputs[1], torch.tensor(0, dtype=outputs[1].dtype).to(get_accelerator().device_name())): + torch.equal(outputs[1], torch.tensor(0).to(get_accelerator().device_name())): # a hacky solution to overcome issue when running old script examples/pretrain_gpt_distributed.sh outputs = (outputs[0],) torch.autograd.backward(outputs, args) diff --git a/megatron/mpu/utils.py b/megatron/mpu/utils.py index 56ed1c76e1404389f18ab3be01e13dfdc678d942..2431a1c39c334c5a40b1a8c46cfb6428ac40f552 100644 --- a/megatron/mpu/utils.py +++ b/megatron/mpu/utils.py @@ -29,6 +29,60 @@ def divide(numerator, denominator): ensure_divisibility(numerator, denominator) return numerator // denominator +def _kernel_make_viewless_tensor(inp, requires_grad): + '''Make a viewless tensor. + + View tensors have the undesirable side-affect of retaining a reference + to the originally-viewed tensor, even after manually setting the '.data' + field. This method creates a new tensor that links to the old tensor's + data, without linking the viewed tensor, referenced via the '._base' + field. + ''' + out = torch.empty( + (1,), + dtype = inp.dtype, + device = inp.device, + requires_grad = requires_grad, + ) + with torch.no_grad(): + out.set_(inp.data) + return out + +class MakeViewlessTensor(torch.autograd.Function): + ''' + Autograd function to make a viewless tensor. + + This function should be used in cases where the computation graph needs + to be propagated, but we only want a viewless tensor (e.g., + ParallelTransformer's hidden_states). Call this function by passing + 'keep_graph = True' to 'make_viewless_tensor()'. + ''' + @staticmethod + def forward(ctx, inp, requires_grad): + return _kernel_make_viewless_tensor(inp, requires_grad) + @staticmethod + def backward(ctx, grad_output): + return grad_output, None + +def make_viewless_tensor(inp, requires_grad, keep_graph): + ''' + Entry-point for creating viewless tensors. + + This method should be used, rather than calling 'MakeViewlessTensor' + or '_kernel_make_viewless_tensor' directly. This method acts as a + switch for determining if an autograd function or a regular method + should be used to create the tensor. + ''' + + # return tensor as-is, if not a 'view' + if inp._base is None: + return inp + + # create viewless tensor + if keep_graph: + return MakeViewlessTensor.apply(inp, requires_grad) + else: + return _kernel_make_viewless_tensor(inp, requires_grad) def split_tensor_along_last_dim(tensor, num_partitions, contiguous_split_chunks=False): diff --git a/megatron/optimizer/__init__.py b/megatron/optimizer/__init__.py index 67968efe53dfeda550f6551c3d2bf36592172d31..12df7348d93c67406c13dd6f7300de1a89a61212 100644 --- a/megatron/optimizer/__init__.py +++ b/megatron/optimizer/__init__.py @@ -12,11 +12,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -import math -import apex -import torch from deepspeed.accelerator import get_accelerator +if get_accelerator().device_name() == 'cuda': + from apex.optimizers import FusedAdam as Adam + from apex.optimizers import FusedSGD as SGD +else: + from megatron.optimizer.adam import AdamW as Adam + from torch.optim import SGD + from megatron import get_args from megatron.model import LayerNorm @@ -24,129 +27,6 @@ from megatron.model import LayerNorm from .grad_scaler import ConstantGradScaler, DynamicGradScaler from .optimizer import Float16OptimizerWithFloat16Params, FP32Optimizer - -class AdamW(torch.optim.Optimizer): - r"""Implements AdamW algorithm. - - The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_. - The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_. - - Arguments: - params (iterable): iterable of parameters to optimize or dicts defining - parameter groups - lr (float, optional): learning rate (default: 1e-3) - betas (Tuple[float, float], optional): coefficients used for computing - running averages of gradient and its square (default: (0.9, 0.999)) - eps (float, optional): term added to the denominator to improve - numerical stability (default: 1e-8) - weight_decay (float, optional): weight decay coefficient (default: 1e-2) - amsgrad (boolean, optional): whether to use the AMSGrad variant of this - algorithm from the paper `On the Convergence of Adam and Beyond`_ - (default: False) - - .. _Adam\: A Method for Stochastic Optimization: - https://arxiv.org/abs/1412.6980 - .. _Decoupled Weight Decay Regularization: - https://arxiv.org/abs/1711.05101 - .. _On the Convergence of Adam and Beyond: - https://openreview.net/forum?id=ryQu7f-RZ - """ - - def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, - weight_decay=1e-2, amsgrad=False): - if not 0.0 <= lr: - raise ValueError("Invalid learning rate: {}".format(lr)) - if not 0.0 <= eps: - raise ValueError("Invalid epsilon value: {}".format(eps)) - if not 0.0 <= betas[0] < 1.0: - raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) - if not 0.0 <= betas[1] < 1.0: - raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) - defaults = dict(lr=lr, betas=betas, eps=eps, - weight_decay=weight_decay, amsgrad=amsgrad) - super(AdamW, self).__init__(params, defaults) - - def __setstate__(self, state): - super(AdamW, self).__setstate__(state) - for group in self.param_groups: - group.setdefault('amsgrad', False) - - @torch.no_grad() - def step(self, closure=None): - """Performs a single optimization step. - - Arguments: - closure (callable, optional): A closure that reevaluates the model - and returns the loss. - """ - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - for group in self.param_groups: - for p in group['params']: - if p.grad is None: - continue - - # Perform stepweight decay - p.data.mul_(1 - group['lr'] * group['weight_decay']) - - # Perform optimization step - grad = p.grad - if grad.is_sparse: - raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') - amsgrad = group['amsgrad'] - - state = self.state[p] - - # State initialization - if len(state) == 0: - state['step'] = 0 - # Exponential moving average of gradient values - state['exp_avg'] = torch.zeros_like(p) - # Exponential moving average of squared gradient values - state['exp_avg_sq'] = torch.zeros_like(p) - if amsgrad: - # Maintains max of all exp. moving avg. of sq. grad. values - state['max_exp_avg_sq'] = torch.zeros_like(p) - - exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] - if amsgrad: - max_exp_avg_sq = state['max_exp_avg_sq'] - beta1, beta2 = group['betas'] - - state['step'] += 1 - bias_correction1 = 1 - beta1 ** state['step'] - bias_correction2 = 1 - beta2 ** state['step'] - - # Decay the first and second moment running average coefficient - exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) - exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) - if amsgrad: - # Maintains the maximum of all 2nd moment running avg. till now - torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) - # Use the max. for normalizing running avg. of gradient - denom = (max_exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) - else: - denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) - - step_size = group['lr'] / bias_correction1 - - p.addcdiv_(exp_avg, denom, value=-step_size) - - return loss - -apex.optimizers.FusedAdam = AdamW -apex.optimizers.FusedSGD = torch.optim.SGD - -if get_accelerator().device_name() == 'cuda': - from apex.optimizers import FusedAdam as Adam - from apex.optimizers import FusedSGD as SGD -else: - from torch.optim import Adam - from torch.optim import SGD - def _get_params_for_weight_decay_optimization(modules): """Divide params into with-weight-decay and without-weight-decay groups. Layernorms and baises will have no weight decay but the rest will. diff --git a/megatron/optimizer/adam.py b/megatron/optimizer/adam.py new file mode 100644 index 0000000000000000000000000000000000000000..dbb6411795a633f59c0e64f15aa7b513efa0f04b --- /dev/null +++ b/megatron/optimizer/adam.py @@ -0,0 +1,200 @@ +from typing import List, Optional + +import torch +import torch_npu +from torch import Tensor +from torch.optim.optimizer import Optimizer + + +def adamw(params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + max_exp_avg_sqs: List[Tensor], + state_steps: List[int], + *, + amsgrad: bool, + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + eps: float, + maximize: bool): + r"""Functional API that performs AdamW algorithm computation. + See :class:`~torch.optim.AdamW` for details. + """ + for i, param in enumerate(params): + grad = grads[i] + exp_avg = exp_avgs[i] + exp_avg_sq = exp_avg_sqs[i] + step = state_steps[i] + + # Perform stepweight decay + ## param.mul_(1 - lr * weight_decay) + bias_correction1 = beta1 ** step + bias_correction2 = beta2 ** step + + param.data, exp_avg, exp_avg_sq = torch_npu.npu_apply_adam_w( + bias_correction1, + bias_correction2, + lr, + weight_decay, + beta1, + beta2, + eps, + grad, + None, + amsgrad, + maximize, + out=(param.data, exp_avg, exp_avg_sq) + ) + +class AdamW(Optimizer): + r"""Implements AdamW algorithm. + .. math:: + \begin{aligned} + &\rule{110mm}{0.4pt} \\ + &\textbf{input} : \gamma \text{(lr)}, \: \beta_1, \beta_2 + \text{(betas)}, \: \theta_0 \text{(params)}, \: f(\theta) \text{(objective)}, + \: \epsilon \text{ (epsilon)} \\ + &\hspace{13mm} \lambda \text{(weight decay)}, \: \textit{amsgrad}, + \: \textit{maximize} \\ + &\textbf{initialize} : m_0 \leftarrow 0 \text{ (first moment)}, v_0 \leftarrow 0 + \text{ ( second moment)}, \: \widehat{v_0}^{max}\leftarrow 0 \\[-1.ex] + &\rule{110mm}{0.4pt} \\ + &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\ + &\hspace{5mm}\textbf{if} \: \textit{maximize}: \\ + &\hspace{10mm}g_t \leftarrow -\nabla_{\theta} f_t (\theta_{t-1}) \\ + &\hspace{5mm}\textbf{else} \\ + &\hspace{10mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\ + &\hspace{5mm} \theta_t \leftarrow \theta_{t-1} - \gamma \lambda \theta_{t-1} \\ + &\hspace{5mm}m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\ + &\hspace{5mm}v_t \leftarrow \beta_2 v_{t-1} + (1-\beta_2) g^2_t \\ + &\hspace{5mm}\widehat{m_t} \leftarrow m_t/\big(1-\beta_1^t \big) \\ + &\hspace{5mm}\widehat{v_t} \leftarrow v_t/\big(1-\beta_2^t \big) \\ + &\hspace{5mm}\textbf{if} \: amsgrad \\ + &\hspace{10mm}\widehat{v_t}^{max} \leftarrow \mathrm{max}(\widehat{v_t}^{max}, + \widehat{v_t}) \\ + &\hspace{10mm}\theta_t \leftarrow \theta_t - \gamma \widehat{m_t}/ + \big(\sqrt{\widehat{v_t}^{max}} + \epsilon \big) \\ + &\hspace{5mm}\textbf{else} \\ + &\hspace{10mm}\theta_t \leftarrow \theta_t - \gamma \widehat{m_t}/ + \big(\sqrt{\widehat{v_t}} + \epsilon \big) \\ + &\rule{110mm}{0.4pt} \\[-1.ex] + &\bf{return} \: \theta_t \\[-1.ex] + &\rule{110mm}{0.4pt} \\[-1.ex] + \end{aligned} + For further details regarding the algorithm we refer to `Decoupled Weight Decay Regularization`_. + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay coefficient (default: 1e-2) + amsgrad (boolean, optional): whether to use the AMSGrad variant of this + algorithm from the paper `On the Convergence of Adam and Beyond`_ + (default: False) + maximize (bool, optional): maximize the params based on the objective, instead of + minimizing (default: False) + .. _Decoupled Weight Decay Regularization: + https://arxiv.org/abs/1711.05101 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ + """ + + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, + weight_decay=1e-2, amsgrad=False, *, maximize: bool = False): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + if not 0.0 <= weight_decay: + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + defaults = dict(lr=lr, betas=betas, eps=eps, + weight_decay=weight_decay, amsgrad=amsgrad, maximize=maximize) + super(AdamW, self).__init__(params, defaults) + + def __setstate__(self, state): + super(AdamW, self).__setstate__(state) + for group in self.param_groups: + group.setdefault('amsgrad', False) + group.setdefault('maximize', False) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + Args: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad = [] + grads = [] + exp_avgs = [] + exp_avg_sqs = [] + state_sums = [] + max_exp_avg_sqs = [] + state_steps = [] + amsgrad = group['amsgrad'] + beta1, beta2 = group['betas'] + + for p in group['params']: + if p.grad is None: + continue + params_with_grad.append(p) + if p.grad.is_sparse: + raise RuntimeError('AdamW does not support sparse gradients') + grads.append(p.grad) + + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = 0 + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) + if amsgrad: + # Maintains max of all exp. moving avg. of sq. grad. values + state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) + + exp_avgs.append(state['exp_avg']) + exp_avg_sqs.append(state['exp_avg_sq']) + + if amsgrad: + max_exp_avg_sqs.append(state['max_exp_avg_sq']) + + # update the steps for each param group update + state['step'] += 1 + # record the step after step update + state_steps.append(state['step']) + + # adamw_torch(params_with_grad, + adamw(params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + amsgrad=amsgrad, + beta1=beta1, + beta2=beta2, + lr=group['lr'], + weight_decay=group['weight_decay'], + eps=group['eps'], + maximize=group['maximize']) + + return loss \ No newline at end of file diff --git a/megatron/optimizer/clip_grads.py b/megatron/optimizer/clip_grads.py index e125fddb9b00db31bbfa0ed85cad00a4ee8b59a5..2619139d1f8c014b8786edfa27e5760f720b6967 100644 --- a/megatron/optimizer/clip_grads.py +++ b/megatron/optimizer/clip_grads.py @@ -45,13 +45,9 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2): Arguments: parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a single Tensor that will have gradients normalized - grads_for_norm (Iterable[Tensor]): an iterable of Tensors or a single - Tensor that will be used for calculating the grad norm. max_norm (float or int): max norm of the gradients norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for infinity norm. - model_parallel_group (group): given the nature of the distributed - optimizer, this is passed as an argument. Returns: Total norm of the parameters (viewed as a single vector). @@ -60,7 +56,10 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2): if isinstance(parameters, torch.Tensor): parameters = [parameters] - # Grads. + # Filter parameters based on: + # - grad should not be none + # - parameter should not be shared + # - should not be a replica due to tensor model parallelism grads = [] grads_for_norm = [] for param in parameters: @@ -70,11 +69,11 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2): grad = param.grad.detach() if grad_not_none: # Make sure the grads are in fp32 + assert param.grad.type() == 'torch.{}.FloatTensor'.format(get_accelerator().device_name()) grads.append(grad) if grad_not_none and is_not_shared and is_not_tp_duplicate: grads_for_norm.append(grad) - # Norm parameters. max_norm = float(max_norm) norm_type = float(norm_type) @@ -83,16 +82,23 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2): # Calculate norm. if norm_type == inf: total_norm = max(grad.abs().max() for grad in grads_for_norm) - total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) + total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)]) # Take max across all model-parallel GPUs. torch.distributed.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=mpu.get_model_parallel_group()) total_norm = total_norm_cuda[0].item() + else: - for grad in grads_for_norm: - grad_norm = torch.norm(grad, norm_type) - total_norm += grad_norm ** norm_type + if norm_type == 2.0: + grad_norm = torch.norm(grads_for_norm, p=2.0) + # Since we will be summing across data parallel groups, + # we need the pow(norm-type). + total_norm = grad_norm ** norm_type + else: + for grad in grads_for_norm: + grad_norm = torch.norm(grad, norm_type) + total_norm += grad_norm ** norm_type # Sum across all model-parallel GPUs. torch.distributed.all_reduce(total_norm, @@ -103,8 +109,8 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2): # Scale. clip_coeff = max_norm / (total_norm + 1.0e-6) if clip_coeff < 1.0: - for p in parameters: - p.grad.detach().mul_(clip_coeff) + for g in grads: + g.detach().mul_(clip_coeff.to(g.device)) return total_norm diff --git a/megatron/optimizer/optimizer.py b/megatron/optimizer/optimizer.py index 8c5f7a7db744e03336a9816b41bbd368706d974d..55f9ac61485640f0b2e235bf4475c841c6f53d9c 100644 --- a/megatron/optimizer/optimizer.py +++ b/megatron/optimizer/optimizer.py @@ -26,6 +26,8 @@ from megatron import mpu from megatron import print_rank_0 from deepspeed.accelerator import get_accelerator from .clip_grads import clip_grad_norm_fp32, count_zeros_fp32 +from ..model.module import param_is_not_shared +from ..mpu.layers import param_is_not_tensor_parallel_duplicate def _zero_grad_group_helper(group, set_to_none): @@ -87,9 +89,36 @@ class MegatronOptimizer(ABC): return params + def get_main_grads_for_grad_norm(self): + + # Filter parameters based on: + # - grad should not be none + # - parameter should not be shared + # - should not be a replica due to tensor model parallelism + params = self.get_parameters() + grads_for_norm = [] + for param in params: + grad = param.grad + grad_not_none = grad is not None + is_not_shared = param_is_not_shared(param) + is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param) + if grad_not_none and is_not_shared and is_not_tp_duplicate: + grads_for_norm.append(grad) + + return grads_for_norm + + + def get_model_parallel_group(self): + """Default returned here, but the distributed optimizer overrides this.""" + return mpu.get_model_parallel_group() + + def clip_grad_norm(self, clip_grad): params = self.get_parameters() - return clip_grad_norm_fp32(params, clip_grad) + grads_for_norm = self.get_main_grads_for_grad_norm() + return clip_grad_norm_fp32( + params, grads_for_norm, clip_grad, + model_parallel_group=self.get_model_parallel_group()) def count_zeros(self): @@ -239,15 +268,11 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer): # For all the parameters in this group: for i, param in enumerate(param_group['params']): if param.requires_grad: - if param.type() == "torch.cuda.HalfTensor": - param_type = "torch.npu.HalfTensor" - elif param.type() == "torch.cuda.BFloat16Tensor": - param_type = "torch.npu.BFloat16Tensor" - elif param.type() == "torch.cuda.FloatTensor": - param_type = "torch.npu.FloatTensor" # float16 params: - if param_type in ['torch.{}.HalfTensor'.format(get_accelerator().device_name()), + + + if param.type() in ['torch.{}.HalfTensor'.format(get_accelerator().device_name()), 'torch.{}.BFloat16Tensor'.format(get_accelerator().device_name())]: float16_params_this_group.append(param) # Create a copy @@ -266,7 +291,7 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer): = self.optimizer.state.pop(param) # fp32 params. - elif param_type == 'torch.{}.FloatTensor'.format(format(get_accelerator().device_name())): + elif param.type() == 'torch.{}.FloatTensor'.format(format(get_accelerator().device_name())): fp32_params_this_group.append(param) param_group['params'][i] = param @@ -308,7 +333,6 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer): for model_group, main_group in zip(self.float16_groups, self.fp32_from_float16_groups): for model_param, main_param in zip(model_group, main_group): - # if self.params_have_main_grad: if self.params_have_main_grad: main_param.grad = model_param.main_grad.float() else: diff --git a/megatron/p2p_communication.py b/megatron/p2p_communication.py index 07e66c8019d2355e27f5afec03b483887aa219b0..e9a10c0d097a65095807f9e76588f829cf729a78 100644 --- a/megatron/p2p_communication.py +++ b/megatron/p2p_communication.py @@ -21,10 +21,12 @@ from megatron import get_args from megatron import mpu -def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next, - use_ring_exchange=False): - """Communicate tensors between stages. Used as helper method in other - communication methods that are used in megatron/schedules.py. +def _communicate_shapes(tensor_send_next, tensor_send_prev, + recv_prev, recv_next): + """Communicate tensor shapes between stages. Used to communicate + tensor shapes before the actual tensor communication happens. + This is required when the sequence lengths across micro batches + are not uniform. Takes the following arguments: tensor_send_next: tensor to send to next rank (no tensor sent if @@ -35,40 +37,148 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next, previous rank. recv_next: boolean for whether tensor should be received from next rank. - use_ring_exchange: boolean for whether torch.distributed.ring_exchange() - API should be used. - Returns: - (tensor_recv_prev, tensor_recv_next) + (recv_prev_shape, recv_next_shape) """ + + args = get_args() + recv_prev_shape_tensor = None + recv_next_shape_tensor = None + send_prev_shape_tensor = None + send_next_shape_tensor = None + if recv_prev: + recv_prev_shape_tensor = torch.empty((3), + device=torch.cuda.current_device(), + dtype=torch.int64) + if recv_next: + recv_next_shape_tensor = torch.empty((3), + device=torch.cuda.current_device(), + dtype=torch.int64) + if tensor_send_prev is not None: + send_prev_shape_tensor = torch.tensor(tensor_send_prev.size(), + device=torch.cuda.current_device(), + dtype=torch.int64) + if tensor_send_next is not None: + send_next_shape_tensor = torch.tensor(tensor_send_next.size(), + device=torch.cuda.current_device(), + dtype=torch.int64) + + if args.use_ring_exchange_p2p: + torch.distributed.ring_exchange(tensor_send_prev=send_prev_shape_tensor, + tensor_recv_prev=recv_prev_shape_tensor, + tensor_send_next=send_next_shape_tensor, + tensor_recv_next=recv_next_shape_tensor, + group=mpu.get_pipeline_model_parallel_group()) + else: + ops = [] + if send_prev_shape_tensor is not None: + send_prev_op = torch.distributed.P2POp( + torch.distributed.isend, send_prev_shape_tensor, + mpu.get_pipeline_model_parallel_prev_rank()) + ops.append(send_prev_op) + if recv_prev_shape_tensor is not None: + recv_prev_op = torch.distributed.P2POp( + torch.distributed.irecv, recv_prev_shape_tensor, + mpu.get_pipeline_model_parallel_prev_rank()) + ops.append(recv_prev_op) + if send_next_shape_tensor is not None: + send_next_op = torch.distributed.P2POp( + torch.distributed.isend, send_next_shape_tensor, + mpu.get_pipeline_model_parallel_next_rank()) + ops.append(send_next_op) + if recv_next_shape_tensor is not None: + recv_next_op = torch.distributed.P2POp( + torch.distributed.irecv, recv_next_shape_tensor, + mpu.get_pipeline_model_parallel_next_rank()) + ops.append(recv_next_op) + if len(ops) > 0: + reqs = torch.distributed.batch_isend_irecv(ops) + for req in reqs: + req.wait() + + # To protect against race condition when using batch_isend_irecv(). + # should take this out once the bug with batch_isend_irecv is resolved. + torch.cuda.synchronize() + + recv_prev_shape = [0, 0, 0] + if recv_prev_shape_tensor is not None: + recv_prev_shape = recv_prev_shape_tensor.tolist() + + recv_next_shape = [0, 0, 0] + if recv_next_shape_tensor is not None: + recv_next_shape = recv_next_shape_tensor.tolist() + + return recv_prev_shape, recv_next_shape + + +def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next, tensor_shape, dtype_=None): args = get_args() # Create placeholder tensors for receive in forward and backward directions # if needed. tensor_recv_prev = None tensor_recv_next = None - tensor_shape = (args.seq_length, args.micro_batch_size, args.hidden_size) - if args.scatter_gather_tensors_in_pipeline: - tensor_chunk_shape = reduce(operator.mul, tensor_shape, 1) // \ - mpu.get_tensor_model_parallel_world_size() + + # Some legacy inference code doesn't set the tensor shape, do so now + # for the normal values for gpt/bert. This could be removed if inference + # code is changed to provide tensor_shape. + if not args.variable_seq_lengths: + if tensor_shape is None: + recv_prev_shape = (args.seq_length, args.micro_batch_size, args.hidden_size) + recv_next_shape = (args.seq_length, args.micro_batch_size, args.hidden_size) + else: + recv_prev_shape = tensor_shape + recv_next_shape = tensor_shape else: - tensor_chunk_shape = tensor_shape + recv_prev_shape, recv_next_shape = \ + _communicate_shapes(tensor_send_next, + tensor_send_prev, + recv_prev, + recv_next) + + override_scatter_gather_tensors_in_pipeline = False + if args.scatter_gather_tensors_in_pipeline and \ + not args.sequence_parallel: + recv_prev_chunk_shape = reduce(operator.mul, recv_prev_shape, 1) + recv_next_chunk_shape = reduce(operator.mul, recv_next_shape, 1) + if recv_prev_chunk_shape % mpu.get_tensor_model_parallel_world_size() == 0 and \ + recv_next_chunk_shape % mpu.get_tensor_model_parallel_world_size() == 0: + recv_prev_chunk_shape = recv_prev_chunk_shape // \ + mpu.get_tensor_model_parallel_world_size() + recv_next_chunk_shape = recv_next_chunk_shape // \ + mpu.get_tensor_model_parallel_world_size() + else: + recv_prev_chunk_shape = recv_prev_shape + recv_next_chunk_shape = recv_next_shape + override_scatter_gather_tensors_in_pipeline = True + else: + recv_prev_chunk_shape = recv_prev_shape + recv_next_chunk_shape = recv_next_shape + dtype = args.params_dtype if args.fp32_residual_connection: dtype = torch.float + + requires_grad = True + if dtype_ is not None: + dtype = dtype_ + requires_grad = False + if recv_prev: - tensor_recv_prev = torch.empty(tensor_chunk_shape, - requires_grad=True, - device=get_accelerator().current_device_name(), + tensor_recv_prev = torch.empty(recv_prev_chunk_shape, + requires_grad=requires_grad, + device=torch.cuda.current_device(), dtype=dtype) if recv_next: - tensor_recv_next = torch.empty(tensor_chunk_shape, - requires_grad=True, - device=get_accelerator().current_device_name(), + tensor_recv_next = torch.empty(recv_next_chunk_shape, + requires_grad=requires_grad, + device=torch.cuda.current_device(), dtype=dtype) # Split tensor into smaller chunks if using scatter-gather optimization. - if args.scatter_gather_tensors_in_pipeline: + if not override_scatter_gather_tensors_in_pipeline and \ + args.scatter_gather_tensors_in_pipeline and \ + not args.sequence_parallel: if tensor_send_next is not None: tensor_send_next = mpu.split_tensor_into_1d_equal_chunks(tensor_send_next) @@ -76,7 +186,7 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next, tensor_send_prev = mpu.split_tensor_into_1d_equal_chunks(tensor_send_prev) # Send tensors in both the forward and backward directions as appropriate. - if use_ring_exchange: + if args.use_ring_exchange_p2p: torch.distributed.ring_exchange(tensor_send_prev=tensor_send_prev, tensor_recv_prev=tensor_recv_prev, tensor_send_next=tensor_send_next, @@ -108,22 +218,28 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next, reqs = torch.distributed.batch_isend_irecv(ops) for req in reqs: req.wait() - # To protect against race condition when using batch_isend_irecv(). - get_accelerator().synchronize() + # To protect against race condition when using batch_isend_irecv(). + torch.cuda.synchronize() # If using scatter-gather optimization, gather smaller chunks. - if args.scatter_gather_tensors_in_pipeline: + if not override_scatter_gather_tensors_in_pipeline and \ + args.scatter_gather_tensors_in_pipeline and \ + not args.sequence_parallel: if recv_prev: tensor_recv_prev = mpu.gather_split_1d_tensor( - tensor_recv_prev).view(tensor_shape).requires_grad_() + tensor_recv_prev).view(recv_prev_shape).requires_grad_() + tensor_recv_prev = mpu.utils.make_viewless_tensor(tensor_recv_prev, + requires_grad=True, + keep_graph=False) if recv_next: tensor_recv_next = mpu.gather_split_1d_tensor( - tensor_recv_next).view(tensor_shape).requires_grad_() + tensor_recv_next).view(recv_next_shape).requires_grad_() + tensor_recv_next = mpu.utils.make_viewless_tensor(tensor_recv_next, + requires_grad=True, + keep_graph=False) return tensor_recv_prev, tensor_recv_next - - def recv_forward(timers=None): """Receive tensor from previous rank in pipeline (forward receive).""" if mpu.is_pipeline_first_stage(): diff --git a/megatron/schedules.py b/megatron/schedules.py index 97dbdae3812f55d38fa5e512c70bc7aedb84c6bf..4e6d4c77ce9182291e3f8a520a85034f0901888f 100644 --- a/megatron/schedules.py +++ b/megatron/schedules.py @@ -15,7 +15,7 @@ from contextlib import contextmanager import torch -import torch_npu +from torch.autograd.variable import Variable from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP from megatron import get_args @@ -58,83 +58,138 @@ def get_forward_backward_func(): forward_backward_func = forward_backward_no_pipelining return forward_backward_func - -def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_reduced): +def custom_backward(output, grad_output): + '''Directly call C++ autograd engine. + + To make the 'deallocate_output_tensor' (above) optimization work, the C++ + autograd engine must be called directly, bypassing Pytorch's + torch.autograd.backward. Pytorch's 'backward' checks that the output and + grad have the same shape, while C++'s 'backward' does not. + ''' + + assert output.numel() == 1, \ + "output should be pseudo-'freed' in schedule, to optimize memory" + assert isinstance(output, torch.Tensor), \ + "output == '%s'." % type(output).__name__ + assert isinstance(grad_output, (torch.Tensor, type(None))), \ + "grad_output == '%s'." % type(grad_output).__name__ + + # Handle scalar output + if grad_output is None: + assert output.numel() == 1, "implicit grad requires scalar output." + grad_output = torch.ones_like( + output, + memory_format = torch.preserve_format, + ) + + # Call c++ engine [ see torch/csrc/autograd/python_engine.cpp ] + Variable._execution_engine.run_backward( + tensors = (output,), + grad_tensors = (grad_output,), + keep_graph = False, + create_graph = False, + inputs = tuple(), + allow_unreachable=True, + accumulate_grad=True, + ) + + +def forward_step(forward_step_func, + data_iterator, + model, + input_tensor, + forward_data_store, + timers, + collect_non_loss_data=False): """Forward step for passed-in model. If first stage, input tensor is obtained from data_iterator, otherwise passed-in input_tensor is used. - Returns output tensor. - """ - timers = get_timers() - + Returns output tensor.""" args = get_args() - timers('forward-compute').start() + if timers is not None: + timers('forward-compute', log_level=2).start() unwrapped_model = unwrap_model( model, (torchDDP, LocalDDP, Float16Module)) - if not args.deepspeed: - unwrapped_model.set_input_tensor(input_tensor) - else: - unwrapped_model.module.set_input_tensor(input_tensor) - # Note: it's recommended to NOT add any new argument to forward_step_func() - # because it is an abstract API used by many different models and tasks. - # Changing this API requires changing it in all models/tasks. Instead, - # it's recommended to use args to pass additional arguments. + unwrap_output_tensor = False + if not isinstance(input_tensor, list): + input_tensor = [input_tensor] + unwrap_output_tensor = True + + unwrapped_model.set_input_tensor(input_tensor) output_tensor, loss_func = forward_step_func(data_iterator, model) if mpu.is_pipeline_last_stage(): - output_tensor = loss_func(output_tensor) - loss, loss_reduced = output_tensor - if not args.no_pipeline_parallel: + if not collect_non_loss_data: + output_tensor = loss_func(output_tensor) + loss, loss_reduced = output_tensor output_tensor = loss / get_num_microbatches() + forward_data_store.append(loss_reduced) else: - output_tensor = loss - losses_reduced.append(loss_reduced) - timers('forward-compute').stop() - - return output_tensor + data = loss_func(output_tensor, non_loss_data=True) + forward_data_store.append(data) + if timers is not None: + timers('forward-compute').stop() -def backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad, model=None): - """Backward step through passed-in output tensor. + if unwrap_output_tensor: + return output_tensor + return [output_tensor] - If last stage, output_tensor_grad is None, otherwise gradient of loss - with respect to stage's output tensor. - Returns gradient of loss with respect to input tensor (None if first - stage).""" +def backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad, timers): args = get_args() - - if args.deepspeed: - assert model is not None - - timers = get_timers() - timers('backward-compute').start() + if timers is not None: + timers('backward-compute', log_level=2).start() # Retain the grad on the input_tensor. - if input_tensor is not None: - input_tensor.retain_grad() - + unwrap_input_tensor_grad = False + if not isinstance(input_tensor, list): + input_tensor = [input_tensor] + unwrap_input_tensor_grad = True + for x in input_tensor: + if x is not None: + x.retain_grad() + + if not isinstance(output_tensor, list): + output_tensor = [output_tensor] + if not isinstance(output_tensor_grad, list): + output_tensor_grad = [output_tensor_grad] + + # Backward pass. clear_npu_overflow_flag() - if args.deepspeed: - model.backward(output_tensor) - else: - # Backward pass. - if output_tensor_grad is None: - output_tensor = optimizer.scale_loss(output_tensor) - torch.autograd.backward(output_tensor, grad_tensors=output_tensor_grad) + if output_tensor_grad[0] is None: + output_tensor = optimizer.scale_loss(output_tensor[0]) + custom_backward(output_tensor[0], output_tensor_grad[0]) # Collect the grad of the input_tensor. - input_tensor_grad = None + input_tensor_grad = [None] if input_tensor is not None: - input_tensor_grad = input_tensor.grad + input_tensor_grad = [] + for x in input_tensor: + if x is None: + input_tensor_grad.append(None) + else: + input_tensor_grad.append(x.grad) - timers('backward-compute').stop() + # Handle single skip connection if it exists (encoder_hidden_state in + # model with encoder and decoder). + if mpu.get_pipeline_model_parallel_world_size() > 1 and \ + mpu.is_pipeline_stage_after_split() and \ + args.model_type == ModelType.encoder_and_decoder: + if output_tensor_grad[1] is not None: + input_tensor_grad[-1].add_(output_tensor_grad[1]) + if unwrap_input_tensor_grad: + input_tensor_grad = input_tensor_grad[0] + + if timers is not None: + timers('backward-compute').stop() return input_tensor_grad + @contextmanager def dummy_handler(): try: @@ -143,43 +198,44 @@ def dummy_handler(): pass -def forward_backward_no_pipelining(forward_step_func, data_iterator, model, - optimizer, timers, forward_only): - """Run forward and backward passes with no pipeline parallelism - (no inter-stage communication). - - Returns dictionary with losses.""" +def forward_backward_no_pipelining(forward_step_func, data_iterator, model, optimizer, timers, forward_only, + collect_non_loss_data=False): assert len(model) == 1 model = model[0] context_handler = dummy_handler - if isinstance(model, torchDDP): + if isinstance(model, torch.nn.parallel.distributed.DistributedDataParallel): context_handler = model.no_sync - losses_reduced = [] + forward_data_store = [] input_tensor, output_tensor_grad = None, None overflow_flag_all = False with context_handler(): for i in range(get_num_microbatches() - 1): - output_tensor = forward_step(forward_step_func, data_iterator, model, - input_tensor, losses_reduced) + output_tensor = forward_step(forward_step_func, data_iterator, + model, input_tensor, forward_data_store, + timers, collect_non_loss_data) if not forward_only: backward_step(optimizer, input_tensor, output_tensor, - output_tensor_grad) + output_tensor_grad, timers) overflow_flag = get_npu_overflow_flag() overflow_flag_all = overflow_flag or overflow_flag_all - output_tensor = forward_step(forward_step_func, data_iterator, model, - input_tensor, losses_reduced) + + # Run computation for last microbatch out of context handler (want to + # synchronize gradients). + output_tensor = forward_step(forward_step_func, data_iterator, + model, input_tensor, forward_data_store, + timers, collect_non_loss_data) if not forward_only: - backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad) + backward_step(optimizer, input_tensor, output_tensor, + output_tensor_grad, timers) overflow_flag = get_npu_overflow_flag() overflow_flag_all = overflow_flag or overflow_flag_all - if overflow_flag_all: set_npu_overflow_flag() - return losses_reduced + return forward_data_store def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterator, model, @@ -188,7 +244,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat communication between pipeline stages as needed. Returns dictionary with losses if the last stage, empty dict otherwise.""" - + input_tensors = [[] for _ in range(len(model))] output_tensors = [[] for _ in range(len(model))] losses_reduced = [] diff --git a/megatron/timers.py b/megatron/timers.py new file mode 100644 index 0000000000000000000000000000000000000000..03810cbd20fd74a01693261b91a889ec481fba7a --- /dev/null +++ b/megatron/timers.py @@ -0,0 +1,283 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + +"""Megatron timers.""" + +from abc import ABC +from abc import abstractmethod +import time + +import torch + + + +class TimerBase(ABC): + + def __init__(self, name): + self.name = name + + @abstractmethod + def start(self, barrier=False): + pass + + @abstractmethod + def stop(self, barrier=False): + pass + + @abstractmethod + def reset(self): + pass + + @abstractmethod + def elapsed(self, reset=True, barrier=False): + pass + + + +class DummyTimer(TimerBase): + + def __init__(self): + super().__init__('dummy timer') + + def start(self, barrier=False): + return + + def stop(self, barrier=False): + return + + def reset(self): + return + + def elapsed(self, reset=True, barrier=False): + raise Exception('dummy timer should not be used to ' + 'calculate elapsed time') + + + +class Timer(TimerBase): + """ + Comment on using `barrier`: If this flag is passed, then all + the caller processes will wait till all reach the timing routine. + It is up to the user to make sure all the ranks in `barrier_group` + call it otherwise, it will result in a hang. + Comment on `barrier_group`: By default it is set to None which + in torch distributed land, it will result in the global communicator. + """ + + def __init__(self, name): + super().__init__(name) + self._elapsed = 0.0 + self._started = False + # Note that None will default to the global process group + self._barrier_group = None + self._start_time = time.time() + + + def set_barrier_group(self, barrier_group): + self._barrier_group = barrier_group + + + def start(self, barrier=False): + """Start the timer.""" + assert not self._started, 'timer has already been started' + if barrier: + torch.distributed.barrier(group=self._barrier_group) + torch.cuda.synchronize() + self._start_time = time.time() + self._started = True + + + def stop(self, barrier=False): + """Stop the timer.""" + assert self._started, 'timer is not started' + if barrier: + torch.distributed.barrier(group=self._barrier_group) + torch.cuda.synchronize() + self._elapsed += (time.time() - self._start_time) + self._started = False + + + def reset(self): + """Reset timer.""" + self._elapsed = 0.0 + self._started = False + + + def elapsed(self, reset=True, barrier=False): + """Calculate the elapsed time.""" + _started = self._started + # If the timing in progress, end it first. + if self._started: + self.stop(barrier=barrier) + # Get the elapsed time. + _elapsed = self._elapsed + # Reset the elapsed time + if reset: + self.reset() + # If timing was in progress, set it back. + if _started: + self.start(barrier=barrier) + return _elapsed + + + +class Timers: + """Group of timers.""" + + def __init__(self, log_level, log_option): + self._log_level = log_level + self._log_option = log_option + self._timers = {} + self._log_levels = {} + self._dummy_timer = DummyTimer() + self._max_log_level = 2 + + + def __call__(self, name, log_level=None): + if name not in self._timers: + self._timers[name] = Timer(name=name) + return self._timers[name] + + + def _get_elapsed_time_all_ranks(self, names, reset, barrier): + """ + Assumptions: + - All the ranks call this function. + - `names` are identical on all ranks. + If the above assumptions are not met, calling this function will + result in hang. + Arguments: + - names: list of timer names + - reset: reset the timer after recording the elapsed time + - barrier: if set, do a global barrier before time measurments + """ + + # First make sure all the callers are in sync. + if barrier: + torch.distributed.barrier() + + world_size = torch.distributed.get_world_size() + rank = torch.distributed.get_rank() + + # Here we can use gather on the rank we want to print the + # timing, however, there is no gather_base support in + # pytorch yet. It is simpler to deal with a single tensor + # and since we are only gathering a small amount of data, + # it should be ok to use all-gather instead of gather. + rank_name_to_time = torch.zeros((world_size, len(names)), + dtype=torch.float, + device=torch.cuda.current_device()) + for i, name in enumerate(names): + if name in self._timers: + # Here we don't need to pass the barrier flag as all + # the processes are already in sync. This avoids the + # issue of different timers having different barrier + # groups inside their class. + rank_name_to_time[rank, i] = self._timers[name].elapsed( + reset=reset) + + # See the note above for why we are not using gather. + torch.distributed._all_gather_base(rank_name_to_time.view(-1), + rank_name_to_time[rank, :].view(-1)) + + return rank_name_to_time + + + def _get_global_min_max_time(self, names, reset, barrier, normalizer): + """Report only min and max times across all ranks.""" + + rank_name_to_time = self._get_elapsed_time_all_ranks(names, reset, + barrier) + name_to_min_max_time = {} + for i, name in enumerate(names): + rank_to_time = rank_name_to_time[:, i] + # filter out the ones we did not have any timings for + rank_to_time = rank_to_time[rank_to_time > 0.0] + # If the timer exists: + if rank_to_time.numel() > 0: + name_to_min_max_time[name] = ( + rank_to_time.min().item() / normalizer, + rank_to_time.max().item() / normalizer) + return name_to_min_max_time + + + def _get_global_min_max_time_string(self, names, reset, barrier, + normalizer, max_only): + name_to_min_max_time = self._get_global_min_max_time( + names, reset, barrier, normalizer) + if not name_to_min_max_time: + return None + output_string = '(min, max) time across ranks (ms):' + for name in name_to_min_max_time: + min_time, max_time = name_to_min_max_time[name] + if max_only: + output_string += '\n {}: {:.2f}'.format( + (name+' ').ljust(48, '.'), max_time) + else: + output_string += '\n {}: ({:.2f}, {:.2f})'.format( + (name+' ').ljust(48, '.'), min_time, max_time) + return output_string + + + def _get_all_ranks_time_string(self, names, reset, barrier, normalizer): + """Report times across all ranks.""" + rank_name_to_time = self._get_elapsed_time_all_ranks(names, reset, + barrier) + + output_string = 'times across ranks (ms):' + no_reported_timing = True + for i, name in enumerate(names): + not_yet_found = True + for rank in range(torch.distributed.get_world_size()): + if rank_name_to_time[rank, i] > 0: + no_reported_timing = False + if not_yet_found: + not_yet_found = False + output_string += '\n {}:'.format(name) + output_string += '\n rank {:2d}: {:.2f}'.format( + rank, rank_name_to_time[rank, i] / normalizer) + if no_reported_timing: + return None + return output_string + + + def log(self, names, rank=None, normalizer=1.0, reset=True, barrier=False): + """Log a group of timers.""" + + # Print. + assert normalizer > 0.0 + if self._log_option in ['max', 'minmax']: + max_only = False + if self._log_option == 'max': + max_only = True + output_string = self._get_global_min_max_time_string( + names, reset, barrier, normalizer/1000.0, max_only) + elif self._log_option == 'all': + output_string = self._get_all_ranks_time_string(names, + reset, barrier, + normalizer/1000.0) + else: + raise Exception('unknown timing log option {}'.format( + self._log_option)) + + # If no input rank is provided, log on last rank. + if rank is None: + rank = torch.distributed.get_world_size() - 1 + if rank == torch.distributed.get_rank() and output_string is not None: + print(output_string, flush=True) + + + def write(self, names, writer, iteration, normalizer=1.0, + reset=False, barrier=False): + """Write timers to a tensorboard writer + Note that we only report maximum time across ranks to tensorboard. + """ + # currently when using add_scalars, + # torch.utils.add_scalars makes each timer its own run, which + # polutes the runs list, so we just add each as a scalar + assert normalizer > 0.0 + name_to_min_max_time = self._get_global_min_max_time( + names, reset, barrier, normalizer) + if writer is not None: + for name in name_to_min_max_time: + _, max_time = name_to_min_max_time[name] + writer.add_scalar(name + '-time', max_time, iteration) diff --git a/megatron/tokenizer/tokenizer.py b/megatron/tokenizer/tokenizer.py index 13085a81c9feb4f4ba2ffe3f254a3bd7d3f4c144..b260c247551ccacaf36a77fabd253360629ed996 100644 --- a/megatron/tokenizer/tokenizer.py +++ b/megatron/tokenizer/tokenizer.py @@ -17,7 +17,7 @@ from abc import ABC from abc import abstractmethod - +from transformers import AutoTokenizer from .bert_tokenization import FullTokenizer as FullBertTokenizer from .gpt2_tokenization import GPT2Tokenizer @@ -29,7 +29,7 @@ def build_tokenizer(args): flush=True) # Select and instantiate the tokenizer. - assert args.vocab_file is not None + assert args.vocab_file is not None or args.tokenizer_type == 'PretrainedFromHF' if args.tokenizer_type == 'BertWordPieceLowerCase': tokenizer = _BertWordPieceTokenizer(vocab_file=args.vocab_file, lower_case=True, @@ -41,6 +41,21 @@ def build_tokenizer(args): elif args.tokenizer_type == 'GPT2BPETokenizer': assert args.merge_file is not None tokenizer = _GPT2BPETokenizer(args.vocab_file, args.merge_file) + elif args.tokenizer_type == "PretrainedFromHF": + assert args.tokenizer_name_or_path is not None + + # prevent transformers from logging info and warnings on each rank + import transformers + import logging + if args.rank == 0: + transformers.utils.logging.set_verbosity(logging.INFO) + else: + # shut the warnings on replicas + transformers.utils.logging.set_verbosity(logging.ERROR) + + if args.rank == 0: + print(" vocab file is un-used. loading tokenizer from pre-trained model") + tokenizer = _AutoTokenizer(args.tokenizer_name_or_path, vocab_extra_ids=args.vocab_extra_ids) else: raise NotImplementedError('{} tokenizer is not ' 'implemented.'.format(args.tokenizer_type)) @@ -53,14 +68,25 @@ def build_tokenizer(args): def _vocab_size_with_padding(orig_vocab_size, args): - """Pad vocab size so it is divisible by model parallel size and - still having GPU friendly size.""" - - after = orig_vocab_size - multiple = args.make_vocab_size_divisible_by * \ - args.tensor_model_parallel_size - while (after % multiple) != 0: - after += 1 + """Apply the requested rules to change the size of the vocabulary""" + if args.pad_vocab_size_to is not None: + if args.pad_vocab_size_to < orig_vocab_size: + raise ValueError( + f"You asked to pad the vocabulary to {args.pad_vocab_size_to} when the initial vocabulary size is " + f"{orig_vocab_size}. You can only pad to a higher value." + ) + + if args.make_vocab_size_divisible_by is not None and (args.pad_vocab_size_to % args.make_vocab_size_divisible_by) != 0: + raise ValueError(f"{args.pad_vocab_size_to} is not divisible by {args.make_vocab_size_divisible_by}") + + after = args.pad_vocab_size_to + else: + # Pad vocab size so it is divisible by model parallel size and still having GPU friendly size. + after = orig_vocab_size + multiple = args.make_vocab_size_divisible_by * \ + args.tensor_model_parallel_size + while (after % multiple) != 0: + after += 1 if args.rank == 0: print(' > padded vocab (size: {}) with {} dummy tokens ' '(new size: {})'.format( @@ -289,3 +315,86 @@ class _GPT2BPETokenizer(AbstractTokenizer): @property def eod(self): return self.eod_id + + +class _AutoTokenizer(AbstractTokenizer): + """AutoTokenizer for Hf Pretrained model loading.""" + + def __init__(self, tokenizer_name_or_path, vocab_extra_ids): + name = tokenizer_name_or_path + super().__init__(name) + hf_tokenizer_kwargs = {} + if vocab_extra_ids > 0: + # TODO @thomasw21 we might need to concatenate to a pre-existing list? + hf_tokenizer_kwargs["additional_special_tokens"] = [f"" for _id in range(vocab_extra_ids)] + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path, **hf_tokenizer_kwargs) + self.encoder = self.tokenizer.get_vocab() + self.decoder = {v: k for k, v in self.encoder.items()} + + @property + def vocab_size(self): + return len(self.tokenizer) # vocab_size doesn't contain additional tokens + + @property + def vocab(self): + # TODO @thomasw21 make sure that special tokens don't collapse with vocab tokens. + return { + **{special_token: self.tokenizer.convert_tokens_to_ids(special_token) for special_token in self.tokenizer.additional_special_tokens}, + **self.tokenizer.vocab, + } + + @property + def inv_vocab(self): + return {v: k for k, v in self.vocab.items()} + + def tokenize(self, text): + return self.tokenizer.encode(text) + + def detokenize(self, token_ids): + return self.tokenizer.decode(token_ids) + + @property + def eod(self): + # TODO @thomasw21 might conflict with + return self.eos + + @property + def cls(self): + candidate = self.tokenizer.cls_token_id + return self._check_token_candidate(candidate) + + @property + def sep(self): + candidate = self.tokenizer.sep_token_id + return self._check_token_candidate(candidate) + + @property + def pad(self): + candidate = self.tokenizer.pad_token_id + return self._check_token_candidate(candidate) + + @property + def mask(self): + candidate = self.tokenizer.mask_token_id + return self._check_token_candidate(candidate) + + @property + def bos(self): + raise NotImplementedError("Missing ") + + @property + def eos(self): + # TODO @thomasw21 might conflict with the notion of + candidate = self.tokenizer.eos_token_id + return self._check_token_candidate(candidate) + + @property + def additional_special_tokens_ids(self): + """ All the additional special tokens you may want to use (list of strings).""" + return self.tokenizer.additional_special_tokens_ids + + @staticmethod + def _check_token_candidate(candidate): + if candidate is None: + raise AttributeError("Token doesn't exist") + return candidate \ No newline at end of file diff --git a/megatron/training.py b/megatron/training.py index 36c76ec252676a2e621da5c092c30fc0ca9f29c4..fc5657b8fa3ef7b1309f56ad3c50e3c20b8e74ef 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -184,6 +184,8 @@ def pretrain(train_valid_test_dataset_provider, print_rank_0('training ...') iteration = 0 + save_checkpoint(iteration, model, optimizer, lr_scheduler) + if args.do_train and args.train_iters > 0: iteration = train(forward_step_func, model, optimizer, lr_scheduler, @@ -327,7 +329,9 @@ def get_model(model_provider_func): # GPU allocation. for model_module in model: - model_module.to(get_accelerator().current_device_name()) + device_name = get_accelerator().current_device_name() + print_rank_0(f"model to {device_name}") + model_module.to(device_name) # Fp16 conversion. @@ -489,6 +493,7 @@ def setup_model_and_optimizer(model_provider_func, teacher=False, # Number of train/valid/test samples. if args.train_samples: train_samples = args.train_samples + update_train_iters(args) else: train_samples = args.train_iters * args.global_batch_size # eval_iters and test_iters here are not actually used, only for @@ -522,6 +527,7 @@ def setup_model_and_optimizer(model_provider_func, teacher=False, lr_scheduler=lr_scheduler, mpu=mpu if args.no_pipeline_parallel else None ) + assert model.fp16_enabled() == args.fp16, "megatron fp16 config does not match deepspeed" if isinstance(model, deepspeed.PipelineEngine): # hack to get batch_fn from pretrain_gpt.py model.set_batch_fn(model.module._megatron_batch_fn) @@ -728,7 +734,7 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration, timers_to_log = [] def add_to_logging(name): - if name in timers.timers: + if name in timers._timers: timers_to_log.append(name) add_to_logging('forward-compute') add_to_logging('forward-recv') @@ -1273,6 +1279,7 @@ def build_train_valid_test_data_iterators( # Number of train/valid/test samples. if args.train_samples: train_samples = args.train_samples + update_train_iters(args) else: train_samples = args.train_iters * args.global_batch_size eval_iters = (args.train_iters // args.eval_interval + 1) * \ diff --git a/megatron/utils.py b/megatron/utils.py index e1f23ec4b682661457be5b67aec992babf3cec13..277dbc97f0e4a48598a4ac0df90e8ddb93f804e8 100644 --- a/megatron/utils.py +++ b/megatron/utils.py @@ -33,6 +33,7 @@ from megatron.model.module import param_is_not_shared from megatron.mpu.layers import param_is_not_tensor_parallel_duplicate from megatron import get_num_microbatches from deepspeed.accelerator import get_accelerator + def unwrap_model(model, module_instances=(torchDDP)): return_list = True if not isinstance(model, list): @@ -155,14 +156,23 @@ def get_ltor_masks_and_position_ids(data, eod_token, reset_position_ids, reset_attention_mask, - eod_mask_loss): - """Build masks and position id for left to right model.""" + eod_mask_loss, + prefix_indices=None, + loss_on_targets_only=False): + """ + Build masks and position id for left to right model. + :param prefix_indices: argument can have multiple types: + - None signifies that the model is fully autoregressive. + - List[int] the argument holds all prefix indices that split a row into an input and a target + - List[List[int]] the argument holds all prefix indices that split documents between input and target. + :param loss_on_targets_only: bool to determine if we should mask loss on prefix. + """ # Extract batch size and sequence length. micro_batch_size, seq_length = data.size() # Attention mask (lower triangular). - if reset_attention_mask: + if reset_attention_mask or prefix_indices is not None: att_mask_batch = micro_batch_size else: att_mask_batch = 1 @@ -183,12 +193,20 @@ def get_ltor_masks_and_position_ids(data, if reset_position_ids: position_ids = position_ids.clone() - if reset_position_ids or reset_attention_mask: + if reset_position_ids or reset_attention_mask or prefix_indices is not None: # Loop through the batches: for b in range(micro_batch_size): # Find indecies where EOD token is. eod_index = position_ids[b, data[b] == eod_token] + + # If the last eod token is not the last token of the sequence, we suppose that there is a partial document + # We treat this case as if we add an eod token at the end of the sequence. + if data[b][-1] != eod_token: + eod_index = torch.cat( + (eod_index, torch.tensor([len(data[b])], dtype=eod_index.dtype, device=eod_index.device)) + ) + # Detach indecies from positions if going to modify positions. if reset_position_ids: eod_index = eod_index.clone() @@ -197,13 +215,33 @@ def get_ltor_masks_and_position_ids(data, prev_index = 0 for j in range(eod_index.size()[0]): i = eod_index[j] - # Mask attention loss. + if reset_attention_mask: + # Prevent cross document interactions. attention_mask[b, 0, (i + 1):, :(i + 1)] = 0 + + # Prefix lm per document. + if prefix_indices: + assert isinstance(prefix_indices[b], list), f"prefix for a row has to be document specific, and consequently return a list, got {prefix_indices[b]}" + attention_mask[b, 0, prev_index: prefix_indices[b][j], prev_index: prefix_indices[b][j]] = 1 + if loss_on_targets_only: + # Last token of the prefix should predict the prefix_index id + loss_mask[b, prev_index: prefix_indices[b][j] - 1] = 0.0 + # Reset positions. if reset_position_ids: position_ids[b, (i + 1):] -= (i + 1 - prev_index) - prev_index = i + 1 + + prev_index = i + 1 + + # Prefix lm per row. + if prefix_indices is not None and (reset_attention_mask is False): + assert isinstance(prefix_indices[b], int), \ + f"prefix for a row has to be row specific, and consequently return an int, got {prefix_indices[b]}" + attention_mask[b, 0, :prefix_indices[b], :prefix_indices[b]] = 1 + if loss_on_targets_only: + # Last token of the prefix should predict the prefix_index id + loss_mask[b, :prefix_indices[b] - 1] = 0.0 # Convert attention mask to binary: attention_mask = (attention_mask < 0.5) diff --git a/pretrain_gpt.py b/pretrain_gpt.py index 128a39de8d5a5b158909416788af3429501192b0..c582966e89f017c9afebf892e9a0a52ec8cdc2da 100644 --- a/pretrain_gpt.py +++ b/pretrain_gpt.py @@ -14,32 +14,39 @@ # limitations under the License. """Pretrain GPT""" - -import math -from functools import partial - +import time import torch import torch_npu from torch_npu.contrib import transfer_to_npu -import deepspeed_npu +torch.cuda.init = torch.npu.init +torch.npu.init() +torch.cuda.default_generators = torch_npu.npu.default_generators +import deepspeed +import deepspeed_npu +from functools import partial from megatron import get_args from megatron import print_rank_0 from megatron import get_timers from megatron import get_tokenizer from megatron import mpu from megatron.data.gpt_dataset import build_train_valid_test_datasets -from megatron.model import GPTModel, GPTModelPipe +from megatron.model import GPTModel, ModelType, GPTModelPipe +from megatron.enums import AttnMaskType from megatron.training import pretrain from megatron.utils import get_ltor_masks_and_position_ids from megatron.utils import average_losses_across_data_parallel_group -import deepspeed from deepspeed.runtime.utils import see_memory_usage -from deepspeed.accelerator.real_accelerator import get_accelerator +from deepspeed.accelerator import get_accelerator +import os -from torch import nn -import torch.nn.functional as F +try: + from torch.distributed.elastic.multiprocessing.errors import record +except ImportError: + # noop + def record(fn): + return fn def model_provider(pre_process=True, post_process=True): """Build the model.""" @@ -53,7 +60,8 @@ def model_provider(pre_process=True, post_process=True): config_dict_or_path=args.deepspeed_config, enabled=args.zero_stage == 3, mpu=mpu): - if args.deepspeed and not args.no_pipeline_parallel: + if args.deepspeed: + args.pretrain_causal_attention = True model = GPTModelPipe( num_tokentypes=0, parallel_output=True @@ -78,7 +86,6 @@ def model_provider(pre_process=True, post_process=True): # Attention mask must be bool. args.attn_mask = attention_mask.to(torch.bool) - else: model = GPTModel( num_tokentypes=0, @@ -107,7 +114,7 @@ def get_batch(data_iterator): data_b = mpu.broadcast_data(keys, data, datatype) # Unpack. - tokens_ = data_b['text'].long() + tokens_ = data_b['text'].int() labels = tokens_[:, 1:].contiguous() tokens = tokens_[:, :-1].contiguous() @@ -163,31 +170,33 @@ def get_batch_pipe(data): labels = tokens_[:, 1:].contiguous() tokens = tokens_[:, :-1].contiguous() - # Get the masks and postition ids. + # Get the masks and position ids. attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids( tokens, tokenizer.eod, args.reset_position_ids, args.reset_attention_mask, - args.eod_mask_loss) - if args.curriculum_learning_legacy and args.curriculum_seqlen < tokens.size()[1]: - # seqlen-based curriculum learning - # tokens, position_ids, labels, loss_mask have size [batch size, seqlen] - tokens = tokens[:, :args.curriculum_seqlen].contiguous() - position_ids = position_ids[:, :args.curriculum_seqlen].contiguous() - if labels is not None: - labels = labels[:, :args.curriculum_seqlen].contiguous() - loss_mask = loss_mask[:, :args.curriculum_seqlen].contiguous() + args.eod_mask_loss, + prefix_indices=None, + loss_on_targets_only=args.loss_on_targets_only + ) + # TEMP_IGNORE + # if args.curriculum_learning and args.curriculum_seqlen < tokens.size()[1]: + # # seqlen-based curriculum learning + # # tokens, position_ids, labels, loss_mask have size [batch size, seqlen] + # tokens = tokens[:, :args.curriculum_seqlen].contiguous() + # position_ids = position_ids[:, :args.curriculum_seqlen].contiguous() + # labels = labels[:, :args.curriculum_seqlen].contiguous() + # loss_mask = loss_mask[:, :args.curriculum_seqlen].contiguous() return (tokens, position_ids, attention_mask), (labels, loss_mask) -def loss_func(loss_mask, moe_loss, mos_loss, output_tensor): +def loss_func(loss_mask, output_tensor): args = get_args() losses = output_tensor.float() loss_mask = loss_mask.view(-1).float() loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum() - # Reduce loss for logging. averaged_loss = average_losses_across_data_parallel_group([loss]) if args.mos or args.kd: @@ -242,45 +251,15 @@ def forward_step(data_iterator, model): timers = get_timers() # Get the batch. - timers('batch-generator').start() + timers('batch-generator', log_level=2).start() tokens, labels, loss_mask, attention_mask, position_ids = get_batch( data_iterator) timers('batch-generator').stop() - if args.data_efficiency_curriculum_learning: - args.curriculum_seqlen = tokens.size()[1] - if hasattr(args, 'data_efficiency_curriculum_learning_seqlen_type') and \ - args.data_efficiency_curriculum_learning_seqlen_type == 'seqlen_reshape': - args.data_efficiency_curriculum_learning_numel = torch.numel(tokens) + output_tensor = model(tokens, position_ids, attention_mask, + labels=labels) - if args.mos or args.kd: - # The forward func can return either the loss or the logits, depending on whether passing in the labels or not. - stu_output, *other_losses = model(tokens, position_ids, attention_mask) - if args.curriculum_learning_legacy and args.curriculum_seqlen < args.seq_length: - assert args.curriculum_seqlen is not None - labels = labels[:, :args.curriculum_seqlen].contiguous() - output_tensor = mpu.vocab_parallel_cross_entropy(stu_output.contiguous().float(), labels) - else: - output_tensor, *other_losses = model(tokens, position_ids, attention_mask, - labels=labels) - if args.curriculum_learning_legacy and args.curriculum_seqlen < args.seq_length: - loss_mask = loss_mask[:, :args.curriculum_seqlen].contiguous() - - moe_losses = [] - for moe_loss in other_losses: - if moe_loss is not None: - moe_losses.append(moe_loss) - moe_loss = sum(moe_losses) * args.moe_loss_coeff - - mos_loss = 0 - if args.mos or args.kd: - assert model.training - if args.teacher_forward and args.teacher_model is not None: - mos_loss = calculate_mos_loss(args, stu_output, - args.teacher_model[0], tokens, position_ids, attention_mask) - - # Output_tensor stores the standard loss, loos_func calculates the total loss. - return output_tensor, partial(loss_func, loss_mask, moe_loss, mos_loss) + return output_tensor, partial(loss_func, loss_mask) def train_valid_test_datasets_provider(train_val_test_num_samples): @@ -289,6 +268,7 @@ def train_valid_test_datasets_provider(train_val_test_num_samples): print_rank_0('> building train, validation, and test datasets ' 'for GPT ...') + # TEMP IGNORE: xxxx_data_path is no used in original code hub train_ds, valid_ds, test_ds = build_train_valid_test_datasets( data_prefix=args.data_path, data_impl=args.data_impl, @@ -303,6 +283,10 @@ def train_valid_test_datasets_provider(train_val_test_num_samples): if __name__ == "__main__": - pretrain(train_valid_test_datasets_provider, model_provider, forward_step, - args_defaults={'tokenizer_type': 'GPT2BPETokenizer'}, - data_post_process=data_post_process) + torch_npu.npu.set_compile_mode(jit_compile=True) + + pretrain(train_valid_test_datasets_provider, model_provider, + ModelType.encoder_or_decoder, + forward_step, + args_defaults={'tokenizer_type': 'GPT2BPETokenizer'} + ) \ No newline at end of file