diff --git "a/ACL_PyTorch/docs/01.\345\234\250\347\272\277\346\216\250\347\220\206.md" "b/ACL_PyTorch/docs/ONNX/01.\345\234\250\347\272\277\346\216\250\347\220\206.md" similarity index 100% rename from "ACL_PyTorch/docs/01.\345\234\250\347\272\277\346\216\250\347\220\206.md" rename to "ACL_PyTorch/docs/ONNX/01.\345\234\250\347\272\277\346\216\250\347\220\206.md" diff --git "a/ACL_PyTorch/docs/02.ONNX\347\232\204\345\257\274\345\207\272.md" "b/ACL_PyTorch/docs/ONNX/02.ONNX\347\232\204\345\257\274\345\207\272.md" similarity index 100% rename from "ACL_PyTorch/docs/02.ONNX\347\232\204\345\257\274\345\207\272.md" rename to "ACL_PyTorch/docs/ONNX/02.ONNX\347\232\204\345\257\274\345\207\272.md" diff --git "a/ACL_PyTorch/docs/03.ONNX\350\275\254OM.md" "b/ACL_PyTorch/docs/ONNX/03.ONNX\350\275\254OM.md" similarity index 100% rename from "ACL_PyTorch/docs/03.ONNX\350\275\254OM.md" rename to "ACL_PyTorch/docs/ONNX/03.ONNX\350\275\254OM.md" diff --git "a/ACL_PyTorch/docs/04.\347\246\273\347\272\277\346\216\250\347\220\206.md" "b/ACL_PyTorch/docs/ONNX/04.\347\246\273\347\272\277\346\216\250\347\220\206.md" similarity index 100% rename from "ACL_PyTorch/docs/04.\347\246\273\347\272\277\346\216\250\347\220\206.md" rename to "ACL_PyTorch/docs/ONNX/04.\347\246\273\347\272\277\346\216\250\347\220\206.md" diff --git "a/ACL_PyTorch/docs/05.01.ResNet50 \344\275\277\347\224\250\351\235\231\346\200\201 AIPP \347\232\204\346\241\210\344\276\213.md" "b/ACL_PyTorch/docs/ONNX/05.01.ResNet50 \344\275\277\347\224\250\351\235\231\346\200\201 AIPP \347\232\204\346\241\210\344\276\213.md" similarity index 100% rename from "ACL_PyTorch/docs/05.01.ResNet50 \344\275\277\347\224\250\351\235\231\346\200\201 AIPP \347\232\204\346\241\210\344\276\213.md" rename to "ACL_PyTorch/docs/ONNX/05.01.ResNet50 \344\275\277\347\224\250\351\235\231\346\200\201 AIPP \347\232\204\346\241\210\344\276\213.md" diff --git "a/ACL_PyTorch/docs/05.02.GLIP\343\200\201CLIP \346\200\247\350\203\275\344\274\230\345\214\226\346\241\210\344\276\213.md" "b/ACL_PyTorch/docs/ONNX/05.02.GLIP\343\200\201CLIP \346\200\247\350\203\275\344\274\230\345\214\226\346\241\210\344\276\213.md" similarity index 100% rename from "ACL_PyTorch/docs/05.02.GLIP\343\200\201CLIP \346\200\247\350\203\275\344\274\230\345\214\226\346\241\210\344\276\213.md" rename to "ACL_PyTorch/docs/ONNX/05.02.GLIP\343\200\201CLIP \346\200\247\350\203\275\344\274\230\345\214\226\346\241\210\344\276\213.md" diff --git "a/ACL_PyTorch/docs/05.\346\200\247\350\203\275\344\274\230\345\214\226.md" "b/ACL_PyTorch/docs/ONNX/05.\346\200\247\350\203\275\344\274\230\345\214\226.md" similarity index 100% rename from "ACL_PyTorch/docs/05.\346\200\247\350\203\275\344\274\230\345\214\226.md" rename to "ACL_PyTorch/docs/ONNX/05.\346\200\247\350\203\275\344\274\230\345\214\226.md" diff --git a/ACL_PyTorch/docs/README.md b/ACL_PyTorch/docs/ONNX/README.md similarity index 100% rename from ACL_PyTorch/docs/README.md rename to ACL_PyTorch/docs/ONNX/README.md diff --git a/ACL_PyTorch/docs/images/ATC.png b/ACL_PyTorch/docs/ONNX/images/ATC.png similarity index 100% rename from ACL_PyTorch/docs/images/ATC.png rename to ACL_PyTorch/docs/ONNX/images/ATC.png diff --git a/ACL_PyTorch/docs/images/Attention.png b/ACL_PyTorch/docs/ONNX/images/Attention.png similarity index 100% rename from ACL_PyTorch/docs/images/Attention.png rename to ACL_PyTorch/docs/ONNX/images/Attention.png diff --git a/ACL_PyTorch/docs/images/ConvTranspose+Add.png b/ACL_PyTorch/docs/ONNX/images/ConvTranspose+Add.png similarity index 100% rename from ACL_PyTorch/docs/images/ConvTranspose+Add.png rename to ACL_PyTorch/docs/ONNX/images/ConvTranspose+Add.png diff --git a/ACL_PyTorch/docs/images/ConvTranspose+Add_after.png b/ACL_PyTorch/docs/ONNX/images/ConvTranspose+Add_after.png similarity index 100% rename from ACL_PyTorch/docs/images/ConvTranspose+Add_after.png rename to ACL_PyTorch/docs/ONNX/images/ConvTranspose+Add_after.png diff --git a/ACL_PyTorch/docs/images/ConvTranspose+Add_before.png b/ACL_PyTorch/docs/ONNX/images/ConvTranspose+Add_before.png similarity index 100% rename from ACL_PyTorch/docs/images/ConvTranspose+Add_before.png rename to ACL_PyTorch/docs/ONNX/images/ConvTranspose+Add_before.png diff --git a/ACL_PyTorch/docs/images/DeformableConv_after.png b/ACL_PyTorch/docs/ONNX/images/DeformableConv_after.png similarity index 100% rename from ACL_PyTorch/docs/images/DeformableConv_after.png rename to ACL_PyTorch/docs/ONNX/images/DeformableConv_after.png diff --git a/ACL_PyTorch/docs/images/DeformableConv_before.png b/ACL_PyTorch/docs/ONNX/images/DeformableConv_before.png similarity index 100% rename from ACL_PyTorch/docs/images/DeformableConv_before.png rename to ACL_PyTorch/docs/ONNX/images/DeformableConv_before.png diff --git a/ACL_PyTorch/docs/images/DeformableConv_om.png b/ACL_PyTorch/docs/ONNX/images/DeformableConv_om.png similarity index 100% rename from ACL_PyTorch/docs/images/DeformableConv_om.png rename to ACL_PyTorch/docs/ONNX/images/DeformableConv_om.png diff --git a/ACL_PyTorch/docs/images/FAfp32.png b/ACL_PyTorch/docs/ONNX/images/FAfp32.png similarity index 100% rename from ACL_PyTorch/docs/images/FAfp32.png rename to ACL_PyTorch/docs/ONNX/images/FAfp32.png diff --git a/ACL_PyTorch/docs/images/ILSVRC2012_val_00006083.jpeg b/ACL_PyTorch/docs/ONNX/images/ILSVRC2012_val_00006083.jpeg similarity index 100% rename from ACL_PyTorch/docs/images/ILSVRC2012_val_00006083.jpeg rename to ACL_PyTorch/docs/ONNX/images/ILSVRC2012_val_00006083.jpeg diff --git a/ACL_PyTorch/docs/images/LayerNorm_after.png b/ACL_PyTorch/docs/ONNX/images/LayerNorm_after.png similarity index 100% rename from ACL_PyTorch/docs/images/LayerNorm_after.png rename to ACL_PyTorch/docs/ONNX/images/LayerNorm_after.png diff --git a/ACL_PyTorch/docs/images/LayerNorm_before.png b/ACL_PyTorch/docs/ONNX/images/LayerNorm_before.png similarity index 100% rename from ACL_PyTorch/docs/images/LayerNorm_before.png rename to ACL_PyTorch/docs/ONNX/images/LayerNorm_before.png diff --git a/ACL_PyTorch/docs/images/NMS.png b/ACL_PyTorch/docs/ONNX/images/NMS.png similarity index 100% rename from ACL_PyTorch/docs/images/NMS.png rename to ACL_PyTorch/docs/ONNX/images/NMS.png diff --git a/ACL_PyTorch/docs/images/NMS_after.png b/ACL_PyTorch/docs/ONNX/images/NMS_after.png similarity index 100% rename from ACL_PyTorch/docs/images/NMS_after.png rename to ACL_PyTorch/docs/ONNX/images/NMS_after.png diff --git a/ACL_PyTorch/docs/images/NMS_before.png b/ACL_PyTorch/docs/ONNX/images/NMS_before.png similarity index 100% rename from ACL_PyTorch/docs/images/NMS_before.png rename to ACL_PyTorch/docs/ONNX/images/NMS_before.png diff --git a/ACL_PyTorch/docs/images/Reshape.png b/ACL_PyTorch/docs/ONNX/images/Reshape.png similarity index 100% rename from ACL_PyTorch/docs/images/Reshape.png rename to ACL_PyTorch/docs/ONNX/images/Reshape.png diff --git a/ACL_PyTorch/docs/images/Resize_after.png b/ACL_PyTorch/docs/ONNX/images/Resize_after.png similarity index 100% rename from ACL_PyTorch/docs/images/Resize_after.png rename to ACL_PyTorch/docs/ONNX/images/Resize_after.png diff --git a/ACL_PyTorch/docs/images/Resize_before.png b/ACL_PyTorch/docs/ONNX/images/Resize_before.png similarity index 100% rename from ACL_PyTorch/docs/images/Resize_before.png rename to ACL_PyTorch/docs/ONNX/images/Resize_before.png diff --git a/ACL_PyTorch/docs/images/Resize_scales_after.png b/ACL_PyTorch/docs/ONNX/images/Resize_scales_after.png similarity index 100% rename from ACL_PyTorch/docs/images/Resize_scales_after.png rename to ACL_PyTorch/docs/ONNX/images/Resize_scales_after.png diff --git a/ACL_PyTorch/docs/images/Resize_scales_before.png b/ACL_PyTorch/docs/ONNX/images/Resize_scales_before.png similarity index 100% rename from ACL_PyTorch/docs/images/Resize_scales_before.png rename to ACL_PyTorch/docs/ONNX/images/Resize_scales_before.png diff --git a/ACL_PyTorch/docs/images/Split_after.png b/ACL_PyTorch/docs/ONNX/images/Split_after.png similarity index 100% rename from ACL_PyTorch/docs/images/Split_after.png rename to ACL_PyTorch/docs/ONNX/images/Split_after.png diff --git a/ACL_PyTorch/docs/images/Split_before.png b/ACL_PyTorch/docs/ONNX/images/Split_before.png similarity index 100% rename from ACL_PyTorch/docs/images/Split_before.png rename to ACL_PyTorch/docs/ONNX/images/Split_before.png diff --git a/ACL_PyTorch/docs/images/Swin.png b/ACL_PyTorch/docs/ONNX/images/Swin.png similarity index 100% rename from ACL_PyTorch/docs/images/Swin.png rename to ACL_PyTorch/docs/ONNX/images/Swin.png diff --git a/ACL_PyTorch/docs/images/pad_after.png b/ACL_PyTorch/docs/ONNX/images/pad_after.png similarity index 100% rename from ACL_PyTorch/docs/images/pad_after.png rename to ACL_PyTorch/docs/ONNX/images/pad_after.png diff --git a/ACL_PyTorch/docs/images/pad_before.png b/ACL_PyTorch/docs/ONNX/images/pad_before.png similarity index 100% rename from ACL_PyTorch/docs/images/pad_before.png rename to ACL_PyTorch/docs/ONNX/images/pad_before.png diff --git "a/ACL_PyTorch/docs/torchair/images/torch_npu\344\274\230\345\214\226\346\265\201\347\250\213.png" "b/ACL_PyTorch/docs/torchair/images/torch_npu\344\274\230\345\214\226\346\265\201\347\250\213.png" new file mode 100644 index 0000000000000000000000000000000000000000..5877bacad423ef748c32772a25a7c02388862d65 Binary files /dev/null and "b/ACL_PyTorch/docs/torchair/images/torch_npu\344\274\230\345\214\226\346\265\201\347\250\213.png" differ diff --git "a/ACL_PyTorch/docs/torchair/torch_npu/torch_npu\350\277\201\347\247\273\351\200\202\351\205\215.md" "b/ACL_PyTorch/docs/torchair/torch_npu/torch_npu\350\277\201\347\247\273\351\200\202\351\205\215.md" new file mode 100644 index 0000000000000000000000000000000000000000..088753cc645df4eff7765e1d7c3600bc19f424ab --- /dev/null +++ "b/ACL_PyTorch/docs/torchair/torch_npu/torch_npu\350\277\201\347\247\273\351\200\202\351\205\215.md" @@ -0,0 +1,43 @@ +1. 导入NPU相关库 + +```python +import torch +import torch_npu +``` + +2. 指定NPU作为运行设备 +* .to(device)方式 + + 迁移前: + + ```python + device = torch.device('cuda:{})'.format(local_rank)) + model.to(device) + data.to(device) + ``` + + 迁移后: + + ```python + device = torch.device('npu:{}'.format(local_rank)) + model.to(device) + data.to(device) + ``` + +* set_device方式 + + 迁移前: + + ```python + torch.cuda.set_device(local_rank) + ``` + + 迁移后: + + ```python + torch_npu.npu.set_device(local_rank) + ``` + + 更多torch_npu迁移指导请参考:[PyTorch 模型手工迁移指南](https://www.hiascend.com/document/detail/zh/Pytorch/60RC1/ptmoddevg/trainingmigrguide/PT_LMTMOG_0018.html) + +3. 运行模型,确保迁移后可以正确运行。 \ No newline at end of file diff --git "a/ACL_PyTorch/docs/torchair/torch_npu/\345\270\270\347\224\250\344\274\230\345\214\226\347\202\271.md" "b/ACL_PyTorch/docs/torchair/torch_npu/\345\270\270\347\224\250\344\274\230\345\214\226\347\202\271.md" new file mode 100644 index 0000000000000000000000000000000000000000..8c9a6201a409a3579dbc2b14c1413b0a976e1d52 --- /dev/null +++ "b/ACL_PyTorch/docs/torchair/torch_npu/\345\270\270\347\224\250\344\274\230\345\214\226\347\202\271.md" @@ -0,0 +1,277 @@ +# 基本优化流程 +模型性能设计包括算法在内的多个模块,因此模型性能的优化的关键在于找到当前性能瓶颈,找到关键问题后再针对性优化,优化流程如下: + +![image](../images/torch_npu优化流程.png) + +1. 参考[性能调优工具介绍](https://www.hiascend.com/document/detail/zh/Pytorch/710/ptmoddevg/trainingmigrguide/performance_tuning_0014.html),选择对应的性能工具,采集性能数据并拆解性能,找到需要提升性能的模块。 +2. 在明确性能瓶颈模块后,将问题细化定位到下发、计算等模块,并采用对应的优化手段。 + +# 常用优化点 +## cpu操作迁移至npu +该步骤为**重要优化点**! + +* 有一些算子或者数据类型不支持在npu上运行,可以用等价的npu算子,或者逻辑等价的代码去替换,让这段代码可以在npu上运行 + + 示例: + + ```python + # 比如语音模型中常用到傅里叶变换,output为复数,npu不支持复数类型计算,可以用torch.view_as_real()提取实数和虚数部分再进行计算。 + ''' 原始代码 + audio = torch.from_numpy(audio) # cpu tensor + window = torch.hann_window(N_FFT) # cpu tensor + stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True) # stft为复数 + magnitudes = stft[..., -1].abs() ** 2 # npu不支持对复数做abs()计算 + ''' + + # 替换为: + audio = torch.from_numpy(audio).npu() # 把tensor放到npu上 + window = torch.hann_window(N_FFT).to(audio.device) + stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True) + stft_real_img = torch.view_as_real(stft)[..., :-1, :] # 将复数的实数和虚数部分拆解出来 + real_part = stft_real_img[..., 0] + img_part = stft_real_img[..., 1] + magnitudes = real_part ** 2 + img_part ** 2 # 数学逻辑等价于复数.abs()**2 + ``` + +* 将cpu tensor转到npu上 + + 示例: + ```python + ''' 原始代码 + waveforms = torch.empty(0) + feats_pad = torch.empty(0) + feats_lens = torch.empty(0) + ''' + # 修改为: + waveforms = torch.empty(0, device=kwargs["device"]) + feats_pad = torch.empty(0, device=kwargs["device"]) + feats_lens = torch.empty(0, device=kwargs["device"]) + ``` + +## 融合算子替换 +通过数学意义上的等价替换,将多个算子融合为一个算子的计算,减少冗余计算,同时减少下发次数,从而提高性能。 + +### RotaryMul +transformer结构中通常会用到位置编码,旋转位置编码(Rotary Position Embedding)是一种常用的位置编码方式。旋转编码一般在transformer的Attention模块中,在计算完query, key的映射后加入旋转位置编码。 + +torch_npu接口: +```python +torch_npu.npu_rotary_mul(x, r1, r2) +``` + +参数说明: +* x: q, k, shape要求输入为4维,一般为[B, N, S, D]或[B, S, N, D]或[S, B, N, D]。 +* r1:cos值,shape要求输入为4维,一般为[1, 1, S, D]或[1, S, 1, D]或[S, 1, 1, D]。 +* r2:sin值,shape要求输入为4维,一般为[1, 1, S, D]或[1, S, 1, D]或[S, 1, 1, D]。 + +示例: +```python +'''原始接口 +def rotate_half(x): + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2:] + return torch.cat((-x2, x1), dim=1) + +def apply_rotary_pos_emb(q, k, cos, sin, offset: int = 0): + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed +''' + +# 替换为: +def apply_fused_rotary_pos_emb(q, k, cos, sin, offset: int = 0): + return torch_npu.npu_rotary_mul(q, cos, sin), torch_npu.npu_rotary_mul(k, cos, sin) +``` + +**使用限制** +目前算子仅支持r1、r2需要broadcast为x的shape的情形,且算子shape中最后一维D必须是128的倍数。 + +### FlashAttention算子 +如果原始代码中使用小算子来计算attention score,我们可以使用npu的flash attention融合算子提升性能,prefill阶段使用[torch_npu.npu_prompt_flash_attention](https://www.hiascend.com/document/detail/zh/Pytorch/710/apiref/torchnpuCustomsapi/context/torch_npu-npu_prompt_flash_attention.md),decode阶段使用[torch_npu.npu_incre_flash_attention](https://www.hiascend.com/document/detail/zh/Pytorch/710/apiref/torchnpuCustomsapi/context/torch_npu-npu_incre_flash_attention.md) + +小算子样例: +```python +scores = torch.matmul(q, k.transpose(-2, -1)) +scores = scores.masked_fill(mask, min_value) +attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0) +x = torch.matmul(attn, value) +``` + +重写attention模块的forward方法: +```python +# Prefill用PFA +... +n_ctx = q.shape[1] +... +if n_ctx > 1: + mask = mask.to(torch.bool) if mask is not None and n_ctx > 1 else None + sparse_mode = 1 if mask is not None and n_ctx > 1 else 0 + attn = torch_npu.npu_prompt_flash_attention( + q.contiguous(), + k.contiguous(), + v.contiguous(), + num_heads=self.n_head, + input_layout="BNSD", + scale_value=1 / math.sqrt(D), + atten_mask=mask[:n_ctx, :n_ctx] if mask is not None else None, + sparse_mode=sparse_mode + ) +# Decode用IFA +else: + attn = torch_npu.npu_incre_flash_attention( + q.contiguous(), + k.contiguous(), + v.contiguous(), + num_heads=self.n_head, + input_layout="BNSD", + scale_value=1 / math.sqrt(D), + atten_mask=None, + actual_seq_lengths=actual_seq_len, + kv_padding_size=kv_padding_size + ) +``` + +其他融合算子请参考:[融合算子](https://www.hiascend.com/document/detail/zh/Pytorch/710/ptmoddevg/trainingmigrguide/performance_tuning_0031.html) + +## qkv融合 + +优化思路:针对带有attention模块的模型,可以将q, k, v三个矩阵替换为一个矩阵乘,最大化使用npu的计算能力提升性能 + +实现方式: + +```python +class Model(nn.Module): + def __init__(self): + super().__init__() + # 新增qkv大矩阵 + self.qkv = nn.Linear(self.hidden_size, self.num_heads * self.head_dim + + 2 * self.num_key_value_heads * self.head_dim, bias=False) + ... + + def forward(self, hidden_states): + '''原始qkv linear层 + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + ''' + + # 替换为 + qkv_states = self.qkv(hidden_states) + query_states, key_states, value_states = qkv_states.split( + [self.q_hidden_size, self.kv_hidden_size, self.kv_hidden_size], dim=2) + + # 权重融合,将原始q,k,v linear的权重放到qkv大矩阵中 + def merge_qkv_weight(self, tp_size=1): + def _to_parameter(data): + return nn.Parameter(data, requires_grad=False) + + qw_size = self.model.layers[0].self_attn.q_proj.weight.shape + kw_size = self.model.layers[0].self_attn.k_proj.weight.shape + vw_size = self.model.layers[0].self_attn.v_proj.weight.shape + + q_sliced_size = qw_size[0] // tp_size + k_sliced_size = kw_size[0] // tp_size + v_sliced_size = vw_size[0] // tp_size + + for i in range(len(self.model.layers)): + qw = self.model.layers[i].self_attn.q_proj.weight + kw = self.model.layers[i].self_attn.k_proj.weight + vw = self.model.layers[i].self_attn.v_proj.weight + + weight_list = [] + for j in range(tp_size): + sliced_qw = qw[j * q_sliced_size: (j + 1) * q_sliced_size, :] + sliced_kw = kw[j * k_sliced_size: (j + 1) * k_sliced_size, :] + sliced_vw = vw[j * v_sliced_size: (j + 1) * v_sliced_size, :] + weight_list.append(_to_parameter(torch.cat([sliced_qw, sliced_kw, sliced_vw], axis=0))) + + if len(weight_list) == 1: + self.model.layers[i].self_attn.qkv.weight = weight_list[0] + else: + self.model.layers[i].self_attn.qkv.weight = _to_parameter(torch.cat(weight_list, axis=0)) +``` + + + +### 固定kv cache大小 + +优化原因:如果kv cache是作为模型的输入,在模型中cat拼接后返回新的kv cache,这种更新方式存在多次申请内存及拷贝的性能损失。 + +优化方式:根据句子最大长度申请好一块固定大小的kv cache tensor,然后通过scatter_update_算子对指定位置上的kv cache进行更新 + +以transformers的llama源码为例: + +```python +# transformers/models/llama/modeling_llama.py +# LlamaForCausalLM的prepare_inputs_for_generation函数新增逻辑 +# 固定kv cache的大小,用作全量图和增量图的kv cache更新 +batch_size, seq_length = input_ids.shape +use_dtype = self.model.torch_dtype +if past_key_values is None: + kv_shape = ( + batch_size, self.model.max_position_embeddings, self.model.num_key_value_heads // self.world_size, + self.model.hidden_size // self.model.num_attention_heads) + past_key_values = () + for i in range(self.model.num_hidden_layers): + k_cache = torch.zeros(kv_shape, dtype=use_dtype, device=input_ids.device) + v_cache = torch.zeros(kv_shape, dtype=use_dtype, device=input_ids.device) + past_key_values += ((k_cache, v_cache),) +``` + +更新kv的改动: +```python +# 更新指定位置上的kv cache,position_ids在全量图执行时从seq_len 0的位置更新,在增量图执行时从seq_len位置更新 +tmp_ids = updated_kv_positions.reshape(-1) +# format BSND, 1 means seq_len dim index +torch_npu.scatter_update_(past_key_value[0], tmp_ids, key_states, 1) +torch_npu.scatter_update_(past_key_value[1], tmp_ids, value_states, 1) + +key_states1 = past_key_value[0] if q_len == 1 else key_states +value_states1 = past_key_value[1] if q_len == 1 else value_states + +past_key_value = past_key_value if use_cache else None +``` + +固定kv cache后,由于shape变化带来的其他改动: +```python +# prepare_inputs_for_generation函数中新增创建attention_mask以及更新kv位置tensor +# 主要原因是全量和增量流程对于attention_mask的shape要求不一样,kv使用scatter_update更新需要指定更新位置 +past_key_values_length = 0 +if seq_length > 1: + if attention_mask is None: + attention_mask = torch.ones((batch_size, seq_length), dtype=torch.bool, device=input_ids.device) + self.padding_mask = torch.zeros(batch_size, self.model.max_position_embeddings, device=input_ids.device) + self.prompt_length = seq_length + self.updated_kv_positions = torch.zeros(batch_size, dtype=position_ids.dtype, device=position_ids.device) +else: + bsz, src_len = attention_mask.size() + padding_mask = self.padding_mask + padding_mask[:, :src_len] = attention_mask + attention_mask = padding_mask + past_key_values_length = self.model.max_position_embeddings + self.prompt_length += 1 + self.updated_kv_positions = torch.ones(position_ids.shape, dtype=position_ids.dtype, + device=position_ids.device) * (self.prompt_length - 1) + +attention_mask = self.model._prepare_decoder_attention_mask( + attention_mask, (batch_size, seq_length), past_key_values[0][0], past_key_values_length +) +``` + +## 亲和算子替换 +### IndexPut算子 +遇到矩阵索引操作时,可以用乘法操作替代,避免多次进行小数据的随机访问,提升性能。 + +示例: +```python +target_mask = (target < vocab_start_index) | (target >= vocab_end_index) +masked_target = target.clone() - vocab_start_index +masked_target[target_mask] = 0 +``` +替换为: +```python +target_mask = (target < vocab_start_index) | (target >= vocab_end_index) +masked_target = target.clone() - vocab_start_index +masked_target *= ~target_mask +``` +其他亲和算子替换可参考:[亲和算子替换](https://www.hiascend.com/document/detail/zh/Pytorch/710/ptmoddevg/trainingmigrguide/performance_tuning_0042.html) \ No newline at end of file diff --git "a/ACL_PyTorch/docs/torchair/\345\233\276\344\274\230\345\214\226/1.\345\211\215\347\275\256\345\210\206\346\236\220\345\217\212\347\272\246\346\235\237\346\235\241\344\273\266.md" "b/ACL_PyTorch/docs/torchair/\345\233\276\344\274\230\345\214\226/1.\345\211\215\347\275\256\345\210\206\346\236\220\345\217\212\347\272\246\346\235\237\346\235\241\344\273\266.md" new file mode 100644 index 0000000000000000000000000000000000000000..cfc48adbc9d839e02809100e9f7defe14264acee --- /dev/null +++ "b/ACL_PyTorch/docs/torchair/\345\233\276\344\274\230\345\214\226/1.\345\211\215\347\275\256\345\210\206\346\236\220\345\217\212\347\272\246\346\235\237\346\235\241\344\273\266.md" @@ -0,0 +1,10 @@ +## 1. TorchAir适用场景 +TorchAir(Torch Ascend Intermediate Representation)对接Pytorch的Dynamo特性,将Pytorch的FX(functionalize)图转化为Ascend IR,通过Graph Engine进行计算图的编译优化等操作,并下发到昇腾硬件执行。成图主要是为了解决host-bound问题,若性能瓶颈不在此,应该参考[torch_npu常用优化点](../torch_npu/常用优化点.md)做针对性的性能优化。 +* Host Bound问题在profiling中的表现为"Overlap Analysis"计算占比低,free占比高。可参考[Host Bound问题分析](https://www.hiascend.com/document/detail/zh/mindstudio/81RC1/practicalcases/GeneralPerformanceIssue/toolsample6_053.html) +* 入图部分算子必须支持converter,[ATen API支持清单](https://www.hiascend.com/document/detail/zh/Pytorch/710/modthirdparty/torchairuseguide/torchair_00040.html)列出了支持入图的ATen API,这些API能力均对等Eager模式下的ATen API能力。如果自定义模型用到的ATen API不在该列表中,说明对应的API能力可能不完备,用户需根据实际情况进行converter功能扩展,具体步骤参考:[converter补齐](https://gitee.com/ascend/torchair/blob/7.1.0/CONTRIBUTING.md#converter%E8%A1%A5%E9%BD%90) +* 动态shape可控:当前torchair支持10个档位分档,多了会有性能劣化的问题;动态分档适用于某一维度的变化,比如batch size。 + + +## 2. 约束条件 + +* PyTorch图模式支持单进程和多进程,每个进程只支持使用1张NPU卡,不支持使用多张NPU卡。 \ No newline at end of file diff --git "a/ACL_PyTorch/docs/torchair/\345\233\276\344\274\230\345\214\226/2.\350\277\201\347\247\273\351\200\202\351\205\215.md" "b/ACL_PyTorch/docs/torchair/\345\233\276\344\274\230\345\214\226/2.\350\277\201\347\247\273\351\200\202\351\205\215.md" new file mode 100644 index 0000000000000000000000000000000000000000..1bd31e49e8ec9357be81a69dc732239a6612e354 --- /dev/null +++ "b/ACL_PyTorch/docs/torchair/\345\233\276\344\274\230\345\214\226/2.\350\277\201\347\247\273\351\200\202\351\205\215.md" @@ -0,0 +1,281 @@ +## 基础使用方法 + +1. 导入TorchAir相关库 + + ```python + import torchair + ``` + +2. 配置compile的backend + + ```python + config = torchair.CompilerConfig() + npu_backend = torchair.get_npu_backend(compiler_config=config) + # 创建模型实例 + model = MyModel() + # 编译模型 + model = torch.compile(model, backend=npu_backend) + # 或者只编译模型的子模块 + model.encoder = torch.compile(model.encoder, backend=npu_backend) + model.decoder = torch.compile(model.decoder, backend=npu_backend) + # 或者只编译模型的一个方法 + model.forward = torch.compile(model.forward, backend=npu_backend) + ``` + +## huggingface模型示例 + +从huggingface的Model Card的sample code迁移时,注意不要直接compile整个pipeline,而是将pipeline中的模型编译。 + +示例: + + ```python + from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline + model_id = "openai/whisper-large-v3" + model = AutoModelForSpeechSeq2Seq.from_pretrained(model_id) + # 编译模型 + model = torch.compile(model) + processor = AutoProcessor.from_pretrained(model_id) + pipe = pipeline( + "automatic-speech-recognition", + model=model, # 将编译后的模型作为参数传入pipeline中 + tokenizer=processor.tokenizer, + feature_extractor=processor.feature_extractor, + torch_dtype=torch_dtype, + device=device, + ) + # 执行推理 + result = pipeline(sample) + ``` + +## 由多个子模块拼接的模型 +有些模型由多个子模块拼接,在实际调用的时候或许并没有走自己的forward函数,而是通过其他接口去调用子模块。此时需要通过代码调试、代码走读等方式找到实际推理时调用的子模块。 + +示例: + + ```python + # chinese clip模型 + import torch + from cn_clip.clip import load_from_name + device = torch.device('npu:0') + model, preprocess = load_from_name('ViT-B-16', device=device) + + # 推理接口 + with torch.no_grad(): + logits_per_image, logits_per_text = model.get_similarity(image, text) + ``` + +直接打印模型,模型结构如下: + + ``` + CLIP( + (visual): VisualTransformer( + ... + ) + (bert): BertModel( + ... + ) + ) + ``` + +API接口具体实现如下: + + ```python + ... + def encode_image(self, image, mask_ratio=0): + if isinstance(self.visual, ModifiedResNet): + # 调用visual submodule + return self.visual(image.type(self.dtype)) + return self.visual(image.type(self.dtype), mask_ratio) + + def encode_text(self, text): + ... + # 调用bert submodule + x = self.bert(text, attention_mask=attn_mask)[0].type(self.dtype) + return x[:, 0, :] @ self.text_projection + + def get_similarity(self, image, text): + # 处理image + image_features = self.encode_image(image) + # 处理text + text_features = self.encode_text(text) + + # normalized features + image_features = image_features / image_features.norm(dim=1, keepdim=True) + text_features = text_features / text_features.norm(dim=1, keepdim=True) + + # cosine similarity as logits + logit_scale = self.logit_scale.exp() + logits_per_image = logit_scale * image_features @ text_features.t() + logits_per_text = logits_per_image.t() + + return logits_per_image, logits_per_text + ``` + +通过以上模型结构和推理接口的具体实现可以看出CLIP模型实际调用了visual和bert这两个子模块的forward,因此compile的编译对象为visual和bert模型。 + + ```python + config = torchair.CompilerConfig() + npu_backend = torchair.get_npu_backend(compiler_config=config) + model, preprocess = load_from_name('ViT-B-16', device=device) + model.visual = torch.compile(model.visual, backend=npu_backend) + model.bert = torch.compile(model.bert, backend=npu_backend) + ``` + +## 常用编译配置 +#### 1. torchair.CompilerConfig类 +通过torchair.CompilerConfig类的experimental_config属性可以打开一些提升性能的功能: + +* experimental_config.frozen_parameter:推理场景下,该功能将模型执行期间地址不变的Tensor标识为[Parameter类型](https://docs.pytorch.org/docs/2.1/generated/torch.nn.parameter.Parameter.html),从而缩短图下发时间,提升下发性能。推荐设置为True。 +* experimental_config.tiling_schedule_optimize:静态Shape场景下,开启Tiling调度优化,Tiling计算将直接在Device测执行,提升静态Shape模型性能。静态场景下推荐设置为True。 + +使用示例: + +```python +config = torchair.CompilerConfig() +config.experimental_config.frozen_parameter = True +config.experimental_config.tiling_schedule_optimize = True +npu_backend = torchair.get_npu_backend(compiler_config=config) + +model = torch.compile(model, backend=npu_backend) +``` + +### 2. torch.compile() + +PyTorch提供的原生接口,详细资料可参考[torch.compile文档](https://docs.pytorch.org/docs/stable/generated/torch.compile.html) + +```python +torch.compile( + model: Callable[[_InputT], _RetT], + *, + fullgraph: bool = False, + dynamic: Optional[bool] = None, + backend: Union[str, Callable] = 'inductor', + disable: bool = False +) → Callable[[_InputT], _RetT][source] +``` + +参数说明 + +|参数名|参数说明| +|----|----| +|model|如图的模型或者函数,必选| +|fullgraph| bool类型,可选。是否对整图进行优化。
- False(缺省值):自动查询可优化部分,有不支持的算子自动断图
-True:捕获整图优化,但如果发生断图会抛出异常| +|dynamic|bool类型或者None,可选。是否使用动态shape tracing。
- None(缺省值):自动检测是否为动态图
- False:执行静态图
- True:执行动态图| +|backend|后端选择,缺省值为“inductor",目前昇腾NPU暂不支持。昇腾NPU成图只有一种后端,通过torchair.get_npu_backend接口获取,必选。| +|disable| bool类型,可选。是否关闭torch.compile能力。
- False(缺省值):开启torch.compile能力
- True:关闭torch.compile能力,采用单算子模式| + +参数建议: + +* fullgraph:整图性能优于非整图,推荐设置为True。如果遇到不支持成图的计算逻辑,可通过替换算子、修改代码位置等方式尝试解决。 +* dynamic:静态图性能优于动态图,推荐设置为True。静态图要求输入shape固定,且计算过程中不能出现动态shape。 + +使用示例: + +```python +config = torchair.CompilerConfig() +config.experimental_config.frozen_parameter = True +config.experimental_config.tiling_schedule_optimize = True +npu_backend = torchair.get_npu_backend(compiler_config=config) + +model = torch.compile(model, dynamic=False, fullgraph=True, backend=npu_backend) +``` + +## 成图问题解决方案 + +### 1. 拆分prefill和decode阶段 + +因为prefill和decode阶段使用的FlashAttention算子不同,需要分别成图。 + +原模型代码: + +```python +import torch, torch_npu, torchair +class Decoder(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, ...): + ... + return ... +``` + +为prefill和decode单独写一个方法,在forward里判断走prefill还是decode + +```python +import torch, torch_npu, torchair +class Decoder(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, ...): + if x.size(1) > 1: + ... + return self.prefill(x, ...) + ... + return self.decode(x, kv_cache, ...) + + def prefill(self, x, ...): + ... + return ... + + def decode(self, x, kv_cache, ...): + ... + return ... + +decoder = Decoder() +# 编译prefill和decode方法 +config = torchair.CompilerConfig() +config.experimental_config.frozen_parameter = True +config.experimental_config.tiling_schedule_optimize = True +npu_backend = torchair.get_npu_backend(compiler_config=config) +decoder.prefill = torch.compile(decoder.prefill, dynamic=False, fullgraph=True, backend=npu_backend) +# decode阶段通常是自回归模式运行,输入不固定,dynamic设置为True +decoder.decode = torch.compile(decoder.decode, dynamic=True, fullgraph=True, backend=npu_backend) +``` + +### 2. 将不支持的算子移出图编译部分 + +部分数据类型或者操作不支持torchAir成图,可以将这部分代码从需要被优化的函数中移出。比如在语音模型中通常会用torch.stft()对语音输入做傅里叶变换,torchAir不支持该操作,无法通过编译,可以调整这部分代码位置,使得我们可以用fullgraph=True让函数主体可以成图。 + +示例: + +替换前: + +```python +class Model + ... + def decode(self, x): + spec = torch.stft(x, N_FFT, HOP_LENGTH, window, return_complex=True) + .... + return ... + + def forward(self, x): + ... + out = self.decode(x) + ... + return ... +``` + +替换后: + +```python +class Model + ... + def decode(self, x, spec): + # spec = torch.stft(x, N_FFT, HOP_LENGTH, window, return_complex=True) + .... + return ... + + def forward(self, x): + spec = torch.stft(x, N_FFT, HOP_LENGTH, window, return_complex=True) # 将stft操作移出decode + stft_real_img = torch.view_as_real(spec) # npu不支持复数 + out = self.decode(x, stft_real_img) + ... + return ... + +model = Model() +model.decode = torch.compile(model.decode) +``` + +### 3. 算子converter注册和实现 +完成算子注册PyTorch后,支持以Eager模式调用自定义算子,但不支持图模式调用。自定义算子还需向TorchAir注册并实现对应converter函数,完成ATen IR向GE IR转换,实现算子在NPU上入图。具体步骤参考:[converter补齐](https://gitee.com/ascend/torchair/blob/7.1.0/CONTRIBUTING.md#converter%E8%A1%A5%E9%BD%90) \ No newline at end of file diff --git "a/ACL_PyTorch/docs/torchair/\345\233\276\344\274\230\345\214\226/3.\345\212\237\350\203\275\347\262\276\345\272\246\351\227\256\351\242\230\345\256\232\344\275\215.md" "b/ACL_PyTorch/docs/torchair/\345\233\276\344\274\230\345\214\226/3.\345\212\237\350\203\275\347\262\276\345\272\246\351\227\256\351\242\230\345\256\232\344\275\215.md" new file mode 100644 index 0000000000000000000000000000000000000000..c02c8b66b9ea259245165f05a5d71604054975fc --- /dev/null +++ "b/ACL_PyTorch/docs/torchair/\345\233\276\344\274\230\345\214\226/3.\345\212\237\350\203\275\347\262\276\345\272\246\351\227\256\351\242\230\345\256\232\344\275\215.md" @@ -0,0 +1,72 @@ +## 功能问题定位 +模型中存在不支持的算子会导致torchAir成图失败,具体支持算子清单可以参考:[ATen API支持清单](https://www.hiascend.com/document/detail/zh/Pytorch/710/modthirdparty/torchairuseguide/torchair_00040.html) + +### 1. dynamo成图python日志 + +开启方式: + +```python +import logging +torch._logging.set_logs(dynamo=logging.DEBUG, aot=logging.DEBUG, output_code=True, graph_code=True, recompiles=True) +``` + +主要用于原生dynamo报错时定位具体问题,主要看FX图guard信息,可以看到dynamo成FX图的具体流程,从而找到dynamo报错或者不支持的原因。 + +### 2. TorchAir python日志 +开启方式: +```python +import logging +torchair.logger.setLevel(logging.DEBUG) +``` + +当发生例如converter内部报错、TorchAir配置项报错等问题时,可以打印该日志查看TorchAir的converter转换流程或配置项等。 + +### 3. TorchAir C++日志 + +开启方式: +```bash +export TNG_LOG_lEVEL=0 +``` + +主要用于定位Ascend IR图是否配置以及内存问题。在该日志中可以看到Ascend IR图是否编译成功以及传入的具体配置项,同时还能够看到输入/输出的内存申请。 + +## 精度问题定位 + +### 1. TorchAir python日志 +开启方式: +```python +import logging +torchair.logger.setLevel(logging.DEBUG) +``` + +除了converter内部报错、TorchAir配置项报错等问题,该日志还可以查看单个算子的输入/输出shape用于定位精度问题。 + +### 2. Tensor print +图模式下,图内中间节点的device输出结果无法直接print查看。可以通过npu_print API打印。使用方式如下: + +```python +import torch, torch_npu, torchair + +class Model(nn.Module): + ... + def forward(self, x): + ... + y = x ** 2 + torchair.ops.npu_print("y=", y, summarize_size=2) + ... + +model = Model() +model = torch.compile(model) +out = model(x) +``` + +API说明:def npu_print(*args, summarize_size=3) + +* *args: 支持输入torch.Tensor和python基本类型 - str, bool, float, int +* summarize_size: 打印tensor时,每个轴首尾展示的数据量 + +接口约束: + +* 入参args中至少包含一个Tensor类型输入 +* summarize_size为大于0的正整数或者-1(全打印) +* 接口是异步打印,多个npu_print的输出顺序与被打印节点执行序一致 \ No newline at end of file diff --git "a/ACL_PyTorch/docs/torchair/\345\233\276\344\274\230\345\214\226/4.\351\253\230\351\230\266\347\211\271\346\200\247.md" "b/ACL_PyTorch/docs/torchair/\345\233\276\344\274\230\345\214\226/4.\351\253\230\351\230\266\347\211\271\346\200\247.md" new file mode 100644 index 0000000000000000000000000000000000000000..8378240f83c2176db42884af472f77ecfc9d66fe --- /dev/null +++ "b/ACL_PyTorch/docs/torchair/\345\233\276\344\274\230\345\214\226/4.\351\253\230\351\230\266\347\211\271\346\200\247.md" @@ -0,0 +1,187 @@ +## 1. 动态分档 + +当模型输入不是固定shape,但可以固定为某几个shape时(通常是batch维度变化,有固定几种batch输入),如果使用动态图性能无法达到最优,但静态图每次遇到新的shape会重新编译。此时可以使用动态分档功能以获得静态图的性能优化,且只需编译一次。 + +示例: + +```python +import torch +import torch_npu +import torchair + +class Model(torch.nn.Module): + def __init__(self): + super().__init__() + def forward(self, x1, x2): + return x1 + x2 + +input1 = torch.ones(2, 2).npu() +input2 = torch.ones(2, 2).npu() +config = torchair.CompilerConfig() +# zip方式,位置一一对应,支持(2, 2)(2, 2)和(4, 2)(4, 2)两种输入组合 +config.inference_config.dynamic_gears_merge_policy = "zip" #缺省值 +# product方式,排列组合,支持(2, 2)(2, 2)、(2, 2)(4, 2)、(4, 2)(2, 2)和(4, 2)(4, 2)四种输入组合 +# config.inference_config.dynamic_gears_merge_policy = "product" +npu_bakcend = torchair.get_npu_backend(compiler_config=config) +model = Model().npu() +# 必须整图编译 +model = torch.compile(model, fullgraph=True, backend=npu_bakcend) + +# 设置档位 +torchair.inference.set_dim_gears(input1, dim_gears={0:[2, 4]}) +torchair.inference.set_dim_gears(input2, dim_gears={0:[2, 4]}) + +# 首次编译+执行,shape为(2, 2)、(2, 2) +print(model(input1, input2)) + +# 再次执行,shape为(4, 2)、(4, 2)在档位中,不会触发重新编译 +input1 = torch.ones(4, 2).npu() +input2 = torch.ones(4, 2).npu() +print(model(input1, input2)) +``` + +## 2. torch.dynamo.mark_static() + +PyTorch提供的接口,标记某个输入的shape为固定shape。将输入全部使用该接口固定后,如果模型中无动态shape,torch.compile的dynamic参数设置为True时,也可以得到静态图。如果模型输入shape固定,但内部有动态shape,使用这个参数可以在GE编译时生成部分静态子图,提升性能。 + +示例: + +```python +import torch +import torch_npu +import torchair + +class Decoder(nn.Module): + ... + + def decode( + self, + x: Tensor, + xa: Tensor, + positional_embedding: Tensor, + kv_cache: Optional[dict] = None, + updated_kv_positions: Optional[torch.LongTensor] = None, + actual_seq_len: Optional[list] = None, + kv_padding_size: Optional[torch.LongTensor] = None + ): + ... + return ... + + def compute_logits(self, x, xa): + ... + torch._dynamo.mark_static(x) + torch._dynamo.mark_static(xa) + torch._dynamo.mark_static(positional_embedding) + for i in range(n_layer): + torch._dynamo.mark_static(self.kv_cache[i]['attn']["key"]) + torch._dynamo.mark_static(self.kv_cache[i]['attn']["value"]) + torch._dynamo.mark_static(self.kv_cache[i]['cross_attn']["key"]) + torch._dynamo.mark_static(self.kv_cache[i]['cross_attn']["value"]) + torch._dynamo.mark_static(kv_padding_size) + + return self.decode(x, xa, positional_embedding, self.kv_cache, + actual_seq_len=[actual_seq_len], kv_padding_size=kv_padding_size, + updated_kv_positions=updated_kv_positions) + +config = torchair.CompilerConfig() +npu_bakcend = torchair.get_npu_backend(compiler_config=config) +decoder = Decoder().npu() +decoder.decode = torch.compile(decoder.decode, dynamic=True, fullgraph=True, backend=npu_backend) +``` + +## 3. 编译缓存 +使用torch.compile编译时,每次启动程序都要重新编译,且编译时间通常较长,不方便调试,我们可以通过cache_compile接口,将首次编译结果落盘到磁盘,加速图模式的启动时间。 + +一般使用torch.compile的时候,首次执行需要Dynamo编译,Torch Guards,Ascend IR图编译等过程,后续每次执行的时候还需要运行Guards函数来判断是否需要重新编译。使用编译缓存后,首次执行直接load cache,后续再次执行的时候跳过Guards函数。 + +缓存编译接口: + +torchair.inference.cache_compile + +```python +def cache_compile( + func, # 缓存编译的method,只支持module的method + *, + config: Optional[CompilerConfig] = None, + dynamic: bool = True, # 是否按照输入动态trace + cache_dir: Optional[str] = None, # 缓存根目录,默认.torchair_cache + global_rank: Optional[int] = None, # 分布式训练时的rank,默认torch.distributed.get_rank() + tp_rank: Optional[int] = None, # 指定的tp rank + pp_rank: Optional[int] = None, # 指定的pp rank + custom_decompositions: optional[dict] = None, # 用户自定义的decompose策略 + ge_cache: bool = False, # 是否开启GE缓存 + **kwargs +) -> Callable: +``` + +### 使用约束: + +* func函数只能被处罚一次Dynamo trace, 如果func发生重编译,则会放弃缓存。 +* 对于发生多次trace (Guards失效)的函数,需要进行一次函数封装来使缓存生效。 +* func必须是module实例对象的method,且该方法未被其他装饰器修饰 +* func必须能形成整图,即必须支持full graph +* 只支持推理模式,不支持带反向计算过程的func缓存。 + +Ascend IR编译缓存 + +* 除了优化Dynamo编译耗时,还支持优化Ascend IR图编译耗时,主要通过cache_compile中的ge_cache参数实现,以进一步加速图模式启动时间。 +* 缺省情况下,ge_cache=False (功能不开启),因受CANN版本变更影响,用户需根据实际情况手动开启该功能 +* 缓存的编译结果文件路径与封装的func函数缓存文件路径一致 + +### 使用示例 + +原模型代码: + +```python +import torch, torch_npu, torchair +class Model(torch.nn.Module): + def __init__(self): + super().__init__() + + @torch.inference_mode() + def forward(self, x, y): + return x + y +``` + +适配步骤 + +* 提取forward的实现为_forward函数,缓存编译_forward函数 +* 初始化时将被编译的函数定义为Module自身的一个属性 +* forward里执行被编译的函数,也就是初始化中设置的属性 + +```python +import torch, torch_npu, torchair +class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.cached_forward = torchair.Inference.cache_compile(self._forward) + + def forward(self, x, y): + # 修改为编译后的模型 + return self.cached_forward(x, y) + + def _forward(self, xy, y): + return x + y +``` + +带有transformer结构的模型中常见forward中会在prefill和decode阶段中使用不同的算子或者计算逻辑,此时需要针对每个场景新增一个函数: + +```python +class Model(torch.nn.Module): + def __init__(self): + super().__init__() + config = torchair.CompilerConfig() + self.cached_prefill = torchair.inference.cache_compile(self.prefill, config=config) + self.cached_decode = torchair.inference.cache_compile(self.decode, config=config) + + def forward(self, x, ...): + if x.size(1) > 1: + return self.cached_prefill(x, ...) + return self.cached_decode(x, ...) + + def prefill(self, x, ...): + return ... + + def decode(self, x, kv_cache): + return ... +``` \ No newline at end of file