代码拉取完成,页面将自动刷新
use_legacy: False # Specifies whether to use the mcore model.
pretrained_model_dir: "path/to/model_dir" # The directory path where the Hugging Face model configuration is located.
seed: 0 # Set the global seed. For details, refer to https://www.mindspore.cn/docs/en/master/api_python/mindspore/mindspore.set_seed.html
output_dir: './output' # Set the path where log, checkpoint, strategy, etc. files are saved
load_checkpoint: '' # File or folder paths for loading weights
auto_trans_ckpt: True # If true, auto transform load_checkpoint to load in distributed model
resume_training: False # Enable resumable training after breakpoint. For details, refer to https://www.mindspore.cn/mindformers/docs/en/master/feature/resume_training.html#resumable-training
run_mode: 'train' # Set the running mode of the model: `train`, `finetune` or `predict`
use_parallel: True # Enable parallel mode
load_ckpt_format: 'safetensors' # The format of loading checkpoint, either `ckpt` or `safetensors`
# dataset
train_dataset: &train_dataset
data_loader:
type: BlendedMegatronDatasetDataLoader
datasets_type: "GPTDataset"
sizes:
- 1000 # Number of training set data samples
- 0 # The number of test set data samples is currently not supported for configuration
- 0 # The sample size of the evaluation set is currently not supported for configuration
config: # GPTDataset configuration
seed: 1234 # Random seed for data sampling
split: "1, 0, 0" # The usage ratio of training, testing, and evaluation sets is currently not supported for configuration
seq_length: 4096 # The sequence length of the data returned by the dataset
eod_mask_loss: False # Whether to calculate loss at EOD
reset_position_ids: False # Whether to reset position_ids at eod
create_attention_mask: True # Return attention mask
reset_attention_mask: False # Whether to reset attention_mask at eod and return a stepped attention_mask
create_compressed_eod_mask: False # Whether to return the compressed attention mask
eod_pad_length: 128 # Set the length of the attention mask after compression
eod: 1 # Token ID of EOD in the dataset
pad: -1 # Token ID of pad in the dataset
data_path: # Sampling ratio and path of Megatron dataset
- '1' # Proportion of dataset sampling
- "/path/megatron_data" # The bin file path of the dataset (excluding the .bin suffix)
input_columns: [ "input_ids", "labels", "loss_mask", "position_ids", "attention_mask" ] # Set the input data columns for the training dataset
construct_args_key: [ "input_ids", "labels", "loss_mask", "position_ids", "attention_mask" ] # Set the dataset part `keys` of the model `construct` input to the model in lexicographical order, used when the parameter passing order of the model does not match the order of the dataset input
num_parallel_workers: 8 # The number of parallel workers
python_multiprocessing: False # Enabling Python multi-process mode to improve data processing performance
drop_remainder: True # Whether to discard the last batch of data if it contains fewer samples than batch_size
numa_enable: False # Whether to use NUMA binding function
prefetch_size: 1 # Set the amount of pre-read data
seed: 1234 # Random seed for dataset sampling. Megatron datasets use this value to randomly sample and concatenate samples. Default: `1234`
train_dataset_task:
type: CausalLanguageModelDataset # Set up the dataset class, which is used to encapsulate the data loading class and other related configurations.
dataset_config: *train_dataset # Typically set as a reference to `train_dataset`, containing all configuration entries for `train_dataset`.
# model config
model:
model_config:
qkv_concat: True # qkv_concat conversion
input_sliced_sig: True # Has the dataset been processed to the seq_length size of the model
compute_dtype: "bfloat16" # Types used in calculations
layernorm_compute_dtype: "float32" # The computation type of layernorm
softmax_compute_dtype: "float32" # The computation type of softmax
rotary_dtype: "float32" # The dtype of rotary embeddings.
router_dense_type: "float32" # Router score data type
params_dtype: "float32" # Parameter initialization type
offset: 0 # Offset of transformer layer when set pipeline stage number
# recompute config
recompute_config:
recompute: False # Whether to enable recalculation
select_recompute: False # Whether to enable selective re-computation, which only involves re-computation of operators in the attention layer
parallel_optimizer_comm_recompute: False # Is the AllGather communication introduced in parallel by the optimizer subject to recalculation
mp_comm_recompute: False # Is the communication operation introduced by the parallel model recalculated
recompute_slice_activation: False # Whether slice the Cell output stored in memory
# optimizer
optimizer:
type: AdamW # Set the optimizer class, the optimizer is mainly used to calculate the gradient for model training
betas: [0.9, 0.999] # The exponential decay rate of `moment1` and `moment2`. Each parameter range (0.0, 1.0)
eps: 1.e-6 # Add it to the denominator to improve numerical stability. Must be greater than 0
weight_decay: 0.01 # Set the optimizer weight decay coefficient
# lr schedule
lr_schedule:
type: CosineWithWarmUpLR # Set the lr_schedule class
learning_rate: 1.e-6 # Set the initialized learning rate size
lr_end: 1.e-6 # Final value of the learning rate
warmup_ratio: 0 # Ratio of warmup phase to total training steps
total_steps: -1 # -1 means it will load the total steps of the dataset
# default parallel of device num = 8 910B
parallel_config:
data_parallel: &dp 2 # Set the number of data parallel
model_parallel: 2 # Set the number of model parallel
pipeline_stage: 2 # Set the number of pipeline parallel
context_parallel: 1 # Set the number of sequence parallel
use_seq_parallel: True # Corresponding to Megatron Short Sequence Parallelism
micro_batch_num: 2 # Set the pipeline parallel microbatch size, which should satisfy `parallel_config.micro_batch_num` >= `parallel_config.pipeline_stage` when `parallel_config.pipeline_stage` is greater than 1
vocab_emb_dp: False # Shard embedding in model parallel or data parallel. If True, the embedding lookup
# when model parallel is greater than 1, we can set micro_batch_interleave_num=2, that may accelerate the train process.
micro_batch_interleave_num: 1 # Set the size of the interleave micro batch data in each step of the training. This parameter is used to calculate the actual loss value
# callbacks
callbacks:
- type: CheckpointMonitor # Set the callbacks class
prefix: "llm" # Set the prefix for saving file names
save_checkpoint_steps: 5000 # Set the number of interval steps for saving model weights
keep_checkpoint_max: 1 # Set the maximum number of model weight files to be saved, if there are more model weight files in the save path, they will be deleted starting from the earliest file created to ensure that the total number of files does not exceed `keep_checkpoint_max`
integrated_save: False # Turn on aggregation to save the weights file
async_save: False # Set an asynchronous execution to save the model weights file
# parallel context config
parallel:
parallel_mode: 1 # 0-data parallel, 1-semi-auto parallel, 2-auto parallel, 3-hybrid parallel
gradients_mean: False # Whether to execute the averaging operator after the gradient AllReduce. Typically set to `False` in semi-automatic parallel mode and `True` in data parallel mode
enable_alltoall: False # Enables generation of the AllToAll communication operator during communication. Typically set to `True` only in MOE scenarios, default value is `False`
full_batch: False # Whether to load the full batch of data from the dataset in parallel mode. Setting it to `True` means all ranks will load the full batch of data. Setting it to `False` means each rank will only load the corresponding batch of data
dataset_strategy: [[*dp, 1], [*dp, 1], [*dp, 1], [*dp, 1], [*dp, 1, 1, 1]] # Only supports `List of List` type and is effective only when `full_batch=False`. The number of sublists in the list must be equal to the length of `train_dataset.input_columns`. Each sublist in the list must have the same shape as the data returned by the dataset. Generally, data parallel splitting is done along the first dimension, so the first dimension of the sublist should be configured to match `data_parallel`, while the other dimensions should be set to `1`
search_mode: "sharding_propagation" # Set fully-automatic parallel strategy search mode, options are `recursive_programming`, `dynamic_programming` and `sharding_propagation`, only works in fully-automatic parallel mode, experimental interface
strategy_ckpt_config:
save_file: "./ckpt_strategy.ckpt" # The path of saving files
only_trainable_params: False # Whether to save (or load) information about the slicing strategy for trainable parameters only, default is True, set this parameter to `False` when there are frozen parameters in the network but need to be sliced
enable_parallel_optimizer: True # Whether enable the optimizer parallel 1.slice model weight parameters by number of devices in data parallel mode 2.slice model weight parameters by `parallel_config.data_parallel` in semi-automatic parallel mode
parallel_optimizer_config:
gradient_accumulation_shard: False # Set whether the cumulative gradient variable is sliced on the data-parallel dimension, only effective if `enable_parallel_optimizer=True`
parallel_optimizer_threshold: 64 # Set the threshold for the optimizer weight parameter cut, effective only if `enable_parallel_optimizer=True`
# mindspore context init config
context:
mode: 0 # 0--Graph Mode; 1--Pynative Mode
device_target: "Ascend" # Set the backend execution device. MindSpore Transformers is only supported on `Ascend` devices
max_device_memory: "58GB" # Set the maximum memory available to the device in the format “xxGB”, and the default value is `1024GB`
save_graphs: False # Save the compilation graph during execution
save_graphs_path: "./graph" # Path for saving the compilation diagram
memory_optimize_level: "O1" # The memory optimize level
jit_config:
jit_level: "O0" # The jit level, could be O0, O1 or O2
ascend_config:
parallel_speed_up_json_path: "path/to/parallel_speed_up.json" # The path to the parallel speed up json file, configuration can refer to `parallel_speed_up.json
# trainer config
trainer:
type: CausalLanguageModelingTrainer # Set the trainer class, usually different models for different application scenarios will set different trainer classes
model_name: 'llm' # Set the model name in the format '{name}_xxb', indicating a certain specification of the model
# runner config
runner_config:
epochs: 2 # Set the number of rounds for model training
batch_size: 1 # Set the sample size of the batch data, which overrides the `batch_size` in the dataset configuration
# wrapper cell config
runner_wrapper:
type: MFTrainOneStepCell # Set the wrapper class, generally set 'MFTrainOneStepCell'
scale_sense: 1.0 # Gradient scaling configuration
use_clip_grad: True # Turn on gradient clipping. Turning on to avoid cases where the inverse gradient is too large and training fails to converge
profile: False # Whether to enable the performance analysis tool
profile_start_step: 1 # Set the number of steps to start collecting performance data
profile_stop_step: 10 # Set the number of steps to stop collecting performance data
init_start_profile: False # Set whether to turn on collecting performance data when the Profiler is initialized; this parameter does not take effect when `profile_start_step` is set. This parameter needs to be set to `True` when `profile_memory` is turned on
profile_communication: False # Set whether communication performance data is collected in multi-device training, this parameter is invalid when using single card training
profile_memory: True # Set whether to collect Tensor memory data
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。