An LLM inference engine developed based on MNN, supporting mainstream open-source LLM models. This functionality is divided into two parts:
llmexport
is a tool for exporting LLM models, capable of exporting LLM models to ONNX and MNN formats.
Clone the LLM project you want to export to your local environment, for example, Qwen2-0.5B-Instruct:
git clone https://www.modelscope.cn/qwen/Qwen2-0.5B-Instruct.git
Run llmexport.py to export the model:
cd ./transformers/llm/export
# Export the model, tokenizer, embedding, and the corresponding MNN model
python llmexport.py \
--path /path/to/Qwen2-0.5B-Instruct \
--export mnn
Exported Artifacts
The exported files include:
config.json:
Configuration file for runtime, which can be manually modified.embeddings_bf16.bin:
Binary file containing the embedding weights, used during inference.llm.mnn:
The MNN model file, used during inference.llm.mnn.json:
JSON file corresponding to the MNN model, used for applying LoRA or GPTQ quantized weights.llm.mnn.weight:
MNN model weights, used during inference.llm.onnx:
ONNX model file without weights, not used during inference.llm_config.json:
Model configuration file, used during inference.The directory structure is as follows:
.
└── model
├── config.json
├── embeddings_bf16.bin
├── llm.mnn
├── llm.mnn.json
├── llm.mnn.weight
├── onnx/
├──llm.onnx
├──llm.onnx.data
├── llm_config.json
└── tokenizer.txt
Direct Conversion to MNN Model
Use --export mnn
to directly convert to an MNN model. Note that you need to either install pymnn or specify the path to the MNNConvert tool using the --mnnconvert
option. At least one of these conditions must be met. If pymnn is not installed and the MNNConvert tool's path is not specified via --mnnconvert, the llmexport.py script will search for the MNNConvert tool in the directory "../../../build/". Ensure that the MNNConvert file exists in this directory. This method currently supports exporting 4-bit and 8-bit models.
If you encounter issues with directly converting to an MNN model or require quantization with other bit depths (e.g., 5-bit/6-bit), you can first convert the model to an ONNX model using --export onnx
. Then, use the MNNConvert tool to convert the ONNX model to an MNN model with the following command:
./MNNConvert --modelFile ../transformers/llm/export/model/onnx/llm.onnx --MNNModel llm.mnn --keepInputFormat --weightQuantBits=4 --weightQuantBlock=128 -f ONNX --transformerFuse=1 --allowCustomOp --saveExternalData
--test $query
to return the LLM's response.--lora_path
.--quant_bit
and the block size for quantization using --quant_block
.--lm_quant_bit
to specify the quantization bit depth for the lm_head layer's weights. If not specified, the bit depth defined by --quant_bit
will be used.usage: llmexport.py [-h] --path PATH [--type TYPE] [--lora_path LORA_PATH] [--dst_path DST_PATH] [--test TEST] [--export EXPORT]
[--quant_bit QUANT_BIT] [--quant_block QUANT_BLOCK] [--lm_quant_bit LM_QUANT_BIT]
[--mnnconvert MNNCONVERT]
llm_exporter
options:
-h, --help show this help message and exit
--path PATH path(`str` or `os.PathLike`):
Can be either:
- A string, the *model id* of a pretrained model like `THUDM/chatglm-6b`. [TODO]
- A path to a *directory* clone from repo like `../chatglm-6b`.
--type TYPE type(`str`, *optional*):
The pretrain llm model type.
--lora_path LORA_PATH
lora path, defaut is `None` mean not apply lora.
--dst_path DST_PATH export onnx/mnn model to path, defaut is `./model`.
--test TEST test model inference with query `TEST`.
--export EXPORT export model to an onnx/mnn model.
--quant_bit QUANT_BIT
mnn quant bit, 4 or 8, default is 4.
--quant_block QUANT_BLOCK
mnn quant block, default is 0 mean channle-wise.
--lm_quant_bit LM_QUANT_BIT
mnn lm_head quant bit, 4 or 8, default is `quant_bit`.
--mnnconvert MNNCONVERT
local mnnconvert path, if invalid, using pymnn.
Compile from Source Add the required compilation macros during the standard compilation process:
-DMNN_LOW_MEMORY=true -DMNN_CPU_WEIGHT_DEQUANT_GEMM=true -DMNN_BUILD_LLM=true -DMNN_SUPPORT_TRANSFORMER_FUSE=true
-DLLM_SUPPORT_VISION=true -DMNN_BUILD_OPENCV=true -DMNN_IMGCODECS=true
-DLLM_SUPPORT_AUDIO=true -DMNN_BUILD_AUDIO=true
For macOS/Linux:
mkdir build
cd build
cmake ../ -DMNN_LOW_MEMORY=true -DMNN_CPU_WEIGHT_DEQUANT_GEMM=true -DMNN_BUILD_LLM=true -DMNN_SUPPORT_TRANSFORMER_FUSE=true
make -j16
For x86 architecture, additionally include the MNN_AVX512
macro:
mkdir build
cd build
cmake ../ -DMNN_LOW_MEMORY=true -DMNN_CPU_WEIGHT_DEQUANT_GEMM=true -DMNN_BUILD_LLM=true -DMNN_SUPPORT_TRANSFORMER_FUSE=true -DMNN_AVX512=true
make -j16
Add the macros MNN_ARM82
and MNN_OPENCL
:
cd project/android
mkdir build_64
../build_64.sh "-DMNN_LOW_MEMORY=true -DMNN_CPU_WEIGHT_DEQUANT_GEMM=true -DMNN_BUILD_LLM=true -DMNN_SUPPORT_TRANSFORMER_FUSE=true -DMNN_ARM82=true -DMNN_OPENCL=true -DMNN_USE_LOGCAT=true"
sh package_scripts/ios/buildiOS.sh "-DMNN_ARM82=true -DMNN_LOW_MEMORY=true -DMNN_SUPPORT_TRANSFORMER_FUSE=true -DMNN_BUILD_LLM=true -DMNN_CPU_WEIGHT_DEQUANT_GEMM=true"
Refer to the environment setup at: https://mnn-docs.readthedocs.io/en/latest/compile/engine.html#web
libMNN.a
,libMNN_Express.a
,libllm.a
mkdir buildweb
emcmake cmake .. -DCMAKE_BUILD_TYPE=Release -DCMAKE_CXX_FLAGS="-msimd128 -msse4.1" -DMNN_FORBID_MULTI_THREAD=ON -DMNN_USE_THREAD_POOL=OFF -DMNN_USE_SSE=ON -DMNN_LOW_MEMORY=true -DMNN_CPU_WEIGHT_DEQUANT_GEMM=true -DMNN_BUILD_LLM=true -DMNN_SUPPORT_TRANSFORMER_FUSE=true
make -j16
emcc ../transformers/llm/engine/llm_demo.cpp -DCMAKE_CXX_FLAGS="-msimd128 -msse4.1" -I ../include -I ../transformers/llm/engine/include libMNN.a libllm.a express/libMNN_Express.a -o llm_demo.js --preload-file ~/qwen2.0_1.5b/ -s ALLOW_MEMORY_GROWTH=1 -o llm_demo.js
To test the compiled demo, use the following command:
node llm_demo.js ~/qwen2.0_1.5b/config.json ~/qwen2.0_1.5b/prompt.txt
Place all the exported files required for model inference into the same folder. Add a config.json
file to describe the model name and inference parameters. The directory structure should look as follows:
.
└── model_dir
├── config.json
├── embeddings_bf16.bin
├── llm_config.json
├── llm.mnn
├── llm.mnn.weight
└── tokenizer.txt
The configuration file supports the following options:
Model File Information
base_dir
: Directory where model files are loaded. Defaults to the directory of config.json
or the model directory.llm_config
: Path to llm_config.json
, resolved as base_dir + llm_config
. Defaults to base_dir + 'config.json'
.llm_model
: Path to llm.mnn
, resolved as base_dir + llm_model
. Defaults to base_dir + 'llm.mnn
'.llm_weight
: Path to llm.mnn.weight
, resolved as base_dir + llm_weight
. Defaults to base_dir + 'llm.mnn.weight'
block_model
: For segmented models, the path to block_{idx}.mnn
, resolved as base_dir + block_model
. Defaults to base_dir + 'block_{idx}.mnn
.lm_model
: For segmented models, the path to lm.mnn
, resolved as base_dir + lm_model
. Defaults to base_dir + 'lm.mnn'
.embedding_model
: If embedding uses a model, the path to the embedding is base_dir + embedding_model
. Defaults to base_dir + 'embedding.mnn'
.embedding_file
: If embedding uses a binary file, the path to the embedding is base_dir + embedding_file
. Defaults to base_dir + 'embeddings_bf16.bin'
.tokenizer_file
: Path to tokenizer.txt
, resolved as base_dir + tokenizer_file
. Defaults to base_dir + 'tokenizer.txt'
.visual_model
: If using a VL model, the path to the visual model is base_dir + visual_model
. Defaults to base_dir + 'visual.mnn'
.Inference Configuration
512
kv cache
in multi-turn dialogues. Defaults to false
0
, 1
, 2
, 3
, 4
. Defaults to 0:
key
nor value
is quantized.key
.fp8
format to quantize value8-bit
quantization for key and fp8
for value
.key
and value
while using asymmetric 8-bit quantization for query
and int8
matrix multiplication for Q*K
.mmap
to write weights to disk when memory is insufficient, avoiding overflow. Defaults to false
. For mobile devices, it is recommended to set this to true.mmap
for KV Cache to write to disk when memory is insufficient, avoiding overflow. Defaults to false
mmap-related
features are enabled.
NSString *tempDirectory = NSTemporaryDirectory();llm->set_config("{\"tmp_path\":\"" + std::string([tempDirectory UTF8String]) + "\"}")
Hardware Configuration
"cpu"
. "opencl"
is supported for android gpu, and "metal"
is supported for macOS and iOS GPU.4
. For OpenCL inference, use 68
."low"
, preferring fp16
."low"
, enabling runtime quantization.Sampler Configuration
greedy
, temperature
, topK
, topP
, minP
, tfs
, typical
, penalty
8 basic sampler types, and mixed
(when set as mixed
,samplers in mixed_samplers are executed one by one sequentially). Defaults to greedy
, but mixed
, temperature
are suggested for output diversity, and penalty
is suggested to avoid output repeatedness.sampler_type
is to be mixed
. Defaults to ["topK", "tfs", "typical", "topP", "min_p", "temperature"]
, which means the logits will be sampled by these strategies sequentially, one after another.temperature
, topP
, minP
, tfsZ
, typical
strategies. Defaults to 1.0topK
sampler type. Defaults to 40topP
sampler type. Defaults to 0.9minP
sampler type. Defaults to 0.1tfs
sampler type. Defaults to 1.0typical
sampler type. Defaults to 1.0penalty
sampler type. Defaults to 0.0 (no penalty)penalty
sampler type. Defaults to 8penalty
sampler type. Defaults to 1.0 (no extra penalty)penalty
sampler type,can be "greedy"
or "temperature"
. Defaults to "greedy"
.config.json
llm_config.json
{
"hidden_size": 1536,
"layer_nums": 28,
"attention_mask": "float",
"key_value_shape": [
2,
1,
0,
2,
128
],
"prompt_template": "<|im_start|>user\n%s<|im_end|>\n<|im_start|>assistant\n",
"is_visual": false,
"is_single": true
}
The usage of llm_demo
is as follows:
# Using config.json
## Interactive Chat
./llm_demo model_dir/config.json
## Replying to each line in the prompt
./llm_demo model_dir/config.json prompt.txt
# Without config.json, using default configuration
## Interactive Chat
./llm_demo model_dir/llm.mnn
## Replying to each line in the prompt
./llm_demo model_dir/llm.mnn prompt.txt
<img>https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg</img>Describe the content of the image.
Specify the image size:
<img><hw>280, 420</hw>https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg</img>Describe the content of the image.
<audio>https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/audio/translate_to_chinese.wav</audio>Describe the content of the audio.
To use GPTQ
weights, you can specify the path to the Qwen2.5-0.5B-Instruct-GPTQ-Int4
model using the --gptq_path PATH
option when exporting the Qwen2.5-0.5B-Instruct
model. Use the following command:
# Export the GPTQ-quantized model
python llmexport.py --path /path/to/Qwen2.5-0.5B-Instruct --gptq_path /path/to/Qwen2.5-0.5B-Instruct-GPTQ-Int4 --export mnn
LoRA weights can be used in two ways:
Merge LoRA weights into the original model. Export LoRA models separately. The first approach is faster and simpler but does not support runtime switching of LoRA weights. The second approach adds slight memory and computation overhead but is more flexible, supporting runtime switching of LoRA weights, making it suitable for multi-LoRA scenarios.
To merge LoRA weights into the original model, specify the --lora_path PATH
parameter during model export. By default, the model is exported with the merged weights. Use the following command:
# Export the model with merged LoRA weights
python llmexport.py --path /path/to/Qwen2.5-0.5B-Instruct --lora_path /path/to/lora --export mnn
Using the merged LoRA model is exactly the same as using the original model.
To export LoRA as a separate model, supporting runtime switching, specify the --lora_path PATH parameter and include the --lora_split flag during model export. Use the following command:
python llmexport.py --path /path/to/Qwen2.5-0.5B-Instruct --lora_path /path/to/lora --lora_split --export mnn
After export, in addition to the original model files, a new lora.mnn
file will be added to the folder, which is the LoRA model file.
Using LoRA Model
lora.json
configuration file, similar to running a merged LoRA model:{
"llm_model": "lora.mnn",
"llm_weight": "base.mnn.weight",
}
// Create and load the base model
std::unique_ptr<Llm> llm(Llm::createLLM(config_path));
llm->load();
// Use the same object to selectively use multiple LoRA models, but cannot use them concurrently
{
// Add `lora_1` model on top of the base model; its index is `lora_1_idx`
size_t lora_1_idx = llm->apply_lora("lora_1.mnn");
llm->response("Hello lora1"); // Infer using `lora_1` model
// Add `lora_2` model and use it
size_t lora_2_idx = llm->apply_lora("lora_2.mnn");
llm->response("Hello lora2"); // Infer using `lora_2` model
// Select `lora_1` as the current model using its index
llm->select_module(lora_1_idx);
llm->response("Hello lora1"); // Infer using `lora_1` model
// Release the loaded LoRA models
llm->release_module(lora_1_idx);
llm->release_module(lora_2_idx);
// Select and use the base model
llm->select_module(0);
llm->response("Hello base"); // Infer using `base` model
}
// Use multiple objects to load and use multiple LoRA models concurrently
{
std::mutex creat_mutex;
auto chat = [&](const std::string& lora_name) {
MNN::BackendConfig bnConfig;
auto newExe = Executor::newExecutor(MNN_FORWARD_CPU, bnConfig, 1);
ExecutorScope scope(newExe);
Llm* current_llm = nullptr;
{
std::lock_guard<std::mutex> guard(creat_mutex);
current_llm = llm->create_lora(lora_name);
}
current_llm->response("Hello");
};
std::thread thread1(chat, "lora_1.mnn");
std::thread thread2(chat, "lora_2.mnn");
thread1.join();
thread2.join();
}
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。