From c0ffd4efc48c75d9103e6d8ea10d1b096d972854 Mon Sep 17 00:00:00 2001 From: niujunhao Date: Fri, 14 Nov 2025 15:43:59 +0800 Subject: [PATCH] add hf dataset streaming --- .../dataset/causal_language_model_dataset.py | 8 +- .../dataset/dataloader/hf_dataloader.py | 542 +++++++++++++++--- .../dataset/dataloader/mock_dataloader.py | 2 +- mindformers/trainer/base_trainer.py | 10 +- .../test_hf_dataloader_streaming.py | 137 +++++ 5 files changed, 609 insertions(+), 90 deletions(-) create mode 100644 tests/st/test_ut/test_dataset/test_dataloader/test_hf_dataloader/test_hf_dataloader_streaming.py diff --git a/mindformers/dataset/causal_language_model_dataset.py b/mindformers/dataset/causal_language_model_dataset.py index 1b3dc35e3..e5452511b 100644 --- a/mindformers/dataset/causal_language_model_dataset.py +++ b/mindformers/dataset/causal_language_model_dataset.py @@ -60,7 +60,8 @@ def dyn_batch_wrapper(*cols, divisor, remainder, pad_token_id=None): outputs = [] for col_idx, col in enumerate(columns): # set dynamic batch max length - max_length = max([len(sample) for sample in col]) + sample_len = [len(sample) for sample in col] + max_length = max(sample_len) if divisor and remainder: max_length = ((max_length - remainder - 1) // divisor + 1) * divisor + remainder else: @@ -385,6 +386,7 @@ class CausalLanguageModelDataset(BaseDataset): 'num_parallel_workers': dataset_config.get('num_parallel_workers', 1), 'column_names': dataset_config.input_columns, 'dataset_dir': dataset_config.data_loader.pop("dataset_dir", None), + 'batch_size': dataset_config.get('batch_size', 1), } dataloader = build_dataset_loader(dataset_config.data_loader, default_args=custom_dataloader_args) return dataloader @@ -482,7 +484,7 @@ class TokenCounter: filename = os.path.join(self.saved_directory, f"rank_{rank_id}_token_counts.csv") # Clear existing file content - with open(filename, 'w', newline='') as csvfile: + with open(filename, 'w', newline='', encoding='utf-8') as csvfile: _ = csv.writer(csvfile) set_safe_mode_for_file_or_dir(filename) @@ -517,7 +519,7 @@ class TokenCounter: rank_id = get_rank() filename = os.path.join(self.saved_directory, f"rank_{rank_id}_token_counts.csv") - with open(filename, mode='a', newline='') as csvfile: + with open(filename, mode='a', newline='', encoding='utf-8') as csvfile: csv_writer = csv.writer(csvfile) if not self.token_count_pairs_header_written: diff --git a/mindformers/dataset/dataloader/hf_dataloader.py b/mindformers/dataset/dataloader/hf_dataloader.py index df86b9d78..8c9fe1531 100644 --- a/mindformers/dataset/dataloader/hf_dataloader.py +++ b/mindformers/dataset/dataloader/hf_dataloader.py @@ -14,12 +14,16 @@ # ============================================================================ """HF DataLoader""" +import os import inspect +import json +import itertools from dataclasses import dataclass from typing import Optional, Union from copy import deepcopy import numpy as np import datasets +from datasets.distributed import split_dataset_by_node import mindspore as ms from mindspore import Tensor, ops @@ -30,6 +34,8 @@ from mindformers.tools.logger import logger from mindformers.tools.utils import get_real_group_size from mindformers.version_control import skip_barrier_controller import mindformers.dataset.handler as custom_process +from mindformers.tools.utils import get_real_rank +from mindformers.tools.utils import FILE_PERMISSION from .utils import is_dataset_built_on_rank from .mock_dataloader import BaseMockDataLoader @@ -65,51 +71,50 @@ def _is_pack_mode(config): for sub_process in config.process: if sub_process.get('type') == 'PackingHandler': return True - return config.create_attention_mask + return bool(getattr(config, "create_attention_mask", False)) def process_legacy_args(**kwargs): """Process and adapt legacy configuration arguments into the non-legacy format.""" - replace_args = dict() + replace_args = {} - # process handler + # Process handlers packing = kwargs.pop('packing', None) handler = kwargs.pop('handler', None) if handler and isinstance(handler, list): - replace_args['handler'] = list() + cur_handler = [] for sub_handler in handler: - if sub_handler.get('type') == 'AlpacaInstructDataHandler': + # Disable padding if packing or dynamic length is used padding = sub_handler.get('padding', True) if packing or sub_handler.get('is_dynamic', False): padding = False - alpaca_args = { + cur_handler.append({ 'type': 'AlpacaInstructDataHandler', 'seq_length': sub_handler.get('seq_length'), 'tokenizer': sub_handler.get('tokenizer'), 'padding': padding - } - replace_args['handler'].append(alpaca_args) - + }) elif sub_handler.get('type') == 'PackingHandler': pack_strategy = sub_handler.get('pack_strategy', packing) - packing_args = { + cur_handler.append({ 'type': 'PackingHandler', 'seq_length': sub_handler.get('seq_length'), 'pack_strategy': pack_strategy - } - replace_args['handler'].append(packing_args) - + }) else: - replace_args['handler'].append(sub_handler) + cur_handler.append(sub_handler) + replace_args['handler'] = cur_handler - # process adaptor config + # Process adaptor config adaptor_config = kwargs.pop('adaptor_config', None) - if adaptor_config and isinstance(adaptor_config, dict): - replace_args['create_compressed_eod_mask'] = adaptor_config.get('compress_mask', False) - replace_args['compressed_eod_mask_length'] = adaptor_config.get('eod_pad_length', 128) + if isinstance(adaptor_config, dict): + replace_args.update({ + 'create_compressed_eod_mask': adaptor_config.get('compress_mask', False), + 'compressed_eod_mask_length': adaptor_config.get('eod_pad_length', 128) + }) - # merge transformed arguments back into kwargs + # Merge updated arguments kwargs.update(replace_args) return kwargs @@ -159,6 +164,43 @@ class HFDataLoaderConfig: raise ValueError("`load` in HFDataLoader must be a dict, but got None.") +@dataclass +class HFStreamingConfig(HFDataLoaderConfig): + """ + Extended configuration for streaming Hugging Face datasets. + + Adds state management and checkpointing parameters for large-scale + streaming dataset loading. + + Attributes: + streaming (bool): Enable streaming mode. + size (int): Total dataset size across all shards. + dataset_state_dir (str): Directory path for saving dataset states. + save_step (int): Frequency (in steps) to save dataset state. + resume_step (int): Step number to resume dataset iteration from. + batch_size (int): Batch size used in iteration (for step calculations). + """ + + streaming: bool = False + + size: int = None + + dataset_state_dir: str = None + + save_step: int = None + + resume_step: int = None + + batch_size: int = 1 + + def __post_init__(self): + """Ensure mandatory streaming parameters are set.""" + if self.size is None or self.dataset_state_dir is None: + raise ValueError( + "`size` and `dataset_state_dir` must be provided when `streaming=True`." + ) + + @MindFormerRegister.register(MindFormerModuleType.DATASET_LOADER) class HFDataLoader: """ @@ -244,7 +286,7 @@ class HFDataLoader: skip_barrier_controller() # Ensure all ranks reach the same point before proceeding if config.use_broadcast_data: - cls._broadcast_dataset_info(dataset.source.dataset) + cls._broadcast_dataset_info(dataset.source) return dataset @@ -252,14 +294,16 @@ class HFDataLoader: def load_dataset(cls, config: HFDataLoaderConfig): """Load datasets, support HF dataset loading methods now.""" load_func_name = config.load.pop('load_func', 'load_dataset') - logger.info(f" > use function `{load_func_name}` to load dataset, " - f"if you need to use other methods, modify the configuration `load_func`.") + logger.info( + f" > using `datasets.{load_func_name}` to load dataset. " + f"Pass 'load_func' in config.load to change this behavior." + ) load_func = getattr(datasets, load_func_name) dataset = load_func(**config.load) if config.load.get('split') is None and not isinstance(dataset, datasets.Dataset): - logger.info(" > `split` argument in load function is not set, use 'train' default.") + logger.info(" > `split` not provided for datasets, use 'train' default.") dataset = dataset.get('train') return dataset @@ -273,12 +317,16 @@ class HFDataLoader: for process_args in config.process: sub_process = deepcopy(process_args) if not isinstance(sub_process, dict) or 'type' not in sub_process: - raise ValueError(f"`process` in HFDataLoader must be a dict or list, " - f"and dict should contain the 'type' key.") + raise ValueError("`process` in HFDataLoader must be a dict or list, " + "and dict should contain the 'type' key.") process_type = sub_process.pop('type') # 1. processing dataset with methods in `handler` module if existed if hasattr(custom_process, process_type): + # In streaming mode, skip packing handler (packing requires random access). + if process_type == 'PackingHandler' and getattr(config, 'streaming', False): + logger.info(" > skipping PackingHandler because streaming mode is enabled.") + continue process_func = getattr(custom_process, process_type) logger.info(f" > processing `{process_type}` in `handler` module ...") dataset = cls._process_custom(process_func, dataset, **sub_process) @@ -304,7 +352,10 @@ class HFDataLoader: python_multiprocessing=False, num_parallel_workers=1): """Wrap source dataset with Mindspore Dataset.""" - hf_dataset = HFDataset(config, dataset) + if getattr(config, 'streaming', False): + hf_dataset = HFIterableDataset(config, dataset, num_shards, shard_id) + else: + hf_dataset = HFDataset(config, dataset) dataset = GeneratorDataset( hf_dataset, column_names=hf_dataset.column_names, @@ -346,12 +397,36 @@ class HFDataLoader: handler = kwargs.pop('handler', None) shuffle = kwargs.pop('shuffle', False) + streaming = kwargs.get('streaming', False) + size = kwargs.pop('size', None) + dataset_state_dir = kwargs.pop('dataset_state_dir', None) + save_step = kwargs.pop('save_step', None) + resume_step = kwargs.pop('resume_step', None) + batch_size = kwargs.pop('batch_size', 1) + # filter out invalid parameters passed from upper-level interfaces. invalid_args = ['dataset_dir', 'column_names', 'type'] for args in invalid_args: kwargs.pop(args, None) - config = HFDataLoaderConfig( + if streaming: + return HFStreamingConfig( + load=kwargs, + process=handler, + create_attention_mask=create_attention_mask, + create_compressed_eod_mask=create_compressed_eod_mask, + compressed_eod_mask_length=compressed_eod_mask_length, + use_broadcast_data=use_broadcast_data, + shuffle=shuffle, + streaming=streaming, + size=size, + dataset_state_dir=dataset_state_dir, + save_step=save_step, + resume_step=resume_step, + batch_size=batch_size, + ) + + return HFDataLoaderConfig( load=kwargs, process=handler, create_attention_mask=create_attention_mask, @@ -360,7 +435,6 @@ class HFDataLoader: use_broadcast_data=use_broadcast_data, shuffle=shuffle, ) - return config @staticmethod def _build_mock_dataset(config, num_shards=None, shard_id=None): @@ -378,11 +452,12 @@ class HFDataLoader: # Receive dataset metadata from main rank (rank 0) dataset_size, num_columns, seq_length = ops.Broadcast(0)(received_data) - mock_data = dict( - dataset_size=dataset_size.numpy()[0], # Total number of samples - num_columns=num_columns.numpy()[0], # Number of dataset columns - seq_length=seq_length.numpy()[0] # Sequence length of each sample - ) + mock_data = { + 'dataset_size': dataset_size.numpy()[0], # Total number of samples + 'num_columns': num_columns.numpy()[0], # Number of dataset columns + 'seq_length': seq_length.numpy()[0] # Sequence length of each sample + } + logger.info(f"\n > received dataset info: \n" f" size: {mock_data.get('dataset_size')} \n" f" num_columns: {mock_data.get('num_columns')} \n" @@ -406,12 +481,12 @@ class HFDataLoader: and use it to construct a mock dataset with the same structure. """ # iter dataset - sample = next(iter(dataset)) - sample_columns = list(sample.keys()) + sample = next(iter(itertools.tee(dataset, 1)[0])) + sample_columns = dataset.column_names dataset_size = len(dataset) num_columns = len(sample_columns) - seq_length = len(sample.get(sample_columns[0])) # Assuming the first column is `input_ids` + seq_length = len(sample[0]) # Assuming the first column is `input_ids` logger.info(f"\n > build real dataset completed, broadcast dataset info: \n" f" size: {dataset_size} \n" @@ -452,7 +527,7 @@ class HFDataset: else: self.column_names = ['input_ids', 'labels', 'loss_mask', 'position_ids', 'attention_mask'] else: - self.column_names = list(next(iter(dataset)).keys()) + self.column_names = list(next(iter(deepcopy(dataset))).keys()) self._check_columns(self.column_names) self.dataset_size = len(dataset) @@ -495,54 +570,19 @@ class HFDataset: tuple: Tokens, labels, loss mask, position IDs, and attention mask. """ sample = self.dataset[idx] - tokens = np.array(sample['input_ids']) - labels = np.array(sample['labels']) - actual_seq_len = np.array(sample['actual_seq_len']) - - seq_length = len(tokens) - - # loss mask - loss_mask = (labels != -100) - - # position ids and attention mask - position_ids = [] - if self.create_compressed_eod_mask: - attention_mask = None - else: - attention_mask = np.expand_dims(np.tril(np.ones((seq_length, seq_length))), axis=0) - pre_seq = 0 - for seq in actual_seq_len: - sub_pos = np.arange(seq - pre_seq, dtype=np.float32) - position_ids.append(sub_pos) - pre_seq = seq - if attention_mask is not None: - attention_mask[0, seq:, : seq] = 0 - - position_ids.append(np.arange(seq_length - actual_seq_len[-1], dtype=np.float32)) - position_ids = np.concatenate(position_ids) - - if self.create_compressed_eod_mask: - if self.compressed_eod_mask_length < len(actual_seq_len): - raise ValueError( - f"The actual_seq_len: {len(actual_seq_len)} in the dataset exceeds the " - f"compressed_eod_mask_length: {self.compressed_eod_mask_length}, please check data or " - f"increase the compressed_eod_mask_length.") - - actual_seq_len = np.pad( - actual_seq_len, (0, self.compressed_eod_mask_length - len(actual_seq_len)), - mode='constant', - constant_values=seq_length) - attention_mask = actual_seq_len - else: - # reverse attention mask - attention_mask = attention_mask < 0.5 - - return ( - tokens.astype(np.int32), - labels.astype(np.int32), - loss_mask.astype(np.int32), - position_ids.astype(np.int32), - attention_mask.astype(np.int32) + input_ids = sample.get('input_ids') + labels = sample.get('labels') + actual_seq_len = sample.get('actual_seq_len') + + if input_ids is None or labels is None or actual_seq_len is None: + raise ValueError("Packed dataset sample missing required keys: 'input_ids','labels' or 'actual_seq_len'.") + + return _get_packed_data( + input_ids, + labels, + actual_seq_len, + self.create_compressed_eod_mask, + self.compressed_eod_mask_length, ) @staticmethod @@ -609,3 +649,335 @@ class MockHFDataLoader(BaseMockDataLoader): data_dtypes = ['int32'] * len(mock_columns) super().__init__(mock_columns, data_shapes, data_dtypes, dataset_size) + + +def _resume_hf_iterable_dataset(dataset, step): + """ + Resume a Hugging Face iterable dataset from a given training step. + """ + inner_dataset = dataset + max_depth = 20 # Prevent infinite recursion when traversing nested structures + cur_depth = 0 + + # Traverse into nested dataset wrappers to find the actual HF dataset source + while not hasattr(inner_dataset, 'source'): + if isinstance(inner_dataset, list): + inner_dataset = inner_dataset[0] + elif hasattr(inner_dataset, 'children'): + inner_dataset = inner_dataset.children + + cur_depth += 1 + if cur_depth >= max_depth: + # Safety check: stop if nesting is unexpectedly deep + return + + source = inner_dataset.source + # Ensure the dataset supports `_load_state` before assigning resume step + if not hasattr(source, '_load_state'): + return + + logger.info(f"Set HFIterableDataset resume_step={step}") + # Set the resume step + source.resume_step = step + + +class HFIterableDataset: + """ + Iterable wrapper for streaming HF datasets used with MindSpore GeneratorDataset. + + This class supports: + - sharding via `datasets.distributed.split_dataset_by_node` when `num_shards` is provided; + - packing multiple samples into a fixed `seq_length` when PackingHandler is enabled; + - saving and loading dataset iterator state for resumable streaming. + + Notes on __getitem__: + - MindSpore may call __getitem__ with arbitrary indices; we treat each call as "advance once" + and manage an internal iterator to produce the next sample or packed batch. + - `idx` is not treated as a strict sample index; it's used only for a few control actions + (e.g., first-call resume). This keeps the iterator behavior stable under MindSpore's expectations. + """ + + def __init__(self, config, dataset, num_shards, shard_id): + self.size = config.size + self.state_dir = config.dataset_state_dir + self.save_step = config.save_step + self.resume_step = config.resume_step + + self.create_compressed_eod_mask = config.create_compressed_eod_mask + self.compressed_eod_mask_length = config.compressed_eod_mask_length + self.pack_config = self._get_pack_config(config) + + # Define column names based on whether packed mode is used + if self.pack_config: + if config.create_compressed_eod_mask: + self.column_names = ['input_ids', 'labels', 'loss_mask', 'position_ids', 'actual_seq_len'] + else: + self.column_names = ['input_ids', 'labels', 'loss_mask', 'position_ids', 'attention_mask'] + else: + self.column_names = list(next(iter(deepcopy(dataset))).keys()) + + self.shard_id = shard_id if shard_id is not None else 0 + if num_shards: + dataset = split_dataset_by_node(dataset, shard_id, num_shards) + per_shard_size = config.size // num_shards + else: + per_shard_size = config.size + self.batch_size = config.batch_size + self.max_iters = per_shard_size - (per_shard_size % self.batch_size) + logger.info(f"HFIterableDataset: per_shard_size={per_shard_size}, " + f"batch_size={self.batch_size}, max_iters={self.max_iters}") + + self.source = deepcopy(dataset) + self.dataset = None + self.dataset_iter = None + + # iteration state + self.cur_epoch = 1 + self.cur_iter = 0 + self.cur_step = 0 + self.global_step = 0 + self.total_steps = self.max_iters // self.batch_size + + # initialize iterator state + self._init_state() + + self.local_rank = get_real_rank() + + def __getitem__(self, idx): + """ + Return the next sample (or packed sample) for the generator. + + Important: + - `idx` is not treated as absolute index. MindSpore may call __getitem__ with + many worker-specific indices; we only use idx for first-call resume behaviour. + """ + if int(idx) == self.shard_id or int(idx) == 0: + self.cur_iter = 0 + self.cur_step = 0 + self._init_state() + + if self.resume_step and int(idx) > self.shard_id: + self._load_state(self.resume_step) + self.resume_step = None + + if self.pack_config: + sample = self._query_packed_samples() + else: + sample = self._query_single_sample() + + # update internal counters + self.cur_iter += 1 + if (self.cur_iter - 1) % self.batch_size == 0: + self.cur_step += 1 + self.global_step = (self.cur_epoch - 1) * self.total_steps + self.cur_step + if self.cur_iter >= self.max_iters: + self.cur_epoch += 1 + + if self.save_step and self.global_step % self.save_step == 0: + self._save_state(self.global_step) + + return sample + + def _query_single_sample(self): + """ + Fetch one sample from the underlying iterator, re-initializing if iterator exhausted. + Returns a tuple of column values (ordered by self.column_names). + """ + sample = None + if self.cur_iter >= self.max_iters: + is_init_state = True + else: + is_init_state = False + try: + sample = tuple(next(self.dataset_iter).values()) + except StopIteration: + is_init_state = True + + if is_init_state: + self._init_state() + sample = tuple(next(self.dataset_iter).values()) + return sample + + def _load_state(self, step): + """ + Load dataset iterator state and training counters from JSON saved by `_save_state`. + """ + state_path = f"{self.state_dir}/step_{step}/dataset_state_rank{self.local_rank}.json" + if not os.path.exists(state_path): + raise FileNotFoundError(f"Dataset state file not found: {state_path}") + logger.info(f'Load dataset state form {state_path}.') + + with open(state_path, 'r', encoding='utf-8') as fp: + state_dict = json.load(fp) + train_state = state_dict.pop('_train_state') + self.cur_iter = train_state.get('cur_iter') + self.cur_epoch = train_state.get('cur_epoch') + self.cur_step = train_state.get('cur_step') + self.global_step = train_state.get('global_step') + + # reinitialize dataset and load internal state if available + self._init_state() + self.dataset.load_state_dict(state_dict) + + def _save_state(self, step): + """ + Save the dataset state and current training counters to JSON. + The saved file layout: + /step_/dataset_state_rank.json + """ + state_path = f"{self.state_dir}/step_{step}/dataset_state_rank{self.local_rank}.json" + os.makedirs(os.path.dirname(state_path), exist_ok=True) + state_dict = self.dataset.state_dict() + state_dict['_train_state'] = { + 'cur_iter': self.cur_iter, + 'cur_epoch': self.cur_epoch, + 'cur_step': self.cur_step, + 'global_step': self.global_step, + } + flags = os.O_WRONLY | os.O_CREAT | os.O_TRUNC + file = os.open(state_path, flags, FILE_PERMISSION) + with os.fdopen(file, 'w', encoding='utf-8') as fp: + json.dump(state_dict, fp, indent=2) + logger.info(f'Save dataset state to {state_path}.') + + def _init_state(self): + """Deepcopy the source and create a fresh iterator for iteration.""" + logger.info( + "Initialize the iterator state, which may occur either at the start of training " + "or when the iterator is exhausted.") + self.dataset = deepcopy(self.source) + self.dataset_iter = iter(self.dataset) + + def _query_packed_samples(self): + """ + Build a packed sequence by concatenating multiple samples until `seq_length` tokens are reached. + """ + seq_length = int(self.pack_config.get('seq_length')) + pack_iter = itertools.tee(self.dataset_iter, 1)[0] + input_ids = [] + labels = [] + actual_seq_len = [0] + + while len(input_ids) < seq_length: + try: + candidate = next(pack_iter) + except StopIteration: + self._init_state() + break + + cur_input_ids = candidate.get('input_ids') + cur_labels = candidate.get('labels') + cur_seq_length = len(cur_input_ids) + actual_seq_len[-1] + if cur_seq_length > seq_length: + # cannot accept candidate with overflow; stop packing + break + + # accept candidate: append content and advance real iterator + input_ids.extend(cur_input_ids) + labels.extend(cur_labels) + actual_seq_len.append(cur_seq_length) + next(self.dataset_iter) + + # padding if necessary + pad_length = seq_length - len(input_ids) + if pad_length > 0: + input_ids.extend([0] * pad_length) + labels.extend([-100] * pad_length) + del pack_iter + + data = _get_packed_data( + input_ids, + labels, + actual_seq_len, + self.create_compressed_eod_mask, + self.compressed_eod_mask_length, + ) + return data + + def __len__(self): + """Return total number of samples in the dataset.""" + return self.size + + @staticmethod + def _get_pack_config(config): + """ + Extract the configuration dictionary for `PackingHandler`. + """ + if isinstance(config.process, list): + for sub_process in config.process: + if sub_process.get('type') == 'PackingHandler': + return deepcopy(sub_process) + return {} + + +def _get_packed_data( + tokens, + labels, + actual_seq_len, + create_compressed_eod_mask: bool = False, + compressed_eod_mask_length: int = None, +): + """ + Convert packed sequence components into final arrays required by training. + + Args: + tokens (list or np.ndarray): flattened token ids for concatenated subsequences (length == seq_length). + labels (list or np.ndarray): flattened labels for concatenated subsequences (length == seq_length). + actual_seq_len (list or np.ndarray): cumulative end positions for subsequences, starting from 0. + Example: [0, 5, 12] means first subseq length 5, second subseq length 7 (cumulative 12). + create_compressed_eod_mask (bool): if True, return `attention_mask` as compressed EOD mask array. + compressed_eod_mask_length (int): maximum number of subsequences to store in compressed mask. + + Returns: + tuple(tokens_np, labels_np, loss_mask_np, position_ids_np, attention_mask_np_or_actual_seq_len) + All returned arrays are numpy arrays with dtype `int32`. + """ + tokens = np.array(tokens) + labels = np.array(labels) + actual_seq_len = np.array(actual_seq_len) + + seq_length = len(tokens) + + # loss mask + loss_mask = labels != -100 + + # position ids and attention mask + position_ids = [] + if create_compressed_eod_mask: + attention_mask = None + else: + attention_mask = np.expand_dims(np.tril(np.ones((seq_length, seq_length))), axis=0) + pre_seq = 0 + for seq in actual_seq_len: + sub_pos = np.arange(seq - pre_seq, dtype=np.float32) + position_ids.append(sub_pos) + pre_seq = seq + if attention_mask is not None: + attention_mask[0, seq:, : seq] = 0 + + position_ids.append(np.arange(seq_length - actual_seq_len[-1], dtype=np.float32)) + position_ids = np.concatenate(position_ids) + + if create_compressed_eod_mask: + if compressed_eod_mask_length < len(actual_seq_len): + raise ValueError( + f"The actual_seq_len: {len(actual_seq_len)} in the dataset exceeds the " + f"compressed_eod_mask_length: {compressed_eod_mask_length}, please check data or " + f"increase the compressed_eod_mask_length.") + + actual_seq_len = np.pad( + actual_seq_len, (0, compressed_eod_mask_length - len(actual_seq_len)), + mode='constant', + constant_values=seq_length) + attention_mask = actual_seq_len + else: + # reverse attention mask + attention_mask = attention_mask < 0.5 + + return ( + tokens.astype(np.int32), + labels.astype(np.int32), + loss_mask.astype(np.int32), + position_ids.astype(np.int32), + attention_mask.astype(np.int32) + ) diff --git a/mindformers/dataset/dataloader/mock_dataloader.py b/mindformers/dataset/dataloader/mock_dataloader.py index 5a99c7848..87b9f2816 100644 --- a/mindformers/dataset/dataloader/mock_dataloader.py +++ b/mindformers/dataset/dataloader/mock_dataloader.py @@ -78,7 +78,7 @@ class BaseMockDataLoader: Returns: list: A list of tensors corresponding to each column, in the order of mock_columns. """ - return [getattr(self, col) for col in self.mock_columns] + return tuple(getattr(self, col) for col in self.mock_columns) def __len__(self): """ diff --git a/mindformers/trainer/base_trainer.py b/mindformers/trainer/base_trainer.py index 4da99944e..e6b778331 100644 --- a/mindformers/trainer/base_trainer.py +++ b/mindformers/trainer/base_trainer.py @@ -68,6 +68,7 @@ from mindformers.modules.seq_pipe import SequenceSplit from mindformers.utils.load_checkpoint_utils import get_load_path_after_hf_convert from mindformers.checkpoint.checkpoint import load_checkpoint, CommonInfo from mindformers.checkpoint.utils import compile_model +from mindformers.dataset.dataloader.hf_dataloader import _resume_hf_iterable_dataset from ..core.config_args import ConfigArguments from .training_args import TrainingArguments from .utils import ( @@ -1010,7 +1011,11 @@ class BaseTrainer: self._check_input_sliced_sig(config, f"{dataloader_type} with packing") config.train_dataset.data_loader.create_attention_mask = True - # Must use broadcast opt level > 0 if broadcast is enabled + if dataloader_config.streaming and isinstance(config.callbacks, list): + for callback in config.callbacks: + if callback.get('type') == 'CheckpointMonitor': + dataloader_config.save_step = int(callback.get('save_checkpoint_steps')) + if dataloader_config.get('use_broadcast_data', True): self._set_dataset_broadcast_opt_level(config) @@ -1197,6 +1202,9 @@ class BaseTrainer: logger.info("dataset skip %d steps.", data_skip_steps) else: logger.info("ignore dataset skip.") + if config.resume_training: + resume_step = config.runner_config.initial_step + _resume_hf_iterable_dataset(dataset, resume_step) self.set_train_dataset(dataset) check_runner_config(config, dataset) diff --git a/tests/st/test_ut/test_dataset/test_dataloader/test_hf_dataloader/test_hf_dataloader_streaming.py b/tests/st/test_ut/test_dataset/test_dataloader/test_hf_dataloader/test_hf_dataloader_streaming.py new file mode 100644 index 000000000..740ca3502 --- /dev/null +++ b/tests/st/test_ut/test_dataset/test_dataloader/test_hf_dataloader/test_hf_dataloader_streaming.py @@ -0,0 +1,137 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""test hf dataloader streaming""" + +import os +import json +from unittest.mock import patch +import pytest + +from mindformers.tools.register.config import MindFormerConfig +from mindformers.dataset.dataloader.hf_dataloader import HFDataLoader +from .test_hf_dataloader import MockTokenizer + +WORK_DIR = os.path.dirname(os.path.abspath(__file__)) +JSON_DATASET_PATH = os.path.join(WORK_DIR, 'alpaca.json') + +DATASET_CONFIG = { + "data_loader": { + "type": "HFDataLoader", + "load_func": "load_dataset", + "path": "", + "data_files": "", + "streaming": True, + "size": 6, + "dataset_state_dir": f"{WORK_DIR}/saved_state", + "handler": None, + "use_broadcast_data": False, + "create_compressed_eod_mask": False, + "compressed_eod_mask_length": 128, + "shuffle": False, + } +} +GLOBAL_CONFIG = MindFormerConfig(**DATASET_CONFIG) + + +def generate_json(): + """generate alpaca samples and save to json""" + num_samples = 6 + sample = [{ + "instruction": "Explain why the following fraction is equivalent to 1/4", + "input": "4/16", + "output": "The fraction 4/16 is equivalent to 1/4 because both fractions represent the same value. " + "A fraction can be simplified by dividing both the numerator and the denominator by a common factor. " + "In this case, 4 is a common factor of both the numerator and the denominator of 4/16. " + "When we divide both by 4, we get 4/4 = 1 and 16/4 = 4, so the simplified fraction is 1/4. " + "Alternatively, we can think of this in terms of multiplication. For example, if we multiply " + "the numerator and denominator of the fraction 1/4 by 4, we get (1x4)/(4x4), or 4/16. Since " + "both fractions can be derived from the other through multiplication or division by the same number, " + "they represent the same value and are equivalent." + }] * (num_samples // 2) + sample += [{ + "instruction": "Give three tips for staying healthy.", + "input": "", + "output": "1. Eat a balanced and nutritious diet: Make sure your meals are inclusive of a variety of fruits " + "and vegetables, lean protein, whole grains, and healthy fats. This helps to provide your body " + "with the essential nutrients to function at its best and can help prevent chronic diseases." + "\n\n2. Engage in regular physical activity: Exercise is crucial for maintaining strong bones, " + "muscles, and cardiovascular health. Aim for at least 150 minutes of moderate aerobic exercise " + "or 75 minutes of vigorous exercise each week.\n\n3. Get enough sleep: Getting enough quality " + "sleep is crucial for physical and mental well-being. It helps to regulate mood, improve cognitive " + "function, and supports healthy growth and immune function. Aim for 7-9 hours of sleep each night." + }] * (num_samples // 2) + with open(JSON_DATASET_PATH, 'w', encoding='utf-8') as fp: + json.dump(sample, fp, indent=2) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_load_streaming_resume(): + """ + Feature: HFDataLoader load dataset with streaming=True + Description: HFDataLoader save and resume in streaming mode + Expectation: success + """ + generate_json() + GLOBAL_CONFIG.data_loader.path = 'json' + GLOBAL_CONFIG.data_loader.data_files = JSON_DATASET_PATH + GLOBAL_CONFIG.data_loader.save_step = 4 + dataloader = HFDataLoader(**GLOBAL_CONFIG.data_loader, python_multiprocessing=False) + + resume_target = None + for idx, sample in enumerate(dataloader.source): + resume_target = sample + if idx >= 3: + break + + GLOBAL_CONFIG.data_loader.resume_step = 4 + dataloader = HFDataLoader(**GLOBAL_CONFIG.data_loader, python_multiprocessing=False) + dataloader.set_init_step(4) + resume_sample = dataloader.source[1] + + assert resume_sample[-1] == resume_target[-1] + + +@patch('mindformers.dataset.handler.base_handler.BaseInstructDataHandler.build_tokenizer', + return_value=MockTokenizer()) +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_load_streaming_pack(mock_tokenizer): + """ + Feature: HFDataLoader load dataset with streaming=True + Description: HFDataLoader pack samples in streaming mode + Expectation: success + """ + _ = mock_tokenizer + generate_json() + GLOBAL_CONFIG.data_loader.path = 'json' + GLOBAL_CONFIG.data_loader.data_files = JSON_DATASET_PATH + GLOBAL_CONFIG.data_loader.handler = [ + {"type": "AlpacaInstructDataHandler", "seq_length": 4096, "padding": False}, + {"type": "PackingHandler", "pack_strategy": "pack", "seq_length": 4096} + ] + dataloader = HFDataLoader(**GLOBAL_CONFIG.data_loader, python_multiprocessing=False) + + target_sample = dataloader.source[0] + assert len(target_sample) == 5 # ['input_ids', 'labels', 'loss_mask', 'position_ids', 'attention_mask'] + assert target_sample[-1].shape == (1, 4096, 4096) + + GLOBAL_CONFIG.data_loader.create_compressed_eod_mask = True + dataloader = HFDataLoader(**GLOBAL_CONFIG.data_loader, python_multiprocessing=False) + + target_sample = dataloader.source[0] + assert target_sample[-1].shape == (GLOBAL_CONFIG.data_loader.compressed_eod_mask_length,) -- Gitee