This is an inference engine for the language model of RWKV implemented in pure WebGPU.
Note that web-rwkv
is only an inference engine. It only provides the following functionalities:
run
function that takes in prompt tokens and returns logits, and a softmax
function that turns logits into predicted next token probabilities. Both of them are executed on GPU.It does not provide the following:
web-rwkv
.web-rwkv-axum
project if you want some fancy inference pipelines, including Classifier-Free Guidance (CFG), Backus–Naur Form (BNF) guidance, and more.For devs: Check the slides for technique details about the architecture, history and optimizations.
convert_safetensors.py
. Put the .st
model under assets/models
.$ cargo build --release --examples
The test generates 500 tokens and measure the time cost.
$ cargo run --release --example gen
To chat with the model, run
$ cargo run --release --example chat
In this demo, type +
to retry last round's generation; type -
to exit.
To specify the location of your safetensors model, use
$ cargo run --release --example chat -- --model /path/to/model
To load custom prompts for chat, use
$ cargo run --release --example chat -- --prompt /path/to/prompt
See assets/prompt.json
for details.
To specify layer quantization, use --quant <LAYERS>
or --quant-nf4 <LAYERS>
to quantize the first <LAYERS>
layers. For example, use
$ cargo run --release --example chat -- --quant 32
to quantize all 32 layers.
This demo showcases generation of 4 batches of text with various lengths simultaneously.
$ cargo run --release --example batch
The inspector demo is a guide to an advanced usage called hooks. Hooks allow user to inject any tensor ops into the model's inference process, fetching and modifying the contents of the runtime buffer, state, and even the model parameters. Hooks enable certain third-party implementations like dynamic LoRA, control net, and so on.
All versions of models implements serde::ser::Serialize
and serde::de::DeserializeSeed<'de>
, which means that one can save quantized or lora-merged model into a file and load it afterwards.
To use in your own rust project, simply add web-rwkv = "0.10"
as a dependency in your Cargo.toml
.
Check examples on how to create the environment, the tokenizer and how to run the model.
Since v0.7 there is a runtime
feature for the crate. When enabled, applications can use infrastructures of the asynchronous runtime
API.
In general, a runtime
is an asynchronous task that is driven by tokio
. It allows CPU and GPU to work in parallel, maximizing the utilization of GPU computing resource.
Check examples starting with rt
for more information, and compare the generation speed with their non-rt
counterparts.
Since version v0.2.4, the engine supports batched inference, i.e., inference of a batch of prompts (with different length) in parallel.
This is achieved by a modified WKV
kernel.
When building the model, the user specifies token_chunk_size
(default: 32, but for powerful GPUs this could be much higher), which is the maximum number of tokens the engine could process in one run
call.
After creating the model, the user creates a ModelState
with num_batch
specified.
This means that there are num_batch
slots that could consume the inputs in parallel.
Before calling run()
, the user fills each slot with some tokens as prompt.
If a slot is empty, no inference will be run for it.
After calling run()
, some (but may not be all) input tokens are consumed, and logits
appears in their corresponding returned slots if the inference of that slot is finished during this run.
Since there are only token_chunk_size
tokens are processed during each run()
call, there may be none of logits
appearing in the results.
Hooks are a very powerful tool for customizing model inference process.
The library provides with the Model::run_with_hooks
function, which takes into a HookMap
as a parameter.
HookMap
is essentially a hashmap from Model::Hook
to functions.Model::Hook
defines a certain place the hook function can be injected into. A model generally has dozens of hooking points.Fn(&Frame) -> Result<TensorOp, TensorError>
, where you can create tensor ops that reads/writes all the tensors you get here.Frame
contains all accessible GPU buffers during the inference, including the state and all runtime buffers.An example reading out every layer's output during inference:
let info = model.info();
#[derive(Debug, Clone)]
struct Buffer(TensorGpu<f32, ReadWrite>);
// create a buffer to store each layer's output
let buffer = Buffer(context.tensor_init([info.num_emb, info.num_layer, 1, 1]));
let mut hooks = HookMap::default();
for layer in 0..info.num_layer {
// cloning a buffer doesn't actually clone its internal data; use `deep_clone()` to clone to a new buffer
let buffer = buffer.clone();
hooks.insert(
v6::Hook::PostFfn(layer),
Box::new(
move |frame: &v6::Frame<_>| -> Result<TensorOp, TensorError> {
// figure out how many tokens this run has
let shape = frame.buffer.ffn_x.shape();
let num_token = shape[1];
// "steal" the layer's output (activation), and put it into our buffer
TensorOp::blit(
frame.buffer.ffn_x.view(.., num_token - 1, .., ..)?,
buffer.0.view(.., layer, .., ..)?,
)
},
),
);
}
let bundle = v6::Bundle::<f16>::new_with_hooks(model, 1, hooks);
let runtime = TokioRuntime::new(bundle).await;
let (input, output) = runtime.infer(input).await?;
// now the data is available in `buffer`, we can read it back
let data = buffer.back().await.to_vec();
You must download the model and put in assets/models
before running if you are building from source.
You can now download the converted models here.
You may download the official RWKV World series models from HuggingFace, and convert them via the provided convert_safetensors.py
.
$ python assets/scripts/convert_safetensors.py --input /path/to/model.pth --output /path/to/model.st
If you don't have python installed or don't want to, there is a pure rust converter
.
You can clone that repo and run
$ cd /path/to/web-rwkv-converter
$ cargo run --release --example converter -- --input /path/to/model.pth --output /path/to/model.st
"thread 'main' panicked at 'called Result::unwrap()
on an Err
value: HeaderTooLarge'"
Your model is broken, mainly because you cloned the repo but did not set up git-lfs.Please download the model manually and overwrite that one in assets/models
.
"thread 'main' panicked at 'Error in Queue::submit: parent device is lost'"
Your GPU is not responding. Maybe you are running a model that is just too big for your device. If the model doesn't fit into your VRam, the driver needs to constantly swap and transfer the model parameters, causing it to be 10x slower. Try to quantize your model first.
Source: link.
The default toolchain installed on Windows by Rustup
is the x86_64-pc-windows-msvc
toolchain. This toolchain does not include Rust-specific formatters for LLDB, as it is assumed that users will primarily use WinDbg or Microsoft Visual Studio's debugger for this target.
If you prefer to use CodeLLDB
for debugging, you have two options:
x86_64-pc-windows-gnu
toolchain to compile your Rust project: This option ensures full LLDB visualization support for Rust types.x86_64-pc-windows-msvc
toolchain but use LLDB formatters from x86_64-pc-windows-gnu
: To use this option, install the x86_64-pc-windows-gnu
toolchain via rustup toolchain add x86_64-pc-windows-gnu
.
Then, configure CodeLLDB to load its formatters by adding the following entry to your workspace configuration:"lldb.script": { "lang.rust.toolchain": "x86_64-pc-windows-gnu" }
Note that this setup is less ideal due to differences in the debug information layout emitted by the Rust compiler for enum data types when targeting MSVC, which means enums may not be visualized correctly. However, LLDB formatters will work for standard collections like strings and vectors.
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。
1. 开源生态
2. 协作、人、软件
3. 评估模型