diff --git a/README.en.md b/README.en.md index 615844aa9846e1750c7b230cfbaa437872459a6b..f66d4af0d5cf3305927aea94587466175737ae47 100644 --- a/README.en.md +++ b/README.en.md @@ -1,7 +1,11 @@ -# 0.2B small Chinese chat language model +
+ +# A Small Chat with Chinese Language Model: ChatLM-Chinese-0.2B + [中文](./README.md) | English + +
# 1. 👋Introduction -*阅读中文文档 [中文](README.md).* Today's large language models tend to have large parameters, and consumer-grade computers are slow to do simple inference, let alone train a model from scratch. The goal of this project is to organize the training process of generative language models, including data cleaning, tokenizer training, model pre-training, SFT instruction fine-tuning, RLHF optimization, etc. @@ -11,7 +15,7 @@ ChatLM-mini-Chinese is a small Chinese chat model with only 0.2B (added shared w - Use the `Huggingface` NLP framework, including `transformers`, `accelerate`, `trl`, `peft`, etc. - Self-implemented `trainer`, supporting pre-training and SFT fine-tuning on a single machine with a single card or with multiple cards on a single machine. It supports stopping at any position during training and continuing training at any position. - Pre-training: Integrated into end-to-end `Text-to-Text` pre-training, non-`mask` mask prediction pre-training. - - Open source all data cleaning, dataset construction, dataset loading optimization and other processes; + - Open source all data cleaning (such as standardization, document deduplication based on mini_hash, etc.), data set construction, data set loading optimization and other processes; - tokenizer multi-process word frequency statistics, supports tokenizer training of `sentencepiece` and `huggingface tokenizers`; - Pre-training supports checkpoint at any step, and training can be continued from the breakpoint; - Streaming loading of large datasets (GB level), supporting buffer data shuffling, does not use memory or hard disk as cache, effectively reducing memory and disk usage. configuring `batch_size=1, max_len=320`, supporting pre-training on a machine with at least 16GB RAM + 4GB GPU memory; @@ -27,6 +31,11 @@ ChatLM-mini-Chinese is a small Chinese chat model with only 0.2B (added shared w 🟢**Latest Update** + 2024-01-07 +- Add document deduplication based on mini hash during the data cleaning process (in this project, the samples of the data set are actually deduplicated). Prevent the model from spitting out training data during inference after encountering multiple repeated data.
+- Add the `DropDatasetDuplicate` class to implement deduplication of documents from large data sets.
+ +
2023-12-29 - Update the model code (weights is NOT changed), you can directly use `AutoModelForSeq2SeqLM.from_pretrained(...)` to load the model for using.
@@ -160,6 +169,10 @@ Apple是一家专注于设计和用户体验的公司,其产品在设计上注 ## 3.2 from clone code repository start +The model of this project is the `TextToText` model. In the `prompt`, `response` and other fields of the pre-training stage, SFT stage, and RLFH stage, please be sure to add the `[EOS]` end-of-sentence mark. +The model of this project is the `TextToText` model. In the `prompt`, `response` and other fields of the pre-training stage, SFT stage, and RLFH stage, please be sure to add the `[EOS]` end-of-sentence mark. +The model of this project is the `TextToText` model. In the `prompt`, `response` and other fields of the pre-training stage, SFT stage, and RLFH stage, please be sure to add the `[EOS]` end-of-sentence mark. + ### 3.2.1 Clone repository ```bash git clone --depth 1 https://github.com/charent/ChatLM-mini-Chinese.git diff --git a/README.md b/README.md index 7036ed9f1e99591d50ca7a5348e1178005fe1fda..941e12ff18ec133a260742bce1672108cf1334a5 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,13 @@ -# 中文对话0.2B小模型 ChatLM-Chinese-0.2B +
+ +# 中文对话0.2B小模型 ChatLM-Chinese-0.2B + +中文 | [English](./README.en.md) + +
+ # 一、👋介绍 -*Read this in [English](README.en.md).* 现在的大语言模型的参数往往较大,消费级电脑单纯做推理都比较慢,更别说想自己从头开始训练一个模型了。本项目的目标是整理生成式语言模型的训练流程,包括数据清洗、tokenizer训练、模型预训练、SFT指令微调、RLHF优化等。 ChatLM-mini-Chinese为中文对话小模型,模型参数只有0.2B(算共享权重约210M),可以在最低4GB显存的机器进行预训练(`batch_size=1`,`fp16`或者` bf16`),`float16`加载、推理最少只需要512MB显存。 @@ -11,7 +17,7 @@ ChatLM-mini-Chinese为中文对话小模型,模型参数只有0.2B(算共享 - 使用`Huggingface`NLP框架,包括`transformers`、`accelerate`、`trl`、`peft`等。 - 自实现`trainer`,支持单机单卡、单机多卡进行预训练、SFT微调。训练过程中支持在任意位置停止,及在任意位置继续训练。 - 预训练:整合为端到端的`Text-to-Text`预训练,非`mask`掩码预测预训练。 - - 开源所有数据清洗、数据集构造、数据集加载优化等流程; + - 开源所有数据清洗(如规范化、基于mini_hash的文档去重等)、数据集构造、数据集加载优化等流程; - tokenizer多进程词频统计,支持`sentencepiece`、`huggingface tokenizers`的tokenizer训练; - 预训练支持任意位置断点,可从断点处继续训练; - 大数据集(GB级别)流式加载、支持缓冲区数据打乱,不利用内存、硬盘作为缓存,有效减少内存、磁盘占用。配置`batch_size=1, max_len=320`下,最低支持在16GB内存+4GB显存的机器上进行预训练; @@ -27,6 +33,12 @@ ChatLM-mini-Chinese为中文对话小模型,模型参数只有0.2B(算共享 🟢**最近更新** +
+ 2024-01-07 +- 添加数据清洗过程中基于mini hash实现的文档去重(在本项目中其实数据集的样本去重),防止模型遇到多次重复数据后,在推理时吐出训练数据。
+- 添加`DropDatasetDuplicate`类实现对大数据集的文档去重。
+
+
2023-12-29 - 更新模型代码(权重不变),可以直接使用`AutoModelForSeq2SeqLM.from_pretrained(...)`加载模型使用。
@@ -165,6 +177,10 @@ Apple是一家专注于设计和用户体验的公司,其产品在设计上注 ## 3.2 从克隆仓库代码开始 +本项目模型为`TextToText`模型,在预训练阶段、SFT阶段、RLFH阶段的`prompt`、`response`等字段,请务必加上`[EOS]`句子结束标记。 +本项目模型为`TextToText`模型,在预训练阶段、SFT阶段、RLFH阶段的`prompt`、`response`等字段,请务必加上`[EOS]`句子结束标记。 +本项目模型为`TextToText`模型,在预训练阶段、SFT阶段、RLFH阶段的`prompt`、`response`等字段,请务必加上`[EOS]`句子结束标记。 + ### 3.2.1 克隆项目: ```bash git clone --depth 1 https://github.com/charent/ChatLM-mini-Chinese.git diff --git a/model/chat_model.py b/model/chat_model.py index 4c38e2eafee41b84f39482d7bc005aef0381637e..62bdc8beaa709b82a0aaab66a103ac72ecea7600 100644 --- a/model/chat_model.py +++ b/model/chat_model.py @@ -57,7 +57,7 @@ class TextToTextModel(T5ForConditionalGeneration): generation_config.num_beams = 1 generation_config.do_sample = True generation_config.top_k = 50 - generation_config.temperature = 0.98 # 越低概率越趋向于均匀分布 + generation_config.temperature = 0.98 # 越低,贫富差距越大,越高(>1),越趋向于均匀分布 generation_config.top_p = 0.80 generation_config.no_repeat_ngram_size = 4 elif search_type == 'contrastive': diff --git a/requirements.txt b/requirements.txt index d0eb3cfc711e4e3530b779f157932fc5b3e1ef13..dc1c6bd2ad03c47f6cf8fc032ad7d3093891d55c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,7 @@ accelerate==0.25.0 colorlog==6.8.0 datasets==2.15.0 +datasketch==1.6.4 fastapi==0.105.0 fastparquet==2023.10.1 fire==0.5.0 diff --git a/utils/functions.py b/utils/functions.py index 9ad4a9faaccc549749206f45b98b6270d938dcf1..5b84f30ea87e57e045f830964bdeedcda28110fc 100644 --- a/utils/functions.py +++ b/utils/functions.py @@ -5,6 +5,11 @@ from transformers import T5Config import ctypes import os import platform +import re + +# from nltk import ngrams +from datasketch import MinHash, MinHashLSH +from collections import defaultdict from nltk.translate.bleu_score import sentence_bleu import numpy as np @@ -16,6 +21,64 @@ from config import T5ModelConfig END_PUN = set(".。!!))》}】??\"”") +# 保留中文和英文、下划线,不要标点符号 +NON_CHAR = re.compile("[^[\u4E00-\u9FA5|A-Za-z_0-9]") + +def _get_doc_mini_hash(doc: list[str] | str, num_perm: int) -> MinHash: + ''' + 获取一段文本的mini hash + ''' + mini_hash = MinHash(num_perm=num_perm) + for s in doc: + mini_hash.update(s.encode('utf-8')) + return mini_hash + +class DropDatasetDuplicate: + + def __init__(self, threshold: float=0.85, num_perm: int=256) -> None: + ''' + 获取一个数据集中所有重复(相似的超过threshold)的index,输入为:list[str],一个str元素为一段文本(doc) + 如输入: [a, b, c, d, c, d, e] 返回:{4, 5} (后面两个 c, d 的index) + ''' + self.similar_index_cluster = defaultdict(set) + self.data_lsh = MinHashLSH(threshold=threshold, num_perm=num_perm) + self.num_perm = num_perm + + def add_doc(self, index: object, doc: str,) -> set[int]: + ''' + 添加文档, + index: 文档的索引 + doc: 文档本身 + ''' + + # 只保留中文和英文、下划线,不要标点符号 + doc = ''.join(NON_CHAR.split(doc)) + # doc = [''.join(t) for t in list(ngrams(doc, 3))] + + doc_hash = _get_doc_mini_hash(doc, self.num_perm) + close_duplicates = self.data_lsh.query(doc_hash) + + self.data_lsh.insert(index, doc_hash) + + # 所有相似的doc在similar_index_cluster中的key都是最早出现的idx + # 如:data中索引inndex 2, 7, 8, 9, 10, 12 是相似的,则在similar_index_cluster中表现为 {2: {8, 9, 10, 12}} + if len(close_duplicates) > 0: + min_idx= min(close_duplicates) + self.similar_index_cluster[min_idx].add(index) + + def get_duplicate_indexs(self): + ''' + 返回所有的重复文档索引 + ''' + similar_index_cluster = self.similar_index_cluster + need_to_remove_idx = set() + + for key_idx in similar_index_cluster.keys(): + need_to_remove_idx |= similar_index_cluster[key_idx] + + return need_to_remove_idx + + def get_T5_config(config: T5ModelConfig, vocab_size: int, decoder_start_token_id: int=0, eos_token_id: int=1) -> T5Config: ''' 用户配置转换为T5Config diff --git a/utils/raw_data_process.py b/utils/raw_data_process.py index d27fa52a206efeb555dc30095cead6acabad1cd8..3aaeb21bafe55206f24abe16de6edee6411ba722 100644 --- a/utils/raw_data_process.py +++ b/utils/raw_data_process.py @@ -21,7 +21,7 @@ sys.path.extend(['.','..']) from logger import Logger from config import PROJECT_ROOT -from utils.functions import get_path_of_suffix_files +from utils.functions import get_path_of_suffix_files, DropDatasetDuplicate log = Logger('data_process', save2file=True, file_name=PROJECT_ROOT + '/logs/raw_data_process.log') @@ -765,6 +765,60 @@ def merge_dataset_as_single_file(groups_cnt: int=50000, max_len: int=512, min_le log.info("merge into file: {}, 全部数据共{}行,清洗后剩余{}行".format(save_file, all_cnt, keep_cnt), save_to_file=True) + +def remove_dataset_duplicate_rows(groups_cnt: int=50000) -> None: + ''' + 使用mini_hash删除数据集中重复的部分 + ''' + from_parquet_files = PROJECT_ROOT + '/data/my_dataset.parquet' + + save_file = PROJECT_ROOT + '/data/my_dataset_no_dulpticates.parquet' + + # 后续append写入,存在文件先删除 + if exists(save_file): + assert delete_file(save_file) + + cur_rows = [] + all_cnt, keep_cnt = 0, 0 + row_index = -1 + drop_dataset_duplicate = DropDatasetDuplicate(threshold=0.85, num_perm=256) + + parquet_table = pq.read_table(from_parquet_files) + all_cnt = parquet_table.num_rows + + # 先顺序遍历获取哪些行是重复的 + for prompt, response in progress.track(zip(parquet_table['prompt'], parquet_table['response']), total=parquet_table.num_rows): + row_index += 1 + + doc = f"{prompt.as_py()}{response.as_py()}" + drop_dataset_duplicate.add_doc(index=row_index, doc=doc) + + row_index = -1 + need_to_drop_indexs = drop_dataset_duplicate.get_duplicate_indexs() + + # 再顺序遍历一遍,重复的行不添加到新的数据集 + for prompt, response in progress.track(zip(parquet_table['prompt'], parquet_table['response']), total=parquet_table.num_rows): + row_index += 1 # 不管有没有跳过行, row_index都必须+1 + + # 重复的行跳过 + if row_index in need_to_drop_indexs: + continue + + cur_rows.append({'prompt': prompt.as_py() , 'response': response.as_py()}) + keep_cnt += 1 + + if len(cur_rows) >= groups_cnt: + df = pd.DataFrame(cur_rows) + write_single_parquet_file(save_file, df) + cur_rows = [] + + # 处理末尾部分 + if len(cur_rows) > 0: + df = pd.DataFrame(cur_rows) + write_single_parquet_file(save_file, df) + + log.info("merge into file: {}, 全部数据共{}行,文档去重后剩余{}行".format(save_file, all_cnt, keep_cnt), save_to_file=True) + def shuffle_parquet_dataset(parquet_file: str, shuffle_file: str, seed: int=23333, groups_cnt: int=65536) -> None: ''' 打乱一个parquet文件数据集 @@ -1128,6 +1182,9 @@ if __name__ == '__main__': # merge # merge_dataset_as_single_file(groups_cnt=50000, min_len=3, max_len=512, cut_max_len=True) + + + remove_dataset_duplicate_rows(groups_cnt=50000) # # shuffle # shuffle_parquet_dataset(