diff --git a/mindformers/dataset/dataloader/datareaders.py b/mindformers/dataset/dataloader/datareaders.py index 1f95cd2716ec6767c3656156f32d74f2b6187ca3..83e79b1c70fffff38e08b05c5e922ef0de082ffe 100644 --- a/mindformers/dataset/dataloader/datareaders.py +++ b/mindformers/dataset/dataloader/datareaders.py @@ -37,44 +37,44 @@ def cmrc2018_reader(path): return dict(prompts=prompts, answers=answers) +def wikitext_clean(string): + """ cleaning wikitext dataset""" + # contractions + string = string.replace("s '", "s'") + string = re.sub(r"/' [0-9]/", r"/'[0-9]/", string) + # number separators + string = string.replace(" @-@ ", "-") + string = string.replace(" @,@ ", ",") + string = string.replace(" @.@ ", ".") + # punctuation + string = string.replace(" : ", ": ") + string = string.replace(" ; ", "; ") + string = string.replace(" . ", ". ") + string = string.replace(" ! ", "! ") + string = string.replace(" ? ", "? ") + string = string.replace(" , ", ", ") + # double brackets + string = re.sub(r"\(\s*([^\)]*?)\s*\)", r"(\1)", string) + string = re.sub(r"\[\s*([^\]]*?)\s*\]", r"[\1]", string) + string = re.sub(r"{\s*([^}]*?)\s*}", r"{\1}", string) + string = re.sub(r"\"\s*([^\"]*?)\s*\"", r'"\1"', string) + string = re.sub(r"'\s*([^']*?)\s*'", r"'\1'", string) + # miscellaneous + string = string.replace("= = = =", "====") + string = string.replace("= = =", "===") + string = string.replace("= =", "==") + string = string.replace(" " + chr(176) + " ", chr(176)) + string = string.replace(" \n", "\n") + string = string.replace("\n ", "\n") + string = string.replace(" N ", " 1 ") + string = string.replace(" 's", "'s") + return string + + def wikitext_reader(path): """Reading wikitext datasets. Returns a list of many sentences.""" path = os.path.realpath(path) - def wikitext_clean(string): - """ string clean """ - # contractions - string = string.replace("s '", "s'") - string = re.sub(r"/' [0-9]/", r"/'[0-9]/", string) - # number separators - string = string.replace(" @-@ ", "-") - string = string.replace(" @,@ ", ",") - string = string.replace(" @.@ ", ".") - # punctuation - string = string.replace(" : ", ": ") - string = string.replace(" ; ", "; ") - string = string.replace(" . ", ". ") - string = string.replace(" .", ".") - string = string.replace(" ! ", "! ") - string = string.replace(" ? ", "? ") - string = string.replace(" , ", ", ") - # double brackets - string = re.sub(r"\(\s*([^\)]*?)\s*\)", r"(\1)", string) - string = re.sub(r"\[\s*([^\]]*?)\s*\]", r"[\1]", string) - string = re.sub(r"{\s*([^}]*?)\s*}", r"{\1}", string) - string = re.sub(r"\"\s*([^\"]*?)\s*\"", r'"\1"', string) - string = re.sub(r"'\s*([^']*?)\s*'", r"'\1'", string) - # miscellaneous - string = string.replace("= = = =", "====") - string = string.replace("= = =", "===") - string = string.replace("= =", "==") - string = string.replace(" " + chr(176) + " ", chr(176)) - string = string.replace(" \n", "\n") - string = string.replace("\n ", "\n") - string = string.replace(" N ", " 1 ") - string = string.replace(" 's", "'s") - return string - def preprocess_data(input_file): """preprocess data.""" dataset_valid = [] diff --git a/mindformers/generation/parallel_decoding_mcore.py b/mindformers/generation/parallel_decoding_mcore.py index e64421d67f12804a3b62990ca25a78c6bfd28436..9f4591d9876e8f2f2ed3d1d8d1b3f7663ab724c7 100644 --- a/mindformers/generation/parallel_decoding_mcore.py +++ b/mindformers/generation/parallel_decoding_mcore.py @@ -35,7 +35,6 @@ def la_pre_process(input_ids, slot_mapping, **model_kwargs): for q_seq_len in q_seq_lens: input_ids_list += input_ids[start: start + q_seq_len] start += max_q_seq_lens - input_ids = np.array(input_ids_list, dtype=np.int32) if len(slot_mapping) != sum_q_seq_lens: slot_mapping = slot_mapping.tolist() slot_mapping_list = list() diff --git a/mindformers/models/base_config.py b/mindformers/models/base_config.py index b84a0d2280d760cc5a22350a74e3bcb1f616276d..1fdc84d7b81816f2091b4835e8493a9b11d66915 100644 --- a/mindformers/models/base_config.py +++ b/mindformers/models/base_config.py @@ -18,12 +18,11 @@ BaseConfig class, which is all model configs' base class """ import os -import shutil - import yaml from mindformers.tools.check_rules import check_yaml_depth_before_loading from mindformers.tools.utils import FILE_PERMISSION from mindformers.models.utils import DEFAULT_CHECKPOINT_SAVE_FOLDER +from mindformers.models.auto.utils import set_default_yaml_file from ..mindformer_book import MindFormerBook from ..mindformer_book import print_path_or_list from ..tools.logger import logger @@ -147,21 +146,8 @@ class BaseConfig(dict): yaml_file = os.path.join(checkpoint_path, yaml_name + ".yaml") - def get_default_yaml_file(model_name): - default_yaml_file = "" - for model_dict in MindFormerBook.get_trainer_support_task_list().values(): - if model_name in model_dict: - default_yaml_file = model_dict.get(model_name) - break - return default_yaml_file - - if not os.path.exists(yaml_file): - default_yaml_file = get_default_yaml_file(yaml_name) - if os.path.realpath(default_yaml_file) and os.path.exists(default_yaml_file): - shutil.copy(default_yaml_file, yaml_file) - logger.info("default yaml config in %s is used.", yaml_file) - else: - raise FileNotFoundError(f'default yaml file path must be correct, but get {default_yaml_file}') + set_default_yaml_file(yaml_name, yaml_file) + config_args = MindFormerConfig(yaml_file) config_args.model.model_config.update(**kwargs) config = build_model_config(config_args.model.model_config) diff --git a/mindformers/models/base_processor.py b/mindformers/models/base_processor.py index 07300000d79238ff229c32793f5a6f5c766cffe9..480f63162019f21465bf3ab68d7d38f1ea27eea0 100644 --- a/mindformers/models/base_processor.py +++ b/mindformers/models/base_processor.py @@ -17,12 +17,12 @@ BaseProcessor """ import os -import shutil - import yaml from mindformers.tools.check_rules import check_yaml_depth_before_loading +from mindformers.models.image_processing_utils import BaseImageProcessor from mindformers.tools.utils import FILE_PERMISSION from mindformers.models.utils import DEFAULT_CHECKPOINT_SAVE_FOLDER +from mindformers.models.auto.utils import set_default_yaml_file from ..mindformer_book import print_path_or_list, MindFormerBook from .build_processor import build_processor from .tokenization_utils_base import PreTrainedTokenizerBase @@ -256,21 +256,7 @@ class BaseProcessor: yaml_file = os.path.join(checkpoint_path, yaml_name + ".yaml") - def get_default_yaml_file(model_name): - default_yaml_file = "" - for model_dict in MindFormerBook.get_trainer_support_task_list().values(): - if model_name in model_dict: - default_yaml_file = model_dict.get(model_name) - break - return default_yaml_file - - if not os.path.exists(yaml_file): - default_yaml_file = get_default_yaml_file(yaml_name) - if os.path.realpath(default_yaml_file) and os.path.exists(default_yaml_file): - shutil.copy(default_yaml_file, yaml_file) - logger.info("default yaml config in %s is used.", yaml_file) - else: - raise FileNotFoundError(f'default yaml file path must be correct, but get {default_yaml_file}') + set_default_yaml_file(yaml_name, yaml_file) config_args = MindFormerConfig(yaml_file) diff --git a/mindformers/models/configuration_utils.py b/mindformers/models/configuration_utils.py index 1feb598839102cd98fc090896ff5b88f7618a621..65854e8146020b52e35406c0e4174139a525d7a3 100644 --- a/mindformers/models/configuration_utils.py +++ b/mindformers/models/configuration_utils.py @@ -18,7 +18,6 @@ """Configuration base class and utilities.""" import os -import shutil import re import json import copy @@ -35,6 +34,7 @@ from mindformers.tools.check_rules import check_yaml_depth_before_loading from mindformers.tools.utils import FILE_PERMISSION from mindformers.models.build_config import build_model_config, get_model_config from mindformers.models.utils import CONFIG_NAME, ms_type_to_str, DEFAULT_CHECKPOINT_SAVE_FOLDER +from mindformers.models.auto.utils import set_default_yaml_file from mindformers.mindformer_book import MindFormerBook, print_path_or_list from mindformers.tools import ( PushToHubMixin, @@ -414,23 +414,8 @@ class PretrainedConfig(PushToHubMixin): yaml_file = os.path.join(checkpoint_path, yaml_name + ".yaml") - def get_default_yaml_file(model_name): - default_yaml_file = "" - for model_dict in MindFormerBook.get_trainer_support_task_list().values(): - if model_name in model_dict: - default_yaml_file = model_dict.get(model_name) - break - return default_yaml_file - - if not os.path.exists(yaml_file): - default_yaml_file = get_default_yaml_file(yaml_name) - if os.path.realpath(default_yaml_file) and os.path.exists(default_yaml_file): - shutil.copy(default_yaml_file, yaml_file) - logger.info("default yaml config in %s is used.", yaml_file) - else: - raise FileNotFoundError( - f'default yaml file path must be correct, but get {default_yaml_file}' - ) + set_default_yaml_file(yaml_name, yaml_file) + config_args = MindFormerConfig(yaml_file) use_legacy = config_args.get_value("use_legacy", True) config_args.model.model_config.update(**kwargs) diff --git a/mindformers/models/processing_utils.py b/mindformers/models/processing_utils.py index d3943690dc174a59daedea840241cdf2d8be226c..2e01e8f6195b6e6c5f319c7bdf49a5be3f2986e5 100644 --- a/mindformers/models/processing_utils.py +++ b/mindformers/models/processing_utils.py @@ -18,7 +18,6 @@ """ import os -import shutil from pathlib import Path from typing import Optional, Union import json @@ -26,6 +25,7 @@ import yaml from mindformers.tools.check_rules import check_yaml_depth_before_loading from mindformers.tools.utils import FILE_PERMISSION from mindformers.models.utils import DEFAULT_CHECKPOINT_SAVE_FOLDER +from mindformers.models.auto.utils import set_default_yaml_file from ..mindformer_book import print_path_or_list, MindFormerBook from .build_processor import build_processor from .tokenization_utils import PreTrainedTokenizer @@ -231,21 +231,7 @@ class ProcessorMixin(PushToHubMixin): yaml_file = os.path.join(checkpoint_path, yaml_name + ".yaml") - def get_default_yaml_file(model_name): - default_yaml_file = "" - for model_dict in MindFormerBook.get_trainer_support_task_list().values(): - if model_name in model_dict: - default_yaml_file = model_dict.get(model_name) - break - return default_yaml_file - - if not os.path.exists(yaml_file): - default_yaml_file = get_default_yaml_file(yaml_name) - if os.path.realpath(default_yaml_file) and os.path.exists(default_yaml_file): - shutil.copy(default_yaml_file, yaml_file) - logger.info("default yaml config in %s is used.", yaml_file) - else: - raise FileNotFoundError(f'default yaml file path must be correct, but get {default_yaml_file}') + set_default_yaml_file(yaml_name, yaml_file) config_args = MindFormerConfig(yaml_file) diff --git a/mindformers/modules/local_block_sparse_attention.py b/mindformers/modules/local_block_sparse_attention.py index 42535a0c7bda482fbacf5409c29c07ffe25fcc54..259307241aa1642159d7e9f9c2a223df8082373e 100644 --- a/mindformers/modules/local_block_sparse_attention.py +++ b/mindformers/modules/local_block_sparse_attention.py @@ -17,8 +17,6 @@ A Local Block Sparse Attention. """ from __future__ import absolute_import -from functools import wraps, partial -import inspect import math import numpy as np @@ -34,6 +32,7 @@ except ImportError: import mindspore._checkparam as Validator from mindformers.modules.transformer.op_parallel_config import default_dpmp_config, OpParallelConfig +from mindformers.modules.layers import _args_type_validator_check, _valid_value_checks, _valid_type_checks __all__ = ["LocalBlockSparseAttention"] @@ -42,59 +41,6 @@ kv_index = None mask_index = None -def _args_type_validator_check(*type_args, **type_kwargs): - """Check whether input data type is correct.""" - - def type_check(func): - sig = inspect.signature(func) - bound_types = sig.bind_partial(*type_args, **type_kwargs).arguments - - @wraps(func) - def wrapper(*args, **kwargs): - nonlocal bound_types - bound_values = sig.bind(*args, **kwargs) - - argument_dict = bound_values.arguments - if "kwargs" in bound_types: - bound_types = bound_types["kwargs"] - if "kwargs" in argument_dict: - argument_dict = argument_dict["kwargs"] - for name, value in argument_dict.items(): - if name in bound_types: - bound_types[name](value, name) - return func(*args, **kwargs) - - return wrapper - - return type_check - - -def _valid_type_checks(types, class_name): - """types should be a list of types, this function check if the type is in the valid dtypes""" - def validator_check_func(value, name): - # The args of Validator.check_type_name is (arg_name, arg_type, valid_types, prim_name) - # as the input of _args_type_validator_check is fixed, so we need to manually change the input order - partial_check = partial(Validator.check_type_name, - valid_types=types, - prim_name=class_name) - return partial_check(name, type(value)) - - return validator_check_func - - -def _valid_value_checks(types, class_name): - """the value should be a list of types, this function check if the value is in the valid dtypes""" - def validator_check_func(value, name): - # The args of Validator.check_type_name is (arg_name, arg_type, valid_types, prim_name) - # as the input of _args_type_validator_check is fixed, so we need to manually change the input order - partial_check = partial(Validator.check_type_name, - valid_types=types, - prim_name=class_name) - return partial_check(name, value) - - return validator_check_func - - class InitGatherIndex: """ A self-defined function to init the index of gather operation diff --git a/mindformers/parallel_core/training_graph/transformer/moe/shared_experts.py b/mindformers/parallel_core/training_graph/transformer/moe/shared_experts.py index 4d48a1883f8e867ff3e33f8a194bd3d06feb5373..e00c7f27e09c930341aca554cd3a390d5ed15328 100644 --- a/mindformers/parallel_core/training_graph/transformer/moe/shared_experts.py +++ b/mindformers/parallel_core/training_graph/transformer/moe/shared_experts.py @@ -29,7 +29,7 @@ from mindformers.parallel_core.training_graph.device_matrix import layout class SharedExpertMLP(MLP): - r""" + """ Implementation of a shared expert feedforward block that inherits from MLP. This module extends the standard MLP to support shared expert logic, typically used in MoE settings. @@ -110,7 +110,7 @@ class SharedExpertMLP(MLP): class SharedExpertMLPInterleaved(MLPInterleaved): - r""" + """ Implementation of a shared expert feedforward block that inherits from MLP. This module extends the standard MLP to support shared expert logic, typically used in MoE settings. diff --git a/mindformers/tools/dataset_preprocess/llama/llama_preprocess.py b/mindformers/tools/dataset_preprocess/llama/llama_preprocess.py index 0f68b2c1113e66f73dcaf8fe475546665a78eb03..b2ef062a390faea3a5c1d740456dc4b0af3d48f9 100644 --- a/mindformers/tools/dataset_preprocess/llama/llama_preprocess.py +++ b/mindformers/tools/dataset_preprocess/llama/llama_preprocess.py @@ -19,13 +19,13 @@ transform wikitext-2, wikitext-103, lambada, openwebtext dataset to mindrecord. import argparse import json import os -import re import numpy as np from mindspore.mindrecord import FileWriter from tqdm import tqdm from mindformers.dataset.dataloader.training_dataloader import TrainingDataset +from mindformers.dataset.dataloader.datareaders import wikitext_clean from mindformers.models.llama.llama_tokenizer import LlamaTokenizer from mindformers.tools import logger @@ -40,55 +40,6 @@ def chunks(lst, n): yield lst[i:i + n] -def package_file(it, n): - """ package multiple files""" - stop = False - while not stop: - batch = [] - for _ in range(n): - try: - batch.append(next(it)) - except StopIteration: - stop = True - if not batch: - break - yield batch - - -def clean_wikitext(string): - """ cleaning wikitext dataset""" - # contractions - string = string.replace("s '", "s'") - string = re.sub(r"/' [0-9]/", r"/'[0-9]/", string) - # number separators - string = string.replace(" @-@ ", "-") - string = string.replace(" @,@ ", ",") - string = string.replace(" @.@ ", ".") - # punctuation - string = string.replace(" : ", ": ") - string = string.replace(" ; ", "; ") - string = string.replace(" . ", ". ") - string = string.replace(" ! ", "! ") - string = string.replace(" ? ", "? ") - string = string.replace(" , ", ", ") - # double brackets - string = re.sub(r"\(\s*([^\)]*?)\s*\)", r"(\1)", string) - string = re.sub(r"\[\s*([^\]]*?)\s*\]", r"[\1]", string) - string = re.sub(r"{\s*([^}]*?)\s*}", r"{\1}", string) - string = re.sub(r"\"\s*([^\"]*?)\s*\"", r'"\1"', string) - string = re.sub(r"'\s*([^']*?)\s*'", r"'\1'", string) - # miscellaneous - string = string.replace("= = = =", "====") - string = string.replace("= = =", "===") - string = string.replace("= =", "==") - string = string.replace(" " + chr(176) + " ", chr(176)) - string = string.replace(" \n", "\n") - string = string.replace("\n ", "\n") - string = string.replace(" N ", " 1 ") - string = string.replace(" 's", "'s") - return string - - def preprocess(sources, tokenizer, seq_length): """conversation preprocess.""" conv = get_default_conv_template("vicuna").copy() @@ -203,7 +154,7 @@ def tokenize_wiki(tokenizer, file_path, seq_length, repeat): pbar.update(1) pbar.set_description("Processing text") - paras = clean_wikitext(raw_text).split("\n\n") + paras = wikitext_clean(raw_text).split("\n\n") pbar.update(1) pbar.set_description("Tokenizing text") @@ -273,6 +224,7 @@ def tokenize_qa(tokenizer, file_path, seq_length): for i, _ in enumerate(dataset_cls): yield dataset_cls[i] + if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--dataset_type', type=str, default='wiki') diff --git a/mindformers/tools/transform_ckpt.py b/mindformers/tools/transform_ckpt.py index dee98978cb706a0883a178aef411ab2cd7d13690..905b8568d09aa926c9302007a361fa25ae0c689b 100644 --- a/mindformers/tools/transform_ckpt.py +++ b/mindformers/tools/transform_ckpt.py @@ -52,26 +52,14 @@ def get_strategy(startegy_path, rank_id=None): return None + if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('--src_ckpt_strategy', - default="", - help='path of src ckpt strategy') - parser.add_argument('--dst_ckpt_strategy', - default="", - help='path of dst ckpt strategy') - parser.add_argument('--src_ckpt_dir', - default="", - type=str, - help='path of src ckpt') - parser.add_argument('--dst_ckpt_dir', - default="", - type=str, - help='path where to save dst ckpt') - parser.add_argument('--prefix', - default='checkpoint_', - type=str, - help='prefix of transformed checkpoint') + parser.add_argument('--src_ckpt_strategy', default="", help='path of src ckpt strategy') + parser.add_argument('--dst_ckpt_strategy', default="", help='path of dst ckpt strategy') + parser.add_argument('--src_ckpt_dir', default="", type=str, help='path of src ckpt') + parser.add_argument('--dst_ckpt_dir', default="", type=str, help='path where to save dst ckpt') + parser.add_argument('--prefix', default='checkpoint_', type=str, help='prefix of transformed checkpoint') args = parser.parse_args() src_ckpt_strategy = get_strategy(args.src_ckpt_strategy) diff --git a/mindformers/tools/transform_ckpt_lora.py b/mindformers/tools/transform_ckpt_lora.py index 6f5f3ca0b08af435db376202f2d8eafb8c94f32f..de8296a796089624d066388ae799e3838f2857d7 100644 --- a/mindformers/tools/transform_ckpt_lora.py +++ b/mindformers/tools/transform_ckpt_lora.py @@ -21,73 +21,23 @@ import mindspore as ms from mindspore import Parameter, Tensor import mindspore.ops as P from mindformers.tools.logger import logger - - -def get_strategy(startegy_path, rank_id=None): - """Merge strategy if strategy path is dir - - Args: - startegy_path (str): The path of stategy. - rank_id (int): The rank id of device. - - Returns: - None or strategy path - """ - if not startegy_path or startegy_path == "None": - return None - - if not os.path.exists(startegy_path): - raise ValueError(f'{startegy_path} not found!') - - if os.path.isfile(startegy_path): - return startegy_path - - if os.path.isdir(startegy_path): - if rank_id: - merge_path = os.path.join(startegy_path, f'merged_ckpt_strategy_{rank_id}.ckpt') - else: - merge_path = os.path.join(startegy_path, f'merged_ckpt_strategy.ckpt') - - if os.path.exists(merge_path): - os.remove(merge_path) - - ms.merge_pipeline_strategys(startegy_path, merge_path) - return merge_path - - return None +from mindformers.tools.transform_ckpt import get_strategy def transpose(weight, fan_in_fan_out): return weight.T if fan_in_fan_out else weight + if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('--src_ckpt_strategy', - default="", - help='path of src ckpt strategy') - parser.add_argument('--dst_ckpt_strategy', - default="", - help='path of dst ckpt strategy') - parser.add_argument('--src_ckpt_path_or_dir', - default="", - type=str, - help='path of src ckpt') - parser.add_argument('--dst_ckpt_dir', - default="", - type=str, - help='path where to save dst ckpt') - parser.add_argument('--prefix', - default='checkpoint_', - type=str, - help='prefix of transformed checkpoint') - parser.add_argument('--lora_scaling', - default=1, - type=float, + parser.add_argument('--src_ckpt_strategy', default="", help='path of src ckpt strategy') + parser.add_argument('--dst_ckpt_strategy', default="", help='path of dst ckpt strategy') + parser.add_argument('--src_ckpt_path_or_dir', default="", type=str, help='path of src ckpt') + parser.add_argument('--dst_ckpt_dir', default="", type=str, help='path where to save dst ckpt') + parser.add_argument('--prefix', default='checkpoint_', type=str, help='prefix of transformed checkpoint') + parser.add_argument('--lora_scaling', default=1, type=float, help='scale of lora when merge model weight, default is lora_alpha/lora_rank') - parser.add_argument('--save_format', - default='ckpt', - type=str, - choices=['ckpt', 'safetensors'], + parser.add_argument('--save_format', default='ckpt', type=str, choices=['ckpt', 'safetensors'], help='format for saving the model, choose between ckpt and safetensors') args = parser.parse_args() diff --git a/mindformers/trainer/utils.py b/mindformers/trainer/utils.py index 284bac2697b1a04314ba3ee448feaddd357a2317..971dfcff3e39d1de5b111e49d7ee1fad20117157 100644 --- a/mindformers/trainer/utils.py +++ b/mindformers/trainer/utils.py @@ -31,8 +31,8 @@ from mindspore.communication.comm_func import barrier from mindformers.tools.logger import logger from mindformers.tools.utils import get_real_rank -from mindformers.utils.load_checkpoint_utils import CkptFormat, load_checkpoint_with_safetensors -from mindformers.checkpoint.utils import compile_model +from mindformers.utils.load_checkpoint_utils import \ + CkptFormat, load_checkpoint_with_safetensors, compile_model, get_last_checkpoint from mindformers.tools.register import MindFormerConfig from mindformers.tools.utils import ( replace_rank_id_in_ckpt_name, @@ -42,7 +42,7 @@ from mindformers.tools.utils import ( format_path, is_main_rank ) -from mindformers.tools.ckpt_transform import TransformCkpt +from mindformers.tools.ckpt_transform import TransformCkpt, check_rank_folders, check_ckpt_file_exist from mindformers.models.base_model import BaseModel from mindformers.models.modeling_utils import PreTrainedModel from mindformers.version_control import need_nz @@ -520,22 +520,6 @@ def check_checkpoint_config_valid(config): f"config.load_ckpt_format only support for 'ckpt' or 'safetensors', but got {config.load_ckpt_format}.") -def check_rank_folders(path, rank_id): - """check if the folders in path are correct""" - folder_name = "rank_{}".format(rank_id) - if not os.path.exists(os.path.join(path, folder_name)): - return False - return True - - -def check_ckpt_file_exist(path): - """check if the files in path endswith .ckpt""" - for file_name in os.listdir(path): - if file_name.endswith('.ckpt'): - return True - return False - - def check_path_include_total_ckpt(path): """check if the input path is total, not split.""" if path is None: @@ -755,23 +739,3 @@ def load_ckpt(config, network, optimizer=None, model=None, future=None): if optimizer: not_load_optim_params = load_param_into_net(optimizer, checkpoint_dict) logger.info("Optimizer parameters are not loaded: %s", not_load_optim_params) - - -def get_last_checkpoint(checkpoint_dir, ckpt_format='ckpt'): - """get last checkpoint for resuming or finetune.""" - if not os.path.isdir(checkpoint_dir): - raise NotADirectoryError( - f"{checkpoint_dir} is not a real directory," - f"When distributed loads are sliced weights," - f"load_checkpoint should be a checkpoint directory containing the directory of rank_{{0-*}}," - f"The directory structure is as follows: **checkpoint_root_dir/rank_{{0-*}}/**.{ckpt_format}") - output_checkpoint_path = [ - checkpoint - for checkpoint in os.listdir(checkpoint_dir) - if checkpoint.endswith(f'.{ckpt_format}') - ] - if not output_checkpoint_path: - return None - output_checkpoint_path = sorted(output_checkpoint_path, - key=lambda x: os.path.getmtime(os.path.join(checkpoint_dir, x))) - return os.path.join(checkpoint_dir, output_checkpoint_path[-1]) diff --git a/research/deepseek3/deepseek3_preprocess.py b/research/deepseek3/deepseek3_preprocess.py index be9d860893740ade4aa70857af2c771e4445dabf..0e9cf803e60bee590eecda846597f289eddd61c7 100644 --- a/research/deepseek3/deepseek3_preprocess.py +++ b/research/deepseek3/deepseek3_preprocess.py @@ -19,7 +19,6 @@ transform dataset to mindrecord. import argparse import json import os -import re import pathlib import numpy as np @@ -27,6 +26,7 @@ from mindspore.mindrecord import FileWriter from deepseek3_conversation import get_default_conv_template from mindformers.models.llama import LlamaTokenizerFast +from mindformers.dataset.dataloader.datareaders import wikitext_clean IGNORE_TOKEN_ID = -100 @@ -96,40 +96,6 @@ def chunks(lst, n): yield lst[i:i + n] -def clean_wikitext(string): - """ cleaning wikitext dataset""" - # contractions - string = string.replace("s '", "s'") - string = re.sub(r"/' [0-9]/", r"/'[0-9]/", string) - # number separators - string = string.replace(" @-@ ", "-") - string = string.replace(" @,@ ", ",") - string = string.replace(" @.@ ", ".") - # punctuation - string = string.replace(" : ", ": ") - string = string.replace(" ; ", "; ") - string = string.replace(" . ", ". ") - string = string.replace(" ! ", "! ") - string = string.replace(" ? ", "? ") - string = string.replace(" , ", ", ") - # double brackets - string = re.sub(r"\(\s*([^\)]*?)\s*\)", r"(\1)", string) - string = re.sub(r"\[\s*([^\]]*?)\s*\]", r"[\1]", string) - string = re.sub(r"{\s*([^}]*?)\s*}", r"{\1}", string) - string = re.sub(r"\"\s*([^\"]*?)\s*\"", r'"\1"', string) - string = re.sub(r"'\s*([^']*?)\s*'", r"'\1'", string) - # miscellaneous - string = string.replace("= = = =", "====") - string = string.replace("= = =", "===") - string = string.replace("= =", "==") - string = string.replace(" " + chr(176) + " ", chr(176)) - string = string.replace(" \n", "\n") - string = string.replace("\n ", "\n") - string = string.replace(" N ", " 1 ") - string = string.replace(" 's", "'s") - return string - - def preprocess(sources, tokenizer, seq_length): """conversation preprocess.""" conv = get_default_conv_template("vicuna").copy() @@ -215,7 +181,7 @@ def tokenize_wiki(tokenizer, file_path, seq_length, repeat): """tokenize wikitext-2/wikitext-103 dataset""" content = [] with open(file_path, 'r', encoding='utf-8') as f: - for para in clean_wikitext(f.read()).split("\n\n"): + for para in wikitext_clean(f.read()).split("\n\n"): if para and para.strip().startswith('=') is False: content += tokenizer(para)['input_ids'] content_out = [] diff --git a/research/deepseek3/wikitext_to_bin.py b/research/deepseek3/wikitext_to_bin.py index 7c7badd05681a650bc28c12d369850c222f265ef..0fb8bbbbe47f84f2ed1fef2252d0e63fd1609f37 100644 --- a/research/deepseek3/wikitext_to_bin.py +++ b/research/deepseek3/wikitext_to_bin.py @@ -10,7 +10,6 @@ import argparse import math import json import os -import re import sys import multiprocessing import numpy as np @@ -27,50 +26,17 @@ from mindformers.dataset.blended_datasets.indexed_dataset import IndexedDatasetB from mindformers.models import build_tokenizer from mindformers.models.tokenization_utils import AddedToken from mindformers.models.llama.llama_tokenizer_fast import LlamaTokenizerFast +from mindformers.dataset.dataloader.datareaders import wikitext_clean sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir))) -def clean_wikitext(string): - """cleaning wikitext dataset""" - # contractions - string = string.replace("s '", "s'") - string = re.sub(r"/' [0-9]/", r"/'[0-9]/", string) - # number separators - string = string.replace(" @-@ ", "-") - string = string.replace(" @,@ ", ",") - string = string.replace(" @.@ ", ".") - # punctuation - string = string.replace(" : ", ": ") - string = string.replace(" ; ", "; ") - string = string.replace(" . ", ". ") - string = string.replace(" ! ", "! ") - string = string.replace(" ? ", "? ") - string = string.replace(" , ", ", ") - # double brackets - string = re.sub(r"\(\s*([^\)]*?)\s*\)", r"(\1)", string) - string = re.sub(r"\[\s*([^\]]*?)\s*\]", r"[\1]", string) - string = re.sub(r"{\s*([^}]*?)\s*}", r"{\1}", string) - string = re.sub(r"\"\s*([^\"]*?)\s*\"", r'"\1"', string) - string = re.sub(r"'\s*([^']*?)\s*'", r"'\1'", string) - # miscellaneous - string = string.replace("= = = =", "====") - string = string.replace("= = =", "===") - string = string.replace("= =", "==") - string = string.replace(" " + chr(176) + " ", chr(176)) - string = string.replace(" \n", "\n") - string = string.replace("\n ", "\n") - string = string.replace(" N ", " 1 ") - string = string.replace(" 's", "'s") - return string - - def gen_wiki_json(input_file, output_file): """generate wikitext-2/wikitext-103 json""" data_idx = 0 out = open(output_file, 'w', encoding='utf-8') with open(input_file, 'r', encoding='utf-8') as f: - for para in clean_wikitext(f.read()).split("\n\n"): + for para in wikitext_clean(f.read()).split("\n\n"): content = {} if para and para.strip().startswith('=') is False: print(data_idx) diff --git a/research/llama3_1/llama3_1_preprocess.py b/research/llama3_1/llama3_1_preprocess.py index e1a43dca025163888263d83ebfd27c1f49801ad4..d201f46ed8b7b9055954dad9bb11537127ed151e 100644 --- a/research/llama3_1/llama3_1_preprocess.py +++ b/research/llama3_1/llama3_1_preprocess.py @@ -19,14 +19,13 @@ transform wikitext-2, wikitext-103, lambada, openwebtext dataset to mindrecord. import argparse import json import os -import re import numpy as np - +from mindspore.mindrecord import FileWriter +from mindformers.tools import logger +from mindformers.dataset.dataloader.datareaders import wikitext_clean from llama3_1_tokenizer import Llama3Tokenizer from llama3_1_conversation import get_default_conv_template -from mindspore.mindrecord import FileWriter -from mindformers.tools import logger IGNORE_TOKEN_ID = -100 @@ -37,55 +36,6 @@ def chunks(lst, n): yield lst[i:i + n] -def package_file(it, n): - """ package multiple files""" - stop = False - while not stop: - batch = [] - for _ in range(n): - try: - batch.append(next(it)) - except StopIteration: - stop = True - if not batch: - break - yield batch - - -def clean_wikitext(string): - """ cleaning wikitext dataset""" - # contractions - string = string.replace("s '", "s'") - string = re.sub(r"/' [0-9]/", r"/'[0-9]/", string) - # number separators - string = string.replace(" @-@ ", "-") - string = string.replace(" @,@ ", ",") - string = string.replace(" @.@ ", ".") - # punctuation - string = string.replace(" : ", ": ") - string = string.replace(" ; ", "; ") - string = string.replace(" . ", ". ") - string = string.replace(" ! ", "! ") - string = string.replace(" ? ", "? ") - string = string.replace(" , ", ", ") - # double brackets - string = re.sub(r"\(\s*([^\)]*?)\s*\)", r"(\1)", string) - string = re.sub(r"\[\s*([^\]]*?)\s*\]", r"[\1]", string) - string = re.sub(r"{\s*([^}]*?)\s*}", r"{\1}", string) - string = re.sub(r"\"\s*([^\"]*?)\s*\"", r'"\1"', string) - string = re.sub(r"'\s*([^']*?)\s*'", r"'\1'", string) - # miscellaneous - string = string.replace("= = = =", "====") - string = string.replace("= = =", "===") - string = string.replace("= =", "==") - string = string.replace(" " + chr(176) + " ", chr(176)) - string = string.replace(" \n", "\n") - string = string.replace("\n ", "\n") - string = string.replace(" N ", " 1 ") - string = string.replace(" 's", "'s") - return string - - def preprocess(sources, tokenizer, seq_length): """conversation preprocess.""" conv = get_default_conv_template("vicuna").copy() @@ -186,7 +136,7 @@ def tokenize_wiki(tokenizer, file_path, seq_length, repeat): """tokenize wikitext-2/wikitext-103 dataset""" content = [] with open(file_path, 'r', encoding='utf-8') as f: - for para in clean_wikitext(f.read()).split("\n\n"): + for para in wikitext_clean(f.read()).split("\n\n"): if para and para.strip().startswith('=') is False: content += tokenizer(para)['input_ids'] content_out = [] diff --git a/research/qwen2/qwen2_preprocess.py b/research/qwen2/qwen2_preprocess.py index 22f9b3347c0d8594aba92bdc2a47424cb569fd19..a09e10de4d1c3e2305cebf59b071812a7c3a0da1 100644 --- a/research/qwen2/qwen2_preprocess.py +++ b/research/qwen2/qwen2_preprocess.py @@ -19,12 +19,11 @@ transform dataset to mindrecord. import argparse import json import os -import re import numpy as np from mindspore.mindrecord import FileWriter - +from mindformers.dataset.dataloader.datareaders import wikitext_clean from research.qwen2.qwen2_tokenizer import Qwen2Tokenizer IGNORE_TOKEN_ID = -100 @@ -36,40 +35,6 @@ def chunks(lst, n): yield lst[i:i + n] -def clean_wikitext(string): - """ cleaning wikitext dataset""" - # contractions - string = string.replace("s '", "s'") - string = re.sub(r"/' [0-9]/", r"/'[0-9]/", string) - # number separators - string = string.replace(" @-@ ", "-") - string = string.replace(" @,@ ", ",") - string = string.replace(" @.@ ", ".") - # punctuation - string = string.replace(" : ", ": ") - string = string.replace(" ; ", "; ") - string = string.replace(" . ", ". ") - string = string.replace(" ! ", "! ") - string = string.replace(" ? ", "? ") - string = string.replace(" , ", ", ") - # double brackets - string = re.sub(r"\(\s*([^\)]*?)\s*\)", r"(\1)", string) - string = re.sub(r"\[\s*([^\]]*?)\s*\]", r"[\1]", string) - string = re.sub(r"{\s*([^}]*?)\s*}", r"{\1}", string) - string = re.sub(r"\"\s*([^\"]*?)\s*\"", r'"\1"', string) - string = re.sub(r"'\s*([^']*?)\s*'", r"'\1'", string) - # miscellaneous - string = string.replace("= = = =", "====") - string = string.replace("= = =", "===") - string = string.replace("= =", "==") - string = string.replace(" " + chr(176) + " ", chr(176)) - string = string.replace(" \n", "\n") - string = string.replace("\n ", "\n") - string = string.replace(" N ", " 1 ") - string = string.replace(" 's", "'s") - return string - - def preprocess(messages, tokenizer, seq_length): """Preprocesses the data for supervised fine-tuning.""" @@ -98,7 +63,7 @@ def tokenize_wiki(tokenizer, file_path, seq_length, repeat): """tokenize wikitext-2/wikitext-103 dataset""" content = [] with open(file_path, 'r', encoding='utf-8') as f: - for para in clean_wikitext(f.read()).split("\n\n"): + for para in wikitext_clean(f.read()).split("\n\n"): if para and para.strip().startswith('=') is False: content += tokenizer(para)['input_ids'] content_out = []