This document explains how to build the GPT-NeoX model using TensorRT-LLM and run on a single GPU and a single node with multiple GPUs.
The TensorRT-LLM GPT-NeoX implementation can be found in tensorrt_llm/models/gptneox/model.py
. The TensorRT-LLM GPT-NeoX example code is located in examples/gptneox
. There is one main file:
convert_checkpoint.py
to convert a checkpoint from the HuggingFace (HF) Transformers format to the TensorRT-LLM format.In addition, there are two shared files in the parent folder examples
for inference and evaluation:
../run.py
to run the inference on an input text;../summarize.py
to summarize the articles in the cnn_dailymail dataset.The TensorRT-LLM GPT-NeoX example code locates at examples/gptneox. It takes HF weights as input, and builds the corresponding TensorRT engines. The number of TensorRT engines depends on the number of GPUs used to run inference.
Please install required packages first:
pip install -r requirements.txt
# Weights & config
git clone https://huggingface.co/EleutherAI/gpt-neox-20b gptneox_model
If you want to use Int8 weight only quantization, just need to add --use_weight_only
flag.
# Single GPU
python3 convert_checkpoint.py --model_dir ./gptneox_model \
--dtype float16 \
--output_dir ./gptneox/20B/trt_ckpt/fp16/1-gpu/
# With 2-way Tensor Parallel
python3 convert_checkpoint.py --model_dir ./gptneox_model \
--dtype float16 \
--tp_size 2 \
--workers 2 \
--output_dir ./gptneox/20B/trt_ckpt/fp16/2-gpu/
# Single GPU with int8 weight only
python3 convert_checkpoint.py --model_dir ./gptneox_model \
--dtype float16 \
--use_weight_only \
--output_dir ./gptneox/20B/trt_ckpt/int8_wo/1-gpu/
# With 2-way Tensor Parallel with int8 weight only
python3 convert_checkpoint.py --model_dir ./gptneox_model \
--dtype float16 \
--use_weight_only \
--tp_size 2 \
--workers 2 \
--output_dir ./gptneox/20B/trt_ckpt/int8_wo/2-gpu/
# Single GPU
trtllm-build --checkpoint_dir ./gptneox/20B/trt_ckpt/fp16/1-gpu/ \
--gemm_plugin float16 \
--max_batch_size 8 \
--max_input_len 924 \
--max_output_len 100 \
--output_dir ./gptneox/20B/trt_engines/fp16/1-gpu/
# With 2-way Tensor Parallel
trtllm-build --checkpoint_dir ./gptneox/20B/trt_ckpt/fp16/2-gpu/ \
--gemm_plugin float16 \
--max_batch_size 8 \
--max_input_len 924 \
--max_output_len 100 \
--workers 2 \
--output_dir ./gptneox/20B/trt_engines/fp16/2-gpu/
# Single GPU with int8 weight only
trtllm-build --checkpoint_dir ./gptneox/20B/trt_ckpt/int8_wo/1-gpu/ \
--gemm_plugin float16 \
--max_batch_size 8 \
--max_input_len 924 \
--max_output_len 100 \
--output_dir ./gptneox/20B/trt_engines/int8_wo/1-gpu/
# With 2-way Tensor Parallel with int8 weight only
trtllm-build --checkpoint_dir ./gptneox/20B/trt_ckpt/int8_wo/2-gpu/ \
--gemm_plugin float16 \
--max_batch_size 8 \
--max_input_len 924 \
--max_output_len 100 \
--workers 2 \
--output_dir ./gptneox/20B/trt_engines/int8_wo/2-gpu/
The following section describes how to run a TensorRT-LLM GPT-NeoX model to summarize the articles from the
cnn_dailymail dataset. For each summary, the script can compute the
ROUGE scores and use the ROUGE-1
score to validate the implementation.
The script can also perform the same summarization using the HF GPT-NeoX model.
# Single GPU
python3 ../summarize.py --engine_dir ./gptneox/20B/trt_engines/fp16/1-gpu/ \
--test_trt_llm \
--hf_model_dir gptneox_model \
--data_type fp16
# With 2-way Tensor Parallel
mpirun -np 2 --oversubscribe --allow-run-as-root \
python3 ../summarize.py --engine_dir ./gptneox/20B/trt_engines/fp16/2-gpu/ \
--test_trt_llm \
--hf_model_dir gptneox_model \
--data_type fp16
# Single GPU with int8 weight only
python3 ../summarize.py --engine_dir ./gptneox/20B/trt_engines/int8_wo/1-gpu/ \
--test_trt_llm \
--hf_model_dir gptneox_model \
--data_type fp16
# With 2-way Tensor Parallel with int8 weight only
mpirun -np 2 --oversubscribe --allow-run-as-root \
python3 ../summarize.py --engine_dir ./gptneox/20B/trt_engines/int8_wo/2-gpu/ \
--test_trt_llm \
--hf_model_dir gptneox_model \
--data_type fp16
# Weights & config
sh get_weights.sh
In this example, the weights are quantized using GPTQ-for-LLaMa. Note that the parameter --act-order
referring to whether to apply the activation order GPTQ heuristic is not supported by TRT-LLM.
sh gptq_convert.sh
To apply groupwise quantization GPTQ, addition command-line flags need to be passed to convert_checkpoint.py
:
Here --ammo_quant_ckpt_path
flag specifies the output safetensors of gptq_convert.sh
script.
# Single GPU
python3 convert_checkpoint.py --model_dir ./gptneox_model \
--dtype float16 \
--use_weight_only \
--weight_only_precision int4_gptq \
--ammo_quant_ckpt_path ./gptneox_model/gptneox-20b-4bit-gs128.safetensors \
--output_dir ./gptneox/20B/trt_ckpt/int4_gptq/1-gpu/
# With 2-way Tensor Parallel
python3 convert_checkpoint.py --model_dir ./gptneox_model \
--dtype float16 \
--use_weight_only \
--weight_only_precision int4_gptq \
--tp_size 2 \
--workers 2 \
--ammo_quant_ckpt_path ./gptneox_model/gptneox-20b-4bit-gs128.safetensors \
--output_dir ./gptneox/20B/trt_ckpt/int4_gptq/2-gpu/
The command to build TensorRT engines to apply GPTQ does not change:
# Single GPU
trtllm-build --checkpoint_dir ./gptneox/20B/trt_ckpt/int4_gptq/1-gpu/ \
--gemm_plugin float16 \
--max_batch_size 8 \
--max_input_len 924 \
--max_output_len 100 \
--output_dir ./gptneox/20B/trt_engines/int4_gptq/1-gpu/
# With 2-way Tensor Parallel
trtllm-build --checkpoint_dir ./gptneox/20B/trt_ckpt/int4_gptq/2-gpu/ \
--gemm_plugin float16 \
--max_batch_size 8 \
--max_input_len 924 \
--max_output_len 100 \
--workers 2 \
--output_dir ./gptneox/20B/trt_engines/int4_gptq/2-gpu/
The command to run summarization with GPTQ quantized model also does not change:
# Single GPU
python3 ../summarize.py --engine_dir ./gptneox/20B/trt_engines/int4_gptq/1-gpu/ \
--test_trt_llm \
--hf_model_dir gptneox_model \
--data_type fp16
# With 2-way Tensor Parallel
mpirun -np 2 --oversubscribe --allow-run-as-root \
python3 ../summarize.py --engine_dir ./gptneox/20B/trt_engines/int4_gptq/2-gpu/ \
--test_trt_llm \
--hf_model_dir gptneox_model \
--data_type fp16
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。