diff --git a/mindformers/dataset/dataloader/blended_megatron_dataloader.py b/mindformers/dataset/dataloader/blended_megatron_dataloader.py index fbfed7d064df855bfee11d343403b3a7b8f7cfbc..c916cbf74b214c1d6a473f82c4e7dc9292452700 100644 --- a/mindformers/dataset/dataloader/blended_megatron_dataloader.py +++ b/mindformers/dataset/dataloader/blended_megatron_dataloader.py @@ -157,7 +157,12 @@ class MegatronDatasetBuilder: logger.info(f"This rank is {global_rank_id}, tensor parallel = {tp}, \ pipeline stage = {global_rank_id//(dp *tp )}, this rank will build empty data.") source = MockBlendedMegatron(blended_config, sizes[0]) - gen_dataset = GeneratorDataset(source, column_names=source.cols(), shuffle=False) + gen_dataset = GeneratorDataset( + source, + column_names=source.cols(), + shuffle=False, + num_shards=num_shards, + shard_id=shard_id) # keep the same sharding as real dataset skip_barrier_controller(times=2) else: gen_dataset = build_gpt_dataset() diff --git a/mindformers/trainer/base_trainer.py b/mindformers/trainer/base_trainer.py index 26aba03869946b7ea7822ff0ec8830518f4eda5c..9feb84765b5fb8614eb80bb16501e79a09474546 100644 --- a/mindformers/trainer/base_trainer.py +++ b/mindformers/trainer/base_trainer.py @@ -65,7 +65,6 @@ from mindformers.core.callback.callback import ( ColdHotExpertMonitor, TopkBiasBalanceCallback ) -from mindformers.dataset.dataloader.blended_megatron_dataloader import is_dataset_built_on_rank from mindformers.modules.seq_pipe import SequenceSplit from mindformers.utils.load_checkpoint_utils import get_load_path_after_hf_convert from ..core.config_args import ConfigArguments @@ -799,23 +798,12 @@ class BaseTrainer: dataset_info = config.train_dataset.data_loader # reset dataset size to remove redundant data - ori_ds = dataset.get_dataset_size() - dataset.dataset_size = int(dataset_info.sizes[0]) // self.global_batch_size - cur_ds = dataset.get_dataset_size() - logger.info(f"Use BlendedMegatronDatasetDataLoader, reset dataset size {ori_ds} to {cur_ds}.") + dataset = dataset.take(int(dataset_info.sizes[0]) // self.global_batch_size) + logger.info(f"Use BlendedMegatronDatasetDataLoader, reset dataset size to {dataset.get_dataset_size()}.") # Sync assign eod compression arguments if self.config.train_dataset.data_loader.config.create_compressed_eod_mask: self.config.model.model_config.use_eod_attn_mask_compression = True - - # skip data for real dataset - if config.data_skip_steps or config.resume_training: - rank_id = get_real_rank() - parallel_mode = ms.context.get_auto_parallel_context("parallel_mode") - if parallel_mode in ("semi_auto_parallel", "auto_parallel") and not is_dataset_built_on_rank(): - # not skip fake data in megatron dataset - config.ignore_data_skip = True - logger.info(f"local rank id: {rank_id}, ignore data skip: {config.ignore_data_skip}.") return dataset, config @staticmethod