diff --git a/python/level1_single_api/9_amct/amct_pytorch/mxfp4_quantization/README_CN.md b/python/level1_single_api/9_amct/amct_pytorch/mxfp4_quantization/README_CN.md new file mode 100644 index 0000000000000000000000000000000000000000..b9302933f493bf5b061a68c16b35c316c3561504 --- /dev/null +++ b/python/level1_single_api/9_amct/amct_pytorch/mxfp4_quantization/README_CN.md @@ -0,0 +1,49 @@ +# MXFP4量化 + +## 1 MXFP4量化前提 + +### 1.1 安装依赖 + +本sample依赖包可参考[requirements.txt](requirements.txt) + +### 1.2 模型和数据集准备 + +本sample以Llama2-7b模型和c4/realnewslike数据集为示例,请用户自行下载。 + +### 1.3 简易量化配置 +./quant_conf/quant.cfg文件为用户自定义的简易量化配置,具体表示信息如下: + + +| 字段 |类型| 说明 | 默认值 | 取值范围 | 注意事项 | +|:--| :-: | :-- | :-: | :-: | :-: | +|batch_num|uint32|量化使用的batch数量 |1|/|/| +|weight_only_config.weight_compress_only|bool|是否为仅权重量化|/|True/False|MXFP4量化目前仅支持权重量化,需要设置为True| +|weight_only_config.wts_type|enum|量化后权重类型|MXFP4_E2M1|MXFP4_E2M1/HIFLOAT8/FLOAT8_E4M3FN|/| +|weight_only_config.weight_granularity|enum|权重量化粒度|/|PER_TENSOR/PER_CHANNEL/PER_GROUP|MXFP4_E2M1仅支持PER_GROUP模式| +|weight_only_config.round_mode|enum|舍入模式|/|HYBRID/ROUND/RINT|MXFP4_E2M1仅支持RINT模式| + +## 2 MXFP4量化示例 + +### 2.1 使用接口方式调用 + +**step 1.** 请在当前目录执行如下命令运行示例程序得到量化因子记录文件,用户根据实际情况修改示例程序中的模型和数据集路径: + +`CUDA_VISIBLE_DEVICES=0,1 python3 src/run_llama7b_quantization.py` + +推理成功后,在当前目录会生成量化日志文件./amct_log/amct_pytorch.log和./outputs文件夹,该文件夹内包含以下内容: + +- config.json:量化配置文件,描述了如何对模型中的每一层进行量化。 +- record.txt:量化因子记录文件。 + +**step 2.** 用户可参考src/reference_function.py文件中的read_kv_cache_factors函数读取record文件中的量化因子,参考do_quant函数和do_antiquant函数进行量化和反量化。用户自行修改模型后即可进行模型推理。 + +> 如果outputs目录下已经存在量化配置文件或量化因子记录文件,再次运行示例程序时,如果新生成的文件与已有文件同名,则会覆盖已有的量化配置文件或量化因子记录文件。 + +### 2.2 使用单算子方式调用 + +如果用户需要支持更多算子类型,或者用户自定义了其他操作,则可以使用单算子方式进行构图,然后进行量化校准,并输出量化因子记录文件。 + +**step 1.** 请参考src/quant_calibration_op_demo.py文件对模型进行修改后,进行模型推理得到量化因子记录文件。 + +**step 2.** 可参考2.1中step 2修改模型后即可进行模型推理。进行此步骤时请注释step 1修改的代码防止record文件被覆盖。 + diff --git a/python/level1_single_api/9_amct/amct_pytorch/mxfp4_quantization/requirement.txt b/python/level1_single_api/9_amct/amct_pytorch/mxfp4_quantization/requirement.txt new file mode 100644 index 0000000000000000000000000000000000000000..643825ed941e406e83807c292bc2fa7e1ce463ac --- /dev/null +++ b/python/level1_single_api/9_amct/amct_pytorch/mxfp4_quantization/requirement.txt @@ -0,0 +1,8 @@ +torch==2.1.0 +transformers==4.40.0 +accelerate==0.30.1 +datasets==2.19.1 +sentencepiece==0.2.0 +onnx==1.10.0 +numpy==1.23.5 +protobuf==3.13.0 \ No newline at end of file diff --git a/python/level1_single_api/9_amct/amct_pytorch/mxfp4_quantization/src/run_llama7b_infer.py b/python/level1_single_api/9_amct/amct_pytorch/mxfp4_quantization/src/run_llama7b_infer.py new file mode 100644 index 0000000000000000000000000000000000000000..dac80a7185e18dd8ec9135a1cb72ce3b36e530cf --- /dev/null +++ b/python/level1_single_api/9_amct/amct_pytorch/mxfp4_quantization/src/run_llama7b_infer.py @@ -0,0 +1,119 @@ + +import os +import time +import argparse +import tqdm +import torch +import torch.nn as nn +from transformers import AutoTokenizer, AutoConfig +from accelerate import infer_auto_device_map, dispatch_model +from accelerate.utils.modeling import get_balanced_memory + +from data_utils import get_loaders +from model_utils import get_llama2 +import amct_pytorch as amct + + +def build_model_and_enc(model, model_path, gpu_num, map_type='manual'): + config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) + if "mpt" in config.__class__.__name__.lower(): + enc = AutoTokenizer.from_pretrained( + config.tokenizer_name, trust_remote_code=True + ) + else: + enc = AutoTokenizer.from_pretrained( + model_path, use_fast=False, trust_remote_code=True + ) + + # Move the model to GPU (as much as possible) for LM evaluation + # max_memory = ['0:16GiB', '1:16GiB','2:16GiB', 'cpu:30GiB'], '0' means the first GPU that you specify. + # I don't recommend use 16GiB, we need to reserve some space for other tensors during calculation + # please see the recommand memeory allocation in the Word file + # Adjust the max_size accroding to the real situation + # a clever way: + if map_type == 'manual': + max_memory = [] + for i in range(gpu_num): + max_memory.append(f'{i}:12GiB') + max_memory.append('cpu:80GiB') + print('Max_memory allocation: \n', max_memory) + + max_memory = [v.split(":") for v in (max_memory or [])] + max_memory = {(int(k) if k.isdigit() else k): v for k, v in max_memory} + kwargs = { + "max_memory": get_balanced_memory( + model, max_memory if len(max_memory) > 0 else None + ) + } + model.tie_weights() + device_map = infer_auto_device_map( + model, + # TODO: can we remove this? + no_split_module_classes=[ + "OPTDecoderLayer", + "LlamaDecoderLayer", + "BloomBlock", + "MPTBlock", + "DecoderLayer", + ], + **kwargs, + ) + model = dispatch_model(model, device_map=device_map, + offload_dir=os.path.join(model_path, 'offload_dir')) + + return model, enc + + +if __name__ == '__main__': + model, model_path = get_llama2('7b') + model = model.eval() + gpus = os.getenv('CUDA_VISIBLE_DEVICES') + if gpus == '' or gpus == '-1' or gpus is None: + gpu_num = 0 + else: + gpu_num = len(gpus.split(',')) + model, enc = build_model_and_enc(model, model_path, gpu_num) + # Load dataset + testenc = get_loaders(dataset_name='wikitext2', + enc=enc, + seqlen=model.seqlen) + + testenc = testenc.input_ids.to(model.device) + + CUR_DIR = os.path.split(os.path.realpath(__file__))[0] + temp_dir = os.path.join(CUR_DIR, 'output') + os.makedirs(temp_dir, exist_ok=True) + proto_path = os.path.join(temp_dir, 'temp_proto.cfg') + config_file = os.path.join(temp_dir, 'config.json') + record_file = os.path.join(temp_dir, 'record.txt') + + + fake_quant_model = amct.save_post_quant_model(record_file, model, mode='fakequant') + nsamples = testenc.numel() // model.seqlen + if torch.cuda.is_available(): + torch.cuda.empty_cache() + # Start inference + nlls = [] + test_start_time = time.time() + for i in tqdm.tqdm(range(nsamples), desc="evaluating..."): + batch = testenc[:, (i * model.seqlen) : ((i + 1) * model.seqlen)].to( + model.device + ) + with torch.no_grad(): + lm_logits = fake_quant_model(batch).logits + shift_logits = lm_logits[:, :-1, :].contiguous().float() + shift_labels = testenc[:, (i * model.seqlen) : ((i + 1) * model.seqlen)][:, 1:] + loss_fct = nn.CrossEntropyLoss() + loss = loss_fct( + shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1) + ) + neg_log_likelihood = loss.float() * model.seqlen + nlls.append(neg_log_likelihood) + test_end_time = time.time() + + ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * model.seqlen)) + + total_time = test_end_time - test_start_time + print('Test time taken: ', total_time // 60, 'min ', total_time%60, 's' ) + print('Score: ', ppl.item()) + diff --git a/python/level1_single_api/9_amct/amct_pytorch/mxfp4_quantization/src/run_llama7b_quantization.py b/python/level1_single_api/9_amct/amct_pytorch/mxfp4_quantization/src/run_llama7b_quantization.py new file mode 100644 index 0000000000000000000000000000000000000000..2801b36f4b20e845f65dca471e9612890c299f92 --- /dev/null +++ b/python/level1_single_api/9_amct/amct_pytorch/mxfp4_quantization/src/run_llama7b_quantization.py @@ -0,0 +1,118 @@ +import os +import time +import argparse +import tqdm +import torch +import torch.nn as nn +from transformers import AutoTokenizer, AutoConfig +from accelerate import infer_auto_device_map, dispatch_model +from accelerate.utils.modeling import get_balanced_memory + +from data_utils import get_loaders +from model_utils import get_llama2 +import amct_pytorch as amct + + +def build_model_and_enc(model, model_path, gpu_num, map_type='manual'): + config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) + if "mpt" in config.__class__.__name__.lower(): + enc = AutoTokenizer.from_pretrained( + config.tokenizer_name, trust_remote_code=True + ) + else: + enc = AutoTokenizer.from_pretrained( + model_path, use_fast=False, trust_remote_code=True + ) + + # Move the model to GPU (as much as possible) for LM evaluation + # max_memory = ['0:16GiB', '1:16GiB','2:16GiB', 'cpu:30GiB'], '0' means the first GPU that you specify. + # I don't recommend use 16GiB, we need to reserve some space for other tensors during calculation + # please see the recommand memeory allocation in the Word file + # Adjust the max_size accroding to the real situation + # a clever way: + if map_type == 'manual': + max_memory = [] + for i in range(gpu_num): + max_memory.append(f'{i}:12GiB') + max_memory.append('cpu:80GiB') + print('Max_memory allocation: \n', max_memory) + + max_memory = [v.split(":") for v in (max_memory or [])] + max_memory = {(int(k) if k.isdigit() else k): v for k, v in max_memory} + kwargs = { + "max_memory": get_balanced_memory( + model, max_memory if len(max_memory) > 0 else None + ) + } + model.tie_weights() + device_map = infer_auto_device_map( + model, + # TODO: can we remove this? + no_split_module_classes=[ + "OPTDecoderLayer", + "LlamaDecoderLayer", + "BloomBlock", + "MPTBlock", + "DecoderLayer", + ], + **kwargs, + ) + model = dispatch_model(model, device_map=device_map, + offload_dir=os.path.join(model_path, 'offload_dir')) + + return model, enc + + +if __name__ == '__main__': + model, model_path = get_llama2('7b') + model = model.eval() + gpus = os.getenv('CUDA_VISIBLE_DEVICES') + if gpus == '' or gpus == '-1' or gpus is None: + gpu_num = 0 + else: + gpu_num = len(gpus.split(',')) + model, enc = build_model_and_enc(model, model_path, gpu_num) + # Load dataset + testenc = get_loaders(dataset_name='wikitext2', + enc=enc, + seqlen=model.seqlen) + + testenc = testenc.input_ids.to(model.device) + + CUR_DIR = os.path.split(os.path.realpath(__file__))[0] + temp_dir = os.path.join(CUR_DIR, 'output') + os.makedirs(temp_dir, exist_ok=True) + proto_path = os.path.join(temp_dir, 'temp_proto.cfg') + config_file = os.path.join(temp_dir, 'config.json') + record_file = os.path.join(temp_dir, 'record.txt') + with open(proto_path, 'w') as f: + proto_txt = '''batch_num: 4 + weight_only_config: { + weight_compress_only: True + wts_type: MXFP4_E2M1 + awq_quantize : { + grids_num: 20 + } + }''' + f.write(proto_txt) + + test_start_time = time.time() + amct.create_post_quant_config(config_file, + model, + config_defination=proto_path) + + post_quant_model = amct.create_post_quant_model(config_file, + record_file, + model) + if torch.cuda.is_available(): + torch.cuda.empty_cache() + # Do inference to get quantize factors + batch_num = 3 + model.config.use_cache = False + for i in tqdm.tqdm(range(1), desc="getting quantize factors..."): + batch = testenc[:, (i * post_quant_model.seqlen) : ((i + batch_num) * post_quant_model.seqlen)].to(post_quant_model.device) + with torch.no_grad(): + post_quant_model(batch) + test_end_time = time.time() + total_time = test_end_time - test_start_time + print('Calibration time taken: ', total_time // 60, 'min ', total_time%60, 's') \ No newline at end of file