From af4acfe509cddd944b79bd10919a7a1e314d05f0 Mon Sep 17 00:00:00 2001 From: fengliangjun Date: Tue, 6 Jun 2023 19:26:16 +0800 Subject: [PATCH 01/10] enable ptd mode for gpt and llama --- amp_C.py | 0 bugfix.py | 822 +++++++++++++++++ examples/gpt_task/pretrain_gpt_1p.sh | 43 + examples/gpt_task/pretrain_gpt_ptd_8p.sh | 50 + examples/gpt_task/pretrain_gpt_td_8p.sh | 49 + examples/llama_task/pretrain_llama_1p.sh | 43 + examples/llama_task/pretrain_llama_8p.sh | 49 + examples/llama_task/pretrain_llama_ptd_8p.sh | 50 + examples/llama_task/pretrain_llama_td_8p.sh | 49 + examples/pretrain_llama_1p.sh | 44 + examples/pretrain_llama_8p | 49 + examples/pretrain_llama_8p.sh | 49 + main.py | 16 + megatron/__init__.py | 153 ++++ megatron/data/indexed_dataset.py | 2 +- megatron/fused_kernels/__init__.py | 130 +++ megatron/fused_kernels/compat.h | 31 + megatron/fused_kernels/layer_norm_cuda.cpp | 201 ++++ .../fused_kernels/layer_norm_cuda_kernel.cu | 866 ++++++++++++++++++ .../fused_kernels/scaled_masked_softmax.cpp | 77 ++ .../fused_kernels/scaled_masked_softmax.h | 492 ++++++++++ .../scaled_masked_softmax_cuda.cu | 114 +++ .../scaled_upper_triang_masked_softmax.cpp | 72 ++ .../scaled_upper_triang_masked_softmax.h | 511 +++++++++++ ...scaled_upper_triang_masked_softmax_cuda.cu | 100 ++ megatron/fused_kernels/type_shim.h | 91 ++ megatron/initialize.py | 58 +- megatron/model/__init__.py | 2 +- megatron/model/bert_model.py | 236 +++++ megatron/model/biencoder_model.py | 295 ++++++ megatron/model/classification.py | 119 +++ megatron/model/distributed.py | 1 + megatron/model/fused_layer_norm.py | 23 +- megatron/model/fused_softmax.py | 319 ++++--- megatron/model/llama_model.py | 19 +- megatron/model/module.py | 14 +- megatron/model/multiple_choice.py | 130 +++ megatron/model/realm_model.py | 204 +++++ megatron/model/t5_model.py | 174 ++++ megatron/mpu/layers.py | 1 - megatron/mpu/random.py | 40 +- megatron/optimizer/clip_grads.py | 52 +- megatron/optimizer/optimizer.py | 13 +- megatron/p2p_communication.py | 10 +- megatron/schedules.py | 44 +- megatron/training.py | 3 +- pretrain_gpt.py | 5 +- pretrain_llama.py | 250 +++++ runner.py | 16 + 49 files changed, 5891 insertions(+), 290 deletions(-) create mode 100644 amp_C.py create mode 100644 bugfix.py create mode 100644 examples/gpt_task/pretrain_gpt_1p.sh create mode 100644 examples/gpt_task/pretrain_gpt_ptd_8p.sh create mode 100644 examples/gpt_task/pretrain_gpt_td_8p.sh create mode 100644 examples/llama_task/pretrain_llama_1p.sh create mode 100644 examples/llama_task/pretrain_llama_8p.sh create mode 100644 examples/llama_task/pretrain_llama_ptd_8p.sh create mode 100644 examples/llama_task/pretrain_llama_td_8p.sh create mode 100644 examples/pretrain_llama_1p.sh create mode 100644 examples/pretrain_llama_8p create mode 100644 examples/pretrain_llama_8p.sh create mode 100644 main.py create mode 100644 megatron/fused_kernels/__init__.py create mode 100644 megatron/fused_kernels/compat.h create mode 100644 megatron/fused_kernels/layer_norm_cuda.cpp create mode 100644 megatron/fused_kernels/layer_norm_cuda_kernel.cu create mode 100644 megatron/fused_kernels/scaled_masked_softmax.cpp create mode 100644 megatron/fused_kernels/scaled_masked_softmax.h create mode 100644 megatron/fused_kernels/scaled_masked_softmax_cuda.cu create mode 100644 megatron/fused_kernels/scaled_upper_triang_masked_softmax.cpp create mode 100644 megatron/fused_kernels/scaled_upper_triang_masked_softmax.h create mode 100644 megatron/fused_kernels/scaled_upper_triang_masked_softmax_cuda.cu create mode 100644 megatron/fused_kernels/type_shim.h create mode 100644 megatron/model/bert_model.py create mode 100644 megatron/model/biencoder_model.py create mode 100644 megatron/model/classification.py create mode 100644 megatron/model/multiple_choice.py create mode 100644 megatron/model/realm_model.py create mode 100644 megatron/model/t5_model.py create mode 100644 runner.py diff --git a/amp_C.py b/amp_C.py new file mode 100644 index 000000000..e69de29bb diff --git a/bugfix.py b/bugfix.py new file mode 100644 index 000000000..ac3e5c420 --- /dev/null +++ b/bugfix.py @@ -0,0 +1,822 @@ +import sys +import time +import math +import copy +import numpy as np +import torch +import torch_npu +from functools import wraps + +# ====================== +# torch +# ====================== +# +# # INPLACE.1: torch.cuda.get_rng_state +# torch.cuda.get_rng_state = torch.get_rng_state +# torch.cuda.set_rng_state = torch.set_rng_state + + +# INPLACE.2: torch.Tensor.type() +def wrapper_type(fn): + @wraps(fn) + def decorated(*args, **kwargs): + output = fn(*args, **kwargs) + if isinstance(output, str): + if output == 'torch.npu.FloatTensor': + output = 'torch.cuda.FloatTensor' + elif output == 'torch.npu.HalfTensor': + output = 'torch.cuda.HalfTensor' + return output + + return decorated + + +torch.Tensor.type = wrapper_type(torch.Tensor.type) + +# INPLACE.3: torch.ditributed.xx input long --> int +from torch import distributed as dist + + +def wrapper_dist_long2int(fn): + @wraps(fn) + def wrapper(*args, **kwargs): + if args[0].dtype == torch.long and not kwargs.get('async_op', False): + new_args = list(copy.deepcopy(args)) + new_args[0] = new_args[0].int() + fn(*new_args, **kwargs) + args[0].copy_(new_args[0].long()) + return + return fn(*args, **kwargs) + + return wrapper + + +dist.all_reduce = wrapper_dist_long2int(dist.all_reduce) +dist.broadcast = wrapper_dist_long2int(dist.broadcast) +dist.send = wrapper_dist_long2int(dist.send) +dist.recv = wrapper_dist_long2int(dist.recv) + +# ====================== +# apex +# ====================== + +# INPLACE.4: apex.optimizers +import apex + + +class AdamW(torch.optim.Optimizer): + r"""Implements AdamW algorithm. + + The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_. + The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_. + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay coefficient (default: 1e-2) + amsgrad (boolean, optional): whether to use the AMSGrad variant of this + algorithm from the paper `On the Convergence of Adam and Beyond`_ + (default: False) + + .. _Adam\: A Method for Stochastic Optimization: + https://arxiv.org/abs/1412.6980 + .. _Decoupled Weight Decay Regularization: + https://arxiv.org/abs/1711.05101 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ + """ + + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, + weight_decay=1e-2, amsgrad=False): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + defaults = dict(lr=lr, betas=betas, eps=eps, + weight_decay=weight_decay, amsgrad=amsgrad) + super(AdamW, self).__init__(params, defaults) + + def __setstate__(self, state): + super(AdamW, self).__setstate__(state) + for group in self.param_groups: + group.setdefault('amsgrad', False) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + + # Perform stepweight decay + p.data.mul_(1 - group['lr'] * group['weight_decay']) + + # Perform optimization step + grad = p.grad + if grad.is_sparse: + raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') + amsgrad = group['amsgrad'] + + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = 0 + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like(p) + if amsgrad: + # Maintains max of all exp. moving avg. of sq. grad. values + state['max_exp_avg_sq'] = torch.zeros_like(p) + + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + if amsgrad: + max_exp_avg_sq = state['max_exp_avg_sq'] + beta1, beta2 = group['betas'] + + state['step'] += 1 + bias_correction1 = 1 - beta1 ** state['step'] + bias_correction2 = 1 - beta2 ** state['step'] + + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + if amsgrad: + # Maintains the maximum of all 2nd moment running avg. till now + torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) + # Use the max. for normalizing running avg. of gradient + denom = (max_exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) + else: + denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) + + step_size = group['lr'] / bias_correction1 + + p.addcdiv_(exp_avg, denom, value=-step_size) + + return loss + + +apex.optimizers.FusedAdam = AdamW +apex.optimizers.FusedSGD = torch.optim.SGD + +# ====================== +# megatron +# ====================== +import megatron + +# +# # INPLACE.5: megatron.initialize._compile_dependencies +# def _compile_dependencies(): +# if torch.distributed.get_rank() == 0: +# start_time = time.time() +# print('> compiling dataset index builder ...') +# from megatron.data.dataset_utils import compile_helper +# compile_helper() +# print('>>> done with dataset index builder. Compilation time: {:.3f} ' +# 'seconds'.format(time.time() - start_time), flush=True) +# +# +# megatron.initialize._compile_dependencies = _compile_dependencies +# +# # INPLACE.6: fp32_to_float16, float16_to_fp32 +# from torch.autograd import Variable +# from torch.nn.parameter import Parameter +# from megatron.model.module import fp32_to_float16, float16_to_fp32, conversion_helper +# +# +# def fp32_to_float16(val, float16_convertor): +# """Convert fp32 `val` to fp16/bf16""" +# +# def half_conversion(val): +# val_typecheck = val +# if isinstance(val_typecheck, (Parameter, Variable)): +# val_typecheck = val.data +# if val_typecheck.dtype == torch.float32: +# val = float16_convertor(val) +# return val +# +# return conversion_helper(val, half_conversion) +# +# +# def float16_to_fp32(val): +# """Convert fp16/bf16 `val` to fp32""" +# +# def float_conversion(val): +# val_typecheck = val +# if isinstance(val_typecheck, (Parameter, Variable)): +# val_typecheck = val.data +# if val_typecheck.dtype in [torch.float16, torch.bfloat16]: +# val = val.float() +# return val +# +# return conversion_helper(val, float_conversion) +# +# +# megatron.model.module.fp32_to_float16 = fp32_to_float16 +# megatron.model.module.float16_to_fp32 = float16_to_fp32 + +# INPLACE.7: MixedFusedLayerNorm +# from megatron.model.fused_layer_norm import MixedFusedLayerNorm +# import numbers +# +# class MixedFusedLayerNorm(torch.nn.LayerNorm): +# def __init__(self, normalized_shape, eps=1e-5, no_persist_layer_norm=True, sequence_parallel=False): +# super(MixedFusedLayerNorm, self).__init__(normalized_shape, eps, no_persist_layer_norm) +# +# if isinstance(normalized_shape, numbers.Integral): +# normalized_shape = (normalized_shape,) +# +# self.weight = Parameter(torch.Tensor(*normalized_shape)) +# self.bias = Parameter(torch.Tensor(*normalized_shape)) +# +# # set sequence parallelism flag on weight and bias parameters +# self.sequence_parallel = sequence_parallel +# setattr(self.weight, 'sequence_parallel', self.sequence_parallel) +# setattr(self.bias, 'sequence_parallel', self.sequence_parallel) +# +# +# for k in sys.modules: +# if k.startswith('megatron.model'): +# for target in ['LayerNorm', 'MixedFusedLayerNorm']: +# if getattr(sys.modules[k], target, None): +# setattr(sys.modules[k], target, MixedFusedLayerNorm) + +# # INPLACE.8: _unscale_main_grads_and_check_for_nan +# from megatron.optimizer import Float16OptimizerWithFloat16Params +# +# +# def _unscale_main_grads_and_check_for_nan(self): +# # Collect main grads. +# main_grads = self._collect_main_grad_data_for_unscaling() +# +# # Reset found inf. +# self.found_inf.fill_(0.0) +# +# # Unscale and set found inf/nan +# torch._amp_foreach_non_finite_check_and_unscale_( +# main_grads, self.found_inf, self.grad_scaler.inv_scale) +# +# # Update across all model parallel instances. +# torch.distributed.all_reduce(self.found_inf, +# op=torch.distributed.ReduceOp.MAX, +# group=self.get_model_parallel_group()) +# +# # add data_parallel synchronize +# torch.distributed.all_reduce(self.found_inf, +# op=torch.distributed.ReduceOp.MAX, +# group=self.get_data_parallel_group()) +# +# # Check for nan. +# found_inf_flag = (self.found_inf.item() > 0) +# +# return found_inf_flag +# +# +# Float16OptimizerWithFloat16Params._unscale_main_grads_and_check_for_nan = _unscale_main_grads_and_check_for_nan + +# INPLACE.9: FusedScaleMaskSoftmax +# from megatron.model.fused_softmax import FusedScaleMaskSoftmax +# from megatron.model.enums import AttnMaskType +# +# +# class FusedScaleMaskSoftmax(torch.nn.Module): +# def __init__( +# self, +# input_in_fp16, +# input_in_bf16, +# attn_mask_type, +# scaled_masked_softmax_fusion, +# mask_func, +# softmax_in_fp32, +# scale, +# ): +# super(FusedScaleMaskSoftmax, self).__init__() +# self.input_in_fp16 = input_in_fp16 +# self.input_in_bf16 = input_in_bf16 +# assert not ( +# self.input_in_fp16 and self.input_in_bf16 +# ), "both fp16 and bf16 flags cannot be active at the same time." +# self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16 +# self.attn_mask_type = attn_mask_type +# self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion +# self.mask_func = mask_func +# self.softmax_in_fp32 = softmax_in_fp32 +# self.scale = scale +# self.mask_tri = None +# p = torch.npu.get_device_properties(0) if torch.npu.is_available() else None +# self.fused = p.name in ['Ascend910A', 'Ascend910ProB'] if p is not None else False +# +# assert ( +# self.scale is None or softmax_in_fp32 +# ), "softmax should be in fp32 when scaled" +# +# def forward(self, input, mask): +# # [b, np, sq, sk] +# assert input.dim() == 4 +# +# if torch.npu.is_available() and self.fused: +# return self.forward_fused_softmax(input, mask) +# +# return self.forward_torch_softmax(input, mask) +# +# def forward_fused_softmax(self, input, mask): +# if self.softmax_in_fp32: +# input = input.float() +# +# if self.scale is None: +# self.scale = 1.0 +# +# if self.attn_mask_type == AttnMaskType.causal: +# if self.mask_tri is None: +# self.mask_tri = torch.triu(torch.ones(input.shape, device=input.device), diagonal=1).bool() +# probs = torch_npu.npu_scaled_masked_softmax(input, self.mask_tri, self.scale, False) +# else: +# probs = torch_npu.npu_scaled_masked_softmax(input, mask, self.scale, False) +# +# probs = probs.half() +# +# return probs +# +# def forward_torch_softmax(self, input, mask): +# if self.input_in_float16 and self.softmax_in_fp32: +# input = input.float() +# +# if self.scale is not None: +# input = input * self.scale +# +# if self.attn_mask_type == AttnMaskType.causal: +# mask_tri = torch.triu(torch.ones(input.shape, device=input.device), diagonal=1).bool() +# mask_output = self.mask_func(input, mask_tri) +# else: +# mask_output = self.mask_func(input, mask) if mask is not None else input +# probs = torch.nn.Softmax(dim=-1)(mask_output) +# +# if self.input_in_float16 and self.softmax_in_fp32: +# if self.input_in_fp16: +# probs = probs.half() +# else: +# probs = probs.bfloat16() +# +# return probs +# +# +# for k in sys.modules: +# if k.startswith('megatron.model'): +# for target in ['FusedScaleMaskSoftmax']: +# if getattr(sys.modules[k], target, None): +# setattr(sys.modules[k], target, FusedScaleMaskSoftmax) + +# INPLACE.10: clip_grad_norm_fp32 +from torch._six import inf +from megatron import mpu +from megatron.model.module import param_is_not_shared +from megatron.mpu.layers import param_is_not_tensor_parallel_duplicate +from megatron.optimizer.clip_grads import clip_grad_norm_fp32 +from deepspeed.accelerator import get_accelerator + + +# def clip_grad_norm_fp32(parameters, max_norm, norm_type=2): +# """Clips gradient norm of an iterable of parameters whose gradients +# are in fp32. +# +# This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and +# added functionality to handle model parallel parameters. Note that +# the gradients are modified in place. +# +# Arguments: +# parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a +# single Tensor that will have gradients normalized +# grads_for_norm (Iterable[Tensor]): an iterable of Tensors or a single +# Tensor that will be used for calculating the grad norm. +# max_norm (float or int): max norm of the gradients +# norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for +# infinity norm. +# model_parallel_group (group): given the nature of the distributed +# optimizer, this is passed as an argument. +# +# Returns: +# Total norm of the parameters (viewed as a single vector). +# """ +# +# if isinstance(parameters, torch.Tensor): +# parameters = [parameters] +# # if isinstance(grads_for_norm, torch.Tensor): +# # grads_for_norm = [grads_for_norm] +# +# # Grads. +# grads = [] +# grads_for_norm = [] +# for param in parameters: +# grad_not_none = param.grad is not None +# is_not_shared = param_is_not_shared(param) +# is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param) +# grad = param.grad.detach() +# if grad_not_none: +# # Make sure the grads are in fp32 +# # assert param.grad.type() == 'torch.{}.FloatTensor'.format(get_accelerator().device_name()) +# grads.append(grad) +# if grad_not_none and is_not_shared and is_not_tp_duplicate: +# grads_for_norm.append(grad) +# +# +# # Norm parameters. +# max_norm = float(max_norm) +# norm_type = float(norm_type) +# total_norm = 0.0 +# +# # Calculate norm. +# if norm_type == inf: +# total_norm = max(grad.abs().max() for grad in grads_for_norm) +# total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) +# # Take max across all model-parallel GPUs. +# torch.distributed.all_reduce(total_norm_cuda, +# op=torch.distributed.ReduceOp.MAX, +# group=mpu.get_model_parallel_group()) +# total_norm = total_norm_cuda[0].item() +# else: +# for grad in grads_for_norm: +# grad_norm = torch.norm(grad, norm_type) +# total_norm += grad_norm ** norm_type +# +# # Sum across all model-parallel GPUs. +# torch.distributed.all_reduce(total_norm, +# op=torch.distributed.ReduceOp.SUM, +# group=mpu.get_model_parallel_group()) +# total_norm = total_norm.item() ** (1.0 / norm_type) +# +# # Scale. +# clip_coeff = max_norm / (total_norm + 1.0e-6) +# if clip_coeff < 1.0: +# for p in parameters: +# p.grad.detach().mul_(clip_coeff) +# +# return total_norm +# +# +# megatron.optimizer.clip_grads.clip_grad_norm_fp32 = clip_grad_norm_fp32 +# megatron.optimizer.optimizer.clip_grad_norm_fp32 = clip_grad_norm_fp32 + +# INPLACE.11: _CUDA_RNG_STATE_TRACKER +import contextlib +from megatron.mpu.random import CudaRNGStatesTracker, _CUDA_RNG_STATE_TRACKER, _MODEL_PARALLEL_RNG_TRACKER_NAME + + +class CudaRNGStatesTracker: + """Tracker for the cuda RNG states. + + Using the `add` method, a cuda rng state is initialized based on + the input `seed` and is assigned to `name`. Later, by forking the + rng state, we can perform operations and return to our starting + cuda state. + """ + + def __init__(self): + # Map from a string name to the cuda rng state. + self.states_ = {} + # Seeds are just for book keeping and ensure no seed is set twice. + self.seeds_ = set() + + def reset(self): + """Set to the initial state (no tracker).""" + self.states_ = {} + self.seeds_ = set() + + def get_states(self): + """Get rng states. Copy the dictionary so we have direct + pointers to the states, not just a pointer to the dictionary.""" + states = {} + for name in self.states_: + states[name] = self.states_[name] + return states + + def set_states(self, states): + """Set the rng states. For efficiency purposes, we do not check + the size of seed for compatibility.""" + self.states_ = states + + def add(self, name, seed): + """Track the rng state.""" + # Check seed is not already used. + if seed in self.seeds_: + raise Exception('seed {} already exists'.format(seed)) + self.seeds_.add(seed) + # Check that state is not already defined. + if name in self.states_: + raise Exception('cuda rng state {} already exists'.format(name)) + # Get the current rng state. + # orig_rng_state = torch.cuda.get_rng_state() + # Set the new state and store it. + torch.cuda.manual_seed(seed) + # self.states_[name] = torch.cuda.get_rng_state() + # Reset rng state to what it was. + # _set_cuda_rng_state(orig_rng_state) + + @contextlib.contextmanager + def fork(self, name=_MODEL_PARALLEL_RNG_TRACKER_NAME): + yield + # """Fork the cuda rng state, perform operations, and exit with + # the original state.""" + # # Check if we have added the state + # if name not in self.states_: + # raise Exception('cuda rng state {} is not added'.format(name)) + # Store current rng state. + # orig_cuda_rng_state = torch.cuda.get_rng_state() + # Set rng state to the desired one + # _set_cuda_rng_state(self.states_[name]) + # Do the stuff we wanted to do. + # try: + # yield + # finally: + # # Update the current rng state for later use. + # self.states_[name] = torch.cuda.get_rng_state() + # # And set the state to the original state we started with. + # _set_cuda_rng_state(orig_cuda_rng_state) + + +megatron.mpu.random.CudaRNGStatesTracker = CudaRNGStatesTracker +megatron.mpu.random._CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker() + +# INPLACE.12: _unscale_main_grads_and_check_for_nan +from megatron.optimizer.optimizer import Float16OptimizerWithFloat16Params + +# +# def _unscale_main_grads_and_check_for_nan(self): +# main_grads = [] +# # fp32 params fromm float16 ones. +# for main_group in self.fp32_from_float16_groups: +# for main_param in main_group: +# if main_param.grad is not None: +# main_grads.append(main_param.grad.data) +# # Append fp32 parameters. +# for main_group in self.fp32_from_fp32_groups: +# for main_param in main_group: +# if main_param.grad is not None: +# main_grads.append(main_param.grad.data) +# # Reset found inf. +# self.found_inf.fill_(0.0) +# # Unscale and set found inf/nan +# torch._amp_foreach_non_finite_check_and_unscale_( +# main_grads, self.found_inf, self.grad_scaler.inv_scale) +# # Update across all model parallel instances. +# torch.distributed.all_reduce(self.found_inf, +# op=torch.distributed.ReduceOp.MAX, +# group=mpu.get_model_parallel_group()) +# torch.distributed.all_reduce(self.found_inf, +# op=torch.distributed.ReduceOp.MAX, +# group=mpu.get_data_parallel_group()) +# +# # Check for nan. +# found_inf_flag = (self.found_inf.item() > 0) +# return found_inf_flag +# +# +# Float16OptimizerWithFloat16Params._unscale_main_grads_and_check_for_nan = _unscale_main_grads_and_check_for_nan + +# INPLACE.13: refine overflow flag +from megatron import schedules, get_num_microbatches, get_args, get_timers +from megatron.schedules import dummy_handler, forward_step +from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP + + +def clear_npu_overflow_flag(): + float_status = torch.zeros(8).npu() + result = torch_npu.npu_clear_float_status(float_status) + + +def get_npu_overflow_flag(): + float_status = torch.zeros(8).npu() + result = torch_npu.npu_get_float_status(float_status) + if float_status.cpu()[0] != 0: + return True + else: + return False + + +def set_npu_overflow_flag(): + torch.tensor([65504]).half().npu() + 100 + + +def backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad, model=None): + """Backward step through passed-in output tensor. + + If last stage, output_tensor_grad is None, otherwise gradient of loss + with respect to stage's output tensor. + + Returns gradient of loss with respect to input tensor (None if first + stage).""" + args = get_args() + + if args.deepspeed: + assert model is not None + + timers = get_timers() + timers('backward-compute').start() + + # Retain the grad on the input_tensor. + if input_tensor is not None: + input_tensor.retain_grad() + + clear_npu_overflow_flag() + if args.deepspeed: + model.backward(output_tensor) + else: + # Backward pass. + if output_tensor_grad is None: + output_tensor = optimizer.scale_loss(output_tensor) + torch.autograd.backward(output_tensor, grad_tensors=output_tensor_grad) + + # Collect the grad of the input_tensor. + input_tensor_grad = None + if input_tensor is not None: + input_tensor_grad = input_tensor.grad + + timers('backward-compute').stop() + + return input_tensor_grad + +def forward_backward_no_pipelining(forward_step_func, data_iterator, model, + optimizer, timers, forward_only): + """Run forward and backward passes with no pipeline parallelism + (no inter-stage communication). + + Returns dictionary with losses.""" + assert len(model) == 1 + model = model[0] + + context_handler = dummy_handler + if isinstance(model, torchDDP): + context_handler = model.no_sync + + losses_reduced = [] + input_tensor, output_tensor_grad = None, None + overflow_flag_all = False + with context_handler(): + for i in range(get_num_microbatches() - 1): + output_tensor = forward_step(forward_step_func, data_iterator, model, + input_tensor, losses_reduced) + if not forward_only: + backward_step(optimizer, input_tensor, output_tensor, + output_tensor_grad) + + overflow_flag = get_npu_overflow_flag() + overflow_flag_all = overflow_flag or overflow_flag_all + output_tensor = forward_step(forward_step_func, data_iterator, model, + input_tensor, losses_reduced) + if not forward_only: + backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad) + + overflow_flag = get_npu_overflow_flag() + overflow_flag_all = overflow_flag or overflow_flag_all + + if overflow_flag_all: + set_npu_overflow_flag() + return losses_reduced + + +schedules.forward_backward_no_pipelining = forward_backward_no_pipelining +# +# # INPLACE.14: remove dropout in ParallelTransformerLayer +# from megatron.model.transformer import ParallelTransformerLayer, bias_dropout_add_fused_train, \ +# bias_dropout_add_fused_inference, get_bias_dropout_add +# from megatron.model.enums import AttnMaskType, LayerType, AttnType +# +# +# def forward(self, hidden_states, attention_mask=None, +# encoder_output=None, enc_dec_attn_mask=None, +# layer_past=None, get_key_value=False): +# # hidden_states: [b, s, h] +# +# # Layer norm at the beginning of the transformer layer. +# layernorm_output = self.input_layernorm(hidden_states) +# # Self attention. +# attention_output, attention_bias = \ +# self.attention(layernorm_output, +# attention_mask, +# layer_past=layer_past, +# get_key_value=get_key_value) +# +# if get_key_value: +# attention_output, presents = attention_output +# +# # Residual connection. +# if self.apply_residual_connection_post_layernorm: +# residual = layernorm_output +# else: +# residual = hidden_states +# +# # jit scripting for a nn.module (with dropout) is not +# # trigerring the fusion kernel. For now, we use two +# # different nn.functional routines to account for varying +# # dropout semantics during training and inference phases. +# if self.bias_dropout_fusion: +# if self.training: +# bias_dropout_add_func = bias_dropout_add_fused_train +# else: +# bias_dropout_add_func = bias_dropout_add_fused_inference +# else: +# bias_dropout_add_func = get_bias_dropout_add(self.training) +# +# # re-enable torch grad to enable fused optimization. +# with torch.enable_grad(): +# layernorm_input = bias_dropout_add_func( +# attention_output, +# attention_bias.expand_as(residual), +# residual, +# 0.) +# +# # Layer norm post the self attention. +# layernorm_output = self.post_attention_layernorm(layernorm_input) +# +# if self.layer_type == LayerType.decoder: +# attention_output, attention_bias = \ +# self.inter_attention(layernorm_output, +# enc_dec_attn_mask, +# encoder_output=encoder_output) +# # residual connection +# if self.apply_residual_connection_post_layernorm: +# residual = layernorm_output +# else: +# residual = layernorm_input +# +# # re-enable torch grad to enable fused optimization. +# with torch.enable_grad(): +# layernorm_input = bias_dropout_add_func( +# attention_output, +# attention_bias.expand_as(residual), +# residual, +# 0.) +# +# # Layer norm post the decoder attention +# layernorm_output = self.post_inter_attention_layernorm(layernorm_input) +# +# # MLP. +# moe_loss = torch.tensor(0.0, device=layernorm_output.device, dtype=layernorm_output.dtype) +# mlp_bias = torch.tensor(0.0, device=layernorm_output.device, dtype=layernorm_output.dtype) +# +# if self.num_experts == 1: +# mlp_output, mlp_bias = self.mlp(layernorm_output) +# else: +# mlp_output, moe_loss, _ = self.mlp(layernorm_output) +# +# # Second residual connection. +# if self.apply_residual_connection_post_layernorm: +# residual = layernorm_output +# else: +# residual = layernorm_input +# +# # re-enable torch grad to enable fused optimization. +# with torch.enable_grad(): +# # if self.num_experts <= 1: +# output = bias_dropout_add_func( +# mlp_output, +# mlp_bias.expand_as(residual), +# residual, +# 0.) +# # else: +# # output = mlp_output + residual +# +# if get_key_value: +# output = [output, presents] +# +# return output, moe_loss +# +# +# ParallelTransformerLayer.forward = forward + +# +# from megatron import initialize +# from megatron.initialize import _warmup_jit_function +# +# def set_jit_fusion_options(): +# """Set PyTorch JIT layer fusion options.""" +# +# # legacy pytorch fuser +# torch._C._jit_set_profiling_mode(False) +# torch._C._jit_set_profiling_executor(False) +# torch._C._jit_override_can_fuse_on_cpu(True) +# torch._C._jit_override_can_fuse_on_gpu(True) +# +# _warmup_jit_function() +# +# initialize.set_jit_fusion_options = set_jit_fusion_options diff --git a/examples/gpt_task/pretrain_gpt_1p.sh b/examples/gpt_task/pretrain_gpt_1p.sh new file mode 100644 index 000000000..c131b30a0 --- /dev/null +++ b/examples/gpt_task/pretrain_gpt_1p.sh @@ -0,0 +1,43 @@ +export LD_LIBRARY_PATH=/root/miniconda3/lib:$LD_LIBRARY_PATH + +source /usr/local/Ascend/ascend-toolkit/set_env.sh +RANK=0 +WORLD_SIZE=1 + +DATA_PATH=./dataset/enwiki-gpt/gpt_text_sentence +CHECKPOINT_PATH=./ckpt + +export LOCAL_RANK=0 + +python pretrain_gpt.py \ + --DDP-impl local \ + --use-contiguous-buffers-in-ddp \ + --num-layers 1 \ + --hidden-size 4096 \ + --num-attention-heads 32 \ + --micro-batch-size 4 \ + --global-batch-size 8 \ + --seq-length 2048 \ + --max-position-embeddings 2048 \ + --train-iters 500000 \ + --lr-decay-iters 320000 \ + --save $CHECKPOINT_PATH \ + --load $CHECKPOINT_PATH \ + --data-path $DATA_PATH \ + --vocab-file ./dataset/gpt2-vocab.json \ + --merge-file ./dataset/gpt2-merges.txt \ + --data-impl mmap \ + --split 949,50,1 \ + --distributed-backend nccl \ + --lr 0.00015 \ + --min-lr 1.0e-5 \ + --lr-decay-style cosine \ + --weight-decay 1e-2 \ + --clip-grad 1.0 \ + --lr-warmup-fraction .01 \ + --checkpoint-activations \ + --log-interval 10 \ + --save-interval 10000 \ + --eval-interval 1000 \ + --eval-iters 10 \ + --fp16 | tee logs/train.log diff --git a/examples/gpt_task/pretrain_gpt_ptd_8p.sh b/examples/gpt_task/pretrain_gpt_ptd_8p.sh new file mode 100644 index 000000000..a4e427449 --- /dev/null +++ b/examples/gpt_task/pretrain_gpt_ptd_8p.sh @@ -0,0 +1,50 @@ +export LD_LIBRARY_PATH=/root/miniconda3/lib:$LD_LIBRARY_PATH +source /usr/local/Ascend/ascend-toolkit/set_env.sh +GPUS_PER_NODE=8 +# Change for multinode config +MASTER_ADDR=localhost +MASTER_PORT=6000 +NNODES=1 +NODE_RANK=0 +WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES)) + +DATA_PATH=./dataset/enwiki-gpt/gpt_text_sentence +CHECKPOINT_PATH=./ckpt + +DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE --nnodes $NNODES --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT" + +python -m torch.distributed.launch $DISTRIBUTED_ARGS \ + pretrain_gpt.py \ + --DDP-impl local \ + --use-contiguous-buffers-in-ddp \ + --tensor-model-parallel-size 2 \ + --pipeline-model-parallel-size 2 \ + --num-layers 8 \ + --hidden-size 4096 \ + --num-attention-heads 32 \ + --micro-batch-size 2 \ + --global-batch-size 16 \ + --seq-length 2048 \ + --max-position-embeddings 2048 \ + --train-iters 500000 \ + --lr-decay-iters 320000 \ + --save $CHECKPOINT_PATH \ + --load $CHECKPOINT_PATH \ + --data-path $DATA_PATH \ + --vocab-file ./dataset/gpt2-vocab.json \ + --merge-file ./dataset/gpt2-merges.txt \ + --data-impl mmap \ + --split 949,50,1 \ + --distributed-backend nccl \ + --lr 0.00015 \ + --lr-decay-style cosine \ + --min-lr 1.0e-5 \ + --weight-decay 1e-2 \ + --clip-grad 1.0 \ + --lr-warmup-fraction .01 \ + --checkpoint-activations \ + --log-interval 10 \ + --save-interval 10000 \ + --eval-interval 1000 \ + --eval-iters 10 \ + --fp16 | tee logs/train.log diff --git a/examples/gpt_task/pretrain_gpt_td_8p.sh b/examples/gpt_task/pretrain_gpt_td_8p.sh new file mode 100644 index 000000000..240425bc4 --- /dev/null +++ b/examples/gpt_task/pretrain_gpt_td_8p.sh @@ -0,0 +1,49 @@ +export LD_LIBRARY_PATH=/root/miniconda3/lib:$LD_LIBRARY_PATH +source /usr/local/Ascend/ascend-toolkit/set_env.sh +GPUS_PER_NODE=8 +# Change for multinode config +MASTER_ADDR=localhost +MASTER_PORT=6000 +NNODES=1 +NODE_RANK=0 +WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES)) + +DATA_PATH=./dataset/enwiki-gpt/gpt_text_sentence +CHECKPOINT_PATH=./ckpt + +DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE --nnodes $NNODES --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT" + +python -m torch.distributed.launch $DISTRIBUTED_ARGS \ + pretrain_gpt.py \ + --DDP-impl local \ + --use-contiguous-buffers-in-ddp \ + --tensor-model-parallel-size 4 \ + --num-layers 8 \ + --hidden-size 4096 \ + --num-attention-heads 32 \ + --micro-batch-size 2 \ + --global-batch-size 16 \ + --seq-length 2048 \ + --max-position-embeddings 2048 \ + --train-iters 500000 \ + --lr-decay-iters 320000 \ + --save $CHECKPOINT_PATH \ + --load $CHECKPOINT_PATH \ + --data-path $DATA_PATH \ + --vocab-file ./dataset/gpt2-vocab.json \ + --merge-file ./dataset/gpt2-merges.txt \ + --data-impl mmap \ + --split 949,50,1 \ + --distributed-backend nccl \ + --lr 0.00015 \ + --lr-decay-style cosine \ + --min-lr 1.0e-5 \ + --weight-decay 1e-2 \ + --clip-grad 1.0 \ + --lr-warmup-fraction .01 \ + --checkpoint-activations \ + --log-interval 10 \ + --save-interval 10000 \ + --eval-interval 1000 \ + --eval-iters 10 \ + --fp16 | tee logs/train.log diff --git a/examples/llama_task/pretrain_llama_1p.sh b/examples/llama_task/pretrain_llama_1p.sh new file mode 100644 index 000000000..a78f4edb7 --- /dev/null +++ b/examples/llama_task/pretrain_llama_1p.sh @@ -0,0 +1,43 @@ +export LD_LIBRARY_PATH=/root/miniconda3/lib:$LD_LIBRARY_PATH + +source /usr/local/Ascend/ascend-toolkit/set_env.sh +RANK=0 +WORLD_SIZE=1 + +DATA_PATH=./dataset/enwiki-gpt/gpt_text_sentence +CHECKPOINT_PATH=./ckpt + +export LOCAL_RANK=0 + +python pretrain_llama.py \ + --DDP-impl local \ + --use-contiguous-buffers-in-ddp \ + --num-layers 1 \ + --hidden-size 4096 \ + --num-attention-heads 32 \ + --micro-batch-size 4 \ + --global-batch-size 8 \ + --seq-length 2048 \ + --max-position-embeddings 2048 \ + --train-iters 500000 \ + --lr-decay-iters 320000 \ + --save $CHECKPOINT_PATH \ + --load $CHECKPOINT_PATH \ + --data-path $DATA_PATH \ + --vocab-file ./dataset/gpt2-vocab.json \ + --merge-file ./dataset/gpt2-merges.txt \ + --data-impl mmap \ + --split 949,50,1 \ + --distributed-backend nccl \ + --lr 0.00015 \ + --min-lr 1.0e-5 \ + --lr-decay-style cosine \ + --weight-decay 1e-2 \ + --clip-grad 1.0 \ + --lr-warmup-fraction .01 \ + --checkpoint-activations \ + --log-interval 10 \ + --save-interval 10000 \ + --eval-interval 1000 \ + --eval-iters 10 \ + --fp16 | tee logs/train.log diff --git a/examples/llama_task/pretrain_llama_8p.sh b/examples/llama_task/pretrain_llama_8p.sh new file mode 100644 index 000000000..a51ea9e3b --- /dev/null +++ b/examples/llama_task/pretrain_llama_8p.sh @@ -0,0 +1,49 @@ +export LD_LIBRARY_PATH=/root/miniconda3/lib:$LD_LIBRARY_PATH +source /usr/local/Ascend/ascend-toolkit/set_env.sh +GPUS_PER_NODE=8 +# Change for multinode config +MASTER_ADDR=localhost +MASTER_PORT=6000 +NNODES=1 +NODE_RANK=0 +WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES)) + +DATA_PATH=/home/flj/AscendSpeed/dataset/enwiki-gpt/gpt_text_sentence +CHECKPOINT_PATH=./ckpt + +DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE --nnodes $NNODES --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT" + +python -m torch.distributed.launch $DISTRIBUTED_ARGS \ + pretrain_llama.py \ + --DDP-impl local \ + --use-contiguous-buffers-in-ddp \ + --tensor-model-parallel-size 4 \ + --num-layers 8 \ + --hidden-size 4096 \ + --num-attention-heads 32 \ + --micro-batch-size 2 \ + --global-batch-size 16 \ + --seq-length 2048 \ + --max-position-embeddings 2048 \ + --train-iters 500000 \ + --lr-decay-iters 320000 \ + --save $CHECKPOINT_PATH \ + --load $CHECKPOINT_PATH \ + --data-path $DATA_PATH \ + --vocab-file /home/flj/AscendSpeed/dataset/gpt2-vocab.json \ + --merge-file /home/flj/AscendSpeed/dataset/gpt2-merges.txt \ + --data-impl mmap \ + --split 949,50,1 \ + --distributed-backend nccl \ + --lr 0.00015 \ + --lr-decay-style cosine \ + --min-lr 1.0e-5 \ + --weight-decay 1e-2 \ + --clip-grad 1.0 \ + --lr-warmup-fraction .01 \ + --checkpoint-activations \ + --log-interval 10 \ + --save-interval 10000 \ + --eval-interval 1000 \ + --eval-iters 10 \ + --fp16 | tee logs/train.log diff --git a/examples/llama_task/pretrain_llama_ptd_8p.sh b/examples/llama_task/pretrain_llama_ptd_8p.sh new file mode 100644 index 000000000..81eb7c69f --- /dev/null +++ b/examples/llama_task/pretrain_llama_ptd_8p.sh @@ -0,0 +1,50 @@ +export LD_LIBRARY_PATH=/root/miniconda3/lib:$LD_LIBRARY_PATH +source /usr/local/Ascend/ascend-toolkit/set_env.sh +GPUS_PER_NODE=8 +# Change for multinode config +MASTER_ADDR=localhost +MASTER_PORT=6000 +NNODES=1 +NODE_RANK=0 +WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES)) + +DATA_PATH=./dataset/enwiki-gpt/gpt_text_sentence +CHECKPOINT_PATH=./ckpt + +DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE --nnodes $NNODES --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT" + +python -m torch.distributed.launch $DISTRIBUTED_ARGS \ + pretrain_llama.py \ + --DDP-impl local \ + --use-contiguous-buffers-in-ddp \ + --tensor-model-parallel-size 2 \ + --pipeline-model-parallel-size 2 \ + --num-layers 8 \ + --hidden-size 4096 \ + --num-attention-heads 32 \ + --micro-batch-size 2 \ + --global-batch-size 16 \ + --seq-length 2048 \ + --max-position-embeddings 2048 \ + --train-iters 500000 \ + --lr-decay-iters 320000 \ + --save $CHECKPOINT_PATH \ + --load $CHECKPOINT_PATH \ + --data-path $DATA_PATH \ + --vocab-file ./dataset/gpt2-vocab.json \ + --merge-file ./dataset/gpt2-merges.txt \ + --data-impl mmap \ + --split 949,50,1 \ + --distributed-backend nccl \ + --lr 0.00015 \ + --lr-decay-style cosine \ + --min-lr 1.0e-5 \ + --weight-decay 1e-2 \ + --clip-grad 1.0 \ + --lr-warmup-fraction .01 \ + --checkpoint-activations \ + --log-interval 10 \ + --save-interval 10000 \ + --eval-interval 1000 \ + --eval-iters 10 \ + --fp16 | tee logs/train.log diff --git a/examples/llama_task/pretrain_llama_td_8p.sh b/examples/llama_task/pretrain_llama_td_8p.sh new file mode 100644 index 000000000..a2dd944a5 --- /dev/null +++ b/examples/llama_task/pretrain_llama_td_8p.sh @@ -0,0 +1,49 @@ +export LD_LIBRARY_PATH=/root/miniconda3/lib:$LD_LIBRARY_PATH +source /usr/local/Ascend/ascend-toolkit/set_env.sh +GPUS_PER_NODE=8 +# Change for multinode config +MASTER_ADDR=localhost +MASTER_PORT=6000 +NNODES=1 +NODE_RANK=0 +WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES)) + +DATA_PATH=./dataset/enwiki-gpt/gpt_text_sentence +CHECKPOINT_PATH=./ckpt + +DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE --nnodes $NNODES --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT" + +python -m torch.distributed.launch $DISTRIBUTED_ARGS \ + pretrain_llama.py \ + --DDP-impl local \ + --use-contiguous-buffers-in-ddp \ + --tensor-model-parallel-size 4 \ + --num-layers 8 \ + --hidden-size 4096 \ + --num-attention-heads 32 \ + --micro-batch-size 2 \ + --global-batch-size 16 \ + --seq-length 2048 \ + --max-position-embeddings 2048 \ + --train-iters 500000 \ + --lr-decay-iters 320000 \ + --save $CHECKPOINT_PATH \ + --load $CHECKPOINT_PATH \ + --data-path $DATA_PATH \ + --vocab-file ./dataset/gpt2-vocab.json \ + --merge-file ./dataset/gpt2-merges.txt \ + --data-impl mmap \ + --split 949,50,1 \ + --distributed-backend nccl \ + --lr 0.00015 \ + --lr-decay-style cosine \ + --min-lr 1.0e-5 \ + --weight-decay 1e-2 \ + --clip-grad 1.0 \ + --lr-warmup-fraction .01 \ + --checkpoint-activations \ + --log-interval 10 \ + --save-interval 10000 \ + --eval-interval 1000 \ + --eval-iters 10 \ + --fp16 | tee logs/train.log diff --git a/examples/pretrain_llama_1p.sh b/examples/pretrain_llama_1p.sh new file mode 100644 index 000000000..e1973b18d --- /dev/null +++ b/examples/pretrain_llama_1p.sh @@ -0,0 +1,44 @@ +#! /bin/bash + +# Runs the "345M" parameter model + +RANK=0 +WORLD_SIZE=1 + +DATA_PATH=/home/flj/Megatron_Deepspeed_llama/dataset/enwiki-gpt/gpt_text_sentence +CHECKPOINT_PATH=./ckpt + +export LOCAL_RANK=7 + +python pretrain_llama.py \ + --DDP-impl local \ + --use-contiguous-buffers-in-ddp \ + --num-layers 2 \ + --hidden-size 4096 \ + --num-attention-heads 32 \ + --micro-batch-size 4 \ + --global-batch-size 8 \ + --seq-length 2048 \ + --max-position-embeddings 2048 \ + --train-iters 500000 \ + --lr-decay-iters 320000 \ + --save $CHECKPOINT_PATH \ + --load $CHECKPOINT_PATH \ + --data-path $DATA_PATH \ + --vocab-file /home/flj/Megatron_Deepspeed_llama/dataset/gpt2-vocab.json \ + --merge-file /home/flj/Megatron_Deepspeed_llama/dataset/gpt2-merges.txt \ + --data-impl mmap \ + --split 949,50,1 \ + --distributed-backend nccl \ + --lr 0.00015 \ + --min-lr 1.0e-5 \ + --lr-decay-style cosine \ + --weight-decay 1e-2 \ + --clip-grad 1.0 \ + --lr-warmup-fraction .01 \ + --checkpoint-activations \ + --log-interval 100 \ + --save-interval 10000 \ + --eval-interval 1000 \ + --eval-iters 10 \ + --fp16 | tee logs/train.log diff --git a/examples/pretrain_llama_8p b/examples/pretrain_llama_8p new file mode 100644 index 000000000..e9648284d --- /dev/null +++ b/examples/pretrain_llama_8p @@ -0,0 +1,49 @@ +#! /bin/bash + +# Runs the "345M" parameter model + +GPUS_PER_NODE=8 +# Change for multinode config +MASTER_ADDR=localhost +MASTER_PORT=6000 +NNODES=1 +NODE_RANK=0 +WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES)) + +DATA_PATH=/home/flj/Megatron_Deepspeed_llama/dataset/enwiki-gpt/gpt_text_sentence +CHECKPOINT_PATH=./ckpt + +DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE --nnodes $NNODES --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT" + +python -m torch.distributed.launch $DISTRIBUTED_ARGS \ + pretrain_gpt.py \ + --tensor-model-parallel-size 8 \ + --num-layers 24 \ + --hidden-size 4096 \ + --num-attention-heads 32 \ + --micro-batch-size 4 \ + --global-batch-size 8 \ + --seq-length 2048 \ + --max-position-embeddings 2048 \ + --train-iters 500000 \ + --lr-decay-iters 320000 \ + --save $CHECKPOINT_PATH \ + --load $CHECKPOINT_PATH \ + --data-path $DATA_PATH \ + --vocab-file /home/flj/Megatron_Deepspeed_llama/dataset/gpt2-vocab.json \ + --merge-file /home/flj/Megatron_Deepspeed_llama/dataset/gpt2-merges.txt \ + --data-impl mmap \ + --split 949,50,1 \ + --distributed-backend nccl \ + --lr 0.00015 \ + --lr-decay-style cosine \ + --min-lr 1.0e-5 \ + --weight-decay 1e-2 \ + --clip-grad 1.0 \ + --lr-warmup-fraction .01 \ + --checkpoint-activations \ + --log-interval 100 \ + --save-interval 10000 \ + --eval-interval 1000 \ + --eval-iters 10 \ + --fp16 | tee logs/train.log diff --git a/examples/pretrain_llama_8p.sh b/examples/pretrain_llama_8p.sh new file mode 100644 index 000000000..e9648284d --- /dev/null +++ b/examples/pretrain_llama_8p.sh @@ -0,0 +1,49 @@ +#! /bin/bash + +# Runs the "345M" parameter model + +GPUS_PER_NODE=8 +# Change for multinode config +MASTER_ADDR=localhost +MASTER_PORT=6000 +NNODES=1 +NODE_RANK=0 +WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES)) + +DATA_PATH=/home/flj/Megatron_Deepspeed_llama/dataset/enwiki-gpt/gpt_text_sentence +CHECKPOINT_PATH=./ckpt + +DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE --nnodes $NNODES --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT" + +python -m torch.distributed.launch $DISTRIBUTED_ARGS \ + pretrain_gpt.py \ + --tensor-model-parallel-size 8 \ + --num-layers 24 \ + --hidden-size 4096 \ + --num-attention-heads 32 \ + --micro-batch-size 4 \ + --global-batch-size 8 \ + --seq-length 2048 \ + --max-position-embeddings 2048 \ + --train-iters 500000 \ + --lr-decay-iters 320000 \ + --save $CHECKPOINT_PATH \ + --load $CHECKPOINT_PATH \ + --data-path $DATA_PATH \ + --vocab-file /home/flj/Megatron_Deepspeed_llama/dataset/gpt2-vocab.json \ + --merge-file /home/flj/Megatron_Deepspeed_llama/dataset/gpt2-merges.txt \ + --data-impl mmap \ + --split 949,50,1 \ + --distributed-backend nccl \ + --lr 0.00015 \ + --lr-decay-style cosine \ + --min-lr 1.0e-5 \ + --weight-decay 1e-2 \ + --clip-grad 1.0 \ + --lr-warmup-fraction .01 \ + --checkpoint-activations \ + --log-interval 100 \ + --save-interval 10000 \ + --eval-interval 1000 \ + --eval-iters 10 \ + --fp16 | tee logs/train.log diff --git a/main.py b/main.py new file mode 100644 index 000000000..5596b4478 --- /dev/null +++ b/main.py @@ -0,0 +1,16 @@ +# This is a sample Python script. + +# Press Shift+F10 to execute it or replace it with your code. +# Press Double Shift to search everywhere for classes, files, tool windows, actions, and settings. + + +def print_hi(name): + # Use a breakpoint in the code line below to debug your script. + print(f'Hi, {name}') # Press Ctrl+F8 to toggle the breakpoint. + + +# Press the green button in the gutter to run the script. +if __name__ == '__main__': + print_hi('PyCharm') + +# See PyCharm help at https://www.jetbrains.com/help/pycharm/ diff --git a/megatron/__init__.py b/megatron/__init__.py index 93894cd3f..a71604131 100644 --- a/megatron/__init__.py +++ b/megatron/__init__.py @@ -13,7 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +import math +import copy +import apex +from functools import wraps import torch +from torch import distributed as dist from deepspeed.accelerator import get_accelerator from .package_info import ( __description__, @@ -36,6 +41,154 @@ from .global_vars import get_adlr_autoresume from .global_vars import get_timers from .initialize import initialize_megatron +def wrapper_type(fn): + @wraps(fn) + def decorated(*args, **kwargs): + output = fn(*args, **kwargs) + if isinstance(output, str): + if output == 'torch.npu.FloatTensor': + output = 'torch.cuda.FloatTensor' + elif output == 'torch.npu.HalfTensor': + output = 'torch.cuda.HalfTensor' + return output + + return decorated + +torch.Tensor.type = wrapper_type(torch.Tensor.type) # fix for torch.Tensor.type() + +def wrapper_dist_long2int(fn): + @wraps(fn) + def wrapper(*args, **kwargs): + if args[0].dtype == torch.long and not kwargs.get('async_op', False): + new_args = list(copy.deepcopy(args)) + new_args[0] = new_args[0].int() + fn(*new_args, **kwargs) + args[0].copy_(new_args[0].long()) + return + return fn(*args, **kwargs) + + return wrapper + +dist.all_reduce = wrapper_dist_long2int(dist.all_reduce) +dist.broadcast = wrapper_dist_long2int(dist.broadcast) +dist.send = wrapper_dist_long2int(dist.send) +dist.recv = wrapper_dist_long2int(dist.recv) # fix for torch.distributed + +class AdamW(torch.optim.Optimizer): + r"""Implements AdamW algorithm. + + The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_. + The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_. + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay coefficient (default: 1e-2) + amsgrad (boolean, optional): whether to use the AMSGrad variant of this + algorithm from the paper `On the Convergence of Adam and Beyond`_ + (default: False) + + .. _Adam\: A Method for Stochastic Optimization: + https://arxiv.org/abs/1412.6980 + .. _Decoupled Weight Decay Regularization: + https://arxiv.org/abs/1711.05101 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ + """ + + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, + weight_decay=1e-2, amsgrad=False): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + defaults = dict(lr=lr, betas=betas, eps=eps, + weight_decay=weight_decay, amsgrad=amsgrad) + super(AdamW, self).__init__(params, defaults) + + def __setstate__(self, state): + super(AdamW, self).__setstate__(state) + for group in self.param_groups: + group.setdefault('amsgrad', False) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + + # Perform stepweight decay + p.data.mul_(1 - group['lr'] * group['weight_decay']) + + # Perform optimization step + grad = p.grad + if grad.is_sparse: + raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') + amsgrad = group['amsgrad'] + + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = 0 + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like(p) + if amsgrad: + # Maintains max of all exp. moving avg. of sq. grad. values + state['max_exp_avg_sq'] = torch.zeros_like(p) + + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + if amsgrad: + max_exp_avg_sq = state['max_exp_avg_sq'] + beta1, beta2 = group['betas'] + + state['step'] += 1 + bias_correction1 = 1 - beta1 ** state['step'] + bias_correction2 = 1 - beta2 ** state['step'] + + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + if amsgrad: + # Maintains the maximum of all 2nd moment running avg. till now + torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) + # Use the max. for normalizing running avg. of gradient + denom = (max_exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) + else: + denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) + + step_size = group['lr'] / bias_correction1 + + p.addcdiv_(exp_avg, denom, value=-step_size) + + return loss + +apex.optimizers.FusedAdam = AdamW +apex.optimizers.FusedSGD = torch.optim.SGD + def print_rank_0(message): """If distributed is initialized, print only on rank 0.""" if torch.distributed.is_initialized(): diff --git a/megatron/data/indexed_dataset.py b/megatron/data/indexed_dataset.py index 28ec14762..20344f889 100644 --- a/megatron/data/indexed_dataset.py +++ b/megatron/data/indexed_dataset.py @@ -11,7 +11,7 @@ # An empty sentence no longer separates documents. # Some of the fixes/improvements are adopted from -# https://github.com/bigscience-workshop/AscendSpeed/blob/main/megatron/data/indexed_dataset.py +# https://github.com/bigscience-workshop/Megatron-DeepSpeed/blob/main/megatron/data/indexed_dataset.py from functools import lru_cache import os diff --git a/megatron/fused_kernels/__init__.py b/megatron/fused_kernels/__init__.py new file mode 100644 index 000000000..bdc654c39 --- /dev/null +++ b/megatron/fused_kernels/__init__.py @@ -0,0 +1,130 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import pathlib +import subprocess + +import torch +from torch.utils import cpp_extension + +# Setting this param to a list has a problem of generating different +# compilation commands (with diferent order of architectures) and +# leading to recompilation of fused kernels. Set it to empty string +# to avoid recompilation and assign arch flags explicity in +# extra_cuda_cflags below +os.environ["TORCH_CUDA_ARCH_LIST"] = "" + + +def load(args): + + # Check if cuda 11 is installed for compute capability 8.0 + cc_flag = [] + if torch.version.hip is None: + _, bare_metal_major, _ = _get_cuda_bare_metal_version( + cpp_extension.CUDA_HOME) + if int(bare_metal_major) >= 11: + cc_flag.append('-gencode') + cc_flag.append('arch=compute_80,code=sm_80') + + # Build path + srcpath = pathlib.Path(__file__).parent.absolute() + buildpath = srcpath / 'build' + _create_build_dir(buildpath) + + # Helper function to build the kernels. + def _cpp_extention_load_helper(name, sources, extra_cuda_flags, extra_include_paths): + if torch.version.hip is not None: + extra_cuda_cflags=['-O3'] + extra_cuda_flags + cc_flag + else: + extra_cuda_cflags=['-O3', + '-gencode', 'arch=compute_70,code=sm_70', + '--use_fast_math'] + extra_cuda_flags + cc_flag + + return cpp_extension.load( + name=name, + sources=sources, + build_directory=buildpath, + extra_cflags=['-O3',], + extra_cuda_cflags=extra_cuda_cflags, + extra_include_paths=extra_include_paths, + verbose=(args.rank == 0) + ) + + # ============== + # Fused softmax. + # ============== + + if torch.version.hip is not None: + extra_include_paths=[os.path.abspath(srcpath)] + else: + extra_include_paths=[] + + if args.masked_softmax_fusion: + if torch.version.hip is not None: + extra_cuda_flags = ['-D__HIP_NO_HALF_OPERATORS__=1', + '-D__HIP_NO_HALF_CONVERSIONS__=1'] + else: + extra_cuda_flags = ['-U__CUDA_NO_HALF_OPERATORS__', + '-U__CUDA_NO_HALF_CONVERSIONS__', + '--expt-relaxed-constexpr', + '--expt-extended-lambda'] + + # Upper triangular softmax. + sources=[srcpath / 'scaled_upper_triang_masked_softmax.cpp', + srcpath / 'scaled_upper_triang_masked_softmax_cuda.cu'] + scaled_upper_triang_masked_softmax_cuda = _cpp_extention_load_helper( + "scaled_upper_triang_masked_softmax_cuda", + sources, extra_cuda_flags, extra_include_paths) + + # Masked softmax. + sources=[srcpath / 'scaled_masked_softmax.cpp', + srcpath / 'scaled_masked_softmax_cuda.cu'] + scaled_masked_softmax_cuda = _cpp_extention_load_helper( + "scaled_masked_softmax_cuda", sources, extra_cuda_flags, extra_include_paths) + + # ================================= + # Mixed precision fused layer norm. + # ================================= + + if torch.version.hip is not None: + extra_cuda_flags = [] + else: + extra_cuda_flags = ['-maxrregcount=50'] + + sources=[srcpath / 'layer_norm_cuda.cpp', + srcpath / 'layer_norm_cuda_kernel.cu'] + fused_mix_prec_layer_norm_cuda = _cpp_extention_load_helper( + "fused_mix_prec_layer_norm_cuda", sources, extra_cuda_flags, extra_include_paths) + + +def _get_cuda_bare_metal_version(cuda_dir): + raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], + universal_newlines=True) + output = raw_output.split() + release_idx = output.index("release") + 1 + release = output[release_idx].split(".") + bare_metal_major = release[0] + bare_metal_minor = release[1][0] + + return raw_output, bare_metal_major, bare_metal_minor + + +def _create_build_dir(buildpath): + try: + os.mkdir(buildpath) + except OSError: + if not os.path.isdir(buildpath): + print(f"Creation of the build directory {buildpath} failed") diff --git a/megatron/fused_kernels/compat.h b/megatron/fused_kernels/compat.h new file mode 100644 index 000000000..92e7eb772 --- /dev/null +++ b/megatron/fused_kernels/compat.h @@ -0,0 +1,31 @@ +/* coding=utf-8 + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*This code is copied fron NVIDIA apex: + * https://github.com/NVIDIA/apex + * with minor changes. */ + + + +#ifndef TORCH_CHECK +#define TORCH_CHECK AT_CHECK +#endif + +#ifdef VERSION_GE_1_3 +#define DATA_PTR data_ptr +#else +#define DATA_PTR data +#endif diff --git a/megatron/fused_kernels/layer_norm_cuda.cpp b/megatron/fused_kernels/layer_norm_cuda.cpp new file mode 100644 index 000000000..8f28e7b4a --- /dev/null +++ b/megatron/fused_kernels/layer_norm_cuda.cpp @@ -0,0 +1,201 @@ +/* coding=utf-8 + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*This code is copied fron NVIDIA apex: + * https://github.com/NVIDIA/apex + * with minor changes. */ + +#include +#include +#include +#include "compat.h" + +namespace { + +void compute_n1_n2( + at::Tensor input, + at::IntArrayRef normalized_shape, + int& n1, + int& n2) { + int idiff = input.ndimension() - normalized_shape.size(); + n2 = 1; + for (int i = 0; i < (int)normalized_shape.size(); ++i) { + assert( input.sizes()[i+idiff] == normalized_shape[i] ); + n2 *= normalized_shape[i]; + } + n1 = 1; + for (int i = 0; i < idiff; ++i) { + n1 *= input.sizes()[i]; + } +} + +void check_args( + at::IntArrayRef normalized_shape, + at::Tensor gamma, + at::Tensor beta + ) +{ + TORCH_CHECK(!gamma.defined() || gamma.sizes().equals(normalized_shape)); + TORCH_CHECK(!beta.defined() || beta.sizes().equals(normalized_shape)); +} + +void check_args( + at::Tensor input, + at::IntArrayRef normalized_shape, + int& n1, + int& n2 + ) +{ + int64_t normalized_ndim = normalized_shape.size(); + + if (normalized_ndim < 1) { + std::stringstream ss; + ss << "Expected normalized_shape to be at least 1-dimensional, i.e., " + << "containing at least one element, but got normalized_shape=" + << normalized_shape; + throw std::runtime_error(ss.str()); + } + + auto input_shape = input.sizes(); + auto input_ndim = input.dim(); + + if (input_ndim < normalized_ndim || + !input_shape.slice(input_ndim - normalized_ndim).equals(normalized_shape)) { + std::stringstream ss; + ss << "Given normalized_shape=" << normalized_shape + << ", expected input with shape [*"; + for (auto size : normalized_shape) { + ss << ", " << size; + } + ss << "], but got input of size" << input_shape; + throw std::runtime_error(ss.str()); + } + + compute_n1_n2(input,normalized_shape,n1,n2); +} + + +void check_args( + at::Tensor input, + at::IntArrayRef normalized_shape, + at::Tensor gamma, + at::Tensor beta, + int& n1, + int& n2 + ) +{ + check_args(input,normalized_shape,n1,n2); + check_args(normalized_shape,gamma,beta); +} +} + +void cuda_layer_norm( + at::Tensor* output, + at::Tensor* mean, + at::Tensor* invvar, + at::Tensor* input, + int n1, + int n2, + at::IntArrayRef normalized_shape, + at::Tensor* gamma, + at::Tensor* beta, + double epsilon); + +#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) + +std::vector layer_norm_affine( + at::Tensor input, + at::IntArrayRef normalized_shape, + at::Tensor gamma, + at::Tensor beta, + double epsilon) { + + CHECK_INPUT(input); + CHECK_INPUT(gamma); + CHECK_INPUT(beta); + int n1, n2; + check_args(input, normalized_shape, gamma, beta, n1, n2); + + at::Tensor output = at::empty_like( + input, gamma.options().dtype(gamma.scalar_type())); + at::Tensor mean = at::empty( + {n1}, input.options().dtype(at::ScalarType::Float)); + at::Tensor invvar = at::empty_like(mean); + + cuda_layer_norm(&output, &mean, &invvar, &input, n1, n2, + normalized_shape, &gamma, &beta, epsilon); + + return {output, mean, invvar}; + +} + + +void cuda_layer_norm_gradient( + at::Tensor* dout, + at::Tensor* mean, + at::Tensor* invvar, + at::Tensor* input, + int n1, + int n2, + at::IntArrayRef normalized_shape, + at::Tensor* gamma, + at::Tensor* beta, + double epsilon, + at::Tensor* grad_input, + at::Tensor* grad_gamma, + at::Tensor* grad_beta + ); + +std::vector layer_norm_gradient_affine( + at::Tensor dout, + at::Tensor mean, + at::Tensor invvar, + at::Tensor input, + at::IntArrayRef normalized_shape, + at::Tensor gamma, + at::Tensor beta, + double epsilon) { + + CHECK_INPUT(dout); + CHECK_INPUT(mean); + CHECK_INPUT(invvar); + CHECK_INPUT(input); + CHECK_INPUT(gamma); + CHECK_INPUT(beta); + int n1, n2; + check_args(input, normalized_shape, gamma, beta, n1, n2); + + at::Tensor grad_input = at::empty_like(input); + at::Tensor grad_gamma = at::empty_like(gamma); + at::Tensor grad_beta = at::empty_like(beta); + + cuda_layer_norm_gradient(&dout, &mean, &invvar, &input, n1, n2, + normalized_shape, &gamma, &beta, epsilon, + &grad_input, &grad_gamma, &grad_beta); + + return {grad_input, grad_gamma, grad_beta}; + +} + + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward_affine", &layer_norm_affine, + "LayerNorm forward (CUDA)"); + m.def("backward_affine", &layer_norm_gradient_affine, + "LayerNorm backward (CUDA)"); +} diff --git a/megatron/fused_kernels/layer_norm_cuda_kernel.cu b/megatron/fused_kernels/layer_norm_cuda_kernel.cu new file mode 100644 index 000000000..8a07806b1 --- /dev/null +++ b/megatron/fused_kernels/layer_norm_cuda_kernel.cu @@ -0,0 +1,866 @@ +/* coding=utf-8 + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*This code is copied fron NVIDIA apex: + * https://github.com/NVIDIA/apex + * with minor changes. */ + +#include "ATen/ATen.h" +#include "ATen/AccumulateType.h" +#include "ATen/cuda/CUDAContext.h" +#include "ATen/cuda/DeviceUtils.cuh" + +#include +#include + +#include "type_shim.h" + +template __device__ +void cuWelfordOnlineSum( + const U curr, + U& mu, + U& sigma2, + U& count) +{ + count = count + U(1); + U delta = curr - mu; + U lmean = mu + delta / count; + mu = lmean; + U delta2 = curr - lmean; + sigma2 = sigma2 + delta * delta2; +} + +template __device__ +void cuChanOnlineSum( + const U muB, + const U sigma2B, + const U countB, + U& mu, + U& sigma2, + U& count) +{ + U delta = muB - mu; + U nA = count; + U nB = countB; + count = count + countB; + U nX = count; + if (nX > U(0)) { + nA = nA / nX; + nB = nB / nX; + mu = nA*mu + nB*muB; + sigma2 = sigma2 + sigma2B + delta * delta * nA * nB * nX; + } else { + mu = U(0); + sigma2 = U(0); + } +} + +template __device__ +void cuWelfordMuSigma2( + const T* __restrict__ vals, + const int n1, + const int n2, + const int i1, + U& mu, + U& sigma2, + U* buf, + const int GPU_WARP_SIZE) +{ + // Assumptions: + // 1) blockDim.x == warpSize + // 2) Tensor is contiguous + // 3) 2*blockDim.y*sizeof(U)+blockDim.y*sizeof(int) shared memory available. + // + // compute variance and mean over n2 + U count = U(0); + mu= U(0); + sigma2 = U(0); + if (i1 < n1) { + // one warp normalizes one n1 index, + // synchronization is implicit + // initialize with standard Welford algorithm + const int numx = blockDim.x * blockDim.y; + const int thrx = threadIdx.x + threadIdx.y * blockDim.x; + const T* lvals = vals + i1*n2; + int l = 4*thrx; + for (; l+3 < n2; l+=4*numx) { + for (int k = 0; k < 4; ++k) { + U curr = static_cast(lvals[l+k]); + cuWelfordOnlineSum(curr,mu,sigma2,count); + } + } + for (; l < n2; ++l) { + U curr = static_cast(lvals[l]); + cuWelfordOnlineSum(curr,mu,sigma2,count); + } + // intra-warp reductions + for (int stride = GPU_WARP_SIZE / 2; stride > 0; stride /= 2) { + U sigma2B = WARP_SHFL_DOWN(sigma2, stride); + U muB = WARP_SHFL_DOWN(mu, stride); + U countB = WARP_SHFL_DOWN(count, stride); + cuChanOnlineSum(muB, sigma2B, countB, mu, sigma2, count); + } + // threadIdx.x == 0 has correct values for each warp + // inter-warp reductions + if (blockDim.y > 1) { + U* ubuf = (U*)buf; + U* ibuf = (U*)(ubuf + blockDim.y); + for (int offset = blockDim.y/2; offset > 0; offset /= 2) { + // upper half of warps write to shared + if (threadIdx.x == 0 && threadIdx.y >= offset && threadIdx.y < 2*offset) { + const int wrt_y = threadIdx.y - offset; + ubuf[2*wrt_y] = mu; + ubuf[2*wrt_y+1] = sigma2; + ibuf[wrt_y] = count; + } + __syncthreads(); + // lower half merges + if (threadIdx.x == 0 && threadIdx.y < offset) { + U muB = ubuf[2*threadIdx.y]; + U sigma2B = ubuf[2*threadIdx.y+1]; + U countB = ibuf[threadIdx.y]; + cuChanOnlineSum(muB,sigma2B,countB,mu,sigma2,count); + } + __syncthreads(); + } + // threadIdx.x = 0 && threadIdx.y == 0 only thread that has correct values + if (threadIdx.x == 0 && threadIdx.y == 0) { + ubuf[0] = mu; + ubuf[1] = sigma2; + } + __syncthreads(); + mu = ubuf[0]; + sigma2 = ubuf[1]/U(n2); + // don't care about final value of count, we know count == n2 + } else { + mu = WARP_SHFL(mu, 0); + sigma2 = WARP_SHFL(sigma2/U(n2), 0); + } + } +} + +template<> __device__ +void cuWelfordMuSigma2( + const at::Half* __restrict__ vals, + const int n1, + const int n2, + const int i1, + float& mu, + float& sigma2, + float* buf, + const int GPU_WARP_SIZE) +{ + // Assumptions: + // 1) blockDim.x == warpSize + // 2) Tensor is contiguous + // 3) 2*blockDim.y*sizeof(U)+blockDim.y*sizeof(int) shared memory available. + // + // compute variance and mean over n2 + float count = 0.0f; + mu= float(0); + sigma2 = float(0); + if (i1 < n1) { + // one warp normalizes one n1 index, + // synchronization is implicit + // initialize with standard Welford algorithm + const int numx = blockDim.x * blockDim.y; + const int thrx = threadIdx.x + threadIdx.y * blockDim.x; + const at::Half* lvals = vals + i1*n2; + int l = 8*thrx; + if ((((size_t)lvals)&3) != 0) { + // 16 bit alignment + // first thread consumes first point + if (thrx == 0) { + float curr = static_cast(lvals[0]); + cuWelfordOnlineSum(curr,mu,sigma2,count); + } + ++l; + } + // at this point, lvals[l] are 32 bit aligned for all threads. + for (; l+7 < n2; l+=8*numx) { + for (int k = 0; k < 8; k+=2) { + float2 curr = __half22float2(*((__half2*)(lvals+l+k))); + cuWelfordOnlineSum(curr.x,mu,sigma2,count); + cuWelfordOnlineSum(curr.y,mu,sigma2,count); + } + } + for (; l < n2; ++l) { + float curr = static_cast(lvals[l]); + cuWelfordOnlineSum(curr,mu,sigma2,count); + } + // intra-warp reductions + for (int stride = GPU_WARP_SIZE / 2; stride > 0; stride /= 2) { + float sigma2B = WARP_SHFL_DOWN(sigma2, stride); + float muB = WARP_SHFL_DOWN(mu, stride); + float countB = WARP_SHFL_DOWN(count, stride); + cuChanOnlineSum(muB, sigma2B, countB, mu, sigma2, count); + } + // threadIdx.x == 0 has correct values for each warp + // inter-warp reductions + if (blockDim.y > 1) { + float* ubuf = (float*)buf; + float* ibuf = (float*)(ubuf + blockDim.y); + for (int offset = blockDim.y/2; offset > 0; offset /= 2) { + // upper half of warps write to shared + if (threadIdx.x == 0 && threadIdx.y >= offset && threadIdx.y < 2*offset) { + const int wrt_y = threadIdx.y - offset; + ubuf[2*wrt_y] = mu; + ubuf[2*wrt_y+1] = sigma2; + ibuf[wrt_y] = count; + } + __syncthreads(); + // lower half merges + if (threadIdx.x == 0 && threadIdx.y < offset) { + float muB = ubuf[2*threadIdx.y]; + float sigma2B = ubuf[2*threadIdx.y+1]; + float countB = ibuf[threadIdx.y]; + cuChanOnlineSum(muB,sigma2B,countB,mu,sigma2,count); + } + __syncthreads(); + } + // threadIdx.x = 0 && threadIdx.y == 0 only thread that has correct values + if (threadIdx.x == 0 && threadIdx.y == 0) { + ubuf[0] = mu; + ubuf[1] = sigma2; + } + __syncthreads(); + mu = ubuf[0]; + sigma2 = ubuf[1]/float(n2); + // don't care about final value of count, we know count == n2 + } else { + mu = WARP_SHFL(mu, 0); + sigma2 = WARP_SHFL(sigma2/float(n2), 0); + } + } +} +#ifndef __HIP_PLATFORM_HCC__ +template U rsqrt(U v) { +#else +template __device__ U rsqrt(U v) { +#endif + return U(1) / sqrt(v); +} +#ifndef __HIP_PLATFORM_HCC__ +template<> float rsqrt(float v) { +#else +template<> __device__ float rsqrt(float v) { +#endif + return rsqrtf(v); +} +#ifndef __HIP_PLATFORM_HCC__ +template<> double rsqrt(double v) { +#else +template<> __device__ double rsqrt(double v) { +#endif + return rsqrt(v); +} + +namespace { +// This is the un-specialized struct. Note that we prevent instantiation of this +// struct by putting an undefined symbol in the function body so it won't compile. +// template +// struct SharedMemory +// { +// // Ensure that we won't compile any un-specialized types +// __device__ T *getPointer() +// { +// extern __device__ void error(void); +// error(); +// return NULL; +// } +// }; +// https://github.com/NVIDIA/apex/issues/246 +template +struct SharedMemory; + +template <> +struct SharedMemory +{ + __device__ float *getPointer() + { + extern __shared__ float s_float[]; + return s_float; + } +}; + +} + +template __global__ +void cuApplyLayerNorm( + V* __restrict__ output_vals, + U* __restrict__ mean, + U* __restrict__ invvar, + const T* __restrict__ vals, + const int n1, + const int n2, + const U epsilon, + const V* __restrict__ gamma, + const V* __restrict__ beta, + const int GPU_WARP_SIZE + ) +{ + // Assumptions: + // 1) blockDim.x == warpSize + // 2) Tensors are contiguous + // +#ifndef __HIP_PLATFORM_HCC__ + for (auto i1=blockIdx.y; i1 < n1; i1 += gridDim.y) { +#else + for (int i1=blockIdx.y; i1 < n1; i1 += gridDim.y) { +#endif + SharedMemory shared; + U* buf = shared.getPointer(); + U mu,sigma2; + cuWelfordMuSigma2(vals,n1,n2,i1,mu,sigma2,buf,GPU_WARP_SIZE); + const T* lvals = vals + i1*n2; + V* ovals = output_vals + i1*n2; + U c_invvar = rsqrt(sigma2 + epsilon); + const int numx = blockDim.x * blockDim.y; + const int thrx = threadIdx.x + threadIdx.y * blockDim.x; + if (gamma != NULL && beta != NULL) { + for (int i = thrx; i < n2; i+=numx) { + U curr = static_cast(lvals[i]); + ovals[i] = gamma[i] * static_cast(c_invvar * (curr - mu)) + beta[i]; + } + } else { + for (int i = thrx; i < n2; i+=numx) { + U curr = static_cast(lvals[i]); + ovals[i] = static_cast(c_invvar * (curr - mu)); + } + } + if (threadIdx.x == 0 && threadIdx.y == 0) { + mean[i1] = mu; + invvar[i1] = c_invvar; + } + __syncthreads(); + } +} + +template __device__ +void cuLoadWriteStridedInputs( + const int i1_block, + const int thr_load_row_off, + const int thr_load_col_off, + const int i2_off, + const int row_stride, + U* warp_buf1, + U* warp_buf2, + const T* input, + const V* dout, + const int i1_end, + const int n2, + const U* __restrict__ mean, + const U* __restrict__ invvar + ) +{ + int i1 = i1_block+thr_load_row_off; + if (i1 < i1_end) { + U curr_mean = mean[i1]; + U curr_invvar = invvar[i1]; + for (int k = 0; k < blockDim.y; ++k) { + int i2 = i2_off + k; + int load_idx = i1*n2+i2; + int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k; + if (i2(input[load_idx]); + U curr_dout = static_cast(dout[load_idx]); + warp_buf1[write_idx] = curr_dout; + warp_buf2[write_idx] = curr_dout * (curr_input - curr_mean) * curr_invvar; + } else { + warp_buf1[write_idx] = U(0); + warp_buf2[write_idx] = U(0); + } + } + } else { + for (int k = 0; k < blockDim.y; ++k) { + int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k; + warp_buf1[write_idx] = U(0); + warp_buf2[write_idx] = U(0); + } + } +} + +template __device__ +void cuLoadAddStridedInputs( + const int i1_block, + const int thr_load_row_off, + const int thr_load_col_off, + const int i2_off, + const int row_stride, + U* warp_buf1, + U* warp_buf2, + const T* input, + const V* dout, + const int i1_end, + const int n2, + const U* __restrict__ mean, + const U* __restrict__ invvar + ) +{ + int i1 = i1_block+thr_load_row_off; + if (i1 < i1_end) { + U curr_mean = mean[i1]; + U curr_invvar = invvar[i1]; + for (int k = 0; k < blockDim.y; ++k) { + int i2 = i2_off + k; + int load_idx = i1*n2+i2; + int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k; + if (i2(input[load_idx]); + U curr_dout = static_cast(dout[load_idx]); + warp_buf1[write_idx] += curr_dout; + warp_buf2[write_idx] += curr_dout * (curr_input - curr_mean) * curr_invvar; + } + } + } +} + +template __global__ +void cuComputePartGradGammaBeta( + const V* __restrict__ dout, + const T* __restrict__ input, + const int n1, + const int n2, + const U* __restrict__ mean, + const U* __restrict__ invvar, + U epsilon, + U* part_grad_gamma, + U* part_grad_beta) +{ + const int numsegs_n1 = (n1+blockDim.y*blockDim.y-1) / (blockDim.y*blockDim.y); + const int segs_per_block = (numsegs_n1 + gridDim.y - 1) / gridDim.y; + const int i1_beg = blockIdx.y * segs_per_block * blockDim.y*blockDim.y; + const int i1_beg_plus_one = (blockIdx.y+1) * segs_per_block * blockDim.y*blockDim.y; + const int i1_end = i1_beg_plus_one < n1 ? i1_beg_plus_one : n1; + const int row_stride = blockDim.x+1; + const int thr_load_col_off = (threadIdx.x*blockDim.y)&(blockDim.x-1); + const int thr_load_row_off = (threadIdx.x*blockDim.y)/blockDim.x + threadIdx.y*blockDim.y; + const int i2_off = blockIdx.x * blockDim.x + thr_load_col_off; + SharedMemory shared; + U* buf = shared.getPointer(); // buf has at least blockDim.x * blockDim.y * blockDim.y + (blockDim.y - 1)*(blockDim.x/blockDim.y) elements + U* warp_buf1 = (U*)buf; + U* warp_buf2 = warp_buf1 + blockDim.y * blockDim.y * row_stride; + // compute partial sums from strided inputs + // do this to increase number of loads in flight + cuLoadWriteStridedInputs(i1_beg,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,n2,mean,invvar); + for (int i1_block = i1_beg+blockDim.y*blockDim.y; i1_block < i1_end; i1_block+=blockDim.y*blockDim.y) { + cuLoadAddStridedInputs(i1_block,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,n2,mean,invvar); + } + __syncthreads(); + // inter-warp reductions + // sum within each warp + U acc1 = U(0); + U acc2 = U(0); + for (int k = 0; k < blockDim.y; ++k) { + int row1 = threadIdx.y + k*blockDim.y; + int idx1 = row1*row_stride + threadIdx.x; + acc1 += warp_buf1[idx1]; + acc2 += warp_buf2[idx1]; + } + warp_buf1[threadIdx.y*row_stride+threadIdx.x] = acc1; + warp_buf2[threadIdx.y*row_stride+threadIdx.x] = acc2; + __syncthreads(); + // sum all warps + for (int offset = blockDim.y/2; offset > 1; offset /= 2) { + if (threadIdx.y < offset) { + int row1 = threadIdx.y; + int row2 = threadIdx.y + offset; + int idx1 = row1*row_stride + threadIdx.x; + int idx2 = row2*row_stride + threadIdx.x; + warp_buf1[idx1] += warp_buf1[idx2]; + warp_buf2[idx1] += warp_buf2[idx2]; + } + __syncthreads(); + } + int i2 = blockIdx.x * blockDim.x + threadIdx.x; + if (threadIdx.y == 0 && i2 < n2) { + int row1 = threadIdx.y; + int row2 = threadIdx.y + 1; + int idx1 = row1*row_stride + threadIdx.x; + int idx2 = row2*row_stride + threadIdx.x; + part_grad_beta[blockIdx.y*n2+i2] = warp_buf1[idx1] + warp_buf1[idx2]; + part_grad_gamma[blockIdx.y*n2+i2] = warp_buf2[idx1] + warp_buf2[idx2]; + } +} + +template __global__ +void cuComputeGradGammaBeta( + const U* part_grad_gamma, + const U* part_grad_beta, + const int part_size, + const int n1, + const int n2, + V* grad_gamma, + V* grad_beta) +{ + // sum partial gradients for gamma and beta + SharedMemory shared; + U* buf = shared.getPointer(); + int i2 = blockIdx.x * blockDim.x + threadIdx.x; + if (i2 < n2) { + // each warp does sequential reductions until reduced part_size is num_warps + int num_warp_reductions = part_size / blockDim.y; + U sum_gamma = U(0); + U sum_beta = U(0); + const U* part_grad_gamma_ptr = part_grad_gamma + threadIdx.y * num_warp_reductions * n2 + i2; + const U* part_grad_beta_ptr = part_grad_beta + threadIdx.y * num_warp_reductions * n2 + i2; + for (int warp_offset = 0; warp_offset < num_warp_reductions; ++warp_offset) { + sum_gamma += part_grad_gamma_ptr[warp_offset*n2]; + sum_beta += part_grad_beta_ptr[warp_offset*n2]; + } + // inter-warp reductions + const int nbsize3 = blockDim.x * blockDim.y / 2; + for (int offset = blockDim.y/2; offset >= 1; offset /= 2) { + // top half write to shared memory + if (threadIdx.y >= offset && threadIdx.y < 2*offset) { + const int write_idx = (threadIdx.y - offset) * blockDim.x + threadIdx.x; + buf[write_idx] = sum_gamma; + buf[write_idx+nbsize3] = sum_beta; + } + __syncthreads(); + // bottom half sums + if (threadIdx.y < offset) { + const int read_idx = threadIdx.y * blockDim.x + threadIdx.x; + sum_gamma += buf[read_idx]; + sum_beta += buf[read_idx+nbsize3]; + } + __syncthreads(); + } + // write out fully summed gradients + if (threadIdx.y == 0) { + grad_gamma[i2] = sum_gamma; + grad_beta[i2] = sum_beta; + } + } +} + +template __global__ +void cuComputeGradInput( + const V* __restrict__ dout, + const T* __restrict__ input, + const int n1, + const int n2, + const U* __restrict__ mean, + const U* __restrict__ invvar, + U epsilon, + const V* gamma, + T* grad_input) +{ +#ifndef __HIP_PLATFORM_HCC__ + for (auto i1=blockIdx.y; i1 < n1; i1 += gridDim.y) { +#else + for (int i1=blockIdx.y; i1 < n1; i1 += gridDim.y) { +#endif + U sum_loss1 = U(0); + U sum_loss2 = U(0); + const U c_mean = mean[i1]; + const U c_invvar = invvar[i1]; + const T* k_input = input + i1*n2; + const V* k_dout = dout + i1*n2; + const int numx = blockDim.x * blockDim.y; + const int thrx = threadIdx.x + threadIdx.y * blockDim.x; + if (gamma != NULL) { + int l = 4*thrx; + for (; l+3 < n2; l+=4*numx) { + for (int k = 0; k < 4; ++k) { + const U c_h = static_cast(k_input[l+k]); + const U c_loss = static_cast(k_dout[l+k]); + sum_loss1 += c_loss * gamma[l+k]; + sum_loss2 += c_loss * gamma[l+k] * (c_h - c_mean) * c_invvar; + } + } + for (; l < n2; ++l) { + const U c_h = static_cast(k_input[l]); + const U c_loss = static_cast(k_dout[l]); + sum_loss1 += c_loss * gamma[l]; + sum_loss2 += c_loss * gamma[l] * (c_h - c_mean) * c_invvar; + } + } else { + int l = 4*thrx; + for (; l+3 < n2; l+=4*numx) { + for (int k = 0; k < 4; ++k) { + const U c_h = static_cast(k_input[l+k]); + const U c_loss = static_cast(k_dout[l+k]); + sum_loss1 += c_loss; + sum_loss2 += c_loss * (c_h - c_mean) * c_invvar; + } + } + for (; l < n2; ++l) { + const U c_h = static_cast(k_input[l]); + const U c_loss = static_cast(k_dout[l]); + sum_loss1 += c_loss; + sum_loss2 += c_loss * (c_h - c_mean) * c_invvar; + } + } + // intra-warp reductions + for (int mask = blockDim.x/2; mask > 0; mask /= 2) { + sum_loss1 += WARP_SHFL_XOR(sum_loss1, mask); + sum_loss2 += WARP_SHFL_XOR(sum_loss2, mask); + } + // inter-warp reductions + if (blockDim.y > 1) { + SharedMemory shared; + U* buf = shared.getPointer(); + for (int offset = blockDim.y/2; offset > 0; offset /= 2) { + // upper half of warps write to shared + if (threadIdx.y >= offset && threadIdx.y < 2*offset) { + const int wrt_i = (threadIdx.y - offset) * blockDim.x + threadIdx.x; + buf[2*wrt_i] = sum_loss1; + buf[2*wrt_i+1] = sum_loss2; + } + __syncthreads(); + // lower half merges + if (threadIdx.y < offset) { + const int read_i = threadIdx.y * blockDim.x + threadIdx.x; + sum_loss1 += buf[2*read_i]; + sum_loss2 += buf[2*read_i+1]; + } + __syncthreads(); + } + if (threadIdx.y == 0) { + buf[2*threadIdx.x] = sum_loss1; + buf[2*threadIdx.x+1] = sum_loss2; + } + __syncthreads(); + if (threadIdx.y !=0) { + sum_loss1 = buf[2*threadIdx.x]; + sum_loss2 = buf[2*threadIdx.x+1]; + } + } + // all threads now have the two sums over l + U fH = (U)n2; + U term1 = (U(1) / fH) * c_invvar; + T* k_grad_input = grad_input + i1*n2; + if (gamma != NULL) { + for (int l = thrx; l < n2; l+=numx) { + const U c_h = static_cast(k_input[l]); + const U c_loss = static_cast(k_dout[l]); + U f_grad_input = fH * c_loss * gamma[l]; + f_grad_input -= sum_loss1; + f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2; + f_grad_input *= term1; + k_grad_input[l] = static_cast(f_grad_input); + } + } else { + for (int l = thrx; l < n2; l+=numx) { + const U c_h = static_cast(k_input[l]); + const U c_loss = static_cast(k_dout[l]); + U f_grad_input = fH * c_loss; + f_grad_input -= sum_loss1; + f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2; + f_grad_input *= term1; + k_grad_input[l] = static_cast(f_grad_input); + } + } + // prevent race where buf is written again before reads are done + __syncthreads(); + } +} + + + + +template +void HostApplyLayerNorm( + V* output, + U* mean, + U* invvar, + const T* input, + int n1, + int n2, + double epsilon, + const V* gamma, + const V* beta + ) +{ + auto stream = at::cuda::getCurrentCUDAStream().stream(); + const int warp_size = at::cuda::warp_size(); + dim3 threads(warp_size,4,1); +#ifndef __HIP_PLATFORM_HCC__ + threads.y = 1; +#endif + const uint64_t maxGridY = + at::cuda::getCurrentDeviceProperties()->maxGridSize[1]; + const dim3 blocks(1, std::min((uint64_t)n1, maxGridY), 1); + int nshared = + threads.y > 1 ? + threads.y*sizeof(U)+(threads.y/2)*sizeof(U) : + 0; + cuApplyLayerNorm<<>>( + output, + mean, + invvar, + input, + n1,n2, + U(epsilon), + gamma, + beta, + warp_size); +} + + +void cuda_layer_norm( + at::Tensor* output, + at::Tensor* mean, + at::Tensor* invvar, + at::Tensor* input, + int n1, + int n2, + #ifdef VERSION_GE_1_1 + at::IntArrayRef normalized_shape, + #else + at::IntList normalized_shape, + #endif + at::Tensor* gamma, + at::Tensor* beta, + double epsilon) +{ + using namespace at; + DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES( + input->scalar_type(), output->scalar_type(), "cuda_layer_norm_kernel", + HostApplyLayerNorm( + output->DATA_PTR(), + mean->DATA_PTR(), + invvar->DATA_PTR(), + input->DATA_PTR(), + n1,n2, + epsilon, + gamma != NULL ? gamma->DATA_PTR() : NULL, + beta != NULL ? beta->DATA_PTR() : NULL); + ) +} + + +template +void HostLayerNormGradient( + const V* dout, + const U* mean, + const U* invvar, + at::Tensor* input, + int n1, + int n2, + const V* gamma, + const V* beta, + double epsilon, + T* grad_input, + V* grad_gamma, + V* grad_beta + ) +{ + auto stream = at::cuda::getCurrentCUDAStream().stream(); + const int warp_size = at::cuda::warp_size(); + + if (gamma != NULL && beta != NULL) { + // compute grad_gamma(j) and grad_beta(j) +#ifndef __HIP_PLATFORM_HCC__ + const int part_size = warp_size; +#else + const int part_size = 16; +#endif + const dim3 threads2(warp_size,4,1); + const dim3 blocks2((n2+threads2.x-1)/threads2.x,part_size,1); + const int nshared2_a = 2 * sizeof(U) * threads2.y * threads2.y * + (threads2.x + 1); + const int nshared2_b = threads2.x * threads2.y * sizeof(U); + const int nshared2 = nshared2_a > nshared2_b ? nshared2_a : nshared2_b; + at::Tensor part_grad_gamma = at::empty( + {part_size,n2}, input->options().dtype(at::ScalarType::Float)); + at::Tensor part_grad_beta = at::empty_like(part_grad_gamma); + cuComputePartGradGammaBeta<<>>( + dout, + input->DATA_PTR(), + n1,n2, + mean, + invvar, + U(epsilon), + part_grad_gamma.DATA_PTR(), + part_grad_beta.DATA_PTR()); + + const dim3 threads3(warp_size,8,1); + const dim3 blocks3((n2+threads2.x-1)/threads2.x,1,1); + const int nshared3 = threads3.x * threads3.y * sizeof(U); + cuComputeGradGammaBeta<<>>( + part_grad_gamma.DATA_PTR(), + part_grad_beta.DATA_PTR(), + part_size, + n1,n2, + grad_gamma, + grad_beta); + } + + // compute grad_input + const uint64_t maxGridY = + at::cuda::getCurrentDeviceProperties()->maxGridSize[1]; + const dim3 blocks1(1, std::min((uint64_t)n1, maxGridY), 1); + dim3 threads1(warp_size,4,1); +#ifndef __HIP_PLATFORM_HCC__ + threads1.y = 2; +#endif + int nshared = + threads1.y > 1 ? + threads1.y*threads1.x*sizeof(U) : + 0; + cuComputeGradInput<<>>( + dout, + input->DATA_PTR(), + n1,n2, + mean, + invvar, + U(epsilon), + gamma, + grad_input); +} + + +void cuda_layer_norm_gradient( + at::Tensor* dout, + at::Tensor* mean, + at::Tensor* invvar, + at::Tensor* input, + int n1, + int n2, + #ifdef VERSION_GE_1_1 + at::IntArrayRef normalized_shape, + #else + at::IntList normalized_shape, + #endif + at::Tensor* gamma, + at::Tensor* beta, + double epsilon, + at::Tensor* grad_input, + at::Tensor* grad_gamma, + at::Tensor* grad_beta) +{ + using namespace at; + DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES( + input->scalar_type(), gamma->scalar_type(), + "cuda_layer_norm_gradient_kernel", + HostLayerNormGradient( + dout->DATA_PTR(), + mean->DATA_PTR(), + invvar->DATA_PTR(), + input, + n1,n2, + // TMJ pass NULL argument for gamma, beta, grad_gamma and grad_beta + // if gamma Tensor is NULL on input. + gamma != NULL ? gamma->DATA_PTR() : NULL, + gamma != NULL ? beta->DATA_PTR() : NULL, + epsilon, + grad_input->DATA_PTR(), + gamma != NULL ? grad_gamma->DATA_PTR() : NULL, + gamma != NULL ? grad_beta->DATA_PTR() : NULL); + ) +} diff --git a/megatron/fused_kernels/scaled_masked_softmax.cpp b/megatron/fused_kernels/scaled_masked_softmax.cpp new file mode 100644 index 000000000..d5334710c --- /dev/null +++ b/megatron/fused_kernels/scaled_masked_softmax.cpp @@ -0,0 +1,77 @@ +/* coding=utf-8 + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +namespace multihead_attn { +namespace fused_softmax { +namespace scaled_masked_softmax { + +torch::Tensor fwd_cuda( + torch::Tensor const& input, + torch::Tensor const& mask, + float scale_factor); + +torch::Tensor bwd_cuda( + torch::Tensor const& output_grads, + torch::Tensor const& softmax_results, + float scale_factor); + +torch::Tensor fwd( + torch::Tensor const& input, + torch::Tensor const& mask, + float scale_factor) { + AT_ASSERTM(input.dim() == 4, "expected 4D tensor"); + AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) || + (input.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); + AT_ASSERTM(mask.dim() == 4, "expected 4D tensor"); + + return fwd_cuda(input, mask, scale_factor); +} + +torch::Tensor bwd( + torch::Tensor const& output_grads, + torch::Tensor const& softmax_results, + float scale_factor) { + + AT_ASSERTM(output_grads.dim() == 4, "expected 3D tensor"); + AT_ASSERTM(softmax_results.dim() == 4, "expected 3D tensor"); + + AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) || + (output_grads.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); + AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) || + (softmax_results.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); + + return bwd_cuda(output_grads, softmax_results, scale_factor); +} + +} // end namespace scaled_masked_softmax +} // end namespace fused_softmax +} // end namespace multihead_attn + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", + &multihead_attn::fused_softmax::scaled_masked_softmax::fwd, + "Self Multihead Attention scaled, time masked softmax -- Forward."); + m.def("backward", + &multihead_attn::fused_softmax::scaled_masked_softmax::bwd, + "Self Multihead Attention scaled, time masked softmax -- Backward."); +} diff --git a/megatron/fused_kernels/scaled_masked_softmax.h b/megatron/fused_kernels/scaled_masked_softmax.h new file mode 100644 index 000000000..78e97e4ec --- /dev/null +++ b/megatron/fused_kernels/scaled_masked_softmax.h @@ -0,0 +1,492 @@ +/* coding=utf-8 + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +namespace { + +template +__device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src); + +template <> +__device__ __inline__ void copy_vector(c10::BFloat16 *dst, const c10::BFloat16 *src) { *dst = *src; } + +template <> +__device__ __inline__ void copy_vector(c10::BFloat16 *dst, const c10::BFloat16 *src) { *((float2*) dst) = *((float2*) src); } + +template <> +__device__ __inline__ void copy_vector(c10::Half *dst, const c10::Half *src) { *dst = *src; } + +template <> +__device__ __inline__ void copy_vector(c10::Half *dst, const c10::Half *src) { *((float2*) dst) = *((float2*) src); } + +template <> +__device__ __inline__ void copy_vector(uint8_t *dst, const uint8_t *src) { *dst = *src; } + +template <> +__device__ __inline__ void copy_vector(uint8_t *dst, const uint8_t *src) {*((half2*) dst) = *((half2*) src); } + +int log2_ceil(int value) { + int log2_value = 0; + while ((1 << log2_value) < value) ++log2_value; + return log2_value; +} + +template +struct Add { + __device__ __forceinline__ T operator()(T a, T b) const { + return a + b; + } +}; + +template +struct Max { + __device__ __forceinline__ T operator()(T a, T b) const { + return a < b ? b : a; + } +}; + +template +__device__ __forceinline__ T WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff) +{ +#if CUDA_VERSION >= 9000 + return __shfl_xor_sync(mask, value, laneMask, width); +#else + return __shfl_xor(value, laneMask, width); +#endif +} + +template class ReduceOp> +__device__ __forceinline__ void warp_reduce(acc_t* sum) { + ReduceOp r; + #pragma unroll + for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE); + sum[i] = r(sum[i], b); + } + } +} + +/* + * Extended softmax (from native aten pytorch) with following additional features + * 1) input scaling + * 2) Explicit masking + */ +template +__global__ void scaled_masked_softmax_warp_forward( + output_t *dst, + const input_t *src, + const uint8_t *mask, + const acc_t scale, + int micro_batch_size, + int element_count, + int pad_batches) +{ + // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and + // warp_size of method warp_softmax_forward_kernel. + constexpr int next_power_of_two = 1 << log2_elements; + constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; + constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; + constexpr int ELEMENTS_PER_LDG_STG = 4; + + // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, ) + // gridDim/blockIdx = (seq_len, attn_heads, batches) + int first_batch = (blockDim.y * (blockIdx.x + gridDim.x * (blockIdx.y + gridDim.y * blockIdx.z))+ threadIdx.y) * WARP_BATCH; + int pad_first_batch = 0; + if (pad_batches != 1) { // bert style + pad_first_batch = (blockDim.y * (blockIdx.x + gridDim.x * blockIdx.z) + threadIdx.y) * WARP_BATCH; + } else { // gpt2 style + pad_first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; + } + + // micro_batch_size might not be a multiple of WARP_BATCH. Check how + // many batches have to computed within this WARP. + int local_batches = micro_batch_size - first_batch; + if (local_batches > WARP_BATCH) + local_batches = WARP_BATCH; + + // there might be multiple batches per warp. compute the index within the batch + int local_idx = threadIdx.x; + + src += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; + dst += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; + mask += pad_first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; + + // load data from global memory + acc_t elements[WARP_BATCH][WARP_ITERATIONS]; + input_t temp_data[ELEMENTS_PER_LDG_STG]; + uint8_t temp_mask[ELEMENTS_PER_LDG_STG]; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + int batch_element_count = (i >= local_batches) ? 0 : element_count; + + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + + if (element_index < batch_element_count) { + int itr_idx = i*element_count+it*WARP_SIZE; + copy_vector(temp_data, src + itr_idx); + copy_vector(temp_mask, mask + itr_idx); + + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + if (temp_mask[element] != 1) { + elements[i][it + element] = (acc_t)temp_data[element] * scale; + } else { + elements[i][it + element] = -10000.0; + } + } + } else { + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + elements[i][it + element] = -std::numeric_limits::infinity(); + } + } + } + } + + // compute max_value + acc_t max_value[WARP_BATCH]; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + max_value[i] = elements[i][0]; + #pragma unroll + for (int it = 1; it < WARP_ITERATIONS; ++it) { + max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it]; + } + } + warp_reduce(max_value); + + acc_t sum[WARP_BATCH] { 0.0f }; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; ++it) { + elements[i][it] = std::exp((elements[i][it] - max_value[i])); + sum[i] += elements[i][it]; + } + } + warp_reduce(sum); + + // store result + output_t out[ELEMENTS_PER_LDG_STG]; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + if (i >= local_batches) + break; + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + if (element_index < element_count) { + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + out[element] = elements[i][it + element] / sum[i]; + } + copy_vector(dst + i * element_count + it * WARP_SIZE, out); + } else { + break; + } + } + } +} + +template +__global__ void scaled_masked_softmax_warp_backward( + output_t *gradInput, + input_t *grad, + const input_t *output, + acc_t scale, + int micro_batch_size, + int element_count) +{ + // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and + // warp_size of method warp_softmax_backward_kernel. + constexpr int next_power_of_two = 1 << log2_elements; + constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; + constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; + constexpr int ELEMENTS_PER_LDG_STG = 4; + + // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, ) + // gridDim/blockIdx = (seq_len, attn_heads, batches) + int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; + + // micro_batch_size might not be a multiple of WARP_BATCH. Check how + // many batches have to computed within this WARP. + int local_batches = micro_batch_size - first_batch; + if (local_batches > WARP_BATCH) + local_batches = WARP_BATCH; + + // there might be multiple batches per warp. compute the index within the batch + int local_idx = threadIdx.x; + + // the first element to process by the current thread + int thread_offset = first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; + grad += thread_offset; + output += thread_offset; + gradInput += thread_offset; + + // load data from global memory + acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f }; + acc_t output_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f }; + input_t temp_grad[ELEMENTS_PER_LDG_STG]; + input_t temp_output[ELEMENTS_PER_LDG_STG]; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + int batch_element_count = (i >= local_batches) ? 0 : element_count; + + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + if (element_index < batch_element_count) { + copy_vector(temp_grad, grad + i * element_count + it * WARP_SIZE); + copy_vector(temp_output, output + i * element_count + it * WARP_SIZE); + + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + output_reg[i][it + element] = (acc_t)temp_output[element]; + } + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + grad_reg[i][it + element] = (acc_t)temp_grad[element] * output_reg[i][it + element]; + } + } + } + } + + acc_t sum[WARP_BATCH]; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + sum[i] = grad_reg[i][0]; + #pragma unroll + for (int it = 1; it < WARP_ITERATIONS; ++it) { + sum[i] += grad_reg[i][it]; + } + } + warp_reduce(sum); + + // store result + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + if (i >= local_batches) + break; + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + if (element_index < element_count) { + // compute gradients + output_t out[ELEMENTS_PER_LDG_STG]; + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + out[element] = (output_t)(scale * (grad_reg[i][it + element] - output_reg[i][it + element] * sum[i])); + } + copy_vector(gradInput + i * element_count + it * WARP_SIZE, out); + } + } + } +} + +} // end of anonymous namespace + +template +void dispatch_scaled_masked_softmax_forward( + output_t *dst, + const input_t *src, + const uint8_t *mask, + const input_t scale, + int query_seq_len, + int key_seq_len, + int batches, + int attn_heads, + int pad_batches) +{ + TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 2048 ); + if (key_seq_len == 0) { + return; + } else { + int log2_elements = log2_ceil(key_seq_len); + const int next_power_of_two = 1 << log2_elements; + int batch_count = batches * attn_heads * query_seq_len; + + // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward. + int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + + // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward. + int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; + + // use 128 threads per block to maximimize gpu utilization + constexpr int threads_per_block = 128; + + int warps_per_block = (threads_per_block / warp_size); + int batches_per_block = warps_per_block * batches_per_warp; + TORCH_INTERNAL_ASSERT(query_seq_len%batches_per_block == 0); + dim3 blocks(query_seq_len/batches_per_block, attn_heads, batches); + dim3 threads(warp_size, warps_per_block, 1); + // Launch code would be more elegant if C++ supported FOR CONSTEXPR + switch (log2_elements) { + case 0: // 1 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 1: // 2 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 2: // 4 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 3: // 8 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 4: // 16 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 5: // 32 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 6: // 64 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 7: // 128 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 8: // 256 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 9: // 512 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 10: // 1024 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 11: // 2048 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + default: + break; + } + } +} + +template +void dispatch_scaled_masked_softmax_backward( + output_t *grad_input, + input_t *grad, + const input_t *output, + const acc_t scale, + int query_seq_len, + int key_seq_len, + int batches, + int attn_heads) +{ + TORCH_INTERNAL_ASSERT( key_seq_len >= 0 && key_seq_len <= 2048 ); + if (key_seq_len == 0) { + return; + } else { + int log2_elements = log2_ceil(key_seq_len); + const int next_power_of_two = 1 << log2_elements; + int batch_count = batches * attn_heads * query_seq_len; + + // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward. + int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + + // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward. + int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; + + // use 128 threads per block to maximimize gpu utilization + constexpr int threads_per_block = 128; + + int warps_per_block = (threads_per_block / warp_size); + int batches_per_block = warps_per_block * batches_per_warp; + int blocks = batch_count/batches_per_block; + dim3 threads(warp_size, warps_per_block, 1); + // Launch code would be more elegant if C++ supported FOR CONSTEXPR + switch (log2_elements) { + case 0: // 1 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 1: // 2 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 2: // 4 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 3: // 8 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 4: // 16 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 5: // 32 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 6: // 64 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 7: // 128 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 8: // 256 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 9: // 512 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 10: // 1024 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 11: // 2048 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + default: + break; + } + } +} diff --git a/megatron/fused_kernels/scaled_masked_softmax_cuda.cu b/megatron/fused_kernels/scaled_masked_softmax_cuda.cu new file mode 100644 index 000000000..c034dc3ad --- /dev/null +++ b/megatron/fused_kernels/scaled_masked_softmax_cuda.cu @@ -0,0 +1,114 @@ +/* coding=utf-8 + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#ifndef __HIP_PLATFORM_HCC__ +#include +#endif +#include +#include +#include "scaled_masked_softmax.h" +#include "type_shim.h" + +namespace multihead_attn { +namespace fused_softmax { +namespace scaled_masked_softmax { + +torch::Tensor fwd_cuda( + torch::Tensor const& input, + torch::Tensor const& mask, + float scale_factor) +{ + // input is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len] + const int batches = input.size(0); + const int pad_batches = mask.size(0); + const int attn_heads = input.size(1); + const int query_seq_len = input.size(2); + const int key_seq_len = input.size(3); + TORCH_INTERNAL_ASSERT(key_seq_len <= 2048); + TORCH_INTERNAL_ASSERT(query_seq_len > 1); + TORCH_INTERNAL_ASSERT(pad_batches == 1 || pad_batches == batches); + TORCH_INTERNAL_ASSERT(mask.size(1) == 1); + TORCH_INTERNAL_ASSERT(mask.size(2) == query_seq_len); + TORCH_INTERNAL_ASSERT(mask.size(3) == key_seq_len); + + // Output + auto act_options = input.options().requires_grad(false); + torch::Tensor softmax_results = + torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options); + + // Softmax Intermediate Result Ptr + void* input_ptr = static_cast(input.data_ptr()); + void* mask_ptr = static_cast(mask.data_ptr()); + void* softmax_results_ptr = static_cast(softmax_results.data_ptr()); + + DISPATCH_HALF_AND_BFLOAT( + input.scalar_type(), + "dispatch_scaled_masked_softmax_forward", + dispatch_scaled_masked_softmax_forward( + reinterpret_cast(softmax_results_ptr), + reinterpret_cast(input_ptr), + reinterpret_cast(mask_ptr), + scale_factor, + query_seq_len, + key_seq_len, + batches, + attn_heads, + pad_batches); + ); + return softmax_results; +} + +torch::Tensor bwd_cuda( + torch::Tensor const& output_grads_, + torch::Tensor const& softmax_results_, + float scale_factor) { + + auto output_grads = output_grads_.contiguous(); + auto softmax_results = softmax_results_.contiguous(); + + //output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len] + const int batches = output_grads.size(0); + const int attn_heads = output_grads.size(1); + const int query_seq_len = output_grads.size(2); + const int key_seq_len = output_grads.size(3); + + void* output_grads_ptr = static_cast(output_grads.data_ptr()); + + //Softmax Grad + DISPATCH_HALF_AND_BFLOAT( + output_grads_.scalar_type(), + "dispatch_scaled_masked_softmax_backward", + dispatch_scaled_masked_softmax_backward( + reinterpret_cast(output_grads_ptr), + reinterpret_cast(output_grads_ptr), + reinterpret_cast(softmax_results.data_ptr()), + scale_factor, + query_seq_len, + key_seq_len, + batches, + attn_heads); + ); + + //backward pass is completely in-place + return output_grads; +} +} +} +} diff --git a/megatron/fused_kernels/scaled_upper_triang_masked_softmax.cpp b/megatron/fused_kernels/scaled_upper_triang_masked_softmax.cpp new file mode 100644 index 000000000..ea283588d --- /dev/null +++ b/megatron/fused_kernels/scaled_upper_triang_masked_softmax.cpp @@ -0,0 +1,72 @@ +/* coding=utf-8 + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +namespace multihead_attn { +namespace fused_softmax { +namespace scaled_upper_triang_masked_softmax { + +torch::Tensor fwd_cuda( + torch::Tensor const& input, + float scale_factor); + +torch::Tensor bwd_cuda( + torch::Tensor const& output_grads, + torch::Tensor const& softmax_results, + float scale_factor); + +torch::Tensor fwd(torch::Tensor const& input, float scale_factor) { + AT_ASSERTM(input.dim() == 3, "expected 3D tensor"); + AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) || + (input.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); + + return fwd_cuda(input, scale_factor); +} + +torch::Tensor bwd( + torch::Tensor const& output_grads, + torch::Tensor const& softmax_results, + float scale_factor) { + + AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor"); + AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor"); + + AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) || + (output_grads.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); + AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) || + (softmax_results.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); + + return bwd_cuda(output_grads, softmax_results, scale_factor); +} + +} // end namespace scaled_upper_triang_masked_softmax +} // end namespace fused_softmax +} // end namespace multihead_attn + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", + &multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::fwd, + "Self Multihead Attention scaled, time masked softmax -- Forward."); + m.def("backward", + &multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::bwd, + "Self Multihead Attention scaled, time masked softmax -- Backward."); +} diff --git a/megatron/fused_kernels/scaled_upper_triang_masked_softmax.h b/megatron/fused_kernels/scaled_upper_triang_masked_softmax.h new file mode 100644 index 000000000..addca0a0a --- /dev/null +++ b/megatron/fused_kernels/scaled_upper_triang_masked_softmax.h @@ -0,0 +1,511 @@ +/* coding=utf-8 + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace { + +template +__device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src); + +template <> +__device__ __inline__ void copy_vector(c10::BFloat16 *dst, const c10::BFloat16 *src) { *dst = *src; } + +template <> +__device__ __inline__ void copy_vector(c10::BFloat16 *dst, const c10::BFloat16 *src) { *((float2*) dst) = *((float2*) src); } + +template <> +__device__ __inline__ void copy_vector(c10::Half *dst, const c10::Half *src) { *dst = *src; } + +template <> +__device__ __inline__ void copy_vector(c10::Half *dst, const c10::Half *src) { *((float2*) dst) = *((float2*) src); } + +template <> +__device__ __inline__ void copy_vector(uint8_t *dst, const uint8_t *src) { *dst = *src; } + +template <> +__device__ __inline__ void copy_vector(uint8_t *dst, const uint8_t *src) {*((half2*) dst) = *((half2*) src); } + +template +__device__ __inline__ void copy_zero_vector(Datatype *dst); + +template <> +__device__ __inline__ void copy_zero_vector(c10::BFloat16 *dst) { *dst = 0.0; } + +template <> +__device__ __inline__ void copy_zero_vector(c10::BFloat16 *dst) { *((float2*) dst) = make_float2(0.0f, 0.0f); } + +template <> +__device__ __inline__ void copy_zero_vector(c10::Half *dst) { *dst = 0.0; } + +template <> +__device__ __inline__ void copy_zero_vector(c10::Half *dst) { *((float2*) dst) = make_float2(0.0f, 0.0f); } + + +int log2_ceil(int value) { + int log2_value = 0; + while ((1 << log2_value) < value) ++log2_value; + return log2_value; +} + +template +struct Add { + __device__ __forceinline__ T operator()(T a, T b) const { + return a + b; + } +}; + +template +struct Max { + __device__ __forceinline__ T operator()(T a, T b) const { + return a < b ? b : a; + } +}; + +template +__device__ __forceinline__ T WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff) +{ +#if CUDA_VERSION >= 9000 + return __shfl_xor_sync(mask, value, laneMask, width); +#else + return __shfl_xor(value, laneMask, width); +#endif +} + +template class ReduceOp> +__device__ __forceinline__ void warp_reduce(acc_t* sum) { + ReduceOp r; + #pragma unroll + for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE); + sum[i] = r(sum[i], b); + } + } +} + +/* + * Extended softmax (from native aten pytorch) with following additional features + * 1) input scaling + * 2) Implicit time (diagonal masking) + */ +template +__global__ void scaled_upper_triang_masked_softmax_warp_forward( + output_t *dst, + const input_t *src, + const acc_t scale, + int micro_batch_size, + int stride, + int element_count) +{ + // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and + // warp_size of method warp_softmax_forward_kernel. + constexpr int next_power_of_two = 1 << log2_elements; + constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; + constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; + constexpr int ELEMENTS_PER_LDG_STG = 4; + + int first_batch = (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + blockIdx.x; + int local_seq = blockIdx.x + 1; + int warp_iteration_limit = (local_seq + ELEMENTS_PER_LDG_STG * WARP_SIZE - 1)/ WARP_SIZE; + + // micro_batch_size might not be a multiple of WARP_BATCH. Check how + // many batches have to computed within this WARP. + int local_batches = micro_batch_size - first_batch; + if (local_batches > WARP_BATCH) + local_batches = WARP_BATCH; + + // there might be multiple batches per warp. compute the index within the batch + int local_idx = threadIdx.x; + + src += first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx; + dst += first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx; + + // load data from global memory + acc_t elements[WARP_BATCH][WARP_ITERATIONS]; + input_t temp_data[ELEMENTS_PER_LDG_STG]; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + int batch_element_count = (i >= local_batches) ? 0 : local_seq; + + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + + if (element_index < batch_element_count) { + copy_vector(temp_data, src + i*element_count*stride + it*WARP_SIZE); + + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + if ((element_index + element) < batch_element_count) { + elements[i][it+element] = (acc_t)temp_data[element] * scale; + } else { + elements[i][it + element] = -std::numeric_limits::infinity(); + } + } + } else { + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + elements[i][it + element] = -std::numeric_limits::infinity(); + } + } + } + } + + // compute max_value + acc_t max_value[WARP_BATCH]; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + max_value[i] = elements[i][0]; + #pragma unroll + for (int it = 1; it < WARP_ITERATIONS; ++it) { + max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it]; + } + } + warp_reduce(max_value); + + acc_t sum[WARP_BATCH] { 0.0f }; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; ++it) { + if (it < warp_iteration_limit) { + elements[i][it] = std::exp((elements[i][it] - max_value[i])); + sum[i] += elements[i][it]; + } + } + } + warp_reduce(sum); + + // store result + output_t out[ELEMENTS_PER_LDG_STG]; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + if (i >= local_batches) + break; + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + + if (element_index < local_seq) { + + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + if (element_index + element < local_seq) { + out[element] = elements[i][it + element] / sum[i]; + } else { + out[element] = 0; + } + } + copy_vector(dst + i * element_count * stride + it * WARP_SIZE, out); + } else if (element_index < element_count) { + copy_zero_vector(dst + i * element_count * stride + it * WARP_SIZE); + } else { + break; + } + } + } +} + +template +__global__ void scaled_upper_triang_masked_softmax_warp_backward( + output_t *gradInput, + input_t *grad, + const input_t *output, + acc_t scale, + int micro_batch_size, + int stride, + int element_count) +{ + // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and + // warp_size of method warp_softmax_backward_kernel. + constexpr int next_power_of_two = 1 << log2_elements; + constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; + constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; + constexpr int ELEMENTS_PER_LDG_STG = 4; + + int first_batch = (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + blockIdx.x; + int local_seq = blockIdx.x + 1; + + // micro_batch_size might not be a multiple of WARP_BATCH. Check how + // many batches have to computed within this WARP. + int local_batches = micro_batch_size - first_batch; + if (local_batches > WARP_BATCH) + local_batches = WARP_BATCH; + + // there might be multiple batches per warp. compute the index within the batch + int local_idx = threadIdx.x; + + // the first element to process by the current thread + int thread_offset = first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx; + grad += thread_offset; + output += thread_offset; + gradInput += thread_offset; + + // load data from global memory + acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f }; + acc_t output_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f }; + input_t temp_grad[ELEMENTS_PER_LDG_STG]; + input_t temp_output[ELEMENTS_PER_LDG_STG]; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + int batch_element_count = (i >= local_batches) ? 0 : local_seq; + + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + if (element_index < batch_element_count) { + copy_vector(temp_grad, grad + i * element_count * stride + it * WARP_SIZE); + copy_vector(temp_output, output + i * element_count * stride + it * WARP_SIZE); + + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + if (element_index + element < batch_element_count) { + output_reg[i][it + element] = (acc_t)temp_output[element]; + } + } + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + if (element_index + element < batch_element_count) { + grad_reg[i][it + element] = (acc_t)temp_grad[element] * output_reg[i][it + element]; + } + } + } + } + } + + acc_t sum[WARP_BATCH]; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + sum[i] = grad_reg[i][0]; + #pragma unroll + for (int it = 1; it < WARP_ITERATIONS; ++it) { + sum[i] += grad_reg[i][it]; + } + } + warp_reduce(sum); + + // store result + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + if (i >= local_batches) + break; + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + if (element_index < element_count) { + // compute gradients + output_t out[ELEMENTS_PER_LDG_STG]; + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + out[element] = (output_t)(scale * (grad_reg[i][it + element] - output_reg[i][it + element] * sum[i])); + } + copy_vector(gradInput + i * element_count * stride + it * WARP_SIZE, out); + } + } + } +} + +} // end of anonymous namespace + +template +void dispatch_scaled_upper_triang_masked_softmax_forward( + output_t *dst, + const input_t *src, + const input_t scale, + int softmax_elements, + int softmax_elements_stride, + int attn_batches) +{ + TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 2048 ); + if (softmax_elements == 0) { + return; + } else { + int log2_elements = log2_ceil(softmax_elements); + const int next_power_of_two = 1 << log2_elements; + int seq_len = softmax_elements; + int batch_count = attn_batches * seq_len; + + // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward. + int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + + // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward. + int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; + + // use 128 threads per block to maximimize gpu utilization + constexpr int threads_per_block = 128; + + int warps_per_block = (threads_per_block / warp_size); + int batches_per_block = warps_per_block * batches_per_warp; + TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0); + int blocks_per_seq = attn_batches / batches_per_block; + dim3 blocks(seq_len, blocks_per_seq, 1); + dim3 threads(warp_size, warps_per_block, 1); + // Launch code would be more elegant if C++ supported FOR CONSTEXPR + switch (log2_elements) { + case 0: // 1 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 1: // 2 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 2: // 4 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 3: // 8 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 4: // 16 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 5: // 32 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 6: // 64 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 7: // 128 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 8: // 256 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 9: // 512 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 10: // 1024 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 11: // 2048 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + default: + break; + } + } +} + +template +void dispatch_scaled_upper_triang_masked_softmax_backward( + output_t *grad_input, + input_t *grad, + const input_t *output, + const acc_t scale, + int softmax_elements, + int softmax_elements_stride, + int attn_batches) +{ + TORCH_INTERNAL_ASSERT( softmax_elements >= 0 && softmax_elements <= 2048 ); + if (softmax_elements == 0) { + return; + } else { + int log2_elements = log2_ceil(softmax_elements); + const int next_power_of_two = 1 << log2_elements; + int seq_len = softmax_elements; + int batch_count = attn_batches * seq_len; + + // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward. + int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + + // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward. + int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; + + // use 128 threads per block to maximimize gpu utilization + constexpr int threads_per_block = 128; + + int warps_per_block = (threads_per_block / warp_size); + int batches_per_block = warps_per_block * batches_per_warp; + TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0); + int blocks_per_seq = attn_batches / batches_per_block; + dim3 blocks(seq_len, blocks_per_seq, 1); + dim3 threads(warp_size, warps_per_block, 1); + // Launch code would be more elegant if C++ supported FOR CONSTEXPR + switch (log2_elements) { + case 0: // 1 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 1: // 2 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 2: // 4 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 3: // 8 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 4: // 16 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 5: // 32 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 6: // 64 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 7: // 128 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 8: // 256 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 9: // 512 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 10: // 1024 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 11: // 2048 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + default: + break; + } + } +} diff --git a/megatron/fused_kernels/scaled_upper_triang_masked_softmax_cuda.cu b/megatron/fused_kernels/scaled_upper_triang_masked_softmax_cuda.cu new file mode 100644 index 000000000..59e452584 --- /dev/null +++ b/megatron/fused_kernels/scaled_upper_triang_masked_softmax_cuda.cu @@ -0,0 +1,100 @@ +/* coding=utf-8 + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#ifndef __HIP_PLATFORM_HCC__ +#include +#endif +#include +#include +#include "scaled_upper_triang_masked_softmax.h" +#include "type_shim.h" + +namespace multihead_attn { +namespace fused_softmax { +namespace scaled_upper_triang_masked_softmax { + +torch::Tensor fwd_cuda( + torch::Tensor const& input, + float scale_factor) +{ + // input is a 3d tensor with dimensions [attn_batches, seq_len, seq_len] + const int attn_batches = input.size(0); + const int seq_len = input.size(1); + TORCH_INTERNAL_ASSERT(seq_len <= 2048); + + // Output + auto act_options = input.options().requires_grad(false); + torch::Tensor softmax_results = + torch::empty({attn_batches, seq_len, seq_len}, act_options); + + // Softmax Intermediate Result Ptr + void* input_ptr = static_cast(input.data_ptr()); + void* softmax_results_ptr = static_cast(softmax_results.data_ptr()); + + DISPATCH_HALF_AND_BFLOAT( + input.scalar_type(), + "dispatch_scaled_upper_triang_masked_softmax_forward", + dispatch_scaled_upper_triang_masked_softmax_forward( + reinterpret_cast(softmax_results_ptr), + reinterpret_cast(input_ptr), + scale_factor, + seq_len, + seq_len, + attn_batches); + ); + return softmax_results; +} + + +torch::Tensor bwd_cuda( + torch::Tensor const& output_grads_, + torch::Tensor const& softmax_results_, + float scale_factor) { + + auto output_grads = output_grads_.contiguous(); + auto softmax_results = softmax_results_.contiguous(); + + //output grads is a 3d tensor with dimensions [attn_batches, seq_len, seq_len] + const int attn_batches = output_grads.size(0); + const int seq_len = output_grads.size(1); + TORCH_INTERNAL_ASSERT(output_grads.size(1) == output_grads.size(2)); + + void* output_grads_ptr = static_cast(output_grads.data_ptr()); + + //Softmax Grad + DISPATCH_HALF_AND_BFLOAT( + output_grads_.scalar_type(), + "dispatch_scaled_upper_triang_masked_softmax_backward", + dispatch_scaled_upper_triang_masked_softmax_backward( + reinterpret_cast(output_grads_ptr), + reinterpret_cast(output_grads_ptr), + reinterpret_cast(softmax_results.data_ptr()), + scale_factor, + seq_len, + seq_len, + attn_batches); + ); + + //backward pass is completely in-place + return output_grads; +} +} +} +} diff --git a/megatron/fused_kernels/type_shim.h b/megatron/fused_kernels/type_shim.h new file mode 100644 index 000000000..6437dcc7c --- /dev/null +++ b/megatron/fused_kernels/type_shim.h @@ -0,0 +1,91 @@ +/* coding=utf-8 + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +#include +#include "compat.h" + + +#define DISPATCH_HALF_AND_BFLOAT(TYPE, NAME, ...) \ + switch(TYPE) \ + { \ + case at::ScalarType::Half: \ + { \ + using scalar_t = at::Half; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::BFloat16: \ + { \ + using scalar_t = at::BFloat16; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ + } + + + +#define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \ + switch(TYPEIN) \ + { \ + case at::ScalarType::Float: \ + { \ + using scalar_t_in = float; \ + switch(TYPEOUT) \ + { \ + case at::ScalarType::Float: \ + { \ + using scalar_t_out = float; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::Half: \ + { \ + using scalar_t_out = at::Half; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::BFloat16: \ + { \ + using scalar_t_out = at::BFloat16; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \ + } \ + break; \ + } \ + case at::ScalarType::Half: \ + { \ + using scalar_t_in = at::Half; \ + using scalar_t_out = at::Half; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::BFloat16: \ + { \ + using scalar_t_in = at::BFloat16; \ + using scalar_t_out = at::BFloat16; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \ + } + diff --git a/megatron/initialize.py b/megatron/initialize.py index 4b7e11a6c..baf3669a2 100644 --- a/megatron/initialize.py +++ b/megatron/initialize.py @@ -21,7 +21,7 @@ import time import numpy as np import torch -from megatron import fused_kernels + from megatron import get_adlr_autoresume from megatron import get_args from megatron import get_tensorboard_writer @@ -31,7 +31,6 @@ from megatron.mpu import (set_tensor_model_parallel_rank, set_tensor_model_parallel_world_size) from deepspeed.accelerator import get_accelerator import deepspeed -import deepspeed.utils.groups as groups def initialize_megatron(extra_args_provider=None, args_defaults={}, ignore_unknown_args=False, allow_no_cuda=False): @@ -92,66 +91,13 @@ def initialize_megatron(extra_args_provider=None, args_defaults={}, def _compile_dependencies(): - - args = get_args() - - # ========================= - # Compile dataset C++ code. - # ========================= - # TODO: move this to ninja - if _is_rank_0(): + if torch.distributed.get_rank() == 0: start_time = time.time() print('> compiling dataset index builder ...') from megatron.data.dataset_utils import compile_helper compile_helper() print('>>> done with dataset index builder. Compilation time: {:.3f} ' 'seconds'.format(time.time() - start_time), flush=True) - - if not get_accelerator().device_name() == 'cuda': - print(">fused kernel is only supported in cuda, skip loading fused kernel") - return - # ================== - # Load fused kernels - # ================== - - # Custom kernel constraints check. - seq_len = args.seq_length - attn_batch_size = \ - (args.num_attention_heads / args.tensor_model_parallel_size) * \ - args.micro_batch_size - # Constraints on sequence length and attn_batch_size to enable warp based - # optimization and upper triangular optimization (for causal mask) - custom_kernel_constraint = seq_len > 16 and seq_len <=2048 and \ - seq_len % 4 == 0 and attn_batch_size % 4 == 0 - # Print a warning. - if not ((args.fp16 or args.bf16) and - custom_kernel_constraint and - args.masked_softmax_fusion): - if args.rank == 0: - print('WARNING: constraints for invoking optimized' - ' fused softmax kernel are not met. We default' - ' back to unfused kernel invocations.', flush=True) - - # Always build on rank zero first. - if _is_rank_0(): - start_time = time.time() - print('> compiling and loading fused kernels ...', flush=True) - if get_accelerator().device_count() > 0: # Skip when CPU-only - fused_kernels.load(args) - torch.distributed.barrier() - else: - torch.distributed.barrier() - fused_kernels.load(args) - # Simple barrier to make sure all ranks have passed the - # compilation phase successfully before moving on to the - # rest of the program. We think this might ensure that - # the lock is released. - torch.distributed.barrier() - if _is_rank_0(): - print('>>> done with compiling and loading fused kernels. ' - 'Compilation time: {:.3f} seconds'.format( - time.time() - start_time), flush=True) - def setup_deepspeed_random_and_activation_checkpointing(args): '''Optional DeepSpeed Activation Checkpointing features. diff --git a/megatron/model/__init__.py b/megatron/model/__init__.py index 917b0f341..40967f417 100644 --- a/megatron/model/__init__.py +++ b/megatron/model/__init__.py @@ -23,4 +23,4 @@ from .gpt_model import GPTModel, GPTModelPipe from .llama_model import LlamaModel, LlamaModelPipe from .t5_model import T5Model from .language_model import get_language_model -from .module import Float16Module \ No newline at end of file +from .module import Float16Module diff --git a/megatron/model/bert_model.py b/megatron/model/bert_model.py new file mode 100644 index 000000000..be4025111 --- /dev/null +++ b/megatron/model/bert_model.py @@ -0,0 +1,236 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""BERT model.""" + +import torch + +from megatron import get_args +from megatron import mpu +from megatron.model.enums import AttnMaskType +from megatron.model.language_model import parallel_lm_logits +from megatron.model.language_model import get_language_model +from megatron.model import LayerNorm +from megatron.model.utils import openai_gelu, erf_gelu +from megatron.model.utils import get_linear_layer +from megatron.model.utils import init_method_normal +from megatron.model.utils import scaled_init_method_normal +from .module import MegatronModule + +def bert_extended_attention_mask(attention_mask): + # We create a 3D attention mask from a 2D tensor mask. + # [b, 1, s] + attention_mask_b1s = attention_mask.unsqueeze(1) + # [b, s, 1] + attention_mask_bs1 = attention_mask.unsqueeze(2) + # [b, s, s] + attention_mask_bss = attention_mask_b1s * attention_mask_bs1 + # [b, 1, s, s] + extended_attention_mask = attention_mask_bss.unsqueeze(1) + + # Convert attention mask to binary: + extended_attention_mask = (extended_attention_mask < 0.5) + + return extended_attention_mask + +def bert_position_ids(token_ids): + # Create position ids + seq_length = token_ids.size(1) + position_ids = torch.arange(seq_length, dtype=torch.long, + device=token_ids.device) + position_ids = position_ids.unsqueeze(0).expand_as(token_ids) + + return position_ids + + +class BertLMHead(MegatronModule): + """Masked LM head for Bert + + Arguments: + mpu_vocab_size: model parallel size of vocabulary. + hidden_size: hidden size + init_method: init method for weight initialization + layernorm_epsilon: tolerance for layer norm divisions + parallel_output: whether output logits being distributed or not. + """ + + def __init__(self, mpu_vocab_size, hidden_size, init_method, + layernorm_epsilon, parallel_output): + + super(BertLMHead, self).__init__() + + args = get_args() + + self.bias = torch.nn.Parameter(torch.zeros(mpu_vocab_size)) + mpu.set_tensor_model_parallel_attributes(self.bias, True, 0, 1) + self.parallel_output = parallel_output + + self.dense = get_linear_layer(hidden_size, hidden_size, init_method) + self.layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon) + self.gelu = torch.nn.functional.gelu + if args.openai_gelu: + self.gelu = openai_gelu + elif args.onnx_safe: + self.gelu = erf_gelu + + def forward(self, hidden_states, word_embeddings_weight): + hidden_states = self.dense(hidden_states) + hidden_states = self.gelu(hidden_states) + hidden_states = self.layernorm(hidden_states) + output = parallel_lm_logits(hidden_states, + word_embeddings_weight, + self.parallel_output, + bias=self.bias) + return output + + +def post_language_model_processing(lm_output, pooled_output, + lm_head, binary_head, + lm_labels, + logit_weights, + fp16_lm_cross_entropy): + # Output. + lm_logits = lm_head( + lm_output, logit_weights) + + binary_logits = None + if binary_head is not None: + binary_logits = binary_head(pooled_output) + + if lm_labels is None: + return lm_logits, binary_logits + else: + if fp16_lm_cross_entropy: + assert lm_logits.dtype == torch.half + lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits, lm_labels) + else: + lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits.float(), + lm_labels) + return lm_loss, binary_logits + + +class BertModel(MegatronModule): + """Bert Language model.""" + + def __init__(self, + num_tokentypes=2, + add_binary_head=True, + parallel_output=True, + pre_process=True, + post_process=True): + super(BertModel, self).__init__() + args = get_args() + + self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy + self.add_binary_head = add_binary_head + self.parallel_output = parallel_output + self.pre_process = pre_process + self.post_process = post_process + + init_method = init_method_normal(args.init_method_std) + scaled_init_method = scaled_init_method_normal(args.init_method_std, + args.num_layers) + + self.language_model, self._language_model_key = get_language_model( + num_tokentypes=num_tokentypes, + add_pooler=self.add_binary_head, + encoder_attn_mask_type=AttnMaskType.padding, + init_method=init_method, + scaled_init_method=scaled_init_method, + pre_process=self.pre_process, + post_process=self.post_process) + + self.initialize_word_embeddings(init_method_normal) + if self.post_process: + self.lm_head = BertLMHead( + self.word_embeddings_weight().size(0), + args.hidden_size, init_method, args.layernorm_epsilon, parallel_output) + self._lm_head_key = 'lm_head' + self.binary_head = None + if self.add_binary_head: + self.binary_head = get_linear_layer(args.hidden_size, 2, + init_method) + self._binary_head_key = 'binary_head' + + def set_input_tensor(self, input_tensor): + """See megatron.model.transformer.set_input_tensor()""" + self.language_model.set_input_tensor(input_tensor) + + def forward(self, bert_model_input, attention_mask, + tokentype_ids=None, lm_labels=None): + + extended_attention_mask = bert_extended_attention_mask(attention_mask) + input_ids = bert_model_input + position_ids = bert_position_ids(input_ids) + + lm_output = self.language_model( + input_ids, + position_ids, + extended_attention_mask, + tokentype_ids=tokentype_ids + ) + + if self.post_process and self.add_binary_head: + lm_output, pooled_output = lm_output[0], lm_output[1] + else: + pooled_output = None + + if self.post_process: + return post_language_model_processing(lm_output, pooled_output, + self.lm_head, self.binary_head, + lm_labels, + self.word_embeddings_weight(), + self.fp16_lm_cross_entropy) + else: + return lm_output + + + def state_dict_for_save_checkpoint(self, destination=None, prefix='', + keep_vars=False): + """For easy load when model is combined with other heads, + add an extra key.""" + + state_dict_ = {} + state_dict_[self._language_model_key] \ + = self.language_model.state_dict_for_save_checkpoint( + destination, prefix, keep_vars) + if self.post_process: + state_dict_[self._lm_head_key] \ + = self.lm_head.state_dict_for_save_checkpoint( + destination, prefix, keep_vars) + if self.post_process and self.add_binary_head: + state_dict_[self._binary_head_key] \ + = self.binary_head.state_dict(destination, prefix, keep_vars) + # Save word_embeddings. + if self.post_process and not self.pre_process: + state_dict_[self._word_embeddings_for_head_key] \ + = self.word_embeddings.state_dict(destination, prefix, keep_vars) + return state_dict_ + + def load_state_dict(self, state_dict, strict=True): + """Customized load.""" + + self.language_model.load_state_dict( + state_dict[self._language_model_key], strict=strict) + if self.post_process: + self.lm_head.load_state_dict( + state_dict[self._lm_head_key], strict=strict) + if self.post_process and self.add_binary_head: + self.binary_head.load_state_dict( + state_dict[self._binary_head_key], strict=strict) + # Load word_embeddings. + if self.post_process and not self.pre_process: + self.word_embeddings.load_state_dict( + state_dict[self._word_embeddings_for_head_key], strict=strict) diff --git a/megatron/model/biencoder_model.py b/megatron/model/biencoder_model.py new file mode 100644 index 000000000..51ac0a060 --- /dev/null +++ b/megatron/model/biencoder_model.py @@ -0,0 +1,295 @@ +import os +import torch +import sys + +from megatron import get_args, print_rank_0 +from megatron.checkpointing import fix_query_key_value_ordering +from megatron.checkpointing import get_checkpoint_tracker_filename +from megatron.checkpointing import get_checkpoint_name +from megatron import mpu, get_tokenizer +from megatron.model.bert_model import bert_position_ids +from megatron.model.enums import AttnMaskType +from megatron.model.language_model import get_language_model +from megatron.model.utils import get_linear_layer +from megatron.model.utils import init_method_normal +from megatron.model.utils import scaled_init_method_normal +from .module import MegatronModule + +def biencoder_model_provider(only_query_model=False, + only_context_model=False, + biencoder_shared_query_context_model=False): + """Build the model.""" + args = get_args() + + assert mpu.get_tensor_model_parallel_world_size() == 1 and \ + mpu.get_pipeline_model_parallel_world_size() == 1, \ + "Model parallel size > 1 not supported for ICT" + + print_rank_0('building BiEncoderModel...') + + # simpler to just keep using 2 tokentypes since + # the LM we initialize with has 2 tokentypes + model = BiEncoderModel( + num_tokentypes=2, + parallel_output=False, + only_query_model=only_query_model, + only_context_model=only_context_model, + biencoder_shared_query_context_model=\ + biencoder_shared_query_context_model) + + return model + + +class BiEncoderModel(MegatronModule): + """Bert-based module for Biencoder model.""" + + def __init__(self, + num_tokentypes=1, + parallel_output=True, + only_query_model=False, + only_context_model=False, + biencoder_shared_query_context_model=False): + super(BiEncoderModel, self).__init__() + args = get_args() + + bert_kwargs = dict( + num_tokentypes=num_tokentypes, + parallel_output=parallel_output) + + self.biencoder_shared_query_context_model = \ + biencoder_shared_query_context_model + assert not (only_context_model and only_query_model) + self.use_context_model = not only_query_model + self.use_query_model = not only_context_model + self.biencoder_projection_dim = args.biencoder_projection_dim + + if self.biencoder_shared_query_context_model: + self.model = PretrainedBertModel(**bert_kwargs) + self._model_key = 'shared_model' + self.query_model, self.context_model = self.model, self.model + else: + if self.use_query_model: + # this model embeds (pseudo-)queries - Embed_input in the paper + self.query_model = PretrainedBertModel(**bert_kwargs) + self._query_key = 'query_model' + + if self.use_context_model: + # this model embeds evidence blocks - Embed_doc in the paper + self.context_model = PretrainedBertModel(**bert_kwargs) + self._context_key = 'context_model' + + def forward(self, query_tokens, query_attention_mask, query_types, + context_tokens, context_attention_mask, context_types): + """Run a forward pass for each of the models and + return the respective embeddings.""" + + if self.use_query_model: + query_logits = self.embed_text(self.query_model, + query_tokens, + query_attention_mask, + query_types) + else: + raise ValueError("Cannot embed query without the query model.") + if self.use_context_model: + context_logits = self.embed_text(self.context_model, + context_tokens, + context_attention_mask, + context_types) + else: + raise ValueError("Cannot embed block without the block model.") + return query_logits, context_logits + + @staticmethod + def embed_text(model, tokens, attention_mask, token_types): + """Embed a batch of tokens using the model""" + logits = model(tokens, + attention_mask, + token_types) + return logits + + def state_dict_for_save_checkpoint(self, destination=None, \ + prefix='', keep_vars=False): + """Save dict with state dicts of each of the models.""" + state_dict_ = {} + if self.biencoder_shared_query_context_model: + state_dict_[self._model_key] = \ + self.model.state_dict_for_save_checkpoint(destination, + prefix, + keep_vars) + else: + if self.use_query_model: + state_dict_[self._query_key] = \ + self.query_model.state_dict_for_save_checkpoint( + destination, prefix, keep_vars) + + if self.use_context_model: + state_dict_[self._context_key] = \ + self.context_model.state_dict_for_save_checkpoint( + destination, prefix, keep_vars) + + return state_dict_ + + def load_state_dict(self, state_dict, strict=True): + """Load the state dicts of each of the models""" + if self.biencoder_shared_query_context_model: + print_rank_0("Loading shared query-context model") + self.model.load_state_dict(state_dict[self._model_key], \ + strict=strict) + else: + if self.use_query_model: + print_rank_0("Loading query model") + self.query_model.load_state_dict( \ + state_dict[self._query_key], strict=strict) + + if self.use_context_model: + print_rank_0("Loading context model") + self.context_model.load_state_dict( \ + state_dict[self._context_key], strict=strict) + + def init_state_dict_from_bert(self): + """Initialize the state from a pretrained BERT model + on iteration zero of ICT pretraining""" + args = get_args() + + if args.bert_load is None: + print_rank_0("bert-load argument is None") + return + + tracker_filename = get_checkpoint_tracker_filename(args.bert_load) + if not os.path.isfile(tracker_filename): + raise FileNotFoundError("Could not find BERT checkpoint") + with open(tracker_filename, 'r') as f: + iteration = int(f.read().strip()) + assert iteration > 0 + + checkpoint_name = get_checkpoint_name(args.bert_load, iteration, False) + if mpu.get_data_parallel_rank() == 0: + print('global rank {} is loading BERT checkpoint {}'.format( + torch.distributed.get_rank(), checkpoint_name)) + + # Load the checkpoint. + try: + state_dict = torch.load(checkpoint_name, map_location='cpu') + except ModuleNotFoundError: + from megatron.fp16_deprecated import loss_scaler + # For backward compatibility. + print_rank_0(' > deserializing using the old code structure ...') + sys.modules['fp16.loss_scaler'] = sys.modules[ + 'megatron.fp16_deprecated.loss_scaler'] + sys.modules['megatron.fp16.loss_scaler'] = sys.modules[ + 'megatron.fp16_deprecated.loss_scaler'] + state_dict = torch.load(checkpoint_name, map_location='cpu') + sys.modules.pop('fp16.loss_scaler', None) + sys.modules.pop('megatron.fp16.loss_scaler', None) + except BaseException: + print_rank_0('could not load the BERT checkpoint') + sys.exit() + + checkpoint_version = state_dict.get('checkpoint_version', 0) + + # load the LM state dict into each model + model_dict = state_dict['model']['language_model'] + + if self.biencoder_shared_query_context_model: + self.model.language_model.load_state_dict(model_dict) + fix_query_key_value_ordering(self.model, checkpoint_version) + else: + if self.use_query_model: + self.query_model.language_model.load_state_dict(model_dict) + # give each model the same ict_head to begin with as well + if self.biencoder_projection_dim > 0: + query_proj_state_dict = \ + self.state_dict_for_save_checkpoint()\ + [self._query_key]['projection_enc'] + fix_query_key_value_ordering(self.query_model, checkpoint_version) + + if self.use_context_model: + self.context_model.language_model.load_state_dict(model_dict) + if self.query_model is not None and \ + self.biencoder_projection_dim > 0: + self.context_model.projection_enc.load_state_dict\ + (query_proj_state_dict) + fix_query_key_value_ordering(self.context_model, checkpoint_version) + + +class PretrainedBertModel(MegatronModule): + """BERT-based encoder for queries or contexts used for + learned information retrieval.""" + + def __init__(self, num_tokentypes=2, + parallel_output=True): + super(PretrainedBertModel, self).__init__() + + args = get_args() + tokenizer = get_tokenizer() + self.pad_id = tokenizer.pad + self.biencoder_projection_dim = args.biencoder_projection_dim + self.parallel_output = parallel_output + init_method = init_method_normal(args.init_method_std) + scaled_init_method = scaled_init_method_normal( + args.init_method_std, args.num_layers) + + self.language_model, self._language_model_key = get_language_model( + num_tokentypes=num_tokentypes, + add_pooler=False, + encoder_attn_mask_type=AttnMaskType.padding, + init_method=init_method, + scaled_init_method=scaled_init_method) + + if args.biencoder_projection_dim > 0: + self.projection_enc = get_linear_layer(args.hidden_size, + args.biencoder_projection_dim, + init_method) + self._projection_enc_key = 'projection_enc' + + def forward(self, input_ids, attention_mask, tokentype_ids=None): + extended_attention_mask = attention_mask.unsqueeze(1) + #extended_attention_mask = bert_extended_attention_mask(attention_mask) + position_ids = bert_position_ids(input_ids) + + + lm_output = self.language_model(input_ids, + position_ids, + extended_attention_mask, + tokentype_ids=tokentype_ids) + # This mask will be used in average-pooling and max-pooling + pool_mask = (input_ids == self.pad_id).unsqueeze(2) + + # Taking the representation of the [CLS] token of BERT + pooled_output = lm_output[:, 0, :] + + # Converting to float16 dtype + pooled_output = pooled_output.to(lm_output.dtype) + + # Output. + if self.biencoder_projection_dim: + pooled_output = self.projection_enc(pooled_output) + + return pooled_output + + def state_dict_for_save_checkpoint(self, destination=None, prefix='', + keep_vars=False): + """For easy load when model is combined with other heads, + add an extra key.""" + + state_dict_ = {} + state_dict_[self._language_model_key] \ + = self.language_model.state_dict_for_save_checkpoint( + destination, prefix, keep_vars) + + if self.biencoder_projection_dim > 0: + state_dict_[self._projection_enc_key] = \ + self.projection_enc.state_dict(destination, prefix, keep_vars) + + return state_dict_ + + def load_state_dict(self, state_dict, strict=True): + """Customized load.""" + print_rank_0("loading BERT weights") + self.language_model.load_state_dict( + state_dict[self._language_model_key], strict=strict) + + if self.biencoder_projection_dim > 0: + print_rank_0("loading projection head weights") + self.projection_enc.load_state_dict( + state_dict[self._projection_enc_key], strict=strict) diff --git a/megatron/model/classification.py b/megatron/model/classification.py new file mode 100644 index 000000000..d4742c939 --- /dev/null +++ b/megatron/model/classification.py @@ -0,0 +1,119 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Classification model.""" + +import torch + +from megatron import get_args, print_rank_last +from megatron import mpu +from megatron.model.enums import AttnMaskType +from megatron.model.bert_model import bert_extended_attention_mask, bert_position_ids +from megatron.model.language_model import get_language_model +from megatron.model.utils import get_linear_layer +from megatron.model.utils import init_method_normal +from megatron.model.utils import scaled_init_method_normal +from .module import MegatronModule + + +class Classification(MegatronModule): + + def __init__(self, + num_classes, + num_tokentypes=2, + pre_process=True, + post_process=True): + super(Classification, self).__init__(share_word_embeddings=False) + args = get_args() + + self.num_classes = num_classes + self.pre_process = pre_process + self.post_process = post_process + init_method = init_method_normal(args.init_method_std) + + self.language_model, self._language_model_key = get_language_model( + num_tokentypes=num_tokentypes, + add_pooler=True, + encoder_attn_mask_type=AttnMaskType.padding, + init_method=init_method, + scaled_init_method=scaled_init_method_normal(args.init_method_std, + args.num_layers), + pre_process=self.pre_process, + post_process=self.post_process) + + # Multi-choice head. + if self.post_process: + self.classification_dropout = torch.nn.Dropout(args.hidden_dropout) + self.classification_head = get_linear_layer(args.hidden_size, + self.num_classes, + init_method) + self._classification_head_key = 'classification_head' + + def set_input_tensor(self, input_tensor): + """See megatron.model.transformer.set_input_tensor()""" + self.language_model.set_input_tensor(input_tensor) + + def forward(self, model_input, attention_mask, tokentype_ids=None): + + extended_attention_mask = bert_extended_attention_mask(attention_mask) + input_ids = model_input + position_ids = bert_position_ids(input_ids) + + lm_output = self.language_model( + input_ids, + position_ids, + extended_attention_mask, + tokentype_ids=tokentype_ids + ) + + if self.post_process: + _, pooled_output = lm_output[0], lm_output[1] + classification_output = self.classification_dropout(pooled_output) + classification_logits = self.classification_head(classification_output) + + # Reshape back to separate choices. + classification_logits = classification_logits.view(-1, self.num_classes) + + return classification_logits + return lm_output + + def state_dict_for_save_checkpoint(self, destination=None, prefix='', + keep_vars=False): + """For easy load when model is combined with other heads, + add an extra key.""" + + state_dict_ = {} + state_dict_[self._language_model_key] \ + = self.language_model.state_dict_for_save_checkpoint( + destination, prefix, keep_vars) + if self.post_process: + state_dict_[self._classification_head_key] \ + = self.classification_head.state_dict( + destination, prefix, keep_vars) + return state_dict_ + + def load_state_dict(self, state_dict, strict=True): + """Customized load.""" + + self.language_model.load_state_dict( + state_dict[self._language_model_key], strict=strict) + if self.post_process: + if self._classification_head_key in state_dict: + self.classification_head.load_state_dict( + state_dict[self._classification_head_key], strict=strict) + else: + print_rank_last('***WARNING*** could not find {} in the checkpoint, ' + 'initializing to random'.format( + self._classification_head_key)) diff --git a/megatron/model/distributed.py b/megatron/model/distributed.py index ad86345c7..320261fee 100644 --- a/megatron/model/distributed.py +++ b/megatron/model/distributed.py @@ -188,6 +188,7 @@ class DistributedDataParallel(DistributedDataParallelBase): def allreduce_gradients(self): """Reduce gradients across data parallel ranks.""" # If we have buffers, simply reduce the data in the buffer. + if self._grad_buffers is not None: for _, buffer_ in self._grad_buffers.items(): buffer_.data /= mpu.get_data_parallel_world_size() diff --git a/megatron/model/fused_layer_norm.py b/megatron/model/fused_layer_norm.py index f34a05ffe..76a7648c1 100644 --- a/megatron/model/fused_layer_norm.py +++ b/megatron/model/fused_layer_norm.py @@ -65,9 +65,9 @@ class MixedFusedLayerNorm(torch.nn.Module): def __init__(self, normalized_shape, eps=1e-5): super(MixedFusedLayerNorm, self).__init__() - global fused_mix_prec_layer_norm_cuda - fused_mix_prec_layer_norm_cuda = importlib.import_module( - "fused_mix_prec_layer_norm_cuda") + # global fused_mix_prec_layer_norm_cuda + # fused_mix_prec_layer_norm_cuda = importlib.import_module( + # "fused_mix_prec_layer_norm_cuda") if isinstance(normalized_shape, numbers.Integral): normalized_shape = (normalized_shape,) @@ -79,17 +79,16 @@ class MixedFusedLayerNorm(torch.nn.Module): def reset_parameters(self): - init.ones_(self.weight) init.zeros_(self.bias) - def forward(self, input): - # CPU path is here for unittest sake. - if not input.is_cuda: - print("WARNING! The input of FusedLayerNorm should be on the GPU." - "This warning should only be triggered in the FusedLayerNorm unit tests.") - return F.layer_norm(input, self.normalized_shape, self.weight, self.bias, self.eps) - return FusedLayerNormAffineFunction.apply( - input, self.weight, self.bias, self.normalized_shape,self.eps) + return torch.nn.functional.layer_norm(input, self.normalized_shape, self.weight, self.bias, self.eps) + # # CPU path is here for unittest sake. + # if not input.is_cuda: + # print("WARNING! The input of FusedLayerNorm should be on the GPU." + # "This warning should only be triggered in the FusedLayerNorm unit tests.") + # return F.layer_norm(input, self.normalized_shape, self.weight, self.bias, self.eps) + # return FusedLayerNormAffineFunction.apply( + # input, self.weight, self.bias, self.normalized_shape,self.eps) diff --git a/megatron/model/fused_softmax.py b/megatron/model/fused_softmax.py index 8d52967d9..eb4c69887 100644 --- a/megatron/model/fused_softmax.py +++ b/megatron/model/fused_softmax.py @@ -14,150 +14,231 @@ # limitations under the License. import torch +import torch_npu from megatron.model.enums import AttnMaskType -class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function): - """ - Fused operation which performs following three operations in sequence - 1. Scale the tensor. - 2. Apply upper triangular mask (typically used in gpt models). - 3. Perform softmax. - """ - - @staticmethod - def forward(ctx, inputs, scale): - import scaled_upper_triang_masked_softmax_cuda - - scale_t = torch.tensor([scale]) - - softmax_results = scaled_upper_triang_masked_softmax_cuda.forward( - inputs, scale_t[0] - ) - ctx.save_for_backward(softmax_results, scale_t) - return softmax_results - - @staticmethod - def backward(ctx, output_grads): - import scaled_upper_triang_masked_softmax_cuda - - softmax_results, scale_t = ctx.saved_tensors - - input_grads = scaled_upper_triang_masked_softmax_cuda.backward( - output_grads, softmax_results, scale_t[0] - ) - return input_grads, None - - -class ScaledMaskedSoftmax(torch.autograd.Function): - """ - Fused operation which performs following three operations in sequence - 1. Scale the tensor. - 2. Apply the mask. - 3. Perform softmax. - """ - - @staticmethod - def forward(ctx, inputs, mask, scale): - import scaled_masked_softmax_cuda - - scale_t = torch.tensor([scale]) - - softmax_results = scaled_masked_softmax_cuda.forward( - inputs, mask, scale_t[0] - ) - ctx.save_for_backward(softmax_results, scale_t) - return softmax_results - - @staticmethod - def backward(ctx, output_grads): - import scaled_masked_softmax_cuda - - softmax_results, scale_t = ctx.saved_tensors - - input_grads = scaled_masked_softmax_cuda.backward( - output_grads, softmax_results, scale_t[0] - ) - return input_grads, None, None +# class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function): +# """ +# Fused operation which performs following three operations in sequence +# 1. Scale the tensor. +# 2. Apply upper triangular mask (typically used in gpt models). +# 3. Perform softmax. +# """ +# +# @staticmethod +# def forward(ctx, inputs, scale): +# import scaled_upper_triang_masked_softmax_cuda +# +# scale_t = torch.tensor([scale]) +# +# softmax_results = scaled_upper_triang_masked_softmax_cuda.forward( +# inputs, scale_t[0] +# ) +# ctx.save_for_backward(softmax_results, scale_t) +# return softmax_results +# +# @staticmethod +# def backward(ctx, output_grads): +# import scaled_upper_triang_masked_softmax_cuda +# +# softmax_results, scale_t = ctx.saved_tensors +# +# input_grads = scaled_upper_triang_masked_softmax_cuda.backward( +# output_grads, softmax_results, scale_t[0] +# ) +# return input_grads, None +# +# +# class ScaledMaskedSoftmax(torch.autograd.Function): +# """ +# Fused operation which performs following three operations in sequence +# 1. Scale the tensor. +# 2. Apply the mask. +# 3. Perform softmax. +# """ +# +# @staticmethod +# def forward(ctx, inputs, mask, scale): +# import scaled_masked_softmax_cuda +# +# scale_t = torch.tensor([scale]) +# +# softmax_results = scaled_masked_softmax_cuda.forward( +# inputs, mask, scale_t[0] +# ) +# ctx.save_for_backward(softmax_results, scale_t) +# return softmax_results +# +# @staticmethod +# def backward(ctx, output_grads): +# import scaled_masked_softmax_cuda +# +# softmax_results, scale_t = ctx.saved_tensors +# +# input_grads = scaled_masked_softmax_cuda.backward( +# output_grads, softmax_results, scale_t[0] +# ) +# return input_grads, None, None class FusedScaleMaskSoftmax(torch.nn.Module): - """ - fused operation: scaling + mask + softmax - Arguments: - input_in_fp16: flag to indicate if input in fp16 data format. - attn_mask_type: attention mask type (pad or causal) - mask_func: mask function to be applied. - softmax_in_fp32: if true, softmax in performed at fp32 precision. - scale: scaling factor used in input tensor scaling. - - """ - def __init__( - self, - input_in_fp16, - input_in_bf16, - attn_mask_type, - scaled_masked_softmax_fusion, - mask_func, - softmax_in_fp32, - scale, + self, + input_in_fp16, + input_in_bf16, + attn_mask_type, + scaled_masked_softmax_fusion, + mask_func, + softmax_in_fp32, + scale, ): super(FusedScaleMaskSoftmax, self).__init__() self.input_in_fp16 = input_in_fp16 self.input_in_bf16 = input_in_bf16 - assert not (self.input_in_fp16 and self.input_in_bf16),\ - 'both fp16 and bf16 flags cannot be active at the same time.' + assert not ( + self.input_in_fp16 and self.input_in_bf16 + ), "both fp16 and bf16 flags cannot be active at the same time." self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16 self.attn_mask_type = attn_mask_type self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion self.mask_func = mask_func self.softmax_in_fp32 = softmax_in_fp32 self.scale = scale + self.mask_tri = None + p = torch.npu.get_device_properties(0) if torch.npu.is_available() else None + self.fused = p.name in ['Ascend910A', 'Ascend910ProB'] if p is not None else False assert ( - self.scale is None or softmax_in_fp32 + self.scale is None or softmax_in_fp32 ), "softmax should be in fp32 when scaled" - + def forward(self, input, mask): # [b, np, sq, sk] assert input.dim() == 4 - data_size = input.size() - query_seq_len = data_size[-2] - key_seq_len = data_size[-1] - attn_batch_size = data_size[0] * data_size[1] - - # constraints on various tensor dimensions to enable warp based - # optimization and upper triangular optimization (for causal mask) - custom_kernel_constraint = key_seq_len > 16 and key_seq_len <= 2048 and \ - query_seq_len % 4 == 0 and attn_batch_size % 4 == 0 - - # invoke custom kernel - if self.input_in_float16 and mask is not None and \ - custom_kernel_constraint and self.scaled_masked_softmax_fusion: - scale = self.scale if self.scale is not None else 1.0 - - if self.attn_mask_type == AttnMaskType.causal: - assert query_seq_len == key_seq_len, \ - "causal mask is only for self attention" - input = input.view(-1, query_seq_len, key_seq_len) - probs = ScaledUpperTriangMaskedSoftmax.apply(input, scale) - probs = probs.view(*data_size) - else: - assert self.attn_mask_type == AttnMaskType.padding - probs = ScaledMaskedSoftmax.apply(input, mask, scale) + + if torch.npu.is_available() and self.fused: + return self.forward_fused_softmax(input, mask) + + return self.forward_torch_softmax(input, mask) + + def forward_fused_softmax(self, input, mask): + if self.softmax_in_fp32: + input = input.float() + + if self.scale is None: + self.scale = 1.0 + + if self.attn_mask_type == AttnMaskType.causal: + if self.mask_tri is None: + self.mask_tri = torch.triu(torch.ones(input.shape, device=input.device), diagonal=1).bool() + probs = torch_npu.npu_scaled_masked_softmax(input, self.mask_tri, self.scale, False) else: - if self.input_in_float16 and self.softmax_in_fp32: - input = input.float() + probs = torch_npu.npu_scaled_masked_softmax(input, mask, self.scale, False) + + probs = probs.half() + + return probs + + def forward_torch_softmax(self, input, mask): + if self.input_in_float16 and self.softmax_in_fp32: + input = input.float() + + if self.scale is not None: + input = input * self.scale - if self.scale is not None: - input = input * self.scale + if self.attn_mask_type == AttnMaskType.causal: + mask_tri = torch.triu(torch.ones(input.shape, device=input.device), diagonal=1).bool() + mask_output = self.mask_func(input, mask_tri) + else: mask_output = self.mask_func(input, mask) if mask is not None else input - probs = torch.nn.Softmax(dim=-1)(mask_output) - if self.input_in_float16 and self.softmax_in_fp32: - if self.input_in_fp16: - probs = probs.half() - else: - probs = probs.bfloat16() + probs = torch.nn.Softmax(dim=-1)(mask_output) + + if self.input_in_float16 and self.softmax_in_fp32: + if self.input_in_fp16: + probs = probs.half() + else: + probs = probs.bfloat16() return probs + +# class FusedScaleMaskSoftmax(torch.nn.Module): +# """ +# fused operation: scaling + mask + softmax +# Arguments: +# input_in_fp16: flag to indicate if input in fp16 data format. +# attn_mask_type: attention mask type (pad or causal) +# mask_func: mask function to be applied. +# softmax_in_fp32: if true, softmax in performed at fp32 precision. +# scale: scaling factor used in input tensor scaling. +# +# """ +# +# def __init__( +# self, +# input_in_fp16, +# input_in_bf16, +# attn_mask_type, +# scaled_masked_softmax_fusion, +# mask_func, +# softmax_in_fp32, +# scale, +# ): +# super(FusedScaleMaskSoftmax, self).__init__() +# self.input_in_fp16 = input_in_fp16 +# self.input_in_bf16 = input_in_bf16 +# assert not (self.input_in_fp16 and self.input_in_bf16),\ +# 'both fp16 and bf16 flags cannot be active at the same time.' +# self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16 +# self.attn_mask_type = attn_mask_type +# self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion +# self.mask_func = mask_func +# self.softmax_in_fp32 = softmax_in_fp32 +# self.scale = scale +# +# assert ( +# self.scale is None or softmax_in_fp32 +# ), "softmax should be in fp32 when scaled" +# +# def forward(self, input, mask): +# # [b, np, sq, sk] +# assert input.dim() == 4 +# data_size = input.size() +# query_seq_len = data_size[-2] +# key_seq_len = data_size[-1] +# attn_batch_size = data_size[0] * data_size[1] +# +# # constraints on various tensor dimensions to enable warp based +# # optimization and upper triangular optimization (for causal mask) +# custom_kernel_constraint = key_seq_len > 16 and key_seq_len <= 2048 and \ +# query_seq_len % 4 == 0 and attn_batch_size % 4 == 0 +# +# # invoke custom kernel +# if self.input_in_float16 and mask is not None and \ +# custom_kernel_constraint and self.scaled_masked_softmax_fusion: +# scale = self.scale if self.scale is not None else 1.0 +# +# if self.attn_mask_type == AttnMaskType.causal: +# assert query_seq_len == key_seq_len, \ +# "causal mask is only for self attention" +# input = input.view(-1, query_seq_len, key_seq_len) +# probs = ScaledUpperTriangMaskedSoftmax.apply(input, scale) +# probs = probs.view(*data_size) +# else: +# assert self.attn_mask_type == AttnMaskType.padding +# probs = ScaledMaskedSoftmax.apply(input, mask, scale) +# else: +# if self.input_in_float16 and self.softmax_in_fp32: +# input = input.float() +# +# if self.scale is not None: +# input = input * self.scale +# mask_output = self.mask_func(input, mask) if mask is not None else input +# probs = torch.nn.Softmax(dim=-1)(mask_output) +# if self.input_in_float16 and self.softmax_in_fp32: +# if self.input_in_fp16: +# probs = probs.half() +# else: +# probs = probs.bfloat16() +# +# return probs diff --git a/megatron/model/llama_model.py b/megatron/model/llama_model.py index b523c79cb..213a77dd4 100644 --- a/megatron/model/llama_model.py +++ b/megatron/model/llama_model.py @@ -41,7 +41,8 @@ from deepspeed.pipe import PipelineModule, LayerSpec class RotaryEmbedding(torch.nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) + # inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) + inv_freq = 1.0 / (torch.tensor(base).double() ** (torch.arange(0, dim, 2).float().to(device) / dim).double()) self.register_buffer("inv_freq", inv_freq) # Build here to make `torch.jit.trace` work. @@ -271,9 +272,7 @@ class LlamaParallelMLP(MegatronModule): enable_expert_tensor_parallelism=enable_expert_tensor_parallelism) def forward(self, hidden_states): - intermediate_parallel = self.gate_proj(hidden_states)[0] * self.up_proj(hidden_states)[0] - - intermediate_parallel = self.activation_func(intermediate_parallel) + intermediate_parallel = self.activation_func(self.gate_proj(hidden_states)[0]) * self.up_proj(hidden_states)[0] output, _ = self.down_proj(intermediate_parallel) return output @@ -854,7 +853,7 @@ class LlamaModelPipe(PipelineModule, MegatronModule): self.specs.append(LayerSpec(RMSNorm, args.hidden_size, eps=args.layernorm_epsilon)) self.specs.append( - LayerSpec(LlamaLMHeadPipe, hidden_size=args.hidden_size, vocab_size=padded_vocab_size, + LayerSpec(LlamaLMHeadPipe, hidden_size=args.hidden_size, vocab_size=args.padded_vocab_size, init_method=self.init_method, parallel_output=self.parallel_output) ) @@ -883,7 +882,7 @@ class LlamaModel(MegatronModule): """llama Language model.""" def __init__(self, pre_process, post_process, parallel_output=True, add_pooler=False): - super(LlamaModel, self).__init__() + super(LlamaModel, self).__init__(share_word_embeddings=False) args = get_args() self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy self.hidden_size = args.hidden_size @@ -902,7 +901,7 @@ class LlamaModel(MegatronModule): vocab_size=self.padded_vocab_size) # Transformer. - self.transformer = LlamaParallelTransformer( + self.language_model = LlamaParallelTransformer( self.init_method, self.output_layer_init_method, self_attn_mask_type=self.self_attn_mask_type, @@ -922,7 +921,7 @@ class LlamaModel(MegatronModule): def set_input_tensor(self, input_tensor): """See megatron.model.transformer.set_input_tensor()""" - self.transformer.set_input_tensor(input_tensor) + self.language_model.set_input_tensor(input_tensor) def forward(self, input_ids, attention_mask, labels=None, layer_past=None, get_key_value=False): args = get_args() @@ -933,7 +932,7 @@ class LlamaModel(MegatronModule): hidden_states = input_ids # decoder - hidden_states = self.transformer(hidden_states, attention_mask, layer_past=layer_past, + hidden_states = self.language_model(hidden_states, attention_mask, layer_past=layer_past, get_key_value=get_key_value) if self.post_process: @@ -958,4 +957,4 @@ class LlamaModel(MegatronModule): loss = mpu.vocab_parallel_cross_entropy(hidden_states.float(), labels) return loss - return hidden_states \ No newline at end of file + return hidden_states diff --git a/megatron/model/module.py b/megatron/model/module.py index 9f91c8bd1..3d5b783ff 100644 --- a/megatron/model/module.py +++ b/megatron/model/module.py @@ -121,30 +121,28 @@ def conversion_helper(val, conversion): rtn = tuple(rtn) return rtn - def fp32_to_float16(val, float16_convertor): - """Convert fp32 `val` to fp16/bf16""" def half_conversion(val): val_typecheck = val - if isinstance(val_typecheck, (Parameter, Variable)): + if isinstance(val_typecheck, (torch.nn.parameter.Parameter, torch.autograd.Variable)): val_typecheck = val.data - if isinstance(val_typecheck, _FLOAT_TYPES): + if val_typecheck.dtype == torch.float32: val = float16_convertor(val) return val + return conversion_helper(val, half_conversion) def float16_to_fp32(val): - """Convert fp16/bf16 `val` to fp32""" def float_conversion(val): val_typecheck = val - if isinstance(val_typecheck, (Parameter, Variable)): + if isinstance(val_typecheck, (torch.nn.parameter.Parameter, torch.autograd.Variable)): val_typecheck = val.data - if isinstance(val_typecheck, (_BF16_TYPES, _HALF_TYPES)): + if val_typecheck.dtype in [torch.float16, torch.bfloat16]: val = val.float() return val - return conversion_helper(val, float_conversion) + return conversion_helper(val, float_conversion) class Float16Module(MegatronModule): diff --git a/megatron/model/multiple_choice.py b/megatron/model/multiple_choice.py new file mode 100644 index 000000000..f82948f80 --- /dev/null +++ b/megatron/model/multiple_choice.py @@ -0,0 +1,130 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Multiple choice model.""" + +import torch + +from megatron import get_args, print_rank_last +from megatron import mpu +from megatron.model.enums import AttnMaskType +from megatron.model.bert_model import bert_extended_attention_mask, bert_position_ids +from megatron.model.language_model import get_language_model +from megatron.model.utils import get_linear_layer +from megatron.model.utils import init_method_normal +from megatron.model.utils import scaled_init_method_normal +from .module import MegatronModule + + +class MultipleChoice(MegatronModule): + + def __init__(self, + num_tokentypes=2, + pre_process=True, + post_process=True): + super(MultipleChoice, self).__init__(share_word_embeddings=False) + args = get_args() + + init_method = init_method_normal(args.init_method_std) + self.pre_process = pre_process + self.post_process = post_process + + self.language_model, self._language_model_key = get_language_model( + num_tokentypes=num_tokentypes, + add_pooler=True, + encoder_attn_mask_type=AttnMaskType.padding, + init_method=init_method, + scaled_init_method=scaled_init_method_normal(args.init_method_std, + args.num_layers), + pre_process=self.pre_process, + post_process=self.post_process) + + # Multi-choice head. + if self.post_process: + self.multichoice_dropout = torch.nn.Dropout(args.hidden_dropout) + self.multichoice_head = get_linear_layer(args.hidden_size, 1, + init_method) + self._multichoice_head_key = 'multichoice_head' + + def set_input_tensor(self, input_tensor): + """See megatron.model.transformer.set_input_tensor()""" + self.language_model.set_input_tensor(input_tensor) + + def forward(self, model_input, attention_mask, tokentype_ids=None): + + # [batch, choices, sequence] --> [batch * choices, sequence] --> + # transformer --> [batch, choices] --> softmax + + # Ensure the shape is [batch-size, choices, sequence] + assert len(attention_mask.shape) == 3 + num_choices = attention_mask.shape[1] + + # Reshape and treat choice dimension the same as batch. + attention_mask = attention_mask.view(-1, attention_mask.size(-1)) + extended_attention_mask = bert_extended_attention_mask(attention_mask) + + input_ids = model_input + # Do the same as attention_mask for input_ids, tokentype_ids + assert len(input_ids.shape) == 3 + assert len(tokentype_ids.shape) == 3 + input_ids = input_ids.view(-1, input_ids.size(-1)) + tokentype_ids = tokentype_ids.view(-1, tokentype_ids.size(-1)) + position_ids = bert_position_ids(input_ids) + + lm_output = self.language_model( + input_ids, + position_ids, + extended_attention_mask, + tokentype_ids=tokentype_ids + ) + if self.post_process: + _, pooled_output = lm_output[0], lm_output[1] + multichoice_output = self.multichoice_dropout(pooled_output) + multichoice_logits = self.multichoice_head(multichoice_output) + + # Reshape back to separate choices. + multichoice_logits = multichoice_logits.view(-1, num_choices) + + return multichoice_logits + return lm_output + + def state_dict_for_save_checkpoint(self, destination=None, prefix='', + keep_vars=False): + """For easy load when model is combined with other heads, + add an extra key.""" + + state_dict_ = {} + state_dict_[self._language_model_key] \ + = self.language_model.state_dict_for_save_checkpoint( + destination, prefix, keep_vars) + if self.post_process: + state_dict_[self._multichoice_head_key] \ + = self.multichoice_head.state_dict( + destination, prefix, keep_vars) + return state_dict_ + + def load_state_dict(self, state_dict, strict=True): + """Customized load.""" + + self.language_model.load_state_dict( + state_dict[self._language_model_key], strict=strict) + if self.post_process: + if self._multichoice_head_key in state_dict: + self.multichoice_head.load_state_dict( + state_dict[self._multichoice_head_key], strict=strict) + else: + print_rank_last('***WARNING*** could not find {} in the checkpoint, ' + 'initializing to random'.format( + self._multichoice_head_key)) diff --git a/megatron/model/realm_model.py b/megatron/model/realm_model.py new file mode 100644 index 000000000..e74eb2e58 --- /dev/null +++ b/megatron/model/realm_model.py @@ -0,0 +1,204 @@ +import os +import torch + +from megatron import get_args, print_rank_0 +from megatron.checkpointing import get_checkpoint_tracker_filename, get_checkpoint_name +from megatron.model import BertModel +from .module import MegatronModule +from megatron import mpu +from megatron.model.enums import AttnMaskType +from megatron.model.utils import get_linear_layer +from megatron.model.utils import init_method_normal +from megatron.model.language_model import get_language_model +from megatron.model.utils import scaled_init_method_normal +from megatron.model.bert_model import bert_extended_attention_mask, bert_position_ids +from deepspeed.accelerator import get_accelerator + +def general_ict_model_provider(only_query_model=False, only_block_model=False): + """Build the model.""" + args = get_args() + assert args.ict_head_size is not None, \ + "Need to specify --ict-head-size to provide an ICTBertModel" + assert mpu.get_tensor_model_parallel_world_size() == 1 and mpu.get_pipeline_model_parallel_world_size() == 1, \ + "Model parallel size > 1 not supported for ICT" + + print_rank_0('building ICTBertModel...') + + # simpler to just keep using 2 tokentypes since the LM we initialize with has 2 tokentypes + model = ICTBertModel( + ict_head_size=args.ict_head_size, + num_tokentypes=2, + parallel_output=True, + only_query_model=only_query_model, + only_block_model=only_block_model) + + return model + + +class ICTBertModel(MegatronModule): + """Bert-based module for Inverse Cloze task.""" + def __init__(self, + ict_head_size, + num_tokentypes=1, + parallel_output=True, + only_query_model=False, + only_block_model=False): + super(ICTBertModel, self).__init__() + bert_kwargs = dict( + ict_head_size=ict_head_size, + num_tokentypes=num_tokentypes, + parallel_output=parallel_output + ) + assert not (only_block_model and only_query_model) + self.use_block_model = not only_query_model + self.use_query_model = not only_block_model + + if self.use_query_model: + # this model embeds (pseudo-)queries - Embed_input in the paper + self.query_model = IREncoderBertModel(**bert_kwargs) + self._query_key = 'question_model' + + if self.use_block_model: + # this model embeds evidence blocks - Embed_doc in the paper + self.block_model = IREncoderBertModel(**bert_kwargs) + self._block_key = 'context_model' + + def forward(self, query_tokens, query_attention_mask, block_tokens, block_attention_mask): + """Run a forward pass for each of the models and return the respective embeddings.""" + query_logits = self.embed_query(query_tokens, query_attention_mask) + block_logits = self.embed_block(block_tokens, block_attention_mask) + return query_logits, block_logits + + def embed_query(self, query_tokens, query_attention_mask): + """Embed a batch of tokens using the query model""" + if self.use_query_model: + query_types = get_accelerator().LongTensor(*query_tokens.shape).fill_(0) + query_ict_logits, _ = self.query_model.forward(query_tokens, query_attention_mask, query_types) + return query_ict_logits + else: + raise ValueError("Cannot embed query without query model.") + + def embed_block(self, block_tokens, block_attention_mask): + """Embed a batch of tokens using the block model""" + if self.use_block_model: + block_types = get_accelerator().LongTensor(*block_tokens.shape).fill_(0) + block_ict_logits, _ = self.block_model.forward(block_tokens, block_attention_mask, block_types) + return block_ict_logits + else: + raise ValueError("Cannot embed block without block model.") + + def state_dict_for_save_checkpoint(self, destination=None, prefix='', keep_vars=False): + """Save dict with state dicts of each of the models.""" + state_dict_ = {} + if self.use_query_model: + state_dict_[self._query_key] \ + = self.query_model.state_dict_for_save_checkpoint( + destination, prefix, keep_vars) + + if self.use_block_model: + state_dict_[self._block_key] \ + = self.block_model.state_dict_for_save_checkpoint( + destination, prefix, keep_vars) + + return state_dict_ + + def load_state_dict(self, state_dict, strict=True): + """Load the state dicts of each of the models""" + if self.use_query_model: + print("Loading ICT query model", flush=True) + self.query_model.load_state_dict( + state_dict[self._query_key], strict=strict) + + if self.use_block_model: + print("Loading ICT block model", flush=True) + self.block_model.load_state_dict( + state_dict[self._block_key], strict=strict) + + def init_state_dict_from_bert(self): + """Initialize the state from a pretrained BERT model on iteration zero of ICT pretraining""" + args = get_args() + tracker_filename = get_checkpoint_tracker_filename(args.bert_load) + if not os.path.isfile(tracker_filename): + raise FileNotFoundError("Could not find BERT load for ICT") + with open(tracker_filename, 'r') as f: + iteration = int(f.read().strip()) + assert iteration > 0 + + checkpoint_name = get_checkpoint_name(args.bert_load, iteration, False) + if mpu.get_data_parallel_rank() == 0: + print('global rank {} is loading checkpoint {}'.format( + torch.distributed.get_rank(), checkpoint_name)) + + try: + state_dict = torch.load(checkpoint_name, map_location='cpu') + except BaseException: + raise ValueError("Could not load checkpoint") + + # load the LM state dict into each model + model_dict = state_dict['model']['language_model'] + self.query_model.language_model.load_state_dict(model_dict) + self.block_model.language_model.load_state_dict(model_dict) + + # give each model the same ict_head to begin with as well + query_ict_head_state_dict = self.state_dict_for_save_checkpoint()[self._query_key]['ict_head'] + self.block_model.ict_head.load_state_dict(query_ict_head_state_dict) + + +class IREncoderBertModel(MegatronModule): + """BERT-based encoder for queries or blocks used for learned information retrieval.""" + def __init__(self, ict_head_size, num_tokentypes=2, parallel_output=True): + super(IREncoderBertModel, self).__init__() + args = get_args() + + self.ict_head_size = ict_head_size + self.parallel_output = parallel_output + init_method = init_method_normal(args.init_method_std) + scaled_init_method = scaled_init_method_normal(args.init_method_std, + args.num_layers) + + self.language_model, self._language_model_key = get_language_model( + num_tokentypes=num_tokentypes, + add_pooler=True, + encoder_attn_mask_type=AttnMaskType.padding, + init_method=init_method, + scaled_init_method=scaled_init_method) + + self.ict_head = get_linear_layer(args.hidden_size, ict_head_size, init_method) + self._ict_head_key = 'ict_head' + + def forward(self, input_ids, attention_mask, tokentype_ids=None): + extended_attention_mask = bert_extended_attention_mask( + attention_mask, next(self.language_model.parameters()).dtype) + position_ids = bert_position_ids(input_ids) + + lm_output, pooled_output = self.language_model( + input_ids, + position_ids, + extended_attention_mask, + tokentype_ids=tokentype_ids) + + # Output. + ict_logits = self.ict_head(pooled_output) + return ict_logits, None + + def state_dict_for_save_checkpoint(self, destination=None, prefix='', + keep_vars=False): + """For easy load when model is combined with other heads, + add an extra key.""" + + state_dict_ = {} + state_dict_[self._language_model_key] \ + = self.language_model.state_dict_for_save_checkpoint( + destination, prefix, keep_vars) + state_dict_[self._ict_head_key] \ + = self.ict_head.state_dict(destination, prefix, keep_vars) + return state_dict_ + + def load_state_dict(self, state_dict, strict=True): + """Customized load.""" + self.language_model.load_state_dict( + state_dict[self._language_model_key], strict=strict) + self.ict_head.load_state_dict( + state_dict[self._ict_head_key], strict=strict) + + diff --git a/megatron/model/t5_model.py b/megatron/model/t5_model.py new file mode 100644 index 000000000..beb4f0ee5 --- /dev/null +++ b/megatron/model/t5_model.py @@ -0,0 +1,174 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""T5 model.""" + +import torch + +from megatron import ( + get_args, + mpu +) +from megatron.model.enums import AttnMaskType +from megatron.model.language_model import parallel_lm_logits, get_language_model +from megatron.model.transformer import LayerNorm +from megatron.model.utils import ( + openai_gelu, + get_linear_layer, + init_method_normal, + scaled_init_method_normal +) +from .module import MegatronModule + + +def t5_extended_attention_mask(attention_mask_list): + + def attn_mask_postprocess(attn_mask): + # [b, 1, s, s] + extended_attention_mask = attn_mask.unsqueeze(1) + return extended_attention_mask + + return [attn_mask_postprocess(attn_mask) for attn_mask in attention_mask_list] + + +def t5_position_ids(token_ids): + # Create position ids + seq_length = token_ids.size(1) + position_ids = torch.arange(seq_length, dtype=torch.long, + device=token_ids.device) + position_ids = position_ids.unsqueeze(0).expand_as(token_ids) + + return position_ids + + +class T5LMHead(MegatronModule): + """Masked LM head for T5 + + Arguments: + mpu_vocab_size: model parallel size of vocabulary. + hidden_size: hidden size + init_method: init method for weight initialization + layernorm_epsilon: tolerance for layer norm divisions + parallel_output: wether output logits being distributed or not. + """ + + def __init__(self, mpu_vocab_size, parallel_output): + super(T5LMHead, self).__init__() + + args = get_args() + + self.bias = torch.nn.Parameter(torch.zeros(mpu_vocab_size)) + self.bias.model_parallel = True + self.bias.partition_dim = 0 + self.bias.stride = 1 + self.parallel_output = parallel_output + + def forward(self, hidden_states, word_embeddings_weight): + output = parallel_lm_logits(hidden_states, + word_embeddings_weight, + self.parallel_output, + bias=self.bias) + return output + + +class T5Model(MegatronModule): + """T5 Language model.""" + + def __init__(self, num_tokentypes=0, parallel_output=True): + super(T5Model, self).__init__() + args = get_args() + + self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy + self.parallel_output = parallel_output + init_method = init_method_normal(args.init_method_std) + scaled_init_method = scaled_init_method_normal(args.init_method_std, + args.num_layers) + + self.language_model, self._language_model_key = get_language_model( + num_tokentypes=num_tokentypes, + add_pooler=False, + add_decoder=True, + encoder_attn_mask_type=AttnMaskType.padding, + init_method=init_method, + scaled_init_method=scaled_init_method) + + self.lm_head = T5LMHead( + self.language_model.embedding.word_embeddings.weight.size(0), + parallel_output) + self._lm_head_key = 'lm_head' + + def set_input_tensor(self, input_tensor): + """See megatron.model.transformer.set_input_tensor()""" + self.language_model.set_input_tensor(input_tensor) + + def forward(self, encoder_input_ids, decoder_input_ids, encoder_attn_mask, + decoder_attn_mask, encoder_decoder_attn_mask, + tokentype_ids=None, lm_labels=None, enc_hidden_states=None): + + # Converting the attention masks to proper parameter settings + encoder_attn_mask, decoder_attn_mask, encoder_decoder_attn_mask = t5_extended_attention_mask( + [encoder_attn_mask, decoder_attn_mask, encoder_decoder_attn_mask]) + + encoder_position_ids = t5_position_ids(encoder_input_ids) + decoder_position_ids = t5_position_ids(decoder_input_ids) + + lm_output = self.language_model(encoder_input_ids, + encoder_position_ids, + encoder_attn_mask, + decoder_input_ids, + decoder_position_ids, + decoder_attn_mask, + encoder_decoder_attn_mask, + tokentype_ids=tokentype_ids, + enc_hidden_states=enc_hidden_states) + + decoder_output, encoder_output = lm_output + + # Output. + lm_logits = self.lm_head(decoder_output, + self.language_model.embedding.word_embeddings.weight) + + if lm_labels is None: + return lm_logits, encoder_output + else: + if self.fp16_lm_cross_entropy: + assert lm_logits.dtype == torch.half + lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits, lm_labels) + else: + lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits.float(), + lm_labels) + return lm_loss, encoder_output + + def state_dict_for_save_checkpoint(self, destination=None, prefix='', + keep_vars=False): + """For easy load when model is combined with other heads, + add an extra key.""" + + state_dict_ = {} + state_dict_[self._language_model_key] \ + = self.language_model.state_dict_for_save_checkpoint( + destination, prefix, keep_vars) + state_dict_[self._lm_head_key] \ + = self.lm_head.state_dict_for_save_checkpoint( + destination, prefix, keep_vars) + return state_dict_ + + def load_state_dict(self, state_dict, strict=True): + """Customized load.""" + + self.language_model.load_state_dict( + state_dict[self._language_model_key], strict=strict) + self.lm_head.load_state_dict(state_dict[self._lm_head_key], + strict=strict) diff --git a/megatron/mpu/layers.py b/megatron/mpu/layers.py index 5d168c672..c08168340 100644 --- a/megatron/mpu/layers.py +++ b/megatron/mpu/layers.py @@ -43,7 +43,6 @@ _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = {'tensor_model_parallel': False, 'partition_dim': -1, 'partition_stride': 1} - def param_is_not_tensor_parallel_duplicate(param): return (hasattr(param, 'tensor_model_parallel') and param.tensor_model_parallel) or ( diff --git a/megatron/mpu/random.py b/megatron/mpu/random.py index beea791b5..7de26f467 100644 --- a/megatron/mpu/random.py +++ b/megatron/mpu/random.py @@ -168,33 +168,33 @@ class CudaRNGStatesTracker: if name in self.states_: raise Exception('cuda rng state {} already exists'.format(name)) # Get the current rng state. - orig_rng_state = get_accelerator().get_rng_state() + # orig_rng_state = torch.cuda.get_rng_state() # Set the new state and store it. - get_accelerator().manual_seed(seed) - self.states_[name] = get_accelerator().get_rng_state() + torch.cuda.manual_seed(seed) + # self.states_[name] = torch.cuda.get_rng_state() # Reset rng state to what it was. - _set_cuda_rng_state(orig_rng_state) + # _set_cuda_rng_state(orig_rng_state) @contextlib.contextmanager def fork(self, name=_MODEL_PARALLEL_RNG_TRACKER_NAME): - """Fork the cuda rng state, perform operations, and exit with - the original state.""" - # Check if we have added the state - if name not in self.states_: - print(name, self.states_) - raise Exception('cuda rng state {} is not added'.format(name)) + yield + # """Fork the cuda rng state, perform operations, and exit with + # the original state.""" + # # Check if we have added the state + # if name not in self.states_: + # raise Exception('cuda rng state {} is not added'.format(name)) # Store current rng state. - orig_cuda_rng_state = get_accelerator().get_rng_state() + # orig_cuda_rng_state = torch.cuda.get_rng_state() # Set rng state to the desired one - _set_cuda_rng_state(self.states_[name]) + # _set_cuda_rng_state(self.states_[name]) # Do the stuff we wanted to do. - try: - yield - finally: - # Update the current rng state for later use. - self.states_[name] = get_accelerator().get_rng_state() - # And set the state to the original state we started with. - _set_cuda_rng_state(orig_cuda_rng_state) + # try: + # yield + # finally: + # # Update the current rng state for later use. + # self.states_[name] = torch.cuda.get_rng_state() + # # And set the state to the original state we started with. + # _set_cuda_rng_state(orig_cuda_rng_state) # RNG tracker object. @@ -310,7 +310,7 @@ class CheckpointFunction(torch.autograd.Function): if isinstance(outputs, torch.Tensor): outputs = (outputs,) elif len(outputs) == 2 and isinstance(outputs[1], torch.Tensor) and \ - torch.equal(outputs[1], torch.tensor(0).to(get_accelerator().device_name())): + torch.equal(outputs[1], torch.tensor(0, dtype=outputs[1].dtype).to(get_accelerator().device_name())): # a hacky solution to overcome issue when running old script examples/pretrain_gpt_distributed.sh outputs = (outputs[0],) torch.autograd.backward(outputs, args) diff --git a/megatron/optimizer/clip_grads.py b/megatron/optimizer/clip_grads.py index fdec6ff80..1563966ed 100644 --- a/megatron/optimizer/clip_grads.py +++ b/megatron/optimizer/clip_grads.py @@ -45,9 +45,13 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2): Arguments: parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a single Tensor that will have gradients normalized + grads_for_norm (Iterable[Tensor]): an iterable of Tensors or a single + Tensor that will be used for calculating the grad norm. max_norm (float or int): max norm of the gradients norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for infinity norm. + model_parallel_group (group): given the nature of the distributed + optimizer, this is passed as an argument. Returns: Total norm of the parameters (viewed as a single vector). @@ -55,11 +59,10 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2): if isinstance(parameters, torch.Tensor): parameters = [parameters] + # if isinstance(grads_for_norm, torch.Tensor): + # grads_for_norm = [grads_for_norm] - # Filter parameters based on: - # - grad should not be none - # - parameter should not be shared - # - should not be a replica due to tensor model parallelism + # Grads. grads = [] grads_for_norm = [] for param in parameters: @@ -69,11 +72,12 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2): grad = param.grad.detach() if grad_not_none: # Make sure the grads are in fp32 - assert param.grad.type() == 'torch.{}.FloatTensor'.format(get_accelerator().device_name()) + # assert param.grad.type() == 'torch.{}.FloatTensor'.format(get_accelerator().device_name()) grads.append(grad) if grad_not_none and is_not_shared and is_not_tp_duplicate: grads_for_norm.append(grad) + # Norm parameters. max_norm = float(max_norm) norm_type = float(norm_type) @@ -82,35 +86,16 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2): # Calculate norm. if norm_type == inf: total_norm = max(grad.abs().max() for grad in grads_for_norm) - total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)]) + total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) # Take max across all model-parallel GPUs. torch.distributed.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=mpu.get_model_parallel_group()) total_norm = total_norm_cuda[0].item() - else: - if norm_type == 2.0: - if get_accelerator().device_name() == 'cuda': - dummy_overflow_buf = get_accelerator().IntTensor([0]) - # Use apex's multi-tensor applier for efficiency reasons. - # Multi-tensor applier takes a function and a list of list - # and performs the operation on that list all in one kernel. - grad_norm, _ = multi_tensor_applier( - amp_C.multi_tensor_l2norm, - dummy_overflow_buf, - [grads_for_norm], - False # no per-parameter norm - ) - else: - grad_norm = torch.norm(grads_for_norm,p=2.0) - # Since we will be summing across data parallel groups, - # we need the pow(norm-type). - total_norm = grad_norm ** norm_type - else: - for grad in grads_for_norm: - grad_norm = torch.norm(grad, norm_type) - total_norm += grad_norm ** norm_type + for grad in grads_for_norm: + grad_norm = torch.norm(grad, norm_type) + total_norm += grad_norm ** norm_type # Sum across all model-parallel GPUs. torch.distributed.all_reduce(total_norm, @@ -121,15 +106,8 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2): # Scale. clip_coeff = max_norm / (total_norm + 1.0e-6) if clip_coeff < 1.0: - if get_accelerator().device_name() == 'cuda': - dummy_overflow_buf = get_accelerator().IntTensor([0]) - multi_tensor_applier(amp_C.multi_tensor_scale, - dummy_overflow_buf, - [grads, grads], - clip_coeff) - else: - for g in grads: - g.detach().mul_(clip_coeff.to(g.device)) + for p in parameters: + p.grad.detach().mul_(clip_coeff) return total_norm diff --git a/megatron/optimizer/optimizer.py b/megatron/optimizer/optimizer.py index 6568bf10a..8c5f7a7db 100644 --- a/megatron/optimizer/optimizer.py +++ b/megatron/optimizer/optimizer.py @@ -239,11 +239,15 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer): # For all the parameters in this group: for i, param in enumerate(param_group['params']): if param.requires_grad: + if param.type() == "torch.cuda.HalfTensor": + param_type = "torch.npu.HalfTensor" + elif param.type() == "torch.cuda.BFloat16Tensor": + param_type = "torch.npu.BFloat16Tensor" + elif param.type() == "torch.cuda.FloatTensor": + param_type = "torch.npu.FloatTensor" # float16 params: - - - if param.type() in ['torch.{}.HalfTensor'.format(get_accelerator().device_name()), + if param_type in ['torch.{}.HalfTensor'.format(get_accelerator().device_name()), 'torch.{}.BFloat16Tensor'.format(get_accelerator().device_name())]: float16_params_this_group.append(param) # Create a copy @@ -262,7 +266,7 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer): = self.optimizer.state.pop(param) # fp32 params. - elif param.type() == 'torch.{}.FloatTensor'.format(format(get_accelerator().device_name())): + elif param_type == 'torch.{}.FloatTensor'.format(format(get_accelerator().device_name())): fp32_params_this_group.append(param) param_group['params'][i] = param @@ -304,6 +308,7 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer): for model_group, main_group in zip(self.float16_groups, self.fp32_from_float16_groups): for model_param, main_param in zip(model_group, main_group): + # if self.params_have_main_grad: if self.params_have_main_grad: main_param.grad = model_param.main_grad.float() else: diff --git a/megatron/p2p_communication.py b/megatron/p2p_communication.py index 21df8b2b3..07e66c801 100644 --- a/megatron/p2p_communication.py +++ b/megatron/p2p_communication.py @@ -94,16 +94,16 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next, torch.distributed.irecv, tensor_recv_prev, mpu.get_pipeline_model_parallel_prev_rank()) ops.append(recv_prev_op) - if tensor_send_next is not None: - send_next_op = torch.distributed.P2POp( - torch.distributed.isend, tensor_send_next, - mpu.get_pipeline_model_parallel_next_rank()) - ops.append(send_next_op) if tensor_recv_next is not None: recv_next_op = torch.distributed.P2POp( torch.distributed.irecv, tensor_recv_next, mpu.get_pipeline_model_parallel_next_rank()) ops.append(recv_next_op) + if tensor_send_next is not None: + send_next_op = torch.distributed.P2POp( + torch.distributed.isend, tensor_send_next, + mpu.get_pipeline_model_parallel_next_rank()) + ops.append(send_next_op) if len(ops) > 0: reqs = torch.distributed.batch_isend_irecv(ops) for req in reqs: diff --git a/megatron/schedules.py b/megatron/schedules.py index a072ac82b..57d73c05c 100644 --- a/megatron/schedules.py +++ b/megatron/schedules.py @@ -15,6 +15,7 @@ from contextlib import contextmanager import torch +import torch_npu from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP from megatron import get_args @@ -27,6 +28,25 @@ from megatron.utils import unwrap_model from megatron.model import DistributedDataParallel as LocalDDP from megatron.model import Float16Module + +def clear_npu_overflow_flag(): + float_status = torch.zeros(8).npu() + result = torch_npu.npu_clear_float_status(float_status) + + +def get_npu_overflow_flag(): + float_status = torch.zeros(8).npu() + result = torch_npu.npu_get_float_status(float_status) + if float_status.cpu()[0] != 0: + return True + else: + return False + + +def set_npu_overflow_flag(): + torch.tensor([65504]).half().npu() + 100 + + def get_forward_backward_func(): args = get_args() if mpu.get_pipeline_model_parallel_world_size() > 1: @@ -96,6 +116,7 @@ def backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad, mo if input_tensor is not None: input_tensor.retain_grad() + clear_npu_overflow_flag() if args.deepspeed: model.backward(output_tensor) else: @@ -113,7 +134,6 @@ def backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad, mo return input_tensor_grad - @contextmanager def dummy_handler(): try: @@ -131,35 +151,33 @@ def forward_backward_no_pipelining(forward_step_func, data_iterator, model, assert len(model) == 1 model = model[0] - args = get_args() - context_handler = dummy_handler if isinstance(model, torchDDP): context_handler = model.no_sync - if args.deepspeed: - model.set_gradient_accumulation_boundary(False) - losses_reduced = [] input_tensor, output_tensor_grad = None, None + overflow_flag_all = False with context_handler(): for i in range(get_num_microbatches() - 1): output_tensor = forward_step(forward_step_func, data_iterator, model, input_tensor, losses_reduced) if not forward_only: backward_step(optimizer, input_tensor, output_tensor, - output_tensor_grad, model) - - if args.deepspeed: - model.set_gradient_accumulation_boundary(True) + output_tensor_grad) - # Run computation for last microbatch out of context handler (want to - # synchronize gradients). + overflow_flag = get_npu_overflow_flag() + overflow_flag_all = overflow_flag or overflow_flag_all output_tensor = forward_step(forward_step_func, data_iterator, model, input_tensor, losses_reduced) if not forward_only: - backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad, model) + backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad) + + overflow_flag = get_npu_overflow_flag() + overflow_flag_all = overflow_flag or overflow_flag_all + if overflow_flag_all: + set_npu_overflow_flag() return losses_reduced diff --git a/megatron/training.py b/megatron/training.py index 94133e7ae..728136256 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -136,6 +136,7 @@ def pretrain(train_valid_test_dataset_provider, model, optimizer, lr_scheduler = setup_model_and_optimizer( model_provider, teacher=False, data_post_process=data_post_process, build_train_valid_test_datasets_provider=train_valid_test_dataset_provider) + timers('model-and-optimizer-setup').stop() print_datetime('after model, optimizer, and learning rate ' 'scheduler are built') @@ -488,7 +489,6 @@ def setup_model_and_optimizer(model_provider_func, teacher=False, # Number of train/valid/test samples. if args.train_samples: train_samples = args.train_samples - update_train_iters(args) else: train_samples = args.train_iters * args.global_batch_size # eval_iters and test_iters here are not actually used, only for @@ -1273,7 +1273,6 @@ def build_train_valid_test_data_iterators( # Number of train/valid/test samples. if args.train_samples: train_samples = args.train_samples - update_train_iters(args) else: train_samples = args.train_iters * args.global_batch_size eval_iters = (args.train_iters // args.eval_interval + 1) * \ diff --git a/pretrain_gpt.py b/pretrain_gpt.py index 355c41450..099afbd87 100644 --- a/pretrain_gpt.py +++ b/pretrain_gpt.py @@ -16,6 +16,10 @@ """Pretrain GPT""" import torch +import torch_npu +from torch_npu.contrib import transfer_to_npu +import deepspeed_npu + import math from functools import partial from megatron import get_args @@ -322,7 +326,6 @@ def git_ds_info(): if __name__ == "__main__": - git_ds_info() pretrain(train_valid_test_datasets_provider, model_provider, forward_step, args_defaults={'tokenizer_type': 'GPT2BPETokenizer'}, data_post_process=data_post_process) diff --git a/pretrain_llama.py b/pretrain_llama.py index e69de29bb..3d19a5dba 100644 --- a/pretrain_llama.py +++ b/pretrain_llama.py @@ -0,0 +1,250 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Pretrain Llama""" +import torch +import torch_npu +from torch_npu.contrib import transfer_to_npu +import deepspeed_npu + +import math +from functools import partial +from megatron import get_args +from megatron import print_rank_0 +from megatron import get_timers +from megatron import get_tokenizer +from megatron import mpu +from megatron.data.gpt_dataset import build_train_valid_test_datasets +from megatron.model import LlamaModel, LlamaModelPipe +from megatron.training import pretrain +from megatron.utils import get_ltor_masks_and_position_ids +from megatron.utils import average_losses_across_data_parallel_group + +import deepspeed +from deepspeed.runtime.utils import see_memory_usage +from deepspeed.accelerator.real_accelerator import get_accelerator + +import os +import subprocess + +from torch import nn +import torch.nn.functional as F + + +def model_provider(pre_process=True, post_process=True): + """Build the model.""" + + print_rank_0('building llama model ...') + see_memory_usage(f"Before Building Model", force=True) + + args = get_args() + with deepspeed.zero.Init(data_parallel_group=mpu.get_data_parallel_group(), + remote_device=None if args.remote_device == 'none' else args.remote_device, + config_dict_or_path=args.deepspeed_config, + enabled=args.zero_stage == 3, + mpu=mpu): + if args.deepspeed and not args.no_pipeline_parallel: + model = LlamaModelPipe(parallel_output=True) + # This is a hack to give us a reference to get_batch_pipe from within training.py + # We need to call model.set_batch_fn after deepspeed.initialize + model._megatron_batch_fn = get_batch_pipe + + # Predompute the attention mask and store it in args. This avoids having to + # pipeline it as an activation during training. The mask is constant, and thus + # we can reuse it. + attention_mask = torch.tril(torch.ones( + (1, args.seq_length, args.seq_length), device=get_accelerator().current_device_name())).view( + 1, 1, args.seq_length, args.seq_length) + + # Convert attention mask to binary: + attention_mask = (attention_mask < 0.5) + if args.fp16: + attention_mask = attention_mask.half() + elif args.bf16: + attention_mask = attention_mask.bfloat16() + + # Attention mask must be bool. + args.attn_mask = attention_mask.to(torch.bool) + + else: + model = LlamaModel( + parallel_output=True, + add_pooler=False, + pre_process=pre_process, + post_process=post_process + ) + see_memory_usage(f"After Building Model", force=True) + return model + + +def get_batch(data_iterator): + """Generate a batch""" + args = get_args() + tokenizer = get_tokenizer() + + # Items and their type. + keys = ['text'] + datatype = torch.int64 + + # Broadcast data. + if data_iterator is not None: + data = next(data_iterator) + else: + data = None + data_b = mpu.broadcast_data(keys, data, datatype) + + # Unpack. + tokens_ = data_b['text'].long() + labels = tokens_[:, 1:].contiguous() + tokens = tokens_[:, :-1].contiguous() + + # Get the masks and postition ids. + attention_mask, loss_mask, _ = get_ltor_masks_and_position_ids( + tokens, + tokenizer.eod, + args.reset_position_ids, + args.reset_attention_mask, + args.eod_mask_loss) + + return tokens, labels, loss_mask, attention_mask + + +def data_post_process(data, data_sampler_state_dict): + args = get_args() + if args.data_efficiency_curriculum_learning: + if 'seqlen_truncate' in data_sampler_state_dict['current_difficulties']: + args.data_efficiency_curriculum_learning_seqlen_type = 'seqlen_truncate' + current_seqlen = data_sampler_state_dict['current_difficulties']['seqlen_truncate'] + if current_seqlen < args.seq_length: + data['text'] = data['text'][:, :(current_seqlen + 1)].contiguous() + elif 'seqlen_reshape' in data_sampler_state_dict['current_difficulties']: + args.data_efficiency_curriculum_learning_seqlen_type = 'seqlen_reshape' + current_seqlen = data_sampler_state_dict['current_difficulties']['seqlen_reshape'] + if current_seqlen < args.seq_length: + orig_num_token = torch.numel(data['text']) + reshape_len = (data['text'].size()[1] // (current_seqlen + 1)) * (current_seqlen + 1) + data['text'] = torch.cat((data['text'][:, :reshape_len].contiguous().view(-1, current_seqlen + 1), + data['text'][:, -(current_seqlen + 1):]), 0).contiguous() + num_row = math.ceil(orig_num_token / (current_seqlen + 1)) + num_row = min(num_row, data['text'].size()[0]) + if num_row > 1 and num_row % 2 != 0: + num_row -= 1 + data['text'] = data['text'][:num_row, :].contiguous() + else: + args.data_efficiency_curriculum_learning_seqlen_type = None + return data + + +def get_batch_pipe(data): + """Modification of `get_batch` to work on `next(data_iterator)` instead of `data_iterator`""" + args = get_args() + tokenizer = get_tokenizer() + + # Items and their type. + keys = ['text'] + datatype = torch.int64 + + # Broadcast data. + data_b = mpu.broadcast_data(keys, data, datatype) + + # Unpack. + tokens_ = data_b['text'].long() + labels = tokens_[:, 1:].contiguous() + tokens = tokens_[:, :-1].contiguous() + + # Get the masks and postition ids. + attention_mask, loss_mask, _ = get_ltor_masks_and_position_ids( + tokens, + tokenizer.eod, + args.reset_position_ids, + args.reset_attention_mask, + args.eod_mask_loss) + return (tokens, attention_mask), (labels, loss_mask) + +def loss_func(loss_mask, output_tensor): + args = get_args() + losses = output_tensor.float() + loss_mask = loss_mask.view(-1).float() + loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum() + + # Reduce loss for logging. + averaged_loss = average_losses_across_data_parallel_group([loss]) + return loss, {'lm loss': averaged_loss[0]} + + +def forward_step(data_iterator, model): + """Forward step.""" + args = get_args() + timers = get_timers() + # Get the batch. + timers('batch-generator').start() + tokens, labels, loss_mask, attention_mask = get_batch(data_iterator) + timers('batch-generator').stop() + + output_tensor = model(tokens, attention_mask, labels=labels) + # Output_tensor stores the standard loss, loos_func calculates the total loss. + return output_tensor, partial(loss_func, loss_mask) + + +def train_valid_test_datasets_provider(train_val_test_num_samples): + """Build train, valid, and test datasets.""" + args = get_args() + + print_rank_0('> building train, validation, and test datasets ' + 'for llama ...') + train_ds, valid_ds, test_ds = build_train_valid_test_datasets( + data_prefix=args.data_path, + data_impl=args.data_impl, + splits_string=args.split, + train_valid_test_num_samples=train_val_test_num_samples, + seq_length=args.seq_length, + seed=args.seed, + skip_warmup=(not args.mmap_warmup)) + print_rank_0("> finished creating llama datasets ...") + + return train_ds, valid_ds, test_ds + + +def command_exists(cmd): + result = subprocess.Popen(f'type {cmd}', stdout=subprocess.PIPE, shell=True) + return result.wait() == 0 + + +def git_ds_info(): + from deepspeed.env_report import main as ds_report + ds_report() + + # Write out version/git info + git_hash_cmd = "git rev-parse --short HEAD" + git_branch_cmd = "git rev-parse --abbrev-ref HEAD" + if command_exists('git'): + try: + result = subprocess.check_output(git_hash_cmd, shell=True) + git_hash = result.decode('utf-8').strip() + result = subprocess.check_output(git_branch_cmd, shell=True) + git_branch = result.decode('utf-8').strip() + except subprocess.CalledProcessError: + git_hash = "unknown" + git_branch = "unknown" + else: + git_hash = "unknown" + git_branch = "unknown" + print(f'**** Git info for Megatron: git_hash={git_hash} git_branch={git_branch} ****') + + +if __name__ == "__main__": + pretrain(train_valid_test_datasets_provider, model_provider, forward_step, + args_defaults={'tokenizer_type': 'GPT2BPETokenizer'}, + data_post_process=data_post_process) \ No newline at end of file diff --git a/runner.py b/runner.py new file mode 100644 index 000000000..6c1ff6a7e --- /dev/null +++ b/runner.py @@ -0,0 +1,16 @@ +import torch +import torch_npu +from torch_npu.contrib import transfer_to_npu +import bugfix + +from pretrain_t5 import pretrain, train_valid_test_datasets_provider, model_provider, forward_step +from megatron.model import T5Model, ModelType + +option = {} +option["ACL_OP_COMPILER_CACHE_MODE"] = "enable" +option["ACL_OP_COMPILER_CACHE_DIR"] = "./cache" +print("option:",option) +torch.npu.set_option(option) + +pretrain(train_valid_test_datasets_provider, model_provider, ModelType.encoder_and_decoder, + forward_step, args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'}) -- Gitee From 4a2b2deeb176de1c23d94e70c60c7ee9bf8831e6 Mon Sep 17 00:00:00 2001 From: fengliangjun Date: Tue, 6 Jun 2023 19:40:31 +0800 Subject: [PATCH 02/10] delete redundancy --- bugfix.py | 822 ----------------------------- main.py | 16 - megatron/data/indexed_dataset.py | 2 +- megatron/model/fused_layer_norm.py | 11 - megatron/model/fused_softmax.py | 145 ----- 5 files changed, 1 insertion(+), 995 deletions(-) delete mode 100644 bugfix.py delete mode 100644 main.py diff --git a/bugfix.py b/bugfix.py deleted file mode 100644 index ac3e5c420..000000000 --- a/bugfix.py +++ /dev/null @@ -1,822 +0,0 @@ -import sys -import time -import math -import copy -import numpy as np -import torch -import torch_npu -from functools import wraps - -# ====================== -# torch -# ====================== -# -# # INPLACE.1: torch.cuda.get_rng_state -# torch.cuda.get_rng_state = torch.get_rng_state -# torch.cuda.set_rng_state = torch.set_rng_state - - -# INPLACE.2: torch.Tensor.type() -def wrapper_type(fn): - @wraps(fn) - def decorated(*args, **kwargs): - output = fn(*args, **kwargs) - if isinstance(output, str): - if output == 'torch.npu.FloatTensor': - output = 'torch.cuda.FloatTensor' - elif output == 'torch.npu.HalfTensor': - output = 'torch.cuda.HalfTensor' - return output - - return decorated - - -torch.Tensor.type = wrapper_type(torch.Tensor.type) - -# INPLACE.3: torch.ditributed.xx input long --> int -from torch import distributed as dist - - -def wrapper_dist_long2int(fn): - @wraps(fn) - def wrapper(*args, **kwargs): - if args[0].dtype == torch.long and not kwargs.get('async_op', False): - new_args = list(copy.deepcopy(args)) - new_args[0] = new_args[0].int() - fn(*new_args, **kwargs) - args[0].copy_(new_args[0].long()) - return - return fn(*args, **kwargs) - - return wrapper - - -dist.all_reduce = wrapper_dist_long2int(dist.all_reduce) -dist.broadcast = wrapper_dist_long2int(dist.broadcast) -dist.send = wrapper_dist_long2int(dist.send) -dist.recv = wrapper_dist_long2int(dist.recv) - -# ====================== -# apex -# ====================== - -# INPLACE.4: apex.optimizers -import apex - - -class AdamW(torch.optim.Optimizer): - r"""Implements AdamW algorithm. - - The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_. - The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_. - - Arguments: - params (iterable): iterable of parameters to optimize or dicts defining - parameter groups - lr (float, optional): learning rate (default: 1e-3) - betas (Tuple[float, float], optional): coefficients used for computing - running averages of gradient and its square (default: (0.9, 0.999)) - eps (float, optional): term added to the denominator to improve - numerical stability (default: 1e-8) - weight_decay (float, optional): weight decay coefficient (default: 1e-2) - amsgrad (boolean, optional): whether to use the AMSGrad variant of this - algorithm from the paper `On the Convergence of Adam and Beyond`_ - (default: False) - - .. _Adam\: A Method for Stochastic Optimization: - https://arxiv.org/abs/1412.6980 - .. _Decoupled Weight Decay Regularization: - https://arxiv.org/abs/1711.05101 - .. _On the Convergence of Adam and Beyond: - https://openreview.net/forum?id=ryQu7f-RZ - """ - - def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, - weight_decay=1e-2, amsgrad=False): - if not 0.0 <= lr: - raise ValueError("Invalid learning rate: {}".format(lr)) - if not 0.0 <= eps: - raise ValueError("Invalid epsilon value: {}".format(eps)) - if not 0.0 <= betas[0] < 1.0: - raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) - if not 0.0 <= betas[1] < 1.0: - raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) - defaults = dict(lr=lr, betas=betas, eps=eps, - weight_decay=weight_decay, amsgrad=amsgrad) - super(AdamW, self).__init__(params, defaults) - - def __setstate__(self, state): - super(AdamW, self).__setstate__(state) - for group in self.param_groups: - group.setdefault('amsgrad', False) - - @torch.no_grad() - def step(self, closure=None): - """Performs a single optimization step. - - Arguments: - closure (callable, optional): A closure that reevaluates the model - and returns the loss. - """ - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - for group in self.param_groups: - for p in group['params']: - if p.grad is None: - continue - - # Perform stepweight decay - p.data.mul_(1 - group['lr'] * group['weight_decay']) - - # Perform optimization step - grad = p.grad - if grad.is_sparse: - raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') - amsgrad = group['amsgrad'] - - state = self.state[p] - - # State initialization - if len(state) == 0: - state['step'] = 0 - # Exponential moving average of gradient values - state['exp_avg'] = torch.zeros_like(p) - # Exponential moving average of squared gradient values - state['exp_avg_sq'] = torch.zeros_like(p) - if amsgrad: - # Maintains max of all exp. moving avg. of sq. grad. values - state['max_exp_avg_sq'] = torch.zeros_like(p) - - exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] - if amsgrad: - max_exp_avg_sq = state['max_exp_avg_sq'] - beta1, beta2 = group['betas'] - - state['step'] += 1 - bias_correction1 = 1 - beta1 ** state['step'] - bias_correction2 = 1 - beta2 ** state['step'] - - # Decay the first and second moment running average coefficient - exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) - exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) - if amsgrad: - # Maintains the maximum of all 2nd moment running avg. till now - torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) - # Use the max. for normalizing running avg. of gradient - denom = (max_exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) - else: - denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) - - step_size = group['lr'] / bias_correction1 - - p.addcdiv_(exp_avg, denom, value=-step_size) - - return loss - - -apex.optimizers.FusedAdam = AdamW -apex.optimizers.FusedSGD = torch.optim.SGD - -# ====================== -# megatron -# ====================== -import megatron - -# -# # INPLACE.5: megatron.initialize._compile_dependencies -# def _compile_dependencies(): -# if torch.distributed.get_rank() == 0: -# start_time = time.time() -# print('> compiling dataset index builder ...') -# from megatron.data.dataset_utils import compile_helper -# compile_helper() -# print('>>> done with dataset index builder. Compilation time: {:.3f} ' -# 'seconds'.format(time.time() - start_time), flush=True) -# -# -# megatron.initialize._compile_dependencies = _compile_dependencies -# -# # INPLACE.6: fp32_to_float16, float16_to_fp32 -# from torch.autograd import Variable -# from torch.nn.parameter import Parameter -# from megatron.model.module import fp32_to_float16, float16_to_fp32, conversion_helper -# -# -# def fp32_to_float16(val, float16_convertor): -# """Convert fp32 `val` to fp16/bf16""" -# -# def half_conversion(val): -# val_typecheck = val -# if isinstance(val_typecheck, (Parameter, Variable)): -# val_typecheck = val.data -# if val_typecheck.dtype == torch.float32: -# val = float16_convertor(val) -# return val -# -# return conversion_helper(val, half_conversion) -# -# -# def float16_to_fp32(val): -# """Convert fp16/bf16 `val` to fp32""" -# -# def float_conversion(val): -# val_typecheck = val -# if isinstance(val_typecheck, (Parameter, Variable)): -# val_typecheck = val.data -# if val_typecheck.dtype in [torch.float16, torch.bfloat16]: -# val = val.float() -# return val -# -# return conversion_helper(val, float_conversion) -# -# -# megatron.model.module.fp32_to_float16 = fp32_to_float16 -# megatron.model.module.float16_to_fp32 = float16_to_fp32 - -# INPLACE.7: MixedFusedLayerNorm -# from megatron.model.fused_layer_norm import MixedFusedLayerNorm -# import numbers -# -# class MixedFusedLayerNorm(torch.nn.LayerNorm): -# def __init__(self, normalized_shape, eps=1e-5, no_persist_layer_norm=True, sequence_parallel=False): -# super(MixedFusedLayerNorm, self).__init__(normalized_shape, eps, no_persist_layer_norm) -# -# if isinstance(normalized_shape, numbers.Integral): -# normalized_shape = (normalized_shape,) -# -# self.weight = Parameter(torch.Tensor(*normalized_shape)) -# self.bias = Parameter(torch.Tensor(*normalized_shape)) -# -# # set sequence parallelism flag on weight and bias parameters -# self.sequence_parallel = sequence_parallel -# setattr(self.weight, 'sequence_parallel', self.sequence_parallel) -# setattr(self.bias, 'sequence_parallel', self.sequence_parallel) -# -# -# for k in sys.modules: -# if k.startswith('megatron.model'): -# for target in ['LayerNorm', 'MixedFusedLayerNorm']: -# if getattr(sys.modules[k], target, None): -# setattr(sys.modules[k], target, MixedFusedLayerNorm) - -# # INPLACE.8: _unscale_main_grads_and_check_for_nan -# from megatron.optimizer import Float16OptimizerWithFloat16Params -# -# -# def _unscale_main_grads_and_check_for_nan(self): -# # Collect main grads. -# main_grads = self._collect_main_grad_data_for_unscaling() -# -# # Reset found inf. -# self.found_inf.fill_(0.0) -# -# # Unscale and set found inf/nan -# torch._amp_foreach_non_finite_check_and_unscale_( -# main_grads, self.found_inf, self.grad_scaler.inv_scale) -# -# # Update across all model parallel instances. -# torch.distributed.all_reduce(self.found_inf, -# op=torch.distributed.ReduceOp.MAX, -# group=self.get_model_parallel_group()) -# -# # add data_parallel synchronize -# torch.distributed.all_reduce(self.found_inf, -# op=torch.distributed.ReduceOp.MAX, -# group=self.get_data_parallel_group()) -# -# # Check for nan. -# found_inf_flag = (self.found_inf.item() > 0) -# -# return found_inf_flag -# -# -# Float16OptimizerWithFloat16Params._unscale_main_grads_and_check_for_nan = _unscale_main_grads_and_check_for_nan - -# INPLACE.9: FusedScaleMaskSoftmax -# from megatron.model.fused_softmax import FusedScaleMaskSoftmax -# from megatron.model.enums import AttnMaskType -# -# -# class FusedScaleMaskSoftmax(torch.nn.Module): -# def __init__( -# self, -# input_in_fp16, -# input_in_bf16, -# attn_mask_type, -# scaled_masked_softmax_fusion, -# mask_func, -# softmax_in_fp32, -# scale, -# ): -# super(FusedScaleMaskSoftmax, self).__init__() -# self.input_in_fp16 = input_in_fp16 -# self.input_in_bf16 = input_in_bf16 -# assert not ( -# self.input_in_fp16 and self.input_in_bf16 -# ), "both fp16 and bf16 flags cannot be active at the same time." -# self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16 -# self.attn_mask_type = attn_mask_type -# self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion -# self.mask_func = mask_func -# self.softmax_in_fp32 = softmax_in_fp32 -# self.scale = scale -# self.mask_tri = None -# p = torch.npu.get_device_properties(0) if torch.npu.is_available() else None -# self.fused = p.name in ['Ascend910A', 'Ascend910ProB'] if p is not None else False -# -# assert ( -# self.scale is None or softmax_in_fp32 -# ), "softmax should be in fp32 when scaled" -# -# def forward(self, input, mask): -# # [b, np, sq, sk] -# assert input.dim() == 4 -# -# if torch.npu.is_available() and self.fused: -# return self.forward_fused_softmax(input, mask) -# -# return self.forward_torch_softmax(input, mask) -# -# def forward_fused_softmax(self, input, mask): -# if self.softmax_in_fp32: -# input = input.float() -# -# if self.scale is None: -# self.scale = 1.0 -# -# if self.attn_mask_type == AttnMaskType.causal: -# if self.mask_tri is None: -# self.mask_tri = torch.triu(torch.ones(input.shape, device=input.device), diagonal=1).bool() -# probs = torch_npu.npu_scaled_masked_softmax(input, self.mask_tri, self.scale, False) -# else: -# probs = torch_npu.npu_scaled_masked_softmax(input, mask, self.scale, False) -# -# probs = probs.half() -# -# return probs -# -# def forward_torch_softmax(self, input, mask): -# if self.input_in_float16 and self.softmax_in_fp32: -# input = input.float() -# -# if self.scale is not None: -# input = input * self.scale -# -# if self.attn_mask_type == AttnMaskType.causal: -# mask_tri = torch.triu(torch.ones(input.shape, device=input.device), diagonal=1).bool() -# mask_output = self.mask_func(input, mask_tri) -# else: -# mask_output = self.mask_func(input, mask) if mask is not None else input -# probs = torch.nn.Softmax(dim=-1)(mask_output) -# -# if self.input_in_float16 and self.softmax_in_fp32: -# if self.input_in_fp16: -# probs = probs.half() -# else: -# probs = probs.bfloat16() -# -# return probs -# -# -# for k in sys.modules: -# if k.startswith('megatron.model'): -# for target in ['FusedScaleMaskSoftmax']: -# if getattr(sys.modules[k], target, None): -# setattr(sys.modules[k], target, FusedScaleMaskSoftmax) - -# INPLACE.10: clip_grad_norm_fp32 -from torch._six import inf -from megatron import mpu -from megatron.model.module import param_is_not_shared -from megatron.mpu.layers import param_is_not_tensor_parallel_duplicate -from megatron.optimizer.clip_grads import clip_grad_norm_fp32 -from deepspeed.accelerator import get_accelerator - - -# def clip_grad_norm_fp32(parameters, max_norm, norm_type=2): -# """Clips gradient norm of an iterable of parameters whose gradients -# are in fp32. -# -# This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and -# added functionality to handle model parallel parameters. Note that -# the gradients are modified in place. -# -# Arguments: -# parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a -# single Tensor that will have gradients normalized -# grads_for_norm (Iterable[Tensor]): an iterable of Tensors or a single -# Tensor that will be used for calculating the grad norm. -# max_norm (float or int): max norm of the gradients -# norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for -# infinity norm. -# model_parallel_group (group): given the nature of the distributed -# optimizer, this is passed as an argument. -# -# Returns: -# Total norm of the parameters (viewed as a single vector). -# """ -# -# if isinstance(parameters, torch.Tensor): -# parameters = [parameters] -# # if isinstance(grads_for_norm, torch.Tensor): -# # grads_for_norm = [grads_for_norm] -# -# # Grads. -# grads = [] -# grads_for_norm = [] -# for param in parameters: -# grad_not_none = param.grad is not None -# is_not_shared = param_is_not_shared(param) -# is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param) -# grad = param.grad.detach() -# if grad_not_none: -# # Make sure the grads are in fp32 -# # assert param.grad.type() == 'torch.{}.FloatTensor'.format(get_accelerator().device_name()) -# grads.append(grad) -# if grad_not_none and is_not_shared and is_not_tp_duplicate: -# grads_for_norm.append(grad) -# -# -# # Norm parameters. -# max_norm = float(max_norm) -# norm_type = float(norm_type) -# total_norm = 0.0 -# -# # Calculate norm. -# if norm_type == inf: -# total_norm = max(grad.abs().max() for grad in grads_for_norm) -# total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) -# # Take max across all model-parallel GPUs. -# torch.distributed.all_reduce(total_norm_cuda, -# op=torch.distributed.ReduceOp.MAX, -# group=mpu.get_model_parallel_group()) -# total_norm = total_norm_cuda[0].item() -# else: -# for grad in grads_for_norm: -# grad_norm = torch.norm(grad, norm_type) -# total_norm += grad_norm ** norm_type -# -# # Sum across all model-parallel GPUs. -# torch.distributed.all_reduce(total_norm, -# op=torch.distributed.ReduceOp.SUM, -# group=mpu.get_model_parallel_group()) -# total_norm = total_norm.item() ** (1.0 / norm_type) -# -# # Scale. -# clip_coeff = max_norm / (total_norm + 1.0e-6) -# if clip_coeff < 1.0: -# for p in parameters: -# p.grad.detach().mul_(clip_coeff) -# -# return total_norm -# -# -# megatron.optimizer.clip_grads.clip_grad_norm_fp32 = clip_grad_norm_fp32 -# megatron.optimizer.optimizer.clip_grad_norm_fp32 = clip_grad_norm_fp32 - -# INPLACE.11: _CUDA_RNG_STATE_TRACKER -import contextlib -from megatron.mpu.random import CudaRNGStatesTracker, _CUDA_RNG_STATE_TRACKER, _MODEL_PARALLEL_RNG_TRACKER_NAME - - -class CudaRNGStatesTracker: - """Tracker for the cuda RNG states. - - Using the `add` method, a cuda rng state is initialized based on - the input `seed` and is assigned to `name`. Later, by forking the - rng state, we can perform operations and return to our starting - cuda state. - """ - - def __init__(self): - # Map from a string name to the cuda rng state. - self.states_ = {} - # Seeds are just for book keeping and ensure no seed is set twice. - self.seeds_ = set() - - def reset(self): - """Set to the initial state (no tracker).""" - self.states_ = {} - self.seeds_ = set() - - def get_states(self): - """Get rng states. Copy the dictionary so we have direct - pointers to the states, not just a pointer to the dictionary.""" - states = {} - for name in self.states_: - states[name] = self.states_[name] - return states - - def set_states(self, states): - """Set the rng states. For efficiency purposes, we do not check - the size of seed for compatibility.""" - self.states_ = states - - def add(self, name, seed): - """Track the rng state.""" - # Check seed is not already used. - if seed in self.seeds_: - raise Exception('seed {} already exists'.format(seed)) - self.seeds_.add(seed) - # Check that state is not already defined. - if name in self.states_: - raise Exception('cuda rng state {} already exists'.format(name)) - # Get the current rng state. - # orig_rng_state = torch.cuda.get_rng_state() - # Set the new state and store it. - torch.cuda.manual_seed(seed) - # self.states_[name] = torch.cuda.get_rng_state() - # Reset rng state to what it was. - # _set_cuda_rng_state(orig_rng_state) - - @contextlib.contextmanager - def fork(self, name=_MODEL_PARALLEL_RNG_TRACKER_NAME): - yield - # """Fork the cuda rng state, perform operations, and exit with - # the original state.""" - # # Check if we have added the state - # if name not in self.states_: - # raise Exception('cuda rng state {} is not added'.format(name)) - # Store current rng state. - # orig_cuda_rng_state = torch.cuda.get_rng_state() - # Set rng state to the desired one - # _set_cuda_rng_state(self.states_[name]) - # Do the stuff we wanted to do. - # try: - # yield - # finally: - # # Update the current rng state for later use. - # self.states_[name] = torch.cuda.get_rng_state() - # # And set the state to the original state we started with. - # _set_cuda_rng_state(orig_cuda_rng_state) - - -megatron.mpu.random.CudaRNGStatesTracker = CudaRNGStatesTracker -megatron.mpu.random._CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker() - -# INPLACE.12: _unscale_main_grads_and_check_for_nan -from megatron.optimizer.optimizer import Float16OptimizerWithFloat16Params - -# -# def _unscale_main_grads_and_check_for_nan(self): -# main_grads = [] -# # fp32 params fromm float16 ones. -# for main_group in self.fp32_from_float16_groups: -# for main_param in main_group: -# if main_param.grad is not None: -# main_grads.append(main_param.grad.data) -# # Append fp32 parameters. -# for main_group in self.fp32_from_fp32_groups: -# for main_param in main_group: -# if main_param.grad is not None: -# main_grads.append(main_param.grad.data) -# # Reset found inf. -# self.found_inf.fill_(0.0) -# # Unscale and set found inf/nan -# torch._amp_foreach_non_finite_check_and_unscale_( -# main_grads, self.found_inf, self.grad_scaler.inv_scale) -# # Update across all model parallel instances. -# torch.distributed.all_reduce(self.found_inf, -# op=torch.distributed.ReduceOp.MAX, -# group=mpu.get_model_parallel_group()) -# torch.distributed.all_reduce(self.found_inf, -# op=torch.distributed.ReduceOp.MAX, -# group=mpu.get_data_parallel_group()) -# -# # Check for nan. -# found_inf_flag = (self.found_inf.item() > 0) -# return found_inf_flag -# -# -# Float16OptimizerWithFloat16Params._unscale_main_grads_and_check_for_nan = _unscale_main_grads_and_check_for_nan - -# INPLACE.13: refine overflow flag -from megatron import schedules, get_num_microbatches, get_args, get_timers -from megatron.schedules import dummy_handler, forward_step -from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP - - -def clear_npu_overflow_flag(): - float_status = torch.zeros(8).npu() - result = torch_npu.npu_clear_float_status(float_status) - - -def get_npu_overflow_flag(): - float_status = torch.zeros(8).npu() - result = torch_npu.npu_get_float_status(float_status) - if float_status.cpu()[0] != 0: - return True - else: - return False - - -def set_npu_overflow_flag(): - torch.tensor([65504]).half().npu() + 100 - - -def backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad, model=None): - """Backward step through passed-in output tensor. - - If last stage, output_tensor_grad is None, otherwise gradient of loss - with respect to stage's output tensor. - - Returns gradient of loss with respect to input tensor (None if first - stage).""" - args = get_args() - - if args.deepspeed: - assert model is not None - - timers = get_timers() - timers('backward-compute').start() - - # Retain the grad on the input_tensor. - if input_tensor is not None: - input_tensor.retain_grad() - - clear_npu_overflow_flag() - if args.deepspeed: - model.backward(output_tensor) - else: - # Backward pass. - if output_tensor_grad is None: - output_tensor = optimizer.scale_loss(output_tensor) - torch.autograd.backward(output_tensor, grad_tensors=output_tensor_grad) - - # Collect the grad of the input_tensor. - input_tensor_grad = None - if input_tensor is not None: - input_tensor_grad = input_tensor.grad - - timers('backward-compute').stop() - - return input_tensor_grad - -def forward_backward_no_pipelining(forward_step_func, data_iterator, model, - optimizer, timers, forward_only): - """Run forward and backward passes with no pipeline parallelism - (no inter-stage communication). - - Returns dictionary with losses.""" - assert len(model) == 1 - model = model[0] - - context_handler = dummy_handler - if isinstance(model, torchDDP): - context_handler = model.no_sync - - losses_reduced = [] - input_tensor, output_tensor_grad = None, None - overflow_flag_all = False - with context_handler(): - for i in range(get_num_microbatches() - 1): - output_tensor = forward_step(forward_step_func, data_iterator, model, - input_tensor, losses_reduced) - if not forward_only: - backward_step(optimizer, input_tensor, output_tensor, - output_tensor_grad) - - overflow_flag = get_npu_overflow_flag() - overflow_flag_all = overflow_flag or overflow_flag_all - output_tensor = forward_step(forward_step_func, data_iterator, model, - input_tensor, losses_reduced) - if not forward_only: - backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad) - - overflow_flag = get_npu_overflow_flag() - overflow_flag_all = overflow_flag or overflow_flag_all - - if overflow_flag_all: - set_npu_overflow_flag() - return losses_reduced - - -schedules.forward_backward_no_pipelining = forward_backward_no_pipelining -# -# # INPLACE.14: remove dropout in ParallelTransformerLayer -# from megatron.model.transformer import ParallelTransformerLayer, bias_dropout_add_fused_train, \ -# bias_dropout_add_fused_inference, get_bias_dropout_add -# from megatron.model.enums import AttnMaskType, LayerType, AttnType -# -# -# def forward(self, hidden_states, attention_mask=None, -# encoder_output=None, enc_dec_attn_mask=None, -# layer_past=None, get_key_value=False): -# # hidden_states: [b, s, h] -# -# # Layer norm at the beginning of the transformer layer. -# layernorm_output = self.input_layernorm(hidden_states) -# # Self attention. -# attention_output, attention_bias = \ -# self.attention(layernorm_output, -# attention_mask, -# layer_past=layer_past, -# get_key_value=get_key_value) -# -# if get_key_value: -# attention_output, presents = attention_output -# -# # Residual connection. -# if self.apply_residual_connection_post_layernorm: -# residual = layernorm_output -# else: -# residual = hidden_states -# -# # jit scripting for a nn.module (with dropout) is not -# # trigerring the fusion kernel. For now, we use two -# # different nn.functional routines to account for varying -# # dropout semantics during training and inference phases. -# if self.bias_dropout_fusion: -# if self.training: -# bias_dropout_add_func = bias_dropout_add_fused_train -# else: -# bias_dropout_add_func = bias_dropout_add_fused_inference -# else: -# bias_dropout_add_func = get_bias_dropout_add(self.training) -# -# # re-enable torch grad to enable fused optimization. -# with torch.enable_grad(): -# layernorm_input = bias_dropout_add_func( -# attention_output, -# attention_bias.expand_as(residual), -# residual, -# 0.) -# -# # Layer norm post the self attention. -# layernorm_output = self.post_attention_layernorm(layernorm_input) -# -# if self.layer_type == LayerType.decoder: -# attention_output, attention_bias = \ -# self.inter_attention(layernorm_output, -# enc_dec_attn_mask, -# encoder_output=encoder_output) -# # residual connection -# if self.apply_residual_connection_post_layernorm: -# residual = layernorm_output -# else: -# residual = layernorm_input -# -# # re-enable torch grad to enable fused optimization. -# with torch.enable_grad(): -# layernorm_input = bias_dropout_add_func( -# attention_output, -# attention_bias.expand_as(residual), -# residual, -# 0.) -# -# # Layer norm post the decoder attention -# layernorm_output = self.post_inter_attention_layernorm(layernorm_input) -# -# # MLP. -# moe_loss = torch.tensor(0.0, device=layernorm_output.device, dtype=layernorm_output.dtype) -# mlp_bias = torch.tensor(0.0, device=layernorm_output.device, dtype=layernorm_output.dtype) -# -# if self.num_experts == 1: -# mlp_output, mlp_bias = self.mlp(layernorm_output) -# else: -# mlp_output, moe_loss, _ = self.mlp(layernorm_output) -# -# # Second residual connection. -# if self.apply_residual_connection_post_layernorm: -# residual = layernorm_output -# else: -# residual = layernorm_input -# -# # re-enable torch grad to enable fused optimization. -# with torch.enable_grad(): -# # if self.num_experts <= 1: -# output = bias_dropout_add_func( -# mlp_output, -# mlp_bias.expand_as(residual), -# residual, -# 0.) -# # else: -# # output = mlp_output + residual -# -# if get_key_value: -# output = [output, presents] -# -# return output, moe_loss -# -# -# ParallelTransformerLayer.forward = forward - -# -# from megatron import initialize -# from megatron.initialize import _warmup_jit_function -# -# def set_jit_fusion_options(): -# """Set PyTorch JIT layer fusion options.""" -# -# # legacy pytorch fuser -# torch._C._jit_set_profiling_mode(False) -# torch._C._jit_set_profiling_executor(False) -# torch._C._jit_override_can_fuse_on_cpu(True) -# torch._C._jit_override_can_fuse_on_gpu(True) -# -# _warmup_jit_function() -# -# initialize.set_jit_fusion_options = set_jit_fusion_options diff --git a/main.py b/main.py deleted file mode 100644 index 5596b4478..000000000 --- a/main.py +++ /dev/null @@ -1,16 +0,0 @@ -# This is a sample Python script. - -# Press Shift+F10 to execute it or replace it with your code. -# Press Double Shift to search everywhere for classes, files, tool windows, actions, and settings. - - -def print_hi(name): - # Use a breakpoint in the code line below to debug your script. - print(f'Hi, {name}') # Press Ctrl+F8 to toggle the breakpoint. - - -# Press the green button in the gutter to run the script. -if __name__ == '__main__': - print_hi('PyCharm') - -# See PyCharm help at https://www.jetbrains.com/help/pycharm/ diff --git a/megatron/data/indexed_dataset.py b/megatron/data/indexed_dataset.py index 20344f889..28ec14762 100644 --- a/megatron/data/indexed_dataset.py +++ b/megatron/data/indexed_dataset.py @@ -11,7 +11,7 @@ # An empty sentence no longer separates documents. # Some of the fixes/improvements are adopted from -# https://github.com/bigscience-workshop/Megatron-DeepSpeed/blob/main/megatron/data/indexed_dataset.py +# https://github.com/bigscience-workshop/AscendSpeed/blob/main/megatron/data/indexed_dataset.py from functools import lru_cache import os diff --git a/megatron/model/fused_layer_norm.py b/megatron/model/fused_layer_norm.py index 76a7648c1..1583f97ea 100644 --- a/megatron/model/fused_layer_norm.py +++ b/megatron/model/fused_layer_norm.py @@ -65,10 +65,6 @@ class MixedFusedLayerNorm(torch.nn.Module): def __init__(self, normalized_shape, eps=1e-5): super(MixedFusedLayerNorm, self).__init__() - # global fused_mix_prec_layer_norm_cuda - # fused_mix_prec_layer_norm_cuda = importlib.import_module( - # "fused_mix_prec_layer_norm_cuda") - if isinstance(normalized_shape, numbers.Integral): normalized_shape = (normalized_shape,) self.normalized_shape = torch.Size(normalized_shape) @@ -84,11 +80,4 @@ class MixedFusedLayerNorm(torch.nn.Module): def forward(self, input): return torch.nn.functional.layer_norm(input, self.normalized_shape, self.weight, self.bias, self.eps) - # # CPU path is here for unittest sake. - # if not input.is_cuda: - # print("WARNING! The input of FusedLayerNorm should be on the GPU." - # "This warning should only be triggered in the FusedLayerNorm unit tests.") - # return F.layer_norm(input, self.normalized_shape, self.weight, self.bias, self.eps) - # return FusedLayerNormAffineFunction.apply( - # input, self.weight, self.bias, self.normalized_shape,self.eps) diff --git a/megatron/model/fused_softmax.py b/megatron/model/fused_softmax.py index eb4c69887..8a4ddf08e 100644 --- a/megatron/model/fused_softmax.py +++ b/megatron/model/fused_softmax.py @@ -18,70 +18,6 @@ import torch_npu from megatron.model.enums import AttnMaskType -# class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function): -# """ -# Fused operation which performs following three operations in sequence -# 1. Scale the tensor. -# 2. Apply upper triangular mask (typically used in gpt models). -# 3. Perform softmax. -# """ -# -# @staticmethod -# def forward(ctx, inputs, scale): -# import scaled_upper_triang_masked_softmax_cuda -# -# scale_t = torch.tensor([scale]) -# -# softmax_results = scaled_upper_triang_masked_softmax_cuda.forward( -# inputs, scale_t[0] -# ) -# ctx.save_for_backward(softmax_results, scale_t) -# return softmax_results -# -# @staticmethod -# def backward(ctx, output_grads): -# import scaled_upper_triang_masked_softmax_cuda -# -# softmax_results, scale_t = ctx.saved_tensors -# -# input_grads = scaled_upper_triang_masked_softmax_cuda.backward( -# output_grads, softmax_results, scale_t[0] -# ) -# return input_grads, None -# -# -# class ScaledMaskedSoftmax(torch.autograd.Function): -# """ -# Fused operation which performs following three operations in sequence -# 1. Scale the tensor. -# 2. Apply the mask. -# 3. Perform softmax. -# """ -# -# @staticmethod -# def forward(ctx, inputs, mask, scale): -# import scaled_masked_softmax_cuda -# -# scale_t = torch.tensor([scale]) -# -# softmax_results = scaled_masked_softmax_cuda.forward( -# inputs, mask, scale_t[0] -# ) -# ctx.save_for_backward(softmax_results, scale_t) -# return softmax_results -# -# @staticmethod -# def backward(ctx, output_grads): -# import scaled_masked_softmax_cuda -# -# softmax_results, scale_t = ctx.saved_tensors -# -# input_grads = scaled_masked_softmax_cuda.backward( -# output_grads, softmax_results, scale_t[0] -# ) -# return input_grads, None, None - - class FusedScaleMaskSoftmax(torch.nn.Module): def __init__( self, @@ -161,84 +97,3 @@ class FusedScaleMaskSoftmax(torch.nn.Module): probs = probs.bfloat16() return probs - -# class FusedScaleMaskSoftmax(torch.nn.Module): -# """ -# fused operation: scaling + mask + softmax -# Arguments: -# input_in_fp16: flag to indicate if input in fp16 data format. -# attn_mask_type: attention mask type (pad or causal) -# mask_func: mask function to be applied. -# softmax_in_fp32: if true, softmax in performed at fp32 precision. -# scale: scaling factor used in input tensor scaling. -# -# """ -# -# def __init__( -# self, -# input_in_fp16, -# input_in_bf16, -# attn_mask_type, -# scaled_masked_softmax_fusion, -# mask_func, -# softmax_in_fp32, -# scale, -# ): -# super(FusedScaleMaskSoftmax, self).__init__() -# self.input_in_fp16 = input_in_fp16 -# self.input_in_bf16 = input_in_bf16 -# assert not (self.input_in_fp16 and self.input_in_bf16),\ -# 'both fp16 and bf16 flags cannot be active at the same time.' -# self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16 -# self.attn_mask_type = attn_mask_type -# self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion -# self.mask_func = mask_func -# self.softmax_in_fp32 = softmax_in_fp32 -# self.scale = scale -# -# assert ( -# self.scale is None or softmax_in_fp32 -# ), "softmax should be in fp32 when scaled" -# -# def forward(self, input, mask): -# # [b, np, sq, sk] -# assert input.dim() == 4 -# data_size = input.size() -# query_seq_len = data_size[-2] -# key_seq_len = data_size[-1] -# attn_batch_size = data_size[0] * data_size[1] -# -# # constraints on various tensor dimensions to enable warp based -# # optimization and upper triangular optimization (for causal mask) -# custom_kernel_constraint = key_seq_len > 16 and key_seq_len <= 2048 and \ -# query_seq_len % 4 == 0 and attn_batch_size % 4 == 0 -# -# # invoke custom kernel -# if self.input_in_float16 and mask is not None and \ -# custom_kernel_constraint and self.scaled_masked_softmax_fusion: -# scale = self.scale if self.scale is not None else 1.0 -# -# if self.attn_mask_type == AttnMaskType.causal: -# assert query_seq_len == key_seq_len, \ -# "causal mask is only for self attention" -# input = input.view(-1, query_seq_len, key_seq_len) -# probs = ScaledUpperTriangMaskedSoftmax.apply(input, scale) -# probs = probs.view(*data_size) -# else: -# assert self.attn_mask_type == AttnMaskType.padding -# probs = ScaledMaskedSoftmax.apply(input, mask, scale) -# else: -# if self.input_in_float16 and self.softmax_in_fp32: -# input = input.float() -# -# if self.scale is not None: -# input = input * self.scale -# mask_output = self.mask_func(input, mask) if mask is not None else input -# probs = torch.nn.Softmax(dim=-1)(mask_output) -# if self.input_in_float16 and self.softmax_in_fp32: -# if self.input_in_fp16: -# probs = probs.half() -# else: -# probs = probs.bfloat16() -# -# return probs -- Gitee From 65c95929f647e11ff625bacd6b6376cf7bfc91be Mon Sep 17 00:00:00 2001 From: fengliangjun Date: Tue, 6 Jun 2023 20:13:32 +0800 Subject: [PATCH 03/10] remove fused kernel --- examples/gpt_task/pretrain_gpt_1p.sh | 2 +- examples/gpt_task/pretrain_gpt_ptd_8p.sh | 1 + examples/gpt_task/pretrain_gpt_td_8p.sh | 1 + examples/llama_task/pretrain_llama_1p.sh | 2 +- examples/llama_task/pretrain_llama_8p.sh | 49 - examples/llama_task/pretrain_llama_ptd_8p.sh | 1 + examples/llama_task/pretrain_llama_td_8p.sh | 1 + examples/pretrain_llama_1p.sh | 44 - examples/pretrain_llama_8p | 49 - examples/pretrain_llama_8p.sh | 49 - megatron/fused_kernels/__init__.py | 130 --- megatron/fused_kernels/compat.h | 31 - megatron/fused_kernels/layer_norm_cuda.cpp | 201 ---- .../fused_kernels/layer_norm_cuda_kernel.cu | 866 ------------------ .../fused_kernels/scaled_masked_softmax.cpp | 77 -- .../fused_kernels/scaled_masked_softmax.h | 492 ---------- .../scaled_masked_softmax_cuda.cu | 114 --- .../scaled_upper_triang_masked_softmax.cpp | 72 -- .../scaled_upper_triang_masked_softmax.h | 511 ----------- ...scaled_upper_triang_masked_softmax_cuda.cu | 100 -- megatron/fused_kernels/type_shim.h | 91 -- 21 files changed, 6 insertions(+), 2878 deletions(-) delete mode 100644 examples/llama_task/pretrain_llama_8p.sh delete mode 100644 examples/pretrain_llama_1p.sh delete mode 100644 examples/pretrain_llama_8p delete mode 100644 examples/pretrain_llama_8p.sh delete mode 100644 megatron/fused_kernels/__init__.py delete mode 100644 megatron/fused_kernels/compat.h delete mode 100644 megatron/fused_kernels/layer_norm_cuda.cpp delete mode 100644 megatron/fused_kernels/layer_norm_cuda_kernel.cu delete mode 100644 megatron/fused_kernels/scaled_masked_softmax.cpp delete mode 100644 megatron/fused_kernels/scaled_masked_softmax.h delete mode 100644 megatron/fused_kernels/scaled_masked_softmax_cuda.cu delete mode 100644 megatron/fused_kernels/scaled_upper_triang_masked_softmax.cpp delete mode 100644 megatron/fused_kernels/scaled_upper_triang_masked_softmax.h delete mode 100644 megatron/fused_kernels/scaled_upper_triang_masked_softmax_cuda.cu delete mode 100644 megatron/fused_kernels/type_shim.h diff --git a/examples/gpt_task/pretrain_gpt_1p.sh b/examples/gpt_task/pretrain_gpt_1p.sh index c131b30a0..5b77e4b1b 100644 --- a/examples/gpt_task/pretrain_gpt_1p.sh +++ b/examples/gpt_task/pretrain_gpt_1p.sh @@ -1,5 +1,5 @@ export LD_LIBRARY_PATH=/root/miniconda3/lib:$LD_LIBRARY_PATH - +export HCCL_CONNECT_TIMEOUT=1200 source /usr/local/Ascend/ascend-toolkit/set_env.sh RANK=0 WORLD_SIZE=1 diff --git a/examples/gpt_task/pretrain_gpt_ptd_8p.sh b/examples/gpt_task/pretrain_gpt_ptd_8p.sh index a4e427449..4ffa16900 100644 --- a/examples/gpt_task/pretrain_gpt_ptd_8p.sh +++ b/examples/gpt_task/pretrain_gpt_ptd_8p.sh @@ -1,4 +1,5 @@ export LD_LIBRARY_PATH=/root/miniconda3/lib:$LD_LIBRARY_PATH +export HCCL_CONNECT_TIMEOUT=1200 source /usr/local/Ascend/ascend-toolkit/set_env.sh GPUS_PER_NODE=8 # Change for multinode config diff --git a/examples/gpt_task/pretrain_gpt_td_8p.sh b/examples/gpt_task/pretrain_gpt_td_8p.sh index 240425bc4..522f1069d 100644 --- a/examples/gpt_task/pretrain_gpt_td_8p.sh +++ b/examples/gpt_task/pretrain_gpt_td_8p.sh @@ -1,4 +1,5 @@ export LD_LIBRARY_PATH=/root/miniconda3/lib:$LD_LIBRARY_PATH +export HCCL_CONNECT_TIMEOUT=1200 source /usr/local/Ascend/ascend-toolkit/set_env.sh GPUS_PER_NODE=8 # Change for multinode config diff --git a/examples/llama_task/pretrain_llama_1p.sh b/examples/llama_task/pretrain_llama_1p.sh index a78f4edb7..7bfc8a6af 100644 --- a/examples/llama_task/pretrain_llama_1p.sh +++ b/examples/llama_task/pretrain_llama_1p.sh @@ -1,5 +1,5 @@ export LD_LIBRARY_PATH=/root/miniconda3/lib:$LD_LIBRARY_PATH - +export HCCL_CONNECT_TIMEOUT=1200 source /usr/local/Ascend/ascend-toolkit/set_env.sh RANK=0 WORLD_SIZE=1 diff --git a/examples/llama_task/pretrain_llama_8p.sh b/examples/llama_task/pretrain_llama_8p.sh deleted file mode 100644 index a51ea9e3b..000000000 --- a/examples/llama_task/pretrain_llama_8p.sh +++ /dev/null @@ -1,49 +0,0 @@ -export LD_LIBRARY_PATH=/root/miniconda3/lib:$LD_LIBRARY_PATH -source /usr/local/Ascend/ascend-toolkit/set_env.sh -GPUS_PER_NODE=8 -# Change for multinode config -MASTER_ADDR=localhost -MASTER_PORT=6000 -NNODES=1 -NODE_RANK=0 -WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES)) - -DATA_PATH=/home/flj/AscendSpeed/dataset/enwiki-gpt/gpt_text_sentence -CHECKPOINT_PATH=./ckpt - -DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE --nnodes $NNODES --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT" - -python -m torch.distributed.launch $DISTRIBUTED_ARGS \ - pretrain_llama.py \ - --DDP-impl local \ - --use-contiguous-buffers-in-ddp \ - --tensor-model-parallel-size 4 \ - --num-layers 8 \ - --hidden-size 4096 \ - --num-attention-heads 32 \ - --micro-batch-size 2 \ - --global-batch-size 16 \ - --seq-length 2048 \ - --max-position-embeddings 2048 \ - --train-iters 500000 \ - --lr-decay-iters 320000 \ - --save $CHECKPOINT_PATH \ - --load $CHECKPOINT_PATH \ - --data-path $DATA_PATH \ - --vocab-file /home/flj/AscendSpeed/dataset/gpt2-vocab.json \ - --merge-file /home/flj/AscendSpeed/dataset/gpt2-merges.txt \ - --data-impl mmap \ - --split 949,50,1 \ - --distributed-backend nccl \ - --lr 0.00015 \ - --lr-decay-style cosine \ - --min-lr 1.0e-5 \ - --weight-decay 1e-2 \ - --clip-grad 1.0 \ - --lr-warmup-fraction .01 \ - --checkpoint-activations \ - --log-interval 10 \ - --save-interval 10000 \ - --eval-interval 1000 \ - --eval-iters 10 \ - --fp16 | tee logs/train.log diff --git a/examples/llama_task/pretrain_llama_ptd_8p.sh b/examples/llama_task/pretrain_llama_ptd_8p.sh index 81eb7c69f..a7b691abc 100644 --- a/examples/llama_task/pretrain_llama_ptd_8p.sh +++ b/examples/llama_task/pretrain_llama_ptd_8p.sh @@ -1,4 +1,5 @@ export LD_LIBRARY_PATH=/root/miniconda3/lib:$LD_LIBRARY_PATH +export HCCL_CONNECT_TIMEOUT=1200 source /usr/local/Ascend/ascend-toolkit/set_env.sh GPUS_PER_NODE=8 # Change for multinode config diff --git a/examples/llama_task/pretrain_llama_td_8p.sh b/examples/llama_task/pretrain_llama_td_8p.sh index a2dd944a5..861d501fa 100644 --- a/examples/llama_task/pretrain_llama_td_8p.sh +++ b/examples/llama_task/pretrain_llama_td_8p.sh @@ -1,4 +1,5 @@ export LD_LIBRARY_PATH=/root/miniconda3/lib:$LD_LIBRARY_PATH +export HCCL_CONNECT_TIMEOUT=1200 source /usr/local/Ascend/ascend-toolkit/set_env.sh GPUS_PER_NODE=8 # Change for multinode config diff --git a/examples/pretrain_llama_1p.sh b/examples/pretrain_llama_1p.sh deleted file mode 100644 index e1973b18d..000000000 --- a/examples/pretrain_llama_1p.sh +++ /dev/null @@ -1,44 +0,0 @@ -#! /bin/bash - -# Runs the "345M" parameter model - -RANK=0 -WORLD_SIZE=1 - -DATA_PATH=/home/flj/Megatron_Deepspeed_llama/dataset/enwiki-gpt/gpt_text_sentence -CHECKPOINT_PATH=./ckpt - -export LOCAL_RANK=7 - -python pretrain_llama.py \ - --DDP-impl local \ - --use-contiguous-buffers-in-ddp \ - --num-layers 2 \ - --hidden-size 4096 \ - --num-attention-heads 32 \ - --micro-batch-size 4 \ - --global-batch-size 8 \ - --seq-length 2048 \ - --max-position-embeddings 2048 \ - --train-iters 500000 \ - --lr-decay-iters 320000 \ - --save $CHECKPOINT_PATH \ - --load $CHECKPOINT_PATH \ - --data-path $DATA_PATH \ - --vocab-file /home/flj/Megatron_Deepspeed_llama/dataset/gpt2-vocab.json \ - --merge-file /home/flj/Megatron_Deepspeed_llama/dataset/gpt2-merges.txt \ - --data-impl mmap \ - --split 949,50,1 \ - --distributed-backend nccl \ - --lr 0.00015 \ - --min-lr 1.0e-5 \ - --lr-decay-style cosine \ - --weight-decay 1e-2 \ - --clip-grad 1.0 \ - --lr-warmup-fraction .01 \ - --checkpoint-activations \ - --log-interval 100 \ - --save-interval 10000 \ - --eval-interval 1000 \ - --eval-iters 10 \ - --fp16 | tee logs/train.log diff --git a/examples/pretrain_llama_8p b/examples/pretrain_llama_8p deleted file mode 100644 index e9648284d..000000000 --- a/examples/pretrain_llama_8p +++ /dev/null @@ -1,49 +0,0 @@ -#! /bin/bash - -# Runs the "345M" parameter model - -GPUS_PER_NODE=8 -# Change for multinode config -MASTER_ADDR=localhost -MASTER_PORT=6000 -NNODES=1 -NODE_RANK=0 -WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES)) - -DATA_PATH=/home/flj/Megatron_Deepspeed_llama/dataset/enwiki-gpt/gpt_text_sentence -CHECKPOINT_PATH=./ckpt - -DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE --nnodes $NNODES --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT" - -python -m torch.distributed.launch $DISTRIBUTED_ARGS \ - pretrain_gpt.py \ - --tensor-model-parallel-size 8 \ - --num-layers 24 \ - --hidden-size 4096 \ - --num-attention-heads 32 \ - --micro-batch-size 4 \ - --global-batch-size 8 \ - --seq-length 2048 \ - --max-position-embeddings 2048 \ - --train-iters 500000 \ - --lr-decay-iters 320000 \ - --save $CHECKPOINT_PATH \ - --load $CHECKPOINT_PATH \ - --data-path $DATA_PATH \ - --vocab-file /home/flj/Megatron_Deepspeed_llama/dataset/gpt2-vocab.json \ - --merge-file /home/flj/Megatron_Deepspeed_llama/dataset/gpt2-merges.txt \ - --data-impl mmap \ - --split 949,50,1 \ - --distributed-backend nccl \ - --lr 0.00015 \ - --lr-decay-style cosine \ - --min-lr 1.0e-5 \ - --weight-decay 1e-2 \ - --clip-grad 1.0 \ - --lr-warmup-fraction .01 \ - --checkpoint-activations \ - --log-interval 100 \ - --save-interval 10000 \ - --eval-interval 1000 \ - --eval-iters 10 \ - --fp16 | tee logs/train.log diff --git a/examples/pretrain_llama_8p.sh b/examples/pretrain_llama_8p.sh deleted file mode 100644 index e9648284d..000000000 --- a/examples/pretrain_llama_8p.sh +++ /dev/null @@ -1,49 +0,0 @@ -#! /bin/bash - -# Runs the "345M" parameter model - -GPUS_PER_NODE=8 -# Change for multinode config -MASTER_ADDR=localhost -MASTER_PORT=6000 -NNODES=1 -NODE_RANK=0 -WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES)) - -DATA_PATH=/home/flj/Megatron_Deepspeed_llama/dataset/enwiki-gpt/gpt_text_sentence -CHECKPOINT_PATH=./ckpt - -DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE --nnodes $NNODES --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT" - -python -m torch.distributed.launch $DISTRIBUTED_ARGS \ - pretrain_gpt.py \ - --tensor-model-parallel-size 8 \ - --num-layers 24 \ - --hidden-size 4096 \ - --num-attention-heads 32 \ - --micro-batch-size 4 \ - --global-batch-size 8 \ - --seq-length 2048 \ - --max-position-embeddings 2048 \ - --train-iters 500000 \ - --lr-decay-iters 320000 \ - --save $CHECKPOINT_PATH \ - --load $CHECKPOINT_PATH \ - --data-path $DATA_PATH \ - --vocab-file /home/flj/Megatron_Deepspeed_llama/dataset/gpt2-vocab.json \ - --merge-file /home/flj/Megatron_Deepspeed_llama/dataset/gpt2-merges.txt \ - --data-impl mmap \ - --split 949,50,1 \ - --distributed-backend nccl \ - --lr 0.00015 \ - --lr-decay-style cosine \ - --min-lr 1.0e-5 \ - --weight-decay 1e-2 \ - --clip-grad 1.0 \ - --lr-warmup-fraction .01 \ - --checkpoint-activations \ - --log-interval 100 \ - --save-interval 10000 \ - --eval-interval 1000 \ - --eval-iters 10 \ - --fp16 | tee logs/train.log diff --git a/megatron/fused_kernels/__init__.py b/megatron/fused_kernels/__init__.py deleted file mode 100644 index bdc654c39..000000000 --- a/megatron/fused_kernels/__init__.py +++ /dev/null @@ -1,130 +0,0 @@ -# coding=utf-8 -# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import pathlib -import subprocess - -import torch -from torch.utils import cpp_extension - -# Setting this param to a list has a problem of generating different -# compilation commands (with diferent order of architectures) and -# leading to recompilation of fused kernels. Set it to empty string -# to avoid recompilation and assign arch flags explicity in -# extra_cuda_cflags below -os.environ["TORCH_CUDA_ARCH_LIST"] = "" - - -def load(args): - - # Check if cuda 11 is installed for compute capability 8.0 - cc_flag = [] - if torch.version.hip is None: - _, bare_metal_major, _ = _get_cuda_bare_metal_version( - cpp_extension.CUDA_HOME) - if int(bare_metal_major) >= 11: - cc_flag.append('-gencode') - cc_flag.append('arch=compute_80,code=sm_80') - - # Build path - srcpath = pathlib.Path(__file__).parent.absolute() - buildpath = srcpath / 'build' - _create_build_dir(buildpath) - - # Helper function to build the kernels. - def _cpp_extention_load_helper(name, sources, extra_cuda_flags, extra_include_paths): - if torch.version.hip is not None: - extra_cuda_cflags=['-O3'] + extra_cuda_flags + cc_flag - else: - extra_cuda_cflags=['-O3', - '-gencode', 'arch=compute_70,code=sm_70', - '--use_fast_math'] + extra_cuda_flags + cc_flag - - return cpp_extension.load( - name=name, - sources=sources, - build_directory=buildpath, - extra_cflags=['-O3',], - extra_cuda_cflags=extra_cuda_cflags, - extra_include_paths=extra_include_paths, - verbose=(args.rank == 0) - ) - - # ============== - # Fused softmax. - # ============== - - if torch.version.hip is not None: - extra_include_paths=[os.path.abspath(srcpath)] - else: - extra_include_paths=[] - - if args.masked_softmax_fusion: - if torch.version.hip is not None: - extra_cuda_flags = ['-D__HIP_NO_HALF_OPERATORS__=1', - '-D__HIP_NO_HALF_CONVERSIONS__=1'] - else: - extra_cuda_flags = ['-U__CUDA_NO_HALF_OPERATORS__', - '-U__CUDA_NO_HALF_CONVERSIONS__', - '--expt-relaxed-constexpr', - '--expt-extended-lambda'] - - # Upper triangular softmax. - sources=[srcpath / 'scaled_upper_triang_masked_softmax.cpp', - srcpath / 'scaled_upper_triang_masked_softmax_cuda.cu'] - scaled_upper_triang_masked_softmax_cuda = _cpp_extention_load_helper( - "scaled_upper_triang_masked_softmax_cuda", - sources, extra_cuda_flags, extra_include_paths) - - # Masked softmax. - sources=[srcpath / 'scaled_masked_softmax.cpp', - srcpath / 'scaled_masked_softmax_cuda.cu'] - scaled_masked_softmax_cuda = _cpp_extention_load_helper( - "scaled_masked_softmax_cuda", sources, extra_cuda_flags, extra_include_paths) - - # ================================= - # Mixed precision fused layer norm. - # ================================= - - if torch.version.hip is not None: - extra_cuda_flags = [] - else: - extra_cuda_flags = ['-maxrregcount=50'] - - sources=[srcpath / 'layer_norm_cuda.cpp', - srcpath / 'layer_norm_cuda_kernel.cu'] - fused_mix_prec_layer_norm_cuda = _cpp_extention_load_helper( - "fused_mix_prec_layer_norm_cuda", sources, extra_cuda_flags, extra_include_paths) - - -def _get_cuda_bare_metal_version(cuda_dir): - raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], - universal_newlines=True) - output = raw_output.split() - release_idx = output.index("release") + 1 - release = output[release_idx].split(".") - bare_metal_major = release[0] - bare_metal_minor = release[1][0] - - return raw_output, bare_metal_major, bare_metal_minor - - -def _create_build_dir(buildpath): - try: - os.mkdir(buildpath) - except OSError: - if not os.path.isdir(buildpath): - print(f"Creation of the build directory {buildpath} failed") diff --git a/megatron/fused_kernels/compat.h b/megatron/fused_kernels/compat.h deleted file mode 100644 index 92e7eb772..000000000 --- a/megatron/fused_kernels/compat.h +++ /dev/null @@ -1,31 +0,0 @@ -/* coding=utf-8 - * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/*This code is copied fron NVIDIA apex: - * https://github.com/NVIDIA/apex - * with minor changes. */ - - - -#ifndef TORCH_CHECK -#define TORCH_CHECK AT_CHECK -#endif - -#ifdef VERSION_GE_1_3 -#define DATA_PTR data_ptr -#else -#define DATA_PTR data -#endif diff --git a/megatron/fused_kernels/layer_norm_cuda.cpp b/megatron/fused_kernels/layer_norm_cuda.cpp deleted file mode 100644 index 8f28e7b4a..000000000 --- a/megatron/fused_kernels/layer_norm_cuda.cpp +++ /dev/null @@ -1,201 +0,0 @@ -/* coding=utf-8 - * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/*This code is copied fron NVIDIA apex: - * https://github.com/NVIDIA/apex - * with minor changes. */ - -#include -#include -#include -#include "compat.h" - -namespace { - -void compute_n1_n2( - at::Tensor input, - at::IntArrayRef normalized_shape, - int& n1, - int& n2) { - int idiff = input.ndimension() - normalized_shape.size(); - n2 = 1; - for (int i = 0; i < (int)normalized_shape.size(); ++i) { - assert( input.sizes()[i+idiff] == normalized_shape[i] ); - n2 *= normalized_shape[i]; - } - n1 = 1; - for (int i = 0; i < idiff; ++i) { - n1 *= input.sizes()[i]; - } -} - -void check_args( - at::IntArrayRef normalized_shape, - at::Tensor gamma, - at::Tensor beta - ) -{ - TORCH_CHECK(!gamma.defined() || gamma.sizes().equals(normalized_shape)); - TORCH_CHECK(!beta.defined() || beta.sizes().equals(normalized_shape)); -} - -void check_args( - at::Tensor input, - at::IntArrayRef normalized_shape, - int& n1, - int& n2 - ) -{ - int64_t normalized_ndim = normalized_shape.size(); - - if (normalized_ndim < 1) { - std::stringstream ss; - ss << "Expected normalized_shape to be at least 1-dimensional, i.e., " - << "containing at least one element, but got normalized_shape=" - << normalized_shape; - throw std::runtime_error(ss.str()); - } - - auto input_shape = input.sizes(); - auto input_ndim = input.dim(); - - if (input_ndim < normalized_ndim || - !input_shape.slice(input_ndim - normalized_ndim).equals(normalized_shape)) { - std::stringstream ss; - ss << "Given normalized_shape=" << normalized_shape - << ", expected input with shape [*"; - for (auto size : normalized_shape) { - ss << ", " << size; - } - ss << "], but got input of size" << input_shape; - throw std::runtime_error(ss.str()); - } - - compute_n1_n2(input,normalized_shape,n1,n2); -} - - -void check_args( - at::Tensor input, - at::IntArrayRef normalized_shape, - at::Tensor gamma, - at::Tensor beta, - int& n1, - int& n2 - ) -{ - check_args(input,normalized_shape,n1,n2); - check_args(normalized_shape,gamma,beta); -} -} - -void cuda_layer_norm( - at::Tensor* output, - at::Tensor* mean, - at::Tensor* invvar, - at::Tensor* input, - int n1, - int n2, - at::IntArrayRef normalized_shape, - at::Tensor* gamma, - at::Tensor* beta, - double epsilon); - -#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") -#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") -#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) - -std::vector layer_norm_affine( - at::Tensor input, - at::IntArrayRef normalized_shape, - at::Tensor gamma, - at::Tensor beta, - double epsilon) { - - CHECK_INPUT(input); - CHECK_INPUT(gamma); - CHECK_INPUT(beta); - int n1, n2; - check_args(input, normalized_shape, gamma, beta, n1, n2); - - at::Tensor output = at::empty_like( - input, gamma.options().dtype(gamma.scalar_type())); - at::Tensor mean = at::empty( - {n1}, input.options().dtype(at::ScalarType::Float)); - at::Tensor invvar = at::empty_like(mean); - - cuda_layer_norm(&output, &mean, &invvar, &input, n1, n2, - normalized_shape, &gamma, &beta, epsilon); - - return {output, mean, invvar}; - -} - - -void cuda_layer_norm_gradient( - at::Tensor* dout, - at::Tensor* mean, - at::Tensor* invvar, - at::Tensor* input, - int n1, - int n2, - at::IntArrayRef normalized_shape, - at::Tensor* gamma, - at::Tensor* beta, - double epsilon, - at::Tensor* grad_input, - at::Tensor* grad_gamma, - at::Tensor* grad_beta - ); - -std::vector layer_norm_gradient_affine( - at::Tensor dout, - at::Tensor mean, - at::Tensor invvar, - at::Tensor input, - at::IntArrayRef normalized_shape, - at::Tensor gamma, - at::Tensor beta, - double epsilon) { - - CHECK_INPUT(dout); - CHECK_INPUT(mean); - CHECK_INPUT(invvar); - CHECK_INPUT(input); - CHECK_INPUT(gamma); - CHECK_INPUT(beta); - int n1, n2; - check_args(input, normalized_shape, gamma, beta, n1, n2); - - at::Tensor grad_input = at::empty_like(input); - at::Tensor grad_gamma = at::empty_like(gamma); - at::Tensor grad_beta = at::empty_like(beta); - - cuda_layer_norm_gradient(&dout, &mean, &invvar, &input, n1, n2, - normalized_shape, &gamma, &beta, epsilon, - &grad_input, &grad_gamma, &grad_beta); - - return {grad_input, grad_gamma, grad_beta}; - -} - - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("forward_affine", &layer_norm_affine, - "LayerNorm forward (CUDA)"); - m.def("backward_affine", &layer_norm_gradient_affine, - "LayerNorm backward (CUDA)"); -} diff --git a/megatron/fused_kernels/layer_norm_cuda_kernel.cu b/megatron/fused_kernels/layer_norm_cuda_kernel.cu deleted file mode 100644 index 8a07806b1..000000000 --- a/megatron/fused_kernels/layer_norm_cuda_kernel.cu +++ /dev/null @@ -1,866 +0,0 @@ -/* coding=utf-8 - * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/*This code is copied fron NVIDIA apex: - * https://github.com/NVIDIA/apex - * with minor changes. */ - -#include "ATen/ATen.h" -#include "ATen/AccumulateType.h" -#include "ATen/cuda/CUDAContext.h" -#include "ATen/cuda/DeviceUtils.cuh" - -#include -#include - -#include "type_shim.h" - -template __device__ -void cuWelfordOnlineSum( - const U curr, - U& mu, - U& sigma2, - U& count) -{ - count = count + U(1); - U delta = curr - mu; - U lmean = mu + delta / count; - mu = lmean; - U delta2 = curr - lmean; - sigma2 = sigma2 + delta * delta2; -} - -template __device__ -void cuChanOnlineSum( - const U muB, - const U sigma2B, - const U countB, - U& mu, - U& sigma2, - U& count) -{ - U delta = muB - mu; - U nA = count; - U nB = countB; - count = count + countB; - U nX = count; - if (nX > U(0)) { - nA = nA / nX; - nB = nB / nX; - mu = nA*mu + nB*muB; - sigma2 = sigma2 + sigma2B + delta * delta * nA * nB * nX; - } else { - mu = U(0); - sigma2 = U(0); - } -} - -template __device__ -void cuWelfordMuSigma2( - const T* __restrict__ vals, - const int n1, - const int n2, - const int i1, - U& mu, - U& sigma2, - U* buf, - const int GPU_WARP_SIZE) -{ - // Assumptions: - // 1) blockDim.x == warpSize - // 2) Tensor is contiguous - // 3) 2*blockDim.y*sizeof(U)+blockDim.y*sizeof(int) shared memory available. - // - // compute variance and mean over n2 - U count = U(0); - mu= U(0); - sigma2 = U(0); - if (i1 < n1) { - // one warp normalizes one n1 index, - // synchronization is implicit - // initialize with standard Welford algorithm - const int numx = blockDim.x * blockDim.y; - const int thrx = threadIdx.x + threadIdx.y * blockDim.x; - const T* lvals = vals + i1*n2; - int l = 4*thrx; - for (; l+3 < n2; l+=4*numx) { - for (int k = 0; k < 4; ++k) { - U curr = static_cast(lvals[l+k]); - cuWelfordOnlineSum(curr,mu,sigma2,count); - } - } - for (; l < n2; ++l) { - U curr = static_cast(lvals[l]); - cuWelfordOnlineSum(curr,mu,sigma2,count); - } - // intra-warp reductions - for (int stride = GPU_WARP_SIZE / 2; stride > 0; stride /= 2) { - U sigma2B = WARP_SHFL_DOWN(sigma2, stride); - U muB = WARP_SHFL_DOWN(mu, stride); - U countB = WARP_SHFL_DOWN(count, stride); - cuChanOnlineSum(muB, sigma2B, countB, mu, sigma2, count); - } - // threadIdx.x == 0 has correct values for each warp - // inter-warp reductions - if (blockDim.y > 1) { - U* ubuf = (U*)buf; - U* ibuf = (U*)(ubuf + blockDim.y); - for (int offset = blockDim.y/2; offset > 0; offset /= 2) { - // upper half of warps write to shared - if (threadIdx.x == 0 && threadIdx.y >= offset && threadIdx.y < 2*offset) { - const int wrt_y = threadIdx.y - offset; - ubuf[2*wrt_y] = mu; - ubuf[2*wrt_y+1] = sigma2; - ibuf[wrt_y] = count; - } - __syncthreads(); - // lower half merges - if (threadIdx.x == 0 && threadIdx.y < offset) { - U muB = ubuf[2*threadIdx.y]; - U sigma2B = ubuf[2*threadIdx.y+1]; - U countB = ibuf[threadIdx.y]; - cuChanOnlineSum(muB,sigma2B,countB,mu,sigma2,count); - } - __syncthreads(); - } - // threadIdx.x = 0 && threadIdx.y == 0 only thread that has correct values - if (threadIdx.x == 0 && threadIdx.y == 0) { - ubuf[0] = mu; - ubuf[1] = sigma2; - } - __syncthreads(); - mu = ubuf[0]; - sigma2 = ubuf[1]/U(n2); - // don't care about final value of count, we know count == n2 - } else { - mu = WARP_SHFL(mu, 0); - sigma2 = WARP_SHFL(sigma2/U(n2), 0); - } - } -} - -template<> __device__ -void cuWelfordMuSigma2( - const at::Half* __restrict__ vals, - const int n1, - const int n2, - const int i1, - float& mu, - float& sigma2, - float* buf, - const int GPU_WARP_SIZE) -{ - // Assumptions: - // 1) blockDim.x == warpSize - // 2) Tensor is contiguous - // 3) 2*blockDim.y*sizeof(U)+blockDim.y*sizeof(int) shared memory available. - // - // compute variance and mean over n2 - float count = 0.0f; - mu= float(0); - sigma2 = float(0); - if (i1 < n1) { - // one warp normalizes one n1 index, - // synchronization is implicit - // initialize with standard Welford algorithm - const int numx = blockDim.x * blockDim.y; - const int thrx = threadIdx.x + threadIdx.y * blockDim.x; - const at::Half* lvals = vals + i1*n2; - int l = 8*thrx; - if ((((size_t)lvals)&3) != 0) { - // 16 bit alignment - // first thread consumes first point - if (thrx == 0) { - float curr = static_cast(lvals[0]); - cuWelfordOnlineSum(curr,mu,sigma2,count); - } - ++l; - } - // at this point, lvals[l] are 32 bit aligned for all threads. - for (; l+7 < n2; l+=8*numx) { - for (int k = 0; k < 8; k+=2) { - float2 curr = __half22float2(*((__half2*)(lvals+l+k))); - cuWelfordOnlineSum(curr.x,mu,sigma2,count); - cuWelfordOnlineSum(curr.y,mu,sigma2,count); - } - } - for (; l < n2; ++l) { - float curr = static_cast(lvals[l]); - cuWelfordOnlineSum(curr,mu,sigma2,count); - } - // intra-warp reductions - for (int stride = GPU_WARP_SIZE / 2; stride > 0; stride /= 2) { - float sigma2B = WARP_SHFL_DOWN(sigma2, stride); - float muB = WARP_SHFL_DOWN(mu, stride); - float countB = WARP_SHFL_DOWN(count, stride); - cuChanOnlineSum(muB, sigma2B, countB, mu, sigma2, count); - } - // threadIdx.x == 0 has correct values for each warp - // inter-warp reductions - if (blockDim.y > 1) { - float* ubuf = (float*)buf; - float* ibuf = (float*)(ubuf + blockDim.y); - for (int offset = blockDim.y/2; offset > 0; offset /= 2) { - // upper half of warps write to shared - if (threadIdx.x == 0 && threadIdx.y >= offset && threadIdx.y < 2*offset) { - const int wrt_y = threadIdx.y - offset; - ubuf[2*wrt_y] = mu; - ubuf[2*wrt_y+1] = sigma2; - ibuf[wrt_y] = count; - } - __syncthreads(); - // lower half merges - if (threadIdx.x == 0 && threadIdx.y < offset) { - float muB = ubuf[2*threadIdx.y]; - float sigma2B = ubuf[2*threadIdx.y+1]; - float countB = ibuf[threadIdx.y]; - cuChanOnlineSum(muB,sigma2B,countB,mu,sigma2,count); - } - __syncthreads(); - } - // threadIdx.x = 0 && threadIdx.y == 0 only thread that has correct values - if (threadIdx.x == 0 && threadIdx.y == 0) { - ubuf[0] = mu; - ubuf[1] = sigma2; - } - __syncthreads(); - mu = ubuf[0]; - sigma2 = ubuf[1]/float(n2); - // don't care about final value of count, we know count == n2 - } else { - mu = WARP_SHFL(mu, 0); - sigma2 = WARP_SHFL(sigma2/float(n2), 0); - } - } -} -#ifndef __HIP_PLATFORM_HCC__ -template U rsqrt(U v) { -#else -template __device__ U rsqrt(U v) { -#endif - return U(1) / sqrt(v); -} -#ifndef __HIP_PLATFORM_HCC__ -template<> float rsqrt(float v) { -#else -template<> __device__ float rsqrt(float v) { -#endif - return rsqrtf(v); -} -#ifndef __HIP_PLATFORM_HCC__ -template<> double rsqrt(double v) { -#else -template<> __device__ double rsqrt(double v) { -#endif - return rsqrt(v); -} - -namespace { -// This is the un-specialized struct. Note that we prevent instantiation of this -// struct by putting an undefined symbol in the function body so it won't compile. -// template -// struct SharedMemory -// { -// // Ensure that we won't compile any un-specialized types -// __device__ T *getPointer() -// { -// extern __device__ void error(void); -// error(); -// return NULL; -// } -// }; -// https://github.com/NVIDIA/apex/issues/246 -template -struct SharedMemory; - -template <> -struct SharedMemory -{ - __device__ float *getPointer() - { - extern __shared__ float s_float[]; - return s_float; - } -}; - -} - -template __global__ -void cuApplyLayerNorm( - V* __restrict__ output_vals, - U* __restrict__ mean, - U* __restrict__ invvar, - const T* __restrict__ vals, - const int n1, - const int n2, - const U epsilon, - const V* __restrict__ gamma, - const V* __restrict__ beta, - const int GPU_WARP_SIZE - ) -{ - // Assumptions: - // 1) blockDim.x == warpSize - // 2) Tensors are contiguous - // -#ifndef __HIP_PLATFORM_HCC__ - for (auto i1=blockIdx.y; i1 < n1; i1 += gridDim.y) { -#else - for (int i1=blockIdx.y; i1 < n1; i1 += gridDim.y) { -#endif - SharedMemory shared; - U* buf = shared.getPointer(); - U mu,sigma2; - cuWelfordMuSigma2(vals,n1,n2,i1,mu,sigma2,buf,GPU_WARP_SIZE); - const T* lvals = vals + i1*n2; - V* ovals = output_vals + i1*n2; - U c_invvar = rsqrt(sigma2 + epsilon); - const int numx = blockDim.x * blockDim.y; - const int thrx = threadIdx.x + threadIdx.y * blockDim.x; - if (gamma != NULL && beta != NULL) { - for (int i = thrx; i < n2; i+=numx) { - U curr = static_cast(lvals[i]); - ovals[i] = gamma[i] * static_cast(c_invvar * (curr - mu)) + beta[i]; - } - } else { - for (int i = thrx; i < n2; i+=numx) { - U curr = static_cast(lvals[i]); - ovals[i] = static_cast(c_invvar * (curr - mu)); - } - } - if (threadIdx.x == 0 && threadIdx.y == 0) { - mean[i1] = mu; - invvar[i1] = c_invvar; - } - __syncthreads(); - } -} - -template __device__ -void cuLoadWriteStridedInputs( - const int i1_block, - const int thr_load_row_off, - const int thr_load_col_off, - const int i2_off, - const int row_stride, - U* warp_buf1, - U* warp_buf2, - const T* input, - const V* dout, - const int i1_end, - const int n2, - const U* __restrict__ mean, - const U* __restrict__ invvar - ) -{ - int i1 = i1_block+thr_load_row_off; - if (i1 < i1_end) { - U curr_mean = mean[i1]; - U curr_invvar = invvar[i1]; - for (int k = 0; k < blockDim.y; ++k) { - int i2 = i2_off + k; - int load_idx = i1*n2+i2; - int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k; - if (i2(input[load_idx]); - U curr_dout = static_cast(dout[load_idx]); - warp_buf1[write_idx] = curr_dout; - warp_buf2[write_idx] = curr_dout * (curr_input - curr_mean) * curr_invvar; - } else { - warp_buf1[write_idx] = U(0); - warp_buf2[write_idx] = U(0); - } - } - } else { - for (int k = 0; k < blockDim.y; ++k) { - int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k; - warp_buf1[write_idx] = U(0); - warp_buf2[write_idx] = U(0); - } - } -} - -template __device__ -void cuLoadAddStridedInputs( - const int i1_block, - const int thr_load_row_off, - const int thr_load_col_off, - const int i2_off, - const int row_stride, - U* warp_buf1, - U* warp_buf2, - const T* input, - const V* dout, - const int i1_end, - const int n2, - const U* __restrict__ mean, - const U* __restrict__ invvar - ) -{ - int i1 = i1_block+thr_load_row_off; - if (i1 < i1_end) { - U curr_mean = mean[i1]; - U curr_invvar = invvar[i1]; - for (int k = 0; k < blockDim.y; ++k) { - int i2 = i2_off + k; - int load_idx = i1*n2+i2; - int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k; - if (i2(input[load_idx]); - U curr_dout = static_cast(dout[load_idx]); - warp_buf1[write_idx] += curr_dout; - warp_buf2[write_idx] += curr_dout * (curr_input - curr_mean) * curr_invvar; - } - } - } -} - -template __global__ -void cuComputePartGradGammaBeta( - const V* __restrict__ dout, - const T* __restrict__ input, - const int n1, - const int n2, - const U* __restrict__ mean, - const U* __restrict__ invvar, - U epsilon, - U* part_grad_gamma, - U* part_grad_beta) -{ - const int numsegs_n1 = (n1+blockDim.y*blockDim.y-1) / (blockDim.y*blockDim.y); - const int segs_per_block = (numsegs_n1 + gridDim.y - 1) / gridDim.y; - const int i1_beg = blockIdx.y * segs_per_block * blockDim.y*blockDim.y; - const int i1_beg_plus_one = (blockIdx.y+1) * segs_per_block * blockDim.y*blockDim.y; - const int i1_end = i1_beg_plus_one < n1 ? i1_beg_plus_one : n1; - const int row_stride = blockDim.x+1; - const int thr_load_col_off = (threadIdx.x*blockDim.y)&(blockDim.x-1); - const int thr_load_row_off = (threadIdx.x*blockDim.y)/blockDim.x + threadIdx.y*blockDim.y; - const int i2_off = blockIdx.x * blockDim.x + thr_load_col_off; - SharedMemory shared; - U* buf = shared.getPointer(); // buf has at least blockDim.x * blockDim.y * blockDim.y + (blockDim.y - 1)*(blockDim.x/blockDim.y) elements - U* warp_buf1 = (U*)buf; - U* warp_buf2 = warp_buf1 + blockDim.y * blockDim.y * row_stride; - // compute partial sums from strided inputs - // do this to increase number of loads in flight - cuLoadWriteStridedInputs(i1_beg,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,n2,mean,invvar); - for (int i1_block = i1_beg+blockDim.y*blockDim.y; i1_block < i1_end; i1_block+=blockDim.y*blockDim.y) { - cuLoadAddStridedInputs(i1_block,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,n2,mean,invvar); - } - __syncthreads(); - // inter-warp reductions - // sum within each warp - U acc1 = U(0); - U acc2 = U(0); - for (int k = 0; k < blockDim.y; ++k) { - int row1 = threadIdx.y + k*blockDim.y; - int idx1 = row1*row_stride + threadIdx.x; - acc1 += warp_buf1[idx1]; - acc2 += warp_buf2[idx1]; - } - warp_buf1[threadIdx.y*row_stride+threadIdx.x] = acc1; - warp_buf2[threadIdx.y*row_stride+threadIdx.x] = acc2; - __syncthreads(); - // sum all warps - for (int offset = blockDim.y/2; offset > 1; offset /= 2) { - if (threadIdx.y < offset) { - int row1 = threadIdx.y; - int row2 = threadIdx.y + offset; - int idx1 = row1*row_stride + threadIdx.x; - int idx2 = row2*row_stride + threadIdx.x; - warp_buf1[idx1] += warp_buf1[idx2]; - warp_buf2[idx1] += warp_buf2[idx2]; - } - __syncthreads(); - } - int i2 = blockIdx.x * blockDim.x + threadIdx.x; - if (threadIdx.y == 0 && i2 < n2) { - int row1 = threadIdx.y; - int row2 = threadIdx.y + 1; - int idx1 = row1*row_stride + threadIdx.x; - int idx2 = row2*row_stride + threadIdx.x; - part_grad_beta[blockIdx.y*n2+i2] = warp_buf1[idx1] + warp_buf1[idx2]; - part_grad_gamma[blockIdx.y*n2+i2] = warp_buf2[idx1] + warp_buf2[idx2]; - } -} - -template __global__ -void cuComputeGradGammaBeta( - const U* part_grad_gamma, - const U* part_grad_beta, - const int part_size, - const int n1, - const int n2, - V* grad_gamma, - V* grad_beta) -{ - // sum partial gradients for gamma and beta - SharedMemory shared; - U* buf = shared.getPointer(); - int i2 = blockIdx.x * blockDim.x + threadIdx.x; - if (i2 < n2) { - // each warp does sequential reductions until reduced part_size is num_warps - int num_warp_reductions = part_size / blockDim.y; - U sum_gamma = U(0); - U sum_beta = U(0); - const U* part_grad_gamma_ptr = part_grad_gamma + threadIdx.y * num_warp_reductions * n2 + i2; - const U* part_grad_beta_ptr = part_grad_beta + threadIdx.y * num_warp_reductions * n2 + i2; - for (int warp_offset = 0; warp_offset < num_warp_reductions; ++warp_offset) { - sum_gamma += part_grad_gamma_ptr[warp_offset*n2]; - sum_beta += part_grad_beta_ptr[warp_offset*n2]; - } - // inter-warp reductions - const int nbsize3 = blockDim.x * blockDim.y / 2; - for (int offset = blockDim.y/2; offset >= 1; offset /= 2) { - // top half write to shared memory - if (threadIdx.y >= offset && threadIdx.y < 2*offset) { - const int write_idx = (threadIdx.y - offset) * blockDim.x + threadIdx.x; - buf[write_idx] = sum_gamma; - buf[write_idx+nbsize3] = sum_beta; - } - __syncthreads(); - // bottom half sums - if (threadIdx.y < offset) { - const int read_idx = threadIdx.y * blockDim.x + threadIdx.x; - sum_gamma += buf[read_idx]; - sum_beta += buf[read_idx+nbsize3]; - } - __syncthreads(); - } - // write out fully summed gradients - if (threadIdx.y == 0) { - grad_gamma[i2] = sum_gamma; - grad_beta[i2] = sum_beta; - } - } -} - -template __global__ -void cuComputeGradInput( - const V* __restrict__ dout, - const T* __restrict__ input, - const int n1, - const int n2, - const U* __restrict__ mean, - const U* __restrict__ invvar, - U epsilon, - const V* gamma, - T* grad_input) -{ -#ifndef __HIP_PLATFORM_HCC__ - for (auto i1=blockIdx.y; i1 < n1; i1 += gridDim.y) { -#else - for (int i1=blockIdx.y; i1 < n1; i1 += gridDim.y) { -#endif - U sum_loss1 = U(0); - U sum_loss2 = U(0); - const U c_mean = mean[i1]; - const U c_invvar = invvar[i1]; - const T* k_input = input + i1*n2; - const V* k_dout = dout + i1*n2; - const int numx = blockDim.x * blockDim.y; - const int thrx = threadIdx.x + threadIdx.y * blockDim.x; - if (gamma != NULL) { - int l = 4*thrx; - for (; l+3 < n2; l+=4*numx) { - for (int k = 0; k < 4; ++k) { - const U c_h = static_cast(k_input[l+k]); - const U c_loss = static_cast(k_dout[l+k]); - sum_loss1 += c_loss * gamma[l+k]; - sum_loss2 += c_loss * gamma[l+k] * (c_h - c_mean) * c_invvar; - } - } - for (; l < n2; ++l) { - const U c_h = static_cast(k_input[l]); - const U c_loss = static_cast(k_dout[l]); - sum_loss1 += c_loss * gamma[l]; - sum_loss2 += c_loss * gamma[l] * (c_h - c_mean) * c_invvar; - } - } else { - int l = 4*thrx; - for (; l+3 < n2; l+=4*numx) { - for (int k = 0; k < 4; ++k) { - const U c_h = static_cast(k_input[l+k]); - const U c_loss = static_cast(k_dout[l+k]); - sum_loss1 += c_loss; - sum_loss2 += c_loss * (c_h - c_mean) * c_invvar; - } - } - for (; l < n2; ++l) { - const U c_h = static_cast(k_input[l]); - const U c_loss = static_cast(k_dout[l]); - sum_loss1 += c_loss; - sum_loss2 += c_loss * (c_h - c_mean) * c_invvar; - } - } - // intra-warp reductions - for (int mask = blockDim.x/2; mask > 0; mask /= 2) { - sum_loss1 += WARP_SHFL_XOR(sum_loss1, mask); - sum_loss2 += WARP_SHFL_XOR(sum_loss2, mask); - } - // inter-warp reductions - if (blockDim.y > 1) { - SharedMemory shared; - U* buf = shared.getPointer(); - for (int offset = blockDim.y/2; offset > 0; offset /= 2) { - // upper half of warps write to shared - if (threadIdx.y >= offset && threadIdx.y < 2*offset) { - const int wrt_i = (threadIdx.y - offset) * blockDim.x + threadIdx.x; - buf[2*wrt_i] = sum_loss1; - buf[2*wrt_i+1] = sum_loss2; - } - __syncthreads(); - // lower half merges - if (threadIdx.y < offset) { - const int read_i = threadIdx.y * blockDim.x + threadIdx.x; - sum_loss1 += buf[2*read_i]; - sum_loss2 += buf[2*read_i+1]; - } - __syncthreads(); - } - if (threadIdx.y == 0) { - buf[2*threadIdx.x] = sum_loss1; - buf[2*threadIdx.x+1] = sum_loss2; - } - __syncthreads(); - if (threadIdx.y !=0) { - sum_loss1 = buf[2*threadIdx.x]; - sum_loss2 = buf[2*threadIdx.x+1]; - } - } - // all threads now have the two sums over l - U fH = (U)n2; - U term1 = (U(1) / fH) * c_invvar; - T* k_grad_input = grad_input + i1*n2; - if (gamma != NULL) { - for (int l = thrx; l < n2; l+=numx) { - const U c_h = static_cast(k_input[l]); - const U c_loss = static_cast(k_dout[l]); - U f_grad_input = fH * c_loss * gamma[l]; - f_grad_input -= sum_loss1; - f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2; - f_grad_input *= term1; - k_grad_input[l] = static_cast(f_grad_input); - } - } else { - for (int l = thrx; l < n2; l+=numx) { - const U c_h = static_cast(k_input[l]); - const U c_loss = static_cast(k_dout[l]); - U f_grad_input = fH * c_loss; - f_grad_input -= sum_loss1; - f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2; - f_grad_input *= term1; - k_grad_input[l] = static_cast(f_grad_input); - } - } - // prevent race where buf is written again before reads are done - __syncthreads(); - } -} - - - - -template -void HostApplyLayerNorm( - V* output, - U* mean, - U* invvar, - const T* input, - int n1, - int n2, - double epsilon, - const V* gamma, - const V* beta - ) -{ - auto stream = at::cuda::getCurrentCUDAStream().stream(); - const int warp_size = at::cuda::warp_size(); - dim3 threads(warp_size,4,1); -#ifndef __HIP_PLATFORM_HCC__ - threads.y = 1; -#endif - const uint64_t maxGridY = - at::cuda::getCurrentDeviceProperties()->maxGridSize[1]; - const dim3 blocks(1, std::min((uint64_t)n1, maxGridY), 1); - int nshared = - threads.y > 1 ? - threads.y*sizeof(U)+(threads.y/2)*sizeof(U) : - 0; - cuApplyLayerNorm<<>>( - output, - mean, - invvar, - input, - n1,n2, - U(epsilon), - gamma, - beta, - warp_size); -} - - -void cuda_layer_norm( - at::Tensor* output, - at::Tensor* mean, - at::Tensor* invvar, - at::Tensor* input, - int n1, - int n2, - #ifdef VERSION_GE_1_1 - at::IntArrayRef normalized_shape, - #else - at::IntList normalized_shape, - #endif - at::Tensor* gamma, - at::Tensor* beta, - double epsilon) -{ - using namespace at; - DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES( - input->scalar_type(), output->scalar_type(), "cuda_layer_norm_kernel", - HostApplyLayerNorm( - output->DATA_PTR(), - mean->DATA_PTR(), - invvar->DATA_PTR(), - input->DATA_PTR(), - n1,n2, - epsilon, - gamma != NULL ? gamma->DATA_PTR() : NULL, - beta != NULL ? beta->DATA_PTR() : NULL); - ) -} - - -template -void HostLayerNormGradient( - const V* dout, - const U* mean, - const U* invvar, - at::Tensor* input, - int n1, - int n2, - const V* gamma, - const V* beta, - double epsilon, - T* grad_input, - V* grad_gamma, - V* grad_beta - ) -{ - auto stream = at::cuda::getCurrentCUDAStream().stream(); - const int warp_size = at::cuda::warp_size(); - - if (gamma != NULL && beta != NULL) { - // compute grad_gamma(j) and grad_beta(j) -#ifndef __HIP_PLATFORM_HCC__ - const int part_size = warp_size; -#else - const int part_size = 16; -#endif - const dim3 threads2(warp_size,4,1); - const dim3 blocks2((n2+threads2.x-1)/threads2.x,part_size,1); - const int nshared2_a = 2 * sizeof(U) * threads2.y * threads2.y * - (threads2.x + 1); - const int nshared2_b = threads2.x * threads2.y * sizeof(U); - const int nshared2 = nshared2_a > nshared2_b ? nshared2_a : nshared2_b; - at::Tensor part_grad_gamma = at::empty( - {part_size,n2}, input->options().dtype(at::ScalarType::Float)); - at::Tensor part_grad_beta = at::empty_like(part_grad_gamma); - cuComputePartGradGammaBeta<<>>( - dout, - input->DATA_PTR(), - n1,n2, - mean, - invvar, - U(epsilon), - part_grad_gamma.DATA_PTR(), - part_grad_beta.DATA_PTR()); - - const dim3 threads3(warp_size,8,1); - const dim3 blocks3((n2+threads2.x-1)/threads2.x,1,1); - const int nshared3 = threads3.x * threads3.y * sizeof(U); - cuComputeGradGammaBeta<<>>( - part_grad_gamma.DATA_PTR(), - part_grad_beta.DATA_PTR(), - part_size, - n1,n2, - grad_gamma, - grad_beta); - } - - // compute grad_input - const uint64_t maxGridY = - at::cuda::getCurrentDeviceProperties()->maxGridSize[1]; - const dim3 blocks1(1, std::min((uint64_t)n1, maxGridY), 1); - dim3 threads1(warp_size,4,1); -#ifndef __HIP_PLATFORM_HCC__ - threads1.y = 2; -#endif - int nshared = - threads1.y > 1 ? - threads1.y*threads1.x*sizeof(U) : - 0; - cuComputeGradInput<<>>( - dout, - input->DATA_PTR(), - n1,n2, - mean, - invvar, - U(epsilon), - gamma, - grad_input); -} - - -void cuda_layer_norm_gradient( - at::Tensor* dout, - at::Tensor* mean, - at::Tensor* invvar, - at::Tensor* input, - int n1, - int n2, - #ifdef VERSION_GE_1_1 - at::IntArrayRef normalized_shape, - #else - at::IntList normalized_shape, - #endif - at::Tensor* gamma, - at::Tensor* beta, - double epsilon, - at::Tensor* grad_input, - at::Tensor* grad_gamma, - at::Tensor* grad_beta) -{ - using namespace at; - DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES( - input->scalar_type(), gamma->scalar_type(), - "cuda_layer_norm_gradient_kernel", - HostLayerNormGradient( - dout->DATA_PTR(), - mean->DATA_PTR(), - invvar->DATA_PTR(), - input, - n1,n2, - // TMJ pass NULL argument for gamma, beta, grad_gamma and grad_beta - // if gamma Tensor is NULL on input. - gamma != NULL ? gamma->DATA_PTR() : NULL, - gamma != NULL ? beta->DATA_PTR() : NULL, - epsilon, - grad_input->DATA_PTR(), - gamma != NULL ? grad_gamma->DATA_PTR() : NULL, - gamma != NULL ? grad_beta->DATA_PTR() : NULL); - ) -} diff --git a/megatron/fused_kernels/scaled_masked_softmax.cpp b/megatron/fused_kernels/scaled_masked_softmax.cpp deleted file mode 100644 index d5334710c..000000000 --- a/megatron/fused_kernels/scaled_masked_softmax.cpp +++ /dev/null @@ -1,77 +0,0 @@ -/* coding=utf-8 - * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include - -namespace multihead_attn { -namespace fused_softmax { -namespace scaled_masked_softmax { - -torch::Tensor fwd_cuda( - torch::Tensor const& input, - torch::Tensor const& mask, - float scale_factor); - -torch::Tensor bwd_cuda( - torch::Tensor const& output_grads, - torch::Tensor const& softmax_results, - float scale_factor); - -torch::Tensor fwd( - torch::Tensor const& input, - torch::Tensor const& mask, - float scale_factor) { - AT_ASSERTM(input.dim() == 4, "expected 4D tensor"); - AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) || - (input.scalar_type() == at::ScalarType::BFloat16), - "Only fp16 and bf16 are supported"); - AT_ASSERTM(mask.dim() == 4, "expected 4D tensor"); - - return fwd_cuda(input, mask, scale_factor); -} - -torch::Tensor bwd( - torch::Tensor const& output_grads, - torch::Tensor const& softmax_results, - float scale_factor) { - - AT_ASSERTM(output_grads.dim() == 4, "expected 3D tensor"); - AT_ASSERTM(softmax_results.dim() == 4, "expected 3D tensor"); - - AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) || - (output_grads.scalar_type() == at::ScalarType::BFloat16), - "Only fp16 and bf16 are supported"); - AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) || - (softmax_results.scalar_type() == at::ScalarType::BFloat16), - "Only fp16 and bf16 are supported"); - - return bwd_cuda(output_grads, softmax_results, scale_factor); -} - -} // end namespace scaled_masked_softmax -} // end namespace fused_softmax -} // end namespace multihead_attn - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("forward", - &multihead_attn::fused_softmax::scaled_masked_softmax::fwd, - "Self Multihead Attention scaled, time masked softmax -- Forward."); - m.def("backward", - &multihead_attn::fused_softmax::scaled_masked_softmax::bwd, - "Self Multihead Attention scaled, time masked softmax -- Backward."); -} diff --git a/megatron/fused_kernels/scaled_masked_softmax.h b/megatron/fused_kernels/scaled_masked_softmax.h deleted file mode 100644 index 78e97e4ec..000000000 --- a/megatron/fused_kernels/scaled_masked_softmax.h +++ /dev/null @@ -1,492 +0,0 @@ -/* coding=utf-8 - * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include -#include -#include -#include -#include - -namespace { - -template -__device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src); - -template <> -__device__ __inline__ void copy_vector(c10::BFloat16 *dst, const c10::BFloat16 *src) { *dst = *src; } - -template <> -__device__ __inline__ void copy_vector(c10::BFloat16 *dst, const c10::BFloat16 *src) { *((float2*) dst) = *((float2*) src); } - -template <> -__device__ __inline__ void copy_vector(c10::Half *dst, const c10::Half *src) { *dst = *src; } - -template <> -__device__ __inline__ void copy_vector(c10::Half *dst, const c10::Half *src) { *((float2*) dst) = *((float2*) src); } - -template <> -__device__ __inline__ void copy_vector(uint8_t *dst, const uint8_t *src) { *dst = *src; } - -template <> -__device__ __inline__ void copy_vector(uint8_t *dst, const uint8_t *src) {*((half2*) dst) = *((half2*) src); } - -int log2_ceil(int value) { - int log2_value = 0; - while ((1 << log2_value) < value) ++log2_value; - return log2_value; -} - -template -struct Add { - __device__ __forceinline__ T operator()(T a, T b) const { - return a + b; - } -}; - -template -struct Max { - __device__ __forceinline__ T operator()(T a, T b) const { - return a < b ? b : a; - } -}; - -template -__device__ __forceinline__ T WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff) -{ -#if CUDA_VERSION >= 9000 - return __shfl_xor_sync(mask, value, laneMask, width); -#else - return __shfl_xor(value, laneMask, width); -#endif -} - -template class ReduceOp> -__device__ __forceinline__ void warp_reduce(acc_t* sum) { - ReduceOp r; - #pragma unroll - for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE); - sum[i] = r(sum[i], b); - } - } -} - -/* - * Extended softmax (from native aten pytorch) with following additional features - * 1) input scaling - * 2) Explicit masking - */ -template -__global__ void scaled_masked_softmax_warp_forward( - output_t *dst, - const input_t *src, - const uint8_t *mask, - const acc_t scale, - int micro_batch_size, - int element_count, - int pad_batches) -{ - // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and - // warp_size of method warp_softmax_forward_kernel. - constexpr int next_power_of_two = 1 << log2_elements; - constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; - constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; - constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; - constexpr int ELEMENTS_PER_LDG_STG = 4; - - // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, ) - // gridDim/blockIdx = (seq_len, attn_heads, batches) - int first_batch = (blockDim.y * (blockIdx.x + gridDim.x * (blockIdx.y + gridDim.y * blockIdx.z))+ threadIdx.y) * WARP_BATCH; - int pad_first_batch = 0; - if (pad_batches != 1) { // bert style - pad_first_batch = (blockDim.y * (blockIdx.x + gridDim.x * blockIdx.z) + threadIdx.y) * WARP_BATCH; - } else { // gpt2 style - pad_first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; - } - - // micro_batch_size might not be a multiple of WARP_BATCH. Check how - // many batches have to computed within this WARP. - int local_batches = micro_batch_size - first_batch; - if (local_batches > WARP_BATCH) - local_batches = WARP_BATCH; - - // there might be multiple batches per warp. compute the index within the batch - int local_idx = threadIdx.x; - - src += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; - dst += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; - mask += pad_first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; - - // load data from global memory - acc_t elements[WARP_BATCH][WARP_ITERATIONS]; - input_t temp_data[ELEMENTS_PER_LDG_STG]; - uint8_t temp_mask[ELEMENTS_PER_LDG_STG]; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - int batch_element_count = (i >= local_batches) ? 0 : element_count; - - #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; - - if (element_index < batch_element_count) { - int itr_idx = i*element_count+it*WARP_SIZE; - copy_vector(temp_data, src + itr_idx); - copy_vector(temp_mask, mask + itr_idx); - - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - if (temp_mask[element] != 1) { - elements[i][it + element] = (acc_t)temp_data[element] * scale; - } else { - elements[i][it + element] = -10000.0; - } - } - } else { - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - elements[i][it + element] = -std::numeric_limits::infinity(); - } - } - } - } - - // compute max_value - acc_t max_value[WARP_BATCH]; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - max_value[i] = elements[i][0]; - #pragma unroll - for (int it = 1; it < WARP_ITERATIONS; ++it) { - max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it]; - } - } - warp_reduce(max_value); - - acc_t sum[WARP_BATCH] { 0.0f }; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; ++it) { - elements[i][it] = std::exp((elements[i][it] - max_value[i])); - sum[i] += elements[i][it]; - } - } - warp_reduce(sum); - - // store result - output_t out[ELEMENTS_PER_LDG_STG]; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - if (i >= local_batches) - break; - #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; - if (element_index < element_count) { - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - out[element] = elements[i][it + element] / sum[i]; - } - copy_vector(dst + i * element_count + it * WARP_SIZE, out); - } else { - break; - } - } - } -} - -template -__global__ void scaled_masked_softmax_warp_backward( - output_t *gradInput, - input_t *grad, - const input_t *output, - acc_t scale, - int micro_batch_size, - int element_count) -{ - // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and - // warp_size of method warp_softmax_backward_kernel. - constexpr int next_power_of_two = 1 << log2_elements; - constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; - constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; - constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; - constexpr int ELEMENTS_PER_LDG_STG = 4; - - // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, ) - // gridDim/blockIdx = (seq_len, attn_heads, batches) - int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; - - // micro_batch_size might not be a multiple of WARP_BATCH. Check how - // many batches have to computed within this WARP. - int local_batches = micro_batch_size - first_batch; - if (local_batches > WARP_BATCH) - local_batches = WARP_BATCH; - - // there might be multiple batches per warp. compute the index within the batch - int local_idx = threadIdx.x; - - // the first element to process by the current thread - int thread_offset = first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; - grad += thread_offset; - output += thread_offset; - gradInput += thread_offset; - - // load data from global memory - acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f }; - acc_t output_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f }; - input_t temp_grad[ELEMENTS_PER_LDG_STG]; - input_t temp_output[ELEMENTS_PER_LDG_STG]; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - int batch_element_count = (i >= local_batches) ? 0 : element_count; - - #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; - if (element_index < batch_element_count) { - copy_vector(temp_grad, grad + i * element_count + it * WARP_SIZE); - copy_vector(temp_output, output + i * element_count + it * WARP_SIZE); - - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - output_reg[i][it + element] = (acc_t)temp_output[element]; - } - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - grad_reg[i][it + element] = (acc_t)temp_grad[element] * output_reg[i][it + element]; - } - } - } - } - - acc_t sum[WARP_BATCH]; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - sum[i] = grad_reg[i][0]; - #pragma unroll - for (int it = 1; it < WARP_ITERATIONS; ++it) { - sum[i] += grad_reg[i][it]; - } - } - warp_reduce(sum); - - // store result - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - if (i >= local_batches) - break; - #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; - if (element_index < element_count) { - // compute gradients - output_t out[ELEMENTS_PER_LDG_STG]; - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - out[element] = (output_t)(scale * (grad_reg[i][it + element] - output_reg[i][it + element] * sum[i])); - } - copy_vector(gradInput + i * element_count + it * WARP_SIZE, out); - } - } - } -} - -} // end of anonymous namespace - -template -void dispatch_scaled_masked_softmax_forward( - output_t *dst, - const input_t *src, - const uint8_t *mask, - const input_t scale, - int query_seq_len, - int key_seq_len, - int batches, - int attn_heads, - int pad_batches) -{ - TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 2048 ); - if (key_seq_len == 0) { - return; - } else { - int log2_elements = log2_ceil(key_seq_len); - const int next_power_of_two = 1 << log2_elements; - int batch_count = batches * attn_heads * query_seq_len; - - // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward. - int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; - - // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward. - int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; - - // use 128 threads per block to maximimize gpu utilization - constexpr int threads_per_block = 128; - - int warps_per_block = (threads_per_block / warp_size); - int batches_per_block = warps_per_block * batches_per_warp; - TORCH_INTERNAL_ASSERT(query_seq_len%batches_per_block == 0); - dim3 blocks(query_seq_len/batches_per_block, attn_heads, batches); - dim3 threads(warp_size, warps_per_block, 1); - // Launch code would be more elegant if C++ supported FOR CONSTEXPR - switch (log2_elements) { - case 0: // 1 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 1: // 2 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 2: // 4 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 3: // 8 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 4: // 16 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 5: // 32 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 6: // 64 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 7: // 128 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 8: // 256 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 9: // 512 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 10: // 1024 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 11: // 2048 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - default: - break; - } - } -} - -template -void dispatch_scaled_masked_softmax_backward( - output_t *grad_input, - input_t *grad, - const input_t *output, - const acc_t scale, - int query_seq_len, - int key_seq_len, - int batches, - int attn_heads) -{ - TORCH_INTERNAL_ASSERT( key_seq_len >= 0 && key_seq_len <= 2048 ); - if (key_seq_len == 0) { - return; - } else { - int log2_elements = log2_ceil(key_seq_len); - const int next_power_of_two = 1 << log2_elements; - int batch_count = batches * attn_heads * query_seq_len; - - // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward. - int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; - - // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward. - int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; - - // use 128 threads per block to maximimize gpu utilization - constexpr int threads_per_block = 128; - - int warps_per_block = (threads_per_block / warp_size); - int batches_per_block = warps_per_block * batches_per_warp; - int blocks = batch_count/batches_per_block; - dim3 threads(warp_size, warps_per_block, 1); - // Launch code would be more elegant if C++ supported FOR CONSTEXPR - switch (log2_elements) { - case 0: // 1 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 1: // 2 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 2: // 4 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 3: // 8 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 4: // 16 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 5: // 32 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 6: // 64 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 7: // 128 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 8: // 256 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 9: // 512 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 10: // 1024 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 11: // 2048 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - default: - break; - } - } -} diff --git a/megatron/fused_kernels/scaled_masked_softmax_cuda.cu b/megatron/fused_kernels/scaled_masked_softmax_cuda.cu deleted file mode 100644 index c034dc3ad..000000000 --- a/megatron/fused_kernels/scaled_masked_softmax_cuda.cu +++ /dev/null @@ -1,114 +0,0 @@ -/* coding=utf-8 - * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include -#ifndef __HIP_PLATFORM_HCC__ -#include -#endif -#include -#include -#include "scaled_masked_softmax.h" -#include "type_shim.h" - -namespace multihead_attn { -namespace fused_softmax { -namespace scaled_masked_softmax { - -torch::Tensor fwd_cuda( - torch::Tensor const& input, - torch::Tensor const& mask, - float scale_factor) -{ - // input is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len] - const int batches = input.size(0); - const int pad_batches = mask.size(0); - const int attn_heads = input.size(1); - const int query_seq_len = input.size(2); - const int key_seq_len = input.size(3); - TORCH_INTERNAL_ASSERT(key_seq_len <= 2048); - TORCH_INTERNAL_ASSERT(query_seq_len > 1); - TORCH_INTERNAL_ASSERT(pad_batches == 1 || pad_batches == batches); - TORCH_INTERNAL_ASSERT(mask.size(1) == 1); - TORCH_INTERNAL_ASSERT(mask.size(2) == query_seq_len); - TORCH_INTERNAL_ASSERT(mask.size(3) == key_seq_len); - - // Output - auto act_options = input.options().requires_grad(false); - torch::Tensor softmax_results = - torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options); - - // Softmax Intermediate Result Ptr - void* input_ptr = static_cast(input.data_ptr()); - void* mask_ptr = static_cast(mask.data_ptr()); - void* softmax_results_ptr = static_cast(softmax_results.data_ptr()); - - DISPATCH_HALF_AND_BFLOAT( - input.scalar_type(), - "dispatch_scaled_masked_softmax_forward", - dispatch_scaled_masked_softmax_forward( - reinterpret_cast(softmax_results_ptr), - reinterpret_cast(input_ptr), - reinterpret_cast(mask_ptr), - scale_factor, - query_seq_len, - key_seq_len, - batches, - attn_heads, - pad_batches); - ); - return softmax_results; -} - -torch::Tensor bwd_cuda( - torch::Tensor const& output_grads_, - torch::Tensor const& softmax_results_, - float scale_factor) { - - auto output_grads = output_grads_.contiguous(); - auto softmax_results = softmax_results_.contiguous(); - - //output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len] - const int batches = output_grads.size(0); - const int attn_heads = output_grads.size(1); - const int query_seq_len = output_grads.size(2); - const int key_seq_len = output_grads.size(3); - - void* output_grads_ptr = static_cast(output_grads.data_ptr()); - - //Softmax Grad - DISPATCH_HALF_AND_BFLOAT( - output_grads_.scalar_type(), - "dispatch_scaled_masked_softmax_backward", - dispatch_scaled_masked_softmax_backward( - reinterpret_cast(output_grads_ptr), - reinterpret_cast(output_grads_ptr), - reinterpret_cast(softmax_results.data_ptr()), - scale_factor, - query_seq_len, - key_seq_len, - batches, - attn_heads); - ); - - //backward pass is completely in-place - return output_grads; -} -} -} -} diff --git a/megatron/fused_kernels/scaled_upper_triang_masked_softmax.cpp b/megatron/fused_kernels/scaled_upper_triang_masked_softmax.cpp deleted file mode 100644 index ea283588d..000000000 --- a/megatron/fused_kernels/scaled_upper_triang_masked_softmax.cpp +++ /dev/null @@ -1,72 +0,0 @@ -/* coding=utf-8 - * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include - -namespace multihead_attn { -namespace fused_softmax { -namespace scaled_upper_triang_masked_softmax { - -torch::Tensor fwd_cuda( - torch::Tensor const& input, - float scale_factor); - -torch::Tensor bwd_cuda( - torch::Tensor const& output_grads, - torch::Tensor const& softmax_results, - float scale_factor); - -torch::Tensor fwd(torch::Tensor const& input, float scale_factor) { - AT_ASSERTM(input.dim() == 3, "expected 3D tensor"); - AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) || - (input.scalar_type() == at::ScalarType::BFloat16), - "Only fp16 and bf16 are supported"); - - return fwd_cuda(input, scale_factor); -} - -torch::Tensor bwd( - torch::Tensor const& output_grads, - torch::Tensor const& softmax_results, - float scale_factor) { - - AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor"); - AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor"); - - AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) || - (output_grads.scalar_type() == at::ScalarType::BFloat16), - "Only fp16 and bf16 are supported"); - AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) || - (softmax_results.scalar_type() == at::ScalarType::BFloat16), - "Only fp16 and bf16 are supported"); - - return bwd_cuda(output_grads, softmax_results, scale_factor); -} - -} // end namespace scaled_upper_triang_masked_softmax -} // end namespace fused_softmax -} // end namespace multihead_attn - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("forward", - &multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::fwd, - "Self Multihead Attention scaled, time masked softmax -- Forward."); - m.def("backward", - &multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::bwd, - "Self Multihead Attention scaled, time masked softmax -- Backward."); -} diff --git a/megatron/fused_kernels/scaled_upper_triang_masked_softmax.h b/megatron/fused_kernels/scaled_upper_triang_masked_softmax.h deleted file mode 100644 index addca0a0a..000000000 --- a/megatron/fused_kernels/scaled_upper_triang_masked_softmax.h +++ /dev/null @@ -1,511 +0,0 @@ -/* coding=utf-8 - * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include -#include -#include -#include - -namespace { - -template -__device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src); - -template <> -__device__ __inline__ void copy_vector(c10::BFloat16 *dst, const c10::BFloat16 *src) { *dst = *src; } - -template <> -__device__ __inline__ void copy_vector(c10::BFloat16 *dst, const c10::BFloat16 *src) { *((float2*) dst) = *((float2*) src); } - -template <> -__device__ __inline__ void copy_vector(c10::Half *dst, const c10::Half *src) { *dst = *src; } - -template <> -__device__ __inline__ void copy_vector(c10::Half *dst, const c10::Half *src) { *((float2*) dst) = *((float2*) src); } - -template <> -__device__ __inline__ void copy_vector(uint8_t *dst, const uint8_t *src) { *dst = *src; } - -template <> -__device__ __inline__ void copy_vector(uint8_t *dst, const uint8_t *src) {*((half2*) dst) = *((half2*) src); } - -template -__device__ __inline__ void copy_zero_vector(Datatype *dst); - -template <> -__device__ __inline__ void copy_zero_vector(c10::BFloat16 *dst) { *dst = 0.0; } - -template <> -__device__ __inline__ void copy_zero_vector(c10::BFloat16 *dst) { *((float2*) dst) = make_float2(0.0f, 0.0f); } - -template <> -__device__ __inline__ void copy_zero_vector(c10::Half *dst) { *dst = 0.0; } - -template <> -__device__ __inline__ void copy_zero_vector(c10::Half *dst) { *((float2*) dst) = make_float2(0.0f, 0.0f); } - - -int log2_ceil(int value) { - int log2_value = 0; - while ((1 << log2_value) < value) ++log2_value; - return log2_value; -} - -template -struct Add { - __device__ __forceinline__ T operator()(T a, T b) const { - return a + b; - } -}; - -template -struct Max { - __device__ __forceinline__ T operator()(T a, T b) const { - return a < b ? b : a; - } -}; - -template -__device__ __forceinline__ T WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff) -{ -#if CUDA_VERSION >= 9000 - return __shfl_xor_sync(mask, value, laneMask, width); -#else - return __shfl_xor(value, laneMask, width); -#endif -} - -template class ReduceOp> -__device__ __forceinline__ void warp_reduce(acc_t* sum) { - ReduceOp r; - #pragma unroll - for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE); - sum[i] = r(sum[i], b); - } - } -} - -/* - * Extended softmax (from native aten pytorch) with following additional features - * 1) input scaling - * 2) Implicit time (diagonal masking) - */ -template -__global__ void scaled_upper_triang_masked_softmax_warp_forward( - output_t *dst, - const input_t *src, - const acc_t scale, - int micro_batch_size, - int stride, - int element_count) -{ - // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and - // warp_size of method warp_softmax_forward_kernel. - constexpr int next_power_of_two = 1 << log2_elements; - constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; - constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; - constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; - constexpr int ELEMENTS_PER_LDG_STG = 4; - - int first_batch = (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + blockIdx.x; - int local_seq = blockIdx.x + 1; - int warp_iteration_limit = (local_seq + ELEMENTS_PER_LDG_STG * WARP_SIZE - 1)/ WARP_SIZE; - - // micro_batch_size might not be a multiple of WARP_BATCH. Check how - // many batches have to computed within this WARP. - int local_batches = micro_batch_size - first_batch; - if (local_batches > WARP_BATCH) - local_batches = WARP_BATCH; - - // there might be multiple batches per warp. compute the index within the batch - int local_idx = threadIdx.x; - - src += first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx; - dst += first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx; - - // load data from global memory - acc_t elements[WARP_BATCH][WARP_ITERATIONS]; - input_t temp_data[ELEMENTS_PER_LDG_STG]; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - int batch_element_count = (i >= local_batches) ? 0 : local_seq; - - #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; - - if (element_index < batch_element_count) { - copy_vector(temp_data, src + i*element_count*stride + it*WARP_SIZE); - - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - if ((element_index + element) < batch_element_count) { - elements[i][it+element] = (acc_t)temp_data[element] * scale; - } else { - elements[i][it + element] = -std::numeric_limits::infinity(); - } - } - } else { - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - elements[i][it + element] = -std::numeric_limits::infinity(); - } - } - } - } - - // compute max_value - acc_t max_value[WARP_BATCH]; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - max_value[i] = elements[i][0]; - #pragma unroll - for (int it = 1; it < WARP_ITERATIONS; ++it) { - max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it]; - } - } - warp_reduce(max_value); - - acc_t sum[WARP_BATCH] { 0.0f }; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; ++it) { - if (it < warp_iteration_limit) { - elements[i][it] = std::exp((elements[i][it] - max_value[i])); - sum[i] += elements[i][it]; - } - } - } - warp_reduce(sum); - - // store result - output_t out[ELEMENTS_PER_LDG_STG]; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - if (i >= local_batches) - break; - #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; - - if (element_index < local_seq) { - - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - if (element_index + element < local_seq) { - out[element] = elements[i][it + element] / sum[i]; - } else { - out[element] = 0; - } - } - copy_vector(dst + i * element_count * stride + it * WARP_SIZE, out); - } else if (element_index < element_count) { - copy_zero_vector(dst + i * element_count * stride + it * WARP_SIZE); - } else { - break; - } - } - } -} - -template -__global__ void scaled_upper_triang_masked_softmax_warp_backward( - output_t *gradInput, - input_t *grad, - const input_t *output, - acc_t scale, - int micro_batch_size, - int stride, - int element_count) -{ - // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and - // warp_size of method warp_softmax_backward_kernel. - constexpr int next_power_of_two = 1 << log2_elements; - constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; - constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; - constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; - constexpr int ELEMENTS_PER_LDG_STG = 4; - - int first_batch = (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + blockIdx.x; - int local_seq = blockIdx.x + 1; - - // micro_batch_size might not be a multiple of WARP_BATCH. Check how - // many batches have to computed within this WARP. - int local_batches = micro_batch_size - first_batch; - if (local_batches > WARP_BATCH) - local_batches = WARP_BATCH; - - // there might be multiple batches per warp. compute the index within the batch - int local_idx = threadIdx.x; - - // the first element to process by the current thread - int thread_offset = first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx; - grad += thread_offset; - output += thread_offset; - gradInput += thread_offset; - - // load data from global memory - acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f }; - acc_t output_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f }; - input_t temp_grad[ELEMENTS_PER_LDG_STG]; - input_t temp_output[ELEMENTS_PER_LDG_STG]; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - int batch_element_count = (i >= local_batches) ? 0 : local_seq; - - #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; - if (element_index < batch_element_count) { - copy_vector(temp_grad, grad + i * element_count * stride + it * WARP_SIZE); - copy_vector(temp_output, output + i * element_count * stride + it * WARP_SIZE); - - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - if (element_index + element < batch_element_count) { - output_reg[i][it + element] = (acc_t)temp_output[element]; - } - } - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - if (element_index + element < batch_element_count) { - grad_reg[i][it + element] = (acc_t)temp_grad[element] * output_reg[i][it + element]; - } - } - } - } - } - - acc_t sum[WARP_BATCH]; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - sum[i] = grad_reg[i][0]; - #pragma unroll - for (int it = 1; it < WARP_ITERATIONS; ++it) { - sum[i] += grad_reg[i][it]; - } - } - warp_reduce(sum); - - // store result - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - if (i >= local_batches) - break; - #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; - if (element_index < element_count) { - // compute gradients - output_t out[ELEMENTS_PER_LDG_STG]; - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - out[element] = (output_t)(scale * (grad_reg[i][it + element] - output_reg[i][it + element] * sum[i])); - } - copy_vector(gradInput + i * element_count * stride + it * WARP_SIZE, out); - } - } - } -} - -} // end of anonymous namespace - -template -void dispatch_scaled_upper_triang_masked_softmax_forward( - output_t *dst, - const input_t *src, - const input_t scale, - int softmax_elements, - int softmax_elements_stride, - int attn_batches) -{ - TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 2048 ); - if (softmax_elements == 0) { - return; - } else { - int log2_elements = log2_ceil(softmax_elements); - const int next_power_of_two = 1 << log2_elements; - int seq_len = softmax_elements; - int batch_count = attn_batches * seq_len; - - // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward. - int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; - - // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward. - int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; - - // use 128 threads per block to maximimize gpu utilization - constexpr int threads_per_block = 128; - - int warps_per_block = (threads_per_block / warp_size); - int batches_per_block = warps_per_block * batches_per_warp; - TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0); - int blocks_per_seq = attn_batches / batches_per_block; - dim3 blocks(seq_len, blocks_per_seq, 1); - dim3 threads(warp_size, warps_per_block, 1); - // Launch code would be more elegant if C++ supported FOR CONSTEXPR - switch (log2_elements) { - case 0: // 1 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 1: // 2 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 2: // 4 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 3: // 8 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 4: // 16 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 5: // 32 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 6: // 64 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 7: // 128 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 8: // 256 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 9: // 512 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 10: // 1024 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 11: // 2048 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - default: - break; - } - } -} - -template -void dispatch_scaled_upper_triang_masked_softmax_backward( - output_t *grad_input, - input_t *grad, - const input_t *output, - const acc_t scale, - int softmax_elements, - int softmax_elements_stride, - int attn_batches) -{ - TORCH_INTERNAL_ASSERT( softmax_elements >= 0 && softmax_elements <= 2048 ); - if (softmax_elements == 0) { - return; - } else { - int log2_elements = log2_ceil(softmax_elements); - const int next_power_of_two = 1 << log2_elements; - int seq_len = softmax_elements; - int batch_count = attn_batches * seq_len; - - // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward. - int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; - - // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward. - int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; - - // use 128 threads per block to maximimize gpu utilization - constexpr int threads_per_block = 128; - - int warps_per_block = (threads_per_block / warp_size); - int batches_per_block = warps_per_block * batches_per_warp; - TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0); - int blocks_per_seq = attn_batches / batches_per_block; - dim3 blocks(seq_len, blocks_per_seq, 1); - dim3 threads(warp_size, warps_per_block, 1); - // Launch code would be more elegant if C++ supported FOR CONSTEXPR - switch (log2_elements) { - case 0: // 1 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 1: // 2 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 2: // 4 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 3: // 8 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 4: // 16 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 5: // 32 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 6: // 64 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 7: // 128 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 8: // 256 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 9: // 512 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 10: // 1024 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 11: // 2048 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - default: - break; - } - } -} diff --git a/megatron/fused_kernels/scaled_upper_triang_masked_softmax_cuda.cu b/megatron/fused_kernels/scaled_upper_triang_masked_softmax_cuda.cu deleted file mode 100644 index 59e452584..000000000 --- a/megatron/fused_kernels/scaled_upper_triang_masked_softmax_cuda.cu +++ /dev/null @@ -1,100 +0,0 @@ -/* coding=utf-8 - * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include -#ifndef __HIP_PLATFORM_HCC__ -#include -#endif -#include -#include -#include "scaled_upper_triang_masked_softmax.h" -#include "type_shim.h" - -namespace multihead_attn { -namespace fused_softmax { -namespace scaled_upper_triang_masked_softmax { - -torch::Tensor fwd_cuda( - torch::Tensor const& input, - float scale_factor) -{ - // input is a 3d tensor with dimensions [attn_batches, seq_len, seq_len] - const int attn_batches = input.size(0); - const int seq_len = input.size(1); - TORCH_INTERNAL_ASSERT(seq_len <= 2048); - - // Output - auto act_options = input.options().requires_grad(false); - torch::Tensor softmax_results = - torch::empty({attn_batches, seq_len, seq_len}, act_options); - - // Softmax Intermediate Result Ptr - void* input_ptr = static_cast(input.data_ptr()); - void* softmax_results_ptr = static_cast(softmax_results.data_ptr()); - - DISPATCH_HALF_AND_BFLOAT( - input.scalar_type(), - "dispatch_scaled_upper_triang_masked_softmax_forward", - dispatch_scaled_upper_triang_masked_softmax_forward( - reinterpret_cast(softmax_results_ptr), - reinterpret_cast(input_ptr), - scale_factor, - seq_len, - seq_len, - attn_batches); - ); - return softmax_results; -} - - -torch::Tensor bwd_cuda( - torch::Tensor const& output_grads_, - torch::Tensor const& softmax_results_, - float scale_factor) { - - auto output_grads = output_grads_.contiguous(); - auto softmax_results = softmax_results_.contiguous(); - - //output grads is a 3d tensor with dimensions [attn_batches, seq_len, seq_len] - const int attn_batches = output_grads.size(0); - const int seq_len = output_grads.size(1); - TORCH_INTERNAL_ASSERT(output_grads.size(1) == output_grads.size(2)); - - void* output_grads_ptr = static_cast(output_grads.data_ptr()); - - //Softmax Grad - DISPATCH_HALF_AND_BFLOAT( - output_grads_.scalar_type(), - "dispatch_scaled_upper_triang_masked_softmax_backward", - dispatch_scaled_upper_triang_masked_softmax_backward( - reinterpret_cast(output_grads_ptr), - reinterpret_cast(output_grads_ptr), - reinterpret_cast(softmax_results.data_ptr()), - scale_factor, - seq_len, - seq_len, - attn_batches); - ); - - //backward pass is completely in-place - return output_grads; -} -} -} -} diff --git a/megatron/fused_kernels/type_shim.h b/megatron/fused_kernels/type_shim.h deleted file mode 100644 index 6437dcc7c..000000000 --- a/megatron/fused_kernels/type_shim.h +++ /dev/null @@ -1,91 +0,0 @@ -/* coding=utf-8 - * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - - -#include -#include "compat.h" - - -#define DISPATCH_HALF_AND_BFLOAT(TYPE, NAME, ...) \ - switch(TYPE) \ - { \ - case at::ScalarType::Half: \ - { \ - using scalar_t = at::Half; \ - __VA_ARGS__; \ - break; \ - } \ - case at::ScalarType::BFloat16: \ - { \ - using scalar_t = at::BFloat16; \ - __VA_ARGS__; \ - break; \ - } \ - default: \ - AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ - } - - - -#define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \ - switch(TYPEIN) \ - { \ - case at::ScalarType::Float: \ - { \ - using scalar_t_in = float; \ - switch(TYPEOUT) \ - { \ - case at::ScalarType::Float: \ - { \ - using scalar_t_out = float; \ - __VA_ARGS__; \ - break; \ - } \ - case at::ScalarType::Half: \ - { \ - using scalar_t_out = at::Half; \ - __VA_ARGS__; \ - break; \ - } \ - case at::ScalarType::BFloat16: \ - { \ - using scalar_t_out = at::BFloat16; \ - __VA_ARGS__; \ - break; \ - } \ - default: \ - AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \ - } \ - break; \ - } \ - case at::ScalarType::Half: \ - { \ - using scalar_t_in = at::Half; \ - using scalar_t_out = at::Half; \ - __VA_ARGS__; \ - break; \ - } \ - case at::ScalarType::BFloat16: \ - { \ - using scalar_t_in = at::BFloat16; \ - using scalar_t_out = at::BFloat16; \ - __VA_ARGS__; \ - break; \ - } \ - default: \ - AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \ - } - -- Gitee From 7f2312946befb16ebcfe9f128ae23fe565e3bcf3 Mon Sep 17 00:00:00 2001 From: fengliangjun Date: Tue, 6 Jun 2023 12:14:52 +0000 Subject: [PATCH 04/10] =?UTF-8?q?=E5=88=A0=E9=99=A4=E6=96=87=E4=BB=B6=20me?= =?UTF-8?q?gatron/model/bert=5Fmodel.py?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- megatron/model/bert_model.py | 236 ----------------------------------- 1 file changed, 236 deletions(-) delete mode 100644 megatron/model/bert_model.py diff --git a/megatron/model/bert_model.py b/megatron/model/bert_model.py deleted file mode 100644 index be4025111..000000000 --- a/megatron/model/bert_model.py +++ /dev/null @@ -1,236 +0,0 @@ -# coding=utf-8 -# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""BERT model.""" - -import torch - -from megatron import get_args -from megatron import mpu -from megatron.model.enums import AttnMaskType -from megatron.model.language_model import parallel_lm_logits -from megatron.model.language_model import get_language_model -from megatron.model import LayerNorm -from megatron.model.utils import openai_gelu, erf_gelu -from megatron.model.utils import get_linear_layer -from megatron.model.utils import init_method_normal -from megatron.model.utils import scaled_init_method_normal -from .module import MegatronModule - -def bert_extended_attention_mask(attention_mask): - # We create a 3D attention mask from a 2D tensor mask. - # [b, 1, s] - attention_mask_b1s = attention_mask.unsqueeze(1) - # [b, s, 1] - attention_mask_bs1 = attention_mask.unsqueeze(2) - # [b, s, s] - attention_mask_bss = attention_mask_b1s * attention_mask_bs1 - # [b, 1, s, s] - extended_attention_mask = attention_mask_bss.unsqueeze(1) - - # Convert attention mask to binary: - extended_attention_mask = (extended_attention_mask < 0.5) - - return extended_attention_mask - -def bert_position_ids(token_ids): - # Create position ids - seq_length = token_ids.size(1) - position_ids = torch.arange(seq_length, dtype=torch.long, - device=token_ids.device) - position_ids = position_ids.unsqueeze(0).expand_as(token_ids) - - return position_ids - - -class BertLMHead(MegatronModule): - """Masked LM head for Bert - - Arguments: - mpu_vocab_size: model parallel size of vocabulary. - hidden_size: hidden size - init_method: init method for weight initialization - layernorm_epsilon: tolerance for layer norm divisions - parallel_output: whether output logits being distributed or not. - """ - - def __init__(self, mpu_vocab_size, hidden_size, init_method, - layernorm_epsilon, parallel_output): - - super(BertLMHead, self).__init__() - - args = get_args() - - self.bias = torch.nn.Parameter(torch.zeros(mpu_vocab_size)) - mpu.set_tensor_model_parallel_attributes(self.bias, True, 0, 1) - self.parallel_output = parallel_output - - self.dense = get_linear_layer(hidden_size, hidden_size, init_method) - self.layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon) - self.gelu = torch.nn.functional.gelu - if args.openai_gelu: - self.gelu = openai_gelu - elif args.onnx_safe: - self.gelu = erf_gelu - - def forward(self, hidden_states, word_embeddings_weight): - hidden_states = self.dense(hidden_states) - hidden_states = self.gelu(hidden_states) - hidden_states = self.layernorm(hidden_states) - output = parallel_lm_logits(hidden_states, - word_embeddings_weight, - self.parallel_output, - bias=self.bias) - return output - - -def post_language_model_processing(lm_output, pooled_output, - lm_head, binary_head, - lm_labels, - logit_weights, - fp16_lm_cross_entropy): - # Output. - lm_logits = lm_head( - lm_output, logit_weights) - - binary_logits = None - if binary_head is not None: - binary_logits = binary_head(pooled_output) - - if lm_labels is None: - return lm_logits, binary_logits - else: - if fp16_lm_cross_entropy: - assert lm_logits.dtype == torch.half - lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits, lm_labels) - else: - lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits.float(), - lm_labels) - return lm_loss, binary_logits - - -class BertModel(MegatronModule): - """Bert Language model.""" - - def __init__(self, - num_tokentypes=2, - add_binary_head=True, - parallel_output=True, - pre_process=True, - post_process=True): - super(BertModel, self).__init__() - args = get_args() - - self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy - self.add_binary_head = add_binary_head - self.parallel_output = parallel_output - self.pre_process = pre_process - self.post_process = post_process - - init_method = init_method_normal(args.init_method_std) - scaled_init_method = scaled_init_method_normal(args.init_method_std, - args.num_layers) - - self.language_model, self._language_model_key = get_language_model( - num_tokentypes=num_tokentypes, - add_pooler=self.add_binary_head, - encoder_attn_mask_type=AttnMaskType.padding, - init_method=init_method, - scaled_init_method=scaled_init_method, - pre_process=self.pre_process, - post_process=self.post_process) - - self.initialize_word_embeddings(init_method_normal) - if self.post_process: - self.lm_head = BertLMHead( - self.word_embeddings_weight().size(0), - args.hidden_size, init_method, args.layernorm_epsilon, parallel_output) - self._lm_head_key = 'lm_head' - self.binary_head = None - if self.add_binary_head: - self.binary_head = get_linear_layer(args.hidden_size, 2, - init_method) - self._binary_head_key = 'binary_head' - - def set_input_tensor(self, input_tensor): - """See megatron.model.transformer.set_input_tensor()""" - self.language_model.set_input_tensor(input_tensor) - - def forward(self, bert_model_input, attention_mask, - tokentype_ids=None, lm_labels=None): - - extended_attention_mask = bert_extended_attention_mask(attention_mask) - input_ids = bert_model_input - position_ids = bert_position_ids(input_ids) - - lm_output = self.language_model( - input_ids, - position_ids, - extended_attention_mask, - tokentype_ids=tokentype_ids - ) - - if self.post_process and self.add_binary_head: - lm_output, pooled_output = lm_output[0], lm_output[1] - else: - pooled_output = None - - if self.post_process: - return post_language_model_processing(lm_output, pooled_output, - self.lm_head, self.binary_head, - lm_labels, - self.word_embeddings_weight(), - self.fp16_lm_cross_entropy) - else: - return lm_output - - - def state_dict_for_save_checkpoint(self, destination=None, prefix='', - keep_vars=False): - """For easy load when model is combined with other heads, - add an extra key.""" - - state_dict_ = {} - state_dict_[self._language_model_key] \ - = self.language_model.state_dict_for_save_checkpoint( - destination, prefix, keep_vars) - if self.post_process: - state_dict_[self._lm_head_key] \ - = self.lm_head.state_dict_for_save_checkpoint( - destination, prefix, keep_vars) - if self.post_process and self.add_binary_head: - state_dict_[self._binary_head_key] \ - = self.binary_head.state_dict(destination, prefix, keep_vars) - # Save word_embeddings. - if self.post_process and not self.pre_process: - state_dict_[self._word_embeddings_for_head_key] \ - = self.word_embeddings.state_dict(destination, prefix, keep_vars) - return state_dict_ - - def load_state_dict(self, state_dict, strict=True): - """Customized load.""" - - self.language_model.load_state_dict( - state_dict[self._language_model_key], strict=strict) - if self.post_process: - self.lm_head.load_state_dict( - state_dict[self._lm_head_key], strict=strict) - if self.post_process and self.add_binary_head: - self.binary_head.load_state_dict( - state_dict[self._binary_head_key], strict=strict) - # Load word_embeddings. - if self.post_process and not self.pre_process: - self.word_embeddings.load_state_dict( - state_dict[self._word_embeddings_for_head_key], strict=strict) -- Gitee From 83999bb932dc51494f871079c9c3dac4d99cb7dc Mon Sep 17 00:00:00 2001 From: fengliangjun Date: Tue, 6 Jun 2023 12:15:23 +0000 Subject: [PATCH 05/10] =?UTF-8?q?=E5=88=A0=E9=99=A4=E6=96=87=E4=BB=B6=20me?= =?UTF-8?q?gatron/model/biencoder=5Fmodel.py?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- megatron/model/biencoder_model.py | 295 ------------------------------ 1 file changed, 295 deletions(-) delete mode 100644 megatron/model/biencoder_model.py diff --git a/megatron/model/biencoder_model.py b/megatron/model/biencoder_model.py deleted file mode 100644 index 51ac0a060..000000000 --- a/megatron/model/biencoder_model.py +++ /dev/null @@ -1,295 +0,0 @@ -import os -import torch -import sys - -from megatron import get_args, print_rank_0 -from megatron.checkpointing import fix_query_key_value_ordering -from megatron.checkpointing import get_checkpoint_tracker_filename -from megatron.checkpointing import get_checkpoint_name -from megatron import mpu, get_tokenizer -from megatron.model.bert_model import bert_position_ids -from megatron.model.enums import AttnMaskType -from megatron.model.language_model import get_language_model -from megatron.model.utils import get_linear_layer -from megatron.model.utils import init_method_normal -from megatron.model.utils import scaled_init_method_normal -from .module import MegatronModule - -def biencoder_model_provider(only_query_model=False, - only_context_model=False, - biencoder_shared_query_context_model=False): - """Build the model.""" - args = get_args() - - assert mpu.get_tensor_model_parallel_world_size() == 1 and \ - mpu.get_pipeline_model_parallel_world_size() == 1, \ - "Model parallel size > 1 not supported for ICT" - - print_rank_0('building BiEncoderModel...') - - # simpler to just keep using 2 tokentypes since - # the LM we initialize with has 2 tokentypes - model = BiEncoderModel( - num_tokentypes=2, - parallel_output=False, - only_query_model=only_query_model, - only_context_model=only_context_model, - biencoder_shared_query_context_model=\ - biencoder_shared_query_context_model) - - return model - - -class BiEncoderModel(MegatronModule): - """Bert-based module for Biencoder model.""" - - def __init__(self, - num_tokentypes=1, - parallel_output=True, - only_query_model=False, - only_context_model=False, - biencoder_shared_query_context_model=False): - super(BiEncoderModel, self).__init__() - args = get_args() - - bert_kwargs = dict( - num_tokentypes=num_tokentypes, - parallel_output=parallel_output) - - self.biencoder_shared_query_context_model = \ - biencoder_shared_query_context_model - assert not (only_context_model and only_query_model) - self.use_context_model = not only_query_model - self.use_query_model = not only_context_model - self.biencoder_projection_dim = args.biencoder_projection_dim - - if self.biencoder_shared_query_context_model: - self.model = PretrainedBertModel(**bert_kwargs) - self._model_key = 'shared_model' - self.query_model, self.context_model = self.model, self.model - else: - if self.use_query_model: - # this model embeds (pseudo-)queries - Embed_input in the paper - self.query_model = PretrainedBertModel(**bert_kwargs) - self._query_key = 'query_model' - - if self.use_context_model: - # this model embeds evidence blocks - Embed_doc in the paper - self.context_model = PretrainedBertModel(**bert_kwargs) - self._context_key = 'context_model' - - def forward(self, query_tokens, query_attention_mask, query_types, - context_tokens, context_attention_mask, context_types): - """Run a forward pass for each of the models and - return the respective embeddings.""" - - if self.use_query_model: - query_logits = self.embed_text(self.query_model, - query_tokens, - query_attention_mask, - query_types) - else: - raise ValueError("Cannot embed query without the query model.") - if self.use_context_model: - context_logits = self.embed_text(self.context_model, - context_tokens, - context_attention_mask, - context_types) - else: - raise ValueError("Cannot embed block without the block model.") - return query_logits, context_logits - - @staticmethod - def embed_text(model, tokens, attention_mask, token_types): - """Embed a batch of tokens using the model""" - logits = model(tokens, - attention_mask, - token_types) - return logits - - def state_dict_for_save_checkpoint(self, destination=None, \ - prefix='', keep_vars=False): - """Save dict with state dicts of each of the models.""" - state_dict_ = {} - if self.biencoder_shared_query_context_model: - state_dict_[self._model_key] = \ - self.model.state_dict_for_save_checkpoint(destination, - prefix, - keep_vars) - else: - if self.use_query_model: - state_dict_[self._query_key] = \ - self.query_model.state_dict_for_save_checkpoint( - destination, prefix, keep_vars) - - if self.use_context_model: - state_dict_[self._context_key] = \ - self.context_model.state_dict_for_save_checkpoint( - destination, prefix, keep_vars) - - return state_dict_ - - def load_state_dict(self, state_dict, strict=True): - """Load the state dicts of each of the models""" - if self.biencoder_shared_query_context_model: - print_rank_0("Loading shared query-context model") - self.model.load_state_dict(state_dict[self._model_key], \ - strict=strict) - else: - if self.use_query_model: - print_rank_0("Loading query model") - self.query_model.load_state_dict( \ - state_dict[self._query_key], strict=strict) - - if self.use_context_model: - print_rank_0("Loading context model") - self.context_model.load_state_dict( \ - state_dict[self._context_key], strict=strict) - - def init_state_dict_from_bert(self): - """Initialize the state from a pretrained BERT model - on iteration zero of ICT pretraining""" - args = get_args() - - if args.bert_load is None: - print_rank_0("bert-load argument is None") - return - - tracker_filename = get_checkpoint_tracker_filename(args.bert_load) - if not os.path.isfile(tracker_filename): - raise FileNotFoundError("Could not find BERT checkpoint") - with open(tracker_filename, 'r') as f: - iteration = int(f.read().strip()) - assert iteration > 0 - - checkpoint_name = get_checkpoint_name(args.bert_load, iteration, False) - if mpu.get_data_parallel_rank() == 0: - print('global rank {} is loading BERT checkpoint {}'.format( - torch.distributed.get_rank(), checkpoint_name)) - - # Load the checkpoint. - try: - state_dict = torch.load(checkpoint_name, map_location='cpu') - except ModuleNotFoundError: - from megatron.fp16_deprecated import loss_scaler - # For backward compatibility. - print_rank_0(' > deserializing using the old code structure ...') - sys.modules['fp16.loss_scaler'] = sys.modules[ - 'megatron.fp16_deprecated.loss_scaler'] - sys.modules['megatron.fp16.loss_scaler'] = sys.modules[ - 'megatron.fp16_deprecated.loss_scaler'] - state_dict = torch.load(checkpoint_name, map_location='cpu') - sys.modules.pop('fp16.loss_scaler', None) - sys.modules.pop('megatron.fp16.loss_scaler', None) - except BaseException: - print_rank_0('could not load the BERT checkpoint') - sys.exit() - - checkpoint_version = state_dict.get('checkpoint_version', 0) - - # load the LM state dict into each model - model_dict = state_dict['model']['language_model'] - - if self.biencoder_shared_query_context_model: - self.model.language_model.load_state_dict(model_dict) - fix_query_key_value_ordering(self.model, checkpoint_version) - else: - if self.use_query_model: - self.query_model.language_model.load_state_dict(model_dict) - # give each model the same ict_head to begin with as well - if self.biencoder_projection_dim > 0: - query_proj_state_dict = \ - self.state_dict_for_save_checkpoint()\ - [self._query_key]['projection_enc'] - fix_query_key_value_ordering(self.query_model, checkpoint_version) - - if self.use_context_model: - self.context_model.language_model.load_state_dict(model_dict) - if self.query_model is not None and \ - self.biencoder_projection_dim > 0: - self.context_model.projection_enc.load_state_dict\ - (query_proj_state_dict) - fix_query_key_value_ordering(self.context_model, checkpoint_version) - - -class PretrainedBertModel(MegatronModule): - """BERT-based encoder for queries or contexts used for - learned information retrieval.""" - - def __init__(self, num_tokentypes=2, - parallel_output=True): - super(PretrainedBertModel, self).__init__() - - args = get_args() - tokenizer = get_tokenizer() - self.pad_id = tokenizer.pad - self.biencoder_projection_dim = args.biencoder_projection_dim - self.parallel_output = parallel_output - init_method = init_method_normal(args.init_method_std) - scaled_init_method = scaled_init_method_normal( - args.init_method_std, args.num_layers) - - self.language_model, self._language_model_key = get_language_model( - num_tokentypes=num_tokentypes, - add_pooler=False, - encoder_attn_mask_type=AttnMaskType.padding, - init_method=init_method, - scaled_init_method=scaled_init_method) - - if args.biencoder_projection_dim > 0: - self.projection_enc = get_linear_layer(args.hidden_size, - args.biencoder_projection_dim, - init_method) - self._projection_enc_key = 'projection_enc' - - def forward(self, input_ids, attention_mask, tokentype_ids=None): - extended_attention_mask = attention_mask.unsqueeze(1) - #extended_attention_mask = bert_extended_attention_mask(attention_mask) - position_ids = bert_position_ids(input_ids) - - - lm_output = self.language_model(input_ids, - position_ids, - extended_attention_mask, - tokentype_ids=tokentype_ids) - # This mask will be used in average-pooling and max-pooling - pool_mask = (input_ids == self.pad_id).unsqueeze(2) - - # Taking the representation of the [CLS] token of BERT - pooled_output = lm_output[:, 0, :] - - # Converting to float16 dtype - pooled_output = pooled_output.to(lm_output.dtype) - - # Output. - if self.biencoder_projection_dim: - pooled_output = self.projection_enc(pooled_output) - - return pooled_output - - def state_dict_for_save_checkpoint(self, destination=None, prefix='', - keep_vars=False): - """For easy load when model is combined with other heads, - add an extra key.""" - - state_dict_ = {} - state_dict_[self._language_model_key] \ - = self.language_model.state_dict_for_save_checkpoint( - destination, prefix, keep_vars) - - if self.biencoder_projection_dim > 0: - state_dict_[self._projection_enc_key] = \ - self.projection_enc.state_dict(destination, prefix, keep_vars) - - return state_dict_ - - def load_state_dict(self, state_dict, strict=True): - """Customized load.""" - print_rank_0("loading BERT weights") - self.language_model.load_state_dict( - state_dict[self._language_model_key], strict=strict) - - if self.biencoder_projection_dim > 0: - print_rank_0("loading projection head weights") - self.projection_enc.load_state_dict( - state_dict[self._projection_enc_key], strict=strict) -- Gitee From b468c53b0efc17ea3b973f359e8df229746b461c Mon Sep 17 00:00:00 2001 From: fengliangjun Date: Tue, 6 Jun 2023 12:16:32 +0000 Subject: [PATCH 06/10] =?UTF-8?q?=E5=88=A0=E9=99=A4=E6=96=87=E4=BB=B6=20me?= =?UTF-8?q?gatron/model/realm=5Fmodel.py?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- megatron/model/realm_model.py | 204 ---------------------------------- 1 file changed, 204 deletions(-) delete mode 100644 megatron/model/realm_model.py diff --git a/megatron/model/realm_model.py b/megatron/model/realm_model.py deleted file mode 100644 index e74eb2e58..000000000 --- a/megatron/model/realm_model.py +++ /dev/null @@ -1,204 +0,0 @@ -import os -import torch - -from megatron import get_args, print_rank_0 -from megatron.checkpointing import get_checkpoint_tracker_filename, get_checkpoint_name -from megatron.model import BertModel -from .module import MegatronModule -from megatron import mpu -from megatron.model.enums import AttnMaskType -from megatron.model.utils import get_linear_layer -from megatron.model.utils import init_method_normal -from megatron.model.language_model import get_language_model -from megatron.model.utils import scaled_init_method_normal -from megatron.model.bert_model import bert_extended_attention_mask, bert_position_ids -from deepspeed.accelerator import get_accelerator - -def general_ict_model_provider(only_query_model=False, only_block_model=False): - """Build the model.""" - args = get_args() - assert args.ict_head_size is not None, \ - "Need to specify --ict-head-size to provide an ICTBertModel" - assert mpu.get_tensor_model_parallel_world_size() == 1 and mpu.get_pipeline_model_parallel_world_size() == 1, \ - "Model parallel size > 1 not supported for ICT" - - print_rank_0('building ICTBertModel...') - - # simpler to just keep using 2 tokentypes since the LM we initialize with has 2 tokentypes - model = ICTBertModel( - ict_head_size=args.ict_head_size, - num_tokentypes=2, - parallel_output=True, - only_query_model=only_query_model, - only_block_model=only_block_model) - - return model - - -class ICTBertModel(MegatronModule): - """Bert-based module for Inverse Cloze task.""" - def __init__(self, - ict_head_size, - num_tokentypes=1, - parallel_output=True, - only_query_model=False, - only_block_model=False): - super(ICTBertModel, self).__init__() - bert_kwargs = dict( - ict_head_size=ict_head_size, - num_tokentypes=num_tokentypes, - parallel_output=parallel_output - ) - assert not (only_block_model and only_query_model) - self.use_block_model = not only_query_model - self.use_query_model = not only_block_model - - if self.use_query_model: - # this model embeds (pseudo-)queries - Embed_input in the paper - self.query_model = IREncoderBertModel(**bert_kwargs) - self._query_key = 'question_model' - - if self.use_block_model: - # this model embeds evidence blocks - Embed_doc in the paper - self.block_model = IREncoderBertModel(**bert_kwargs) - self._block_key = 'context_model' - - def forward(self, query_tokens, query_attention_mask, block_tokens, block_attention_mask): - """Run a forward pass for each of the models and return the respective embeddings.""" - query_logits = self.embed_query(query_tokens, query_attention_mask) - block_logits = self.embed_block(block_tokens, block_attention_mask) - return query_logits, block_logits - - def embed_query(self, query_tokens, query_attention_mask): - """Embed a batch of tokens using the query model""" - if self.use_query_model: - query_types = get_accelerator().LongTensor(*query_tokens.shape).fill_(0) - query_ict_logits, _ = self.query_model.forward(query_tokens, query_attention_mask, query_types) - return query_ict_logits - else: - raise ValueError("Cannot embed query without query model.") - - def embed_block(self, block_tokens, block_attention_mask): - """Embed a batch of tokens using the block model""" - if self.use_block_model: - block_types = get_accelerator().LongTensor(*block_tokens.shape).fill_(0) - block_ict_logits, _ = self.block_model.forward(block_tokens, block_attention_mask, block_types) - return block_ict_logits - else: - raise ValueError("Cannot embed block without block model.") - - def state_dict_for_save_checkpoint(self, destination=None, prefix='', keep_vars=False): - """Save dict with state dicts of each of the models.""" - state_dict_ = {} - if self.use_query_model: - state_dict_[self._query_key] \ - = self.query_model.state_dict_for_save_checkpoint( - destination, prefix, keep_vars) - - if self.use_block_model: - state_dict_[self._block_key] \ - = self.block_model.state_dict_for_save_checkpoint( - destination, prefix, keep_vars) - - return state_dict_ - - def load_state_dict(self, state_dict, strict=True): - """Load the state dicts of each of the models""" - if self.use_query_model: - print("Loading ICT query model", flush=True) - self.query_model.load_state_dict( - state_dict[self._query_key], strict=strict) - - if self.use_block_model: - print("Loading ICT block model", flush=True) - self.block_model.load_state_dict( - state_dict[self._block_key], strict=strict) - - def init_state_dict_from_bert(self): - """Initialize the state from a pretrained BERT model on iteration zero of ICT pretraining""" - args = get_args() - tracker_filename = get_checkpoint_tracker_filename(args.bert_load) - if not os.path.isfile(tracker_filename): - raise FileNotFoundError("Could not find BERT load for ICT") - with open(tracker_filename, 'r') as f: - iteration = int(f.read().strip()) - assert iteration > 0 - - checkpoint_name = get_checkpoint_name(args.bert_load, iteration, False) - if mpu.get_data_parallel_rank() == 0: - print('global rank {} is loading checkpoint {}'.format( - torch.distributed.get_rank(), checkpoint_name)) - - try: - state_dict = torch.load(checkpoint_name, map_location='cpu') - except BaseException: - raise ValueError("Could not load checkpoint") - - # load the LM state dict into each model - model_dict = state_dict['model']['language_model'] - self.query_model.language_model.load_state_dict(model_dict) - self.block_model.language_model.load_state_dict(model_dict) - - # give each model the same ict_head to begin with as well - query_ict_head_state_dict = self.state_dict_for_save_checkpoint()[self._query_key]['ict_head'] - self.block_model.ict_head.load_state_dict(query_ict_head_state_dict) - - -class IREncoderBertModel(MegatronModule): - """BERT-based encoder for queries or blocks used for learned information retrieval.""" - def __init__(self, ict_head_size, num_tokentypes=2, parallel_output=True): - super(IREncoderBertModel, self).__init__() - args = get_args() - - self.ict_head_size = ict_head_size - self.parallel_output = parallel_output - init_method = init_method_normal(args.init_method_std) - scaled_init_method = scaled_init_method_normal(args.init_method_std, - args.num_layers) - - self.language_model, self._language_model_key = get_language_model( - num_tokentypes=num_tokentypes, - add_pooler=True, - encoder_attn_mask_type=AttnMaskType.padding, - init_method=init_method, - scaled_init_method=scaled_init_method) - - self.ict_head = get_linear_layer(args.hidden_size, ict_head_size, init_method) - self._ict_head_key = 'ict_head' - - def forward(self, input_ids, attention_mask, tokentype_ids=None): - extended_attention_mask = bert_extended_attention_mask( - attention_mask, next(self.language_model.parameters()).dtype) - position_ids = bert_position_ids(input_ids) - - lm_output, pooled_output = self.language_model( - input_ids, - position_ids, - extended_attention_mask, - tokentype_ids=tokentype_ids) - - # Output. - ict_logits = self.ict_head(pooled_output) - return ict_logits, None - - def state_dict_for_save_checkpoint(self, destination=None, prefix='', - keep_vars=False): - """For easy load when model is combined with other heads, - add an extra key.""" - - state_dict_ = {} - state_dict_[self._language_model_key] \ - = self.language_model.state_dict_for_save_checkpoint( - destination, prefix, keep_vars) - state_dict_[self._ict_head_key] \ - = self.ict_head.state_dict(destination, prefix, keep_vars) - return state_dict_ - - def load_state_dict(self, state_dict, strict=True): - """Customized load.""" - self.language_model.load_state_dict( - state_dict[self._language_model_key], strict=strict) - self.ict_head.load_state_dict( - state_dict[self._ict_head_key], strict=strict) - - -- Gitee From 0007d974207a199df24fd4da283293f338184aa6 Mon Sep 17 00:00:00 2001 From: fengliangjun Date: Tue, 6 Jun 2023 12:17:42 +0000 Subject: [PATCH 07/10] =?UTF-8?q?=E5=88=A0=E9=99=A4=E6=96=87=E4=BB=B6=20me?= =?UTF-8?q?gatron/model/classification.py?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- megatron/model/classification.py | 119 ------------------------------- 1 file changed, 119 deletions(-) delete mode 100644 megatron/model/classification.py diff --git a/megatron/model/classification.py b/megatron/model/classification.py deleted file mode 100644 index d4742c939..000000000 --- a/megatron/model/classification.py +++ /dev/null @@ -1,119 +0,0 @@ -# coding=utf-8 -# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Classification model.""" - -import torch - -from megatron import get_args, print_rank_last -from megatron import mpu -from megatron.model.enums import AttnMaskType -from megatron.model.bert_model import bert_extended_attention_mask, bert_position_ids -from megatron.model.language_model import get_language_model -from megatron.model.utils import get_linear_layer -from megatron.model.utils import init_method_normal -from megatron.model.utils import scaled_init_method_normal -from .module import MegatronModule - - -class Classification(MegatronModule): - - def __init__(self, - num_classes, - num_tokentypes=2, - pre_process=True, - post_process=True): - super(Classification, self).__init__(share_word_embeddings=False) - args = get_args() - - self.num_classes = num_classes - self.pre_process = pre_process - self.post_process = post_process - init_method = init_method_normal(args.init_method_std) - - self.language_model, self._language_model_key = get_language_model( - num_tokentypes=num_tokentypes, - add_pooler=True, - encoder_attn_mask_type=AttnMaskType.padding, - init_method=init_method, - scaled_init_method=scaled_init_method_normal(args.init_method_std, - args.num_layers), - pre_process=self.pre_process, - post_process=self.post_process) - - # Multi-choice head. - if self.post_process: - self.classification_dropout = torch.nn.Dropout(args.hidden_dropout) - self.classification_head = get_linear_layer(args.hidden_size, - self.num_classes, - init_method) - self._classification_head_key = 'classification_head' - - def set_input_tensor(self, input_tensor): - """See megatron.model.transformer.set_input_tensor()""" - self.language_model.set_input_tensor(input_tensor) - - def forward(self, model_input, attention_mask, tokentype_ids=None): - - extended_attention_mask = bert_extended_attention_mask(attention_mask) - input_ids = model_input - position_ids = bert_position_ids(input_ids) - - lm_output = self.language_model( - input_ids, - position_ids, - extended_attention_mask, - tokentype_ids=tokentype_ids - ) - - if self.post_process: - _, pooled_output = lm_output[0], lm_output[1] - classification_output = self.classification_dropout(pooled_output) - classification_logits = self.classification_head(classification_output) - - # Reshape back to separate choices. - classification_logits = classification_logits.view(-1, self.num_classes) - - return classification_logits - return lm_output - - def state_dict_for_save_checkpoint(self, destination=None, prefix='', - keep_vars=False): - """For easy load when model is combined with other heads, - add an extra key.""" - - state_dict_ = {} - state_dict_[self._language_model_key] \ - = self.language_model.state_dict_for_save_checkpoint( - destination, prefix, keep_vars) - if self.post_process: - state_dict_[self._classification_head_key] \ - = self.classification_head.state_dict( - destination, prefix, keep_vars) - return state_dict_ - - def load_state_dict(self, state_dict, strict=True): - """Customized load.""" - - self.language_model.load_state_dict( - state_dict[self._language_model_key], strict=strict) - if self.post_process: - if self._classification_head_key in state_dict: - self.classification_head.load_state_dict( - state_dict[self._classification_head_key], strict=strict) - else: - print_rank_last('***WARNING*** could not find {} in the checkpoint, ' - 'initializing to random'.format( - self._classification_head_key)) -- Gitee From 7b33f369206104ec05db87e339e5c0b65ae8a8fa Mon Sep 17 00:00:00 2001 From: fengliangjun Date: Tue, 6 Jun 2023 12:18:16 +0000 Subject: [PATCH 08/10] =?UTF-8?q?=E5=88=A0=E9=99=A4=E6=96=87=E4=BB=B6=20me?= =?UTF-8?q?gatron/model/multiple=5Fchoice.py?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- megatron/model/multiple_choice.py | 130 ------------------------------ 1 file changed, 130 deletions(-) delete mode 100644 megatron/model/multiple_choice.py diff --git a/megatron/model/multiple_choice.py b/megatron/model/multiple_choice.py deleted file mode 100644 index f82948f80..000000000 --- a/megatron/model/multiple_choice.py +++ /dev/null @@ -1,130 +0,0 @@ -# coding=utf-8 -# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Multiple choice model.""" - -import torch - -from megatron import get_args, print_rank_last -from megatron import mpu -from megatron.model.enums import AttnMaskType -from megatron.model.bert_model import bert_extended_attention_mask, bert_position_ids -from megatron.model.language_model import get_language_model -from megatron.model.utils import get_linear_layer -from megatron.model.utils import init_method_normal -from megatron.model.utils import scaled_init_method_normal -from .module import MegatronModule - - -class MultipleChoice(MegatronModule): - - def __init__(self, - num_tokentypes=2, - pre_process=True, - post_process=True): - super(MultipleChoice, self).__init__(share_word_embeddings=False) - args = get_args() - - init_method = init_method_normal(args.init_method_std) - self.pre_process = pre_process - self.post_process = post_process - - self.language_model, self._language_model_key = get_language_model( - num_tokentypes=num_tokentypes, - add_pooler=True, - encoder_attn_mask_type=AttnMaskType.padding, - init_method=init_method, - scaled_init_method=scaled_init_method_normal(args.init_method_std, - args.num_layers), - pre_process=self.pre_process, - post_process=self.post_process) - - # Multi-choice head. - if self.post_process: - self.multichoice_dropout = torch.nn.Dropout(args.hidden_dropout) - self.multichoice_head = get_linear_layer(args.hidden_size, 1, - init_method) - self._multichoice_head_key = 'multichoice_head' - - def set_input_tensor(self, input_tensor): - """See megatron.model.transformer.set_input_tensor()""" - self.language_model.set_input_tensor(input_tensor) - - def forward(self, model_input, attention_mask, tokentype_ids=None): - - # [batch, choices, sequence] --> [batch * choices, sequence] --> - # transformer --> [batch, choices] --> softmax - - # Ensure the shape is [batch-size, choices, sequence] - assert len(attention_mask.shape) == 3 - num_choices = attention_mask.shape[1] - - # Reshape and treat choice dimension the same as batch. - attention_mask = attention_mask.view(-1, attention_mask.size(-1)) - extended_attention_mask = bert_extended_attention_mask(attention_mask) - - input_ids = model_input - # Do the same as attention_mask for input_ids, tokentype_ids - assert len(input_ids.shape) == 3 - assert len(tokentype_ids.shape) == 3 - input_ids = input_ids.view(-1, input_ids.size(-1)) - tokentype_ids = tokentype_ids.view(-1, tokentype_ids.size(-1)) - position_ids = bert_position_ids(input_ids) - - lm_output = self.language_model( - input_ids, - position_ids, - extended_attention_mask, - tokentype_ids=tokentype_ids - ) - if self.post_process: - _, pooled_output = lm_output[0], lm_output[1] - multichoice_output = self.multichoice_dropout(pooled_output) - multichoice_logits = self.multichoice_head(multichoice_output) - - # Reshape back to separate choices. - multichoice_logits = multichoice_logits.view(-1, num_choices) - - return multichoice_logits - return lm_output - - def state_dict_for_save_checkpoint(self, destination=None, prefix='', - keep_vars=False): - """For easy load when model is combined with other heads, - add an extra key.""" - - state_dict_ = {} - state_dict_[self._language_model_key] \ - = self.language_model.state_dict_for_save_checkpoint( - destination, prefix, keep_vars) - if self.post_process: - state_dict_[self._multichoice_head_key] \ - = self.multichoice_head.state_dict( - destination, prefix, keep_vars) - return state_dict_ - - def load_state_dict(self, state_dict, strict=True): - """Customized load.""" - - self.language_model.load_state_dict( - state_dict[self._language_model_key], strict=strict) - if self.post_process: - if self._multichoice_head_key in state_dict: - self.multichoice_head.load_state_dict( - state_dict[self._multichoice_head_key], strict=strict) - else: - print_rank_last('***WARNING*** could not find {} in the checkpoint, ' - 'initializing to random'.format( - self._multichoice_head_key)) -- Gitee From a2f3de877d91a37e3ad19e8c8bcb15b79f041ffc Mon Sep 17 00:00:00 2001 From: fengliangjun Date: Tue, 6 Jun 2023 12:18:59 +0000 Subject: [PATCH 09/10] =?UTF-8?q?=E5=88=A0=E9=99=A4=E6=96=87=E4=BB=B6=20me?= =?UTF-8?q?gatron/model/t5=5Fmodel.py?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- megatron/model/t5_model.py | 174 ------------------------------------- 1 file changed, 174 deletions(-) delete mode 100644 megatron/model/t5_model.py diff --git a/megatron/model/t5_model.py b/megatron/model/t5_model.py deleted file mode 100644 index beb4f0ee5..000000000 --- a/megatron/model/t5_model.py +++ /dev/null @@ -1,174 +0,0 @@ -# coding=utf-8 -# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""T5 model.""" - -import torch - -from megatron import ( - get_args, - mpu -) -from megatron.model.enums import AttnMaskType -from megatron.model.language_model import parallel_lm_logits, get_language_model -from megatron.model.transformer import LayerNorm -from megatron.model.utils import ( - openai_gelu, - get_linear_layer, - init_method_normal, - scaled_init_method_normal -) -from .module import MegatronModule - - -def t5_extended_attention_mask(attention_mask_list): - - def attn_mask_postprocess(attn_mask): - # [b, 1, s, s] - extended_attention_mask = attn_mask.unsqueeze(1) - return extended_attention_mask - - return [attn_mask_postprocess(attn_mask) for attn_mask in attention_mask_list] - - -def t5_position_ids(token_ids): - # Create position ids - seq_length = token_ids.size(1) - position_ids = torch.arange(seq_length, dtype=torch.long, - device=token_ids.device) - position_ids = position_ids.unsqueeze(0).expand_as(token_ids) - - return position_ids - - -class T5LMHead(MegatronModule): - """Masked LM head for T5 - - Arguments: - mpu_vocab_size: model parallel size of vocabulary. - hidden_size: hidden size - init_method: init method for weight initialization - layernorm_epsilon: tolerance for layer norm divisions - parallel_output: wether output logits being distributed or not. - """ - - def __init__(self, mpu_vocab_size, parallel_output): - super(T5LMHead, self).__init__() - - args = get_args() - - self.bias = torch.nn.Parameter(torch.zeros(mpu_vocab_size)) - self.bias.model_parallel = True - self.bias.partition_dim = 0 - self.bias.stride = 1 - self.parallel_output = parallel_output - - def forward(self, hidden_states, word_embeddings_weight): - output = parallel_lm_logits(hidden_states, - word_embeddings_weight, - self.parallel_output, - bias=self.bias) - return output - - -class T5Model(MegatronModule): - """T5 Language model.""" - - def __init__(self, num_tokentypes=0, parallel_output=True): - super(T5Model, self).__init__() - args = get_args() - - self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy - self.parallel_output = parallel_output - init_method = init_method_normal(args.init_method_std) - scaled_init_method = scaled_init_method_normal(args.init_method_std, - args.num_layers) - - self.language_model, self._language_model_key = get_language_model( - num_tokentypes=num_tokentypes, - add_pooler=False, - add_decoder=True, - encoder_attn_mask_type=AttnMaskType.padding, - init_method=init_method, - scaled_init_method=scaled_init_method) - - self.lm_head = T5LMHead( - self.language_model.embedding.word_embeddings.weight.size(0), - parallel_output) - self._lm_head_key = 'lm_head' - - def set_input_tensor(self, input_tensor): - """See megatron.model.transformer.set_input_tensor()""" - self.language_model.set_input_tensor(input_tensor) - - def forward(self, encoder_input_ids, decoder_input_ids, encoder_attn_mask, - decoder_attn_mask, encoder_decoder_attn_mask, - tokentype_ids=None, lm_labels=None, enc_hidden_states=None): - - # Converting the attention masks to proper parameter settings - encoder_attn_mask, decoder_attn_mask, encoder_decoder_attn_mask = t5_extended_attention_mask( - [encoder_attn_mask, decoder_attn_mask, encoder_decoder_attn_mask]) - - encoder_position_ids = t5_position_ids(encoder_input_ids) - decoder_position_ids = t5_position_ids(decoder_input_ids) - - lm_output = self.language_model(encoder_input_ids, - encoder_position_ids, - encoder_attn_mask, - decoder_input_ids, - decoder_position_ids, - decoder_attn_mask, - encoder_decoder_attn_mask, - tokentype_ids=tokentype_ids, - enc_hidden_states=enc_hidden_states) - - decoder_output, encoder_output = lm_output - - # Output. - lm_logits = self.lm_head(decoder_output, - self.language_model.embedding.word_embeddings.weight) - - if lm_labels is None: - return lm_logits, encoder_output - else: - if self.fp16_lm_cross_entropy: - assert lm_logits.dtype == torch.half - lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits, lm_labels) - else: - lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits.float(), - lm_labels) - return lm_loss, encoder_output - - def state_dict_for_save_checkpoint(self, destination=None, prefix='', - keep_vars=False): - """For easy load when model is combined with other heads, - add an extra key.""" - - state_dict_ = {} - state_dict_[self._language_model_key] \ - = self.language_model.state_dict_for_save_checkpoint( - destination, prefix, keep_vars) - state_dict_[self._lm_head_key] \ - = self.lm_head.state_dict_for_save_checkpoint( - destination, prefix, keep_vars) - return state_dict_ - - def load_state_dict(self, state_dict, strict=True): - """Customized load.""" - - self.language_model.load_state_dict( - state_dict[self._language_model_key], strict=strict) - self.lm_head.load_state_dict(state_dict[self._lm_head_key], - strict=strict) -- Gitee From 11a592dc36344c08ca6fbacae147cca4f0d1da14 Mon Sep 17 00:00:00 2001 From: fengliangjun Date: Tue, 6 Jun 2023 12:33:38 +0000 Subject: [PATCH 10/10] =?UTF-8?q?=E5=88=A0=E9=99=A4=E6=96=87=E4=BB=B6=20ru?= =?UTF-8?q?nner.py?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- runner.py | 16 ---------------- 1 file changed, 16 deletions(-) delete mode 100644 runner.py diff --git a/runner.py b/runner.py deleted file mode 100644 index 6c1ff6a7e..000000000 --- a/runner.py +++ /dev/null @@ -1,16 +0,0 @@ -import torch -import torch_npu -from torch_npu.contrib import transfer_to_npu -import bugfix - -from pretrain_t5 import pretrain, train_valid_test_datasets_provider, model_provider, forward_step -from megatron.model import T5Model, ModelType - -option = {} -option["ACL_OP_COMPILER_CACHE_MODE"] = "enable" -option["ACL_OP_COMPILER_CACHE_DIR"] = "./cache" -print("option:",option) -torch.npu.set_option(option) - -pretrain(train_valid_test_datasets_provider, model_provider, ModelType.encoder_and_decoder, - forward_step, args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'}) -- Gitee