From 2afad4f450344ad13135de23a282d283d1e7331a Mon Sep 17 00:00:00 2001 From: fengliangjun Date: Mon, 10 Mar 2025 09:31:23 +0800 Subject: [PATCH] add build dataset --- LICENSE | 2 +- mindspeed_rl/__init__.py | 11 +- mindspeed_rl/datasets/__init__.py | 3 +- mindspeed_rl/datasets/build_dataset.py | 178 +++++++++++++++++++ mindspeed_rl/datasets/dataloader.py | 2 +- mindspeed_rl/datasets/formatter.py | 2 +- mindspeed_rl/datasets/instruction_dataset.py | 2 +- mindspeed_rl/datasets/templates.py | 5 +- mindspeed_rl/utils/loggers.py | 1 + mindspeed_rl/utils/metrics.py | 1 + tests/test_tools/dist_test.py | 1 + tests/ut/datasets/__init__.py | 2 + tests/ut/datasets/test_build_dataset.py | 97 ++++++++++ tests/ut/utils/test_logger.py | 3 +- 14 files changed, 297 insertions(+), 13 deletions(-) create mode 100644 mindspeed_rl/datasets/build_dataset.py create mode 100644 tests/ut/datasets/test_build_dataset.py diff --git a/LICENSE b/LICENSE index 1a43e3b5..b6b4894d 100644 --- a/LICENSE +++ b/LICENSE @@ -1,6 +1,6 @@ MIT License -Copyright (c) 2025 fengliangjun +Copyright (c) 2025, HUAWEI CORPORATION. All rights reserved. Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/mindspeed_rl/__init__.py b/mindspeed_rl/__init__.py index 4b9003db..5e6bad19 100644 --- a/mindspeed_rl/__init__.py +++ b/mindspeed_rl/__init__.py @@ -2,9 +2,14 @@ # Copyright (c) 2025, HUAWEI CORPORATION. All rights reserved. from .config_cls import MegatronConfig -from .utils import get_tokenizer, Metric, Loggers, WandbLogger -from .datasets import DataLoader, InstructionDataset +from .datasets import DataLoader, InstructionDataset, build_train_valid_test_datasets +from .utils import (get_tokenizer, Metric, Loggers, WandbLogger, + get_batch_metrices_mean, get_tune_attention_mask, + num_floating_point_operations) __all__ = ['MegatronConfig', 'get_tokenizer', 'Metric', 'Loggers', 'WandbLogger', - 'DataLoader', 'InstructionDataset'] + 'DataLoader', 'InstructionDataset', + 'build_train_valid_test_datasets', + 'get_batch_metrices_mean', 'get_tune_attention_mask', + 'num_floating_point_operations'] diff --git a/mindspeed_rl/datasets/__init__.py b/mindspeed_rl/datasets/__init__.py index 6dfb2a6b..cee1de4c 100644 --- a/mindspeed_rl/datasets/__init__.py +++ b/mindspeed_rl/datasets/__init__.py @@ -3,5 +3,6 @@ from .instruction_dataset import InstructionDataset from .dataloader import DataLoader +from .build_dataset import build_train_valid_test_datasets -__all__ = ['InstructionDataset', 'DataLoader'] +__all__ = ['InstructionDataset', 'DataLoader', 'build_train_valid_test_datasets'] diff --git a/mindspeed_rl/datasets/build_dataset.py b/mindspeed_rl/datasets/build_dataset.py new file mode 100644 index 00000000..52a90dc3 --- /dev/null +++ b/mindspeed_rl/datasets/build_dataset.py @@ -0,0 +1,178 @@ +# coding=utf-8 +# Copyright (c) 2025, HUAWEI CORPORATION. All rights reserved. + +from typing import Optional, Any, Tuple + +import numpy as np + +from mindspeed_rl.datasets.indexed_dataset import get_packed_indexed_dataset +from mindspeed_rl.utils import Loggers + + +logger = Loggers('build_train_valid_test_datasets') + + +def build_train_valid_test_datasets( + data_prefix: str, + splits_string: str, + seq_length: int, + train_valid_test_num_samples: Tuple, + dataset_cls: Any, + tokenizer: Any, + parallel_state: Any, + full_shuffle_index: bool = False, + get_item_in_order: bool = False, + reset_position_ids: bool = False, + prompt_type: str = None, + prompt_type_path: str = None, + seed: int = 42, + extra_param: Optional[Any] = None +): + """ + Build train, valid, and test datasets. + + Args: + data_prefix: path+prefix for data + splits_string: split for train,valid,test data, i.e. 80,10,10 + seq_length: sequence length used for training + train_valid_test_num_samples: a triplet for samples, i.e. (3840, 5120, 1280) + dataset_cls: An class implemented based on BaseDataset + tokenizer: tokenizer by get_tokenizer + parallel_state: megatron parallel state + full_shuffle_index: full shuffle for all index + get_item_in_order: do not use shuffle index + reset_position_ids: support for TND Training + prompt_type: for instruction training, model related + prompt_type_path: the path to templates.json + seed: random seed + extra_param: param for dataset + """ + + logger.info(' > datasets target sizes (minimum size):') + logger.info(' train: {}'.format(train_valid_test_num_samples[0])) + logger.info(' validation: {}'.format(train_valid_test_num_samples[1])) + logger.info(' test: {}'.format(train_valid_test_num_samples[2])) + + + # Only Support Single dataset. + all_train_datasets, all_valid_datasets, all_test_datasets = _build_train_valid_test_datasets( + data_prefix=data_prefix, + splits_string=splits_string, + seq_length=seq_length, + train_valid_test_num_samples=train_valid_test_num_samples, + tokenizer=tokenizer, + dataset_cls=dataset_cls, + parallel_state=parallel_state, + full_shuffle_index=full_shuffle_index, + get_item_in_order=get_item_in_order, + reset_position_ids=reset_position_ids, + prompt_type=prompt_type, + prompt_type_path=prompt_type_path, + seed=seed, + extra_param=extra_param + ) + + return all_train_datasets, all_valid_datasets, all_test_datasets + + +def _build_train_valid_test_datasets( + data_prefix, + splits_string, + seq_length: int, + train_valid_test_num_samples, + tokenizer=None, + dataset_cls=None, + parallel_state=None, + full_shuffle_index=None, + get_item_in_order=None, + reset_position_ids=None, + prompt_type=None, + prompt_type_path=None, + seed=None, + extra_param=None +): + """Build train, valid, and test datasets.""" + + # 设置默认数据集类,保持向后兼容 + if dataset_cls is None: + raise ValueError("dataset_cls must be provided.") + + # Target indexed dataset. + packed_indexed_dataset = get_packed_indexed_dataset(data_prefix=data_prefix) + + total_num_of_documents = len(list(packed_indexed_dataset.datasets.values())[0]) + splits = get_train_valid_test_split_(splits_string, total_num_of_documents) + logger.info(' > dataset split:') + + logger.info(" train: document indices in [{}, {}) total" + " of {} documents".format(splits[0], splits[1], splits[1] - splits[0])) + logger.info(" validation: document indices in [{}, {}) total" + " of {} documents".format(splits[1], splits[2], splits[2] - splits[1])) + logger.info(" test: document indices in [{}, {}) total" + " of {} documents".format(splits[2], splits[3], splits[3] - splits[2])) + + def build_dataset(index, name): + dataset = None + if splits[index + 1] > splits[index]: + documents = np.arange(start=splits[index], stop=splits[index + 1], dtype=np.int32) + # 使用传入的dataset_cls动态创建数据集实例 + dataset = dataset_cls( + parallel_state=parallel_state, + dataset_type='LLM', + data_prefix=data_prefix, + is_packed_data=True, + tokenizer=tokenizer, + seq_length=seq_length, + num_samples=train_valid_test_num_samples[index], + name=name, + documents=documents, + seed=seed, + full_shuffle_index=full_shuffle_index, + get_item_in_order=get_item_in_order, + reset_position_ids=reset_position_ids, + prompt_type=prompt_type, + prompt_type_path=prompt_type_path, + extra_param=extra_param + ) + + return dataset + + train_dataset = build_dataset(0, 'train') + valid_dataset = build_dataset(1, 'valid') + test_dataset = build_dataset(2, 'test') + + return train_dataset, valid_dataset, test_dataset + + +def get_train_valid_test_split_(splits_string, size): + """ Get dataset splits from comma or '/' separated string list.""" + + splits = [] + if splits_string.find(',') != -1: + splits = [float(s) for s in splits_string.split(',')] + elif splits_string.find('/') != -1: + splits = [float(s) for s in splits_string.split('/')] + else: + splits = [float(splits_string)] + while len(splits) < 3: + splits.append(0.) + splits = splits[:3] + splits_sum = sum(splits) + + if splits_sum <= 0.0: + raise ValueError("splits_num {} should be lager than 0".format(splits_sum)) + + splits = [split / splits_sum for split in splits] + splits_index = [0] + for index, split in enumerate(splits): + splits_index.append(splits_index[index] + + int(round(split * float(size)))) + diff = splits_index[-1] - size + for index in range(1, len(splits_index)): + splits_index[index] -= diff + + if len(splits_index) != 4: + raise ValueError("the length of splits_index {} should be 4".format(len(splits_index))) + if splits_index[-1] != size: + raise ValueError("splits_index[-1] {} and size {} are supposed to be equal".format(splits_index[-1], size)) + return splits_index diff --git a/mindspeed_rl/datasets/dataloader.py b/mindspeed_rl/datasets/dataloader.py index f4bbe4d7..8dc1e6b6 100644 --- a/mindspeed_rl/datasets/dataloader.py +++ b/mindspeed_rl/datasets/dataloader.py @@ -1,5 +1,5 @@ # coding=utf-8 -# Copyright (c) 2025, HUAWEN. All rights rI CORPORATIOeserved. +# Copyright (c) 2025, HUAWEI CORPORATION. All rights reserved. import torch from transformers import DataCollatorForSeq2Seq diff --git a/mindspeed_rl/datasets/formatter.py b/mindspeed_rl/datasets/formatter.py index 44ce55f2..8be60ffe 100644 --- a/mindspeed_rl/datasets/formatter.py +++ b/mindspeed_rl/datasets/formatter.py @@ -1,5 +1,5 @@ # coding=utf-8 -# Copyright (c) 2024, HUAWEI CORPORATION. All rights reserved. +# Copyright (c) 2025, HUAWEI CORPORATION. All rights reserved. import json import re diff --git a/mindspeed_rl/datasets/instruction_dataset.py b/mindspeed_rl/datasets/instruction_dataset.py index 6e8ba62f..029c5880 100644 --- a/mindspeed_rl/datasets/instruction_dataset.py +++ b/mindspeed_rl/datasets/instruction_dataset.py @@ -1,5 +1,5 @@ # coding=utf-8 -# Copyright (c) 2025, HUAWEN. All rights rI CORPORATIOeserved. +# Copyright (c) 2025, HUAWEI CORPORATION. All rights reserved. import warnings from typing import Optional, Callable, List, Any diff --git a/mindspeed_rl/datasets/templates.py b/mindspeed_rl/datasets/templates.py index 6acaced2..d2105807 100644 --- a/mindspeed_rl/datasets/templates.py +++ b/mindspeed_rl/datasets/templates.py @@ -1,10 +1,9 @@ # coding=utf-8 -# Copyright (c) 2024, HUAWEI CORPORATION. All rights reserved. +# Copyright (c) 2025, HUAWEI CORPORATION. All rights reserved. import os import re import json -import logging from pathlib import Path from dataclasses import dataclass from enum import Enum, unique @@ -17,8 +16,6 @@ if TYPE_CHECKING: from .formatter import Formatter -logger = logging.getLogger(__name__) - cur_file_dir = Path(__file__).absolute().parent TEMPLATES_DIR = os.path.join(cur_file_dir.parent, "configs/templates.json") diff --git a/mindspeed_rl/utils/loggers.py b/mindspeed_rl/utils/loggers.py index da707bca..50eeaa4a 100644 --- a/mindspeed_rl/utils/loggers.py +++ b/mindspeed_rl/utils/loggers.py @@ -1,3 +1,4 @@ +# coding=utf-8 # Copyright (c) 2025, HUAWEI CORPORATION. All rights reserved. import os import logging diff --git a/mindspeed_rl/utils/metrics.py b/mindspeed_rl/utils/metrics.py index 638b7523..2b2d0892 100644 --- a/mindspeed_rl/utils/metrics.py +++ b/mindspeed_rl/utils/metrics.py @@ -1,3 +1,4 @@ +# coding=utf-8 # Copyright (c) 2025, HUAWEI CORPORATION. All rights reserved. from abc import ABC, abstractmethod diff --git a/tests/test_tools/dist_test.py b/tests/test_tools/dist_test.py index 1cc65970..6e5b065b 100644 --- a/tests/test_tools/dist_test.py +++ b/tests/test_tools/dist_test.py @@ -1,3 +1,4 @@ +# coding=utf-8 # Copyright (c) Microsoft Corporation. # # This source code is licensed under the Apache license found in the diff --git a/tests/ut/datasets/__init__.py b/tests/ut/datasets/__init__.py index e69de29b..7f65ae5c 100644 --- a/tests/ut/datasets/__init__.py +++ b/tests/ut/datasets/__init__.py @@ -0,0 +1,2 @@ +# coding=utf-8 +# Copyright (c) 2025, HUAWEI CORPORATION. All rights reserved. \ No newline at end of file diff --git a/tests/ut/datasets/test_build_dataset.py b/tests/ut/datasets/test_build_dataset.py new file mode 100644 index 00000000..496bfa83 --- /dev/null +++ b/tests/ut/datasets/test_build_dataset.py @@ -0,0 +1,97 @@ +# coding=utf-8 +# Copyright (c) 2025, HUAWEI CORPORATION. All rights reserved. + +from megatron.core import parallel_state + +from mindspeed_rl import get_tokenizer, build_train_valid_test_datasets +from mindspeed_rl import InstructionDataset, DataLoader + +from tests.test_tools.dist_test import DistributedTest + + +class TestBuildTrainValidTestDataset(DistributedTest): + world_size = 1 + + def test_build_nonpack_dataset(self): + tokenizer_directory = '/data/models/llama2-7b' + non_pack_data_prefix = '/data/datasets/nonpack/alpaca' + + hf_tokenizer = get_tokenizer(tokenizer_directory) + + train_ds, valid_ds, test_ds = build_train_valid_test_datasets( + data_prefix=non_pack_data_prefix, + splits_string='80,10,10', + seq_length=1024, + train_valid_test_num_samples=(3840, 5120, 1280), + dataset_cls=InstructionDataset, + tokenizer=hf_tokenizer, + parallel_state=parallel_state, + full_shuffle_index=False, + get_item_in_order=False, + reset_position_ids=False, + prompt_type='llama2', + prompt_type_path='./configs/templates.json', + seed=42, + extra_param=None + ) + + train_dl = DataLoader( + dataset=train_ds, + parallel_state=parallel_state, + tokenizer=None, + num_workers=2, + tokenizer_padding_side='right', + pad_to_multiple_of=8, + variable_seq_lengths=False, + num_nextn_predict_layers=0, + micro_batch_size=1, + comsumed_samples=0, + seed=1234 + ) + + for item in train_dl: + assert item['input_ids'][0][-1] == 2, "build nonpack input_id failed!" + assert item['labels'][0][-2] == -100, "build nonpack labels failed!" + break + + def test_build_pack_dataset(self): + tokenizer_directory = '/data/models/llama2-7b' + pack_data_prefix = '/data/datasets/pack/alpaca' + + hf_tokenizer = get_tokenizer(tokenizer_directory) + + train_ds, valid_ds, test_ds = build_train_valid_test_datasets( + data_prefix=pack_data_prefix, + splits_string='100,0,0', + seq_length=1024, + train_valid_test_num_samples=(3840, 5120, 1280), + dataset_cls=InstructionDataset, + tokenizer=hf_tokenizer, + parallel_state=parallel_state, + full_shuffle_index=False, + get_item_in_order=False, + reset_position_ids=True, + prompt_type='llama2', + prompt_type_path='./configs/templates.json', + seed=42, + extra_param=None + ) + + train_dl = DataLoader( + dataset=train_ds, + parallel_state=parallel_state, + tokenizer=None, + num_workers=2, + tokenizer_padding_side='right', + pad_to_multiple_of=8, + variable_seq_lengths=False, + num_nextn_predict_layers=0, + micro_batch_size=1, + comsumed_samples=0, + seed=1234 + ) + + for item in train_dl: + assert item['input_ids'][0][-3] == 1364, "build packed input_ids failed!" + assert item['labels'][0][-1] == 387, "build packed labels failed!" + break diff --git a/tests/ut/utils/test_logger.py b/tests/ut/utils/test_logger.py index 03eacbf3..2c7f4382 100644 --- a/tests/ut/utils/test_logger.py +++ b/tests/ut/utils/test_logger.py @@ -1,4 +1,5 @@ -# Copyright (c) 2025, HUAWEI CORPORATION. All rights reserved. +# coding=utf-8 +# Copyright (c) 2025, HUAWEI CORPORATION. All rights reserved. import unittest import logging -- Gitee