1 Star 3 Fork 0

深圳云天励飞技术股份有限公司/tyllm

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
flash_mla.py 18.44 KB
一键复制 编辑 原始数据 按行查看 历史
Chen Min 提交于 4个月前 . Initial commit
import tvm
from tvm.script.low_level import tir as L
def create_mla_kernel(batch_size, max_k_seq_len, num_head, num_hidden_nope, num_hidden_pe, q_seq_len):
lanes = 8
head_per_tile = 8
assert num_head % head_per_tile == 0
head_tiles = num_head // head_per_tile
k_block_size = 64
q_block_size = 64
max_num_blocks_per_seq = 1024
assert num_hidden_nope % lanes == 0
assert num_hidden_pe % lanes == 0
hidden_outer_nope = num_hidden_nope // lanes
hidden_outer_pe = num_hidden_pe // lanes
hidden_outer = hidden_outer_nope + hidden_outer_pe
softmax_scale = 1.0 / ((num_hidden_nope + num_hidden_pe) ** 0.5)
q_seq_len = L.dimvar("q_seq_len") if q_seq_len is None else int(q_seq_len)
@L.low_level
def mla_prefill_compact_pipe_K(Q: L.Buffer[(batch_size, num_head, hidden_outer, q_seq_len, lanes), L.DATATYPE.FLOAT16],
KV_cache: L.Buffer[(batch_size, hidden_outer, max_k_seq_len, lanes), L.DATATYPE.FLOAT16],
block_tables: L.Buffer[(batch_size, max_num_blocks_per_seq), L.int32],
cache_pos: L.Buffer[(batch_size,), L.int32],
Output: L.Buffer[(batch_size, num_head, hidden_outer, q_seq_len, lanes), L.DATATYPE.FLOAT16],
):
Q_dm = L.alloc_buffer([head_per_tile, hidden_outer, q_block_size, lanes], L.DATATYPE.FLOAT16, L.DATASCOPE.DM)
cache_dm = L.alloc_buffer([2, hidden_outer, k_block_size, lanes], L.DATATYPE.FLOAT16, L.DATASCOPE.DM)
tables_dm = L.alloc_buffer([batch_size, max_num_blocks_per_seq], L.DATATYPE.INT32, L.DATASCOPE.DM)
cache_pos_dm = L.alloc_buffer([batch_size,], L.DATATYPE.INT32, L.DATASCOPE.DM)
QK_weight = L.alloc_buffer([2, head_per_tile, k_block_size // lanes, q_block_size, lanes], L.DATATYPE.FLOAT16, L.DATASCOPE.DM)
weighted_V = L.alloc_buffer([2, head_per_tile, hidden_outer, q_block_size, lanes], L.DATATYPE.FLOAT16, L.DATASCOPE.DM)
Output_dm = L.alloc_buffer([2, head_per_tile, hidden_outer, q_block_size, lanes], L.DATATYPE.FLOAT16, L.DATASCOPE.DM)
partial_max = L.alloc_buffer([2, head_per_tile, q_block_size], L.DATATYPE.FLOAT16, L.DATASCOPE.DM)
partial_scale = L.alloc_buffer([2, head_per_tile, q_block_size], L.DATATYPE.FLOAT16, L.DATASCOPE.DM)
with L.core.cu():
L.cu.start_eidma_async_const(partial_max[:, :, :], L.float16('-inf'))
L.cu.wait_done_eidma()
L.cu.start_eidma_async(cache_pos_dm[:], cache_pos[:])
L.cu.wait_done_eidma()
q_tiles = (q_seq_len + q_block_size - 1) // q_block_size
for batch_idx in range(batch_size):
with L.core.cu():
L.cu.start_eidma_async(tables_dm[:, :], block_tables[batch_idx, :])
L.cu.wait_done_eidma()
for head_outer in range(head_tiles):
head_start = head_outer * head_per_tile
head_end = head_start + head_per_tile
for q_outer in range(q_tiles):
q_start = q_outer * q_block_size
q_end = min(q_start + q_block_size, q_seq_len)
q_tile_len = q_end - q_start
L.cu.start_eidma_async(
Q_dm[:, :, :q_tile_len, :],
Q[batch_idx, head_start:head_end, q_start:q_end, :]
)
L.cu.wait_done_eidma()
k_seq_len = min(cache_pos_dm[batch_idx], max_k_seq_len)
k_valid_seq_len = min(q_end, k_seq_len)
k_tiles = (k_valid_seq_len + k_block_size - 1) // k_block_size
for k_outer in range(k_tiles + 3):
# pipeline phase0: trigger external data loads
if k_outer < k_tiles:
with L.core.cu():
kv_load_start = tables_dm[batch_idx][k_outer] * k_block_size
kv_load_end = min(k_seq_len, kv_load_start + k_block_size)
kv_tile_len = kv_load_end - kv_load_start
kv_load_version = k_outer % 2
L.cu.start_eidma_async(
cache_dm[kv_load_version, :, :kv_tile_len, :],
KV_cache[batch_idx, :, kv_load_start:kv_load_end, :]
)
# pipeline phase1: merged q@k cube for both nope and pe
if 1 <= k_outer < k_tiles + 1:
with L.core.cu():
qk_version = (k_outer - 1) % 2
L.cu.matmul(
QK_weight[qk_version, :, :, :, :],
Q_dm[:, :, :, :],
cache_dm[qk_version, :, :, :],
transpose_a=True,
transpose_b=False,
)
# pipeline phase2: online softmax step1
if 2 <= k_outer < k_tiles + 2:
with L.core.vcu([0, 1, 2]):
cur_version = (k_outer - 2) % 2
kv_load_start = tables_dm[batch_idx][k_outer - 2] * k_block_size
kv_load_end = min(k_valid_seq_len, kv_load_start + k_block_size)
kv_tile_len = kv_load_end - kv_load_start
# compute global partial max
L.vcu.reduce_max(
partial_max[cur_version, :, :],
QK_weight[cur_version, :, :, :, :],
init=partial_max[1 - cur_version, :, :],
axes=[2, 4],
max_reduce_length=kv_tile_len,
)
# compute exp rescale factors for oldmax vs newmax
L.vcu.subtract(
partial_scale[cur_version, :, :],
partial_max[1 - cur_version, :, :],
partial_max[cur_version, :, :],
)
L.vcu.exp(
partial_scale[cur_version, :, :],
partial_scale[cur_version, :, :],
)
# build and apply casual mask
mask = L.alloc_buffer(
[head_per_tile, k_block_size // lanes, q_block_size, lanes],
L.DATATYPE.FLOAT16, L.DATASCOPE.VM
)
L.vcu.fill_partial_casual_mask(
mask[:, :, :, :],
q_start, q_end, q_block_size,
kv_load_start, kv_load_end, k_block_size,
axes=[1, 2, 3],
)
L.vcu.fadd(
QK_weight[cur_version, :, :, :, :],
QK_weight[cur_version, :, :, :, :],
mask[:, :, :, :]
)
# softmax
L.vcu.softmax_with_precompute_max(
QK_weight[cur_version, :, :, :, :],
QK_weight[cur_version, :, :, :, :],
precompute_max=partial_max[cur_version, :, :],
axes=[2, 4],
scale=softmax_scale,
max_reduce_length=kv_tile_len,
)
L.vcu.unblock(L.cu)
# pipeline phase3: weight@v_nope cube and rescale
if 3 <= k_outer < k_tiles + 3:
weighted_v_version = (k_outer - 3) % 2
with L.core.cu():
L.cu.matmul(
weighted_V[weighted_v_version, :, :, :, :],
QK_weight[weighted_v_version, :, :, :, :],
cache_dm[weighted_v_version, :hidden_outer_nope, :, :],
transpose_a=True,
transpose_b=True,
)
L.odma.unblock(L.vcu3)
with L.core.vcu([3]):
L.vcu.wait_on(L.odma)
partial_scale_broadcast_view = partial_scale.view([2, head_per_tile, 1, q_block_size, 1])
L.vcu.fmadd(
Output_dm[weighted_v_version, :, :, :, :],
weighted_V[weighted_v_version, :, :, :, :],
partial_scale_broadcast_view[weighted_v_version, :, :, :, :],
Output_dm[1 - weighted_v_version, :, :, :, :],
)
L.vcu.unblock(L.cu)
# per-round synchronization
with L.core.cu():
if k_outer < k_tiles:
L.cu.wait_done_eidma()
L.cu.wait_done_eidma()
if 1 <= k_outer < k_tiles + 1:
L.cu.wait_on(L.odma)
if 2 <= k_outer < k_tiles + 2:
L.cu.wait_on(L.vcu0, L.vcu1, L.vcu2)
if 3 <= k_outer < k_tiles + 3:
L.cu.wait_on(L.vcu3)
with L.core.cu():
L.cu.start_eodma_async(
Output[batch_idx, head_start:head_end, q_start:q_end, :],
Output_dm[1 - k_tiles % 2, :, :, :q_tile_len, :]
)
L.cu.wait_done_eodma()
@L.low_level
def mla_decode_compact_pipe_K(Q: L.Buffer[(batch_size, num_head, hidden_outer, 1, lanes), L.DATATYPE.FLOAT16],
KV_cache: L.Buffer[(batch_size, hidden_outer, max_k_seq_len, lanes), L.DATATYPE.FLOAT16],
block_tables: L.Buffer[(batch_size, max_num_blocks_per_seq), L.int32],
cache_pos: L.Buffer[(batch_size,), L.int32],
Output: L.Buffer[(batch_size, num_head, hidden_outer, 1, lanes), L.DATATYPE.FLOAT16],
):
Q_dm = L.alloc_buffer([head_per_tile, hidden_outer, 1, lanes], L.DATATYPE.FLOAT16, L.DATASCOPE.DM)
cache_dm = L.alloc_buffer([2, hidden_outer, k_block_size, lanes], L.DATATYPE.FLOAT16, L.DATASCOPE.DM)
tables_dm = L.alloc_buffer([batch_size, max_num_blocks_per_seq], L.DATATYPE.INT32, L.DATASCOPE.DM)
cache_pos_dm = L.alloc_buffer([batch_size,], L.DATATYPE.INT32, L.DATASCOPE.DM)
QK_weight = L.alloc_buffer([2, head_per_tile, k_block_size // lanes, 1, lanes], L.DATATYPE.FLOAT16, L.DATASCOPE.DM)
weighted_V = L.alloc_buffer([2, head_per_tile, hidden_outer, 1, lanes], L.DATATYPE.FLOAT16, L.DATASCOPE.DM)
Output_dm = L.alloc_buffer([2, head_per_tile, hidden_outer, 1, lanes], L.DATATYPE.FLOAT16, L.DATASCOPE.DM)
partial_max = L.alloc_buffer([2, head_per_tile, 1], L.DATATYPE.FLOAT16, L.DATASCOPE.DM)
partial_scale = L.alloc_buffer([2, head_per_tile, 1], L.DATATYPE.FLOAT16, L.DATASCOPE.DM)
with L.core.cu():
L.cu.start_eidma_async_const(partial_max[:, :, :], L.float16('-inf'))
L.cu.wait_done_eidma()
L.cu.start_eidma_async(cache_pos_dm[:], cache_pos[:])
L.cu.wait_done_eidma()
for batch_idx in range(batch_size):
with L.core.cu():
L.cu.start_eidma_async(tables_dm[:, :], block_tables[batch_idx, :])
L.cu.wait_done_eidma()
k_valid_seq_len = min(cache_pos_dm[batch_idx], max_k_seq_len)
for head_outer in range(head_tiles):
head_start = head_outer * head_per_tile
head_end = head_start + head_per_tile
L.cu.start_eidma_async(
Q_dm[:, :, 0, :],
Q[batch_idx, head_start:head_end, 0, :]
)
L.cu.wait_done_eidma()
k_tiles = (k_valid_seq_len + k_block_size - 1) // k_block_size
for k_outer in range(k_tiles + 3):
# pipeline phase0: trigger external data loads
if k_outer < k_tiles:
with L.core.cu():
kv_load_start = tables_dm[batch_idx][k_outer] * k_block_size
kv_load_end = min(k_valid_seq_len, kv_load_start + k_block_size)
kv_tile_len = kv_load_end - kv_load_start
kv_load_version = k_outer % 2
L.cu.start_eidma_async(
cache_dm[kv_load_version, :, :kv_tile_len, :],
KV_cache[batch_idx, :, kv_load_start:kv_load_end, :]
)
# pipeline phase1: merged q@k cube for both nope and pe
if 1 <= k_outer < k_tiles + 1:
with L.core.cu():
qk_version = (k_outer - 1) % 2
L.cu.matmul(
QK_weight[qk_version, :, :, :, :],
Q_dm[:, :, :, :],
cache_dm[qk_version, :, :, :],
transpose_a=True,
transpose_b=False,
)
# pipeline phase2: online softmax step1
if 2 <= k_outer < k_tiles + 2:
with L.core.vcu([0, 1, 2]):
cur_version = (k_outer - 2) % 2
kv_load_start = tables_dm[batch_idx][k_outer - 2] * k_block_size
kv_load_end = min(k_valid_seq_len, kv_load_start + k_block_size)
kv_tile_len = kv_load_end - kv_load_start
# compute global partial max
L.vcu.reduce_max(
partial_max[cur_version, :, :],
QK_weight[cur_version, :, :, :, :],
init=partial_max[1 - cur_version, :, :],
axes=[2, 4],
max_reduce_length=kv_tile_len,
)
# compute exp rescale factors for oldmax vs newmax
L.vcu.subtract(
partial_scale[cur_version, :, :],
partial_max[1 - cur_version, :, :],
partial_max[cur_version, :, :],
)
L.vcu.exp(
partial_scale[cur_version, :, :],
partial_scale[cur_version, :, :],
)
# softmax
L.vcu.softmax_with_precompute_max(
QK_weight[cur_version, :, :, :, :],
QK_weight[cur_version, :, :, :, :],
precompute_max=partial_max[cur_version, :, :],
axes=[2, 4],
scale=softmax_scale,
max_reduce_length=kv_tile_len,
)
L.vcu.unblock(L.cu)
# pipeline phase3: weight@v_nope cube and rescale
if 3 <= k_outer < k_tiles + 3:
weighted_v_version = (k_outer - 3) % 2
with L.core.cu():
L.cu.matmul(
weighted_V[weighted_v_version, :, :, :, :],
QK_weight[weighted_v_version, :, :, :, :],
cache_dm[weighted_v_version, :hidden_outer_nope, :, :],
transpose_a=True,
transpose_b=True,
)
L.odma.unblock(L.vcu3)
with L.core.vcu([3]):
L.vcu.wait_on(L.odma)
partial_scale_broadcast_view = partial_scale.view([2, head_per_tile, 1, 1, 1])
L.vcu.fmadd(
Output_dm[weighted_v_version, :, :, :, :],
weighted_V[weighted_v_version, :, :, :, :],
partial_scale_broadcast_view[weighted_v_version, :, :, :, :],
Output_dm[1 - weighted_v_version, :, :, :, :],
)
L.vcu.unblock(L.cu)
# per-round synchronization
with L.core.cu():
if k_outer < k_tiles:
L.cu.wait_done_eidma()
L.cu.wait_done_eidma()
if 1 <= k_outer < k_tiles + 1:
L.cu.wait_on(L.odma)
if 2 <= k_outer < k_tiles + 2:
L.cu.wait_on(L.vcu0, L.vcu1, L.vcu2)
if 3 <= k_outer < k_tiles + 3:
L.cu.wait_on(L.vcu3)
with L.core.cu():
L.cu.start_eodma_async(
Output[batch_idx, head_start:head_end, 0, :],
Output_dm[1 - k_tiles % 2, :, :, 0, :]
)
L.cu.wait_done_eodma()
if q_seq_len == 1:
return mla_decode_compact_pipe_K
else:
return mla_prefill_compact_pipe_K
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/Intellifusion_2025/tyllm.git
git@gitee.com:Intellifusion_2025/tyllm.git
Intellifusion_2025
tyllm
tyllm
master

搜索帮助