From 064f6350d05a8f03eb6c2ac955c560693c84c6f6 Mon Sep 17 00:00:00 2001 From: kingsleyandher <13825737389@163.com> Date: Mon, 10 Jul 2023 14:44:42 +0800 Subject: [PATCH 1/2] =?UTF-8?q?=E6=8F=90=E4=BA=A4SP=E7=AE=97=E6=B3=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ascendspeed/arguments.py | 7 +- ascendspeed/model/language_model.py | 24 ++- ascendspeed/model/llama_model.py | 30 ++- ascendspeed/model/transformer.py | 12 +- ascendspeed/mpu/initialize.py | 24 ++- ascendspeed/mpu/layers.py | 203 +++++++++++++++++- ascendspeed/mpu/mappings.py | 125 +++++++++++ ascendspeed/mpu/utils.py | 23 ++ ascendspeed/p2p_communication.py | 12 +- ascendspeed/training.py | 16 +- .../llama_task/pretrain_llama_td_8p_sp.sh | 53 +++++ tests/st/test_llama/test_llama_ptd.sh | 1 + 12 files changed, 499 insertions(+), 31 deletions(-) create mode 100644 examples/llama_task/pretrain_llama_td_8p_sp.sh diff --git a/ascendspeed/arguments.py b/ascendspeed/arguments.py index bf58e9832..bb0b4559b 100644 --- a/ascendspeed/arguments.py +++ b/ascendspeed/arguments.py @@ -235,6 +235,9 @@ def parse_args(extra_args_provider=None, defaults={}, assert args.hidden_size % args.num_attention_heads == 0 args.kv_channels = args.hidden_size // args.num_attention_heads + if args.tensor_model_parallel_size == 1: + args.sequence_parallel = False + if args.seq_length is not None: assert args.encoder_seq_length is None args.encoder_seq_length = args.seq_length @@ -285,7 +288,6 @@ def parse_args(extra_args_provider=None, defaults={}, _print_args(args) return args - def _print_args(args): """Print arguments.""" if args.rank == 0: @@ -695,6 +697,9 @@ def _add_distributed_args(parser): group.add_argument('--enable-expert-tensor-parallelism', action='store_true', default=False, help="use tensor parallelism for expert layers in MoE") + group.add_argument('--sequence-parallel', action='store_true', + default=False, + help="use sequence parallelism") group.add_argument('--pipeline-model-parallel-size', type=int, default=1, help='Degree of pipeline model parallelism.') group.add_argument('--moe-expert-parallel-size', type=int, default=1, diff --git a/ascendspeed/model/language_model.py b/ascendspeed/model/language_model.py index cecc1d10f..2af0c7b78 100644 --- a/ascendspeed/model/language_model.py +++ b/ascendspeed/model/language_model.py @@ -26,13 +26,27 @@ from ascendspeed.model.module import MegatronModule from ascendspeed.model.transformer import ParallelTransformer from ascendspeed.model.utils import get_linear_layer from ascendspeed.model.utils import init_method_normal, scaled_init_method_normal - +from ascendspeed.mpu.mappings import gather_from_sequence_parallel_region +from ascendspeed.mpu.initialize import get_global_memory_buffer def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, bias=None): """LM logits using word embedding weights.""" + args = get_args() # Parallel logits. - input_parallel = mpu.copy_to_tensor_model_parallel_region(input_) + if args.sequence_parallel: + world_size = get_tensor_model_parallel_world_size() + dim_size = list(input_.size()) + dim_size[0] = dim_size[0] * world_size + + input_parallel = \ + get_global_memory_buffer().get_tensor(dim_size, input_.dtype, "mpu") + torch.distributed._all_gather_base( + input_parallel, + input_, + group=get_tensor_model_parallel_group()) + else: + input_parallel = mpu.copy_to_tensor_model_parallel_region(input_) # Matrix multiply. if bias is None: logits_parallel = F.linear(input_parallel, word_embeddings_weight) @@ -92,11 +106,17 @@ class Pooler(MegatronModule): def __init__(self, hidden_size, init_method): super(Pooler, self).__init__() + args = get_args() self.dense = get_linear_layer(hidden_size, hidden_size, init_method) + self.sequence_parallel = args.sequence_parallel def forward(self, hidden_states, sequence_index=0): # hidden_states: [b, s, h] # sequence_index: index of the token to pool. + if self.sequence_parallel: + hidden_states = gather_from_sequence_parallel_region( + hidden_states, + tensor_parallel_output_grad=False) pooled = hidden_states[:, sequence_index, :] pooled = self.dense(pooled) pooled = torch.tanh(pooled) diff --git a/ascendspeed/model/llama_model.py b/ascendspeed/model/llama_model.py index fd02fef6e..3a571f341 100644 --- a/ascendspeed/model/llama_model.py +++ b/ascendspeed/model/llama_model.py @@ -31,6 +31,7 @@ from ascendspeed.model.enums import AttnMaskType, LayerType, AttnType from ascendspeed.model.utils import get_linear_layer, init_method_normal, scaled_init_method_normal, attention_mask_func, \ openai_gelu, erf_gelu +from ascendspeed.mpu.mappings import scatter_to_sequence_parallel_region from ascendspeed.model.fused_softmax import NPUFusedScaleMaskSoftmax from ascendspeed.model.language_model import Pooler @@ -133,10 +134,13 @@ class LlamaLMHead(MegatronModule): bias=False, gather_output=not self.parallel_output, skip_bias_add=True, - init_method=self.init_method, ) + init_method=self.init_method, + sequence_parallel_enabled=args.sequence_parallel) def forward(self, inputs): + inputs = inputs.transpose(0, 1).contiguous() logits, _ = self.lm_head(inputs) + logits = logits.transpose(0, 1).contiguous() # SBH-->BSH return logits @@ -188,10 +192,16 @@ class LlamaEmbedding(MegatronModule): # Word embeddings (parallel). self.word_embeddings = mpu.VocabParallelEmbedding(vocab_size, self.hidden_size, init_method=self.init_method) + self.sequence_parallel = args.sequence_parallel def forward(self, input_ids): # Embeddings. embeddings = self.word_embeddings(input_ids) + if self.sequence_parallel: + embeddings = embeddings.transpose(0, 1).contiguous() + embeddings = scatter_to_sequence_parallel_region(embeddings) + embeddings = embeddings.transpose(0, 1).contiguous() + return embeddings @@ -243,7 +253,8 @@ class LlamaParallelMLP(MegatronModule): init_method=self.init_method, skip_bias_add=True, moe=moe, - enable_expert_tensor_parallelism=enable_expert_tensor_parallelism + enable_expert_tensor_parallelism=enable_expert_tensor_parallelism, + sequence_parallel_enabled=args.sequence_parallel ) self.up_proj = mpu.ColumnParallelLinear( @@ -254,7 +265,8 @@ class LlamaParallelMLP(MegatronModule): init_method=self.init_method, skip_bias_add=True, moe=moe, - enable_expert_tensor_parallelism=enable_expert_tensor_parallelism + enable_expert_tensor_parallelism=enable_expert_tensor_parallelism, + sequence_parallel_enabled=args.sequence_parallel ) self.activation_func = F.silu @@ -268,7 +280,8 @@ class LlamaParallelMLP(MegatronModule): init_method=self.output_layer_init_method, skip_bias_add=True, moe=moe, - enable_expert_tensor_parallelism=enable_expert_tensor_parallelism) + enable_expert_tensor_parallelism=enable_expert_tensor_parallelism, + sequence_parallel_enabled=args.sequence_parallel) def forward(self, hidden_states): intermediate_parallel = self.activation_func(self.gate_proj(hidden_states)[0]) * self.up_proj(hidden_states)[0] @@ -296,7 +309,7 @@ class LlamaParallelAttention(MegatronModule): args = get_args() self.fp16 = args.fp16 self.bf16 = args.bf16 - + self.sequence_parallel = args.sequence_parallel self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling self.attention_softmax_in_fp32 = args.attention_softmax_in_fp32 if self.apply_query_key_layer_scaling: @@ -326,7 +339,8 @@ class LlamaParallelAttention(MegatronModule): 3 * projection_size, bias=False, gather_output=False, - init_method=self.init_method) + init_method=self.init_method, + sequence_parallel_enabled=self.sequence_parallel) coeff = None self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) @@ -352,7 +366,8 @@ class LlamaParallelAttention(MegatronModule): bias=False, input_is_parallel=True, init_method=self.output_layer_init_method, - skip_bias_add=True) + skip_bias_add=True, + sequence_parallel_enabled=self.sequence_parallel) if deepspeed.checkpointing.is_configured(): global get_cuda_rng_tracker, checkpoint @@ -559,6 +574,7 @@ class LlamaParallelTransformerLayer(MegatronModule): eps=args.layernorm_epsilon) # MLP + self.rank = args.rank self.mlp = LlamaParallelMLP(self.init_method, self.output_layer_init_method) def forward(self, hidden_states, attention_mask=None, diff --git a/ascendspeed/model/transformer.py b/ascendspeed/model/transformer.py index f3e4058c3..f5571c83e 100644 --- a/ascendspeed/model/transformer.py +++ b/ascendspeed/model/transformer.py @@ -160,20 +160,23 @@ class ParallelAttention(MegatronModule): args.hidden_size, 3 * projection_size, gather_output=False, - init_method=init_method) + init_method=init_method, + sequence_parallel_enabled=args.sequence_parallel) else: assert attention_type == AttnType.cross_attn self.query = mpu.ColumnParallelLinear( args.hidden_size, projection_size, gather_output=False, - init_method=init_method) + init_method=init_method, + sequence_parallel_enabled=args.sequence_parallel) self.key_value = mpu.ColumnParallelLinear( args.hidden_size, 2 * projection_size, gather_output=False, - init_method=init_method) + init_method=init_method, + sequence_parallel_enabled=args.sequence_parallel) coeff = None self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) @@ -200,7 +203,8 @@ class ParallelAttention(MegatronModule): args.hidden_size, input_is_parallel=True, init_method=output_layer_init_method, - skip_bias_add=True) + skip_bias_add=True, + sequence_parallel_enabled=args.sequence_parallel) if deepspeed.checkpointing.is_configured(): global get_cuda_rng_tracker, checkpoint diff --git a/ascendspeed/mpu/initialize.py b/ascendspeed/mpu/initialize.py index 301a98d1b..98a544cc9 100644 --- a/ascendspeed/mpu/initialize.py +++ b/ascendspeed/mpu/initialize.py @@ -18,7 +18,7 @@ import torch -from .utils import ensure_divisibility +from .utils import ensure_divisibility, GlobalMemoryBuffer # Intra-layer model parallel group that the current rank belongs to. @@ -45,6 +45,9 @@ _MPU_PIPELINE_MODEL_PARALLEL_RANK = None # rank when broadcasting from the first or last pipeline stage _PIPELINE_GLOBAL_RANKS = None +# Memory buffers to avoid dynamic memory allocation +_GLOBAL_MEMORY_BUFFER = None + def is_unitialized(): """Useful for code segments that may be accessed with or without mpu initialization""" return _DATA_PARALLEL_GROUP is None @@ -166,6 +169,12 @@ def initialize_model_parallel(tensor_model_parallel_size_=1, group = torch.distributed.new_group(embedding_ranks) if rank in embedding_ranks: _EMBEDDING_GROUP = group + + # Initialize global memory buffer + # This isn't really "parallel state" but there isn't another good place to + # put this. If we end up with a more generic initialization of megatron-core + # we could stick it there + _set_global_memory_buffer() def model_parallel_is_initialized(): @@ -364,6 +373,19 @@ def get_data_parallel_rank(): return torch.distributed.get_rank(group=get_data_parallel_group()) +def _set_global_memory_buffer(): + """Initialize global buffer""" + global _GLOBAL_MEMORY_BUFFER + assert _GLOBAL_MEMORY_BUFFER is None, 'global memory buffer is already initialized' + _GLOBAL_MEMORY_BUFFER = GlobalMemoryBuffer() + + +def get_global_memory_buffer(): + """Return the global GlobalMemoryBuffer object""" + assert _GLOBAL_MEMORY_BUFFER is not None, 'global memory buffer is not initialized' + return _GLOBAL_MEMORY_BUFFER + + def destroy_model_parallel(): """Set the groups to none.""" global _TENSOR_MODEL_PARALLEL_GROUP diff --git a/ascendspeed/mpu/layers.py b/ascendspeed/mpu/layers.py index 7eed59567..15f03a5d4 100644 --- a/ascendspeed/mpu/layers.py +++ b/ascendspeed/mpu/layers.py @@ -19,7 +19,7 @@ import math - +from typing import List, Optional import deepspeed.runtime.activation_checkpointing.checkpointing as ds_checkpointing import torch import torch.nn.functional as F @@ -29,11 +29,14 @@ from deepspeed.accelerator import get_accelerator from ascendspeed import get_args, mpu from ascendspeed.mpu.initialize import get_tensor_model_parallel_rank +from ascendspeed.mpu.initialize import get_tensor_model_parallel_group from ascendspeed.mpu.initialize import get_tensor_model_parallel_world_size +from ascendspeed.mpu.initialize import get_global_memory_buffer from ascendspeed.mpu.mappings import copy_to_tensor_model_parallel_region from ascendspeed.mpu.mappings import gather_from_tensor_model_parallel_region from ascendspeed.mpu.mappings import reduce_from_tensor_model_parallel_region from ascendspeed.mpu.mappings import scatter_to_tensor_model_parallel_region +from ascendspeed.mpu.mappings import reduce_scatter_to_sequence_parallel_region from ascendspeed.mpu.random import get_cuda_rng_tracker from ascendspeed.mpu.utils import divide, split_tensor_along_last_dim, VocabUtility from ascendspeed.model.fused_layer_norm import MixedFusedLayerNorm as LayerNorm @@ -210,6 +213,157 @@ class VocabParallelEmbedding(torch.nn.Module): return output +class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function): + """See linear_with_grad_accumulation_and_async_allreduce""" + + @staticmethod + def forward(ctx, input_, weight, bias, sequence_parallel): + ctx.save_for_backward(input_, weight) + ctx.use_bias = bias is not None + ctx.sequence_parallel = sequence_parallel + + + if sequence_parallel: + world_size = get_tensor_model_parallel_world_size() + dim_size = list(input_.size()) + dim_size[0] = dim_size[0] * world_size + + all_gather_buffer = \ + get_global_memory_buffer().get_tensor(dim_size, input_.dtype, "mpu") + torch.distributed._all_gather_base( + all_gather_buffer, + input_, + group=get_tensor_model_parallel_group()) + + total_input = all_gather_buffer + else: + total_input = input_ + + output = torch.matmul(total_input, weight.t()) + if bias is not None: + output = output + bias + return output + + @staticmethod + def backward(ctx, grad_output): + input_, weight = ctx.saved_tensors + use_bias = ctx.use_bias + + if ctx.sequence_parallel: + world_size = get_tensor_model_parallel_world_size() + dim_size = list(input_.size()) + dim_size[0] = dim_size[0] * world_size + + all_gather_buffer = \ + get_global_memory_buffer().get_tensor(dim_size, input_.dtype, "mpu") + handle = torch.distributed._all_gather_base( + all_gather_buffer, + input_, + group=get_tensor_model_parallel_group(), async_op=True) + + # Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the + # gather is scheduled before the input gradient computation + total_input = all_gather_buffer + else: + total_input = input_ + grad_input = grad_output.matmul(weight) + + if ctx.sequence_parallel: + handle.wait() + + # Convert the tensor shapes to 2D for execution compatibility + grad_output = grad_output.reshape(grad_output.shape[0] * grad_output.shape[1], + grad_output.shape[2]) + total_input = total_input.reshape(total_input.shape[0] * total_input.shape[1], + total_input.shape[2]) + + if ctx.sequence_parallel: + dim_size = list(input_.size()) + sub_grad_input = torch.empty(dim_size, dtype=input_.dtype, + device=torch.cuda.current_device(), + requires_grad=False) + handle = torch.distributed._reduce_scatter_base(sub_grad_input, grad_input, + group=get_tensor_model_parallel_group(), + async_op=True) + # Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the + # reduce scatter is scheduled before the weight gradient computation + + + + grad_weight = grad_output.t().matmul(total_input) + grad_bias = grad_output.sum(dim=0) if use_bias else None + + if ctx.sequence_parallel: + handle.wait() + return sub_grad_input, grad_weight, grad_bias, None, None, None + + return grad_input, grad_weight, grad_bias, None, None, None + + +def linear_with_grad_accumulation_and_async_allreduce( + input_: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor], + sequence_parallel_enabled: bool, +) -> torch.Tensor: + """Linear layer execution with asynchronous communication and + gradient accumulation fusion in backprop. + + This has the option to accumulate the result of backprop + calculation into an existing gradient buffer, preventing the need + to do an additional addition kernel after the gradient + calculation. + + Additionally, the tensor parallel all reduce of the input + gradients can be done asynchronously with the calculation of + the weight gradients. + + In the case of sequence parallelism, the reduce scatter of the + input gradients is done asynchronously with the calcluation of the + weight gradients. + + Use of this module requires that the environment variable + CUDA_DEVICE_MAX_CONNECTIONS=1. There are a few collective + operations, noted in the code, that should be scheduled before + compute kernels to overlap the communication with the computation, + which is necessary for a speedup but not for correctness so that + ordering isn't imposed by the scheduler. Setting + CUDA_DEVICE_MAX_CONNECTIONS=1 forces the kernels to be scheduled + in the order they are called. + + Arguments: + + input (torch.Tensor required): input like torch.nn.functional.linear + + weight (torch.Tensor required): weight like torch.nn.functional.linear + + bias (torch.Tensor optional): bias like torch.nn.functional.linear + + gradient_accumulation_fusion (bool required): Perform the gradient + accumulation fusion, requires the custom CUDA extension + fused_weight_gradient_mlp_cuda module. To use + gradient_accumulation_fusion you must install APEX with + --cpp_ext and --cuda_ext. For example: "pip install + --global-option=\"--cpp_ext\" --global-option=\"--cuda_ext .\" + " Note that the extension requires CUDA>=11. Otherwise, you + must turn off gradient accumulation fusion." + + sequence_parallel_enabled (bool required): Indicates that sequence + parallelism is used and thus in the forward pass the input is + all gathered, and the backward pass the input gradients are + reduce scattered. + """ + args = [ + input_, + weight, + bias, + sequence_parallel_enabled, + ] + + with torch.cuda.amp.autocast(enabled=False): + return LinearWithGradAccumulationAndAsyncCommunication.apply(*args) +linear_with_grad_accumulation_and_async_allreduce.warned = False + class ColumnParallelLinear(torch.nn.Module): """Linear layer with column parallelism. @@ -237,7 +391,10 @@ class ColumnParallelLinear(torch.nn.Module): def __init__(self, input_size, output_size, bias=True, gather_output=True, init_method=init.xavier_normal_, stride=1, keep_master_weight_for_test=False, - skip_bias_add=False, moe=False, enable_expert_tensor_parallelism=False): + skip_bias_add=False, moe=False, + enable_expert_tensor_parallelism=False, + sequence_parallel_enabled: bool = False + ): super(ColumnParallelLinear, self).__init__() # Keep input parameters @@ -290,12 +447,20 @@ class ColumnParallelLinear(torch.nn.Module): self.bias.zero_() else: self.register_parameter('bias', None) + if sequence_parallel_enabled: + if world_size <= 1: + warnings.warn( + f"`sequence_parallel_enabled` is set to `True`, but tensor model parallel size is {world_size}. " + f"Disabling sequence parallel." + ) + sequence_parallel_enabled = False + self.sequence_parallel_enabled = sequence_parallel_enabled def forward(self, input_): # Set up backprop all-reduce. - if self.is_expert_without_slicing: # non-expert only tensor parallelism + if self.is_expert_without_slicing or self.sequence_parallel_enabled: # non-expert only tensor parallelism input_parallel = input_ else: input_parallel = copy_to_tensor_model_parallel_region(input_) @@ -303,9 +468,17 @@ class ColumnParallelLinear(torch.nn.Module): # Matrix multiply. bias = self.bias if not self.skip_bias_add else None - output_parallel = F.linear(input_parallel, self.weight, bias) + + output_parallel = linear_with_grad_accumulation_and_async_allreduce( + input_=input_parallel, + weight=self.weight, + bias=bias, + sequence_parallel_enabled=self.sequence_parallel_enabled, + ) + if self.gather_output and not self.is_expert_without_slicing: # All-gather across the partitions. + assert not self.sequence_parallel_enabled output = gather_from_tensor_model_parallel_region(output_parallel) else: output = output_parallel @@ -347,7 +520,9 @@ class RowParallelLinear(torch.nn.Module): input_is_parallel=False, init_method=init.xavier_normal_, stride=1, keep_master_weight_for_test=False, - skip_bias_add=False, moe=False, enable_expert_tensor_parallelism=False): + skip_bias_add=False, moe=False, + enable_expert_tensor_parallelism=False, + sequence_parallel_enabled: bool = False): super(RowParallelLinear, self).__init__() # Keep input parameters @@ -365,7 +540,9 @@ class RowParallelLinear(torch.nn.Module): self.input_size_per_partition = divide(input_size, world_size) self.skip_bias_add = skip_bias_add - + self.sequence_parallel_enabled = sequence_parallel_enabled + if self.sequence_parallel_enabled and not self.input_is_parallel: + raise RuntimeError("To enable `sequence_parallel_enabled`, `input_is_parallel` must be `True`") # Parameters. # Note: torch.nn.functional.linear performs XA^T + b and as a result # we allocate the transpose. @@ -406,14 +583,24 @@ class RowParallelLinear(torch.nn.Module): if self.input_is_parallel or self.is_expert_without_slicing: input_parallel = input_ else: + assert not self.sequence_parallel_enabled input_parallel = scatter_to_tensor_model_parallel_region(input_) # Matrix multiply. - output_parallel = F.linear(input_parallel, self.weight) + output_parallel = linear_with_grad_accumulation_and_async_allreduce( + input_=input_parallel, + weight=self.weight, + bias=None, + sequence_parallel_enabled=False, + ) + # All-reduce across all the partitions. if self.is_expert_without_slicing: # non-expert only tensor-parallelism output_ = output_parallel else: - output_ = reduce_from_tensor_model_parallel_region(output_parallel) + if self.sequence_parallel_enabled: + output_ = reduce_scatter_to_sequence_parallel_region(output_parallel) + else: + output_ = reduce_from_tensor_model_parallel_region(output_parallel) if not self.skip_bias_add: output = output_ + self.bias if self.bias is not None else output_ diff --git a/ascendspeed/mpu/mappings.py b/ascendspeed/mpu/mappings.py index 821d9acfe..e70fb91b0 100644 --- a/ascendspeed/mpu/mappings.py +++ b/ascendspeed/mpu/mappings.py @@ -72,6 +72,105 @@ def _gather(input_): return output +def _gather_along_first_dim(input_): + """Gather tensors and concatinate along the first dimension.""" + + world_size = get_tensor_model_parallel_world_size() + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + + dim_size = list(input_.size()) + dim_size[0] = dim_size[0] * world_size + + output = torch.empty(dim_size, dtype=input_.dtype, + device=torch.cuda.current_device()) + torch.distributed._all_gather_base(output, input_.contiguous(), + group=get_tensor_model_parallel_group()) + + return output + +def _reduce_scatter_along_first_dim(input_): + """Reduce-scatter the input tensor across model parallel group.""" + world_size = get_tensor_model_parallel_world_size() + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + + dim_size = list(input_.size()) + assert dim_size[0] % world_size == 0, \ + "First dimension of the tensor should be divisible by tensor parallel size" + + dim_size[0] = dim_size[0] // world_size + + output = torch.empty(dim_size, dtype=input_.dtype, + device=torch.cuda.current_device()) + torch.distributed._reduce_scatter_base(output, input_.contiguous(), + group=get_tensor_model_parallel_group()) + return output + +def _split_along_first_dim(input_): + """Split the tensor along its first dimension and keep the + corresponding slice.""" + + world_size = get_tensor_model_parallel_world_size() + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + + # Split along first dimension. + dim_size = input_.size()[0] + assert dim_size % world_size == 0, \ + "First dimension of the tensor should be divisible by tensor parallel size" + local_dim_size = dim_size // world_size + rank = get_tensor_model_parallel_rank() + dim_offset = rank * local_dim_size + + output = input_[dim_offset:dim_offset+local_dim_size].contiguous() + + return output + +class _ScatterToSequenceParallelRegion(torch.autograd.Function): + """Split the input and keep only the corresponding chuck to the rank.""" + + @staticmethod + def symbolic(graph, input_): + return _split_along_first_dim(input_) + + @staticmethod + def forward(ctx, input_): + return _split_along_first_dim(input_) + + @staticmethod + def backward(ctx, grad_output): + return _gather_along_first_dim(grad_output) + + +class _GatherFromSequenceParallelRegion(torch.autograd.Function): + """Gather the input from sequence parallel region and concatinate.""" + + @staticmethod + def symbolic(graph, input_, tensor_parallel_output_grad=True): + return _gather_along_first_dim(input_) + + @staticmethod + def forward(ctx, input_, tensor_parallel_output_grad=True): + ctx.tensor_parallel_output_grad = tensor_parallel_output_grad + return _gather_along_first_dim(input_) + + @staticmethod + def backward(ctx, grad_output): + tensor_parallel_output_grad = ctx.tensor_parallel_output_grad + + # If the computation graph after the gather operation is + # in the tensor parallel mode, output gradients need to reduce + # scattered and whereas if the computation is duplicated, + # output gradients need to be scattered. + if tensor_parallel_output_grad: + return _reduce_scatter_along_first_dim(grad_output), None + else: + return _split_along_first_dim(grad_output), None + class _CopyToModelParallelRegion(torch.autograd.Function): """Pass the input to the model parallel region.""" @@ -121,6 +220,23 @@ class _ScatterToModelParallelRegion(torch.autograd.Function): return _gather(grad_output) +class _ReduceScatterToSequenceParallelRegion(torch.autograd.Function): + """Reduce scatter the input from the model parallel region.""" + + @staticmethod + def symbolic(graph, input_): + return _reduce_scatter_along_first_dim(input_) + + @staticmethod + def forward(ctx, input_): + return _reduce_scatter_along_first_dim(input_) + + @staticmethod + def backward(ctx, grad_output): + return _gather_along_first_dim(grad_output) + + + class _GatherFromModelParallelRegion(torch.autograd.Function): """Gather the input from model parallel region and concatinate.""" @@ -153,5 +269,14 @@ def scatter_to_tensor_model_parallel_region(input_): return _ScatterToModelParallelRegion.apply(input_) +def gather_from_sequence_parallel_region(input_, tensor_parallel_output_grad=True): + return _GatherFromSequenceParallelRegion.apply(input_, tensor_parallel_output_grad) + def gather_from_tensor_model_parallel_region(input_): return _GatherFromModelParallelRegion.apply(input_) + +def reduce_scatter_to_sequence_parallel_region(input_): + return _ReduceScatterToSequenceParallelRegion.apply(input_) + +def scatter_to_sequence_parallel_region(input_): + return _ScatterToSequenceParallelRegion.apply(input_) \ No newline at end of file diff --git a/ascendspeed/mpu/utils.py b/ascendspeed/mpu/utils.py index 2431a1c39..229a98f06 100644 --- a/ascendspeed/mpu/utils.py +++ b/ascendspeed/mpu/utils.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from functools import reduce +import operator import torch @@ -29,6 +31,27 @@ def divide(numerator, denominator): ensure_divisibility(numerator, denominator) return numerator // denominator +class GlobalMemoryBuffer: + """Global buffer to avoid dynamic memory allocations. + Caller should ensure that buffers of the same name + are not used concurrently. + """ + + def __init__(self): + self.buffer = {} + + def get_tensor(self, tensor_shape, dtype, name): + required_len = reduce(operator.mul, tensor_shape, 1) + if self.buffer.get((name, dtype), None) is None or \ + self.buffer[(name, dtype)].numel() < required_len: + self.buffer[(name, dtype)] = \ + torch.empty(required_len, + dtype=dtype, + device=torch.cuda.current_device(), + requires_grad=False) + + return self.buffer[(name, dtype)][0:required_len].view(*tensor_shape) + def _kernel_make_viewless_tensor(inp, requires_grad): '''Make a viewless tensor. diff --git a/ascendspeed/p2p_communication.py b/ascendspeed/p2p_communication.py index 4bd06a3fb..056159a6b 100644 --- a/ascendspeed/p2p_communication.py +++ b/ascendspeed/p2p_communication.py @@ -57,9 +57,13 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next, tensor_shape = recv_tensor_shape if args.optimized_pipeline and (recv_prev or recv_next) \ else (args.seq_length, args.micro_batch_size, args.hidden_size) - if args.scatter_gather_tensors_in_pipeline: + if args.sequence_parallel: + seq_length = args.seq_length // mpu.get_tensor_model_parallel_world_size() + tensor_shape = (seq_length, args.micro_batch_size, args.hidden_size) + + if args.scatter_gather_tensors_in_pipeline and not args.sequence_parallel: tensor_chunk_shape = reduce(operator.mul, tensor_shape, 1) // \ - mpu.get_tensor_model_parallel_world_size() + mpu.get_tensor_model_parallel_world_size() else: tensor_chunk_shape = tensor_shape dtype = args.params_dtype @@ -77,7 +81,7 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next, dtype=dtype) # Split tensor into smaller chunks if using scatter-gather optimization. - if args.scatter_gather_tensors_in_pipeline: + if args.scatter_gather_tensors_in_pipeline and not args.sequence_parallel: if tensor_send_next is not None: tensor_send_next = mpu.split_tensor_into_1d_equal_chunks(tensor_send_next) @@ -141,7 +145,7 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next, get_accelerator().synchronize() # If using scatter-gather optimization, gather smaller chunks. - if args.scatter_gather_tensors_in_pipeline: + if args.scatter_gather_tensors_in_pipeline and not args.sequence_parallel: if recv_prev: tensor_recv_prev = mpu.gather_split_1d_tensor( tensor_recv_prev).view(tensor_shape).requires_grad_() diff --git a/ascendspeed/training.py b/ascendspeed/training.py index 3b3e91319..9c3d286b4 100644 --- a/ascendspeed/training.py +++ b/ascendspeed/training.py @@ -20,8 +20,6 @@ import math import sys import time import json -# The earliest we can measure the start time. -_TRAIN_START_TIME = time.time() import torch from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP @@ -55,12 +53,16 @@ from ascendspeed.schedules import optimized_forward_backward_pipelining from ascendspeed.utils import report_memory, throughput_calculator, checkpoint_throughput_calculator from ascendspeed.model.transformer import ParallelTransformerLayer + import deepspeed from deepspeed.accelerator import get_accelerator from deepspeed.compression.compress import init_compression, redundancy_clean from deepspeed.runtime.data_pipeline.data_routing.helper import convert_to_random_ltd +# The earliest we can measure the start time. +_TRAIN_START_TIME = time.time() + def print_datetime(string): """Note that this call will sync across all ranks.""" torch.distributed.barrier() @@ -903,9 +905,15 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration, opt_stats[5] += torch.norm(optimizer.state[param]['exp_avg_sq'].sqrt(),p=1).item() opt_stats[6] += torch.norm(optimizer.state[param]['exp_avg'],p=1).item() opt_stats[7] += torch.norm(param,p=1).item() - opt_stats_2[0] = max(opt_stats_2[0], abs(optimizer.state[param]['exp_avg_sq'].max().item()), abs(optimizer.state[param]['exp_avg_sq'].min().item())) + opt_stats_2[0] = max( + opt_stats_2[0], + abs(optimizer.state[param]['exp_avg_sq'].max().item()), + abs(optimizer.state[param]['exp_avg_sq'].min().item())) opt_stats_2[1] = max(opt_stats_2[1], optimizer.state[param]['exp_avg_sq'].sqrt().abs_().max().item()) - opt_stats_2[2] = max(opt_stats_2[2], abs(optimizer.state[param]['exp_avg'].max().item()), abs(optimizer.state[param]['exp_avg'].min().item())) + opt_stats_2[2] = max( + opt_stats_2[2], + abs(optimizer.state[param]['exp_avg'].max().item()), + abs(optimizer.state[param]['exp_avg'].min().item())) opt_stats_2[3] = max(opt_stats_2[3], abs(param.max().item()), abs(param.min().item())) if args.zero_stage > 0: diff --git a/examples/llama_task/pretrain_llama_td_8p_sp.sh b/examples/llama_task/pretrain_llama_td_8p_sp.sh new file mode 100644 index 000000000..e12c78e62 --- /dev/null +++ b/examples/llama_task/pretrain_llama_td_8p_sp.sh @@ -0,0 +1,53 @@ +# This is an example: train llama using TD, +# the number of parameters is not aligned +export LD_LIBRARY_PATH=/usr/local/lib:/root/miniconda3/lib:$LD_LIBRARY_PATH +export HCCL_CONNECT_TIMEOUT=1200 +source /usr/local/Ascend/ascend-toolkit/set_env.sh +GPUS_PER_NODE=8 +# Change for multinode config +MASTER_ADDR=localhost +MASTER_PORT=6000 +NNODES=1 +NODE_RANK=0 +WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES)) + +DATA_PATH=./dataset/llama_text_document +CHECKPOINT_PATH=./ckpt + +DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE --nnodes $NNODES --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT" + +python -m torch.distributed.launch $DISTRIBUTED_ARGS \ + pretrain_llama.py \ + --DDP-impl local \ + --use-contiguous-buffers-in-ddp \ + --tensor-model-parallel-size 4 \ + --num-layers 8 \ + --hidden-size 4096 \ + --num-attention-heads 32 \ + --micro-batch-size 2 \ + --global-batch-size 16 \ + --seq-length 2048 \ + --max-position-embeddings 2048 \ + --train-iters 500000 \ + --lr-decay-iters 320000 \ + --save $CHECKPOINT_PATH \ + --load $CHECKPOINT_PATH \ + --data-path $DATA_PATH \ + --tokenizer-name-or-path ./dataset/llama/ \ + --tokenizer-not-use-fast \ + --data-impl mmap \ + --split 949,50,1 \ + --distributed-backend nccl \ + --lr 0.00015 \ + --lr-decay-style cosine \ + --min-lr 1.0e-5 \ + --weight-decay 1e-2 \ + --clip-grad 1.0 \ + --lr-warmup-fraction .01 \ + --checkpoint-activations \ + --log-interval 10 \ + --save-interval 10000 \ + --eval-interval 1000 \ + --eval-iters 10 \ + --sequence-parallel \ + --fp16 | tee logs/train.log diff --git a/tests/st/test_llama/test_llama_ptd.sh b/tests/st/test_llama/test_llama_ptd.sh index d17fa0c5c..fe6352299 100644 --- a/tests/st/test_llama/test_llama_ptd.sh +++ b/tests/st/test_llama/test_llama_ptd.sh @@ -51,4 +51,5 @@ python3 -m torch.distributed.launch $DISTRIBUTED_ARGS \ --save-interval 10000 \ --eval-interval 1000 \ --eval-iters 10 \ + --sequence-parallel \ --fp16 -- Gitee From 4cd1587c46f28a7278a1781dafeba6494b9386fd Mon Sep 17 00:00:00 2001 From: kingsleyandher <13825737389@163.com> Date: Tue, 11 Jul 2023 20:12:39 +0800 Subject: [PATCH 2/2] merge code --- ascendspeed/training.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/ascendspeed/training.py b/ascendspeed/training.py index 90be6c663..658b77a38 100644 --- a/ascendspeed/training.py +++ b/ascendspeed/training.py @@ -54,14 +54,11 @@ from ascendspeed.schedules import optimized_forward_backward_pipelining from ascendspeed.utils import report_memory, throughput_calculator, checkpoint_throughput_calculator from ascendspeed.model.transformer import ParallelTransformerLayer - import deepspeed from deepspeed.accelerator import get_accelerator from deepspeed.compression.compress import init_compression, redundancy_clean from deepspeed.runtime.data_pipeline.data_routing.helper import convert_to_random_ltd -# The earliest we can measure the start time. -_TRAIN_START_TIME = time.time() # The earliest we can measure the start time. _TRAIN_START_TIME = time.time() -- Gitee