diff --git a/env_npu.sh b/env_npu.sh new file mode 100644 index 0000000000000000000000000000000000000000..5a653b390739e768051757f7645aaa82c754e29a --- /dev/null +++ b/env_npu.sh @@ -0,0 +1,77 @@ +#!/bin/bash +CANN_INSTALL_PATH_CONF='/etc/Ascend/ascend_cann_install.info' + +if [ -f $CANN_INSTALL_PATH_CONF ]; then + CANN_INSTALL_PATH=$(cat $CANN_INSTALL_PATH_CONF | grep Install_Path | cut -d "=" -f 2) +else + CANN_INSTALL_PATH="/usr/local/Ascend" +fi + +#CANN_INSTALL_PATH=/home/j00648035/ascend/ +CANN_INSTALL_PATH=/home/wangyixian/Ascend/ + +if [ -d ${CANN_INSTALL_PATH}/ascend-toolkit/latest ]; then + source ${CANN_INSTALL_PATH}/ascend-toolkit/set_env.sh +else + source ${CANN_INSTALL_PATH}/nnae/set_env.sh +fi + +echo $CANN_INSTALL_PATH + +#设置device侧日志登记为error +msnpureport -g error -d 0 +msnpureport -g error -d 1 +msnpureport -g error -d 2 +msnpureport -g error -d 3 +msnpureport -g error -d 4 +msnpureport -g error -d 5 +msnpureport -g error -d 6 +msnpureport -g error -d 7 + +#关闭Device侧Event日志 +msnpureport -e disable + +#将Host日志输出到串口,0-关闭/1-开启 +export ASCEND_SLOG_PRINT_TO_STDOUT=0 +#设置默认日志级别,0-debug/1-info/2-warning/3-error +export ASCEND_GLOBAL_LOG_LEVEL=3 +#设置Event日志开启标志,0-关闭/1-开启 +export ASCEND_GLOBAL_EVENT_ENABLE=0 +#设置是否开启taskque,0-关闭/1-开启 +export TASK_QUEUE_ENABLE=1 +#设置是否开启PTCopy,0-关闭/1-开启 +export PTCOPY_ENABLE=1 +#设置是否开启2个非连续combined标志,0-关闭/1-开启 +export COMBINED_ENABLE=0 +#设置特殊场景是否需要重新编译,不需要修改 +export DYNAMIC_OP="ADD#MUL" +#HCCL白名单开关,1-关闭/0-开启 +export HCCL_WHITELIST_DISABLE=1 +#设置HCCL超时时间 +export HCCL_CONNECT_TIMEOUT=1200 +ulimit -SHn 512000 + +path_lib=$(python3.7 -c """ +import sys +import re +result='' +for index in range(len(sys.path)): + match_sit = re.search('-packages', sys.path[index]) + if match_sit is not None: + match_lib = re.search('lib', sys.path[index]) + + if match_lib is not None: + end=match_lib.span()[1] + result += sys.path[index][0:end] + ':' + + result+=sys.path[index] + '/torch/lib:' +print(result)""" +) + +echo ${path_lib} + +export LD_LIBRARY_PATH=/usr/local/python3.7.5/lib/:${path_lib}:$LD_LIBRARY_PATH +# solve libisl +export LD_LIBRARY_PATH=/usr/local/lib/:$LD_LIBRARY_PATH +export HCCL_WHITELIST_DISABLE=1 +#export HCCL_IF_IP=$(hostname -I |awk '{print $1}') diff --git a/megatron/fused_kernels/__init__.py b/megatron/fused_kernels/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7646ddb1a6197eb0eea0844625ef17346519fbe1 --- /dev/null +++ b/megatron/fused_kernels/__init__.py @@ -0,0 +1,118 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + +import os +import pathlib +import subprocess + +from torch.utils import cpp_extension + +# Setting this param to a list has a problem of generating different +# compilation commands (with diferent order of architectures) and +# leading to recompilation of fused kernels. Set it to empty string +# to avoid recompilation and assign arch flags explicity in +# extra_cuda_cflags below +os.environ["TORCH_CUDA_ARCH_LIST"] = "" + + +def load(args): + + # Check if cuda 11 is installed for compute capability 8.0 + cc_flag = [] + _, bare_metal_major, bare_metal_minor = _get_cuda_bare_metal_version( + cpp_extension.CUDA_HOME) + if int(bare_metal_major) >= 11: + cc_flag.append('-gencode') + cc_flag.append('arch=compute_80,code=sm_80') + if int(bare_metal_minor) >= 7: + cc_flag.append('-gencode') + cc_flag.append('arch=compute_90,code=sm_90') + + # Build path + srcpath = pathlib.Path(__file__).parent.absolute() + buildpath = srcpath / 'build' + _create_build_dir(buildpath) + + # Helper function to build the kernels. + def _cpp_extention_load_helper(name, sources, extra_cuda_flags): + return cpp_extension.load( + name=name, + sources=sources, + build_directory=buildpath, + extra_cflags=['-O3',], + extra_cuda_cflags=['-O3', + '-gencode', 'arch=compute_70,code=sm_70', + '--use_fast_math'] + extra_cuda_flags + cc_flag, + verbose=(args.rank == 0) + ) + + # ============== + # Fused softmax. + # ============== + + if args.masked_softmax_fusion: + extra_cuda_flags = ['-U__CUDA_NO_HALF_OPERATORS__', + '-U__CUDA_NO_HALF_CONVERSIONS__', + '--expt-relaxed-constexpr', + '--expt-extended-lambda'] + + # Upper triangular softmax. + sources=[srcpath / 'scaled_upper_triang_masked_softmax.cpp', + srcpath / 'scaled_upper_triang_masked_softmax_cuda.cu'] + scaled_upper_triang_masked_softmax_cuda = _cpp_extention_load_helper( + "scaled_upper_triang_masked_softmax_cuda", + sources, extra_cuda_flags) + + # Masked softmax. + sources=[srcpath / 'scaled_masked_softmax.cpp', + srcpath / 'scaled_masked_softmax_cuda.cu'] + scaled_masked_softmax_cuda = _cpp_extention_load_helper( + "scaled_masked_softmax_cuda", sources, extra_cuda_flags) + + # Softmax + sources=[srcpath / 'scaled_softmax.cpp', + srcpath / 'scaled_softmax_cuda.cu'] + scaled_softmax_cuda = _cpp_extention_load_helper( + "scaled_softmax_cuda", sources, extra_cuda_flags) + + # ================================= + # Mixed precision fused layer norm. + # ================================= + + extra_hopper_flags = ['-U__CUDA_NO_HALF_OPERATORS__', + '-U__CUDA_NO_HALF_CONVERSIONS__'] + + extra_cuda_flags = ['-maxrregcount=50'] + sources=[srcpath / 'layer_norm_cuda.cpp', + srcpath / 'layer_norm_cuda_kernel.cu'] + fused_mix_prec_layer_norm_cuda = _cpp_extention_load_helper( + "fused_mix_prec_layer_norm_cuda", sources, extra_cuda_flags + extra_hopper_flags) + + # ================================= + # Fused gradient accumulation to weight gradient computation of linear layer + # ================================= + + if args.gradient_accumulation_fusion: + sources=[srcpath / 'fused_weight_gradient_dense.cpp', + srcpath / 'fused_weight_gradient_dense.cu'] + fused_dense_cuda = _cpp_extention_load_helper( + "fused_dense_cuda", sources, extra_hopper_flags) + + +def _get_cuda_bare_metal_version(cuda_dir): + raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], + universal_newlines=True) + output = raw_output.split() + release_idx = output.index("release") + 1 + release = output[release_idx].split(".") + bare_metal_major = release[0] + bare_metal_minor = release[1][0] + + return raw_output, bare_metal_major, bare_metal_minor + + +def _create_build_dir(buildpath): + try: + os.mkdir(buildpath) + except OSError: + if not os.path.isdir(buildpath): + print(f"Creation of the build directory {buildpath} failed") diff --git a/megatron/fused_kernels/compat.h b/megatron/fused_kernels/compat.h new file mode 100644 index 0000000000000000000000000000000000000000..5495d7807762d8b4e3dbc11b28dba15f85bd8108 --- /dev/null +++ b/megatron/fused_kernels/compat.h @@ -0,0 +1,17 @@ +/* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. */ + +/*This code is copied fron NVIDIA apex: + * https://github.com/NVIDIA/apex + * with minor changes. */ + + + +#ifndef TORCH_CHECK +#define TORCH_CHECK AT_CHECK +#endif + +#ifdef VERSION_GE_1_3 +#define DATA_PTR data_ptr +#else +#define DATA_PTR data +#endif diff --git a/megatron/fused_kernels/fused_weight_gradient_dense.cpp b/megatron/fused_kernels/fused_weight_gradient_dense.cpp new file mode 100644 index 0000000000000000000000000000000000000000..194ee59353d2a8c9da24e50c592f4e086806d078 --- /dev/null +++ b/megatron/fused_kernels/fused_weight_gradient_dense.cpp @@ -0,0 +1,47 @@ +#include +#include + +#include +#include + +#include "type_shim.h" + + +template +int wgrad_gemm_accum_fp32_cuda(T *input, T *d_output, float *d_weight, int in_dim, int hidden_dim, int out_dim); + +void wgrad_gemm_accum_fp32(const at::Tensor input, const at::Tensor d_output, at::Tensor d_weight) { + at::Tensor input_2d, d_output_2d; + // input tensor: collapse to the first dim + auto in_sizes = input.sizes(); + if (input.dim() > 2) { + input_2d = input.view({-1, in_sizes[in_sizes.size() - 1]}); + } else { + input_2d = input; + } + // d_output tensor: collapse to the first dim + auto d_out_sizes = d_output.sizes(); + if (d_output.dim() > 2) { + d_output_2d = d_output.view({-1, d_out_sizes[d_out_sizes.size() - 1]}); + } else { + d_output_2d = d_output; + } + + int hidden_dim = input_2d.size(0); + int in_dim = input_2d.size(1); + int out_dim = d_weight.size(0); + + DISPATCH_HALF_BFLOAT_AND_FLOAT(input_2d.scalar_type(), "wgrad_gemm_accum_fp32", + int result = wgrad_gemm_accum_fp32_cuda( + input_2d.data_ptr(), + d_output_2d.data_ptr(), + d_weight.data_ptr(), + in_dim, + hidden_dim, + out_dim); + ); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("wgrad_gemm_accum_fp32", &wgrad_gemm_accum_fp32, "wgrad gemm accum in fp32"); +} diff --git a/megatron/fused_kernels/fused_weight_gradient_dense.cu b/megatron/fused_kernels/fused_weight_gradient_dense.cu new file mode 100644 index 0000000000000000000000000000000000000000..7dc10e65d37e531b54d2c875011f9c47eeb8ff7f --- /dev/null +++ b/megatron/fused_kernels/fused_weight_gradient_dense.cu @@ -0,0 +1,157 @@ +#include +#include +#include +#include +#include +#include +#include + +/* Includes, cuda */ +#include +#include + + +// BF16 Tensor core wrapper around cublas GEMMEx +cublasStatus_t gemmex_wrapper( + cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + int m, + int n, + int k, + const float* alpha, + at::BFloat16* A, + int lda, + at::BFloat16* B, + int ldb, + const float* beta, + float* C, + int ldc) { + return cublasGemmEx( + handle, + transa, + transb, + m, + n, + k, + alpha, + A, + CUDA_R_16BF, + lda, + B, + CUDA_R_16BF, + ldb, + beta, + C, + CUDA_R_32F, + ldc, + CUDA_R_32F, + CUBLAS_GEMM_DEFAULT_TENSOR_OP); +} + +// FP16 Tensor core wrapper around cublas GEMMEx +cublasStatus_t gemmex_wrapper( + cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + int m, + int n, + int k, + const float* alpha, + at::Half* A, + int lda, + at::Half* B, + int ldb, + const float* beta, + float* C, + int ldc) { + return cublasGemmEx( + handle, + transa, + transb, + m, + n, + k, + alpha, + A, + CUDA_R_16F, + lda, + B, + CUDA_R_16F, + ldb, + beta, + C, + CUDA_R_32F, + ldc, + CUDA_R_32F, + CUBLAS_GEMM_DEFAULT_TENSOR_OP); +} + +// FP32 Tensor core wrapper around cublas GEMMEx +cublasStatus_t gemmex_wrapper( + cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + int m, + int n, + int k, + const float* alpha, + float* A, + int lda, + float* B, + int ldb, + const float* beta, + float* C, + int ldc) { + return cublasGemmEx( + handle, + transa, + transb, + m, + n, + k, + alpha, + A, + CUDA_R_32F, + lda, + B, + CUDA_R_32F, + ldb, + beta, + C, + CUDA_R_32F, + ldc, + CUDA_R_32F, + CUBLAS_GEMM_DEFAULT_TENSOR_OP); +} + +template +int wgrad_gemm_accum_fp32_cuda(T *input, T *d_output, float *d_weight, int in_dim, int hidden_dim, int out_dim) { + cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); + cudaStream_t stream; + cublasGetStream(handle, &stream); + const float alpha = 1.0; + const float beta = 1.0; + int status = 1; + + status = gemmex_wrapper( + handle, + CUBLAS_OP_N, + CUBLAS_OP_T, + in_dim, + out_dim, + hidden_dim, + &alpha, + input, + in_dim, + d_output, + out_dim, + &beta, + d_weight, + in_dim); + return status; +} + +template int wgrad_gemm_accum_fp32_cuda(at::Half *input, at::Half *d_output, float *d_weight, int in_dim, int hidden_dim, int out_dim); +template int wgrad_gemm_accum_fp32_cuda(at::BFloat16 *input, at::BFloat16 *d_output, float *d_weight, int in_dim, int hidden_dim, int out_dim); +template int wgrad_gemm_accum_fp32_cuda(float *input, float *d_output, float *d_weight, int in_dim, int hidden_dim, int out_dim); diff --git a/megatron/fused_kernels/layer_norm_cuda.cpp b/megatron/fused_kernels/layer_norm_cuda.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f0925fcdd06738a8c3db864d91bde9c7d3012919 --- /dev/null +++ b/megatron/fused_kernels/layer_norm_cuda.cpp @@ -0,0 +1,187 @@ +/* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. */ + +/*This code is copied fron NVIDIA apex: + * https://github.com/NVIDIA/apex + * with minor changes. */ + +#include +#include +#include +#include "compat.h" + +namespace { + +void compute_n1_n2( + at::Tensor input, + at::IntArrayRef normalized_shape, + int& n1, + int& n2) { + int idiff = input.ndimension() - normalized_shape.size(); + n2 = 1; + for (int i = 0; i < (int)normalized_shape.size(); ++i) { + assert( input.sizes()[i+idiff] == normalized_shape[i] ); + n2 *= normalized_shape[i]; + } + n1 = 1; + for (int i = 0; i < idiff; ++i) { + n1 *= input.sizes()[i]; + } +} + +void check_args( + at::IntArrayRef normalized_shape, + at::Tensor gamma, + at::Tensor beta + ) +{ + TORCH_CHECK(!gamma.defined() || gamma.sizes().equals(normalized_shape)); + TORCH_CHECK(!beta.defined() || beta.sizes().equals(normalized_shape)); +} + +void check_args( + at::Tensor input, + at::IntArrayRef normalized_shape, + int& n1, + int& n2 + ) +{ + int64_t normalized_ndim = normalized_shape.size(); + + if (normalized_ndim < 1) { + std::stringstream ss; + ss << "Expected normalized_shape to be at least 1-dimensional, i.e., " + << "containing at least one element, but got normalized_shape=" + << normalized_shape; + throw std::runtime_error(ss.str()); + } + + auto input_shape = input.sizes(); + auto input_ndim = input.dim(); + + if (input_ndim < normalized_ndim || + !input_shape.slice(input_ndim - normalized_ndim).equals(normalized_shape)) { + std::stringstream ss; + ss << "Given normalized_shape=" << normalized_shape + << ", expected input with shape [*"; + for (auto size : normalized_shape) { + ss << ", " << size; + } + ss << "], but got input of size" << input_shape; + throw std::runtime_error(ss.str()); + } + + compute_n1_n2(input,normalized_shape,n1,n2); +} + + +void check_args( + at::Tensor input, + at::IntArrayRef normalized_shape, + at::Tensor gamma, + at::Tensor beta, + int& n1, + int& n2 + ) +{ + check_args(input,normalized_shape,n1,n2); + check_args(normalized_shape,gamma,beta); +} +} + +void cuda_layer_norm( + at::Tensor* output, + at::Tensor* mean, + at::Tensor* invvar, + at::Tensor* input, + int n1, + int n2, + at::IntArrayRef normalized_shape, + at::Tensor* gamma, + at::Tensor* beta, + double epsilon); + +#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) + +std::vector layer_norm_affine( + at::Tensor input, + at::IntArrayRef normalized_shape, + at::Tensor gamma, + at::Tensor beta, + double epsilon) { + + CHECK_INPUT(input); + CHECK_INPUT(gamma); + CHECK_INPUT(beta); + int n1, n2; + check_args(input, normalized_shape, gamma, beta, n1, n2); + + at::Tensor output = at::empty_like( + input, gamma.options().dtype(gamma.scalar_type())); + at::Tensor mean = at::empty( + {n1}, input.options().dtype(at::ScalarType::Float)); + at::Tensor invvar = at::empty_like(mean); + + cuda_layer_norm(&output, &mean, &invvar, &input, n1, n2, + normalized_shape, &gamma, &beta, epsilon); + + return {output, mean, invvar}; + +} + + +void cuda_layer_norm_gradient( + at::Tensor* dout, + at::Tensor* mean, + at::Tensor* invvar, + at::Tensor* input, + int n1, + int n2, + at::IntArrayRef normalized_shape, + at::Tensor* gamma, + at::Tensor* beta, + double epsilon, + at::Tensor* grad_input, + at::Tensor* grad_gamma, + at::Tensor* grad_beta + ); + +std::vector layer_norm_gradient_affine( + at::Tensor dout, + at::Tensor mean, + at::Tensor invvar, + at::Tensor input, + at::IntArrayRef normalized_shape, + at::Tensor gamma, + at::Tensor beta, + double epsilon) { + + CHECK_INPUT(dout); + CHECK_INPUT(mean); + CHECK_INPUT(invvar); + CHECK_INPUT(input); + CHECK_INPUT(gamma); + CHECK_INPUT(beta); + int n1, n2; + check_args(input, normalized_shape, gamma, beta, n1, n2); + + at::Tensor grad_input = at::empty_like(input); + at::Tensor grad_gamma = at::empty_like(gamma); + at::Tensor grad_beta = at::empty_like(beta); + + cuda_layer_norm_gradient(&dout, &mean, &invvar, &input, n1, n2, + normalized_shape, &gamma, &beta, epsilon, + &grad_input, &grad_gamma, &grad_beta); + + return {grad_input, grad_gamma, grad_beta}; + +} + + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward_affine", &layer_norm_affine, + "LayerNorm forward (CUDA)"); + m.def("backward_affine", &layer_norm_gradient_affine, + "LayerNorm backward (CUDA)"); +} diff --git a/megatron/fused_kernels/layer_norm_cuda_kernel.cu b/megatron/fused_kernels/layer_norm_cuda_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..30b376501a8b8e6e45f098f8606e3004e5d4c69b --- /dev/null +++ b/megatron/fused_kernels/layer_norm_cuda_kernel.cu @@ -0,0 +1,818 @@ +/* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. */ + +/*This code is copied fron NVIDIA apex: + * https://github.com/NVIDIA/apex + * with minor changes. */ + +#include "ATen/ATen.h" +#include "ATen/AccumulateType.h" +#include "ATen/cuda/CUDAContext.h" +#include "ATen/cuda/DeviceUtils.cuh" + +#include +#include + +#include "type_shim.h" + +template __device__ +void cuWelfordOnlineSum( + const U curr, + U& mu, + U& sigma2, + U& count) +{ + count = count + U(1); + U delta = curr - mu; + U lmean = mu + delta / count; + mu = lmean; + U delta2 = curr - lmean; + sigma2 = sigma2 + delta * delta2; +} + +template __device__ +void cuChanOnlineSum( + const U muB, + const U sigma2B, + const U countB, + U& mu, + U& sigma2, + U& count) +{ + U delta = muB - mu; + U nA = count; + U nB = countB; + count = count + countB; + U nX = count; + if (nX > U(0)) { + nA = nA / nX; + nB = nB / nX; + mu = nA*mu + nB*muB; + sigma2 = sigma2 + sigma2B + delta * delta * nA * nB * nX; + } else { + mu = U(0); + sigma2 = U(0); + } +} + +template __device__ +void cuWelfordMuSigma2( + const T* __restrict__ vals, + const int n1, + const int n2, + const int i1, + U& mu, + U& sigma2, + U* buf) +{ + // Assumptions: + // 1) blockDim.x == warpSize + // 2) Tensor is contiguous + // 3) 2*blockDim.y*sizeof(U)+blockDim.y*sizeof(int) shared memory available. + // + // compute variance and mean over n2 + U count = U(0); + mu= U(0); + sigma2 = U(0); + if (i1 < n1) { + // one warp normalizes one n1 index, + // synchronization is implicit + // initialize with standard Welford algorithm + const int numx = blockDim.x * blockDim.y; + const int thrx = threadIdx.x + threadIdx.y * blockDim.x; + const T* lvals = vals + i1*n2; + int l = 4*thrx; + for (; l+3 < n2; l+=4*numx) { + for (int k = 0; k < 4; ++k) { + U curr = static_cast(lvals[l+k]); + cuWelfordOnlineSum(curr,mu,sigma2,count); + } + } + for (; l < n2; ++l) { + U curr = static_cast(lvals[l]); + cuWelfordOnlineSum(curr,mu,sigma2,count); + } + // intra-warp reductions + for (int l = 0; l <= 4; ++l) { + int srcLaneB = (threadIdx.x+(1<(muB,sigma2B,countB,mu,sigma2,count); + } + // threadIdx.x == 0 has correct values for each warp + // inter-warp reductions + if (blockDim.y > 1) { + U* ubuf = (U*)buf; + U* ibuf = (U*)(ubuf + blockDim.y); + for (int offset = blockDim.y/2; offset > 0; offset /= 2) { + // upper half of warps write to shared + if (threadIdx.x == 0 && threadIdx.y >= offset && threadIdx.y < 2*offset) { + const int wrt_y = threadIdx.y - offset; + ubuf[2*wrt_y] = mu; + ubuf[2*wrt_y+1] = sigma2; + ibuf[wrt_y] = count; + } + __syncthreads(); + // lower half merges + if (threadIdx.x == 0 && threadIdx.y < offset) { + U muB = ubuf[2*threadIdx.y]; + U sigma2B = ubuf[2*threadIdx.y+1]; + U countB = ibuf[threadIdx.y]; + cuChanOnlineSum(muB,sigma2B,countB,mu,sigma2,count); + } + __syncthreads(); + } + // threadIdx.x = 0 && threadIdx.y == 0 only thread that has correct values + if (threadIdx.x == 0 && threadIdx.y == 0) { + ubuf[0] = mu; + ubuf[1] = sigma2; + } + __syncthreads(); + mu = ubuf[0]; + sigma2 = ubuf[1]/U(n2); + // don't care about final value of count, we know count == n2 + } else { + mu = WARP_SHFL(mu, 0); + sigma2 = WARP_SHFL(sigma2/U(n2), 0); + } + } +} + +template<> __device__ +void cuWelfordMuSigma2( + const at::Half* __restrict__ vals, + const int n1, + const int n2, + const int i1, + float& mu, + float& sigma2, + float* buf) +{ + // Assumptions: + // 1) blockDim.x == warpSize + // 2) Tensor is contiguous + // 3) 2*blockDim.y*sizeof(U)+blockDim.y*sizeof(int) shared memory available. + // + // compute variance and mean over n2 + float count = 0.0f; + mu= float(0); + sigma2 = float(0); + if (i1 < n1) { + // one warp normalizes one n1 index, + // synchronization is implicit + // initialize with standard Welford algorithm + const int numx = blockDim.x * blockDim.y; + const int thrx = threadIdx.x + threadIdx.y * blockDim.x; + const at::Half* lvals = vals + i1*n2; + int l = 8*thrx; + if ((((size_t)lvals)&3) != 0) { + // 16 bit alignment + // first thread consumes first point + if (thrx == 0) { + float curr = static_cast(lvals[0]); + cuWelfordOnlineSum(curr,mu,sigma2,count); + } + ++l; + } + // at this point, lvals[l] are 32 bit aligned for all threads. + for (; l+7 < n2; l+=8*numx) { + for (int k = 0; k < 8; k+=2) { + float2 curr = __half22float2(*((__half2*)(lvals+l+k))); + cuWelfordOnlineSum(curr.x,mu,sigma2,count); + cuWelfordOnlineSum(curr.y,mu,sigma2,count); + } + } + for (; l < n2; ++l) { + float curr = static_cast(lvals[l]); + cuWelfordOnlineSum(curr,mu,sigma2,count); + } + // intra-warp reductions + for (int l = 0; l <= 4; ++l) { + int srcLaneB = (threadIdx.x+(1< 1) { + float* ubuf = (float*)buf; + float* ibuf = (float*)(ubuf + blockDim.y); + for (int offset = blockDim.y/2; offset > 0; offset /= 2) { + // upper half of warps write to shared + if (threadIdx.x == 0 && threadIdx.y >= offset && threadIdx.y < 2*offset) { + const int wrt_y = threadIdx.y - offset; + ubuf[2*wrt_y] = mu; + ubuf[2*wrt_y+1] = sigma2; + ibuf[wrt_y] = count; + } + __syncthreads(); + // lower half merges + if (threadIdx.x == 0 && threadIdx.y < offset) { + float muB = ubuf[2*threadIdx.y]; + float sigma2B = ubuf[2*threadIdx.y+1]; + float countB = ibuf[threadIdx.y]; + cuChanOnlineSum(muB,sigma2B,countB,mu,sigma2,count); + } + __syncthreads(); + } + // threadIdx.x = 0 && threadIdx.y == 0 only thread that has correct values + if (threadIdx.x == 0 && threadIdx.y == 0) { + ubuf[0] = mu; + ubuf[1] = sigma2; + } + __syncthreads(); + mu = ubuf[0]; + sigma2 = ubuf[1]/float(n2); + // don't care about final value of count, we know count == n2 + } else { + mu = WARP_SHFL(mu, 0); + sigma2 = WARP_SHFL(sigma2/float(n2), 0); + } + } +} + +template U rsqrt(U v) { + return U(1) / sqrt(v); +} +template<> float rsqrt(float v) { + return rsqrtf(v); +} +template<> double rsqrt(double v) { + return rsqrt(v); +} + +namespace { +// This is the un-specialized struct. Note that we prevent instantiation of this +// struct by putting an undefined symbol in the function body so it won't compile. +// template +// struct SharedMemory +// { +// // Ensure that we won't compile any un-specialized types +// __device__ T *getPointer() +// { +// extern __device__ void error(void); +// error(); +// return NULL; +// } +// }; +// https://github.com/NVIDIA/apex/issues/246 +template +struct SharedMemory; + +template <> +struct SharedMemory +{ + __device__ float *getPointer() + { + extern __shared__ float s_float[]; + return s_float; + } +}; + +} + +template __global__ +void cuApplyLayerNorm( + V* __restrict__ output_vals, + U* __restrict__ mean, + U* __restrict__ invvar, + const T* __restrict__ vals, + const int n1, + const int n2, + const U epsilon, + const V* __restrict__ gamma, + const V* __restrict__ beta + ) +{ + // Assumptions: + // 1) blockDim.x == warpSize + // 2) Tensors are contiguous + // + for (auto i1=blockIdx.y; i1 < n1; i1 += gridDim.y) { + SharedMemory shared; + U* buf = shared.getPointer(); + U mu,sigma2; + cuWelfordMuSigma2(vals,n1,n2,i1,mu,sigma2,buf); + const T* lvals = vals + i1*n2; + V* ovals = output_vals + i1*n2; + U c_invvar = rsqrt(sigma2 + epsilon); + const int numx = blockDim.x * blockDim.y; + const int thrx = threadIdx.x + threadIdx.y * blockDim.x; + if (gamma != NULL && beta != NULL) { + for (int i = thrx; i < n2; i+=numx) { + U curr = static_cast(lvals[i]); + ovals[i] = gamma[i] * static_cast(c_invvar * (curr - mu)) + beta[i]; + } + } else { + for (int i = thrx; i < n2; i+=numx) { + U curr = static_cast(lvals[i]); + ovals[i] = static_cast(c_invvar * (curr - mu)); + } + } + if (threadIdx.x == 0 && threadIdx.y == 0) { + mean[i1] = mu; + invvar[i1] = c_invvar; + } + __syncthreads(); + } +} + +template __device__ +void cuLoadWriteStridedInputs( + const int i1_block, + const int thr_load_row_off, + const int thr_load_col_off, + const int i2_off, + const int row_stride, + U* warp_buf1, + U* warp_buf2, + const T* input, + const V* dout, + const int i1_end, + const int n2, + const U* __restrict__ mean, + const U* __restrict__ invvar + ) +{ + int i1 = i1_block+thr_load_row_off; + if (i1 < i1_end) { + U curr_mean = mean[i1]; + U curr_invvar = invvar[i1]; + for (int k = 0; k < blockDim.y; ++k) { + int i2 = i2_off + k; + int load_idx = i1*n2+i2; + int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k; + if (i2(input[load_idx]); + U curr_dout = static_cast(dout[load_idx]); + warp_buf1[write_idx] = curr_dout; + warp_buf2[write_idx] = curr_dout * (curr_input - curr_mean) * curr_invvar; + } else { + warp_buf1[write_idx] = U(0); + warp_buf2[write_idx] = U(0); + } + } + } else { + for (int k = 0; k < blockDim.y; ++k) { + int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k; + warp_buf1[write_idx] = U(0); + warp_buf2[write_idx] = U(0); + } + } +} + +template __device__ +void cuLoadAddStridedInputs( + const int i1_block, + const int thr_load_row_off, + const int thr_load_col_off, + const int i2_off, + const int row_stride, + U* warp_buf1, + U* warp_buf2, + const T* input, + const V* dout, + const int i1_end, + const int n2, + const U* __restrict__ mean, + const U* __restrict__ invvar + ) +{ + int i1 = i1_block+thr_load_row_off; + if (i1 < i1_end) { + U curr_mean = mean[i1]; + U curr_invvar = invvar[i1]; + for (int k = 0; k < blockDim.y; ++k) { + int i2 = i2_off + k; + int load_idx = i1*n2+i2; + int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k; + if (i2(input[load_idx]); + U curr_dout = static_cast(dout[load_idx]); + warp_buf1[write_idx] += curr_dout; + warp_buf2[write_idx] += curr_dout * (curr_input - curr_mean) * curr_invvar; + } + } + } +} + +template __global__ +void cuComputePartGradGammaBeta( + const V* __restrict__ dout, + const T* __restrict__ input, + const int n1, + const int n2, + const U* __restrict__ mean, + const U* __restrict__ invvar, + U epsilon, + U* part_grad_gamma, + U* part_grad_beta) +{ + const int numsegs_n1 = (n1+blockDim.y*blockDim.y-1) / (blockDim.y*blockDim.y); + const int segs_per_block = (numsegs_n1 + gridDim.y - 1) / gridDim.y; + const int i1_beg = blockIdx.y * segs_per_block * blockDim.y*blockDim.y; + const int i1_beg_plus_one = (blockIdx.y+1) * segs_per_block * blockDim.y*blockDim.y; + const int i1_end = i1_beg_plus_one < n1 ? i1_beg_plus_one : n1; + const int row_stride = blockDim.x+1; + const int thr_load_col_off = (threadIdx.x*blockDim.y)&(blockDim.x-1); + const int thr_load_row_off = (threadIdx.x*blockDim.y)/blockDim.x + threadIdx.y*blockDim.y; + const int i2_off = blockIdx.x * blockDim.x + thr_load_col_off; + SharedMemory shared; + U* buf = shared.getPointer(); // buf has at least blockDim.x * blockDim.y * blockDim.y + (blockDim.y - 1)*(blockDim.x/blockDim.y) elements + U* warp_buf1 = (U*)buf; + U* warp_buf2 = warp_buf1 + blockDim.y * blockDim.y * row_stride; + // compute partial sums from strided inputs + // do this to increase number of loads in flight + cuLoadWriteStridedInputs(i1_beg,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,n2,mean,invvar); + for (int i1_block = i1_beg+blockDim.y*blockDim.y; i1_block < i1_end; i1_block+=blockDim.y*blockDim.y) { + cuLoadAddStridedInputs(i1_block,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,n2,mean,invvar); + } + __syncthreads(); + // inter-warp reductions + // sum within each warp + U acc1 = U(0); + U acc2 = U(0); + for (int k = 0; k < blockDim.y; ++k) { + int row1 = threadIdx.y + k*blockDim.y; + int idx1 = row1*row_stride + threadIdx.x; + acc1 += warp_buf1[idx1]; + acc2 += warp_buf2[idx1]; + } + warp_buf1[threadIdx.y*row_stride+threadIdx.x] = acc1; + warp_buf2[threadIdx.y*row_stride+threadIdx.x] = acc2; + __syncthreads(); + // sum all warps + for (int offset = blockDim.y/2; offset > 1; offset /= 2) { + if (threadIdx.y < offset) { + int row1 = threadIdx.y; + int row2 = threadIdx.y + offset; + int idx1 = row1*row_stride + threadIdx.x; + int idx2 = row2*row_stride + threadIdx.x; + warp_buf1[idx1] += warp_buf1[idx2]; + warp_buf2[idx1] += warp_buf2[idx2]; + } + __syncthreads(); + } + int i2 = blockIdx.x * blockDim.x + threadIdx.x; + if (threadIdx.y == 0 && i2 < n2) { + int row1 = threadIdx.y; + int row2 = threadIdx.y + 1; + int idx1 = row1*row_stride + threadIdx.x; + int idx2 = row2*row_stride + threadIdx.x; + part_grad_beta[blockIdx.y*n2+i2] = warp_buf1[idx1] + warp_buf1[idx2]; + part_grad_gamma[blockIdx.y*n2+i2] = warp_buf2[idx1] + warp_buf2[idx2]; + } +} + +template __global__ +void cuComputeGradGammaBeta( + const U* part_grad_gamma, + const U* part_grad_beta, + const int part_size, + const int n1, + const int n2, + V* grad_gamma, + V* grad_beta) +{ + // sum partial gradients for gamma and beta + SharedMemory shared; + U* buf = shared.getPointer(); + int i2 = blockIdx.x * blockDim.x + threadIdx.x; + if (i2 < n2) { + // each warp does sequential reductions until reduced part_size is num_warps + int num_warp_reductions = part_size / blockDim.y; + U sum_gamma = U(0); + U sum_beta = U(0); + const U* part_grad_gamma_ptr = part_grad_gamma + threadIdx.y * num_warp_reductions * n2 + i2; + const U* part_grad_beta_ptr = part_grad_beta + threadIdx.y * num_warp_reductions * n2 + i2; + for (int warp_offset = 0; warp_offset < num_warp_reductions; ++warp_offset) { + sum_gamma += part_grad_gamma_ptr[warp_offset*n2]; + sum_beta += part_grad_beta_ptr[warp_offset*n2]; + } + // inter-warp reductions + const int nbsize3 = blockDim.x * blockDim.y / 2; + for (int offset = blockDim.y/2; offset >= 1; offset /= 2) { + // top half write to shared memory + if (threadIdx.y >= offset && threadIdx.y < 2*offset) { + const int write_idx = (threadIdx.y - offset) * blockDim.x + threadIdx.x; + buf[write_idx] = sum_gamma; + buf[write_idx+nbsize3] = sum_beta; + } + __syncthreads(); + // bottom half sums + if (threadIdx.y < offset) { + const int read_idx = threadIdx.y * blockDim.x + threadIdx.x; + sum_gamma += buf[read_idx]; + sum_beta += buf[read_idx+nbsize3]; + } + __syncthreads(); + } + // write out fully summed gradients + if (threadIdx.y == 0) { + grad_gamma[i2] = sum_gamma; + grad_beta[i2] = sum_beta; + } + } +} + +template __global__ +void cuComputeGradInput( + const V* __restrict__ dout, + const T* __restrict__ input, + const int n1, + const int n2, + const U* __restrict__ mean, + const U* __restrict__ invvar, + U epsilon, + const V* gamma, + T* grad_input) +{ + for (auto i1=blockIdx.y; i1 < n1; i1 += gridDim.y) { + U sum_loss1 = U(0); + U sum_loss2 = U(0); + const U c_mean = mean[i1]; + const U c_invvar = invvar[i1]; + const T* k_input = input + i1*n2; + const V* k_dout = dout + i1*n2; + const int numx = blockDim.x * blockDim.y; + const int thrx = threadIdx.x + threadIdx.y * blockDim.x; + if (gamma != NULL) { + int l = 4*thrx; + for (; l+3 < n2; l+=4*numx) { + for (int k = 0; k < 4; ++k) { + const U c_h = static_cast(k_input[l+k]); + const U c_loss = static_cast(k_dout[l+k]); + sum_loss1 += c_loss * gamma[l+k]; + sum_loss2 += c_loss * gamma[l+k] * (c_h - c_mean) * c_invvar; + } + } + for (; l < n2; ++l) { + const U c_h = static_cast(k_input[l]); + const U c_loss = static_cast(k_dout[l]); + sum_loss1 += c_loss * gamma[l]; + sum_loss2 += c_loss * gamma[l] * (c_h - c_mean) * c_invvar; + } + } else { + int l = 4*thrx; + for (; l+3 < n2; l+=4*numx) { + for (int k = 0; k < 4; ++k) { + const U c_h = static_cast(k_input[l+k]); + const U c_loss = static_cast(k_dout[l+k]); + sum_loss1 += c_loss; + sum_loss2 += c_loss * (c_h - c_mean) * c_invvar; + } + } + for (; l < n2; ++l) { + const U c_h = static_cast(k_input[l]); + const U c_loss = static_cast(k_dout[l]); + sum_loss1 += c_loss; + sum_loss2 += c_loss * (c_h - c_mean) * c_invvar; + } + } + // intra-warp reductions + for (int mask = blockDim.x/2; mask > 0; mask /= 2) { + sum_loss1 += WARP_SHFL_XOR(sum_loss1, mask); + sum_loss2 += WARP_SHFL_XOR(sum_loss2, mask); + } + // inter-warp reductions + if (blockDim.y > 1) { + SharedMemory shared; + U* buf = shared.getPointer(); + for (int offset = blockDim.y/2; offset > 0; offset /= 2) { + // upper half of warps write to shared + if (threadIdx.y >= offset && threadIdx.y < 2*offset) { + const int wrt_i = (threadIdx.y - offset) * blockDim.x + threadIdx.x; + buf[2*wrt_i] = sum_loss1; + buf[2*wrt_i+1] = sum_loss2; + } + __syncthreads(); + // lower half merges + if (threadIdx.y < offset) { + const int read_i = threadIdx.y * blockDim.x + threadIdx.x; + sum_loss1 += buf[2*read_i]; + sum_loss2 += buf[2*read_i+1]; + } + __syncthreads(); + } + if (threadIdx.y == 0) { + buf[2*threadIdx.x] = sum_loss1; + buf[2*threadIdx.x+1] = sum_loss2; + } + __syncthreads(); + if (threadIdx.y !=0) { + sum_loss1 = buf[2*threadIdx.x]; + sum_loss2 = buf[2*threadIdx.x+1]; + } + } + // all threads now have the two sums over l + U fH = (U)n2; + U term1 = (U(1) / fH) * c_invvar; + T* k_grad_input = grad_input + i1*n2; + if (gamma != NULL) { + for (int l = thrx; l < n2; l+=numx) { + const U c_h = static_cast(k_input[l]); + const U c_loss = static_cast(k_dout[l]); + U f_grad_input = fH * c_loss * gamma[l]; + f_grad_input -= sum_loss1; + f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2; + f_grad_input *= term1; + k_grad_input[l] = static_cast(f_grad_input); + } + } else { + for (int l = thrx; l < n2; l+=numx) { + const U c_h = static_cast(k_input[l]); + const U c_loss = static_cast(k_dout[l]); + U f_grad_input = fH * c_loss; + f_grad_input -= sum_loss1; + f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2; + f_grad_input *= term1; + k_grad_input[l] = static_cast(f_grad_input); + } + } + // prevent race where buf is written again before reads are done + __syncthreads(); + } +} + + + + +template +void HostApplyLayerNorm( + V* output, + U* mean, + U* invvar, + const T* input, + int n1, + int n2, + double epsilon, + const V* gamma, + const V* beta + ) +{ + auto stream = at::cuda::getCurrentCUDAStream().stream(); + const dim3 threads(32,4,1); + const uint64_t maxGridY = + at::cuda::getCurrentDeviceProperties()->maxGridSize[1]; + const dim3 blocks(1, std::min((uint64_t)n1, maxGridY), 1); + int nshared = + threads.y > 1 ? + threads.y*sizeof(U)+(threads.y/2)*sizeof(U) : + 0; + cuApplyLayerNorm<<>>( + output, + mean, + invvar, + input, + n1,n2, + U(epsilon), + gamma,beta); +} + + +void cuda_layer_norm( + at::Tensor* output, + at::Tensor* mean, + at::Tensor* invvar, + at::Tensor* input, + int n1, + int n2, + #ifdef VERSION_GE_1_1 + at::IntArrayRef normalized_shape, + #else + at::IntList normalized_shape, + #endif + at::Tensor* gamma, + at::Tensor* beta, + double epsilon) +{ + using namespace at; + DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES( + input->scalar_type(), output->scalar_type(), "cuda_layer_norm_kernel", + HostApplyLayerNorm( + output->DATA_PTR(), + mean->DATA_PTR(), + invvar->DATA_PTR(), + input->DATA_PTR(), + n1,n2, + epsilon, + gamma != NULL ? gamma->DATA_PTR() : NULL, + beta != NULL ? beta->DATA_PTR() : NULL); + ) +} + + +template +void HostLayerNormGradient( + const V* dout, + const U* mean, + const U* invvar, + at::Tensor* input, + int n1, + int n2, + const V* gamma, + const V* beta, + double epsilon, + T* grad_input, + V* grad_gamma, + V* grad_beta + ) +{ + auto stream = at::cuda::getCurrentCUDAStream().stream(); + + if (gamma != NULL && beta != NULL) { + // compute grad_gamma(j) and grad_beta(j) + const int part_size = 16; + const dim3 threads2(32,4,1); + const dim3 blocks2((n2+threads2.x-1)/threads2.x,part_size,1); + const int nshared2_a = 2 * sizeof(U) * threads2.y * threads2.y * + (threads2.x + 1); + const int nshared2_b = threads2.x * threads2.y * sizeof(U); + const int nshared2 = nshared2_a > nshared2_b ? nshared2_a : nshared2_b; + at::Tensor part_grad_gamma = at::empty( + {part_size,n2}, input->options().dtype(at::ScalarType::Float)); + at::Tensor part_grad_beta = at::empty_like(part_grad_gamma); + cuComputePartGradGammaBeta<<>>( + dout, + input->DATA_PTR(), + n1,n2, + mean, + invvar, + U(epsilon), + part_grad_gamma.DATA_PTR(), + part_grad_beta.DATA_PTR()); + + const dim3 threads3(32,8,1); + const dim3 blocks3((n2+threads2.x-1)/threads2.x,1,1); + const int nshared3 = threads3.x * threads3.y * sizeof(U); + cuComputeGradGammaBeta<<>>( + part_grad_gamma.DATA_PTR(), + part_grad_beta.DATA_PTR(), + part_size, + n1,n2, + grad_gamma, + grad_beta); + } + + // compute grad_input + const uint64_t maxGridY = + at::cuda::getCurrentDeviceProperties()->maxGridSize[1]; + const dim3 blocks1(1, std::min((uint64_t)n1, maxGridY), 1); + const dim3 threads1(32,4,1); + int nshared = + threads1.y > 1 ? + threads1.y*threads1.x*sizeof(U) : + 0; + cuComputeGradInput<<>>( + dout, + input->DATA_PTR(), + n1,n2, + mean, + invvar, + U(epsilon), + gamma, + grad_input); +} + + +void cuda_layer_norm_gradient( + at::Tensor* dout, + at::Tensor* mean, + at::Tensor* invvar, + at::Tensor* input, + int n1, + int n2, + #ifdef VERSION_GE_1_1 + at::IntArrayRef normalized_shape, + #else + at::IntList normalized_shape, + #endif + at::Tensor* gamma, + at::Tensor* beta, + double epsilon, + at::Tensor* grad_input, + at::Tensor* grad_gamma, + at::Tensor* grad_beta) +{ + using namespace at; + DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES( + input->scalar_type(), gamma->scalar_type(), + "cuda_layer_norm_gradient_kernel", + HostLayerNormGradient( + dout->DATA_PTR(), + mean->DATA_PTR(), + invvar->DATA_PTR(), + input, + n1,n2, + // TMJ pass NULL argument for gamma, beta, grad_gamma and grad_beta + // if gamma Tensor is NULL on input. + gamma != NULL ? gamma->DATA_PTR() : NULL, + gamma != NULL ? beta->DATA_PTR() : NULL, + epsilon, + grad_input->DATA_PTR(), + gamma != NULL ? grad_gamma->DATA_PTR() : NULL, + gamma != NULL ? grad_beta->DATA_PTR() : NULL); + ) +} diff --git a/megatron/fused_kernels/scaled_masked_softmax.cpp b/megatron/fused_kernels/scaled_masked_softmax.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4c8a8c2ee39bcc0f9d04b23b2bf19032d8327e44 --- /dev/null +++ b/megatron/fused_kernels/scaled_masked_softmax.cpp @@ -0,0 +1,83 @@ +/* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. */ + +#include +#include +#include + +namespace multihead_attn { +namespace fused_softmax { +namespace scaled_masked_softmax { + +torch::Tensor fwd_cuda( + torch::Tensor const& input, + torch::Tensor const& mask, + float scale_factor); + +torch::Tensor bwd_cuda( + torch::Tensor const& output_grads, + torch::Tensor const& softmax_results, + float scale_factor); + +int get_batch_per_block_cuda( + int query_seq_len, + int key_seq_len, + int batches, + int attn_heads); + +torch::Tensor fwd( + torch::Tensor const& input, + torch::Tensor const& mask, + float scale_factor) { + AT_ASSERTM(input.dim() == 4, "expected 4D tensor"); + AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) || + (input.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); + AT_ASSERTM(mask.dim() == 4, "expected 4D tensor"); + + return fwd_cuda(input, mask, scale_factor); +} + +torch::Tensor bwd( + torch::Tensor const& output_grads, + torch::Tensor const& softmax_results, + float scale_factor) { + + AT_ASSERTM(output_grads.dim() == 4, "expected 3D tensor"); + AT_ASSERTM(softmax_results.dim() == 4, "expected 3D tensor"); + + AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) || + (output_grads.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); + AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) || + (softmax_results.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); + + return bwd_cuda(output_grads, softmax_results, scale_factor); +} + +int get_batch_per_block( + int query_seq_len, + int key_seq_len, + int batches, + int attn_heads) { + return get_batch_per_block_cuda(query_seq_len, key_seq_len, batches, attn_heads); +} + +} // end namespace scaled_masked_softmax +} // end namespace fused_softmax +} // end namespace multihead_attn + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", + &multihead_attn::fused_softmax::scaled_masked_softmax::fwd, + "Self Multihead Attention scaled, time masked softmax -- Forward."); + + m.def("backward", + &multihead_attn::fused_softmax::scaled_masked_softmax::bwd, + "Self Multihead Attention scaled, time masked softmax -- Backward."); + + m.def("get_batch_per_block", + &multihead_attn::fused_softmax::scaled_masked_softmax::get_batch_per_block, + "Return Batch per block size." + ); +} diff --git a/megatron/fused_kernels/scaled_masked_softmax.h b/megatron/fused_kernels/scaled_masked_softmax.h new file mode 100644 index 0000000000000000000000000000000000000000..21ebbd52284203a64f6c7acab82e36fdb6cd7f6f --- /dev/null +++ b/megatron/fused_kernels/scaled_masked_softmax.h @@ -0,0 +1,710 @@ +/* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +namespace { + +template +__device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src); + +template <> +__device__ __inline__ void copy_vector(c10::BFloat16 *dst, const c10::BFloat16 *src) { *dst = *src; } + +template <> +__device__ __inline__ void copy_vector(c10::BFloat16 *dst, const c10::BFloat16 *src) { *((float2*) dst) = *((float2*) src); } + +template <> +__device__ __inline__ void copy_vector(c10::Half *dst, const c10::Half *src) { *dst = *src; } + +template <> +__device__ __inline__ void copy_vector(c10::Half *dst, const c10::Half *src) { *((float2*) dst) = *((float2*) src); } + +template <> +__device__ __inline__ void copy_vector(uint8_t *dst, const uint8_t *src) { *dst = *src; } + +template <> +__device__ __inline__ void copy_vector(uint8_t *dst, const uint8_t *src) {*((half2*) dst) = *((half2*) src); } + +int log2_ceil(int value) { + int log2_value = 0; + while ((1 << log2_value) < value) ++log2_value; + return log2_value; +} + +template +struct Add { + __device__ __forceinline__ T operator()(T a, T b) const { + return a + b; + } +}; + +template +struct Max { + __device__ __forceinline__ T operator()(T a, T b) const { + return a < b ? b : a; + } +}; + +template +__device__ __forceinline__ T WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff) +{ +#if CUDA_VERSION >= 9000 + return __shfl_xor_sync(mask, value, laneMask, width); +#else + return __shfl_xor(value, laneMask, width); +#endif +} + +template class ReduceOp> +__device__ __forceinline__ void warp_reduce(acc_t* sum) { + ReduceOp r; + #pragma unroll + for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE); + sum[i] = r(sum[i], b); + } + } +} + + +/* + * Extended softmax (from native aten pytorch) with following additional features + * 1) input scaling + */ +template +__global__ void scaled_softmax_warp_forward( + output_t *dst, + const input_t *src, + const acc_t scale, + int micro_batch_size, + int element_count) +{ + // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and + // warp_size of method warp_softmax_forward_kernel. + constexpr int next_power_of_two = 1 << log2_elements; + constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; + constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; + constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; + + // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, ) + // gridDim/blockIdx = (seq_len, attn_heads, batches) + int first_batch = (blockDim.y * (blockIdx.x + gridDim.x * (blockIdx.y + gridDim.y * blockIdx.z))+ threadIdx.y) * WARP_BATCH; + + // micro_batch_size might not be a multiple of WARP_BATCH. Check how + // many batches have to computed within this WARP. + int local_batches = micro_batch_size - first_batch; + if (local_batches > WARP_BATCH) + local_batches = WARP_BATCH; + + // there might be multiple batches per warp. compute the index within the batch + int local_idx = threadIdx.x; + + src += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; + dst += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; + + // load data from global memory + acc_t elements[WARP_BATCH][WARP_ITERATIONS]; + input_t temp_data[ELEMENTS_PER_LDG_STG]; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + int batch_element_count = (i >= local_batches) ? 0 : element_count; + + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + + if (element_index < batch_element_count) { + int itr_idx = i*element_count+it*WARP_SIZE; + copy_vector(temp_data, src + itr_idx); + + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + elements[i][it + element] = (acc_t)temp_data[element] * scale; + } + } else { + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + elements[i][it + element] = -std::numeric_limits::infinity(); + } + } + } + } + + // compute max_value + acc_t max_value[WARP_BATCH]; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + max_value[i] = elements[i][0]; + #pragma unroll + for (int it = 1; it < WARP_ITERATIONS; ++it) { + max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it]; + } + } + warp_reduce(max_value); + + acc_t sum[WARP_BATCH] { 0.0f }; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; ++it) { + elements[i][it] = std::exp((elements[i][it] - max_value[i])); + sum[i] += elements[i][it]; + } + } + warp_reduce(sum); + + // store result + output_t out[ELEMENTS_PER_LDG_STG]; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + if (i >= local_batches) + break; + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + if (element_index < element_count) { + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + out[element] = elements[i][it + element] / sum[i]; + } + copy_vector(dst + i * element_count + it * WARP_SIZE, out); + } else { + break; + } + } + } +} + + +/* + * Extended softmax (from native aten pytorch) with following additional features + * 1) input scaling + * 2) Explicit masking + */ +template +__global__ void scaled_masked_softmax_warp_forward( + output_t *dst, + const input_t *src, + const uint8_t *mask, + const acc_t scale, + int micro_batch_size, + int element_count, + int pad_batches) +{ + // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and + // warp_size of method warp_softmax_forward_kernel. + constexpr int next_power_of_two = 1 << log2_elements; + constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; + constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; + constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; + + // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, ) + // gridDim/blockIdx = (seq_len, attn_heads, batches) + int first_batch = (blockDim.y * (blockIdx.x + gridDim.x * (blockIdx.y + gridDim.y * blockIdx.z))+ threadIdx.y) * WARP_BATCH; + int pad_first_batch = 0; + if (pad_batches != 1) { // bert style + pad_first_batch = (blockDim.y * (blockIdx.x + gridDim.x * blockIdx.z) + threadIdx.y) * WARP_BATCH; + } else { // gpt2 style + pad_first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; + } + + // micro_batch_size might not be a multiple of WARP_BATCH. Check how + // many batches have to computed within this WARP. + int local_batches = micro_batch_size - first_batch; + if (local_batches > WARP_BATCH) + local_batches = WARP_BATCH; + + // there might be multiple batches per warp. compute the index within the batch + int local_idx = threadIdx.x; + + src += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; + dst += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; + mask += pad_first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; + + // load data from global memory + acc_t elements[WARP_BATCH][WARP_ITERATIONS]; + input_t temp_data[ELEMENTS_PER_LDG_STG]; + uint8_t temp_mask[ELEMENTS_PER_LDG_STG]; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + int batch_element_count = (i >= local_batches) ? 0 : element_count; + + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + + if (element_index < batch_element_count) { + int itr_idx = i*element_count+it*WARP_SIZE; + copy_vector(temp_data, src + itr_idx); + copy_vector(temp_mask, mask + itr_idx); + + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + if (temp_mask[element] != 1) { + elements[i][it + element] = (acc_t)temp_data[element] * scale; + } else { + elements[i][it + element] = -10000.0; + } + } + } else { + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + elements[i][it + element] = -std::numeric_limits::infinity(); + } + } + } + } + + // compute max_value + acc_t max_value[WARP_BATCH]; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + max_value[i] = elements[i][0]; + #pragma unroll + for (int it = 1; it < WARP_ITERATIONS; ++it) { + max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it]; + } + } + warp_reduce(max_value); + + // compute scale value to account for full mask + acc_t scale_value[WARP_BATCH]; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + scale_value[i] = (max_value[i] == -10000.0) ? 0.0 : 1.0; + } + + acc_t sum[WARP_BATCH] { 0.0f }; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; ++it) { + elements[i][it] = std::exp((elements[i][it] - max_value[i])); + sum[i] += elements[i][it]; + } + } + warp_reduce(sum); + + // store result + output_t out[ELEMENTS_PER_LDG_STG]; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + if (i >= local_batches) + break; + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + if (element_index < element_count) { + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + out[element] = elements[i][it + element] * scale_value[i] / sum[i]; + } + copy_vector(dst + i * element_count + it * WARP_SIZE, out); + } else { + break; + } + } + } +} + +template +__global__ void scaled_masked_softmax_warp_backward( + output_t *gradInput, + input_t *grad, + const input_t *output, + acc_t scale, + int micro_batch_size, + int element_count) +{ + // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and + // warp_size of method warp_softmax_backward_kernel. + constexpr int next_power_of_two = 1 << log2_elements; + constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; + constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; + constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; + + // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, ) + // gridDim/blockIdx = (seq_len, attn_heads, batches) + int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; + + // micro_batch_size might not be a multiple of WARP_BATCH. Check how + // many batches have to computed within this WARP. + int local_batches = micro_batch_size - first_batch; + if (local_batches > WARP_BATCH) + local_batches = WARP_BATCH; + + // there might be multiple batches per warp. compute the index within the batch + int local_idx = threadIdx.x; + + // the first element to process by the current thread + int thread_offset = first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; + grad += thread_offset; + output += thread_offset; + gradInput += thread_offset; + + // load data from global memory + acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f }; + acc_t output_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f }; + input_t temp_grad[ELEMENTS_PER_LDG_STG]; + input_t temp_output[ELEMENTS_PER_LDG_STG]; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + int batch_element_count = (i >= local_batches) ? 0 : element_count; + + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + if (element_index < batch_element_count) { + copy_vector(temp_grad, grad + i * element_count + it * WARP_SIZE); + copy_vector(temp_output, output + i * element_count + it * WARP_SIZE); + + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + output_reg[i][it + element] = (acc_t)temp_output[element]; + } + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + grad_reg[i][it + element] = (acc_t)temp_grad[element] * output_reg[i][it + element]; + } + } + } + } + + acc_t sum[WARP_BATCH]; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + sum[i] = grad_reg[i][0]; + #pragma unroll + for (int it = 1; it < WARP_ITERATIONS; ++it) { + sum[i] += grad_reg[i][it]; + } + } + warp_reduce(sum); + + // store result + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + if (i >= local_batches) + break; + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + if (element_index < element_count) { + // compute gradients + output_t out[ELEMENTS_PER_LDG_STG]; + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + out[element] = (output_t)(scale * (grad_reg[i][it + element] - output_reg[i][it + element] * sum[i])); + } + copy_vector(gradInput + i * element_count + it * WARP_SIZE, out); + } + } + } +} +} // end of anonymous namespace + +int get_batch_per_block(int query_seq_len, int key_seq_len, int batches, int attn_heads){ + int log2_elements = log2_ceil(key_seq_len); + const int next_power_of_two = 1 << log2_elements; + + int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; + + constexpr int threads_per_block = 128; + int warps_per_block = (threads_per_block / warp_size); + int batches_per_block = warps_per_block * batches_per_warp; + + return batches_per_block; +} + +template +void dispatch_scaled_softmax_forward( + output_t *dst, + const input_t *src, + const input_t scale, + int query_seq_len, + int key_seq_len, + int batches, + int attn_heads) +{ + TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 4096 ); + if (key_seq_len == 0) { + return; + } else { + int log2_elements = log2_ceil(key_seq_len); + const int next_power_of_two = 1 << log2_elements; + int batch_count = batches * attn_heads * query_seq_len; + + // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward. + int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + + // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward. + int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; + + // use 128 threads per block to maximimize gpu utilization + constexpr int threads_per_block = 128; + + int warps_per_block = (threads_per_block / warp_size); + int batches_per_block = warps_per_block * batches_per_warp; + TORCH_INTERNAL_ASSERT(query_seq_len%batches_per_block == 0); + dim3 blocks(query_seq_len/batches_per_block, attn_heads, batches); + dim3 threads(warp_size, warps_per_block, 1); + // Launch code would be more elegant if C++ supported FOR CONSTEXPR + switch (log2_elements) { + case 0: // 1 + scaled_softmax_warp_forward + <<>>(dst, src, scale, batch_count, key_seq_len); + break; + case 1: // 2 + scaled_softmax_warp_forward + <<>>(dst, src, scale, batch_count, key_seq_len); + break; + case 2: // 4 + scaled_softmax_warp_forward + <<>>(dst, src, scale, batch_count, key_seq_len); + break; + case 3: // 8 + scaled_softmax_warp_forward + <<>>(dst, src, scale, batch_count, key_seq_len); + break; + case 4: // 16 + scaled_softmax_warp_forward + <<>>(dst, src, scale, batch_count, key_seq_len); + break; + case 5: // 32 + scaled_softmax_warp_forward + <<>>(dst, src, scale, batch_count, key_seq_len); + break; + case 6: // 64 + scaled_softmax_warp_forward + <<>>(dst, src, scale, batch_count, key_seq_len); + break; + case 7: // 128 + scaled_softmax_warp_forward + <<>>(dst, src, scale, batch_count, key_seq_len); + break; + case 8: // 256 + scaled_softmax_warp_forward + <<>>(dst, src, scale, batch_count, key_seq_len); + break; + case 9: // 512 + scaled_softmax_warp_forward + <<>>(dst, src, scale, batch_count, key_seq_len); + break; + case 10: // 1024 + scaled_softmax_warp_forward + <<>>(dst, src, scale, batch_count, key_seq_len); + break; + case 11: // 2048 + scaled_softmax_warp_forward + <<>>(dst, src, scale, batch_count, key_seq_len); + break; + case 12: // 4096 + scaled_softmax_warp_forward + <<>>(dst, src, scale, batch_count, key_seq_len); + break; + default: + break; + } + } +} + +template +void dispatch_scaled_masked_softmax_forward( + output_t *dst, + const input_t *src, + const uint8_t *mask, + const input_t scale, + int query_seq_len, + int key_seq_len, + int batches, + int attn_heads, + int pad_batches) +{ + TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 4096 ); + if (key_seq_len == 0) { + return; + } else { + int log2_elements = log2_ceil(key_seq_len); + const int next_power_of_two = 1 << log2_elements; + int batch_count = batches * attn_heads * query_seq_len; + + // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward. + int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + + // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward. + int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; + + // use 128 threads per block to maximimize gpu utilization + constexpr int threads_per_block = 128; + + int warps_per_block = (threads_per_block / warp_size); + int batches_per_block = warps_per_block * batches_per_warp; + TORCH_INTERNAL_ASSERT(query_seq_len%batches_per_block == 0); + dim3 blocks(query_seq_len/batches_per_block, attn_heads, batches); + dim3 threads(warp_size, warps_per_block, 1); + // Launch code would be more elegant if C++ supported FOR CONSTEXPR + switch (log2_elements) { + case 0: // 1 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 1: // 2 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 2: // 4 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 3: // 8 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 4: // 16 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 5: // 32 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 6: // 64 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 7: // 128 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 8: // 256 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 9: // 512 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 10: // 1024 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 11: // 2048 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 12: // 4096 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + default: + break; + } + } +} + +template +void dispatch_scaled_masked_softmax_backward( + output_t *grad_input, + input_t *grad, + const input_t *output, + const acc_t scale, + int query_seq_len, + int key_seq_len, + int batches, + int attn_heads) +{ + TORCH_INTERNAL_ASSERT( key_seq_len >= 0 && key_seq_len <= 4096 ); + if (key_seq_len == 0) { + return; + } else { + int log2_elements = log2_ceil(key_seq_len); + const int next_power_of_two = 1 << log2_elements; + int batch_count = batches * attn_heads * query_seq_len; + + // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward. + int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + + // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward. + int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; + + // use 128 threads per block to maximimize gpu utilization + constexpr int threads_per_block = 128; + + int warps_per_block = (threads_per_block / warp_size); + int batches_per_block = warps_per_block * batches_per_warp; + int blocks = batch_count/batches_per_block; + dim3 threads(warp_size, warps_per_block, 1); + // Launch code would be more elegant if C++ supported FOR CONSTEXPR + switch (log2_elements) { + case 0: // 1 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 1: // 2 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 2: // 4 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 3: // 8 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 4: // 16 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 5: // 32 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 6: // 64 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 7: // 128 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 8: // 256 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 9: // 512 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 10: // 1024 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 11: // 2048 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 12: // 4096 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + + default: + break; + } + } +} diff --git a/megatron/fused_kernels/scaled_masked_softmax_cuda.cu b/megatron/fused_kernels/scaled_masked_softmax_cuda.cu new file mode 100644 index 0000000000000000000000000000000000000000..a8be57c0525f2693245b2f46c2a07e00c6e0dd67 --- /dev/null +++ b/megatron/fused_kernels/scaled_masked_softmax_cuda.cu @@ -0,0 +1,107 @@ +/* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. */ + +#include +#include +#include +#include +#include +#include +#include +#include "scaled_masked_softmax.h" +#include "type_shim.h" + +namespace multihead_attn { +namespace fused_softmax { +namespace scaled_masked_softmax { + +int get_batch_per_block_cuda(int query_seq_len, int key_seq_len, int batches, int attn_heads){ + return get_batch_per_block(query_seq_len, key_seq_len, batches, attn_heads); +} + + +torch::Tensor fwd_cuda( + torch::Tensor const& input, + torch::Tensor const& mask, + float scale_factor) +{ + // input is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len] + const int batches = input.size(0); + const int pad_batches = mask.size(0); + const int attn_heads = input.size(1); + const int query_seq_len = input.size(2); + const int key_seq_len = input.size(3); + TORCH_INTERNAL_ASSERT(key_seq_len <= 4096); + TORCH_INTERNAL_ASSERT(query_seq_len > 1); + TORCH_INTERNAL_ASSERT(pad_batches == 1 || pad_batches == batches); + TORCH_INTERNAL_ASSERT(mask.size(1) == 1); + TORCH_INTERNAL_ASSERT(mask.size(2) == query_seq_len); + TORCH_INTERNAL_ASSERT(mask.size(3) == key_seq_len); + + // Output + auto act_options = input.options().requires_grad(false); + torch::Tensor softmax_results = + torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options); + + // Softmax Intermediate Result Ptr + void* input_ptr = static_cast(input.data_ptr()); + void* mask_ptr = static_cast(mask.data_ptr()); + void* softmax_results_ptr = static_cast(softmax_results.data_ptr()); + + DISPATCH_HALF_AND_BFLOAT( + input.scalar_type(), + "dispatch_scaled_masked_softmax_forward", + dispatch_scaled_masked_softmax_forward( + reinterpret_cast(softmax_results_ptr), + reinterpret_cast(input_ptr), + reinterpret_cast(mask_ptr), + scale_factor, + query_seq_len, + key_seq_len, + batches, + attn_heads, + pad_batches); + ); + return softmax_results; +} + +torch::Tensor bwd_cuda( + torch::Tensor const& output_grads_, + torch::Tensor const& softmax_results_, + float scale_factor) { + + auto output_grads = output_grads_.contiguous(); + auto softmax_results = softmax_results_.contiguous(); + + //output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len] + const int batches = output_grads.size(0); + const int attn_heads = output_grads.size(1); + const int query_seq_len = output_grads.size(2); + const int key_seq_len = output_grads.size(3); + + auto act_options = output_grads.options().requires_grad(false); + torch::Tensor input_grads = + torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options); + + void* output_grads_ptr = static_cast(output_grads.data_ptr()); + void* input_grads_ptr = static_cast(input_grads.data_ptr()); + + //Softmax Grad + DISPATCH_HALF_AND_BFLOAT( + output_grads_.scalar_type(), + "dispatch_scaled_masked_softmax_backward", + dispatch_scaled_masked_softmax_backward( + reinterpret_cast(input_grads_ptr), + reinterpret_cast(output_grads_ptr), + reinterpret_cast(softmax_results.data_ptr()), + scale_factor, + query_seq_len, + key_seq_len, + batches, + attn_heads); + ); + + return input_grads; +} +} +} +} diff --git a/megatron/fused_kernels/scaled_softmax.cpp b/megatron/fused_kernels/scaled_softmax.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e10cd77e7fb35247fac2b547d8bccc18c3dca58e --- /dev/null +++ b/megatron/fused_kernels/scaled_softmax.cpp @@ -0,0 +1,61 @@ +/* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. */ + +#include +#include +#include + +namespace multihead_attn { +namespace fused_softmax { +namespace scaled_softmax { + +torch::Tensor fwd_cuda( + torch::Tensor const& input, + float scale_factor); + +torch::Tensor bwd_cuda( + torch::Tensor const& output_grads, + torch::Tensor const& softmax_results, + float scale_factor); + +torch::Tensor fwd( + torch::Tensor const& input, + float scale_factor) { + AT_ASSERTM(input.dim() == 4, "expected 4D tensor"); + AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) || + (input.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); + + return fwd_cuda(input, scale_factor); +} + +torch::Tensor bwd( + torch::Tensor const& output_grads, + torch::Tensor const& softmax_results, + float scale_factor) { + + AT_ASSERTM(output_grads.dim() == 4, "expected 3D tensor"); + AT_ASSERTM(softmax_results.dim() == 4, "expected 3D tensor"); + + AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) || + (output_grads.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); + AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) || + (softmax_results.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); + + return bwd_cuda(output_grads, softmax_results, scale_factor); +} + +} // end namespace scaled_softmax +} // end namespace fused_softmax +} // end namespace multihead_attn + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", + &multihead_attn::fused_softmax::scaled_softmax::fwd, + "Self Multihead Attention scaled, softmax -- Forward."); + m.def("backward", + &multihead_attn::fused_softmax::scaled_softmax::bwd, + "Self Multihead Attention scaled, softmax -- Backward."); +} + diff --git a/megatron/fused_kernels/scaled_softmax_cuda.cu b/megatron/fused_kernels/scaled_softmax_cuda.cu new file mode 100644 index 0000000000000000000000000000000000000000..ecc6eb06e83e5fee38707c1ccd4bc362b0d1df49 --- /dev/null +++ b/megatron/fused_kernels/scaled_softmax_cuda.cu @@ -0,0 +1,90 @@ +/* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. */ + +#include +#include +#include +#include +#include +#include +#include +#include "scaled_masked_softmax.h" +#include "type_shim.h" + +namespace multihead_attn { +namespace fused_softmax { +namespace scaled_softmax { + +torch::Tensor fwd_cuda( + torch::Tensor const& input, + float scale_factor) +{ + // input is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len] + const int batches = input.size(0); + const int attn_heads = input.size(1); + const int query_seq_len = input.size(2); + const int key_seq_len = input.size(3); + TORCH_INTERNAL_ASSERT(key_seq_len <= 4096); + TORCH_INTERNAL_ASSERT(query_seq_len > 1); + + // Output + auto act_options = input.options().requires_grad(false); + torch::Tensor softmax_results = + torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options); + + // Softmax Intermediate Result Ptr + void* input_ptr = static_cast(input.data_ptr()); + void* softmax_results_ptr = static_cast(softmax_results.data_ptr()); + + DISPATCH_HALF_AND_BFLOAT( + input.scalar_type(), + "dispatch_scaled_softmax_forward", + dispatch_scaled_softmax_forward( + reinterpret_cast(softmax_results_ptr), + reinterpret_cast(input_ptr), + scale_factor, + query_seq_len, + key_seq_len, + batches, + attn_heads); + ); + return softmax_results; +} + +torch::Tensor bwd_cuda( + torch::Tensor const& output_grads_, + torch::Tensor const& softmax_results_, + float scale_factor) { + + auto output_grads = output_grads_.contiguous(); + auto softmax_results = softmax_results_.contiguous(); + + //output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len] + const int batches = output_grads.size(0); + const int attn_heads = output_grads.size(1); + const int query_seq_len = output_grads.size(2); + const int key_seq_len = output_grads.size(3); + + void* output_grads_ptr = static_cast(output_grads.data_ptr()); + + //Softmax Grad + DISPATCH_HALF_AND_BFLOAT( + output_grads_.scalar_type(), + "dispatch_scaled_masked_softmax_backward", + dispatch_scaled_masked_softmax_backward( + reinterpret_cast(output_grads_ptr), + reinterpret_cast(output_grads_ptr), + reinterpret_cast(softmax_results.data_ptr()), + scale_factor, + query_seq_len, + key_seq_len, + batches, + attn_heads); + ); + + //backward pass is completely in-place + return output_grads; +} +} +} +} + diff --git a/megatron/fused_kernels/scaled_upper_triang_masked_softmax.cpp b/megatron/fused_kernels/scaled_upper_triang_masked_softmax.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ddfc8646a3dd109b31d633f109079e2c1af98e9d --- /dev/null +++ b/megatron/fused_kernels/scaled_upper_triang_masked_softmax.cpp @@ -0,0 +1,58 @@ +/* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. */ + +#include +#include +#include + +namespace multihead_attn { +namespace fused_softmax { +namespace scaled_upper_triang_masked_softmax { + +torch::Tensor fwd_cuda( + torch::Tensor const& input, + float scale_factor); + +torch::Tensor bwd_cuda( + torch::Tensor const& output_grads, + torch::Tensor const& softmax_results, + float scale_factor); + +torch::Tensor fwd(torch::Tensor const& input, float scale_factor) { + AT_ASSERTM(input.dim() == 3, "expected 3D tensor"); + AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) || + (input.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); + + return fwd_cuda(input, scale_factor); +} + +torch::Tensor bwd( + torch::Tensor const& output_grads, + torch::Tensor const& softmax_results, + float scale_factor) { + + AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor"); + AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor"); + + AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) || + (output_grads.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); + AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) || + (softmax_results.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); + + return bwd_cuda(output_grads, softmax_results, scale_factor); +} + +} // end namespace scaled_upper_triang_masked_softmax +} // end namespace fused_softmax +} // end namespace multihead_attn + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", + &multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::fwd, + "Self Multihead Attention scaled, time masked softmax -- Forward."); + m.def("backward", + &multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::bwd, + "Self Multihead Attention scaled, time masked softmax -- Backward."); +} diff --git a/megatron/fused_kernels/scaled_upper_triang_masked_softmax.h b/megatron/fused_kernels/scaled_upper_triang_masked_softmax.h new file mode 100644 index 0000000000000000000000000000000000000000..98aaf884c9ed99b8c75f6179ef73156699d644c4 --- /dev/null +++ b/megatron/fused_kernels/scaled_upper_triang_masked_softmax.h @@ -0,0 +1,499 @@ +/* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. */ + +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace { + +template +__device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src); + +template <> +__device__ __inline__ void copy_vector(c10::BFloat16 *dst, const c10::BFloat16 *src) { *dst = *src; } + +template <> +__device__ __inline__ void copy_vector(c10::BFloat16 *dst, const c10::BFloat16 *src) { *((float2*) dst) = *((float2*) src); } + +template <> +__device__ __inline__ void copy_vector(c10::Half *dst, const c10::Half *src) { *dst = *src; } + +template <> +__device__ __inline__ void copy_vector(c10::Half *dst, const c10::Half *src) { *((float2*) dst) = *((float2*) src); } + +template <> +__device__ __inline__ void copy_vector(uint8_t *dst, const uint8_t *src) { *dst = *src; } + +template <> +__device__ __inline__ void copy_vector(uint8_t *dst, const uint8_t *src) {*((half2*) dst) = *((half2*) src); } + +template +__device__ __inline__ void copy_zero_vector(Datatype *dst); + +template <> +__device__ __inline__ void copy_zero_vector(c10::BFloat16 *dst) { *dst = 0.0; } + +template <> +__device__ __inline__ void copy_zero_vector(c10::BFloat16 *dst) { *((float2*) dst) = make_float2(0.0f, 0.0f); } + +template <> +__device__ __inline__ void copy_zero_vector(c10::Half *dst) { *dst = 0.0; } + +template <> +__device__ __inline__ void copy_zero_vector(c10::Half *dst) { *((float2*) dst) = make_float2(0.0f, 0.0f); } + + +int log2_ceil(int value) { + int log2_value = 0; + while ((1 << log2_value) < value) ++log2_value; + return log2_value; +} + +template +struct Add { + __device__ __forceinline__ T operator()(T a, T b) const { + return a + b; + } +}; + +template +struct Max { + __device__ __forceinline__ T operator()(T a, T b) const { + return a < b ? b : a; + } +}; + +template +__device__ __forceinline__ T WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff) +{ +#if CUDA_VERSION >= 9000 + return __shfl_xor_sync(mask, value, laneMask, width); +#else + return __shfl_xor(value, laneMask, width); +#endif +} + +template class ReduceOp> +__device__ __forceinline__ void warp_reduce(acc_t* sum) { + ReduceOp r; + #pragma unroll + for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE); + sum[i] = r(sum[i], b); + } + } +} + +/* + * Extended softmax (from native aten pytorch) with following additional features + * 1) input scaling + * 2) Implicit time (diagonal masking) + */ +template +__global__ void scaled_upper_triang_masked_softmax_warp_forward( + output_t *dst, + const input_t *src, + const acc_t scale, + int micro_batch_size, + int stride, + int element_count) +{ + // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and + // warp_size of method warp_softmax_forward_kernel. + constexpr int next_power_of_two = 1 << log2_elements; + constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; + constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; + constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; + + int first_batch = (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + blockIdx.x; + int local_seq = blockIdx.x + 1; + int warp_iteration_limit = (local_seq + ELEMENTS_PER_LDG_STG * WARP_SIZE - 1)/ WARP_SIZE; + + // micro_batch_size might not be a multiple of WARP_BATCH. Check how + // many batches have to computed within this WARP. + int local_batches = micro_batch_size - first_batch; + if (local_batches > WARP_BATCH) + local_batches = WARP_BATCH; + + // there might be multiple batches per warp. compute the index within the batch + int local_idx = threadIdx.x; + + src += first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx; + dst += first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx; + + // load data from global memory + acc_t elements[WARP_BATCH][WARP_ITERATIONS]; + input_t temp_data[ELEMENTS_PER_LDG_STG]; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + int batch_element_count = (i >= local_batches) ? 0 : local_seq; + + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + + if (element_index < batch_element_count) { + copy_vector(temp_data, src + i*element_count*stride + it*WARP_SIZE); + + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + if ((element_index + element) < batch_element_count) { + elements[i][it+element] = (acc_t)temp_data[element] * scale; + } else { + elements[i][it + element] = -std::numeric_limits::infinity(); + } + } + } else { + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + elements[i][it + element] = -std::numeric_limits::infinity(); + } + } + } + } + + // compute max_value + acc_t max_value[WARP_BATCH]; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + max_value[i] = elements[i][0]; + #pragma unroll + for (int it = 1; it < WARP_ITERATIONS; ++it) { + max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it]; + } + } + warp_reduce(max_value); + + acc_t sum[WARP_BATCH] { 0.0f }; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; ++it) { + if (it < warp_iteration_limit) { + elements[i][it] = std::exp((elements[i][it] - max_value[i])); + sum[i] += elements[i][it]; + } + } + } + warp_reduce(sum); + + // store result + output_t out[ELEMENTS_PER_LDG_STG]; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + if (i >= local_batches) + break; + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + + if (element_index < local_seq) { + + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + if (element_index + element < local_seq) { + out[element] = elements[i][it + element] / sum[i]; + } else { + out[element] = 0; + } + } + copy_vector(dst + i * element_count * stride + it * WARP_SIZE, out); + } else if (element_index < element_count) { + copy_zero_vector(dst + i * element_count * stride + it * WARP_SIZE); + } else { + break; + } + } + } +} + +template +__global__ void scaled_upper_triang_masked_softmax_warp_backward( + output_t *gradInput, + input_t *grad, + const input_t *output, + acc_t scale, + int micro_batch_size, + int stride, + int element_count) +{ + // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and + // warp_size of method warp_softmax_backward_kernel. + constexpr int next_power_of_two = 1 << log2_elements; + constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; + constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; + constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; + + int first_batch = (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + blockIdx.x; + int local_seq = blockIdx.x + 1; + + // micro_batch_size might not be a multiple of WARP_BATCH. Check how + // many batches have to computed within this WARP. + int local_batches = micro_batch_size - first_batch; + if (local_batches > WARP_BATCH) + local_batches = WARP_BATCH; + + // there might be multiple batches per warp. compute the index within the batch + int local_idx = threadIdx.x; + + // the first element to process by the current thread + int thread_offset = first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx; + grad += thread_offset; + output += thread_offset; + gradInput += thread_offset; + + // load data from global memory + acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f }; + acc_t output_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f }; + input_t temp_grad[ELEMENTS_PER_LDG_STG]; + input_t temp_output[ELEMENTS_PER_LDG_STG]; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + int batch_element_count = (i >= local_batches) ? 0 : local_seq; + + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + if (element_index < batch_element_count) { + copy_vector(temp_grad, grad + i * element_count * stride + it * WARP_SIZE); + copy_vector(temp_output, output + i * element_count * stride + it * WARP_SIZE); + + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + if (element_index + element < batch_element_count) { + output_reg[i][it + element] = (acc_t)temp_output[element]; + } + } + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + if (element_index + element < batch_element_count) { + grad_reg[i][it + element] = (acc_t)temp_grad[element] * output_reg[i][it + element]; + } + } + } + } + } + + acc_t sum[WARP_BATCH]; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + sum[i] = grad_reg[i][0]; + #pragma unroll + for (int it = 1; it < WARP_ITERATIONS; ++it) { + sum[i] += grad_reg[i][it]; + } + } + warp_reduce(sum); + + // store result + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + if (i >= local_batches) + break; + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + if (element_index < element_count) { + // compute gradients + output_t out[ELEMENTS_PER_LDG_STG]; + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + out[element] = (output_t)(scale * (grad_reg[i][it + element] - output_reg[i][it + element] * sum[i])); + } + copy_vector(gradInput + i * element_count * stride + it * WARP_SIZE, out); + } + } + } +} + +} // end of anonymous namespace + +template +void dispatch_scaled_upper_triang_masked_softmax_forward( + output_t *dst, + const input_t *src, + const input_t scale, + int softmax_elements, + int softmax_elements_stride, + int attn_batches) +{ + TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 2048 ); + if (softmax_elements == 0) { + return; + } else { + int log2_elements = log2_ceil(softmax_elements); + const int next_power_of_two = 1 << log2_elements; + int seq_len = softmax_elements; + int batch_count = attn_batches * seq_len; + + // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward. + int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + + // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward. + int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; + + // use 128 threads per block to maximimize gpu utilization + constexpr int threads_per_block = 128; + + int warps_per_block = (threads_per_block / warp_size); + int batches_per_block = warps_per_block * batches_per_warp; + TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0); + + int blocks_per_seq = attn_batches / batches_per_block; + dim3 blocks(seq_len, blocks_per_seq, 1); + dim3 threads(warp_size, warps_per_block, 1); + // Launch code would be more elegant if C++ supported FOR CONSTEXPR + switch (log2_elements) { + case 0: // 1 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 1: // 2 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 2: // 4 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 3: // 8 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 4: // 16 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 5: // 32 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 6: // 64 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 7: // 128 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 8: // 256 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 9: // 512 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 10: // 1024 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 11: // 2048 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + default: + break; + } + } +} + +template +void dispatch_scaled_upper_triang_masked_softmax_backward( + output_t *grad_input, + input_t *grad, + const input_t *output, + const acc_t scale, + int softmax_elements, + int softmax_elements_stride, + int attn_batches) +{ + TORCH_INTERNAL_ASSERT( softmax_elements >= 0 && softmax_elements <= 2048 ); + if (softmax_elements == 0) { + return; + } else { + int log2_elements = log2_ceil(softmax_elements); + const int next_power_of_two = 1 << log2_elements; + int seq_len = softmax_elements; + int batch_count = attn_batches * seq_len; + + // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward. + int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + + // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward. + int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; + + // use 128 threads per block to maximimize gpu utilization + constexpr int threads_per_block = 128; + + int warps_per_block = (threads_per_block / warp_size); + int batches_per_block = warps_per_block * batches_per_warp; + TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0); + + int blocks_per_seq = attn_batches / batches_per_block; + dim3 blocks(seq_len, blocks_per_seq, 1); + dim3 threads(warp_size, warps_per_block, 1); + // Launch code would be more elegant if C++ supported FOR CONSTEXPR + switch (log2_elements) { + case 0: // 1 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 1: // 2 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 2: // 4 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 3: // 8 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 4: // 16 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 5: // 32 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 6: // 64 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 7: // 128 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 8: // 256 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 9: // 512 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 10: // 1024 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 11: // 2048 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + default: + break; + } + } +} diff --git a/megatron/fused_kernels/scaled_upper_triang_masked_softmax_cuda.cu b/megatron/fused_kernels/scaled_upper_triang_masked_softmax_cuda.cu new file mode 100644 index 0000000000000000000000000000000000000000..c21e5fb4ee181ee60bbdc468e3b6bea832ac1d24 --- /dev/null +++ b/megatron/fused_kernels/scaled_upper_triang_masked_softmax_cuda.cu @@ -0,0 +1,84 @@ +/* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. */ + +#include +#include +#include +#include +#include +#include +#include +#include "scaled_upper_triang_masked_softmax.h" +#include "type_shim.h" + +namespace multihead_attn { +namespace fused_softmax { +namespace scaled_upper_triang_masked_softmax { + +torch::Tensor fwd_cuda( + torch::Tensor const& input, + float scale_factor) +{ + // input is a 3d tensor with dimensions [attn_batches, seq_len, seq_len] + const int attn_batches = input.size(0); + const int seq_len = input.size(1); + TORCH_INTERNAL_ASSERT(seq_len <= 2048); + + // Output + auto act_options = input.options().requires_grad(false); + torch::Tensor softmax_results = + torch::empty({attn_batches, seq_len, seq_len}, act_options); + + // Softmax Intermediate Result Ptr + void* input_ptr = static_cast(input.data_ptr()); + void* softmax_results_ptr = static_cast(softmax_results.data_ptr()); + + DISPATCH_HALF_AND_BFLOAT( + input.scalar_type(), + "dispatch_scaled_upper_triang_masked_softmax_forward", + dispatch_scaled_upper_triang_masked_softmax_forward( + reinterpret_cast(softmax_results_ptr), + reinterpret_cast(input_ptr), + scale_factor, + seq_len, + seq_len, + attn_batches); + ); + return softmax_results; +} + + +torch::Tensor bwd_cuda( + torch::Tensor const& output_grads_, + torch::Tensor const& softmax_results_, + float scale_factor) { + + auto output_grads = output_grads_.contiguous(); + auto softmax_results = softmax_results_.contiguous(); + + //output grads is a 3d tensor with dimensions [attn_batches, seq_len, seq_len] + const int attn_batches = output_grads.size(0); + const int seq_len = output_grads.size(1); + TORCH_INTERNAL_ASSERT(output_grads.size(1) == output_grads.size(2)); + + void* output_grads_ptr = static_cast(output_grads.data_ptr()); + + //Softmax Grad + DISPATCH_HALF_AND_BFLOAT( + output_grads_.scalar_type(), + "dispatch_scaled_upper_triang_masked_softmax_backward", + dispatch_scaled_upper_triang_masked_softmax_backward( + reinterpret_cast(output_grads_ptr), + reinterpret_cast(output_grads_ptr), + reinterpret_cast(softmax_results.data_ptr()), + scale_factor, + seq_len, + seq_len, + attn_batches); + ); + + //backward pass is completely in-place + return output_grads; +} +} +} +} diff --git a/megatron/fused_kernels/tests/__init__.py b/megatron/fused_kernels/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/megatron/fused_kernels/tests/test_fused_kernels.py b/megatron/fused_kernels/tests/test_fused_kernels.py new file mode 100644 index 0000000000000000000000000000000000000000..88d5247e863327abe8671b304af22f73abc5dd4a --- /dev/null +++ b/megatron/fused_kernels/tests/test_fused_kernels.py @@ -0,0 +1,389 @@ +import math + +import torch +from torch.nn import LayerNorm + +from megatron.model.enums import AttnMaskType +from megatron.model.fused_layer_norm import MixedFusedLayerNorm +from megatron.model.fused_softmax import FusedScaleMaskSoftmax +from megatron.model.utils import attention_mask_func +from megatron.fused_kernels import load + +def test_load_fused_kernels(): + try: + import fused_mix_prec_layer_norm_cuda + import scaled_masked_softmax_cuda + import scaled_upper_triang_masked_softmax_cuda + import torch + + print("[Success] load_fused_kernels") + except ImportError as e: + print("[Fail] load_fused_kernels") + raise e + + +def test_fused_softmax(): + bert = BertModel.from_pretrained("bert-base-cased").cuda().half() + tokenizer = BertTokenizer.from_pretrained("bert-base-cased") + test_text = ( + "Hello. How are you? I am fine thank you and you? yes Good. " + "hi hi hi hi hi hi hi hi hi hi hi hi hi" # 32 + ) + + tokens = tokenizer( + [test_text] * 4, + return_tensors="pt", + ) + + embedding_output = bert.embeddings( + input_ids=tokens["input_ids"].cuda(), + position_ids=None, + token_type_ids=tokens["token_type_ids"].cuda(), + inputs_embeds=None, + past_key_values_length=0, + ) + + # (bsz, 1, 1, seq_len) + mask = bert.get_extended_attention_mask( + attention_mask=tokens["attention_mask"].cuda(), + input_shape=tokens["input_ids"].shape, + device=bert.device, + ) + # (bsz, 1, seq_len, seq_len) + mask = mask.repeat(1, 1, mask.size()[-1], 1) + + attention = bert.encoder.layer[0].attention.self + key_layer = attention.transpose_for_scores(attention.key(embedding_output)) + query_layer = attention.transpose_for_scores(attention.query(embedding_output)) + + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + attention_scores /= math.sqrt(key_layer.size()[-1]) + + fused_softmax = ( + FusedScaleMaskSoftmax( + input_in_fp16=True, + input_in_bf16=False, + mask_func=attention_mask_func, + scale=None, + softmax_in_fp32=False, + attn_mask_type=AttnMaskType.padding, + scaled_masked_softmax_fusion=True, + ) + .cuda() + .half() + ) + + fused_softmax_output = fused_softmax( + attention_scores, + (mask != 0), + ) + + torch_softmax = ( + FusedScaleMaskSoftmax( + input_in_fp16=True, + input_in_bf16=False, + mask_func=attention_mask_func, + scale=None, + softmax_in_fp32=False, + attn_mask_type=AttnMaskType.padding, + scaled_masked_softmax_fusion=False, + ) + .cuda() + .half() + ) + + torch_softmax_output = torch_softmax( + attention_scores, + (mask != 0), + ) + + test_result = (fused_softmax_output - torch_softmax_output).abs() + + while test_result.dim() != 1: + test_result = test_result.mean(dim=-1) + + diff = test_result.mean(dim=-1) + + if diff <= 1e-3: + print( + f"\n[Success] test_fused_softmax" + f"\n > mean_difference={diff}" + f"\n > fused_values={fused_softmax_output[-1][-1][-1][:5].tolist()}" + f"\n > torch_values={torch_softmax_output[-1][-1][-1][:5].tolist()}" + ) + else: + print( + f"\n[Fail] test_fused_softmax" + f"\n > mean_difference={diff}, " + f"\n > fused_values={fused_softmax_output[-1][-1][-1][:5].tolist()}, " + f"\n > torch_values={torch_softmax_output[-1][-1][-1][:5].tolist()}" + ) + + +def test_fused_upper_triangle_mask_softmax(): + gpt = GPT2Model.from_pretrained("gpt2").cuda().half() + tokenizer = GPT2Tokenizer.from_pretrained("gpt2") + test_text = ( + "Hello. How are you? I am fine thank you and you? yes Good. " + "hi hi hi hi hi hi hi" # 24 + ) + + tokens = tokenizer( + [test_text] * 4, + return_tensors="pt", + ) + + attention_mask = tokens["attention_mask"].cuda() + attention_mask = attention_mask.view(attention_mask.size(0), -1) + attention_mask = attention_mask[:, None, None, :] + attention_mask = (1.0 - attention_mask) * -10000.0 + attention_mask = attention_mask.repeat(1, 1, attention_mask.size()[-1], 1) + attn = gpt.h[0] + + hidden_states = gpt.wte(tokens["input_ids"].cuda()) + q, k, v = attn.attn.c_attn(hidden_states).split(768, dim=-1) + q = attn.attn._split_heads(q, attn.attn.num_heads, attn.attn.head_dim) + k = attn.attn._split_heads(k, attn.attn.num_heads, attn.attn.head_dim) + attn_weights = torch.matmul(q, k.transpose(-1, -2)) + + sq, sk = q.size(-2), k.size(-2) + causal_mask = attn.attn.bias[:, :, sk - sq : sk, :sk].bool() + total_mask = ~(causal_mask & (attention_mask == 0)) + """ + tensor([[[[False, True, True, ..., True, True, True], + [False, False, True, ..., True, True, True], + [False, False, False, ..., True, True, True], + ..., + [False, False, False, ..., False, True, True], + [False, False, False, ..., False, False, True], + [False, False, False, ..., False, False, False]]] + """ + + fused_softmax = ( + FusedScaleMaskSoftmax( + input_in_fp16=True, + input_in_bf16=False, + mask_func=attention_mask_func, + scale=None, + softmax_in_fp32=False, + attn_mask_type=AttnMaskType.causal, + scaled_masked_softmax_fusion=True, + ) + .cuda() + .half() + ) + + fused_softmax_output = fused_softmax( + attn_weights, + total_mask, + ) + + torch_softmax = ( + FusedScaleMaskSoftmax( + input_in_fp16=True, + input_in_bf16=False, + mask_func=attention_mask_func, + scale=None, + softmax_in_fp32=False, + attn_mask_type=AttnMaskType.causal, + scaled_masked_softmax_fusion=False, + ) + .cuda() + .half() + ) + + torch_softmax_output = torch_softmax( + attn_weights, + total_mask, + ) + + test_result = (fused_softmax_output - torch_softmax_output).abs() + + while test_result.dim() != 1: + test_result = test_result.mean(dim=-1) + + diff = test_result.mean(dim=-1) + + if diff <= 1e-3: + print( + f"\n[Success] test_fused_upper_triangle_mask_softmax" + f"\n > mean_difference={diff}" + f"\n > fused_values={fused_softmax_output[-1][-1][-1][:5].tolist()}" + f"\n > torch_values={torch_softmax_output[-1][-1][-1][:5].tolist()}" + ) + else: + print( + f"\n[Fail] test_fused_upper_triangle_mask_softmax" + f"\n > mean_difference={diff}, " + f"\n > fused_values={fused_softmax_output[-1][-1][-1][:5].tolist()}, " + f"\n > torch_values={torch_softmax_output[-1][-1][-1][:5].tolist()}" + ) + + +def test_layer_norm(): + bert = BertModel.from_pretrained("bert-base-cased").cuda().half() + tokenizer = BertTokenizer.from_pretrained("bert-base-cased") + test_text = ( + "Hello. How are you? I am fine thank you and you? yes Good. " + "hi hi hi hi hi hi hi hi hi hi hi hi hi" # 32 + ) + + tokens = tokenizer( + [test_text] * 4, + return_tensors="pt", + ) + + # [bsz, seq_len, d_model] + embedding_output = ( + bert.embeddings( + input_ids=tokens["input_ids"].cuda(), + position_ids=None, + token_type_ids=tokens["token_type_ids"].cuda(), + inputs_embeds=None, + past_key_values_length=0, + ) + .cuda() + .half() + ) + + fused_layernorm_layer = ( + MixedFusedLayerNorm(normalized_shape=embedding_output.size(-1)).cuda().half() + ) + + torch_layernorm_layer = ( + LayerNorm(normalized_shape=embedding_output.size(-1)).cuda().half() + ) + + fused_output = fused_layernorm_layer(embedding_output) + torch_output = torch_layernorm_layer(embedding_output) + test_result = (fused_output - torch_output).abs() + + while test_result.dim() != 1: + test_result = test_result.mean(dim=-1) + + diff = test_result.mean(dim=-1) + + if diff <= 1e-3: + print( + f"\n[Success] test_layer_norm" + f"\n > mean_difference={diff}" + f"\n > fused_values={fused_output[-1][-1][:5].tolist()}" + f"\n > torch_values={torch_output[-1][-1][:5].tolist()}" + ) + else: + print( + f"\n[Fail] test_layer_norm" + f"\n > mean_difference={diff}, " + f"\n > fused_values={fused_output[-1][-1][:5].tolist()}, " + f"\n > torch_values={torch_output[-1][-1][:5].tolist()}" + ) + + +def attention_mask_func(attention_scores, attention_mask): + attention_scores.masked_fill_(attention_mask, -10000.0) + return attention_scores + + +def forward_torch_softmax(input, mask, scale): + input = input * scale + mask_output = attention_mask_func(input, mask) if mask is not None else input + probs = torch.nn.Softmax(dim=-1)(mask_output) + return probs + + +def test_masked_softmax_forward(): + import scaled_masked_softmax_cuda + + batch = 2 + attn = 16 + scale_t = torch.tensor([1.0]) + for qlen in [128, 256, 1024, 2048, 4096]: + for klen in [128, 256, 1024, 2048]: + inputs = torch.normal(0, 2, (batch, attn, qlen, klen), dtype=torch.float16, device='cuda:0') + masks = torch.randint(0, 2, (batch, 1, qlen, klen), dtype=torch.bool, device='cuda:0') + softmax_results = scaled_masked_softmax_cuda.forward(inputs, masks, scale_t[0].item()) + softmax_results_torch = forward_torch_softmax(inputs, masks, scale_t[0].item()) + error = (softmax_results_torch - softmax_results).abs().max() + assert error < 1e-3 + +def test_masked_softmax_backward(): + import scaled_masked_softmax_cuda + + batch = 2 + attn = 16 + scale_t = torch.tensor([1.0]) + for qlen in [128, 256, 1024, 2048, 4096]: + for klen in [128, 256, 1024, 2048]: + inputs = torch.normal(0, 2, (batch, attn, qlen, klen), dtype=torch.float16, device='cuda:0') + backward = torch.rand_like(inputs, dtype=torch.float16, device='cuda:0') + masks = torch.randint(0, 2, (batch, 1, qlen, klen), dtype=torch.bool, device='cuda:0') + softmax_results = scaled_masked_softmax_cuda.forward(inputs, masks, scale_t[0].item()) + back_grad = scaled_masked_softmax_cuda.backward(backward, softmax_results, scale_t[0].item()) + + inputs.requires_grad = True + softmax_results_torch = forward_torch_softmax(inputs, masks, scale_t[0].item()) + softmax_results_torch.backward(backward) + error = (back_grad - inputs.grad).abs().max() + assert error < 1e-3 + + +def test_allmasked_softmax_forward(): + import scaled_masked_softmax_cuda + + batch = 2 + attn = 16 + scale_t = torch.tensor([1.0]) + for qlen in [128, 256, 1024, 2048, 4096]: + for klen in [128, 256, 1024, 2048]: + inputs = torch.normal(0, 2, (batch, attn, qlen, klen), dtype=torch.float16, device='cuda:0') + masks = torch.ones((batch, 1, qlen, klen), dtype=torch.bool, device='cuda:0') + softmax_results = scaled_masked_softmax_cuda.forward(inputs, masks, scale_t[0].item()) + softmax_results_torch = torch.zeros_like(inputs) + error = (softmax_results_torch - softmax_results).abs().max() + assert error == 0.0 + + +def test_allmasked_softmax_backward(): + import scaled_masked_softmax_cuda + + batch = 2 + attn = 16 + scale_t = torch.tensor([1.0]) + for qlen in [128, 256, 1024, 2048, 4096]: + for klen in [128, 256, 1024, 2048]: + inputs = torch.normal(0, 2, (batch, attn, qlen, klen), dtype=torch.float16, device='cuda:0') + backward = torch.rand_like(inputs, dtype=torch.float16, device='cuda:0') + masks = torch.ones((batch, 1, qlen, klen), dtype=torch.bool, device='cuda:0') + softmax_results = scaled_masked_softmax_cuda.forward(inputs, masks, scale_t[0].item()) + back_grad = scaled_masked_softmax_cuda.backward(backward, softmax_results, scale_t[0].item()) + inputs.requires_grad = True + softmax_results_torch = forward_torch_softmax(inputs, masks, scale_t[0].item()) + softmax_results_torch.backward(backward) + error = (back_grad - inputs.grad).abs().max() + assert error < 1e-3 + + +if __name__ == "__main__": + try: + from transformers import BertTokenizer, GPT2Tokenizer + from transformers.models.bert.modeling_bert import BertModel + from transformers.models.gpt2.modeling_gpt2 import GPT2Model + import transformers + + transformers.logging.set_verbosity( + transformers.logging.FATAL, + ) + + except: + print("\n[Fail] Please install `transformers` package to test fused kernels\n") + exit(-1) + + load() + test_masked_softmax_forward() + test_masked_softmax_backward() + test_allmasked_softmax_forward() + test_allmasked_softmax_backward() + test_load_fused_kernels() + test_fused_softmax() + test_fused_upper_triangle_mask_softmax() + test_layer_norm() diff --git a/megatron/fused_kernels/type_shim.h b/megatron/fused_kernels/type_shim.h new file mode 100644 index 0000000000000000000000000000000000000000..d60a6f8c6fb50e241f9ddcc852adec71e963e1b2 --- /dev/null +++ b/megatron/fused_kernels/type_shim.h @@ -0,0 +1,103 @@ +/* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. */ + + +#include +#include "compat.h" + + +#define DISPATCH_HALF_AND_BFLOAT(TYPE, NAME, ...) \ + switch(TYPE) \ + { \ + case at::ScalarType::Half: \ + { \ + using scalar_t = at::Half; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::BFloat16: \ + { \ + using scalar_t = at::BFloat16; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ + } + + +#define DISPATCH_HALF_BFLOAT_AND_FLOAT(TYPE, NAME, ...) \ + switch(TYPE) \ + { \ + case at::ScalarType::Half: \ + { \ + using scalar_t = at::Half; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::BFloat16: \ + { \ + using scalar_t = at::BFloat16; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::Float: \ + { \ + using scalar_t = float; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ + } + + + +#define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \ + switch(TYPEIN) \ + { \ + case at::ScalarType::Float: \ + { \ + using scalar_t_in = float; \ + switch(TYPEOUT) \ + { \ + case at::ScalarType::Float: \ + { \ + using scalar_t_out = float; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::Half: \ + { \ + using scalar_t_out = at::Half; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::BFloat16: \ + { \ + using scalar_t_out = at::BFloat16; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \ + } \ + break; \ + } \ + case at::ScalarType::Half: \ + { \ + using scalar_t_in = at::Half; \ + using scalar_t_out = at::Half; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::BFloat16: \ + { \ + using scalar_t_in = at::BFloat16; \ + using scalar_t_out = at::BFloat16; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \ + } + diff --git a/megatron/model/fused_softmax.py b/megatron/model/fused_softmax.py index 8d52967d95be2605a29f8ea208844b5fce1fe000..1ae208c3e1ca07f96e5beda22cc1df0124e3b68e 100644 --- a/megatron/model/fused_softmax.py +++ b/megatron/model/fused_softmax.py @@ -14,6 +14,7 @@ # limitations under the License. import torch +import torch.nn as nn from megatron.model.enums import AttnMaskType @@ -80,17 +81,18 @@ class ScaledMaskedSoftmax(torch.autograd.Function): ) return input_grads, None, None - -class FusedScaleMaskSoftmax(torch.nn.Module): +class FusedScaleMaskSoftmax(nn.Module): """ fused operation: scaling + mask + softmax + Arguments: input_in_fp16: flag to indicate if input in fp16 data format. + input_in_bf16: flag to indicate if input in bf16 data format. attn_mask_type: attention mask type (pad or causal) + scaled_masked_softmax_fusion: flag to indicate user want to use softmax fusion mask_func: mask function to be applied. softmax_in_fp32: if true, softmax in performed at fp32 precision. scale: scaling factor used in input tensor scaling. - """ def __init__( @@ -106,8 +108,9 @@ class FusedScaleMaskSoftmax(torch.nn.Module): super(FusedScaleMaskSoftmax, self).__init__() self.input_in_fp16 = input_in_fp16 self.input_in_bf16 = input_in_bf16 - assert not (self.input_in_fp16 and self.input_in_bf16),\ - 'both fp16 and bf16 flags cannot be active at the same time.' + assert not ( + self.input_in_fp16 and self.input_in_bf16 + ), "both fp16 and bf16 flags cannot be active at the same time." self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16 self.attn_mask_type = attn_mask_type self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion @@ -118,46 +121,75 @@ class FusedScaleMaskSoftmax(torch.nn.Module): assert ( self.scale is None or softmax_in_fp32 ), "softmax should be in fp32 when scaled" - + def forward(self, input, mask): # [b, np, sq, sk] assert input.dim() == 4 - data_size = input.size() - query_seq_len = data_size[-2] - key_seq_len = data_size[-1] - attn_batch_size = data_size[0] * data_size[1] - - # constraints on various tensor dimensions to enable warp based - # optimization and upper triangular optimization (for causal mask) - custom_kernel_constraint = key_seq_len > 16 and key_seq_len <= 2048 and \ - query_seq_len % 4 == 0 and attn_batch_size % 4 == 0 - - # invoke custom kernel - if self.input_in_float16 and mask is not None and \ - custom_kernel_constraint and self.scaled_masked_softmax_fusion: - scale = self.scale if self.scale is not None else 1.0 - - if self.attn_mask_type == AttnMaskType.causal: - assert query_seq_len == key_seq_len, \ - "causal mask is only for self attention" - input = input.view(-1, query_seq_len, key_seq_len) - probs = ScaledUpperTriangMaskedSoftmax.apply(input, scale) - probs = probs.view(*data_size) - else: - assert self.attn_mask_type == AttnMaskType.padding - probs = ScaledMaskedSoftmax.apply(input, mask, scale) + + if self.is_kernel_available(mask, *input.size()): + return self.forward_fused_softmax(input, mask) else: - if self.input_in_float16 and self.softmax_in_fp32: - input = input.float() - - if self.scale is not None: - input = input * self.scale - mask_output = self.mask_func(input, mask) if mask is not None else input - probs = torch.nn.Softmax(dim=-1)(mask_output) - if self.input_in_float16 and self.softmax_in_fp32: - if self.input_in_fp16: - probs = probs.half() + return self.forward_torch_softmax(input, mask) + + def is_kernel_available(self, mask, b, np, sq, sk): + attn_batches = b * np + + if ( + self.scaled_masked_softmax_fusion # user want to fuse + and self.input_in_float16 # input must be fp16 + and 16 < sk <= 4096 # sk must be 16 ~ 2048 + and sq % 4 == 0 # sq must be divisor of 4 + and sk % 4 == 0 # sk must be divisor of 4 + and attn_batches % 4 == 0 # np * b must be divisor of 4 + ): + if 0 <= sk <= 4096: + batch_per_block = self.get_batch_per_block(sq, sk, b, np) + + if self.attn_mask_type == AttnMaskType.causal: + if attn_batches % batch_per_block == 0: + return True else: - probs = probs.bfloat16() + if sq % batch_per_block == 0: + return True + return False + + def forward_fused_softmax(self, input, mask): + b, np, sq, sk = input.size() + scale = self.scale if self.scale is not None else 1.0 + + if self.attn_mask_type == AttnMaskType.causal: + assert sq == sk, "causal mask is only for self attention" + + # input is 3D tensor (attn_batches, sq, sk) + input = input.view(-1, sq, sk) + probs = ScaledUpperTriangMaskedSoftmax.apply(input, scale) + return probs.view(b, np, sq, sk) + else: + # input is 4D tensor (b, np, sq, sk) + if mask is not None: + return ScaledMaskedSoftmax.apply(input, mask, scale) + else: + return ScaledSoftmax.apply(input, scale) + + def forward_torch_softmax(self, input, mask): + if self.input_in_float16 and self.softmax_in_fp32: + input = input.float() + + if self.scale is not None: + input = input * self.scale + mask_output = self.mask_func(input, mask) if mask is not None else input + probs = torch.nn.Softmax(dim=-1)(mask_output) + + if self.input_in_float16 and self.softmax_in_fp32: + if self.input_in_fp16: + probs = probs.half() + else: + probs = probs.bfloat16() return probs + + @staticmethod + def get_batch_per_block(sq, sk, b, np): + import scaled_masked_softmax_cuda + + return scaled_masked_softmax_cuda.get_batch_per_block(sq, sk, b, np) diff --git a/megatron/model/gpt_model.py b/megatron/model/gpt_model.py index 62b951064e3de315df01b80c877aea24286728a8..91f5091e882c882d2278225264905f765b2fd7db 100644 --- a/megatron/model/gpt_model.py +++ b/megatron/model/gpt_model.py @@ -72,7 +72,8 @@ class GPTModel(MegatronModule): parallel_output=True, pre_process=True, post_process=True, - return_moe_loss=True): + prefix_lm=False): + # return_moe_loss=True): super(GPTModel, self).__init__() args = get_args() @@ -80,14 +81,15 @@ class GPTModel(MegatronModule): self.pre_process = pre_process self.post_process = post_process self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy - self.return_moe_loss = return_moe_loss + # self.return_moe_loss = return_moe_loss self.language_model, self._language_model_key = get_language_model( num_tokentypes=num_tokentypes, add_pooler=False, - encoder_attn_mask_type=AttnMaskType.causal, + encoder_attn_mask_type=AttnMaskType.prefix if prefix_lm else AttnMaskType.causal, + # encoder_attn_mask_type=AttnMaskType.causal, init_method=init_method_normal(args.init_method_std), scaled_init_method=scaled_init_method_normal(args.init_method_std, args.num_layers), - num_experts=args.num_experts, + # num_experts=args.num_experts, pre_process=self.pre_process, post_process=self.post_process) @@ -134,10 +136,10 @@ class GPTModel(MegatronModule): forward_method_parallel_output, self.fp16_lm_cross_entropy) - if self.return_moe_loss: - return (lm_output, *moe_losses) - else: - return lm_output + # if self.return_moe_loss: + # return (lm_output, *moe_losses) + # else: + return lm_output def state_dict_for_save_checkpoint(self, destination=None, prefix='', keep_vars=False): diff --git a/megatron/mpu/utils.py b/megatron/mpu/utils.py index 56ed1c76e1404389f18ab3be01e13dfdc678d942..3f6e9b7d713d853694151354296fc32a1c0b3112 100644 --- a/megatron/mpu/utils.py +++ b/megatron/mpu/utils.py @@ -29,6 +29,59 @@ def divide(numerator, denominator): ensure_divisibility(numerator, denominator) return numerator // denominator +def _kernel_make_viewless_tensor(inp, requires_grad): + '''Make a viewless tensor. + + View tensors have the undesirable side-affect of retaining a reference + to the originally-viewed tensor, even after manually setting the '.data' + field. This method creates a new tensor that links to the old tensor's + data, without linking the viewed tensor, referenced via the '._base' + field. + ''' + out = torch.empty( + (1,), + dtype = inp.dtype, + device = inp.device, + requires_grad = requires_grad, + ) + out.data = inp.data + return out + +class MakeViewlessTensor(torch.autograd.Function): + ''' + Autograd function to make a viewless tensor. + + This function should be used in cases where the computation graph needs + to be propagated, but we only want a viewless tensor (e.g., + ParallelTransformer's hidden_states). Call this function by passing + 'keep_graph = True' to 'make_viewless_tensor()'. + ''' + @staticmethod + def forward(ctx, inp, requires_grad): + return _kernel_make_viewless_tensor(inp, requires_grad) + @staticmethod + def backward(ctx, grad_output): + return grad_output, None + +def make_viewless_tensor(inp, requires_grad, keep_graph): + ''' + Entry-point for creating viewless tensors. + + This method should be used, rather than calling 'MakeViewlessTensor' + or '_kernel_make_viewless_tensor' directly. This method acts as a + switch for determining if an autograd function or a regular method + should be used to create the tensor. + ''' + + # return tensor as-is, if not a 'view' + if inp._base is None: + return inp + + # create viewless tensor + if keep_graph: + return MakeViewlessTensor.apply(inp, requires_grad) + else: + return _kernel_make_viewless_tensor(inp, requires_grad) def split_tensor_along_last_dim(tensor, num_partitions, contiguous_split_chunks=False): diff --git a/megatron/optimizer/optimizer.py b/megatron/optimizer/optimizer.py index 6568bf10a7fa8b6d2846d4fff29bc13de0d185d0..fc50623601dc45de9e575f1a41cd0cb3c07045b1 100644 --- a/megatron/optimizer/optimizer.py +++ b/megatron/optimizer/optimizer.py @@ -26,6 +26,8 @@ from megatron import mpu from megatron import print_rank_0 from deepspeed.accelerator import get_accelerator from .clip_grads import clip_grad_norm_fp32, count_zeros_fp32 +from ..model.module import param_is_not_shared +from ..mpu.layers import param_is_not_tensor_parallel_duplicate def _zero_grad_group_helper(group, set_to_none): @@ -87,9 +89,36 @@ class MegatronOptimizer(ABC): return params + def get_main_grads_for_grad_norm(self): + + # Filter parameters based on: + # - grad should not be none + # - parameter should not be shared + # - should not be a replica due to tensor model parallelism + params = self.get_parameters() + grads_for_norm = [] + for param in params: + grad = param.grad + grad_not_none = grad is not None + is_not_shared = param_is_not_shared(param) + is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param) + if grad_not_none and is_not_shared and is_not_tp_duplicate: + grads_for_norm.append(grad) + + return grads_for_norm + + + def get_model_parallel_group(self): + """Default returned here, but the distributed optimizer overrides this.""" + return mpu.get_model_parallel_group() + + def clip_grad_norm(self, clip_grad): params = self.get_parameters() - return clip_grad_norm_fp32(params, clip_grad) + grads_for_norm = self.get_main_grads_for_grad_norm() + return clip_grad_norm_fp32( + params, grads_for_norm, clip_grad, + model_parallel_group=self.get_model_parallel_group()) def count_zeros(self): @@ -161,6 +190,167 @@ class MegatronOptimizer(ABC): param_groups = property(_get_param_groups, _set_param_groups) +class MixedPrecisionOptimizer(MegatronOptimizer): + """Base class for both the float-16 and the distributed optimizer. + + Arguments: + optimizer: base optimizer such as Adam or SGD + clip_grad: clip gradeints with this global L2 norm. Note + that clipping is ignored if clip_grad == 0 + log_num_zeros_in_grad: return number of zeros in the gradients. + params_have_main_grad: flag indicating if parameters have + a `main_grad` field. If this is set, we are assuming + that the model parameters are store in the `main_grad` + field instead of the typical `grad` field. This happens + for the DDP cases where there is a continuous buffer + holding the gradients. For example for bfloat16, we want + to do gradient accumulation and all-reduces in float32 + and as a result we store those gradients in the main_grad. + Note that main grad is not necessarily in float32. + use_contiguous_buffers_in_local_ddp: if true, the local DDP model + is using a contiguous buffer to hold the model grads. + fp16: if true, the model is running in fp16. + bf16: if true, the model is running in bfloat16. + params_dtype: used by distributed optimizer. + grad_scaler: used for scaling gradients. Note that this can be + None. This case happens when `bf16 = True` and we don't + use any loss scale. Note that for `bf16 = True`, we can have + a constnat gradient scaler. Also for `bf16 = False`, we + always require a grad scaler. + models: list of models (i.e., the virtual pipelining models). This + is used by the distributed optimizer for mapping parameters. + """ + + def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad, + params_have_main_grad, use_contiguous_buffers_in_local_ddp, + fp16, bf16, params_dtype, grad_scaler, + models): + + super().__init__( + optimizer, clip_grad, log_num_zeros_in_grad, + params_have_main_grad, use_contiguous_buffers_in_local_ddp, + models) + + self.fp16 = fp16 + self.bf16 = bf16 + self.params_dtype = params_dtype + self.grad_scaler = grad_scaler + + # None grad scaler is only supported for bf16. + if self.grad_scaler is None: + assert not self.fp16, 'fp16 expects a grad scaler.' + + # Tensor used to determine if a nan/if has happend. + # Any non-zero value indicates inf/nan. + # Note that we keep this for the cases that grad scaler is none. + # We still record nan/inf if we have a bfloat16 with a grad scaler. + if self.grad_scaler: + self.found_inf = torch.cuda.FloatTensor([0.0]) + + # Dummy tensor needed for apex multi-apply tensor. + # For bfloat, we don't have multi-tensor apply and for now + # we set it to none so the multi-tensor apply gets ignored. + if bf16: + self._dummy_overflow_buf = None + else: + self._dummy_overflow_buf = torch.cuda.IntTensor([0]) + + # In case grad scaler is not passed, define the unity scale. + if self.grad_scaler is None: + self._scale_one = torch.cuda.FloatTensor([1.0]) + + + def get_loss_scale(self): + if self.grad_scaler is None: + return self._scale_one + return self.grad_scaler.scale + + + def reload_model_params(self): + self._copy_model_params_to_main_params() + + + def _unscale_main_grads_and_check_for_nan(self): + + # Collect main grads. + main_grads = self._collect_main_grad_data_for_unscaling() + + # Reset found inf. + self.found_inf.fill_(0.0) + + # Unscale and set found inf/nan + torch._amp_foreach_non_finite_check_and_unscale_( + main_grads, self.found_inf, self.grad_scaler.inv_scale) + + # Update across all model parallel instances. + torch.distributed.all_reduce(self.found_inf, + op=torch.distributed.ReduceOp.MAX, + group=self.get_model_parallel_group()) + + # Check for nan. + found_inf_flag = (self.found_inf.item() > 0) + + return found_inf_flag + + + @torch.no_grad() + def step(self, args, timers): + + # Copy gradients from model params to main params. + timers('optimizer-copy-to-main-grad', log_level=1).start( + barrier=args.barrier_with_L1_time) + self._copy_model_grads_to_main_grads() + timers('optimizer-copy-to-main-grad').stop() + + # Do unscale, check for inf, and update grad scaler only for + # the case that grad scaler is provided. + if self.grad_scaler: + + # Unscale and check for inf/nan. + timers('optimizer-unscale-and-check-inf', log_level=1).start( + barrier=args.barrier_with_L1_time) + found_inf_flag = self._unscale_main_grads_and_check_for_nan() + timers('optimizer-unscale-and-check-inf').stop() + + # We are done with scaling gradients + # so we can update the loss scale. + self.grad_scaler.update(found_inf_flag) + + # If we found inf/nan, skip the update. + if found_inf_flag: + return False, None, None + + # Clip the main gradients. + timers('optimizer-clip-main-grad', log_level=1).start( + barrier=args.barrier_with_L1_time) + grad_norm = None + if self.clip_grad > 0.0: + grad_norm = self.clip_grad_norm(self.clip_grad) + timers('optimizer-clip-main-grad').stop() + + # Count the zeros in the grads. + timers('optimizer-count-zeros', log_level=1).start( + barrier=args.barrier_with_L1_time) + num_zeros_in_grad = self.count_zeros() if \ + self.log_num_zeros_in_grad else None + timers('optimizer-count-zeros').stop() + + # Step the optimizer. + timers('optimizer-inner-step', log_level=1).start( + barrier=args.barrier_with_L1_time) + self.optimizer.step() + timers('optimizer-inner-step').stop() + + # Update params from main params. + timers('optimizer-copy-main-to-model-params', log_level=1).start( + barrier=args.barrier_with_L1_time) + self._copy_main_params_to_model_params() + timers('optimizer-copy-main-to-model-params').stop() + + # Successful update. + return True, grad_norm, num_zeros_in_grad + + class Float16OptimizerWithFloat16Params(MegatronOptimizer): """Float16 optimizer for fp16 and bf16 data types. diff --git a/megatron/timers.py b/megatron/timers.py new file mode 100644 index 0000000000000000000000000000000000000000..03810cbd20fd74a01693261b91a889ec481fba7a --- /dev/null +++ b/megatron/timers.py @@ -0,0 +1,283 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + +"""Megatron timers.""" + +from abc import ABC +from abc import abstractmethod +import time + +import torch + + + +class TimerBase(ABC): + + def __init__(self, name): + self.name = name + + @abstractmethod + def start(self, barrier=False): + pass + + @abstractmethod + def stop(self, barrier=False): + pass + + @abstractmethod + def reset(self): + pass + + @abstractmethod + def elapsed(self, reset=True, barrier=False): + pass + + + +class DummyTimer(TimerBase): + + def __init__(self): + super().__init__('dummy timer') + + def start(self, barrier=False): + return + + def stop(self, barrier=False): + return + + def reset(self): + return + + def elapsed(self, reset=True, barrier=False): + raise Exception('dummy timer should not be used to ' + 'calculate elapsed time') + + + +class Timer(TimerBase): + """ + Comment on using `barrier`: If this flag is passed, then all + the caller processes will wait till all reach the timing routine. + It is up to the user to make sure all the ranks in `barrier_group` + call it otherwise, it will result in a hang. + Comment on `barrier_group`: By default it is set to None which + in torch distributed land, it will result in the global communicator. + """ + + def __init__(self, name): + super().__init__(name) + self._elapsed = 0.0 + self._started = False + # Note that None will default to the global process group + self._barrier_group = None + self._start_time = time.time() + + + def set_barrier_group(self, barrier_group): + self._barrier_group = barrier_group + + + def start(self, barrier=False): + """Start the timer.""" + assert not self._started, 'timer has already been started' + if barrier: + torch.distributed.barrier(group=self._barrier_group) + torch.cuda.synchronize() + self._start_time = time.time() + self._started = True + + + def stop(self, barrier=False): + """Stop the timer.""" + assert self._started, 'timer is not started' + if barrier: + torch.distributed.barrier(group=self._barrier_group) + torch.cuda.synchronize() + self._elapsed += (time.time() - self._start_time) + self._started = False + + + def reset(self): + """Reset timer.""" + self._elapsed = 0.0 + self._started = False + + + def elapsed(self, reset=True, barrier=False): + """Calculate the elapsed time.""" + _started = self._started + # If the timing in progress, end it first. + if self._started: + self.stop(barrier=barrier) + # Get the elapsed time. + _elapsed = self._elapsed + # Reset the elapsed time + if reset: + self.reset() + # If timing was in progress, set it back. + if _started: + self.start(barrier=barrier) + return _elapsed + + + +class Timers: + """Group of timers.""" + + def __init__(self, log_level, log_option): + self._log_level = log_level + self._log_option = log_option + self._timers = {} + self._log_levels = {} + self._dummy_timer = DummyTimer() + self._max_log_level = 2 + + + def __call__(self, name, log_level=None): + if name not in self._timers: + self._timers[name] = Timer(name=name) + return self._timers[name] + + + def _get_elapsed_time_all_ranks(self, names, reset, barrier): + """ + Assumptions: + - All the ranks call this function. + - `names` are identical on all ranks. + If the above assumptions are not met, calling this function will + result in hang. + Arguments: + - names: list of timer names + - reset: reset the timer after recording the elapsed time + - barrier: if set, do a global barrier before time measurments + """ + + # First make sure all the callers are in sync. + if barrier: + torch.distributed.barrier() + + world_size = torch.distributed.get_world_size() + rank = torch.distributed.get_rank() + + # Here we can use gather on the rank we want to print the + # timing, however, there is no gather_base support in + # pytorch yet. It is simpler to deal with a single tensor + # and since we are only gathering a small amount of data, + # it should be ok to use all-gather instead of gather. + rank_name_to_time = torch.zeros((world_size, len(names)), + dtype=torch.float, + device=torch.cuda.current_device()) + for i, name in enumerate(names): + if name in self._timers: + # Here we don't need to pass the barrier flag as all + # the processes are already in sync. This avoids the + # issue of different timers having different barrier + # groups inside their class. + rank_name_to_time[rank, i] = self._timers[name].elapsed( + reset=reset) + + # See the note above for why we are not using gather. + torch.distributed._all_gather_base(rank_name_to_time.view(-1), + rank_name_to_time[rank, :].view(-1)) + + return rank_name_to_time + + + def _get_global_min_max_time(self, names, reset, barrier, normalizer): + """Report only min and max times across all ranks.""" + + rank_name_to_time = self._get_elapsed_time_all_ranks(names, reset, + barrier) + name_to_min_max_time = {} + for i, name in enumerate(names): + rank_to_time = rank_name_to_time[:, i] + # filter out the ones we did not have any timings for + rank_to_time = rank_to_time[rank_to_time > 0.0] + # If the timer exists: + if rank_to_time.numel() > 0: + name_to_min_max_time[name] = ( + rank_to_time.min().item() / normalizer, + rank_to_time.max().item() / normalizer) + return name_to_min_max_time + + + def _get_global_min_max_time_string(self, names, reset, barrier, + normalizer, max_only): + name_to_min_max_time = self._get_global_min_max_time( + names, reset, barrier, normalizer) + if not name_to_min_max_time: + return None + output_string = '(min, max) time across ranks (ms):' + for name in name_to_min_max_time: + min_time, max_time = name_to_min_max_time[name] + if max_only: + output_string += '\n {}: {:.2f}'.format( + (name+' ').ljust(48, '.'), max_time) + else: + output_string += '\n {}: ({:.2f}, {:.2f})'.format( + (name+' ').ljust(48, '.'), min_time, max_time) + return output_string + + + def _get_all_ranks_time_string(self, names, reset, barrier, normalizer): + """Report times across all ranks.""" + rank_name_to_time = self._get_elapsed_time_all_ranks(names, reset, + barrier) + + output_string = 'times across ranks (ms):' + no_reported_timing = True + for i, name in enumerate(names): + not_yet_found = True + for rank in range(torch.distributed.get_world_size()): + if rank_name_to_time[rank, i] > 0: + no_reported_timing = False + if not_yet_found: + not_yet_found = False + output_string += '\n {}:'.format(name) + output_string += '\n rank {:2d}: {:.2f}'.format( + rank, rank_name_to_time[rank, i] / normalizer) + if no_reported_timing: + return None + return output_string + + + def log(self, names, rank=None, normalizer=1.0, reset=True, barrier=False): + """Log a group of timers.""" + + # Print. + assert normalizer > 0.0 + if self._log_option in ['max', 'minmax']: + max_only = False + if self._log_option == 'max': + max_only = True + output_string = self._get_global_min_max_time_string( + names, reset, barrier, normalizer/1000.0, max_only) + elif self._log_option == 'all': + output_string = self._get_all_ranks_time_string(names, + reset, barrier, + normalizer/1000.0) + else: + raise Exception('unknown timing log option {}'.format( + self._log_option)) + + # If no input rank is provided, log on last rank. + if rank is None: + rank = torch.distributed.get_world_size() - 1 + if rank == torch.distributed.get_rank() and output_string is not None: + print(output_string, flush=True) + + + def write(self, names, writer, iteration, normalizer=1.0, + reset=False, barrier=False): + """Write timers to a tensorboard writer + Note that we only report maximum time across ranks to tensorboard. + """ + # currently when using add_scalars, + # torch.utils.add_scalars makes each timer its own run, which + # polutes the runs list, so we just add each as a scalar + assert normalizer > 0.0 + name_to_min_max_time = self._get_global_min_max_time( + names, reset, barrier, normalizer) + if writer is not None: + for name in name_to_min_max_time: + _, max_time = name_to_min_max_time[name] + writer.add_scalar(name + '-time', max_time, iteration) diff --git a/megatron/tokenizer/tokenizer.py b/megatron/tokenizer/tokenizer.py index 13085a81c9feb4f4ba2ffe3f254a3bd7d3f4c144..b260c247551ccacaf36a77fabd253360629ed996 100644 --- a/megatron/tokenizer/tokenizer.py +++ b/megatron/tokenizer/tokenizer.py @@ -17,7 +17,7 @@ from abc import ABC from abc import abstractmethod - +from transformers import AutoTokenizer from .bert_tokenization import FullTokenizer as FullBertTokenizer from .gpt2_tokenization import GPT2Tokenizer @@ -29,7 +29,7 @@ def build_tokenizer(args): flush=True) # Select and instantiate the tokenizer. - assert args.vocab_file is not None + assert args.vocab_file is not None or args.tokenizer_type == 'PretrainedFromHF' if args.tokenizer_type == 'BertWordPieceLowerCase': tokenizer = _BertWordPieceTokenizer(vocab_file=args.vocab_file, lower_case=True, @@ -41,6 +41,21 @@ def build_tokenizer(args): elif args.tokenizer_type == 'GPT2BPETokenizer': assert args.merge_file is not None tokenizer = _GPT2BPETokenizer(args.vocab_file, args.merge_file) + elif args.tokenizer_type == "PretrainedFromHF": + assert args.tokenizer_name_or_path is not None + + # prevent transformers from logging info and warnings on each rank + import transformers + import logging + if args.rank == 0: + transformers.utils.logging.set_verbosity(logging.INFO) + else: + # shut the warnings on replicas + transformers.utils.logging.set_verbosity(logging.ERROR) + + if args.rank == 0: + print(" vocab file is un-used. loading tokenizer from pre-trained model") + tokenizer = _AutoTokenizer(args.tokenizer_name_or_path, vocab_extra_ids=args.vocab_extra_ids) else: raise NotImplementedError('{} tokenizer is not ' 'implemented.'.format(args.tokenizer_type)) @@ -53,14 +68,25 @@ def build_tokenizer(args): def _vocab_size_with_padding(orig_vocab_size, args): - """Pad vocab size so it is divisible by model parallel size and - still having GPU friendly size.""" - - after = orig_vocab_size - multiple = args.make_vocab_size_divisible_by * \ - args.tensor_model_parallel_size - while (after % multiple) != 0: - after += 1 + """Apply the requested rules to change the size of the vocabulary""" + if args.pad_vocab_size_to is not None: + if args.pad_vocab_size_to < orig_vocab_size: + raise ValueError( + f"You asked to pad the vocabulary to {args.pad_vocab_size_to} when the initial vocabulary size is " + f"{orig_vocab_size}. You can only pad to a higher value." + ) + + if args.make_vocab_size_divisible_by is not None and (args.pad_vocab_size_to % args.make_vocab_size_divisible_by) != 0: + raise ValueError(f"{args.pad_vocab_size_to} is not divisible by {args.make_vocab_size_divisible_by}") + + after = args.pad_vocab_size_to + else: + # Pad vocab size so it is divisible by model parallel size and still having GPU friendly size. + after = orig_vocab_size + multiple = args.make_vocab_size_divisible_by * \ + args.tensor_model_parallel_size + while (after % multiple) != 0: + after += 1 if args.rank == 0: print(' > padded vocab (size: {}) with {} dummy tokens ' '(new size: {})'.format( @@ -289,3 +315,86 @@ class _GPT2BPETokenizer(AbstractTokenizer): @property def eod(self): return self.eod_id + + +class _AutoTokenizer(AbstractTokenizer): + """AutoTokenizer for Hf Pretrained model loading.""" + + def __init__(self, tokenizer_name_or_path, vocab_extra_ids): + name = tokenizer_name_or_path + super().__init__(name) + hf_tokenizer_kwargs = {} + if vocab_extra_ids > 0: + # TODO @thomasw21 we might need to concatenate to a pre-existing list? + hf_tokenizer_kwargs["additional_special_tokens"] = [f"" for _id in range(vocab_extra_ids)] + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path, **hf_tokenizer_kwargs) + self.encoder = self.tokenizer.get_vocab() + self.decoder = {v: k for k, v in self.encoder.items()} + + @property + def vocab_size(self): + return len(self.tokenizer) # vocab_size doesn't contain additional tokens + + @property + def vocab(self): + # TODO @thomasw21 make sure that special tokens don't collapse with vocab tokens. + return { + **{special_token: self.tokenizer.convert_tokens_to_ids(special_token) for special_token in self.tokenizer.additional_special_tokens}, + **self.tokenizer.vocab, + } + + @property + def inv_vocab(self): + return {v: k for k, v in self.vocab.items()} + + def tokenize(self, text): + return self.tokenizer.encode(text) + + def detokenize(self, token_ids): + return self.tokenizer.decode(token_ids) + + @property + def eod(self): + # TODO @thomasw21 might conflict with + return self.eos + + @property + def cls(self): + candidate = self.tokenizer.cls_token_id + return self._check_token_candidate(candidate) + + @property + def sep(self): + candidate = self.tokenizer.sep_token_id + return self._check_token_candidate(candidate) + + @property + def pad(self): + candidate = self.tokenizer.pad_token_id + return self._check_token_candidate(candidate) + + @property + def mask(self): + candidate = self.tokenizer.mask_token_id + return self._check_token_candidate(candidate) + + @property + def bos(self): + raise NotImplementedError("Missing ") + + @property + def eos(self): + # TODO @thomasw21 might conflict with the notion of + candidate = self.tokenizer.eos_token_id + return self._check_token_candidate(candidate) + + @property + def additional_special_tokens_ids(self): + """ All the additional special tokens you may want to use (list of strings).""" + return self.tokenizer.additional_special_tokens_ids + + @staticmethod + def _check_token_candidate(candidate): + if candidate is None: + raise AttributeError("Token doesn't exist") + return candidate \ No newline at end of file diff --git a/megatron_npu/.gitignore b/megatron_npu/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..afb27ab2f6bb1bbcf04fd1939bc457970c50f542 --- /dev/null +++ b/megatron_npu/.gitignore @@ -0,0 +1,3 @@ +__pycache__ +.idea +*tmp* \ No newline at end of file diff --git a/megatron_npu/README.md b/megatron_npu/README.md new file mode 100644 index 0000000000000000000000000000000000000000..817c101e540af967f64259c55cc1079009c72e5e --- /dev/null +++ b/megatron_npu/README.md @@ -0,0 +1,139 @@ +# Megatron-LM + +## 简介 + +Megatron 是由 NVIDIA 的应用深度学习研究团队开发的一款功能强大的大型Transformer仓。此仓为昇腾基于github原始仓的适配仓,已适配特性如下: + +- 数据并行(Data parallel) +- 模型并行(Tensor parallel) +- 序列并行(Sequence parallel) +- 流水并行(Pipeline parallel) +- 分布式优化器(Distributed optimizer) + +## 准备环境 + +- 当前模型支持的固件与驱动、 CANN 以及 PyTorch 如下表所示。 + + **表 1** 版本配套表 + +| 配套 | 版本 | +| ----- | ----- | +| 固件与驱动 | [22.0.RC3](https://www.hiascend.com/hardware/firmware-drivers?tag=commercial) | +| CANN | [6.1.RC1](https://www.hiascend.com/software/cann/commercial?version=6.1.RC1) | +| PyTorch | [1.11](https://gitee.com/ascend/pytorch/tree/master/) | + +- 环境准备指导。 + + 请参考《[Pytorch框架训练环境准备](https://www.hiascend.com/document/detail/zh/ModelZoo/pytorchframework/ptes)》。 + +- 克隆原始仓 + ``` + git clone https://github.com/NVIDIA/Megatron-LM.git + cd Megatron-LM + git checkout 285068c8108e0e8e6538f54fe27c3ee86c5217a2 + ``` + +- 下载安装 Megatron_npu + ``` + git clone https://gitee.com/ascend/Megatron-LM.git -b dev megatron_npu + cd megatron_npu + pip install -e . + ``` + +- 安装依赖(根据模型需求,按需添加所需依赖)。 + ``` + pip install -r requirements.txt + ``` + +## 准备数据集 + +1. 获取数据集。 + + ```bash ./tests/dataset_preprocess_t5.sh``` + +2. 数据集目录结构 + 将数据集默认放置在```./dataset/en_wiki/preprocess/```下,数据集的目录结构如下所示: + + ``` + ├── ./dataset/en_wiki/preprocess/ + ├── bert-large-uncased-vocab.txt + ├── my-t5_text_sentence.bin + ├── my-t5_text_sentence.idx + ``` + +> **说明:** +> 该数据集的训练过程脚本只作为一种参考示例。 + +## 获取预训练模型(可选) + +- 本模型不涉及 + +## 测试UT(可选) + +``` +bash tests/test.sh +``` + +# 开始训练 + +## 训练模型 + +1. 需要在模型执行前导入 megatron_npu 包。 + ``` + import megatron_npu + ``` + +2. 进入解压后的源码包根目录。 + + ``` + cd ./${模型文件夹名称} + ``` + +3. 运行训练脚本。 + + 该模型支持单机单卡训练和单机8卡训练。 + + - 单机8卡训练 + + 启动8卡训练。 + + ``` + bash ./tests/train_full_8p.sh + ``` + + 训练完成后,权重文件保存在./checkpoint下,并输出模型训练精度和性能信息。 + +# 训练结果展示 + +**表 2** 训练结果展示表 + +| NAME | PPL | samples/s | Steps | +| ------- | ----- |----------:| ------ | +| 8p-竞品A | 8.688 | 232 | 100000 | +| 8p-NPU | 8.701 | 100 | 100000 | +| 32p-NPU | 6.319 | 393 | 100000 | + +备注:一定要有竞品和NPU。 + +# 版本说明 + +## 变更 + +2022.08.26:首次发布 + +## 已知问题 + +**_当前发行版本中存在的问题描述。_** + +无。 + + + + + + + + + + + diff --git a/megatron_npu/README_ez.md b/megatron_npu/README_ez.md new file mode 100644 index 0000000000000000000000000000000000000000..f9b3559dcce10f11e4b5fc2f5c502f56fb0b8ae4 --- /dev/null +++ b/megatron_npu/README_ez.md @@ -0,0 +1,44 @@ +# Mgeatron-LM + +## ENV +- github https://github.com/NVIDIA/Megatron-LM.git +- commit_id 0bb597b42c53355a567aba2a1357cc34b9d99ddd + +## UT test + +- script: ```bash test/test.sh``` + +- ut_list: (megatron/mpu/tests) + - [x] test_cross_entropy.py + - [x] test_data.py + - [x] test_initialize.py + - [x] test_layers.py + - [ ] test_random.py + +- note + - get_rng_state not ok now, [issue link](https://gitee.com/ascend/pytorch-develop/issues/I50ZH1?from=project-issue) + - set seed on npu is not ok now, [issue link](https://gitee.com/ascend/pytorch-develop/issues/I50ZLR?from=project-issue) + +## model support plan + +### T5 + +#### func + +1. [] pretrain_t5.sh + - ```bash test/pretrain_t5.sh``` +2. [ ] pretrain_t5_distributed_with_mp.sh + - ```bash test/pretrain_t5_distributed_with_mp.sh``` +3. [ ] pretrain_t5_xxB.sh +4. [ ] pretrain_t5_xxxB.sh + + +### GPT-3 + +#### func + +1. [ ] pretrain_gpt.sh +2. [ ] pretrain_gpt_distributed.sh +3. [ ] pretrain_gpt_distributed_with_mp.sh +4. [ ] pretrain_gpt3_175B.sh + diff --git a/megatron_npu/megatron_npu/__init__.py b/megatron_npu/megatron_npu/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1a86f23fafe8580329545f63701855fe2f9adfec --- /dev/null +++ b/megatron_npu/megatron_npu/__init__.py @@ -0,0 +1,60 @@ +import os +import copy +import sys + +import torch +import torch_npu +from functools import wraps +from torch_npu.contrib import transfer_to_npu +from . import adaptor_amp_c + +if 'amp_C' in sys.modules: + del sys.modules['amp_C'] +sys.modules['amp_C'] = __import__('megatron_npu.adaptor_amp_c') + +from . import adaptor_core_tensor_parallel +from . import adaptor_core_utils +from . import adaptor_data_gpt_dataset +from . import adaptor_initialize +from . import adaptor_model_fused_layer_norm +from . import adaptor_model_fused_softmax +from . import adaptor_model_module +from . import adaptor_optimizer_clip_grads +from . import adaptor_optimizer_distrib_optimizer +from . import adaptor_optimizer_optimizer +from . import adaptor_p2p_communication +from . import adaptor_schedules + + +def wrapper_type(fn): + @wraps(fn) + def decorated(*args, **kwargs): + output = fn(*args, **kwargs) + if isinstance(output, str): + if output == 'torch.npu.FloatTensor': + output = 'torch.cuda.FloatTensor' + elif output == 'torch.npu.HalfTensor': + output = 'torch.cuda.HalfTensor' + return output + + return decorated + + +# deprecated +def wrapper_dist(fn): + @wraps(fn) + def wrapper(*args, **kwargs): + if args[0].dtype == torch.long and not kwargs.get('async_op', False): + new_args = list(copy.deepcopy(args)) + new_args[0] = new_args[0].int() + fn(*new_args, **kwargs) + args[0].copy_(new_args[0].long()) + return + return fn(*args, **kwargs) + + return wrapper + + +os.environ['CUDA_DEVICE_MAX_CONNECTIONS'] = '1' +torch.Tensor.type = wrapper_type(torch.Tensor.type) +torch.distributed.all_reduce = wrapper_dist(torch.distributed.all_reduce) diff --git a/megatron_npu/megatron_npu/adaptor_amp_c.py b/megatron_npu/megatron_npu/adaptor_amp_c.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/megatron_npu/megatron_npu/adaptor_core_tensor_parallel.py b/megatron_npu/megatron_npu/adaptor_core_tensor_parallel.py new file mode 100644 index 0000000000000000000000000000000000000000..d5908c192f8e0de17a78b13af1f55335da3a68fd --- /dev/null +++ b/megatron_npu/megatron_npu/adaptor_core_tensor_parallel.py @@ -0,0 +1,34 @@ +import torch +import megatron +#print("-----------------------", megatron.__file__, dir(megatron)) +from torch import _C +from torch_npu.npu import _lazy_call, device as device_ctx_manager + + +def _set_cuda_rng_state(new_state, device=-1): + if hasattr(_C, '_cuda_setRNGState') and callable(_C._cuda_setRNGState): + # older PyTorch + def cb(): + with device_ctx_manager(device): + _C._cuda_setRNGState(new_state) + else: + # newer PyTorch + if device == -1: + device = torch.device('cuda') + elif isinstance(device, str): + device = torch.device(device) + elif isinstance(device, int): + device = torch.device('cuda', device) + + def cb(): + idx = device.index + if idx is None: + idx = torch.cuda.current_device() + default_generator = torch.npu.default_generators[idx] + default_generator.set_state(new_state) + + _lazy_call(cb) + + +#megatron.core.tensor_parallel.random._set_cuda_rng_state = _set_cuda_rng_state +megatron.mpu.random._set_cuda_rng_state = _set_cuda_rng_state diff --git a/megatron_npu/megatron_npu/adaptor_core_utils.py b/megatron_npu/megatron_npu/adaptor_core_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a0523dcb997239d0c143e011f021d47113601de2 --- /dev/null +++ b/megatron_npu/megatron_npu/adaptor_core_utils.py @@ -0,0 +1,12 @@ +import torch +import megatron + + +def _kernel_make_viewless_tensor(inp, requires_grad): + out = torch.empty((1,), dtype=inp.dtype, device=inp.device, requires_grad=requires_grad, ) + with torch.no_grad(): + out.set_(inp.data) + return out + + +megatron.mpu.utils._kernel_make_viewless_tensor = _kernel_make_viewless_tensor diff --git a/megatron_npu/megatron_npu/adaptor_data_gpt_dataset.py b/megatron_npu/megatron_npu/adaptor_data_gpt_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..7710a29e96d6a033a69bcdd08d5a3b17cfef0e33 --- /dev/null +++ b/megatron_npu/megatron_npu/adaptor_data_gpt_dataset.py @@ -0,0 +1,147 @@ +import os +import time +import numpy as np +import torch +import megatron +from megatron import print_rank_0, get_args +from megatron import mpu +from megatron.data.gpt_dataset import _num_tokens, _num_epochs, _build_shuffle_idx, _build_doc_idx + + +def _build_index_mappings(name, data_prefix, documents, sizes, + num_samples, seq_length, seed): + """Build doc-idx, sample-idx, and shuffle-idx. + doc-idx: is an array (ordered) of documents to be used in training. + sample-idx: is the start document index and document offset for each + training sample. + shuffle-idx: maps the sample index into a random index into sample-idx. + """ + # Number of tokens in each epoch and number of required epochs. + tokens_per_epoch = _num_tokens(documents, sizes) + num_epochs = _num_epochs(tokens_per_epoch, seq_length, num_samples) + # rng state + np_rng = np.random.RandomState(seed=seed) + + # Filename of the index mappings. + _filename = data_prefix + _filename += '_{}_indexmap'.format(name) + _filename += '_{}ns'.format(num_samples) + _filename += '_{}sl'.format(seq_length) + _filename += '_{}s'.format(seed) + doc_idx_filename = _filename + '_doc_idx.npy' + sample_idx_filename = _filename + '_sample_idx.npy' + shuffle_idx_filename = _filename + '_shuffle_idx.npy' + + # Build the indexed mapping if not exist. + if int(os.environ['LOCAL_RANK']) == 0: + if (not os.path.isfile(doc_idx_filename)) or \ + (not os.path.isfile(sample_idx_filename)) or \ + (not os.path.isfile(shuffle_idx_filename)): + + print_rank_0(' > WARNING: could not find index map files, building ' + 'the indices on rank 0 ...') + + # For the last epoch, decide whether include the entire epoch + # in the global shuffle or not. + + # If we need only one epoch, then separating last epoch does + # not mean anything. + if num_epochs == 1: + separate_last_epoch = False + print(' > only one epoch required, setting ' + 'separate_last_epoch to False', flush=True) + + else: + # Get the number of samples for the last epoch + num_samples_from_epochs_minus_one = ( + (num_epochs - 1) * tokens_per_epoch - 1) // seq_length + last_epoch_num_samples = num_samples - \ + num_samples_from_epochs_minus_one + assert last_epoch_num_samples >= 0, \ + 'last epoch number of samples should be non-negative.' + num_samples_per_epoch = (tokens_per_epoch - 1) // seq_length + assert last_epoch_num_samples < (num_samples_per_epoch + 1), \ + 'last epoch number of samples exceeded max value.' + # If we have less than 80% of the samples for the last epoch, + # seperate out the epoch and treat it differently. + # Note: the 80% number is just based on common sense and can + # be adjusted if needed. + separate_last_epoch = (last_epoch_num_samples < + int(0.80 * num_samples_per_epoch)) + if separate_last_epoch: + string = ' > last epoch number of samples ({}) is smaller ' \ + 'than 80% of number of samples per epoch ({}), ' \ + 'setting separate_last_epoch to True' + else: + string = ' > last epoch number of samples ({}) is larger ' \ + 'than 80% of number of samples per epoch ({}), ' \ + 'setting separate_last_epoch to False' + print(string.format(last_epoch_num_samples, + num_samples_per_epoch), flush=True) + + # doc-idx. + start_time = time.time() + doc_idx = _build_doc_idx(documents, num_epochs, np_rng, + separate_last_epoch) + np.save(doc_idx_filename, doc_idx, allow_pickle=True) + print_rank_0(' > elasped time to build and save doc-idx mapping ' + '(seconds): {:4f}'.format(time.time() - start_time)) + # sample-idx. + start_time = time.time() + # Use C++ implementation for speed. + # First compile and then import. + from megatron.data import helpers + assert doc_idx.dtype == np.int32 + assert sizes.dtype == np.int32 + sample_idx = helpers.build_sample_idx(sizes, doc_idx, seq_length, + num_epochs, tokens_per_epoch) + # sample_idx = _build_sample_idx(sizes, doc_idx, seq_length, + # num_epochs, tokens_per_epoch) + np.save(sample_idx_filename, sample_idx, allow_pickle=True) + print_rank_0(' > elasped time to build and save sample-idx mapping ' + '(seconds): {:4f}'.format(time.time() - start_time)) + # shuffle-idx. + start_time = time.time() + # -1 is due to data structure used to retieve the index: + # sample i --> [sample_idx[i], sample_idx[i+1]) + if separate_last_epoch: + num_samples_ = num_samples_from_epochs_minus_one + else: + num_samples_ = sample_idx.shape[0] - 1 + shuffle_idx = _build_shuffle_idx(num_samples_, + sample_idx.shape[0] - 1, np_rng) + np.save(shuffle_idx_filename, shuffle_idx, allow_pickle=True) + print_rank_0(' > elasped time to build and save shuffle-idx mapping' + ' (seconds): {:4f}'.format(time.time() - start_time)) + + # This should be a barrier but nccl barrier assumes + # device_index=rank which is not the case for model + # parallel case + counts = torch.cuda.LongTensor([1]) + torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group()) + torch.distributed.all_reduce(counts, group=mpu.get_pipeline_model_parallel_group()) + assert counts[0].item() == ( + torch.distributed.get_world_size() // + torch.distributed.get_world_size(group=mpu.get_tensor_model_parallel_group())) + + # Load mappings. + start_time = time.time() + print_rank_0(' > loading doc-idx mapping from {}'.format( + doc_idx_filename)) + doc_idx = np.load(doc_idx_filename, allow_pickle=True, mmap_mode='r') + print_rank_0(' > loading sample-idx mapping from {}'.format( + sample_idx_filename)) + sample_idx = np.load(sample_idx_filename, allow_pickle=True, mmap_mode='r') + print_rank_0(' > loading shuffle-idx mapping from {}'.format( + shuffle_idx_filename)) + shuffle_idx = np.load(shuffle_idx_filename, allow_pickle=True, mmap_mode='r') + print_rank_0(' loaded indexed file in {:3.3f} seconds'.format( + time.time() - start_time)) + print_rank_0(' total number of samples: {}'.format( + sample_idx.shape[0])) + print_rank_0(' total number of epochs: {}'.format(num_epochs)) + + return doc_idx, sample_idx, shuffle_idx + + +megatron.data.gpt_dataset._build_index_mappings = _build_index_mappings diff --git a/megatron_npu/megatron_npu/adaptor_initialize.py b/megatron_npu/megatron_npu/adaptor_initialize.py new file mode 100644 index 0000000000000000000000000000000000000000..f2cd4b8072c1158fcc6917d7936a7fb3afc0987c --- /dev/null +++ b/megatron_npu/megatron_npu/adaptor_initialize.py @@ -0,0 +1,44 @@ +import sys +import time +import torch +import megatron +from megatron.initialize import _warmup_jit_function + +def _compile_dependencies(): + if torch.distributed.get_rank() == 0: + start_time = time.time() + print('> compiling dataset index builder ...') + from megatron.data.dataset_utils import compile_helper + compile_helper() + print('>>> done with dataset index builder. Compilation time: {:.3f} ' + 'seconds'.format(time.time() - start_time), flush=True) + + +def set_jit_fusion_options(): + """Set PyTorch JIT layer fusion options.""" + # flags required to enable jit fusion kernels + TORCH_MAJOR = int(torch.__version__.split('.')[0]) + TORCH_MINOR = int(torch.__version__.split('.')[1]) + if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10): + # nvfuser + torch._C._jit_set_profiling_executor(True) + torch._C._jit_set_profiling_mode(True) + torch._C._jit_override_can_fuse_on_cpu(False) + torch._C._jit_override_can_fuse_on_gpu(False) + torch._C._jit_set_texpr_fuser_enabled(False) + # torch._C._jit_set_nvfuser_enabled(True) + torch._C._debug_set_autodiff_subgraph_inlining(False) + else: + # legacy pytorch fuser + torch._C._jit_set_profiling_mode(False) + torch._C._jit_set_profiling_executor(False) + torch._C._jit_override_can_fuse_on_cpu(True) + torch._C._jit_override_can_fuse_on_gpu(True) + _warmup_jit_function() + + +megatron.initialize._compile_dependencies = _compile_dependencies + +for k, v in sys.modules.items(): + if 'megatron' in k and hasattr(v, 'set_jit_fusion_options'): + setattr(v, 'set_jit_fusion_options', set_jit_fusion_options) diff --git a/megatron_npu/megatron_npu/adaptor_model_fused_layer_norm.py b/megatron_npu/megatron_npu/adaptor_model_fused_layer_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..50f365b3cfd3ebccfb0610502c5a937f475acf5b --- /dev/null +++ b/megatron_npu/megatron_npu/adaptor_model_fused_layer_norm.py @@ -0,0 +1,35 @@ +import torch +import numbers +import megatron +from megatron.model.fused_layer_norm import MixedFusedLayerNorm, HAVE_PERSIST_LAYER_NORM +from megatron.mpu.utils import make_viewless_tensor + + +def MixedFusedLayerNormInit(self, normalized_shape, eps=1e-5, no_persist_layer_norm=True, sequence_parallel=False): + super(MixedFusedLayerNorm, self).__init__() + if isinstance(normalized_shape, numbers.Integral): + normalized_shape = (normalized_shape,) + self.normalized_shape = torch.Size(normalized_shape) + self.eps = eps + self.weight = torch.nn.parameter.Parameter(torch.Tensor(*normalized_shape)) + self.bias = torch.nn.parameter.Parameter(torch.Tensor(*normalized_shape)) + self.reset_parameters() + self.no_persist_layer_norm = True + self.sequence_parallel = sequence_parallel + + # set sequence parallelism flag on weight and bias parameters + setattr(self.weight, 'sequence_parallel', self.sequence_parallel) + setattr(self.bias, 'sequence_parallel', self.sequence_parallel) + + +def MixedFusedLayerNormForward(self, input): + if self.no_persist_layer_norm: + return torch.nn.functional.layer_norm(input, self.normalized_shape, self.weight, self.bias, self.eps) + else: + output = FastLayerNormFN.apply(input, self.weight, self.bias, self.eps) + output = make_viewless_tensor(inp=output, requires_grad=input.requires_grad, keep_graph=True) + return output + + +megatron.model.fused_layer_norm.MixedFusedLayerNorm.__init__ = MixedFusedLayerNormInit +megatron.model.fused_layer_norm.MixedFusedLayerNorm.forward = MixedFusedLayerNormForward diff --git a/megatron_npu/megatron_npu/adaptor_model_fused_softmax.py b/megatron_npu/megatron_npu/adaptor_model_fused_softmax.py new file mode 100644 index 0000000000000000000000000000000000000000..79f7dffd63416d438222576f9175a1856e051aff --- /dev/null +++ b/megatron_npu/megatron_npu/adaptor_model_fused_softmax.py @@ -0,0 +1,21 @@ +import torch +import torch_npu +import megatron + + +def is_kernel_available(self, mask, b, np, sq, sk): + return ( + self.scaled_masked_softmax_fusion # user want to fuse + and self.input_in_float16 # input must be fp16 + and 32 < sk <= 2048 # sk must be 32 ~ 2048 + and sq % 16 == 0 # sq must be divisor of 16 + and sk % 16 == 0 # sk must be divisor of 16 + ) + + +def forward_fused_softmax(self, input, mask): + return torch_npu.npu_scaled_masked_softmax(input, mask, self.scale, False) + + +megatron.model.fused_softmax.FusedScaleMaskSoftmax.is_kernel_available = is_kernel_available +megatron.model.fused_softmax.FusedScaleMaskSoftmax.forward_fused_softmax = forward_fused_softmax diff --git a/megatron_npu/megatron_npu/adaptor_model_module.py b/megatron_npu/megatron_npu/adaptor_model_module.py new file mode 100644 index 0000000000000000000000000000000000000000..38b6d794770a1559a83d9923ee3c88f436d587b4 --- /dev/null +++ b/megatron_npu/megatron_npu/adaptor_model_module.py @@ -0,0 +1,31 @@ +import torch +import megatron +from megatron.model.module import conversion_helper + + +def fp32_to_float16(val, float16_convertor): + def half_conversion(val): + val_typecheck = val + if isinstance(val_typecheck, (torch.nn.parameter.Parameter, torch.autograd.Variable)): + val_typecheck = val.data + if val_typecheck.dtype == torch.float32: + val = float16_convertor(val) + return val + + return conversion_helper(val, half_conversion) + + +def float16_to_fp32(val): + def float_conversion(val): + val_typecheck = val + if isinstance(val_typecheck, (torch.nn.parameter.Parameter, torch.autograd.Variable)): + val_typecheck = val.data + if val_typecheck.dtype in [torch.float16, torch.bfloat16]: + val = val.float() + return val + + return conversion_helper(val, float_conversion) + + +megatron.model.module.fp32_to_float16 = fp32_to_float16 +megatron.model.module.float16_to_fp32 = float16_to_fp32 diff --git a/megatron_npu/megatron_npu/adaptor_optimizer_clip_grads.py b/megatron_npu/megatron_npu/adaptor_optimizer_clip_grads.py new file mode 100644 index 0000000000000000000000000000000000000000..8d9f6adeb4843e42b0bda446d3a63a33233a037f --- /dev/null +++ b/megatron_npu/megatron_npu/adaptor_optimizer_clip_grads.py @@ -0,0 +1,51 @@ +import sys +import math +import torch +import megatron.optimizer + + +def clip_grad_norm_fp32(parameters, grads_for_norm, max_norm, norm_type=2, model_parallel_group=None): + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + if isinstance(grads_for_norm, torch.Tensor): + grads_for_norm = [grads_for_norm] + + # Grads. + grads = [] + for param in parameters: + if param.grad is not None: + assert param.grad.type() == 'torch.cuda.FloatTensor' + grads.append(param.grad.detach()) + + # Norm parameters. + max_norm = float(max_norm) + norm_type = float(norm_type) + total_norm = 0.0 + + # Calculate norm. + if norm_type == math.inf: + total_norm = max(grad.abs().max() for grad in grads_for_norm) + total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) + # Take max across all model-parallel GPUs. + torch.distributed.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=model_parallel_group) + total_norm = total_norm_cuda[0].item() + else: + for grad in grads_for_norm: + grad_norm = torch.norm(grad, norm_type) + total_norm += grad_norm ** norm_type + + # Sum across all model-parallel GPUs. + torch.distributed.all_reduce(total_norm, op=torch.distributed.ReduceOp.SUM, group=model_parallel_group) + total_norm = total_norm.item() ** (1.0 / norm_type) + + # Scale. + clip_coeff = max_norm / (total_norm + 1.0e-6) + if clip_coeff < 1.0: + for p in parameters: + p.grad.detach().mul_(clip_coeff) + return total_norm + + +for k, v in sys.modules.items(): + if 'megatron' in k and hasattr(v, 'clip_grad_norm_fp32'): + setattr(v, 'clip_grad_norm_fp32', clip_grad_norm_fp32) diff --git a/megatron_npu/megatron_npu/adaptor_optimizer_distrib_optimizer.py b/megatron_npu/megatron_npu/adaptor_optimizer_distrib_optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..fed4412f9191781a0df574d979879f17eaa862ad --- /dev/null +++ b/megatron_npu/megatron_npu/adaptor_optimizer_distrib_optimizer.py @@ -0,0 +1,63 @@ +import torch +import megatron.optimizer + + +def DistributedOptimizerInit(self, optimizer, clip_grad, log_num_zeros_in_grad, params_have_main_grad, + use_contiguous_buffers_in_local_ddp, fp16, bf16, params_dtype, grad_scaler, models): + super(megatron.optimizer.distrib_optimizer.DistributedOptimizer, self).__init__( + optimizer, clip_grad, log_num_zeros_in_grad, + params_have_main_grad, use_contiguous_buffers_in_local_ddp, + fp16, bf16, params_dtype, grad_scaler, models) + + # Verify that contiguous buffers are being used. + # - Note: this should already be checked in arguments.py. + assert use_contiguous_buffers_in_local_ddp + + # Model grad buffer ranges. + self.model_gbuf_ranges = [] + for model_index, model in enumerate(self.models): + self.model_gbuf_ranges.append(self.build_model_gbuf_range_map(model)) + self.model_param_gbuf_map = \ + self.build_model_param_gbuf_map(self.model_gbuf_ranges) + + # Optimizer ranges. + self.opt_group_ranges = self.build_optimizer_group_ranges( + self.optimizer.param_groups, + self.model_gbuf_ranges) + + # Allocate main param shards. + ( + self.model_float16_groups, + self.model_fp32_groups, + self.shard_float16_groups, + self.shard_fp32_groups, + self.shard_fp32_from_float16_groups, + ) = self.build_model_and_main_param_groups(self.model_gbuf_ranges, + self.model_param_gbuf_map, + self.opt_group_ranges) + + # Initialize param buffers. + # - These are views on the DDP model's grad buffers, that share + # storage & have their own dtype. This is safe because the param + # dtype size is always <= grad dtype size. + self.param_buffers = [] + for model_index, model in enumerate(self.models): + current_param_buffers = {} + for dtype, grad_buffer in model._grad_buffers.items(): + param_buffer = torch.tensor(torch.flatten(grad_buffer.data), # grad_buffer.data.storage()._untyped(), + dtype=params_dtype, + device=grad_buffer.data.device) + param_buffer = param_buffer[:grad_buffer.numel_padded] + current_param_buffers[dtype] = param_buffer + self.param_buffers.append(current_param_buffers) + + # Update optimizer groups. + # - Also, leverage state_dict() and load_state_dict() to + # recast preexisting per-param state tensors. + self.optimizer.param_groups = \ + [g["orig_group"] for g in self.opt_group_ranges] + self.optimizer.load_state_dict(self.optimizer.state_dict()) + + +# no need for BLOOM +#megatron.optimizer.distrib_optimizer.DistributedOptimizer.__init__ = DistributedOptimizerInit diff --git a/megatron_npu/megatron_npu/adaptor_optimizer_optimizer.py b/megatron_npu/megatron_npu/adaptor_optimizer_optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..41b37785852b6e13feee8d2980d6b22aafbc55ef --- /dev/null +++ b/megatron_npu/megatron_npu/adaptor_optimizer_optimizer.py @@ -0,0 +1,268 @@ +import sys +import torch +import megatron.optimizer +from megatron import mpu + +import math +import torch +import torch_npu +from torch import Tensor +from typing import List, Optional + + +from torch.optim.optimizer import Optimizer + + +def adamw_torch(params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + max_exp_avg_sqs: List[Tensor], + state_steps: List[int], + *, + amsgrad: bool, + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + eps: float, + maximize: bool): + r"""Functional API that performs AdamW algorithm computation. + See :class:`~torch.optim.AdamW` for details. + """ + for i, param in enumerate(params): + grad = grads[i] if not maximize else -grads[i] + exp_avg = exp_avgs[i] + exp_avg_sq = exp_avg_sqs[i] + step = state_steps[i] + + # Perform stepweight decay + param.mul_(1 - lr * weight_decay) + + bias_correction1 = 1 - beta1 ** step + bias_correction2 = 1 - beta2 ** step + + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + if amsgrad: + # Maintains the maximum of all 2nd moment running avg. till now + torch.maximum(max_exp_avg_sqs[i], exp_avg_sq, out=max_exp_avg_sqs[i]) + # Use the max. for normalizing running avg. of gradient + denom = (max_exp_avg_sqs[i].sqrt() / math.sqrt(bias_correction2)).add_(eps) + else: + denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps) + + step_size = lr / bias_correction1 + + param.addcdiv_(exp_avg, denom, value=-step_size) + + +def adamw(params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + max_exp_avg_sqs: List[Tensor], + state_steps: List[int], + *, + amsgrad: bool, + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + eps: float, + maximize: bool): + r"""Functional API that performs AdamW algorithm computation. + See :class:`~torch.optim.AdamW` for details. + """ + for i, param in enumerate(params): + grad = grads[i] + exp_avg = exp_avgs[i] + exp_avg_sq = exp_avg_sqs[i] + step = state_steps[i] + + # Perform stepweight decay + ## param.mul_(1 - lr * weight_decay) + bias_correction1 = beta1 ** step + bias_correction2 = beta2 ** step + + param.data, exp_avg, exp_avg_sq = torch_npu.npu_apply_adam_w( + bias_correction1, + bias_correction2, + lr, + weight_decay, + beta1, + beta2, + eps, + grad, + None, + amsgrad, + maximize, + out=(param.data, exp_avg, exp_avg_sq) + ) + +class AdamW(Optimizer): + r"""Implements AdamW algorithm. + .. math:: + \begin{aligned} + &\rule{110mm}{0.4pt} \\ + &\textbf{input} : \gamma \text{(lr)}, \: \beta_1, \beta_2 + \text{(betas)}, \: \theta_0 \text{(params)}, \: f(\theta) \text{(objective)}, + \: \epsilon \text{ (epsilon)} \\ + &\hspace{13mm} \lambda \text{(weight decay)}, \: \textit{amsgrad}, + \: \textit{maximize} \\ + &\textbf{initialize} : m_0 \leftarrow 0 \text{ (first moment)}, v_0 \leftarrow 0 + \text{ ( second moment)}, \: \widehat{v_0}^{max}\leftarrow 0 \\[-1.ex] + &\rule{110mm}{0.4pt} \\ + &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\ + &\hspace{5mm}\textbf{if} \: \textit{maximize}: \\ + &\hspace{10mm}g_t \leftarrow -\nabla_{\theta} f_t (\theta_{t-1}) \\ + &\hspace{5mm}\textbf{else} \\ + &\hspace{10mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\ + &\hspace{5mm} \theta_t \leftarrow \theta_{t-1} - \gamma \lambda \theta_{t-1} \\ + &\hspace{5mm}m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\ + &\hspace{5mm}v_t \leftarrow \beta_2 v_{t-1} + (1-\beta_2) g^2_t \\ + &\hspace{5mm}\widehat{m_t} \leftarrow m_t/\big(1-\beta_1^t \big) \\ + &\hspace{5mm}\widehat{v_t} \leftarrow v_t/\big(1-\beta_2^t \big) \\ + &\hspace{5mm}\textbf{if} \: amsgrad \\ + &\hspace{10mm}\widehat{v_t}^{max} \leftarrow \mathrm{max}(\widehat{v_t}^{max}, + \widehat{v_t}) \\ + &\hspace{10mm}\theta_t \leftarrow \theta_t - \gamma \widehat{m_t}/ + \big(\sqrt{\widehat{v_t}^{max}} + \epsilon \big) \\ + &\hspace{5mm}\textbf{else} \\ + &\hspace{10mm}\theta_t \leftarrow \theta_t - \gamma \widehat{m_t}/ + \big(\sqrt{\widehat{v_t}} + \epsilon \big) \\ + &\rule{110mm}{0.4pt} \\[-1.ex] + &\bf{return} \: \theta_t \\[-1.ex] + &\rule{110mm}{0.4pt} \\[-1.ex] + \end{aligned} + For further details regarding the algorithm we refer to `Decoupled Weight Decay Regularization`_. + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay coefficient (default: 1e-2) + amsgrad (boolean, optional): whether to use the AMSGrad variant of this + algorithm from the paper `On the Convergence of Adam and Beyond`_ + (default: False) + maximize (bool, optional): maximize the params based on the objective, instead of + minimizing (default: False) + .. _Decoupled Weight Decay Regularization: + https://arxiv.org/abs/1711.05101 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ + """ + + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, + weight_decay=1e-2, amsgrad=False, *, maximize: bool = False): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + if not 0.0 <= weight_decay: + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + defaults = dict(lr=lr, betas=betas, eps=eps, + weight_decay=weight_decay, amsgrad=amsgrad, maximize=maximize) + super(AdamW, self).__init__(params, defaults) + + def __setstate__(self, state): + super(AdamW, self).__setstate__(state) + for group in self.param_groups: + group.setdefault('amsgrad', False) + group.setdefault('maximize', False) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + Args: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad = [] + grads = [] + exp_avgs = [] + exp_avg_sqs = [] + state_sums = [] + max_exp_avg_sqs = [] + state_steps = [] + amsgrad = group['amsgrad'] + beta1, beta2 = group['betas'] + + for p in group['params']: + if p.grad is None: + continue + params_with_grad.append(p) + if p.grad.is_sparse: + raise RuntimeError('AdamW does not support sparse gradients') + grads.append(p.grad) + + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = 0 + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) + if amsgrad: + # Maintains max of all exp. moving avg. of sq. grad. values + state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) + + exp_avgs.append(state['exp_avg']) + exp_avg_sqs.append(state['exp_avg_sq']) + + if amsgrad: + max_exp_avg_sqs.append(state['max_exp_avg_sq']) + + # update the steps for each param group update + state['step'] += 1 + # record the step after step update + state_steps.append(state['step']) + + # adamw_torch(params_with_grad, + adamw(params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + amsgrad=amsgrad, + beta1=beta1, + beta2=beta2, + lr=group['lr'], + weight_decay=group['weight_decay'], + eps=group['eps'], + maximize=group['maximize']) + + return loss + + +def _unscale_main_grads_and_check_for_nan(self): + main_grads = self._collect_main_grad_data_for_unscaling() + self.found_inf.fill_(0.0) + torch._amp_foreach_non_finite_check_and_unscale_(main_grads, self.found_inf, self.grad_scaler.inv_scale) + torch.distributed.all_reduce(self.found_inf, op=torch.distributed.ReduceOp.MAX, group=self.get_model_parallel_group()) + torch.distributed.all_reduce(self.found_inf, op=torch.distributed.ReduceOp.MAX, group=mpu.get_data_parallel_group()) + found_inf_flag = (self.found_inf.item() > 0) + return found_inf_flag + + +# megatron.optimizer.Adam = torch.optim.AdamW +megatron.optimizer.Adam = AdamW + +megatron.optimizer.optimizer.MixedPrecisionOptimizer._unscale_main_grads_and_check_for_nan = _unscale_main_grads_and_check_for_nan diff --git a/megatron_npu/megatron_npu/adaptor_p2p_communication.py b/megatron_npu/megatron_npu/adaptor_p2p_communication.py new file mode 100644 index 0000000000000000000000000000000000000000..9811066571622273bf7a31e20574c50fd16238fa --- /dev/null +++ b/megatron_npu/megatron_npu/adaptor_p2p_communication.py @@ -0,0 +1,141 @@ +import torch +import operator +import megatron +from functools import reduce +from megatron import get_args +from megatron import mpu +from megatron.p2p_communication import _communicate_shapes + + +def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next, tensor_shape, dtype_=None): + args = get_args() + + # Create placeholder tensors for receive in forward and backward directions + # if needed. + tensor_recv_prev = None + tensor_recv_next = None + + # Some legacy inference code doesn't set the tensor shape, do so now + # for the normal values for gpt/bert. This could be removed if inference + # code is changed to provide tensor_shape. + if not args.variable_seq_lengths: + if tensor_shape is None: + recv_prev_shape = (args.seq_length, args.micro_batch_size, args.hidden_size) + recv_next_shape = (args.seq_length, args.micro_batch_size, args.hidden_size) + else: + recv_prev_shape = tensor_shape + recv_next_shape = tensor_shape + else: + recv_prev_shape, recv_next_shape = \ + _communicate_shapes(tensor_send_next, + tensor_send_prev, + recv_prev, + recv_next) + + override_scatter_gather_tensors_in_pipeline = False + if args.scatter_gather_tensors_in_pipeline and \ + not args.sequence_parallel: + recv_prev_chunk_shape = reduce(operator.mul, recv_prev_shape, 1) + recv_next_chunk_shape = reduce(operator.mul, recv_next_shape, 1) + if recv_prev_chunk_shape % mpu.get_tensor_model_parallel_world_size() == 0 and \ + recv_next_chunk_shape % mpu.get_tensor_model_parallel_world_size() == 0: + recv_prev_chunk_shape = recv_prev_chunk_shape // \ + mpu.get_tensor_model_parallel_world_size() + recv_next_chunk_shape = recv_next_chunk_shape // \ + mpu.get_tensor_model_parallel_world_size() + else: + recv_prev_chunk_shape = recv_prev_shape + recv_next_chunk_shape = recv_next_shape + override_scatter_gather_tensors_in_pipeline = True + else: + recv_prev_chunk_shape = recv_prev_shape + recv_next_chunk_shape = recv_next_shape + + dtype = args.params_dtype + if args.fp32_residual_connection: + dtype = torch.float + + requires_grad = True + if dtype_ is not None: + dtype = dtype_ + requires_grad = False + + if recv_prev: + tensor_recv_prev = torch.empty(recv_prev_chunk_shape, + requires_grad=requires_grad, + device=torch.cuda.current_device(), + dtype=dtype) + if recv_next: + tensor_recv_next = torch.empty(recv_next_chunk_shape, + requires_grad=requires_grad, + device=torch.cuda.current_device(), + dtype=dtype) + + # Split tensor into smaller chunks if using scatter-gather optimization. + if not override_scatter_gather_tensors_in_pipeline and \ + args.scatter_gather_tensors_in_pipeline and \ + not args.sequence_parallel: + if tensor_send_next is not None: + tensor_send_next = mpu.split_tensor_into_1d_equal_chunks(tensor_send_next) + + if tensor_send_prev is not None: + tensor_send_prev = mpu.split_tensor_into_1d_equal_chunks(tensor_send_prev) + + # Send tensors in both the forward and backward directions as appropriate. + if args.use_ring_exchange_p2p: + torch.distributed.ring_exchange(tensor_send_prev=tensor_send_prev, + tensor_recv_prev=tensor_recv_prev, + tensor_send_next=tensor_send_next, + tensor_recv_next=tensor_recv_next, + group=mpu.get_pipeline_model_parallel_group()) + else: + ops = [] + if tensor_send_prev is not None: + send_prev_op = torch.distributed.P2POp( + torch.distributed.isend, tensor_send_prev, + mpu.get_pipeline_model_parallel_prev_rank()) + ops.append(send_prev_op) + if tensor_recv_prev is not None: + recv_prev_op = torch.distributed.P2POp( + torch.distributed.irecv, tensor_recv_prev, + mpu.get_pipeline_model_parallel_prev_rank()) + ops.append(recv_prev_op) + if tensor_recv_next is not None: + recv_next_op = torch.distributed.P2POp( + torch.distributed.irecv, tensor_recv_next, + mpu.get_pipeline_model_parallel_next_rank()) + ops.append(recv_next_op) + if tensor_send_next is not None: + send_next_op = torch.distributed.P2POp( + torch.distributed.isend, tensor_send_next, + mpu.get_pipeline_model_parallel_next_rank()) + ops.append(send_next_op) + if len(ops) > 0: + reqs = torch.distributed.batch_isend_irecv(ops) + for req in reqs: + req.wait() + # To protect against race condition when using batch_isend_irecv(). + torch.cuda.synchronize() + + # If using scatter-gather optimization, gather smaller chunks. + if not override_scatter_gather_tensors_in_pipeline and \ + args.scatter_gather_tensors_in_pipeline and \ + not args.sequence_parallel: + if recv_prev: + tensor_recv_prev = mpu.gather_split_1d_tensor( + tensor_recv_prev).view(recv_prev_shape).requires_grad_() + tensor_recv_prev = mpu.utils.make_viewless_tensor(tensor_recv_prev, + requires_grad=True, + keep_graph=False) + + if recv_next: + tensor_recv_next = mpu.gather_split_1d_tensor( + tensor_recv_next).view(recv_next_shape).requires_grad_() + tensor_recv_next = mpu.utils.make_viewless_tensor(tensor_recv_next, + requires_grad=True, + keep_graph=False) + + return tensor_recv_prev, tensor_recv_next + + +megatron.p2p_communication._communicate = _communicate diff --git a/megatron_npu/megatron_npu/adaptor_schedules.py b/megatron_npu/megatron_npu/adaptor_schedules.py new file mode 100644 index 0000000000000000000000000000000000000000..2736546b3a4f585ca3cd225936a351cf72661b71 --- /dev/null +++ b/megatron_npu/megatron_npu/adaptor_schedules.py @@ -0,0 +1,132 @@ +import torch +import megatron +from megatron import get_args +from megatron import mpu +from megatron.model import ModelType +from megatron.schedules import custom_backward, dummy_handler, forward_step, get_num_microbatches + + +def clear_npu_overflow_flag(): + # TODO: async operation and 1980B compatible adaptation + float_status = torch.zeros(8).npu() + result = torch.npu_clear_float_status(float_status) + + +def get_npu_overflow_flag(): + # TODO: async operation and 1980B compatible adaptation + float_status = torch.zeros(8).npu() + result = torch.npu_get_float_status(float_status) + if float_status.cpu()[0] != 0: + return True + else: + return False + + +def set_npu_overflow_flag(): + # TODO: async operation and 1980B compatible adaptation + torch.tensor([65504]).half().npu() + 100 + + +def backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad, timers): + args = get_args() + if timers is not None: + timers('backward-compute', log_level=2).start() + + # Retain the grad on the input_tensor. + unwrap_input_tensor_grad = False + if not isinstance(input_tensor, list): + input_tensor = [input_tensor] + unwrap_input_tensor_grad = True + for x in input_tensor: + if x is not None: + x.retain_grad() + + if not isinstance(output_tensor, list): + output_tensor = [output_tensor] + if not isinstance(output_tensor_grad, list): + output_tensor_grad = [output_tensor_grad] + + # Backward pass. + clear_npu_overflow_flag() + if output_tensor_grad[0] is None: + output_tensor = optimizer.scale_loss(output_tensor[0]) + custom_backward(output_tensor[0], output_tensor_grad[0]) + + # Collect the grad of the input_tensor. + input_tensor_grad = [None] + if input_tensor is not None: + input_tensor_grad = [] + for x in input_tensor: + if x is None: + input_tensor_grad.append(None) + else: + input_tensor_grad.append(x.grad) + + # Handle single skip connection if it exists (encoder_hidden_state in + # model with encoder and decoder). + if mpu.get_pipeline_model_parallel_world_size() > 1 and \ + mpu.is_pipeline_stage_after_split() and \ + args.model_type == ModelType.encoder_and_decoder: + if output_tensor_grad[1] is not None: + input_tensor_grad[-1].add_(output_tensor_grad[1]) + if unwrap_input_tensor_grad: + input_tensor_grad = input_tensor_grad[0] + + if timers is not None: + timers('backward-compute').stop() + + return input_tensor_grad + + +def forward_backward_no_pipelining(forward_step_func, data_iterator, model, optimizer, timers, forward_only, + collect_non_loss_data=False): + assert len(model) == 1 + model = model[0] + + context_handler = dummy_handler + if isinstance(model, torch.nn.parallel.distributed.DistributedDataParallel): + context_handler = model.no_sync + + forward_data_store = [] + input_tensor, output_tensor_grad = None, None + overflow_flag_all = False + with context_handler(): + for i in range(get_num_microbatches() - 1): + output_tensor = forward_step(forward_step_func, data_iterator, + model, input_tensor, forward_data_store, + timers, collect_non_loss_data) + if not forward_only: + backward_step(optimizer, input_tensor, output_tensor, + output_tensor_grad, timers) + + overflow_flag = get_npu_overflow_flag() + overflow_flag_all = overflow_flag or overflow_flag_all + + # Run computation for last microbatch out of context handler (want to + # synchronize gradients). + output_tensor = forward_step(forward_step_func, data_iterator, + model, input_tensor, forward_data_store, + timers, collect_non_loss_data) + if not forward_only: + backward_step(optimizer, input_tensor, output_tensor, + output_tensor_grad, timers) + + overflow_flag = get_npu_overflow_flag() + overflow_flag_all = overflow_flag or overflow_flag_all + if overflow_flag_all: + set_npu_overflow_flag() + return forward_data_store + + +def deallocate_output_tensor(out): + if out is None: + return + assert isinstance(out, torch.Tensor), "expected Tensor, found %s." % type(out).__name__ + assert out._base is None, "counter-productive to free a view of another tensor." + with torch.no_grad(): + out.set_(torch.empty((1,), device=out.device, dtype=out.dtype)) + + +megatron.schedules.backward_step = backward_step +megatron.schedules.forward_backward_no_pipelining = forward_backward_no_pipelining +megatron.schedules.deallocate_output_tensor = deallocate_output_tensor diff --git a/megatron_npu/requirements.txt b/megatron_npu/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..a2c483ab2ba8a77f2254a459fb6c44625dfedf03 --- /dev/null +++ b/megatron_npu/requirements.txt @@ -0,0 +1,8 @@ +regex +pybind11 +six +torchvision==0.9.1 +decorator +sympy +scipy +attrs \ No newline at end of file diff --git a/megatron_npu/setup.py b/megatron_npu/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..d8c6804059da75bc7b7b0df1450225e05552a1df --- /dev/null +++ b/megatron_npu/setup.py @@ -0,0 +1,21 @@ +import setuptools + +setuptools.setup( + name="megatron_npu", + version="0.1", + description="An adaptor for megatron on Ascend NPU", + packages=['megatron_npu'], + install_package_data=True, + include_package_data=True, + license='Apache2', + license_file='./LICENSE', + classifiers=[ + "Operating System :: OS Independent", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.6", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + ], + python_requires=">=3.7", +) diff --git a/megatron_npu/tests/dataset_preprocess_t5.sh b/megatron_npu/tests/dataset_preprocess_t5.sh new file mode 100644 index 0000000000000000000000000000000000000000..fc6f736f41dbc67671a00f283aac7da1c3eb1552 --- /dev/null +++ b/megatron_npu/tests/dataset_preprocess_t5.sh @@ -0,0 +1,42 @@ +# Step 1: Download enwiki +wget https://dumps.wikimedia.org/enwiki/latest/enwiki-latest-pages-articles.xml.bz2 +bzip2 -dk enwiki-latest-pages-articles.xml.bz2 + + +# Step 2: Download WikiExtractor +pip3 install wikiextractor +python3 -m wikiextractor.WikiExtractor enwiki-latest-pages-articles.xml --json + + +# Step3: Concat json +WIKI_DIR=./text +OUTDIR=./output + +mkdir -p $OUTDIR +rm $OUTDIR/wiki_all.json +touch $OUTDIR/wiki_all.json + +find "$WIKI_DIR" -type f -print0 | + while IFS= read -r -d '' line; do + filename=$(echo "$line" | rev | cut -d'/' -f 1 | rev) + subfilename=$(echo "$line" | rev | cut -d'/' -f 2 | rev) + prefix="${subfilename}_${filename}" + new_name=$(echo "$line") + echo "Procesing $prefix, $filename, $new_name" + cat $new_name >> $OUTDIR/wiki_all.json + done + + +# Step4: Download Vocab and Do Preprocess +wget https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt +VOCAB=./bert-large-uncased-vocab.txt +python3tools/preprocess_data.py \ + --input $OUTDIR/wiki_all.json \ + --output-prefix $OUTDIR/my-t5 \ + --vocab $VOCAB \ + --dataset-impl mmap \ + --tokenizer-type BertWordPieceLowerCase \ + --split-sentences \ + --workers $(nproc) + + diff --git a/megatron_npu/tests/env_npu.sh b/megatron_npu/tests/env_npu.sh new file mode 100644 index 0000000000000000000000000000000000000000..74fd11c2b5de7789f1abf0beafb8df3e6bd9eec3 --- /dev/null +++ b/megatron_npu/tests/env_npu.sh @@ -0,0 +1,83 @@ +#!/bin/bash +export install_path=/usr/local/Ascend + +if [ -d ${install_path}/toolkit ]; then + export LD_LIBRARY_PATH=${install_path}/fwkacllib/lib64/:/usr/include/hdf5/lib/:/usr/local/:/usr/local/lib/:/usr/lib/:${install_path}/driver/lib64/common/:${install_path}/driver/lib64/driver/:${install_path}/add-ons:${path_lib}:${LD_LIBRARY_PATH} + export PATH=${install_path}/fwkacllib/ccec_compiler/bin:${install_path}/fwkacllib/bin:$PATH + export PYTHONPATH=${install_path}/fwkacllib/python/site-packages:${install_path}/tfplugin/python/site-packages:${install_path}/toolkit/python/site-packages:$PYTHONPATH + export PYTHONPATH=/usr/local/python3.7.5/lib/python3.7/site-packages:$PYTHONPATH + export ASCEND_OPP_PATH=${install_path}/opp +else + if [ -d ${install_path}/nnae/latest ];then + export LD_LIBRARY_PATH=${install_path}/nnae/latest/fwkacllib/lib64/:/usr/local/:/usr/local/python3.7.5/lib/:/usr/local/openblas/lib:/usr/local/lib/:/usr/lib64/:/usr/lib/:${install_path}/driver/lib64/common/:${install_path}/driver/lib64/driver/:${install_path}/add-ons/:/usr/lib/aarch64_64-linux-gnu:$LD_LIBRARY_PATH + export PATH=$PATH:${install_path}/nnae/latest/fwkacllib/ccec_compiler/bin/:${install_path}/nnae/latest/toolkit/tools/ide_daemon/bin/ + export ASCEND_OPP_PATH=${install_path}/nnae/latest/opp/ + export OPTION_EXEC_EXTERN_PLUGIN_PATH=${install_path}/nnae/latest/fwkacllib/lib64/plugin/opskernel/libfe.so:${install_path}/nnae/latest/fwkacllib/lib64/plugin/opskernel/libaicpu_engine.so:${install_path}/nnae/latest/fwkacllib/lib64/plugin/opskernel/libge_local_engine.so + export PYTHONPATH=${install_path}/nnae/latest/fwkacllib/python/site-packages/:${install_path}/nnae/latest/fwkacllib/python/site-packages/auto_tune.egg/auto_tune:${install_path}/nnae/latest/fwkacllib/python/site-packages/schedule_search.egg:$PYTHONPATH + export ASCEND_AICPU_PATH=${install_path}/nnae/latest + else + export LD_LIBRARY_PATH=${install_path}/ascend-toolkit/latest/fwkacllib/lib64/:/usr/local/:/usr/local/lib/:/usr/lib64/:/usr/lib/:/usr/local/python3.7.5/lib/:/usr/local/openblas/lib:${install_path}/driver/lib64/common/:${install_path}/driver/lib64/driver/:${install_path}/add-ons/:/usr/lib/aarch64-linux-gnu:$LD_LIBRARY_PATH + export PATH=$PATH:${install_path}/ascend-toolkit/latest/fwkacllib/ccec_compiler/bin/:${install_path}/ascend-toolkit/latest/toolkit/tools/ide_daemon/bin/ + export ASCEND_OPP_PATH=${install_path}/ascend-toolkit/latest/opp/ + export OPTION_EXEC_EXTERN_PLUGIN_PATH=${install_path}/ascend-toolkit/latest/fwkacllib/lib64/plugin/opskernel/libfe.so:${install_path}/ascend-toolkit/latest/fwkacllib/lib64/plugin/opskernel/libaicpu_engine.so:${install_path}/ascend-toolkit/latest/fwkacllib/lib64/plugin/opskernel/libge_local_engine.so + export PYTHONPATH=${install_path}/ascend-toolkit/latest/fwkacllib/python/site-packages/:${install_path}/ascend-toolkit/latest/fwkacllib/python/site-packages/auto_tune.egg/auto_tune:${install_path}/ascend-toolkit/latest/fwkacllib/python/site-packages/schedule_search.egg:$PYTHONPATH + export ASCEND_AICPU_PATH=${install_path}/ascend-toolkit/latest + fi +fi + +${install_path}/driver/tools/msnpureport -g error -d 0 +${install_path}/driver/tools/msnpureport -g error -d 1 +${install_path}/driver/tools/msnpureport -g error -d 2 +${install_path}/driver/tools/msnpureport -g error -d 3 +${install_path}/driver/tools/msnpureport -g error -d 4 +${install_path}/driver/tools/msnpureport -g error -d 5 +${install_path}/driver/tools/msnpureport -g error -d 6 +${install_path}/driver/tools/msnpureport -g error -d 7 + +#将Host日志输出到串口,0-关闭/1-开启 +export ASCEND_SLOG_PRINT_TO_STDOUT=0 +#设置默认日志级别,0-debug/1-info/2-warning/3-error +export ASCEND_GLOBAL_LOG_LEVEL=3 +#设置Event日志开启标志,0-关闭/1-开启 +export ASCEND_GLOBAL_EVENT_ENABLE=0 +#设置是否开启taskque,0-关闭/1-开启 +export TASK_QUEUE_ENABLE=1 +#设置是否开启PTCopy,0-关闭/1-开启 +export PTCOPY_ENABLE=1 +#设置是否开启2个非连续combined标志,0-关闭/1-开启 +export COMBINED_ENABLE=1 +#设置是否开启3个非连续combined标志,0-关闭/1-开启 +export TRI_COMBINED_ENABLE=1 +#设置特殊场景是否需要重新编译,不需要修改 +export DYNAMIC_OP="ADD#MUL" +# HCCL白名单开关,1-关闭/0-开启 +export HCCL_WHITELIST_DISABLE=1 +# HCCL默认超时时间120s较少,修改为1800s对齐PyTorch默认设置 +export HCCL_CONNECT_TIMEOUT=1800 +# 关闭部分算子的CACHE功能以取得正常的精度 +export DISABLE_CACHE_OP="ReduceMaxD,ReduceSumD" + + +ulimit -SHn 512000 + +path_lib=$(python3.7 -c """ +import sys +import re +result='' +for index in range(len(sys.path)): + match_sit = re.search('-packages', sys.path[index]) + if match_sit is not None: + match_lib = re.search('lib', sys.path[index]) + + if match_lib is not None: + end=match_lib.span()[1] + result += sys.path[index][0:end] + ':' + + result+=sys.path[index] + '/torch/lib:' +print(result)""" +) + +echo ${path_lib} + +export LD_LIBRARY_PATH=/usr/local/python3.7.5/lib/:${path_lib}:$LD_LIBRARY_PATH +export PYTHONPATH=$PYTHONPATH:./Megatron-LM diff --git a/megatron_npu/tests/pretrain_t5.sh b/megatron_npu/tests/pretrain_t5.sh new file mode 100644 index 0000000000000000000000000000000000000000..67355c72ef728590751abfd6f88de5034db1acde --- /dev/null +++ b/megatron_npu/tests/pretrain_t5.sh @@ -0,0 +1,42 @@ +#!/bin/bash +source ./tests/env_npu.sh +export MASTER_ADDR=localhost +export MASTER_PORT=6000 + +RANK=0 +WORLD_SIZE=1 +DATA_PATH= +VOCAB_FILE= +CHECKPOINT_PATH= + +python3 runner.py \ + --num-layers 12 \ + --hidden-size 768 \ + --num-attention-heads 12 \ + --kv-channels 64 \ + --ffn-hidden-size 3072 \ + --encoder-seq-length 512 \ + --decoder-seq-length 128 \ + --micro-batch-size 16 \ + --global-batch-size 16 \ + --max-position-embeddings 512 \ + --train-iters 1000000 \ + --lr-decay-iters 1000000 \ + --save $CHECKPOINT_PATH \ + --load $CHECKPOINT_PATH \ + --data-path $DATA_PATH \ + --vocab-file $VOCAB_FILE \ + --data-impl mmap \ + --split 949,50,1 \ + --lr 0.0001 \ + --min-lr 0.00001 \ + --lr-decay-style linear \ + --lr-warmup-fraction .01 \ + --weight-decay 1e-2 \ + --clip-grad 1.0 \ + --log-interval 100 \ + --save-interval 10000 \ + --eval-interval 1000 \ + --eval-iters 10 \ + --fp16 \ + --vocab-extra-ids 100 diff --git a/megatron_npu/tests/pretrain_t5_distributed_with_mp.sh b/megatron_npu/tests/pretrain_t5_distributed_with_mp.sh new file mode 100644 index 0000000000000000000000000000000000000000..9b7638097a46ae788ef15d4da056dcd6e91c54d7 --- /dev/null +++ b/megatron_npu/tests/pretrain_t5_distributed_with_mp.sh @@ -0,0 +1,50 @@ +#!/bin/bash +source ./tests/env_npu.sh + +GPUS_PER_NODE=8 +# Change for multinode config +export MASTER_ADDR=localhost +export MASTER_PORT=6000 +NNODES=1 +NODE_RANK=0 +WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES)) + +DATA_PATH= +VOCAB_FILE= +CHECKPOINT_PATH= + +DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE --nnodes $NNODES --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT" + +python3 -m torch.distributed.launch $DISTRIBUTED_ARGS \ + launcher.py \ + --tensor-model-parallel-size 2 \ + --num-layers 12 \ + --hidden-size 768 \ + --num-attention-heads 12 \ + --kv-channels 64 \ + --ffn-hidden-size 3072 \ + --encoder-seq-length 512 \ + --decoder-seq-length 128 \ + --micro-batch-size 16 \ + --global-batch-size 128 \ + --max-position-embeddings 512 \ + --train-iters 1000000 \ + --lr-decay-iters 1000000 \ + --save $CHECKPOINT_PATH \ + --load $CHECKPOINT_PATH \ + --data-path $DATA_PATH \ + --vocab-file $VOCAB_FILE \ + --data-impl mmap \ + --split 949,50,1 \ + --lr 0.0001 \ + --min-lr 0.00001 \ + --lr-decay-style linear \ + --lr-warmup-fraction .01 \ + --weight-decay 1e-2 \ + --clip-grad 1.0 \ + --log-interval 100 \ + --save-interval 10000 \ + --eval-interval 1000 \ + --eval-iters 10 \ + --fp16 \ + --vocab-extra-ids 100 diff --git a/megatron_npu/tests/test.sh b/megatron_npu/tests/test.sh new file mode 100644 index 0000000000000000000000000000000000000000..90695b92dfa481d81561baa47df285caf24bd5f4 --- /dev/null +++ b/megatron_npu/tests/test.sh @@ -0,0 +1,11 @@ +source tests/env_npu.sh + + +python3 -m torch.distributed.launch --nproc_per_node=8 ./tests/ut/test_cross_entropy.py +python3 -m torch.distributed.launch --nproc_per_node=8 ./tests/ut/test_data.py +python3 -m torch.distributed.launch --nproc_per_node=8 ./tests/ut/test_initialize.py +python3 -m torch.distributed.launch --nproc_per_node=8 ./tests/ut/test_layers.py +python3 -m torch.distributed.launch --nproc_per_node=8 ./tests/ut/test_random.py + +cd - + diff --git a/megatron_npu/tests/train_full_32p.sh b/megatron_npu/tests/train_full_32p.sh new file mode 100644 index 0000000000000000000000000000000000000000..6bbe245e7da2ad5271d6e70f338eabe79939a45f --- /dev/null +++ b/megatron_npu/tests/train_full_32p.sh @@ -0,0 +1,58 @@ +#!/bin/bash +source ./tests/env_npu.sh + +# NPU +export HCCL_IF_IP=xxxx # 当前机器IP地址,需手动设置 +epoxrt HCCL_CONNECT_TIMEOUT=3600 +epoxrt HCCL_EXEC_TIMEOUT=3600 + +GPUS_PER_NODE=8 +# Change for multinode config +export MASTER_ADDR=xxxx # 设置主机诶单,需手动设置 +export MASTER_PORT=23333 +NNODES=4 +NODE_RANK=0 # 依次设置0,1,2,3,需手动设置 +WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES)) + +DATASET_ROOT_PATH=./dataset/en_wiki/preprocess +DATA_PATH=$DATASET_ROOT_PATH/my-t5_text_sentence +VOCAB_FILE=$DATASET_ROOT_PATH/bert-large-uncased-vocab.txt +CHECKPOINT_PATH=./checkpoint + +DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE --nnodes $NNODES --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT" + +python3 -m torch.distributed.launch $DISTRIBUTED_ARGS \ + runner.py \ + --tensor-model-parallel-size 2 \ + --num-layers 12 \ + --hidden-size 768 \ + --num-attention-heads 12 \ + --kv-channels 64 \ + --ffn-hidden-size 3072 \ + --encoder-seq-length 512 \ + --decoder-seq-length 128 \ + --micro-batch-size 16 \ + --global-batch-size 512 \ + --max-position-embeddings 512 \ + --train-iters 1000000 \ + --lr-decay-iters 1000000 \ + --save $CHECKPOINT_PATH \ + --data-path $DATA_PATH \ + --vocab-file $VOCAB_FILE \ + --data-impl mmap \ + --split 949,50,1 \ + --lr 0.0001 \ + --min-lr 0.00001 \ + --lr-decay-style linear \ + --lr-warmup-fraction .01 \ + --weight-decay 1e-2 \ + --clip-grad 1.0 \ + --log-interval 1 \ + --save-interval 10000 \ + --eval-interval 1000 \ + --eval-iters 10 \ + --fp16 \ + --vocab-extra-ids 100 \ + --use-cpu-initialization \ + --no-bias-gelu-fusion + diff --git a/megatron_npu/tests/train_full_8p.sh b/megatron_npu/tests/train_full_8p.sh new file mode 100644 index 0000000000000000000000000000000000000000..6cf822634f0c8c7904fba832f7dc245626942d5e --- /dev/null +++ b/megatron_npu/tests/train_full_8p.sh @@ -0,0 +1,77 @@ +#!/bin/bash +source ./tests/env_npu.sh + +GPUS_PER_NODE=8 +# Change for multinode config +export MASTER_ADDR=localhost +export MASTER_PORT=6000 +NNODES=1 +NODE_RANK=0 +WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES)) + +DATASET_ROOT_PATH=./dataset/en_wiki/preprocess +DATA_PATH=$DATASET_ROOT_PATH/my-t5_text_sentence +VOCAB_FILE=$DATASET_ROOT_PATH/bert-large-uncased-vocab.txt +CHECKPOINT_PATH=./checkpoint + +GlobalBatchSize=128 +ExitInterval=100001 + +DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE --nnodes $NNODES --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT" + +nohup \ +python3 -m torch.distributed.launch $DISTRIBUTED_ARGS \ + runner.py \ + --tensor-model-parallel-size 2 \ + --num-layers 12 \ + --hidden-size 768 \ + --num-attention-heads 12 \ + --kv-channels 64 \ + --ffn-hidden-size 3072 \ + --encoder-seq-length 512 \ + --decoder-seq-length 128 \ + --micro-batch-size 16 \ + --global-batch-size ${GlobalBatchSize} \ + --max-position-embeddings 512 \ + --train-iters 1000000 \ + --lr-decay-iters 1000000 \ + --save $CHECKPOINT_PATH \ + --data-path $DATA_PATH \ + --vocab-file $VOCAB_FILE \ + --data-impl mmap \ + --split 949,50,1 \ + --lr 0.0001 \ + --min-lr 0.00001 \ + --lr-decay-style linear \ + --lr-warmup-fraction .01 \ + --weight-decay 1e-2 \ + --clip-grad 1.0 \ + --log-interval 1 \ + --save-interval 10000 \ + --eval-interval 1000 \ + --eval-iters 10 \ + --fp16 \ + --vocab-extra-ids 100 \ + --use-cpu-initialization \ + --no-bias-gelu-fusion \ + --exit-interval ${ExitInterval} \ +> ./training.log & + +wait + +#结果打印,不需要修改 +echo "------------------ Final result ------------------" +#输出性能FPS,需要模型审视修改 +IterTime=$(grep "elapsed time per iteration" ./training.log | tail -n 1000 | awk '{print $14}' | awk '{sum+=$1} END {print sum/NR}') +FPS=$(echo "${GlobalBatchSize} * 1000 / ${IterTime}"|bc) + +#打印,不需要修改 +echo "Final Performance samples/sec : $FPS" + +#输出训练精度,需要模型审视修改 +train_accuracy=$(grep "PPL" ./training.log | awk 'END {print $15}') + +#打印,不需要修改 +echo "Final Train Accuracy : ${train_accuracy}" + + diff --git a/megatron_npu/tests/train_performance_8p.sh b/megatron_npu/tests/train_performance_8p.sh new file mode 100644 index 0000000000000000000000000000000000000000..f4a21c0ffcb47951c8a42a5a3cdb1aceddf30b93 --- /dev/null +++ b/megatron_npu/tests/train_performance_8p.sh @@ -0,0 +1,77 @@ +#!/bin/bash +source ./tests/env_npu.sh + +GPUS_PER_NODE=8 +# Change for multinode config +export MASTER_ADDR=localhost +export MASTER_PORT=6000 +NNODES=1 +NODE_RANK=0 +WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES)) + +DATASET_ROOT_PATH=./dataset/en_wiki/preprocess +DATA_PATH=$DATASET_ROOT_PATH/my-t5_text_sentence +VOCAB_FILE=$DATASET_ROOT_PATH/bert-large-uncased-vocab.txt +CHECKPOINT_PATH=./checkpoint + +GlobalBatchSize=128 +ExitInterval=1001 + +DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE --nnodes $NNODES --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT" + +nohup \ +python3 -m torch.distributed.launch $DISTRIBUTED_ARGS \ + runner.py \ + --tensor-model-parallel-size 2 \ + --num-layers 12 \ + --hidden-size 768 \ + --num-attention-heads 12 \ + --kv-channels 64 \ + --ffn-hidden-size 3072 \ + --encoder-seq-length 512 \ + --decoder-seq-length 128 \ + --micro-batch-size 16 \ + --global-batch-size ${GlobalBatchSize} \ + --max-position-embeddings 512 \ + --train-iters 1000000 \ + --lr-decay-iters 1000000 \ + --save $CHECKPOINT_PATH \ + --data-path $DATA_PATH \ + --vocab-file $VOCAB_FILE \ + --data-impl mmap \ + --split 949,50,1 \ + --lr 0.0001 \ + --min-lr 0.00001 \ + --lr-decay-style linear \ + --lr-warmup-fraction .01 \ + --weight-decay 1e-2 \ + --clip-grad 1.0 \ + --log-interval 1 \ + --save-interval 10000 \ + --eval-interval 1000 \ + --eval-iters 10 \ + --fp16 \ + --vocab-extra-ids 100 \ + --use-cpu-initialization \ + --no-bias-gelu-fusion \ + --exit-interval ${ExitInterval} \ +> ./training.log & + +wait + +#结果打印,不需要修改 +echo "------------------ Final result ------------------" +#输出性能FPS,需要模型审视修改 +IterTime=$(grep "elapsed time per iteration" ./training.log | tail -n 1000 | awk '{print $14}' | awk '{sum+=$1} END {print sum/NR}') +FPS=$(echo "${GlobalBatchSize} * 1000 / ${IterTime}"|bc) + +#打印,不需要修改 +echo "Final Performance samples/sec : $FPS" + +#输出训练精度,需要模型审视修改 +train_accuracy=$(grep "PPL" ./training.log | awk 'END {print $15}') + +#打印,不需要修改 +echo "Final Train Accuracy : ${train_accuracy}" + + diff --git a/megatron_npu/tests/ut/__init__.py b/megatron_npu/tests/ut/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/megatron_npu/tests/ut/commons.py b/megatron_npu/tests/ut/commons.py new file mode 100644 index 0000000000000000000000000000000000000000..fa273337147b8e2d46740d0e10e1904ba0406b3d --- /dev/null +++ b/megatron_npu/tests/ut/commons.py @@ -0,0 +1,83 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +import random +import numpy +import torch + +from megatron import mpu + + +class IdentityLayer(torch.nn.Module): + def __init__(self, size, scale=1.0): + super(IdentityLayer, self).__init__() + self.weight = torch.nn.Parameter(scale * torch.randn(size)) + + def forward(self): + return self.weight + + +def set_random_seed(seed): + """Set random seed for reproducability.""" + random.seed(seed) + numpy.random.seed(seed) + torch.manual_seed(seed) + mpu.model_parallel_cuda_manual_seed(seed) + + +def initialize_distributed(backend='hccl'): + """Initialize torch.distributed.""" + # Get local rank in case it is provided. + parser = argparse.ArgumentParser() + parser.add_argument('--local_rank', type=int, default=None, + help='local rank passed from distributed launcher') + args = parser.parse_args() + local_rank = args.local_rank + + # Get rank and world size. + rank = int(os.getenv('RANK', '0')) + world_size = int(os.getenv("WORLD_SIZE", '1')) + + print('> initializing torch.distributed with local rank: {}, ' + 'rank: {}, world size: {}'.format(local_rank, rank, world_size)) + + # Set the device id. + device = rank % torch.npu.device_count() + if local_rank is not None: + device = local_rank + torch.npu.set_device(device) + + # Call the init process. + init_method = 'tcp://' + master_ip = os.getenv('MASTER_ADDR', 'localhost') + master_port = os.getenv('MASTER_PORT', '6000') + init_method += master_ip + ':' + master_port + torch.distributed.init_process_group( + backend=backend, + world_size=world_size, + rank=rank, + init_method=init_method) + + +def print_separator(message): + torch.distributed.barrier() + filler_len = (78 - len(message)) // 2 + filler = '-' * filler_len + string = '\n' + filler + ' {} '.format(message) + filler + if torch.distributed.get_rank() == 0: + print(string, flush=True) + torch.distributed.barrier() diff --git a/megatron_npu/tests/ut/test_cross_entropy.py b/megatron_npu/tests/ut/test_cross_entropy.py new file mode 100644 index 0000000000000000000000000000000000000000..07a69c576f8683461e3c106284c50848a77bed49 --- /dev/null +++ b/megatron_npu/tests/ut/test_cross_entropy.py @@ -0,0 +1,114 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +if torch.__version__>="1.8.0": + try: + import torch_npu + from torch_npu.contrib import transfer_to_npu + import bugfix + except: + print('WARNING! torch_npu is not imported.. Please using without npu..') +from commons import set_random_seed +from commons import IdentityLayer +from commons import print_separator +from commons import initialize_distributed +from megatron.mpu.cross_entropy import vocab_parallel_cross_entropy +from megatron import mpu +import torch.nn.functional as F +import random +import sys + + +def torch_cross_entropy(batch_size, seq_length, vocab_size, + logits_scale, seed): + set_random_seed(seed) + identity = IdentityLayer((batch_size, seq_length, vocab_size), + scale=logits_scale).npu() + logits = identity() + target = torch.LongTensor( + size=(batch_size, seq_length)).random_(0, vocab_size).npu() + loss = F.cross_entropy(logits.view(-1, logits.size()[-1]), + target.view(-1), + reduction='none').view_as(target).mean() + loss.backward() + return loss, identity.weight.grad + + +def mpu_cross_entropy(batch_size, seq_length, vocab_size, + logits_scale, seed): + set_random_seed(seed) + identity = IdentityLayer((batch_size, seq_length, vocab_size), + scale=logits_scale).npu() + logits = identity() + logits_parallel = mpu.scatter_to_tensor_model_parallel_region(logits) + target = torch.LongTensor( + size=(batch_size, seq_length)).random_(0, vocab_size).npu() + loss = vocab_parallel_cross_entropy(logits_parallel, target).mean() + loss.backward() + return loss, identity.weight.grad + + +def test_cross_entropy(tensor_model_parallel_size): + + if torch.distributed.get_rank() == 0: + print('> testing cross entropy with model parallel size {} ...'. + format(tensor_model_parallel_size)) + + mpu.initialize_model_parallel(tensor_model_parallel_size) + tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size() + + batch_size = 13 + seq_length = 17 + vocab_size_per_partition = 11 + logits_scale = 1000.0 + vocab_size = vocab_size_per_partition * tensor_model_parallel_size + seed = 1234 + + loss_torch, grad_torch = torch_cross_entropy(batch_size, seq_length, + vocab_size, logits_scale, + seed) + loss_mpu, grad_mpu = mpu_cross_entropy(batch_size, seq_length, + vocab_size, logits_scale, + seed) + + error = loss_torch.sub_(loss_mpu).abs().max() + print(' max error in loss on global rank {}: {}'.format( + torch.distributed.get_rank(), error)) + assert error < 1.0e-6 + + error = grad_torch.sub_(grad_mpu).abs().max() + print(' max error in grad on global rank {}: {}'.format( + torch.distributed.get_rank(), error)) + assert error < 1.0e-6 + + # Reset groups + mpu.destroy_model_parallel() + + torch.distributed.barrier() + if torch.distributed.get_rank() == 0: + print('>> passed the test :-)') + + +if __name__ == '__main__': + + initialize_distributed() + world_size = torch.distributed.get_world_size() + + tensor_model_parallel_size = 1 + while tensor_model_parallel_size <= world_size: + print_separator('test cross entropy') + test_cross_entropy(tensor_model_parallel_size) + tensor_model_parallel_size *= 2 diff --git a/megatron_npu/tests/ut/test_data.py b/megatron_npu/tests/ut/test_data.py new file mode 100644 index 0000000000000000000000000000000000000000..bcf29cbdd472647f9dd619e2686d3008c1cce6a9 --- /dev/null +++ b/megatron_npu/tests/ut/test_data.py @@ -0,0 +1,94 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +if torch.__version__>="1.8.0": + try: + import torch_npu + from torch_npu.contrib import transfer_to_npu + import bugfix + except: + print('WARNING! torch_npu is not imported.. Please using without npu..') +from commons import print_separator +from commons import initialize_distributed +from megatron.mpu import data as data_utils +from megatron import mpu +import functools +import operator +import sys + + +def test_broadcast_data(tensor_model_parallel_size): + + if torch.distributed.get_rank() == 0: + print('> testing broadcast_data with model parallel size {} ...'. + format(tensor_model_parallel_size)) + + mpu.initialize_model_parallel(tensor_model_parallel_size) + torch.manual_seed(1234 + mpu.get_data_parallel_rank()) + tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size() + + key_size_t = {'key1': [7, 11], + 'key2': [8, 2, 1], + 'key3': [13], + 'key4': [5, 1, 2], + 'key5': [5, 12]} + keys = list(key_size_t.keys()) + + data = {} + data_t = {} + for key in key_size_t: + data[key] = torch.LongTensor(size=key_size_t[key]).random_(0, 1000) + data_t[key] = data[key].clone() + data['keyX'] = torch.FloatTensor(size=(5, )).random_(0, 1000) + data_t['keyX'] = data['keyX'].clone() + if mpu.get_tensor_model_parallel_rank() != 0: + data = None + + data_utils._check_data_types(keys, data_t, torch.int64) + key_size, key_numel, \ + total_numel = data_utils._build_key_size_numel_dictionaries(keys, data) + for key in keys: + assert key_size[key] == key_size_t[key] + total_numel_t = 0 + for key in keys: + target_size = functools.reduce(operator.mul, key_size_t[key], 1) + assert key_numel[key] == target_size + total_numel_t += target_size + assert total_numel == total_numel_t + + data_b = data_utils.broadcast_data(keys, data, torch.int64) + for key in keys: + tensor = data_t[key].npu() + assert data_b[key].sub(tensor).abs().max() == 0 + + # Reset groups + mpu.destroy_model_parallel() + + torch.distributed.barrier() + if torch.distributed.get_rank() == 0: + print('>> passed the test :-)') + + +if __name__ == '__main__': + + initialize_distributed() + world_size = torch.distributed.get_world_size() + + tensor_model_parallel_size = 1 + while tensor_model_parallel_size <= world_size: + print_separator('test test broadcast data') + test_broadcast_data(tensor_model_parallel_size) + tensor_model_parallel_size *= 2 diff --git a/megatron_npu/tests/ut/test_initialize.py b/megatron_npu/tests/ut/test_initialize.py new file mode 100644 index 0000000000000000000000000000000000000000..5fbfb5e595aa9615324f153cb1f685441f2ed9ba --- /dev/null +++ b/megatron_npu/tests/ut/test_initialize.py @@ -0,0 +1,101 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +if torch.__version__>="1.8.0": + try: + import torch_npu + from torch_npu.contrib import transfer_to_npu + import bugfix + except: + print('WARNING! torch_npu is not imported.. Please using without npu..') +from commons import print_separator +from commons import initialize_distributed +from megatron import mpu +import sys + + +def test_initialize_model_parallel(tensor_model_parallel_size): + + if torch.distributed.get_rank() == 0: + print('> testing initialize_model_parallel with size {} ...'.format( + tensor_model_parallel_size)) + tensor_model_parallel_size_ = min(tensor_model_parallel_size, + torch.distributed.get_world_size()) + assert not mpu.model_parallel_is_initialized() + mpu.initialize_model_parallel(tensor_model_parallel_size_) + assert mpu.model_parallel_is_initialized() + + # Checks. + def check(group, world_size, rank): + assert world_size == torch.distributed.get_world_size(group=group) + assert rank == torch.distributed.get_rank(group=group) + + # Model parallel. + world_size = tensor_model_parallel_size_ + rank = torch.distributed.get_rank() % tensor_model_parallel_size_ + assert world_size == mpu.get_tensor_model_parallel_world_size() + assert rank == mpu.get_tensor_model_parallel_rank() + check(mpu.get_tensor_model_parallel_group(), world_size, rank) + + # Data parallel. + world_size = torch.distributed.get_world_size() // tensor_model_parallel_size_ + rank = torch.distributed.get_rank() // tensor_model_parallel_size + assert world_size == mpu.get_data_parallel_world_size() + assert rank == mpu.get_data_parallel_rank() + check(mpu.get_data_parallel_group(), world_size, rank) + + # Reset groups + mpu.destroy_model_parallel() + + torch.distributed.barrier() + if torch.distributed.get_rank() == 0: + print('>> passed the test :-)') + + +def test_get_tensor_model_parallel_src_rank(tensor_model_parallel_size_): + + if torch.distributed.get_rank() == 0: + print('> testing get_tensor_model_parallel_src_rank with size {} ...'.format( + tensor_model_parallel_size_)) + tensor_model_parallel_size = min(tensor_model_parallel_size_, + torch.distributed.get_world_size()) + assert not mpu.model_parallel_is_initialized() + mpu.initialize_model_parallel(tensor_model_parallel_size) + assert mpu.model_parallel_is_initialized() + + # Checks + src_rank = torch.distributed.get_rank() - mpu.get_tensor_model_parallel_rank() + assert mpu.get_tensor_model_parallel_src_rank() == src_rank + + # Reset groups + mpu.destroy_model_parallel() + + torch.distributed.barrier() + if torch.distributed.get_rank() == 0: + print('>> passed the test :-)') + + +if __name__ == '__main__': + + initialize_distributed() + world_size = torch.distributed.get_world_size() + tensor_model_parallel_size = 1 + while tensor_model_parallel_size <= world_size: + print_separator('test initialize model parallel') + test_initialize_model_parallel(tensor_model_parallel_size) + print_separator('test model parallel source rank') + test_get_tensor_model_parallel_src_rank(tensor_model_parallel_size) + tensor_model_parallel_size *= 2 diff --git a/megatron_npu/tests/ut/test_layers.py b/megatron_npu/tests/ut/test_layers.py new file mode 100644 index 0000000000000000000000000000000000000000..9d453f13fb64c4648c118a960419f6fdaadba4e5 --- /dev/null +++ b/megatron_npu/tests/ut/test_layers.py @@ -0,0 +1,1000 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +if torch.__version__>="1.8.0": + try: + import torch_npu + from torch_npu.contrib import transfer_to_npu + import bugfix + + option = {} + option["ACL_OP_COMPILER_CACHE_MODE"] = "disable" # cache功能启用 + option["ACL_OP_COMPILER_CACHE_DIR"] = "./cache" # cache所在文件夹 + print("option:", option) + torch.npu.set_option(option) + except: + print('WARNING! torch_npu is not imported.. Please using without npu..') +from megatron.mpu import layers +from commons import set_random_seed +from commons import print_separator +from commons import initialize_distributed +from megatron import mpu, get_global_memory_buffer +from torch.nn.parameter import Parameter +import torch.nn.init as init +import random +import sys +import torch.nn as nn +import math +import torch.nn.functional as F +from megatron.initialize import initialize_megatron + +initialize_megatron( + args_defaults = { + 'micro_batch_size': 1, + 'num_layers': 1, + 'hidden_size': 32, + 'num_attention_heads': 16, + 'max_position_embeddings': 128, + 'encoder_seq_length': 32, + 'use_cpu_initialization': True + }) + +def split_tensor_along_last_dim(tensor, num_partitions, + contiguous_split_chunks=False): + """Split a tensor along its last dimension. + Arguments: + tensor: input tensor. + num_partitions: number of partitions to split the tensor + contiguous_split_chunks: If True, make each chunk contiguous + in memory. + """ + # Get the size and dimension. + last_dim = tensor.dim() - 1 + last_dim_size = divide(tensor.size()[last_dim], num_partitions) + # Split. + tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) + # Note: torch.split does not create contiguous tensors by default. + if contiguous_split_chunks: + return tuple(chunk.contiguous() for chunk in tensor_list) + + return tensor_list + + +def ensure_divisibility(numerator, denominator): + """Ensure that numerator is divisible by the denominator.""" + assert numerator % denominator == 0, '{} is not divisible by {}'.format( + numerator, denominator) + + +def divide(numerator, denominator): + """Ensure that numerator is divisible by the denominator and return + the division value.""" + ensure_divisibility(numerator, denominator) + return numerator // denominator + + +def _initialize_affine_weight(weight, output_size, input_size, + per_partition_size, partition_dim, init_method, + stride=1, return_master_weight=False): + """Initialize affine weight for model parallel. + + Build the master weight on all processes and scatter + the relevant chunk.""" + # If we only use 1 process for model parallelism, bypass scatter. + world_size = mpu.get_tensor_model_parallel_world_size() + if world_size == 1: + init_method(weight) + if return_master_weight: + return weight + return None + + # Initialize master weight + master_weight = torch.empty(output_size, input_size, + dtype=weight.dtype, + requires_grad=False) + init_method(master_weight) + + # Split and copy + per_partition_per_stride_size = divide(per_partition_size, stride) + weight_list = torch.split(master_weight, per_partition_per_stride_size, + dim=partition_dim) + rank = mpu.get_tensor_model_parallel_rank() + my_weight_list = weight_list[rank::world_size] + + with torch.no_grad(): + torch.cat(my_weight_list, dim=partition_dim, out=weight) + if return_master_weight: + return master_weight + return None + + +class BertParallelTransformerOutput(torch.nn.Module): + """The output layer used after self attention and intermediate + parts of transformer layer.""" + def __init__(self, input_size, output_size, dropout_prob, + layernorm_epsilon=1.0e-12, input_is_parallel=False, + init_method=init.xavier_normal_): + super(BertParallelTransformerOutput, self).__init__() + # Components. + self.dense = mpu.RowParallelLinear(input_size, + output_size, + input_is_parallel=input_is_parallel, + init_method=init_method) + self.dropout = torch.nn.Dropout(dropout_prob) + self.layernorm = nn.LayerNorm(output_size, eps=layernorm_epsilon) + + def forward(self, hidden_states, input_tensor): + hidden_states, bias = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + layernorm_input = hidden_states + input_tensor + hidden_states = self.layernorm(layernorm_input) + return hidden_states + + +class BertParallelTransformerLayer(torch.nn.Module): + """A single layer transformer for Bert. + + We use the following notation: + h: hidden size + n: number of attention heads + b: batch size + s: sequence length + Transformore layer takes input with size [b, s, h] and returns an + output of the same size. + + Arguments: + hidden_size: The hidden size of the self attention. + intermediate_size: size of the intermediate state after + self attention. In both BERT and GPT + this is set to be 4 times the hidden + size. + num_attention_heads: number of attention head in the self + attention. + attention_dropout_prob: dropout probability of the attention + score in self attention. + output_dropout_prob: dropout probability for the outputs + after self attention and final output. + intermediate_activation_fn: activation function for output + of intermediate. + layernorm_epsilon: epsilon used in layernorm to avoid + division by zero. + init_method: initialization method used for the weights. Note + that all biases are initialized to zero and + layernorm weight are initialized to one. + """ + def __init__(self, + hidden_size, + intermediate_size, + num_attention_heads, + attention_dropout_prob, + output_dropout_prob, + intermediate_activation_fn, + layernorm_epsilon, + init_method=init.xavier_normal_): + super(BertParallelTransformerLayer, self).__init__() + + # Self attention. + self.attention = BertParallelSelfAttention(hidden_size, + num_attention_heads, + attention_dropout_prob, + output_parallel=True, + init_method=init_method) + # Self attention output. + self.self_output = BertParallelTransformerOutput( + hidden_size, hidden_size, output_dropout_prob, + layernorm_epsilon=layernorm_epsilon, + input_is_parallel=True, + init_method=init_method) + # Intermediate. + self.intermediate = mpu.ColumnParallelLinear(hidden_size, intermediate_size, + gather_output=False, + init_method=init_method) + self.intermediate_activation_fn = intermediate_activation_fn + # Output. + self.output = BertParallelTransformerOutput( + intermediate_size, hidden_size, output_dropout_prob, + layernorm_epsilon=layernorm_epsilon, + input_is_parallel=True, + init_method=init_method) + + def forward(self, hidden_states, attention_mask): + # [b, s, hp] + attention_output_parallel = self.attention(hidden_states, + attention_mask) + # [b, s, h] + attention_self_output = self.self_output(attention_output_parallel, + hidden_states) + # [b, s, ip] + intermediate_output_parallel, bias = self.intermediate(attention_self_output) + intermediate_output_parallel = self.intermediate_activation_fn( + intermediate_output_parallel) + # [b, s, h] + layer_output = self.output(intermediate_output_parallel, + attention_self_output) + + return layer_output + + +class BertParallelSelfAttention(torch.nn.Module): + """Parallel self-attention layer for BERT. + + Self-attention layer takes input with size [b, s, h] where b is + the batch size, s is the sequence lenght, and h is the hidden size + and creates output of the same size. + Arguments: + hidden_size: total hidden size of the layer (h). + num_attention_heads: number of attention heads (n). Note that we + require n to be divisible by number of GPUs + used to parallelize the model. Also, we + require hidden size be divisible by n. + dropout_prob: dropout probability for the attention scores. + output_parallel: If true, no all-gather is done on the output and + the output values will be per partition. + We use the following notation: + h: hidden_size + n: num_attention_heads + p: number of partitions + np: n/p + hp: h/p + hn: h/n + b: batch size + s: sequence length + """ + def __init__(self, hidden_size, num_attention_heads, + dropout_prob, output_parallel=False, + init_method=init.xavier_normal_): + super(BertParallelSelfAttention, self).__init__() + # Input configuration. + self.hidden_size = hidden_size + self.num_attention_heads = num_attention_heads + self.dropout_prob = dropout_prob + self.output_parallel = output_parallel + # Per attention head and per partition values. + world_size = mpu.get_tensor_model_parallel_world_size() + self.hidden_size_per_partition = divide(hidden_size, world_size) + self.hidden_size_per_attention_head = divide(hidden_size, + num_attention_heads) + self.num_attention_heads_per_partition = divide(num_attention_heads, + world_size) + # Strided linear layer. + self.query_key_value = mpu.ColumnParallelLinear(hidden_size, 3*hidden_size, + stride=3, + gather_output=False, + init_method=init_method) + # Dropout. Note that for a single iteration, this layer will generate + # different outputs on different number of parallel partitions but + # on average it should not be partition dependent. + self.dropout = torch.nn.Dropout(dropout_prob) + + def _transpose_for_scores(self, tensor): + """Transpose a 3D tensor [b, s, np*hn] into a 4D tensor with + size [b, np, s, hn]. + """ + new_tensor_shape = tensor.size()[:-1] + \ + (self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head) + tensor = tensor.view(*new_tensor_shape) + return tensor.permute(0, 2, 1, 3) + + def forward(self, hidden_states, attention_mask): + + # Attention heads. [b, s, hp] + mixed_x_layer, bias = self.query_key_value(hidden_states) + (mixed_query_layer, + mixed_key_layer, + mixed_value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3) + + # Reshape and transpose [b, np, s, hn] + query_layer = self._transpose_for_scores(mixed_query_layer) + key_layer = self._transpose_for_scores(mixed_key_layer) + value_layer = self._transpose_for_scores(mixed_value_layer) + + # Raw attention scores. [b, np, s, s] + norm_factor = math.sqrt(math.sqrt(self.hidden_size_per_attention_head)) + attention_scores = torch.matmul(query_layer/norm_factor, + key_layer.transpose(-1, -2)/norm_factor) + # Apply the attention mask. + attention_scores += attention_mask + + # Attention probabilities. [b, np, s, s] + attention_probs = torch.nn.Softmax(dim=-1)(attention_scores) + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + + # get_cuda_rng_tracker is not supported on npu now.. + # with get_cuda_rng_tracker().fork(): + attention_probs = self.dropout(attention_probs) + + # Context layer. + # [b, np, s, hn] + context_layer = torch.matmul(attention_probs, value_layer) + # [b, s, np, hn] + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + \ + (self.hidden_size_per_partition,) + # [b, s, hp] + context_layer = context_layer.view(*new_context_layer_shape) + + # Output. [b, s, h] + if self.output_parallel: + output = context_layer + else: + output = mpu.gather_from_tensor_model_parallel_region(context_layer) + + return output + +class ParallelEmbedding(torch.nn.Module): + """Embedding parallelized in the embedding dimension. + + This is mainly adapted from torch.nn.Embedding and all the default + values are kept. + Arguments: + num_embeddings: vocabulary size. + embedding_dim: size of hidden state. + init_method: method to initialize weights. + """ + def __init__(self, num_embeddings, embedding_dim, + init_method=init.xavier_normal_, + keep_master_weight_for_test=False): + super(ParallelEmbedding, self).__init__() + # Keep the input dimensions. + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + # Set some detauls for compatibility. + self.padding_idx = None + self.max_norm = None + self.norm_type = 2. + self.scale_grad_by_freq = False + self.sparse = False + self._weight = None + # Divide the weight matrix along the embedding dimension. + world_size = mpu.get_tensor_model_parallel_world_size() + self.embedding_dim_per_partition = divide(self.embedding_dim, + world_size) + + # Allocate weights. + self.weight = Parameter(torch.Tensor(self.num_embeddings, + self.embedding_dim_per_partition)) + self.weight.model_parallel = True + # And initialize. + _initialize_affine_weight( + self.weight, self.num_embeddings, self.embedding_dim, + self.embedding_dim_per_partition, 1, init_method, + stride=1, return_master_weight=False) + + def forward(self, input_): + input_parallel = mpu.copy_to_tensor_model_parallel_region(input_) + output_parallel = F.embedding(input_parallel, self.weight, + self.padding_idx, self.max_norm, + self.norm_type, self.scale_grad_by_freq, + self.sparse) + output = mpu.gather_from_tensor_model_parallel_region(output_parallel) + return output + + +def test_parallel_embedding(tensor_model_parallel_size): + + if torch.distributed.get_rank() == 0: + print('> testing parallel embedding with model parallel size {} ...'. + format(tensor_model_parallel_size)) + + mpu.initialize_model_parallel(tensor_model_parallel_size) + tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size() + + batch_size = 17 + seq_length = 23 + vocab_size = 48 + hidden_size = 16 + seed = 1236 + + set_random_seed(123) + input_data = torch.LongTensor( + size=(batch_size, seq_length)).random_(0, vocab_size).npu() + loss_weight = torch.randn([batch_size, seq_length, hidden_size]).npu() + + set_random_seed(seed) + embedding_original = torch.nn.Embedding(vocab_size, hidden_size).npu() + + output = embedding_original(input_data) + loss_original = torch.mul(output, loss_weight).sum() + loss_original.backward() + + set_random_seed(seed) + embedding_parallel = ParallelEmbedding( + vocab_size, hidden_size, init_method=init.normal_).npu() + output = embedding_parallel(input_data) + loss_parallel = torch.mul(output, loss_weight).sum() + loss_parallel.backward() + + set_random_seed(seed) + embedding_vocab_parallel = mpu.VocabParallelEmbedding( + vocab_size, hidden_size, init_method=init.normal_).npu() + output = embedding_vocab_parallel(input_data) + loss_vocab_parallel = torch.mul(output, loss_weight).sum() + loss_vocab_parallel.backward() + + torch.distributed.barrier() + error = loss_parallel.sub(loss_original).abs() + print(' error in loss (parallel) on global rank {}: {}'.format( + torch.distributed.get_rank(), error)) + assert error < 1.0e-4, 'error: {}'.format(error) + + torch.distributed.barrier() + error = loss_vocab_parallel.sub(loss_original).abs() + print(' error in loss (vocab parallel) on global rank {}: {}'.format( + torch.distributed.get_rank(), error)) + assert error < 1.0e-4, 'error: {}'.format(error) + + weight_grad_orig = torch.split(embedding_original.weight.grad, + hidden_size // tensor_model_parallel_size, + 1)[mpu.get_tensor_model_parallel_rank()] + error = embedding_parallel.weight.grad.sub(weight_grad_orig).abs().max() + print(' error in grad (parallel) on global rank {}: {}'.format( + torch.distributed.get_rank(), error)) + assert error < 1.0e-4, 'error: {}'.format(error) + + weight_grad_orig = torch.split(embedding_original.weight.grad, + vocab_size // tensor_model_parallel_size, + 0)[mpu.get_tensor_model_parallel_rank()] + error = embedding_vocab_parallel.weight.grad.sub( + weight_grad_orig).abs().max() + print(' error in grad (vocab parallel) on global rank {}: {}'.format( + torch.distributed.get_rank(), error)) + assert error < 1.0e-4, 'error: {}'.format(error) + + # Reset groups + mpu.destroy_model_parallel() + + torch.distributed.barrier() + if torch.distributed.get_rank() == 0: + print('>> passed the test :-)') + + +def test_initialize_affine_weight(tensor_model_parallel_size): + + mpu.initialize_model_parallel(tensor_model_parallel_size) + if torch.distributed.get_rank() == 0: + print('> testing initialize_affine_weight with model parallel ' + 'size: {}'.format(tensor_model_parallel_size)) + tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size() + + seed = 12345 + input_size_coeff = 13 + input_size = input_size_coeff * tensor_model_parallel_size + output_size_coeff = 17 + output_size = output_size_coeff * tensor_model_parallel_size + + # --------------- + # Column parallel + # --------------- + weight = torch.empty(output_size_coeff, input_size) + set_random_seed(seed) + _initialize_affine_weight(weight, output_size, input_size, + output_size_coeff, 0, + torch.nn.init.normal_) + # Target. + set_random_seed(seed) + master_weight = torch.empty(output_size, input_size) + torch.nn.init.normal_(master_weight) + rank = mpu.get_tensor_model_parallel_rank() + my_weight = torch.split(master_weight, output_size_coeff, + dim=0)[rank].contiguous().clone() + + # Compare. + error = weight.sub(my_weight).abs().max() + torch.distributed.barrier() + print(' column parallel max error (should be zero) on global rank ' + '{}: {}'.format(torch.distributed.get_rank(), error)) + assert error < 1.0e-6 + + # ------------ + # Row parallel + # ------------ + weight = torch.empty(output_size, input_size_coeff) + set_random_seed(seed) + _initialize_affine_weight(weight, output_size, input_size, + input_size_coeff, 1, + torch.nn.init.normal_) + # Target. + set_random_seed(seed) + master_weight = torch.empty(output_size, input_size) + torch.nn.init.normal_(master_weight) + rank = mpu.get_tensor_model_parallel_rank() + my_weight = torch.split(master_weight, input_size_coeff, + dim=1)[rank].contiguous().clone() + + # Compare. + error = weight.sub(my_weight).abs().max() + torch.distributed.barrier() + print(' row parallel max error (should be zero) on global rank ' + '{}: {}'.format(torch.distributed.get_rank(), error)) + assert error < 1.0e-6 + + # Reset groups + mpu.destroy_model_parallel() + + torch.distributed.barrier() + if torch.distributed.get_rank() == 0: + print(' >> passed the test :-)') + + +class IdentityLayer2D(torch.nn.Module): + def __init__(self, m, n): + super(IdentityLayer2D, self).__init__() + self.weight = Parameter(torch.Tensor(m, n)) + torch.nn.init.xavier_normal_(self.weight) + + def forward(self): + return self.weight + +class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function): + """ + Linear layer execution with asynchronous communication and gradient accumulation + fusion in backprop. + """ + + @staticmethod + def forward(ctx, input, weight, bias, gradient_accumulation_fusion, + async_grad_allreduce, sequence_parallel): + ctx.save_for_backward(input, weight) + ctx.use_bias = bias is not None + ctx.gradient_accumulation_fusion = gradient_accumulation_fusion + ctx.async_grad_allreduce = async_grad_allreduce + ctx.sequence_parallel = sequence_parallel + + if sequence_parallel: + world_size = mpu.get_tensor_model_parallel_world_size() + dim_size = list(input.size()) + dim_size[0] = dim_size[0] * world_size + + all_gather_buffer = \ + get_global_memory_buffer().get_tensor(dim_size, input.dtype, "mpu") + torch.distributed._all_gather_base( + all_gather_buffer, + input, + group=mpu.initialize.get_tensor_model_parallel_group()) + total_input = all_gather_buffer + else: + total_input = input + + output = torch.matmul(total_input, weight.t()) + if bias is not None: + output = output + bias + return output + + @staticmethod + def backward(ctx, grad_output): + input, weight = ctx.saved_tensors + use_bias = ctx.use_bias + + if ctx.sequence_parallel: + world_size = mpu.get_tensor_model_parallel_world_size() + dim_size = list(input.size()) + dim_size[0] = dim_size[0] * world_size + + all_gather_buffer = \ + get_global_memory_buffer().get_tensor(dim_size, input.dtype, "mpu") + handle = torch.distributed._all_gather_base( + all_gather_buffer, + input, + group=mpu.initialize.get_tensor_model_parallel_group(), async_op=True) + + # Delay the start of intput gradient computation shortly (3us) to have + # gather scheduled first and have GPU resources allocated + _ = torch.empty(1, device=grad_output.device) + 1 + total_input = all_gather_buffer + else: + total_input = input + grad_input = grad_output.matmul(weight) + + if ctx.sequence_parallel: + handle.wait() + + # Convert the tensor shapes to 2D for execution compatibility + if len(grad_output.shape) > 2: + grad_output = grad_output.view(grad_output.shape[0] * grad_output.shape[1], + grad_output.shape[2]) + total_input = total_input.view(total_input.shape[0] * total_input.shape[1], + total_input.shape[2]) + + if ctx.async_grad_allreduce: + # Asynchronous all-reduce + handle = torch.distributed.all_reduce( + grad_input, group=mpu.initialize.get_tensor_model_parallel_group(), async_op=True) + # Delay the start of weight gradient computation shortly (3us) to have + # all-reduce scheduled first and have GPU resources allocated + _ = torch.empty(1, device=grad_output.device) + 1 + + if ctx.sequence_parallel: + assert not ctx.async_grad_allreduce + dim_size = list(input.size()) + sub_grad_input = torch.empty(dim_size, dtype=input.dtype, + device=torch.cuda.current_device(), + requires_grad=False) + # reduce_scatter + handle = torch.distributed._reduce_scatter_base(sub_grad_input, grad_input, + group=get_tensor_model_parallel_group(), + async_op=True) + # Delay the start of weight gradient computation shortly (3us) to have + # reduce scatter scheduled first and have GPU resources allocated + _ = torch.empty(1, device=grad_output.device) + 1 + + if ctx.gradient_accumulation_fusion: + import fused_dense_cuda + fused_dense_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, weight.main_grad) + grad_weight = None + else: + grad_weight = grad_output.t().matmul(total_input) + grad_bias = grad_output.sum(dim=0) if use_bias else None + + if ctx.sequence_parallel: + handle.wait() + return sub_grad_input, grad_weight, grad_bias, None, None, None + + if ctx.async_grad_allreduce: + handle.wait() + + return grad_input, grad_weight, grad_bias, None, None, None +mpu.LinearWithGradAccumulationAndAsyncCommunication = LinearWithGradAccumulationAndAsyncCommunication +mpu.layers.LinearWithGradAccumulationAndAsyncCommunication = LinearWithGradAccumulationAndAsyncCommunication + +def test_column_parallel_linear(tensor_model_parallel_size): + + mpu.initialize_model_parallel(tensor_model_parallel_size) + if torch.distributed.get_rank() == 0: + print('> testing ColumnParallelLinear with model parallel ' + 'size: {}'.format(tensor_model_parallel_size)) + tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size() + + seed = 12345 + set_random_seed(seed) + input_size_coeff = 13 + input_size = input_size_coeff * tensor_model_parallel_size + output_size_coeff = 17 + output_size = output_size_coeff * tensor_model_parallel_size + batch_size = 7 + + # Network + identity_layer = IdentityLayer2D(batch_size, input_size).npu() + linear_layer = mpu.ColumnParallelLinear( + input_size, output_size, keep_master_weight_for_test=True).npu() + linear_layer.async_tensor_model_parallel_allreduce = False + loss_weight = torch.randn([batch_size, output_size]).npu() + # Forward + input_ = identity_layer() + output, bias = linear_layer(input_) + loss = torch.mul(output, loss_weight).sum() + # Backward + loss.backward() + + # Values. + dLdY = loss_weight + X = identity_layer.weight + A = linear_layer.master_weight.npu() + dLdA = torch.matmul(dLdY.t(), X) + dLdb = torch.matmul(torch.ones(batch_size, 1).npu().t(), dLdY).view(-1) + dLdX = torch.matmul(dLdY, A) + + rank = mpu.get_tensor_model_parallel_rank() + my_dLdA = torch.split(dLdA, output_size_coeff, + dim=0)[rank].contiguous().clone() + error = my_dLdA.sub(linear_layer.weight.grad).abs().max() + torch.distributed.barrier() + print(' error in dLdA on global rank {}: {}'.format( + torch.distributed.get_rank(), error)) + assert error < 1.0e-5 + + my_dLdb = torch.split(dLdb, output_size_coeff, + dim=0)[rank].contiguous().clone() + error = my_dLdb.sub(linear_layer.bias.grad).abs().max() + torch.distributed.barrier() + print(' error in dLdb on global rank {}: {}'.format( + torch.distributed.get_rank(), error)) + assert error < 1.0e-5 + + error = dLdX.sub(identity_layer.weight.grad).abs().max() + torch.distributed.barrier() + print(' error in dLdX on global rank {}: {}'.format( + torch.distributed.get_rank(), error)) + assert error < 1.0e-5 + + # Reset groups + mpu.destroy_model_parallel() + + torch.distributed.barrier() + if torch.distributed.get_rank() == 0: + print(' >> passed the test :-)') + + +def test_row_parallel_linear(tensor_model_parallel_size): + + mpu.initialize_model_parallel(tensor_model_parallel_size) + if torch.distributed.get_rank() == 0: + print('> testing RowParallelLinear with model parallel ' + 'size: {}'.format(tensor_model_parallel_size)) + tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size() + + seed = 12345 + set_random_seed(seed) + input_size_coeff = 13 + input_size = input_size_coeff * tensor_model_parallel_size + output_size_coeff = 17 + output_size = output_size_coeff * tensor_model_parallel_size + batch_size = 7 + + # Network + identity_layer = IdentityLayer2D(batch_size, input_size).npu() + linear_layer = mpu.RowParallelLinear( + input_size, output_size, keep_master_weight_for_test=True).npu() + loss_weight = torch.randn([batch_size, output_size]).npu() + # Forward + input_ = identity_layer() + output, bias = linear_layer(input_) + loss = torch.mul(output, loss_weight).sum() + # Backward + loss.backward() + + # Values. + dLdY = loss_weight + X = identity_layer.weight + A = linear_layer.master_weight.npu() + dLdA = torch.matmul(dLdY.t(), X) + dLdb = torch.matmul(torch.ones(batch_size, 1).npu().t(), dLdY).view(-1) + dLdX = torch.matmul(dLdY, A) + + rank = mpu.get_tensor_model_parallel_rank() + my_dLdA = torch.split(dLdA, input_size_coeff, + dim=1)[rank].contiguous().clone() + error = my_dLdA.sub(linear_layer.weight.grad).abs().max() + torch.distributed.barrier() + print(' error in dLdA on global rank {}: {}'.format( + torch.distributed.get_rank(), error)) + assert error < 1.0e-6 + + error = dLdb.sub(linear_layer.bias.grad).abs().max() + torch.distributed.barrier() + print(' error in dLdb on global rank {}: {}'.format( + torch.distributed.get_rank(), error)) + assert error < 1.0e-6 + + error = dLdX.sub(identity_layer.weight.grad).abs().max() + torch.distributed.barrier() + print(' error in dLdX on global rank {}: {}'.format( + torch.distributed.get_rank(), error)) + assert error < 1.0e-6 + + # Reset groups + mpu.destroy_model_parallel() + + torch.distributed.barrier() + if torch.distributed.get_rank() == 0: + print(' >> passed the test :-)') + + +class IdentityLayer3D(torch.nn.Module): + def __init__(self, m, n, k): + super(IdentityLayer3D, self).__init__() + self.weight = Parameter(torch.Tensor(m, n, k)) + torch.nn.init.xavier_normal_(self.weight) + + def forward(self): + return self.weight + + +def parallel_self_attention(tensor_model_parallel_size, num_att_heads_per_partition, + hidden_size_per_att_head, dropout_prob, batch_size, + sequence_length): + mpu.initialize_model_parallel(tensor_model_parallel_size) + tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size() + + seed = 12345 + set_random_seed(seed) + + num_att_heads = num_att_heads_per_partition * \ + torch.distributed.get_world_size() + hidden_size = hidden_size_per_att_head * num_att_heads + + # Network + identity_layer = IdentityLayer3D(batch_size, sequence_length, + hidden_size).npu() + attention_layer = BertParallelSelfAttention(hidden_size, num_att_heads, + dropout_prob).npu() + loss_weight = torch.randn([batch_size, sequence_length, hidden_size]).npu() + attention_mask = torch.randn([batch_size, 1, 1, sequence_length]).npu() + # Forward + input_ = identity_layer() + output = attention_layer(input_, attention_mask) + loss = torch.mul(output, loss_weight).sum() + # Backward + loss.backward() + + rank = mpu.get_tensor_model_parallel_rank() + mpu.destroy_model_parallel() + return rank, hidden_size, tensor_model_parallel_size, loss, \ + attention_layer, identity_layer + + +def test_parallel_self_attention(tensor_model_parallel_size): + + if torch.distributed.get_rank() == 0: + print('> testing ParallelSelfAttention with model parallel ' + 'size: {}'.format(tensor_model_parallel_size)) + + num_att_heads_per_partition = 3 + hidden_size_per_att_head = 7 + dropout_prob = 0.0 # has to be zero + batch_size = 5 + sequence_length = 13 + + rank_1, hideen_size_1, tensor_model_parallel_size_1, loss_1, \ + attention_layer_1, identity_layer_1 = parallel_self_attention( + 1, num_att_heads_per_partition, + hidden_size_per_att_head, dropout_prob, batch_size, sequence_length) + + rank, hidden_size, tensor_model_parallel_size, loss, \ + attention_layer, identity_layer = parallel_self_attention( + tensor_model_parallel_size, num_att_heads_per_partition, + hidden_size_per_att_head, dropout_prob, batch_size, sequence_length) + assert hideen_size_1 == hidden_size + + error = loss_1.sub(loss).abs().max() + torch.distributed.barrier() + print(' loss error on global rank {}: {}'.format( + torch.distributed.get_rank(), error)) + assert error < 1.0e-5 + + my_lin_grad_list = torch.split( + attention_layer_1.query_key_value.weight.grad, + hidden_size // tensor_model_parallel_size, 0)[rank::tensor_model_parallel_size] + my_lin_grad = torch.cat(my_lin_grad_list, dim=0) + error = my_lin_grad.sub( + attention_layer.query_key_value.weight.grad).abs().max() + torch.distributed.barrier() + print(' weight gradient error on global rank {}: {}'.format( + torch.distributed.get_rank(), error)) + assert error < 1.0e-5 + + error = identity_layer_1.weight.grad.sub( + identity_layer.weight.grad).abs().max() + torch.distributed.barrier() + print(' input gradient error on global rank {}: {}'.format( + torch.distributed.get_rank(), error)) + assert error < 1.0e-5 + + torch.distributed.barrier() + if torch.distributed.get_rank() == 0: + print(' >> passed the test :-)') + + +def parallel_transformer(tensor_model_parallel_size, num_att_heads_per_partition, + hidden_size_per_att_head, batch_size, sequence_length): + + mpu.initialize_model_parallel(tensor_model_parallel_size) + tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size() + + seed = 12345 + set_random_seed(seed) + + num_att_heads = num_att_heads_per_partition * \ + torch.distributed.get_world_size() + hidden_size = hidden_size_per_att_head * num_att_heads + intermediate_size = 4 * hidden_size + + # Network + identity_layer = IdentityLayer3D(batch_size, sequence_length, + hidden_size).npu() + transformer_layer = BertParallelTransformerLayer( + hidden_size, intermediate_size, num_att_heads, 0.0, 0.0, + torch.nn.functional.relu, 1.0e-5).npu() + + loss_weight = torch.randn([batch_size, sequence_length, hidden_size]).npu() + attention_mask = torch.randn([batch_size, 1, 1, sequence_length]).npu() + # Forward + input_ = identity_layer() + output = transformer_layer(input_, attention_mask) + loss = torch.mul(output, loss_weight).sum() + # Backward + loss.backward() + + rank = mpu.get_tensor_model_parallel_rank() + mpu.destroy_model_parallel() + return rank, hidden_size, tensor_model_parallel_size, loss, \ + transformer_layer, identity_layer + + +def test_parallel_transformer_layer(tensor_model_parallel_size): + + if torch.distributed.get_rank() == 0: + print('> testing ParallelTransformerLayer with model parallel ' + 'size: {}'.format(tensor_model_parallel_size)) + + num_att_heads_per_partition = 3 + hidden_size_per_att_head = 7 + batch_size = 5 + sequence_length = 13 + + rank_1, hidden_size_1, tensor_model_parallel_size_1, loss_1, \ + transformer_layer_1, identity_layer_1 = parallel_transformer( + 1, num_att_heads_per_partition, + hidden_size_per_att_head, batch_size, sequence_length) + + rank, hidden_size, tensor_model_parallel_size, loss, \ + transformer_layer, identity_layer = parallel_transformer( + tensor_model_parallel_size, num_att_heads_per_partition, + hidden_size_per_att_head, batch_size, sequence_length) + + error = loss_1.sub(loss).abs().max() + torch.distributed.barrier() + print(' loss error on global rank {}: {}'.format( + torch.distributed.get_rank(), error)) + assert error < 1.0e-3, 'error: {}'.format(error) + + error = identity_layer_1.weight.grad.sub( + identity_layer.weight.grad).abs().max() + torch.distributed.barrier() + print(' input gradient error on global rank {}: {}'.format( + torch.distributed.get_rank(), error)) + assert error < 1.0e-3, 'error: {}'.format(error) + + torch.distributed.barrier() + if torch.distributed.get_rank() == 0: + print(' >> passed the test :-)') + + +if __name__ == '__main__': + + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + # initialize_distributed() + world_size = torch.distributed.get_world_size() + mpu.destroy_model_parallel() + + print_separator('test initialize affine weight') + tensor_model_parallel_size = 1 + while tensor_model_parallel_size <= world_size: + test_initialize_affine_weight(tensor_model_parallel_size) + tensor_model_parallel_size *= 2 + + tensor_model_parallel_size = 1 + while tensor_model_parallel_size <= world_size: + print_separator('test parallel embedding') + test_parallel_embedding(tensor_model_parallel_size) + tensor_model_parallel_size *= 2 + + print_separator('test column-parallel linear') + tensor_model_parallel_size = 1 + while tensor_model_parallel_size <= world_size: + test_column_parallel_linear(tensor_model_parallel_size) + tensor_model_parallel_size *= 2 + + print_separator('test row-parallel linear') + tensor_model_parallel_size = 1 + while tensor_model_parallel_size <= world_size: + test_row_parallel_linear(tensor_model_parallel_size) + tensor_model_parallel_size *= 2 + + print_separator('test parallel self-attention') + tensor_model_parallel_size = 1 + while tensor_model_parallel_size <= world_size: + test_parallel_self_attention(tensor_model_parallel_size) + tensor_model_parallel_size *= 2 + + print_separator('test parallel transformer') + tensor_model_parallel_size = 1 + while tensor_model_parallel_size <= world_size: + test_parallel_transformer_layer(tensor_model_parallel_size) + tensor_model_parallel_size *= 2 diff --git a/megatron_npu/tests/ut/test_random.py b/megatron_npu/tests/ut/test_random.py new file mode 100644 index 0000000000000000000000000000000000000000..7ffcf1c25f73124a6b6f711e7a680d90275a77ff --- /dev/null +++ b/megatron_npu/tests/ut/test_random.py @@ -0,0 +1,239 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +if torch.__version__>="1.8.0": + try: + import torch_npu + from torch_npu.contrib import transfer_to_npu + except: + print('WARNING! torch_npu is not imported.. Please using without npu..') +from commons import print_separator +from commons import initialize_distributed +from megatron import mpu +import sys +import os + +def _set_cuda_rng_state(new_state, device=-1): + """Sets the random number generator state of the current GPU. + + Argumentss: + new_state (torch.ByteTensor): The desired state + This function is adapted from PyTorch repo (torch.cuda.set_rng_state) + with a single change: the input state is not cloned. Cloning caused + major performance issues for +4 GPU cases. + """ + if hasattr(torch._C, '_npu_setRNGState') and callable(torch._C._npu_setRNGState): + # older PyTorch + def cb(): + with device_ctx_manager(device): + torch._C._npu_setRNGState(new_state) + else: + # newer PyTorch + if device == -1: + device = torch.device('npu') + elif isinstance(device, str): + device = torch.device(device) + elif isinstance(device, int): + device = torch.device('npu', device) + + def cb(): + idx = device.index + if idx is None: + idx = torch.npu.current_device() + default_generator = torch.npu.default_generators[idx] + default_generator.set_state(new_state) + + torch.npu._lazy_call(cb) + +for k in sys.modules: + if k.startswith('megatron'): + for target in ['_set_cuda_rng_state']: + if getattr(sys.modules[k], '_set_cuda_rng_state', None): + setattr(sys.modules[k], '_set_cuda_rng_state', _set_cuda_rng_state) + + +def test_set_cuda_rng_state(tensor_model_parallel_size): + + if torch.distributed.get_rank() == 0: + print('> testing set_rng_state with size {} ...'. + format(tensor_model_parallel_size)) + + mpu.initialize_model_parallel(tensor_model_parallel_size) + tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size() + + size = 123 + seed = 1234 + torch.npu.manual_seed(1234) + tensor = torch.npu.FloatTensor(size) + + # Get the state + rng_state = torch.npu.get_rng_state() + rng_state_copy = rng_state.clone() + + # Do some stuff. + for _ in range(5): + result_1 = torch.bernoulli(tensor, p=0.5) + + assert rng_state.sub(rng_state_copy).max() == 0 + assert torch.npu.get_rng_state().sub(rng_state_copy).max() > 0 + + # State should be different. + new_rng_state = torch.npu.get_rng_state() + max_diff = new_rng_state.sub(rng_state).max() + print(' max diff in rng state (should be non-zero) on global rank {}: {}'. + format(torch.distributed.get_rank(), max_diff)) + assert max_diff > 0 + + # Reset the rng state and do the same stuff. + _set_cuda_rng_state(rng_state) + for _ in range(5): + torch.bernoulli(tensor, p=0.5) + _set_cuda_rng_state(rng_state) + for _ in range(5): + result_2 = torch.bernoulli(tensor, p=0.5) + + # Results should be the same + error = result_2.sub(result_1).abs().max() + print(' max error in generated tensors (should be zero) on ' + 'global rank {}: {}'.format(torch.distributed.get_rank(), error)) + assert error < 1.0e-6 + + # Input state should have remained intact. + error = rng_state.sub(rng_state_copy).max() + print(' max error in rng state (should be zero) on global rank {}: {}'. + format(torch.distributed.get_rank(), error)) + assert error == 0 + + # Reset groups + mpu.destroy_model_parallel() + + #torch.distributed.barrier() + if torch.distributed.get_rank() == 0: + print('>> passed the test :-)') + + +def test_cuda_rng_tracker(tensor_model_parallel_size): + + if torch.distributed.get_rank() == 0: + print('> testing cuda rng tracker with size {} ...'. + format(tensor_model_parallel_size)) + + mpu.initialize_model_parallel(tensor_model_parallel_size) + tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size() + + seed_1 = 1234 + seed_2 = 4321 + size = [12, 21] + tensor = torch.npu.FloatTensor(*size) + + # Set to seed_1 and generate two tensors. + torch.npu.manual_seed(seed_1) + target_11 = torch.bernoulli(tensor, p=0.5) + target_12 = torch.bernoulli(tensor, p=0.5) + + # Set to seed_2 and generate two tensors. + torch.npu.manual_seed(seed_2) + target_21 = torch.bernoulli(tensor, p=0.5) + target_22 = torch.bernoulli(tensor, p=0.5) + + # Now if we interleave seed_1 and seed_2, + # we should still get the same tensors + torch.npu.manual_seed(seed_1) + mpu.get_cuda_rng_tracker().add('test', seed_2) + + # torch.randn(size, out=tensor) + result_11 = torch.bernoulli(tensor, p=0.5) + + with mpu.get_cuda_rng_tracker().fork('test'): + result_21 = torch.bernoulli(tensor, p=0.5) + + result_12 = torch.bernoulli(tensor, p=0.5) + + with mpu.get_cuda_rng_tracker().fork('test'): + result_22 = torch.bernoulli(tensor, p=0.5) + + diff = result_11.sub(result_21).abs().max() + diff = min(diff, result_12.sub(result_22).abs().max()) + print(' max diff in generated tensors (should be non-zero) on ' + 'global rank {}: {}'.format(torch.distributed.get_rank(), diff)) + assert diff > 1.0e-6 + error = max(result_11.sub(target_11).abs().max(), + result_12.sub(target_12).abs().max()) + error = max(error, result_21.sub(target_21).abs().max()) + error = max(error, result_22.sub(target_22).abs().max()) + print(' max error in generated tensors (should be zero) on ' + 'global rank {}: {}'.format(torch.distributed.get_rank(), error)) + assert error < 1.0e-6 + + # Reset the tracker + mpu.get_cuda_rng_tracker().reset() + + # Reset groups + mpu.destroy_model_parallel() + + #torch.distributed.barrier() + if torch.distributed.get_rank() == 0: + print('>> passed the test :-)') + + +def test_model_parallel_cuda_manual_seed(tensor_model_parallel_size): + + if torch.distributed.get_rank() == 0: + print('> testing model parallel cuda manual seed with size {} ...'. + format(tensor_model_parallel_size)) + + mpu.initialize_model_parallel(tensor_model_parallel_size) + tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size() + + mpu.model_parallel_cuda_manual_seed(12345) + assert torch.npu.initial_seed() == 12345 + with mpu.get_cuda_rng_tracker().fork(): + assert torch.npu.initial_seed() == (12345 + 2718 + + mpu.get_tensor_model_parallel_rank()) + + # Reset the tracker + mpu.get_cuda_rng_tracker().reset() + + # Reset groups + mpu.destroy_model_parallel() + + #torch.distributed.barrier() + if torch.distributed.get_rank() == 0: + print('>> passed the test :-)') + + +if __name__ == '__main__': + + initialize_distributed() + world_size = torch.distributed.get_world_size() + + tensor_model_parallel_size = 1 + while tensor_model_parallel_size <= world_size: + print_separator('test set rng state') + test_set_cuda_rng_state(tensor_model_parallel_size) + tensor_model_parallel_size *= 2 + + tensor_model_parallel_size = 1 + while tensor_model_parallel_size <= world_size: + print_separator('test cuda rng tracker') + test_cuda_rng_tracker(tensor_model_parallel_size) + tensor_model_parallel_size *= 2 + + tensor_model_parallel_size = 1 + while tensor_model_parallel_size <= world_size: + print_separator('test model parallel cuda manual seed') + test_model_parallel_cuda_manual_seed(tensor_model_parallel_size) + tensor_model_parallel_size *= 2