diff --git a/ascendspeed/arguments.py b/ascendspeed/arguments.py index bf58e9832334d35555d8d77e61f863106cd5ef24..bb0b4559b9ade8d877fc340de425914c2b23489a 100644 --- a/ascendspeed/arguments.py +++ b/ascendspeed/arguments.py @@ -235,6 +235,9 @@ def parse_args(extra_args_provider=None, defaults={}, assert args.hidden_size % args.num_attention_heads == 0 args.kv_channels = args.hidden_size // args.num_attention_heads + if args.tensor_model_parallel_size == 1: + args.sequence_parallel = False + if args.seq_length is not None: assert args.encoder_seq_length is None args.encoder_seq_length = args.seq_length @@ -285,7 +288,6 @@ def parse_args(extra_args_provider=None, defaults={}, _print_args(args) return args - def _print_args(args): """Print arguments.""" if args.rank == 0: @@ -695,6 +697,9 @@ def _add_distributed_args(parser): group.add_argument('--enable-expert-tensor-parallelism', action='store_true', default=False, help="use tensor parallelism for expert layers in MoE") + group.add_argument('--sequence-parallel', action='store_true', + default=False, + help="use sequence parallelism") group.add_argument('--pipeline-model-parallel-size', type=int, default=1, help='Degree of pipeline model parallelism.') group.add_argument('--moe-expert-parallel-size', type=int, default=1, diff --git a/ascendspeed/model/language_model.py b/ascendspeed/model/language_model.py index cecc1d10f1ff8aa7499401255b70a18767296c4b..2af0c7b78c5fa04c14e60f688f965e5ca5022ed9 100644 --- a/ascendspeed/model/language_model.py +++ b/ascendspeed/model/language_model.py @@ -26,13 +26,27 @@ from ascendspeed.model.module import MegatronModule from ascendspeed.model.transformer import ParallelTransformer from ascendspeed.model.utils import get_linear_layer from ascendspeed.model.utils import init_method_normal, scaled_init_method_normal - +from ascendspeed.mpu.mappings import gather_from_sequence_parallel_region +from ascendspeed.mpu.initialize import get_global_memory_buffer def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, bias=None): """LM logits using word embedding weights.""" + args = get_args() # Parallel logits. - input_parallel = mpu.copy_to_tensor_model_parallel_region(input_) + if args.sequence_parallel: + world_size = get_tensor_model_parallel_world_size() + dim_size = list(input_.size()) + dim_size[0] = dim_size[0] * world_size + + input_parallel = \ + get_global_memory_buffer().get_tensor(dim_size, input_.dtype, "mpu") + torch.distributed._all_gather_base( + input_parallel, + input_, + group=get_tensor_model_parallel_group()) + else: + input_parallel = mpu.copy_to_tensor_model_parallel_region(input_) # Matrix multiply. if bias is None: logits_parallel = F.linear(input_parallel, word_embeddings_weight) @@ -92,11 +106,17 @@ class Pooler(MegatronModule): def __init__(self, hidden_size, init_method): super(Pooler, self).__init__() + args = get_args() self.dense = get_linear_layer(hidden_size, hidden_size, init_method) + self.sequence_parallel = args.sequence_parallel def forward(self, hidden_states, sequence_index=0): # hidden_states: [b, s, h] # sequence_index: index of the token to pool. + if self.sequence_parallel: + hidden_states = gather_from_sequence_parallel_region( + hidden_states, + tensor_parallel_output_grad=False) pooled = hidden_states[:, sequence_index, :] pooled = self.dense(pooled) pooled = torch.tanh(pooled) diff --git a/ascendspeed/model/llama_model.py b/ascendspeed/model/llama_model.py index fd02fef6e1769694ddb9354d1f3d2a441062d93e..3a571f341d9a732797e909b2c49c66d381565443 100644 --- a/ascendspeed/model/llama_model.py +++ b/ascendspeed/model/llama_model.py @@ -31,6 +31,7 @@ from ascendspeed.model.enums import AttnMaskType, LayerType, AttnType from ascendspeed.model.utils import get_linear_layer, init_method_normal, scaled_init_method_normal, attention_mask_func, \ openai_gelu, erf_gelu +from ascendspeed.mpu.mappings import scatter_to_sequence_parallel_region from ascendspeed.model.fused_softmax import NPUFusedScaleMaskSoftmax from ascendspeed.model.language_model import Pooler @@ -133,10 +134,13 @@ class LlamaLMHead(MegatronModule): bias=False, gather_output=not self.parallel_output, skip_bias_add=True, - init_method=self.init_method, ) + init_method=self.init_method, + sequence_parallel_enabled=args.sequence_parallel) def forward(self, inputs): + inputs = inputs.transpose(0, 1).contiguous() logits, _ = self.lm_head(inputs) + logits = logits.transpose(0, 1).contiguous() # SBH-->BSH return logits @@ -188,10 +192,16 @@ class LlamaEmbedding(MegatronModule): # Word embeddings (parallel). self.word_embeddings = mpu.VocabParallelEmbedding(vocab_size, self.hidden_size, init_method=self.init_method) + self.sequence_parallel = args.sequence_parallel def forward(self, input_ids): # Embeddings. embeddings = self.word_embeddings(input_ids) + if self.sequence_parallel: + embeddings = embeddings.transpose(0, 1).contiguous() + embeddings = scatter_to_sequence_parallel_region(embeddings) + embeddings = embeddings.transpose(0, 1).contiguous() + return embeddings @@ -243,7 +253,8 @@ class LlamaParallelMLP(MegatronModule): init_method=self.init_method, skip_bias_add=True, moe=moe, - enable_expert_tensor_parallelism=enable_expert_tensor_parallelism + enable_expert_tensor_parallelism=enable_expert_tensor_parallelism, + sequence_parallel_enabled=args.sequence_parallel ) self.up_proj = mpu.ColumnParallelLinear( @@ -254,7 +265,8 @@ class LlamaParallelMLP(MegatronModule): init_method=self.init_method, skip_bias_add=True, moe=moe, - enable_expert_tensor_parallelism=enable_expert_tensor_parallelism + enable_expert_tensor_parallelism=enable_expert_tensor_parallelism, + sequence_parallel_enabled=args.sequence_parallel ) self.activation_func = F.silu @@ -268,7 +280,8 @@ class LlamaParallelMLP(MegatronModule): init_method=self.output_layer_init_method, skip_bias_add=True, moe=moe, - enable_expert_tensor_parallelism=enable_expert_tensor_parallelism) + enable_expert_tensor_parallelism=enable_expert_tensor_parallelism, + sequence_parallel_enabled=args.sequence_parallel) def forward(self, hidden_states): intermediate_parallel = self.activation_func(self.gate_proj(hidden_states)[0]) * self.up_proj(hidden_states)[0] @@ -296,7 +309,7 @@ class LlamaParallelAttention(MegatronModule): args = get_args() self.fp16 = args.fp16 self.bf16 = args.bf16 - + self.sequence_parallel = args.sequence_parallel self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling self.attention_softmax_in_fp32 = args.attention_softmax_in_fp32 if self.apply_query_key_layer_scaling: @@ -326,7 +339,8 @@ class LlamaParallelAttention(MegatronModule): 3 * projection_size, bias=False, gather_output=False, - init_method=self.init_method) + init_method=self.init_method, + sequence_parallel_enabled=self.sequence_parallel) coeff = None self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) @@ -352,7 +366,8 @@ class LlamaParallelAttention(MegatronModule): bias=False, input_is_parallel=True, init_method=self.output_layer_init_method, - skip_bias_add=True) + skip_bias_add=True, + sequence_parallel_enabled=self.sequence_parallel) if deepspeed.checkpointing.is_configured(): global get_cuda_rng_tracker, checkpoint @@ -559,6 +574,7 @@ class LlamaParallelTransformerLayer(MegatronModule): eps=args.layernorm_epsilon) # MLP + self.rank = args.rank self.mlp = LlamaParallelMLP(self.init_method, self.output_layer_init_method) def forward(self, hidden_states, attention_mask=None, diff --git a/ascendspeed/model/transformer.py b/ascendspeed/model/transformer.py index f3e4058c32897df8297175e34eaf976df8e752fc..f5571c83e54d9436ca57463f56ca753da9d00b54 100644 --- a/ascendspeed/model/transformer.py +++ b/ascendspeed/model/transformer.py @@ -160,20 +160,23 @@ class ParallelAttention(MegatronModule): args.hidden_size, 3 * projection_size, gather_output=False, - init_method=init_method) + init_method=init_method, + sequence_parallel_enabled=args.sequence_parallel) else: assert attention_type == AttnType.cross_attn self.query = mpu.ColumnParallelLinear( args.hidden_size, projection_size, gather_output=False, - init_method=init_method) + init_method=init_method, + sequence_parallel_enabled=args.sequence_parallel) self.key_value = mpu.ColumnParallelLinear( args.hidden_size, 2 * projection_size, gather_output=False, - init_method=init_method) + init_method=init_method, + sequence_parallel_enabled=args.sequence_parallel) coeff = None self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) @@ -200,7 +203,8 @@ class ParallelAttention(MegatronModule): args.hidden_size, input_is_parallel=True, init_method=output_layer_init_method, - skip_bias_add=True) + skip_bias_add=True, + sequence_parallel_enabled=args.sequence_parallel) if deepspeed.checkpointing.is_configured(): global get_cuda_rng_tracker, checkpoint diff --git a/ascendspeed/mpu/initialize.py b/ascendspeed/mpu/initialize.py index 301a98d1b71ce7c41cd0744d99d319d6829bf866..98a544cc98bdb13d208577189c0620e76451fa4b 100644 --- a/ascendspeed/mpu/initialize.py +++ b/ascendspeed/mpu/initialize.py @@ -18,7 +18,7 @@ import torch -from .utils import ensure_divisibility +from .utils import ensure_divisibility, GlobalMemoryBuffer # Intra-layer model parallel group that the current rank belongs to. @@ -45,6 +45,9 @@ _MPU_PIPELINE_MODEL_PARALLEL_RANK = None # rank when broadcasting from the first or last pipeline stage _PIPELINE_GLOBAL_RANKS = None +# Memory buffers to avoid dynamic memory allocation +_GLOBAL_MEMORY_BUFFER = None + def is_unitialized(): """Useful for code segments that may be accessed with or without mpu initialization""" return _DATA_PARALLEL_GROUP is None @@ -166,6 +169,12 @@ def initialize_model_parallel(tensor_model_parallel_size_=1, group = torch.distributed.new_group(embedding_ranks) if rank in embedding_ranks: _EMBEDDING_GROUP = group + + # Initialize global memory buffer + # This isn't really "parallel state" but there isn't another good place to + # put this. If we end up with a more generic initialization of megatron-core + # we could stick it there + _set_global_memory_buffer() def model_parallel_is_initialized(): @@ -364,6 +373,19 @@ def get_data_parallel_rank(): return torch.distributed.get_rank(group=get_data_parallel_group()) +def _set_global_memory_buffer(): + """Initialize global buffer""" + global _GLOBAL_MEMORY_BUFFER + assert _GLOBAL_MEMORY_BUFFER is None, 'global memory buffer is already initialized' + _GLOBAL_MEMORY_BUFFER = GlobalMemoryBuffer() + + +def get_global_memory_buffer(): + """Return the global GlobalMemoryBuffer object""" + assert _GLOBAL_MEMORY_BUFFER is not None, 'global memory buffer is not initialized' + return _GLOBAL_MEMORY_BUFFER + + def destroy_model_parallel(): """Set the groups to none.""" global _TENSOR_MODEL_PARALLEL_GROUP diff --git a/ascendspeed/mpu/layers.py b/ascendspeed/mpu/layers.py index 7eed59567adf6e5e80aae7691e8cd342ffd09659..15f03a5d49e23659298fb71f3ab1fb9b5c981aa1 100644 --- a/ascendspeed/mpu/layers.py +++ b/ascendspeed/mpu/layers.py @@ -19,7 +19,7 @@ import math - +from typing import List, Optional import deepspeed.runtime.activation_checkpointing.checkpointing as ds_checkpointing import torch import torch.nn.functional as F @@ -29,11 +29,14 @@ from deepspeed.accelerator import get_accelerator from ascendspeed import get_args, mpu from ascendspeed.mpu.initialize import get_tensor_model_parallel_rank +from ascendspeed.mpu.initialize import get_tensor_model_parallel_group from ascendspeed.mpu.initialize import get_tensor_model_parallel_world_size +from ascendspeed.mpu.initialize import get_global_memory_buffer from ascendspeed.mpu.mappings import copy_to_tensor_model_parallel_region from ascendspeed.mpu.mappings import gather_from_tensor_model_parallel_region from ascendspeed.mpu.mappings import reduce_from_tensor_model_parallel_region from ascendspeed.mpu.mappings import scatter_to_tensor_model_parallel_region +from ascendspeed.mpu.mappings import reduce_scatter_to_sequence_parallel_region from ascendspeed.mpu.random import get_cuda_rng_tracker from ascendspeed.mpu.utils import divide, split_tensor_along_last_dim, VocabUtility from ascendspeed.model.fused_layer_norm import MixedFusedLayerNorm as LayerNorm @@ -210,6 +213,157 @@ class VocabParallelEmbedding(torch.nn.Module): return output +class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function): + """See linear_with_grad_accumulation_and_async_allreduce""" + + @staticmethod + def forward(ctx, input_, weight, bias, sequence_parallel): + ctx.save_for_backward(input_, weight) + ctx.use_bias = bias is not None + ctx.sequence_parallel = sequence_parallel + + + if sequence_parallel: + world_size = get_tensor_model_parallel_world_size() + dim_size = list(input_.size()) + dim_size[0] = dim_size[0] * world_size + + all_gather_buffer = \ + get_global_memory_buffer().get_tensor(dim_size, input_.dtype, "mpu") + torch.distributed._all_gather_base( + all_gather_buffer, + input_, + group=get_tensor_model_parallel_group()) + + total_input = all_gather_buffer + else: + total_input = input_ + + output = torch.matmul(total_input, weight.t()) + if bias is not None: + output = output + bias + return output + + @staticmethod + def backward(ctx, grad_output): + input_, weight = ctx.saved_tensors + use_bias = ctx.use_bias + + if ctx.sequence_parallel: + world_size = get_tensor_model_parallel_world_size() + dim_size = list(input_.size()) + dim_size[0] = dim_size[0] * world_size + + all_gather_buffer = \ + get_global_memory_buffer().get_tensor(dim_size, input_.dtype, "mpu") + handle = torch.distributed._all_gather_base( + all_gather_buffer, + input_, + group=get_tensor_model_parallel_group(), async_op=True) + + # Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the + # gather is scheduled before the input gradient computation + total_input = all_gather_buffer + else: + total_input = input_ + grad_input = grad_output.matmul(weight) + + if ctx.sequence_parallel: + handle.wait() + + # Convert the tensor shapes to 2D for execution compatibility + grad_output = grad_output.reshape(grad_output.shape[0] * grad_output.shape[1], + grad_output.shape[2]) + total_input = total_input.reshape(total_input.shape[0] * total_input.shape[1], + total_input.shape[2]) + + if ctx.sequence_parallel: + dim_size = list(input_.size()) + sub_grad_input = torch.empty(dim_size, dtype=input_.dtype, + device=torch.cuda.current_device(), + requires_grad=False) + handle = torch.distributed._reduce_scatter_base(sub_grad_input, grad_input, + group=get_tensor_model_parallel_group(), + async_op=True) + # Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the + # reduce scatter is scheduled before the weight gradient computation + + + + grad_weight = grad_output.t().matmul(total_input) + grad_bias = grad_output.sum(dim=0) if use_bias else None + + if ctx.sequence_parallel: + handle.wait() + return sub_grad_input, grad_weight, grad_bias, None, None, None + + return grad_input, grad_weight, grad_bias, None, None, None + + +def linear_with_grad_accumulation_and_async_allreduce( + input_: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor], + sequence_parallel_enabled: bool, +) -> torch.Tensor: + """Linear layer execution with asynchronous communication and + gradient accumulation fusion in backprop. + + This has the option to accumulate the result of backprop + calculation into an existing gradient buffer, preventing the need + to do an additional addition kernel after the gradient + calculation. + + Additionally, the tensor parallel all reduce of the input + gradients can be done asynchronously with the calculation of + the weight gradients. + + In the case of sequence parallelism, the reduce scatter of the + input gradients is done asynchronously with the calcluation of the + weight gradients. + + Use of this module requires that the environment variable + CUDA_DEVICE_MAX_CONNECTIONS=1. There are a few collective + operations, noted in the code, that should be scheduled before + compute kernels to overlap the communication with the computation, + which is necessary for a speedup but not for correctness so that + ordering isn't imposed by the scheduler. Setting + CUDA_DEVICE_MAX_CONNECTIONS=1 forces the kernels to be scheduled + in the order they are called. + + Arguments: + + input (torch.Tensor required): input like torch.nn.functional.linear + + weight (torch.Tensor required): weight like torch.nn.functional.linear + + bias (torch.Tensor optional): bias like torch.nn.functional.linear + + gradient_accumulation_fusion (bool required): Perform the gradient + accumulation fusion, requires the custom CUDA extension + fused_weight_gradient_mlp_cuda module. To use + gradient_accumulation_fusion you must install APEX with + --cpp_ext and --cuda_ext. For example: "pip install + --global-option=\"--cpp_ext\" --global-option=\"--cuda_ext .\" + " Note that the extension requires CUDA>=11. Otherwise, you + must turn off gradient accumulation fusion." + + sequence_parallel_enabled (bool required): Indicates that sequence + parallelism is used and thus in the forward pass the input is + all gathered, and the backward pass the input gradients are + reduce scattered. + """ + args = [ + input_, + weight, + bias, + sequence_parallel_enabled, + ] + + with torch.cuda.amp.autocast(enabled=False): + return LinearWithGradAccumulationAndAsyncCommunication.apply(*args) +linear_with_grad_accumulation_and_async_allreduce.warned = False + class ColumnParallelLinear(torch.nn.Module): """Linear layer with column parallelism. @@ -237,7 +391,10 @@ class ColumnParallelLinear(torch.nn.Module): def __init__(self, input_size, output_size, bias=True, gather_output=True, init_method=init.xavier_normal_, stride=1, keep_master_weight_for_test=False, - skip_bias_add=False, moe=False, enable_expert_tensor_parallelism=False): + skip_bias_add=False, moe=False, + enable_expert_tensor_parallelism=False, + sequence_parallel_enabled: bool = False + ): super(ColumnParallelLinear, self).__init__() # Keep input parameters @@ -290,12 +447,20 @@ class ColumnParallelLinear(torch.nn.Module): self.bias.zero_() else: self.register_parameter('bias', None) + if sequence_parallel_enabled: + if world_size <= 1: + warnings.warn( + f"`sequence_parallel_enabled` is set to `True`, but tensor model parallel size is {world_size}. " + f"Disabling sequence parallel." + ) + sequence_parallel_enabled = False + self.sequence_parallel_enabled = sequence_parallel_enabled def forward(self, input_): # Set up backprop all-reduce. - if self.is_expert_without_slicing: # non-expert only tensor parallelism + if self.is_expert_without_slicing or self.sequence_parallel_enabled: # non-expert only tensor parallelism input_parallel = input_ else: input_parallel = copy_to_tensor_model_parallel_region(input_) @@ -303,9 +468,17 @@ class ColumnParallelLinear(torch.nn.Module): # Matrix multiply. bias = self.bias if not self.skip_bias_add else None - output_parallel = F.linear(input_parallel, self.weight, bias) + + output_parallel = linear_with_grad_accumulation_and_async_allreduce( + input_=input_parallel, + weight=self.weight, + bias=bias, + sequence_parallel_enabled=self.sequence_parallel_enabled, + ) + if self.gather_output and not self.is_expert_without_slicing: # All-gather across the partitions. + assert not self.sequence_parallel_enabled output = gather_from_tensor_model_parallel_region(output_parallel) else: output = output_parallel @@ -347,7 +520,9 @@ class RowParallelLinear(torch.nn.Module): input_is_parallel=False, init_method=init.xavier_normal_, stride=1, keep_master_weight_for_test=False, - skip_bias_add=False, moe=False, enable_expert_tensor_parallelism=False): + skip_bias_add=False, moe=False, + enable_expert_tensor_parallelism=False, + sequence_parallel_enabled: bool = False): super(RowParallelLinear, self).__init__() # Keep input parameters @@ -365,7 +540,9 @@ class RowParallelLinear(torch.nn.Module): self.input_size_per_partition = divide(input_size, world_size) self.skip_bias_add = skip_bias_add - + self.sequence_parallel_enabled = sequence_parallel_enabled + if self.sequence_parallel_enabled and not self.input_is_parallel: + raise RuntimeError("To enable `sequence_parallel_enabled`, `input_is_parallel` must be `True`") # Parameters. # Note: torch.nn.functional.linear performs XA^T + b and as a result # we allocate the transpose. @@ -406,14 +583,24 @@ class RowParallelLinear(torch.nn.Module): if self.input_is_parallel or self.is_expert_without_slicing: input_parallel = input_ else: + assert not self.sequence_parallel_enabled input_parallel = scatter_to_tensor_model_parallel_region(input_) # Matrix multiply. - output_parallel = F.linear(input_parallel, self.weight) + output_parallel = linear_with_grad_accumulation_and_async_allreduce( + input_=input_parallel, + weight=self.weight, + bias=None, + sequence_parallel_enabled=False, + ) + # All-reduce across all the partitions. if self.is_expert_without_slicing: # non-expert only tensor-parallelism output_ = output_parallel else: - output_ = reduce_from_tensor_model_parallel_region(output_parallel) + if self.sequence_parallel_enabled: + output_ = reduce_scatter_to_sequence_parallel_region(output_parallel) + else: + output_ = reduce_from_tensor_model_parallel_region(output_parallel) if not self.skip_bias_add: output = output_ + self.bias if self.bias is not None else output_ diff --git a/ascendspeed/mpu/mappings.py b/ascendspeed/mpu/mappings.py index 821d9acfecd6c0349ca0e9e6c1a4842aef5ace8b..e70fb91b0960d48a61ab6265898b949935bdfffd 100644 --- a/ascendspeed/mpu/mappings.py +++ b/ascendspeed/mpu/mappings.py @@ -72,6 +72,105 @@ def _gather(input_): return output +def _gather_along_first_dim(input_): + """Gather tensors and concatinate along the first dimension.""" + + world_size = get_tensor_model_parallel_world_size() + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + + dim_size = list(input_.size()) + dim_size[0] = dim_size[0] * world_size + + output = torch.empty(dim_size, dtype=input_.dtype, + device=torch.cuda.current_device()) + torch.distributed._all_gather_base(output, input_.contiguous(), + group=get_tensor_model_parallel_group()) + + return output + +def _reduce_scatter_along_first_dim(input_): + """Reduce-scatter the input tensor across model parallel group.""" + world_size = get_tensor_model_parallel_world_size() + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + + dim_size = list(input_.size()) + assert dim_size[0] % world_size == 0, \ + "First dimension of the tensor should be divisible by tensor parallel size" + + dim_size[0] = dim_size[0] // world_size + + output = torch.empty(dim_size, dtype=input_.dtype, + device=torch.cuda.current_device()) + torch.distributed._reduce_scatter_base(output, input_.contiguous(), + group=get_tensor_model_parallel_group()) + return output + +def _split_along_first_dim(input_): + """Split the tensor along its first dimension and keep the + corresponding slice.""" + + world_size = get_tensor_model_parallel_world_size() + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + + # Split along first dimension. + dim_size = input_.size()[0] + assert dim_size % world_size == 0, \ + "First dimension of the tensor should be divisible by tensor parallel size" + local_dim_size = dim_size // world_size + rank = get_tensor_model_parallel_rank() + dim_offset = rank * local_dim_size + + output = input_[dim_offset:dim_offset+local_dim_size].contiguous() + + return output + +class _ScatterToSequenceParallelRegion(torch.autograd.Function): + """Split the input and keep only the corresponding chuck to the rank.""" + + @staticmethod + def symbolic(graph, input_): + return _split_along_first_dim(input_) + + @staticmethod + def forward(ctx, input_): + return _split_along_first_dim(input_) + + @staticmethod + def backward(ctx, grad_output): + return _gather_along_first_dim(grad_output) + + +class _GatherFromSequenceParallelRegion(torch.autograd.Function): + """Gather the input from sequence parallel region and concatinate.""" + + @staticmethod + def symbolic(graph, input_, tensor_parallel_output_grad=True): + return _gather_along_first_dim(input_) + + @staticmethod + def forward(ctx, input_, tensor_parallel_output_grad=True): + ctx.tensor_parallel_output_grad = tensor_parallel_output_grad + return _gather_along_first_dim(input_) + + @staticmethod + def backward(ctx, grad_output): + tensor_parallel_output_grad = ctx.tensor_parallel_output_grad + + # If the computation graph after the gather operation is + # in the tensor parallel mode, output gradients need to reduce + # scattered and whereas if the computation is duplicated, + # output gradients need to be scattered. + if tensor_parallel_output_grad: + return _reduce_scatter_along_first_dim(grad_output), None + else: + return _split_along_first_dim(grad_output), None + class _CopyToModelParallelRegion(torch.autograd.Function): """Pass the input to the model parallel region.""" @@ -121,6 +220,23 @@ class _ScatterToModelParallelRegion(torch.autograd.Function): return _gather(grad_output) +class _ReduceScatterToSequenceParallelRegion(torch.autograd.Function): + """Reduce scatter the input from the model parallel region.""" + + @staticmethod + def symbolic(graph, input_): + return _reduce_scatter_along_first_dim(input_) + + @staticmethod + def forward(ctx, input_): + return _reduce_scatter_along_first_dim(input_) + + @staticmethod + def backward(ctx, grad_output): + return _gather_along_first_dim(grad_output) + + + class _GatherFromModelParallelRegion(torch.autograd.Function): """Gather the input from model parallel region and concatinate.""" @@ -153,5 +269,14 @@ def scatter_to_tensor_model_parallel_region(input_): return _ScatterToModelParallelRegion.apply(input_) +def gather_from_sequence_parallel_region(input_, tensor_parallel_output_grad=True): + return _GatherFromSequenceParallelRegion.apply(input_, tensor_parallel_output_grad) + def gather_from_tensor_model_parallel_region(input_): return _GatherFromModelParallelRegion.apply(input_) + +def reduce_scatter_to_sequence_parallel_region(input_): + return _ReduceScatterToSequenceParallelRegion.apply(input_) + +def scatter_to_sequence_parallel_region(input_): + return _ScatterToSequenceParallelRegion.apply(input_) \ No newline at end of file diff --git a/ascendspeed/mpu/utils.py b/ascendspeed/mpu/utils.py index 2431a1c39c334c5a40b1a8c46cfb6428ac40f552..229a98f068916f88fe6a6d1c4001097c76fb1d5e 100644 --- a/ascendspeed/mpu/utils.py +++ b/ascendspeed/mpu/utils.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from functools import reduce +import operator import torch @@ -29,6 +31,27 @@ def divide(numerator, denominator): ensure_divisibility(numerator, denominator) return numerator // denominator +class GlobalMemoryBuffer: + """Global buffer to avoid dynamic memory allocations. + Caller should ensure that buffers of the same name + are not used concurrently. + """ + + def __init__(self): + self.buffer = {} + + def get_tensor(self, tensor_shape, dtype, name): + required_len = reduce(operator.mul, tensor_shape, 1) + if self.buffer.get((name, dtype), None) is None or \ + self.buffer[(name, dtype)].numel() < required_len: + self.buffer[(name, dtype)] = \ + torch.empty(required_len, + dtype=dtype, + device=torch.cuda.current_device(), + requires_grad=False) + + return self.buffer[(name, dtype)][0:required_len].view(*tensor_shape) + def _kernel_make_viewless_tensor(inp, requires_grad): '''Make a viewless tensor. diff --git a/ascendspeed/p2p_communication.py b/ascendspeed/p2p_communication.py index 4bd06a3fbbf6a8c54389947c7de244f7f74ff810..056159a6b95e61d39abc01b79fbdebdff51b422f 100644 --- a/ascendspeed/p2p_communication.py +++ b/ascendspeed/p2p_communication.py @@ -57,9 +57,13 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next, tensor_shape = recv_tensor_shape if args.optimized_pipeline and (recv_prev or recv_next) \ else (args.seq_length, args.micro_batch_size, args.hidden_size) - if args.scatter_gather_tensors_in_pipeline: + if args.sequence_parallel: + seq_length = args.seq_length // mpu.get_tensor_model_parallel_world_size() + tensor_shape = (seq_length, args.micro_batch_size, args.hidden_size) + + if args.scatter_gather_tensors_in_pipeline and not args.sequence_parallel: tensor_chunk_shape = reduce(operator.mul, tensor_shape, 1) // \ - mpu.get_tensor_model_parallel_world_size() + mpu.get_tensor_model_parallel_world_size() else: tensor_chunk_shape = tensor_shape dtype = args.params_dtype @@ -77,7 +81,7 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next, dtype=dtype) # Split tensor into smaller chunks if using scatter-gather optimization. - if args.scatter_gather_tensors_in_pipeline: + if args.scatter_gather_tensors_in_pipeline and not args.sequence_parallel: if tensor_send_next is not None: tensor_send_next = mpu.split_tensor_into_1d_equal_chunks(tensor_send_next) @@ -141,7 +145,7 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next, get_accelerator().synchronize() # If using scatter-gather optimization, gather smaller chunks. - if args.scatter_gather_tensors_in_pipeline: + if args.scatter_gather_tensors_in_pipeline and not args.sequence_parallel: if recv_prev: tensor_recv_prev = mpu.gather_split_1d_tensor( tensor_recv_prev).view(tensor_shape).requires_grad_() diff --git a/ascendspeed/training.py b/ascendspeed/training.py index d9fcb8b4d41a09d401141e15ae7eff73aea4137d..658b77a3892c30b92bcc2cdf1ce5ce13332b77ae 100644 --- a/ascendspeed/training.py +++ b/ascendspeed/training.py @@ -59,6 +59,7 @@ from deepspeed.accelerator import get_accelerator from deepspeed.compression.compress import init_compression, redundancy_clean from deepspeed.runtime.data_pipeline.data_routing.helper import convert_to_random_ltd + # The earliest we can measure the start time. _TRAIN_START_TIME = time.time() diff --git a/examples/llama_task/pretrain_llama_td_8p_sp.sh b/examples/llama_task/pretrain_llama_td_8p_sp.sh new file mode 100644 index 0000000000000000000000000000000000000000..e12c78e6272738218bf6f4be67e5693f9023aea2 --- /dev/null +++ b/examples/llama_task/pretrain_llama_td_8p_sp.sh @@ -0,0 +1,53 @@ +# This is an example: train llama using TD, +# the number of parameters is not aligned +export LD_LIBRARY_PATH=/usr/local/lib:/root/miniconda3/lib:$LD_LIBRARY_PATH +export HCCL_CONNECT_TIMEOUT=1200 +source /usr/local/Ascend/ascend-toolkit/set_env.sh +GPUS_PER_NODE=8 +# Change for multinode config +MASTER_ADDR=localhost +MASTER_PORT=6000 +NNODES=1 +NODE_RANK=0 +WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES)) + +DATA_PATH=./dataset/llama_text_document +CHECKPOINT_PATH=./ckpt + +DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE --nnodes $NNODES --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT" + +python -m torch.distributed.launch $DISTRIBUTED_ARGS \ + pretrain_llama.py \ + --DDP-impl local \ + --use-contiguous-buffers-in-ddp \ + --tensor-model-parallel-size 4 \ + --num-layers 8 \ + --hidden-size 4096 \ + --num-attention-heads 32 \ + --micro-batch-size 2 \ + --global-batch-size 16 \ + --seq-length 2048 \ + --max-position-embeddings 2048 \ + --train-iters 500000 \ + --lr-decay-iters 320000 \ + --save $CHECKPOINT_PATH \ + --load $CHECKPOINT_PATH \ + --data-path $DATA_PATH \ + --tokenizer-name-or-path ./dataset/llama/ \ + --tokenizer-not-use-fast \ + --data-impl mmap \ + --split 949,50,1 \ + --distributed-backend nccl \ + --lr 0.00015 \ + --lr-decay-style cosine \ + --min-lr 1.0e-5 \ + --weight-decay 1e-2 \ + --clip-grad 1.0 \ + --lr-warmup-fraction .01 \ + --checkpoint-activations \ + --log-interval 10 \ + --save-interval 10000 \ + --eval-interval 1000 \ + --eval-iters 10 \ + --sequence-parallel \ + --fp16 | tee logs/train.log diff --git a/tests/st/test_llama/test_llama_ptd.sh b/tests/st/test_llama/test_llama_ptd.sh index d17fa0c5c4f8312ab17edfc193be507a859a3cc3..fe635229972c05f6ee832603385ab5a4164f47f6 100644 --- a/tests/st/test_llama/test_llama_ptd.sh +++ b/tests/st/test_llama/test_llama_ptd.sh @@ -51,4 +51,5 @@ python3 -m torch.distributed.launch $DISTRIBUTED_ARGS \ --save-interval 10000 \ --eval-interval 1000 \ --eval-iters 10 \ + --sequence-parallel \ --fp16