197 Star 1.3K Fork 1.2K

GVPAscend/MindSpeed-LLM

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
ci
configs
docs
examples
mindspeed_llm
core
features_manager
inference
legacy
mindspore
tasks
training
tokenizer
__init__.py
arguments.py
checkpointing.py
initialize.py
training.py
utils.py
__init__.py
sources/images
tests
.gitignore
LICENSE
OWNERS
README.md
SECURITYNOTE.md
Third_Party_Open_Source_Software_Notice
convert_ckpt.py
evaluation.py
img.png
inference.py
posttrain_gpt.py
preprocess_data.py
pretrain_gpt.py
pretrain_mamba.py
requirements.txt
setup.py
克隆/下载
training.py 43.17 KB
一键复制 编辑 原始数据 按行查看 历史
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960
# coding=utf-8
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2024, HUAWEI CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import sys
import json
from datetime import datetime, timezone
from functools import wraps
import time
import torch
import torch_npu
from megatron.core.transformer.moe.moe_utils import track_moe_metrics
from megatron.training import get_args
from megatron.training import get_timers
from megatron.training import get_signal_handler
from megatron.training import get_tensorboard_writer
from megatron.training import get_wandb_writer
from megatron.training import one_logger_utils
from megatron.core.num_microbatches_calculator import get_num_microbatches, update_num_microbatches
from megatron.core import mpu
from megatron.core.utils import get_model_config
from megatron.core.enums import ModelType
from megatron.training.checkpointing import save_checkpoint
from megatron.training.initialize import initialize_megatron
from megatron.training.initialize import write_args_to_tensorboard
from megatron.training.arguments import core_transformer_config_from_args
from megatron.training.theoretical_memory_usage import report_theoretical_memory
from megatron.training.training import (
train_step, calc_params_l2_norm,
evaluate_and_print_results,
save_checkpoint_and_time, print_datetime,
num_floating_point_operations, get_one_logger,
append_to_progress_log, build_train_valid_test_data_iterators
)
import megatron.training.utils
from megatron.training.utils import (
check_adlr_autoresume_termination,
print_rank_0,
print_rank_last,
report_memory,
)
from megatron.core.distributed import DistributedDataParallel as DDP
from megatron.core.distributed import finalize_model_grads
from mindspeed_llm.training.initialize import set_jit_fusion_options
from mindspeed_llm.tasks.posttrain.lora.utils import is_enable_lora
# The earliest we can measure the start time.
_TRAIN_START_TIME = time.time()
def model_provider_func_wrapper(model_provider_func):
@wraps(model_provider_func)
def wrapper(*args, **kwargs):
model = model_provider_func(*args, **kwargs)
args = get_args()
if args.use_fused_mlp:
from mindspeed_llm.tasks.models.transformer.fast_mlp import ParallelSwigluMLPForward
from megatron.legacy.model.transformer import ParallelMLP
from megatron.core.transformer.mlp import MLP
ParallelMLP.forward = ParallelSwigluMLPForward
MLP.forward = ParallelSwigluMLPForward
if is_enable_lora():
import peft
from packaging import version
from peft import LoraConfig, get_peft_model, PeftModel, LoraModel
if version.parse(peft.__version__) <= version.parse('0.11.1'):
setattr(peft.tuners.lora.LoraLayer, 'merge', peft.tuners.lora.Linear.merge)
setattr(peft.tuners.lora.LoraLayer, 'unmerge', peft.tuners.lora.Linear.unmerge)
setattr(peft.tuners.lora.LoraLayer, 'get_delta_weight', peft.tuners.lora.Linear.get_delta_weight)
from peft.tuners.lora import tp_layer
from mindspeed_llm.tasks.posttrain.lora.lora_moe import LoraParallelLinearMoE
tp_layer.LoraParallelLinear = LoraParallelLinearMoE
if hasattr(args, 'lora_fusion') and args.lora_fusion:
from peft.tuners.lora.tp_layer import LoraParallelLinear
from mindspeed_llm.tasks.posttrain.lora.cc_lora_forward import CCLoraParallelLinearForward
LoraParallelLinear.forward = CCLoraParallelLinearForward
if args.use_fused_mlp:
from mindspeed_llm.tasks.posttrain.lora.cc_lora_mlp_forward import ParallelSwigluMLPLoRAForward
from megatron.legacy.model.transformer import ParallelMLP
from megatron.core.transformer.mlp import MLP
ParallelMLP.forward = ParallelSwigluMLPLoRAForward
MLP.forward = ParallelSwigluMLPLoRAForward
config = core_transformer_config_from_args(args)
lora_config = LoraConfig(
r=args.lora_r,
lora_alpha=args.lora_alpha,
target_modules=args.lora_target_modules,
lora_dropout=0.0,
bias="none",
megatron_config=config,
megatron_core="megatron.core",
)
model = get_peft_model(model, lora_config)
model.add_module('module', model.get_base_model())
def _hook(_module, _x_in, _x_out):
""" Extract the feature map of model"""
_x_out.requires_grad_(True)
def _create_hooks(_model, layer):
""" Make the hooks function"""
for name, module in _model.named_modules():
if isinstance(module, megatron.core.tensor_parallel.layers.VocabParallelEmbedding):
_name = name.split('.')[-1]
if _name in layer:
module.register_forward_hook(_hook)
if args.recompute_method == 'block' and args.recompute_granularity == 'full':
_create_hooks(model, args.lora_register_forward_hook)
model.print_trainable_parameters()
for module in model.modules():
# LoRA Linear Layer need all reduce
if isinstance(module, torch.nn.Linear):
setattr(module.weight, 'sequence_parallel', config.sequence_parallel)
# Other layers if is frozen, do not need all reduce.
for param in module.parameters():
if not param.requires_grad and hasattr(param, 'sequence_parallel'):
delattr(param, 'sequence_parallel')
megatron.training.utils.ALL_MODULE_WRAPPER_CLASSNAMES = tuple(
list(megatron.training.utils.ALL_MODULE_WRAPPER_CLASSNAMES) + [PeftModel, LoraModel]
)
return model
return wrapper
def get_model_wrapper(fn):
@wraps(fn)
def wrapper(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap_with_ddp=True):
model_provider_func = model_provider_func_wrapper(model_provider_func)
model = fn(model_provider_func, model_type, wrap_with_ddp)
return model
return wrapper
# This wrapper ensures that DataLoader workers are initialized early to prevent deadlocks during evaluation.
# By accessing the DataLoader's iterator, it triggers the creation of worker subprocesses at initialization.
def build_train_valid_test_data_loaders_wrapper(fn):
@wraps(fn)
def wrapper(build_train_valid_test_datasets_provider):
train_dataloader, valid_dataloader, test_dataloader = fn(build_train_valid_test_datasets_provider)
is_distributed = getattr(build_train_valid_test_datasets_provider, "is_distributed", False)
if is_distributed or mpu.get_tensor_model_parallel_rank() == 0:
for dataloader in [train_dataloader, valid_dataloader, test_dataloader]:
if dataloader is not None:
# Access the DataLoader's iterator to initialize workers
_ = iter(dataloader)
return train_dataloader, valid_dataloader, test_dataloader
return wrapper
def is_profile_enabled():
args = get_args()
if not args.profile:
return False
if args.profile_ranks == [-1]:
return True
if torch.distributed.get_rank() in args.profile_ranks:
return True
return False
def get_profiler():
args = get_args()
if args.profile_level == 'level_none':
profiler_level = torch_npu.profiler.ProfilerLevel.Level_none
elif args.profile_level == 'level0':
profiler_level = torch_npu.profiler.ProfilerLevel.Level0
elif args.profile_level == 'level1':
profiler_level = torch_npu.profiler.ProfilerLevel.Level1
elif args.profile_level == 'level2':
profiler_level = torch_npu.profiler.ProfilerLevel.Level2
else:
raise ValueError(f"profiler_level only supports level0,"
f" 1, 2, and level_none, but gets {args.profile_level}")
if args.profile_export_type == 'text':
profile_export_type = torch_npu.profiler.ExportType.Text
elif args.profile_export_type == 'db':
profile_export_type = torch_npu.profiler.ExportType.Db
else:
raise ValueError(f"profile_export_type only supports text or db,"
f"but gets {args.export_type}")
experimental_config = torch_npu.profiler._ExperimentalConfig(
aic_metrics=torch_npu.profiler.AiCMetrics.PipeUtilization,
profiler_level=profiler_level,
export_type=profile_export_type,
data_simplification=args.profile_data_simplification,
)
skip_first = args.profile_step_start
active = args.profile_step_end - args.profile_step_start
activites = [torch_npu.profiler.ProfilerActivity.NPU]
if args.profile_with_cpu:
activites.append(torch_npu.profiler.ProfilerActivity.CPU)
prof = torch_npu.profiler.profile(
with_stack=args.profile_with_stack,
record_shapes=args.profile_record_shapes,
profile_memory=args.profile_with_memory,
activities=activites,
schedule=torch_npu.profiler.schedule(wait=0, warmup=1, active=active, repeat=1, skip_first=skip_first),
on_trace_ready=torch_npu.profiler.tensorboard_trace_handler(args.profile_save_path),
experimental_config=experimental_config)
prof.add_metadata_json('distributed_args', json.dumps({
'tensor_model_parallel_size': args.tensor_model_parallel_size,
'pipeline_model_parallel_size': args.pipeline_model_parallel_size,
'data_parallel_size': args.data_parallel_size,
'context_parallel_size': args.context_parallel_size,
'expert_model_parallel_size': args.expert_model_parallel_size,
'sequence_parallel': args.sequence_parallel,
'rank': args.rank,
'world_size': args.world_size
}))
return prof
def setup_model_and_optimizer_wrapper(fn):
@wraps(fn)
def wrapper(*args, **kwargs):
model, optimizer, opt_param_scheduler = fn(*args, **kwargs)
argument = get_args()
if argument.enable_high_availability and hasattr(optimizer, "set_current_step"):
optimizer.set_current_step(argument.iteration)
return model, optimizer, opt_param_scheduler
return wrapper
def build_train_args(*input_args):
args, timers, train_valid_test_dataset_provider, model_provider, model_type, forward_step_func, process_non_loss_data_func, app_metrics = input_args
from megatron.training.training import setup_model_and_optimizer
# Model, optimizer, and learning rate.
timers('model-and-optimizer-setup', log_level=0).start(barrier=True)
app_metrics['app_build_optimizer_start_time'] = one_logger_utils.get_timestamp_in_ms()
model, optimizer, opt_param_scheduler = setup_model_and_optimizer(
model_provider, model_type)
timers('model-and-optimizer-setup').stop()
print_datetime('after model, optimizer, and learning rate '
'scheduler are built')
app_metrics['app_build_optimizer_finish_time'] = one_logger_utils.get_timestamp_in_ms()
config = get_model_config(model[0])
# Data stuff.
app_metrics['app_build_dataiters_start_time'] = one_logger_utils.get_timestamp_in_ms()
timers('train/valid/test-data-iterators-setup', log_level=0).start(
barrier=True)
if args.virtual_pipeline_model_parallel_size is not None:
train_data_iterator = []
valid_data_iterator = []
test_data_iterator = []
for i in range(len(model)):
mpu.set_virtual_pipeline_model_parallel_rank(i)
iterators = build_train_valid_test_data_iterators(
train_valid_test_dataset_provider)
train_data_iterator.append(iterators[0])
valid_data_iterator.append(iterators[1])
test_data_iterator.append(iterators[2])
elif args.schedules_method == 'dualpipev':
train_data_iterator = []
valid_data_iterator = []
test_data_iterator = []
for _ in range(2):
iterators = build_train_valid_test_data_iterators(
train_valid_test_dataset_provider)
train_data_iterator.append(iterators[0])
valid_data_iterator.append(iterators[1])
test_data_iterator.append(iterators[2])
else:
train_data_iterator, valid_data_iterator, test_data_iterator \
= build_train_valid_test_data_iterators(
train_valid_test_dataset_provider)
timers('train/valid/test-data-iterators-setup').stop()
print_datetime('after dataloaders are built')
app_metrics['app_build_dataiters_finish_time'] = one_logger_utils.get_timestamp_in_ms()
# Track if training is enabled. Can only be done once args.do_train is assigned after dataloader is built.
one_logger_utils.track_config_flags(args.train_iters, args.skip_train, args.do_train,
args.do_valid, args.do_test, args.dataloader_type,
args.retro_project_dir, args.retro_cyclic_train_iters)
# Print setup timing.
print_rank_0('done with setup ...')
timers.log(['model-and-optimizer-setup',
'train/valid/test-data-iterators-setup'], barrier=True)
train_args = [forward_step_func,
model, optimizer, opt_param_scheduler,
train_data_iterator, valid_data_iterator, process_non_loss_data_func, config]
test_data_iterator_list = [test_data_iterator]
return train_args, test_data_iterator_list
def pretrain(train_valid_test_dataset_provider,
model_provider,
model_type,
forward_step_func,
process_non_loss_data_func=None,
extra_args_provider=None,
args_defaults={}):
"""Main training program.
This function will run the followings in the order provided:
1) initialize Megatron.
2) setup model, optimizer and lr schedule using the model_provider.
3) call train_val_test_data_provider to get train/val/test datasets.
4) train the modle using the forward_step_func.
Args:
train_valid_test_dataset_provider: a function that takes the size of
train/valid/test dataset and returns `train, valid, test` datasets.
model_provider: a function that returns a vanilla version of the
model. By vanilla we mean a simple model on cpu with no fp16 or ddp.
model_type: an enum that specifies the type of model being trained.
forward_step_func: a function that takes a `data iterator` and `model`,
and returns a `loss` scalar with a dictionary with key:values being
the info we would like to monitor during training, for example
`lm-loss: value`. We also require that this function add
`batch generator` to the timers class.
process_non_loss_data_func: a function to post process outputs of the
network. It can be used for dumping output tensors (e.g images) to
tensorboard. It takes `collected data`(list of tensors),
`current iteration index` and `tensorboard writer` as arguments.
extra_args_provider: a function that takes a parser and adds arguments
to it. It is used for programs to add their own arguments.
args_defaults: a dictionary from argument-name to argument-value. It
to set already parse arguments.
"""
# Initalize and get arguments, timers, and Tensorboard writer.
initialize_megatron(extra_args_provider=extra_args_provider,
args_defaults=args_defaults)
args = get_args()
timers = get_timers()
if args.enable_high_availability:
raise AssertionError("High availability feature do not support core_r0.8.0")
if args.log_progress:
append_to_progress_log("Starting job")
# Set pytorch JIT layer fusion options and warmup JIT functions.
set_jit_fusion_options()
# Adjust the startup time so it reflects the largest value.
# This will be closer to what scheduler will see (outside of
# image ... launches.
global _TRAIN_START_TIME
start_time_tensor = torch.tensor([_TRAIN_START_TIME],
dtype=torch.float,
device='cuda')
torch.distributed.all_reduce(start_time_tensor,
op=torch.distributed.ReduceOp.MIN)
_TRAIN_START_TIME = start_time_tensor.item()
app_metrics = {}
app_metrics['app_start_time'] = round(_TRAIN_START_TIME * 1000.0)
app_metrics['app_model_init_start_time'] = round(_TRAIN_START_TIME * 1000.0)
print_rank_0('time to initialize megatron (seconds): {:.3f}'.format(
time.time() - _TRAIN_START_TIME))
print_datetime('after megatron is initialized')
app_metrics['app_model_init_finish_time'] = one_logger_utils.get_timestamp_in_ms()
one_logger_utils.on_pretrain_start()
train_args, test_data_iterator_list = build_train_args(args, timers, train_valid_test_dataset_provider,
model_provider,
model_type, forward_step_func, process_non_loss_data_func,
app_metrics)
one_logger = get_one_logger()
one_logger and one_logger.log_metrics(app_metrics)
if not args.do_train and not args.do_valid and not args.do_test:
raise RuntimeError('no data loaded, you might give wrong data path.')
if not args.skip_train:
print_rank_0('training ...')
if args.dataloader_type == 'cyclic' and args.retro_project_dir:
assert args.retro_cyclic_train_iters is not None
args.train_iters = args.retro_cyclic_train_iters
print_rank_0("retro cyclic train iters : %d" % args.train_iters)
iteration = 0
if args.do_train and args.train_iters > 0:
if args.enable_high_availability:
from mindio_ttp.adaptor import tft_init_controller_processor, tft_register_processor, tft_train
tft_init_controller_processor(enable_tls=False, tls_option_top_path='')
tft_register_processor(train_valid_test_dataset_provider, model_provider, model_type)
iteration, num_floating_point_operations_so_far = tft_train(train_args, test_data_iterator_list)
else:
iteration, num_floating_point_operations_so_far = train(*train_args)
test_data_iterator = test_data_iterator_list[0]
forward_step_func, model, optimizer, opt_param_scheduler, train_data_iterator, valid_data_iterator, process_non_loss_data_func, config = train_args
print_datetime('after training is done')
if args.save and iteration != 0 and iteration % args.save_interval != 0:
save_checkpoint(iteration, model, optimizer, opt_param_scheduler,
num_floating_point_operations_so_far)
one_logger and one_logger.log_metrics({
'app_train_loop_finish_time': one_logger_utils.get_timestamp_in_ms()
})
else:
print_rank_0('skipping training (--skip-train is on) ...')
iteration = args.iteration
if args.do_valid:
prefix = f'iteration {iteration} on validation set'
evaluate_and_print_results(prefix, forward_step_func,
valid_data_iterator, model,
iteration, process_non_loss_data_func, config,
verbose=True, write_to_tensorboard=not args.skip_train)
if args.do_test:
prefix = f'iteration {iteration} on test set'
evaluate_and_print_results(prefix, forward_step_func,
test_data_iterator, model,
iteration, process_non_loss_data_func, config,
verbose=True, write_to_tensorboard=not args.skip_train)
one_logger and one_logger.log_metrics({
'app_finish_time': one_logger_utils.get_timestamp_in_ms()
})
one_logger_utils.finish()
def train(forward_step_func, model, optimizer, opt_param_scheduler,
train_data_iterator, valid_data_iterator,
process_non_loss_data_func, config):
"""Train the model function."""
args = get_args()
timers = get_timers()
one_logger = get_one_logger()
# Write args to tensorboard
write_args_to_tensorboard()
# Turn on training mode which enables dropout.
for model_module in model:
model_module.train()
# Tracking loss.
total_loss_dict = {}
# Iterations.
iteration = args.iteration
# Track E2E metrics at the start of training
one_logger_utils.on_train_start(iteration=iteration, consumed_train_samples=args.consumed_train_samples,
train_samples=args.train_samples, seq_length=args.seq_length,
train_iters=args.train_iters, save=args.save, async_save=args.async_save,
log_throughput=args.log_throughput,
num_floating_point_operations_so_far=args.num_floating_point_operations_so_far)
num_floating_point_operations_so_far = 0
# Setup some training config params
config.grad_scale_func = optimizer.scale_loss
config.timers = timers
if isinstance(model[0], DDP) and args.overlap_grad_reduce and config.no_sync_func is None:
assert config.no_sync_func is None, \
('When overlap_grad_reduce is True, config.no_sync_func must be None; '
'a custom no_sync_func is not supported when overlapping grad-reduce')
config.no_sync_func = [model_chunk.no_sync for model_chunk in model]
if len(model) == 1:
config.no_sync_func = config.no_sync_func[0]
if args.delay_grad_reduce:
config.grad_sync_func = [model_chunk.start_grad_sync for model_chunk in model]
if len(model) == 1:
config.grad_sync_func = config.grad_sync_func[0]
if args.overlap_param_gather and args.delay_param_gather:
config.param_sync_func = [lambda x: optimizer.finish_param_sync(model_index, x)
for model_index in range(len(model))]
if len(model) == 1:
config.param_sync_func = config.param_sync_func[0]
config.finalize_model_grads_func = finalize_model_grads
timers('interval-time', log_level=0).start(barrier=True)
print_datetime('before the start of training step')
report_memory_flag = True
exit = False
if args.manual_gc:
# Disable the default garbage collector and perform the collection manually.
# This is to align the timing of garbage collection across ranks.
assert args.manual_gc_interval >= 0, \
'Manual garbage collection interval should be laerger than or equal to 0.'
gc.disable()
gc.collect()
total_flops = 0.0
num_microbatches = get_num_microbatches()
eval_duration = 0.0
eval_iterations = 0
def get_e2e_base_metrics():
"""Get base metrics values for one-logger to calculate E2E tracking metrics.
"""
return {
'iteration': iteration,
'train_duration': timers('interval-time').active_time(),
'eval_duration': eval_duration,
'eval_iterations': eval_iterations,
'total_flops': total_flops,
'num_floating_point_operations_so_far': num_floating_point_operations_so_far,
'consumed_train_samples': args.consumed_train_samples,
'world_size': args.world_size,
'seq_length': args.seq_length
}
# Cache into one-logger for callback
if one_logger:
with one_logger.get_context_manager():
one_logger.store_set('get_e2e_base_metrics', get_e2e_base_metrics)
if is_profile_enabled():
prof = get_profiler()
prof.start()
while iteration < args.train_iters:
# Update number of microbatches first without consistency check to decide if a
# checkpoint should be saved. If the number of microbatches is different
# from the previous iteration, save a checkpoint. Then run consistency check
# to make sure training configuration is still valid.
update_num_microbatches(args.consumed_train_samples, consistency_check=False)
if get_num_microbatches() != num_microbatches and iteration != 0:
assert get_num_microbatches() > num_microbatches, \
"number of microbatches should be increasing due to batch size rampup"
save_checkpoint_and_time(iteration, model, optimizer,
opt_param_scheduler,
num_floating_point_operations_so_far,
checkpointing_context=None)
num_microbatches = get_num_microbatches()
update_num_microbatches(args.consumed_train_samples, consistency_check=True)
args.curr_iteration = iteration
loss_dict, skipped_iter, grad_norm, num_zeros_in_grad = \
train_step(forward_step_func,
train_data_iterator,
model,
optimizer,
opt_param_scheduler,
config)
iteration += 1
batch_size = mpu.get_data_parallel_world_size() * \
args.micro_batch_size * \
get_num_microbatches()
args.consumed_train_samples += batch_size
num_fp_ops = num_floating_point_operations(args, batch_size)
num_floating_point_operations_so_far += num_fp_ops
total_flops += num_fp_ops
# Logging.
loss_scale = optimizer.get_loss_scale().item()
params_norm = None
if args.log_params_norm:
params_norm = calc_params_l2_norm(model)
learning_rate = None
decoupled_learning_rate = None
for param_group in optimizer.param_groups:
if param_group['is_decoupled_lr']:
decoupled_learning_rate = param_group['lr']
else:
learning_rate = param_group['lr']
report_memory_flag = training_log(loss_dict, total_loss_dict,
learning_rate,
decoupled_learning_rate,
iteration, loss_scale,
report_memory_flag, skipped_iter,
grad_norm, params_norm, num_zeros_in_grad)
if args.enable_high_availability:
args.num_floating_point_operations_so_far = num_floating_point_operations_so_far
args.iteration = iteration
# Autoresume
if args.adlr_autoresume and \
(iteration % args.adlr_autoresume_interval == 0):
check_adlr_autoresume_termination(iteration, model, optimizer,
opt_param_scheduler)
# Evaluation
if args.eval_interval and iteration % args.eval_interval == 0 and \
args.do_valid:
timers('interval-time').stop()
if args.use_distributed_optimizer and args.overlap_param_gather:
optimizer.disable_pre_hook()
if args.manual_gc and args.manual_gc_eval:
# Collect all objects.
gc.collect()
prefix = 'iteration {}'.format(iteration)
timers('eval-time', log_level=0).start(barrier=True)
evaluate_and_print_results(prefix, forward_step_func,
valid_data_iterator, model,
iteration, process_non_loss_data_func,
config, False)
eval_duration += timers('eval-time').elapsed()
eval_iterations += args.eval_iters
timers('eval-time').stop()
one_logger_utils.track_e2e_metrics()
if args.manual_gc and args.manual_gc_eval:
# Collect only the objects created and used in evaluation.
gc.collect(generation=0)
if args.use_distributed_optimizer and args.overlap_param_gather:
optimizer.enable_pre_hook()
timers('interval-time', log_level=0).start(barrier=True)
# Checkpointing
saved_checkpoint = False
if args.exit_signal_handler:
signal_handler = get_signal_handler()
if any(signal_handler.signals_received()):
save_checkpoint_and_time(iteration, model, optimizer,
opt_param_scheduler,
num_floating_point_operations_so_far,
checkpointing_context=None)
print_datetime('exiting program after receiving SIGTERM.')
exit = True
break
if args.save and args.save_interval and \
iteration % args.save_interval == 0:
save_checkpoint_and_time(iteration, model, optimizer,
opt_param_scheduler,
num_floating_point_operations_so_far,
checkpointing_context=None)
saved_checkpoint = True
# Exiting based on duration
if args.exit_duration_in_mins:
train_time = (time.time() - _TRAIN_START_TIME) / 60.0
done_npu = torch.tensor(
[train_time > args.exit_duration_in_mins],
dtype=torch.int, device='npu')
torch.distributed.all_reduce(
done_npu, op=torch.distributed.ReduceOp.MAX)
done = done_npu.item()
if done:
if not saved_checkpoint:
save_checkpoint_and_time(iteration, model, optimizer,
opt_param_scheduler,
num_floating_point_operations_so_far,
checkpointing_context=None)
print_datetime('exiting program after {} minutes'.format(train_time))
exit = True
break
# Exiting based on iterations
if args.exit_interval and iteration % args.exit_interval == 0:
if args.save and not saved_checkpoint:
save_checkpoint_and_time(iteration, model, optimizer,
opt_param_scheduler,
num_floating_point_operations_so_far,
checkpointing_context=None)
torch.distributed.barrier()
print_datetime('exiting program at iteration {}'.format(iteration))
exit = True
break
if args.manual_gc:
if args.manual_gc_interval != 0 and iteration % args.manual_gc_interval == 0:
gc.collect()
if is_profile_enabled():
prof.step()
if is_profile_enabled():
prof.stop()
one_logger_utils.track_e2e_metrics()
# Flush TensorBoard and WandB writers.
writer = get_tensorboard_writer()
if writer:
writer.flush()
wandb_writer = get_wandb_writer()
if wandb_writer:
wandb_writer.finish()
# Close out pre-hooks if using distributed optimizer and overlapped param gather.
if args.use_distributed_optimizer and args.overlap_param_gather:
optimizer.disable_pre_hook()
# If any exit conditions (signal handler, duration, iterations) have been reached, exit.
if exit:
sys.exit()
return iteration, num_floating_point_operations_so_far
def training_log(loss_dict, total_loss_dict, learning_rate, decoupled_learning_rate, iteration,
loss_scale, report_memory_flag, skipped_iter,
grad_norm, params_norm, num_zeros_in_grad):
"""Log training information such as losses, timing, ...."""
args = get_args()
timers = get_timers()
writer = get_tensorboard_writer()
wandb_writer = get_wandb_writer()
one_logger = get_one_logger()
# Advanced, skipped, and Nan iterations.
advanced_iters_key = 'advanced iterations'
skipped_iters_key = 'skipped iterations'
nan_iters_key = 'nan iterations'
# Advanced iterations.
if not skipped_iter:
total_loss_dict[advanced_iters_key] = total_loss_dict.get(
advanced_iters_key, 0) + 1
else:
if advanced_iters_key not in total_loss_dict:
total_loss_dict[advanced_iters_key] = 0
# Skipped iterations.
total_loss_dict[skipped_iters_key] = total_loss_dict.get(
skipped_iters_key, 0) + skipped_iter
# Update losses and set nan iterations
got_nan = False
for key in loss_dict:
if not skipped_iter:
total_loss_dict[key] = total_loss_dict.get(
key, torch.tensor([0.0], dtype=torch.float, device='cuda')) + loss_dict[key]
else:
value = loss_dict[key].float().sum().item()
is_nan = value == float('inf') or \
value == -float('inf') or \
value != value
got_nan = got_nan or is_nan
total_loss_dict[nan_iters_key] = total_loss_dict.get(
nan_iters_key, 0) + int(got_nan)
# Logging.
timers_to_log = [
'forward-backward',
'forward-compute',
'backward-compute',
'batch-generator',
'forward-recv',
'forward-send',
'backward-recv',
'backward-send',
'forward-send-forward-recv',
'forward-send-backward-recv',
'backward-send-forward-recv',
'backward-send-backward-recv',
'forward-backward-send-forward-backward-recv',
'layernorm-grads-all-reduce',
'embedding-grads-all-reduce',
'all-grads-sync',
'params-all-gather',
'optimizer-copy-to-main-grad',
'optimizer-unscale-and-check-inf',
'optimizer-clip-main-grad',
'optimizer-count-zeros',
'optimizer-inner-step',
'optimizer-copy-main-to-model-params',
'optimizer']
# Calculate batch size.
batch_size = args.micro_batch_size * args.data_parallel_size * \
get_num_microbatches()
# Track app tag & app tag ID
one_logger_utils.track_app_tag(batch_size, args.world_size, args.seq_length)
total_iterations = total_loss_dict[advanced_iters_key] + \
total_loss_dict[skipped_iters_key]
# Tensorboard values.
# Timer requires all the ranks to call.
if args.log_timers_to_tensorboard and \
(iteration % args.tensorboard_log_interval == 0):
timers.write(timers_to_log, writer, iteration,
normalizer=total_iterations)
if writer and (iteration % args.tensorboard_log_interval == 0):
if wandb_writer:
wandb_writer.log({'samples vs steps': args.consumed_train_samples},
iteration)
if args.log_learning_rate_to_tensorboard:
writer.add_scalar('learning-rate', learning_rate, iteration)
if args.decoupled_lr is not None:
writer.add_scalar('decoupled-learning-rate', decoupled_learning_rate, iteration)
writer.add_scalar('learning-rate vs samples', learning_rate,
args.consumed_train_samples)
if wandb_writer:
wandb_writer.log({'learning-rate': learning_rate}, iteration)
if args.log_batch_size_to_tensorboard:
writer.add_scalar('batch-size', batch_size, iteration)
writer.add_scalar('batch-size vs samples', batch_size,
args.consumed_train_samples)
if wandb_writer:
wandb_writer.log({'batch-size': batch_size}, iteration)
for key in loss_dict:
writer.add_scalar(key, loss_dict[key], iteration)
writer.add_scalar(key + ' vs samples', loss_dict[key],
args.consumed_train_samples)
if wandb_writer:
wandb_writer.log({key: loss_dict[key]}, iteration)
if args.log_loss_scale_to_tensorboard:
writer.add_scalar('loss-scale', loss_scale, iteration)
writer.add_scalar('loss-scale vs samples', loss_scale,
args.consumed_train_samples)
if wandb_writer:
wandb_writer.log({'loss-scale': loss_scale}, iteration)
if args.log_world_size_to_tensorboard:
writer.add_scalar('world-size', args.world_size, iteration)
writer.add_scalar('world-size vs samples', args.world_size,
args.consumed_train_samples)
if wandb_writer:
wandb_writer.log({'world-size': args.world_size}, iteration)
if grad_norm is not None:
writer.add_scalar('grad-norm', grad_norm, iteration)
writer.add_scalar('grad-norm vs samples', grad_norm,
args.consumed_train_samples)
if wandb_writer:
wandb_writer.log({'grad-norm': grad_norm}, iteration)
if num_zeros_in_grad is not None:
writer.add_scalar('num-zeros', num_zeros_in_grad, iteration)
writer.add_scalar('num-zeros vs samples', num_zeros_in_grad,
args.consumed_train_samples)
if wandb_writer:
wandb_writer.log({'num-zeros': num_zeros_in_grad}, iteration)
if params_norm is not None:
writer.add_scalar('params-norm', params_norm, iteration)
writer.add_scalar('params-norm vs samples', params_norm,
args.consumed_train_samples)
if wandb_writer:
wandb_writer.log({'params-norm': params_norm}, iteration)
if args.log_memory_to_tensorboard:
mem_stats = torch.cuda.memory_stats()
writer.add_scalar(
"mem-reserved-bytes",
mem_stats["reserved_bytes.all.current"],
iteration,
)
writer.add_scalar(
"mem-allocated-bytes",
mem_stats["allocated_bytes.all.current"],
iteration,
)
writer.add_scalar(
"mem-allocated-count",
mem_stats["allocation.all.current"],
iteration,
)
if args.num_experts is not None:
moe_loss_scale = 1 / get_num_microbatches()
track_moe_metrics(moe_loss_scale, iteration, writer, wandb_writer, total_loss_dict, args.moe_per_layer_logging)
if args.mtp_num_layers:
from mindspeed_llm.core.transformer.multi_token_prediction import MTPLossLoggingHelper
mtp_loss_scale = 1 / get_num_microbatches()
MTPLossLoggingHelper.track_mtp_metrics(
mtp_loss_scale, iteration, writer, wandb_writer, total_loss_dict
)
if iteration % args.log_interval == 0:
elapsed_time = timers('interval-time').elapsed(barrier=True)
elapsed_time_per_iteration = elapsed_time / total_iterations
throughput = num_floating_point_operations(args, batch_size) / (
elapsed_time_per_iteration * 10**12 * args.world_size)
one_logger_utils.track_e2e_metrics(args.log_throughput, throughput)
if args.log_timers_to_tensorboard:
if writer:
writer.add_scalar('iteration-time',
elapsed_time_per_iteration, iteration)
if wandb_writer:
wandb_writer.log({'iteration-time': elapsed_time_per_iteration},
iteration)
log_string = f" [{datetime.now(tz=timezone.utc).strftime('%Y-%m-%d %H:%M:%S')}]"
log_string += ' iteration {:8d}/{:8d} |'.format(
iteration, args.train_iters)
log_string += ' consumed samples: {:12d} |'.format(
args.consumed_train_samples)
log_string += ' elapsed time per iteration (ms): {:.1f} |'.format(
elapsed_time_per_iteration * 1000.0)
if args.log_throughput:
log_string += f' throughput per GPU (TFLOP/s/GPU): {throughput:.1f} |'
if args.log_timers_to_tensorboard:
if writer:
writer.add_scalar('throughput', throughput, iteration)
if wandb_writer:
wandb_writer.log({'throughput': throughput}, iteration)
assert learning_rate is not None
# Decoupled_learning_rate should be not None only on first and last pipeline stage.
log_string += ' learning rate: {:.6E} |'.format(learning_rate)
if args.decoupled_lr is not None and (mpu.is_pipeline_first_stage(ignore_virtual=True) or
mpu.is_pipeline_last_stage(ignore_virtual=True)):
assert decoupled_learning_rate is not None
log_string += ' decoupled learning rate: {:.6E} |'.format(decoupled_learning_rate)
else:
assert decoupled_learning_rate is None
log_string += ' global batch size: {:5d} |'.format(batch_size)
for key in total_loss_dict:
if key not in [advanced_iters_key, skipped_iters_key,
nan_iters_key]:
avg = total_loss_dict[key].item() / \
float(max(1, total_loss_dict[advanced_iters_key]))
if avg > 0.0:
log_string += ' {}: {:.6E} |'.format(key, avg)
total_loss_dict[key] = torch.tensor([0.0], dtype=torch.float, device='cuda')
log_string += ' loss scale: {:.1f} |'.format(loss_scale)
if grad_norm is not None:
log_string += ' grad norm: {:.3f} |'.format(grad_norm)
if num_zeros_in_grad is not None:
log_string += ' num zeros: {:.1f} |'.format(num_zeros_in_grad)
if params_norm is not None:
log_string += ' params norm: {:.3f} |'.format(params_norm)
log_string += ' number of skipped iterations: {:3d} |'.format(
total_loss_dict[skipped_iters_key])
log_string += ' number of nan iterations: {:3d} |'.format(
total_loss_dict[nan_iters_key])
total_loss_dict[advanced_iters_key] = 0
total_loss_dict[skipped_iters_key] = 0
total_loss_dict[nan_iters_key] = 0
print_rank_last(log_string)
if report_memory_flag and learning_rate > 0.:
# Report memory after optimizer state has been initialized.
if torch.distributed.get_rank() == 0:
num_microbatches = get_num_microbatches()
report_theoretical_memory(args, num_microbatches=num_microbatches, verbose=True)
report_memory('(after {} iterations)'.format(iteration))
report_memory_flag = False
timers.log(timers_to_log, normalizer=args.log_interval)
return report_memory_flag
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/ascend/MindSpeed-LLM.git
git@gitee.com:ascend/MindSpeed-LLM.git
ascend
MindSpeed-LLM
MindSpeed-LLM
2.1.0

搜索帮助