diff --git "a/ACL_PyTorch/docs/01.\345\234\250\347\272\277\346\216\250\347\220\206.md" "b/ACL_PyTorch/docs/ONNX/01.\345\234\250\347\272\277\346\216\250\347\220\206.md"
similarity index 100%
rename from "ACL_PyTorch/docs/01.\345\234\250\347\272\277\346\216\250\347\220\206.md"
rename to "ACL_PyTorch/docs/ONNX/01.\345\234\250\347\272\277\346\216\250\347\220\206.md"
diff --git "a/ACL_PyTorch/docs/02.ONNX\347\232\204\345\257\274\345\207\272.md" "b/ACL_PyTorch/docs/ONNX/02.ONNX\347\232\204\345\257\274\345\207\272.md"
similarity index 100%
rename from "ACL_PyTorch/docs/02.ONNX\347\232\204\345\257\274\345\207\272.md"
rename to "ACL_PyTorch/docs/ONNX/02.ONNX\347\232\204\345\257\274\345\207\272.md"
diff --git "a/ACL_PyTorch/docs/03.ONNX\350\275\254OM.md" "b/ACL_PyTorch/docs/ONNX/03.ONNX\350\275\254OM.md"
similarity index 100%
rename from "ACL_PyTorch/docs/03.ONNX\350\275\254OM.md"
rename to "ACL_PyTorch/docs/ONNX/03.ONNX\350\275\254OM.md"
diff --git "a/ACL_PyTorch/docs/04.\347\246\273\347\272\277\346\216\250\347\220\206.md" "b/ACL_PyTorch/docs/ONNX/04.\347\246\273\347\272\277\346\216\250\347\220\206.md"
similarity index 100%
rename from "ACL_PyTorch/docs/04.\347\246\273\347\272\277\346\216\250\347\220\206.md"
rename to "ACL_PyTorch/docs/ONNX/04.\347\246\273\347\272\277\346\216\250\347\220\206.md"
diff --git "a/ACL_PyTorch/docs/05.01.ResNet50 \344\275\277\347\224\250\351\235\231\346\200\201 AIPP \347\232\204\346\241\210\344\276\213.md" "b/ACL_PyTorch/docs/ONNX/05.01.ResNet50 \344\275\277\347\224\250\351\235\231\346\200\201 AIPP \347\232\204\346\241\210\344\276\213.md"
similarity index 100%
rename from "ACL_PyTorch/docs/05.01.ResNet50 \344\275\277\347\224\250\351\235\231\346\200\201 AIPP \347\232\204\346\241\210\344\276\213.md"
rename to "ACL_PyTorch/docs/ONNX/05.01.ResNet50 \344\275\277\347\224\250\351\235\231\346\200\201 AIPP \347\232\204\346\241\210\344\276\213.md"
diff --git "a/ACL_PyTorch/docs/05.02.GLIP\343\200\201CLIP \346\200\247\350\203\275\344\274\230\345\214\226\346\241\210\344\276\213.md" "b/ACL_PyTorch/docs/ONNX/05.02.GLIP\343\200\201CLIP \346\200\247\350\203\275\344\274\230\345\214\226\346\241\210\344\276\213.md"
similarity index 100%
rename from "ACL_PyTorch/docs/05.02.GLIP\343\200\201CLIP \346\200\247\350\203\275\344\274\230\345\214\226\346\241\210\344\276\213.md"
rename to "ACL_PyTorch/docs/ONNX/05.02.GLIP\343\200\201CLIP \346\200\247\350\203\275\344\274\230\345\214\226\346\241\210\344\276\213.md"
diff --git "a/ACL_PyTorch/docs/05.\346\200\247\350\203\275\344\274\230\345\214\226.md" "b/ACL_PyTorch/docs/ONNX/05.\346\200\247\350\203\275\344\274\230\345\214\226.md"
similarity index 100%
rename from "ACL_PyTorch/docs/05.\346\200\247\350\203\275\344\274\230\345\214\226.md"
rename to "ACL_PyTorch/docs/ONNX/05.\346\200\247\350\203\275\344\274\230\345\214\226.md"
diff --git a/ACL_PyTorch/docs/README.md b/ACL_PyTorch/docs/ONNX/README.md
similarity index 100%
rename from ACL_PyTorch/docs/README.md
rename to ACL_PyTorch/docs/ONNX/README.md
diff --git a/ACL_PyTorch/docs/images/ATC.png b/ACL_PyTorch/docs/ONNX/images/ATC.png
similarity index 100%
rename from ACL_PyTorch/docs/images/ATC.png
rename to ACL_PyTorch/docs/ONNX/images/ATC.png
diff --git a/ACL_PyTorch/docs/images/Attention.png b/ACL_PyTorch/docs/ONNX/images/Attention.png
similarity index 100%
rename from ACL_PyTorch/docs/images/Attention.png
rename to ACL_PyTorch/docs/ONNX/images/Attention.png
diff --git a/ACL_PyTorch/docs/images/ConvTranspose+Add.png b/ACL_PyTorch/docs/ONNX/images/ConvTranspose+Add.png
similarity index 100%
rename from ACL_PyTorch/docs/images/ConvTranspose+Add.png
rename to ACL_PyTorch/docs/ONNX/images/ConvTranspose+Add.png
diff --git a/ACL_PyTorch/docs/images/ConvTranspose+Add_after.png b/ACL_PyTorch/docs/ONNX/images/ConvTranspose+Add_after.png
similarity index 100%
rename from ACL_PyTorch/docs/images/ConvTranspose+Add_after.png
rename to ACL_PyTorch/docs/ONNX/images/ConvTranspose+Add_after.png
diff --git a/ACL_PyTorch/docs/images/ConvTranspose+Add_before.png b/ACL_PyTorch/docs/ONNX/images/ConvTranspose+Add_before.png
similarity index 100%
rename from ACL_PyTorch/docs/images/ConvTranspose+Add_before.png
rename to ACL_PyTorch/docs/ONNX/images/ConvTranspose+Add_before.png
diff --git a/ACL_PyTorch/docs/images/DeformableConv_after.png b/ACL_PyTorch/docs/ONNX/images/DeformableConv_after.png
similarity index 100%
rename from ACL_PyTorch/docs/images/DeformableConv_after.png
rename to ACL_PyTorch/docs/ONNX/images/DeformableConv_after.png
diff --git a/ACL_PyTorch/docs/images/DeformableConv_before.png b/ACL_PyTorch/docs/ONNX/images/DeformableConv_before.png
similarity index 100%
rename from ACL_PyTorch/docs/images/DeformableConv_before.png
rename to ACL_PyTorch/docs/ONNX/images/DeformableConv_before.png
diff --git a/ACL_PyTorch/docs/images/DeformableConv_om.png b/ACL_PyTorch/docs/ONNX/images/DeformableConv_om.png
similarity index 100%
rename from ACL_PyTorch/docs/images/DeformableConv_om.png
rename to ACL_PyTorch/docs/ONNX/images/DeformableConv_om.png
diff --git a/ACL_PyTorch/docs/images/FAfp32.png b/ACL_PyTorch/docs/ONNX/images/FAfp32.png
similarity index 100%
rename from ACL_PyTorch/docs/images/FAfp32.png
rename to ACL_PyTorch/docs/ONNX/images/FAfp32.png
diff --git a/ACL_PyTorch/docs/images/ILSVRC2012_val_00006083.jpeg b/ACL_PyTorch/docs/ONNX/images/ILSVRC2012_val_00006083.jpeg
similarity index 100%
rename from ACL_PyTorch/docs/images/ILSVRC2012_val_00006083.jpeg
rename to ACL_PyTorch/docs/ONNX/images/ILSVRC2012_val_00006083.jpeg
diff --git a/ACL_PyTorch/docs/images/LayerNorm_after.png b/ACL_PyTorch/docs/ONNX/images/LayerNorm_after.png
similarity index 100%
rename from ACL_PyTorch/docs/images/LayerNorm_after.png
rename to ACL_PyTorch/docs/ONNX/images/LayerNorm_after.png
diff --git a/ACL_PyTorch/docs/images/LayerNorm_before.png b/ACL_PyTorch/docs/ONNX/images/LayerNorm_before.png
similarity index 100%
rename from ACL_PyTorch/docs/images/LayerNorm_before.png
rename to ACL_PyTorch/docs/ONNX/images/LayerNorm_before.png
diff --git a/ACL_PyTorch/docs/images/NMS.png b/ACL_PyTorch/docs/ONNX/images/NMS.png
similarity index 100%
rename from ACL_PyTorch/docs/images/NMS.png
rename to ACL_PyTorch/docs/ONNX/images/NMS.png
diff --git a/ACL_PyTorch/docs/images/NMS_after.png b/ACL_PyTorch/docs/ONNX/images/NMS_after.png
similarity index 100%
rename from ACL_PyTorch/docs/images/NMS_after.png
rename to ACL_PyTorch/docs/ONNX/images/NMS_after.png
diff --git a/ACL_PyTorch/docs/images/NMS_before.png b/ACL_PyTorch/docs/ONNX/images/NMS_before.png
similarity index 100%
rename from ACL_PyTorch/docs/images/NMS_before.png
rename to ACL_PyTorch/docs/ONNX/images/NMS_before.png
diff --git a/ACL_PyTorch/docs/images/Reshape.png b/ACL_PyTorch/docs/ONNX/images/Reshape.png
similarity index 100%
rename from ACL_PyTorch/docs/images/Reshape.png
rename to ACL_PyTorch/docs/ONNX/images/Reshape.png
diff --git a/ACL_PyTorch/docs/images/Resize_after.png b/ACL_PyTorch/docs/ONNX/images/Resize_after.png
similarity index 100%
rename from ACL_PyTorch/docs/images/Resize_after.png
rename to ACL_PyTorch/docs/ONNX/images/Resize_after.png
diff --git a/ACL_PyTorch/docs/images/Resize_before.png b/ACL_PyTorch/docs/ONNX/images/Resize_before.png
similarity index 100%
rename from ACL_PyTorch/docs/images/Resize_before.png
rename to ACL_PyTorch/docs/ONNX/images/Resize_before.png
diff --git a/ACL_PyTorch/docs/images/Resize_scales_after.png b/ACL_PyTorch/docs/ONNX/images/Resize_scales_after.png
similarity index 100%
rename from ACL_PyTorch/docs/images/Resize_scales_after.png
rename to ACL_PyTorch/docs/ONNX/images/Resize_scales_after.png
diff --git a/ACL_PyTorch/docs/images/Resize_scales_before.png b/ACL_PyTorch/docs/ONNX/images/Resize_scales_before.png
similarity index 100%
rename from ACL_PyTorch/docs/images/Resize_scales_before.png
rename to ACL_PyTorch/docs/ONNX/images/Resize_scales_before.png
diff --git a/ACL_PyTorch/docs/images/Split_after.png b/ACL_PyTorch/docs/ONNX/images/Split_after.png
similarity index 100%
rename from ACL_PyTorch/docs/images/Split_after.png
rename to ACL_PyTorch/docs/ONNX/images/Split_after.png
diff --git a/ACL_PyTorch/docs/images/Split_before.png b/ACL_PyTorch/docs/ONNX/images/Split_before.png
similarity index 100%
rename from ACL_PyTorch/docs/images/Split_before.png
rename to ACL_PyTorch/docs/ONNX/images/Split_before.png
diff --git a/ACL_PyTorch/docs/images/Swin.png b/ACL_PyTorch/docs/ONNX/images/Swin.png
similarity index 100%
rename from ACL_PyTorch/docs/images/Swin.png
rename to ACL_PyTorch/docs/ONNX/images/Swin.png
diff --git a/ACL_PyTorch/docs/images/pad_after.png b/ACL_PyTorch/docs/ONNX/images/pad_after.png
similarity index 100%
rename from ACL_PyTorch/docs/images/pad_after.png
rename to ACL_PyTorch/docs/ONNX/images/pad_after.png
diff --git a/ACL_PyTorch/docs/images/pad_before.png b/ACL_PyTorch/docs/ONNX/images/pad_before.png
similarity index 100%
rename from ACL_PyTorch/docs/images/pad_before.png
rename to ACL_PyTorch/docs/ONNX/images/pad_before.png
diff --git "a/ACL_PyTorch/docs/torchair/images/torch_npu\344\274\230\345\214\226\346\265\201\347\250\213.png" "b/ACL_PyTorch/docs/torchair/images/torch_npu\344\274\230\345\214\226\346\265\201\347\250\213.png"
new file mode 100644
index 0000000000000000000000000000000000000000..5877bacad423ef748c32772a25a7c02388862d65
Binary files /dev/null and "b/ACL_PyTorch/docs/torchair/images/torch_npu\344\274\230\345\214\226\346\265\201\347\250\213.png" differ
diff --git "a/ACL_PyTorch/docs/torchair/torch_npu/torch_npu\350\277\201\347\247\273\351\200\202\351\205\215.md" "b/ACL_PyTorch/docs/torchair/torch_npu/torch_npu\350\277\201\347\247\273\351\200\202\351\205\215.md"
new file mode 100644
index 0000000000000000000000000000000000000000..088753cc645df4eff7765e1d7c3600bc19f424ab
--- /dev/null
+++ "b/ACL_PyTorch/docs/torchair/torch_npu/torch_npu\350\277\201\347\247\273\351\200\202\351\205\215.md"
@@ -0,0 +1,43 @@
+1. 导入NPU相关库
+
+```python
+import torch
+import torch_npu
+```
+
+2. 指定NPU作为运行设备
+* .to(device)方式
+
+ 迁移前:
+
+ ```python
+ device = torch.device('cuda:{})'.format(local_rank))
+ model.to(device)
+ data.to(device)
+ ```
+
+ 迁移后:
+
+ ```python
+ device = torch.device('npu:{}'.format(local_rank))
+ model.to(device)
+ data.to(device)
+ ```
+
+* set_device方式
+
+ 迁移前:
+
+ ```python
+ torch.cuda.set_device(local_rank)
+ ```
+
+ 迁移后:
+
+ ```python
+ torch_npu.npu.set_device(local_rank)
+ ```
+
+ 更多torch_npu迁移指导请参考:[PyTorch 模型手工迁移指南](https://www.hiascend.com/document/detail/zh/Pytorch/60RC1/ptmoddevg/trainingmigrguide/PT_LMTMOG_0018.html)
+
+3. 运行模型,确保迁移后可以正确运行。
\ No newline at end of file
diff --git "a/ACL_PyTorch/docs/torchair/torch_npu/\345\270\270\347\224\250\344\274\230\345\214\226\347\202\271.md" "b/ACL_PyTorch/docs/torchair/torch_npu/\345\270\270\347\224\250\344\274\230\345\214\226\347\202\271.md"
new file mode 100644
index 0000000000000000000000000000000000000000..8c9a6201a409a3579dbc2b14c1413b0a976e1d52
--- /dev/null
+++ "b/ACL_PyTorch/docs/torchair/torch_npu/\345\270\270\347\224\250\344\274\230\345\214\226\347\202\271.md"
@@ -0,0 +1,277 @@
+# 基本优化流程
+模型性能设计包括算法在内的多个模块,因此模型性能的优化的关键在于找到当前性能瓶颈,找到关键问题后再针对性优化,优化流程如下:
+
+
+
+1. 参考[性能调优工具介绍](https://www.hiascend.com/document/detail/zh/Pytorch/710/ptmoddevg/trainingmigrguide/performance_tuning_0014.html),选择对应的性能工具,采集性能数据并拆解性能,找到需要提升性能的模块。
+2. 在明确性能瓶颈模块后,将问题细化定位到下发、计算等模块,并采用对应的优化手段。
+
+# 常用优化点
+## cpu操作迁移至npu
+该步骤为**重要优化点**!
+
+* 有一些算子或者数据类型不支持在npu上运行,可以用等价的npu算子,或者逻辑等价的代码去替换,让这段代码可以在npu上运行
+
+ 示例:
+
+ ```python
+ # 比如语音模型中常用到傅里叶变换,output为复数,npu不支持复数类型计算,可以用torch.view_as_real()提取实数和虚数部分再进行计算。
+ ''' 原始代码
+ audio = torch.from_numpy(audio) # cpu tensor
+ window = torch.hann_window(N_FFT) # cpu tensor
+ stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True) # stft为复数
+ magnitudes = stft[..., -1].abs() ** 2 # npu不支持对复数做abs()计算
+ '''
+
+ # 替换为:
+ audio = torch.from_numpy(audio).npu() # 把tensor放到npu上
+ window = torch.hann_window(N_FFT).to(audio.device)
+ stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True)
+ stft_real_img = torch.view_as_real(stft)[..., :-1, :] # 将复数的实数和虚数部分拆解出来
+ real_part = stft_real_img[..., 0]
+ img_part = stft_real_img[..., 1]
+ magnitudes = real_part ** 2 + img_part ** 2 # 数学逻辑等价于复数.abs()**2
+ ```
+
+* 将cpu tensor转到npu上
+
+ 示例:
+ ```python
+ ''' 原始代码
+ waveforms = torch.empty(0)
+ feats_pad = torch.empty(0)
+ feats_lens = torch.empty(0)
+ '''
+ # 修改为:
+ waveforms = torch.empty(0, device=kwargs["device"])
+ feats_pad = torch.empty(0, device=kwargs["device"])
+ feats_lens = torch.empty(0, device=kwargs["device"])
+ ```
+
+## 融合算子替换
+通过数学意义上的等价替换,将多个算子融合为一个算子的计算,减少冗余计算,同时减少下发次数,从而提高性能。
+
+### RotaryMul
+transformer结构中通常会用到位置编码,旋转位置编码(Rotary Position Embedding)是一种常用的位置编码方式。旋转编码一般在transformer的Attention模块中,在计算完query, key的映射后加入旋转位置编码。
+
+torch_npu接口:
+```python
+torch_npu.npu_rotary_mul(x, r1, r2)
+```
+
+参数说明:
+* x: q, k, shape要求输入为4维,一般为[B, N, S, D]或[B, S, N, D]或[S, B, N, D]。
+* r1:cos值,shape要求输入为4维,一般为[1, 1, S, D]或[1, S, 1, D]或[S, 1, 1, D]。
+* r2:sin值,shape要求输入为4维,一般为[1, 1, S, D]或[1, S, 1, D]或[S, 1, 1, D]。
+
+示例:
+```python
+'''原始接口
+def rotate_half(x):
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2:]
+ return torch.cat((-x2, x1), dim=1)
+
+def apply_rotary_pos_emb(q, k, cos, sin, offset: int = 0):
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+ return q_embed, k_embed
+'''
+
+# 替换为:
+def apply_fused_rotary_pos_emb(q, k, cos, sin, offset: int = 0):
+ return torch_npu.npu_rotary_mul(q, cos, sin), torch_npu.npu_rotary_mul(k, cos, sin)
+```
+
+**使用限制**
+目前算子仅支持r1、r2需要broadcast为x的shape的情形,且算子shape中最后一维D必须是128的倍数。
+
+### FlashAttention算子
+如果原始代码中使用小算子来计算attention score,我们可以使用npu的flash attention融合算子提升性能,prefill阶段使用[torch_npu.npu_prompt_flash_attention](https://www.hiascend.com/document/detail/zh/Pytorch/710/apiref/torchnpuCustomsapi/context/torch_npu-npu_prompt_flash_attention.md),decode阶段使用[torch_npu.npu_incre_flash_attention](https://www.hiascend.com/document/detail/zh/Pytorch/710/apiref/torchnpuCustomsapi/context/torch_npu-npu_incre_flash_attention.md)
+
+小算子样例:
+```python
+scores = torch.matmul(q, k.transpose(-2, -1))
+scores = scores.masked_fill(mask, min_value)
+attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0)
+x = torch.matmul(attn, value)
+```
+
+重写attention模块的forward方法:
+```python
+# Prefill用PFA
+...
+n_ctx = q.shape[1]
+...
+if n_ctx > 1:
+ mask = mask.to(torch.bool) if mask is not None and n_ctx > 1 else None
+ sparse_mode = 1 if mask is not None and n_ctx > 1 else 0
+ attn = torch_npu.npu_prompt_flash_attention(
+ q.contiguous(),
+ k.contiguous(),
+ v.contiguous(),
+ num_heads=self.n_head,
+ input_layout="BNSD",
+ scale_value=1 / math.sqrt(D),
+ atten_mask=mask[:n_ctx, :n_ctx] if mask is not None else None,
+ sparse_mode=sparse_mode
+ )
+# Decode用IFA
+else:
+ attn = torch_npu.npu_incre_flash_attention(
+ q.contiguous(),
+ k.contiguous(),
+ v.contiguous(),
+ num_heads=self.n_head,
+ input_layout="BNSD",
+ scale_value=1 / math.sqrt(D),
+ atten_mask=None,
+ actual_seq_lengths=actual_seq_len,
+ kv_padding_size=kv_padding_size
+ )
+```
+
+其他融合算子请参考:[融合算子](https://www.hiascend.com/document/detail/zh/Pytorch/710/ptmoddevg/trainingmigrguide/performance_tuning_0031.html)
+
+## qkv融合
+
+优化思路:针对带有attention模块的模型,可以将q, k, v三个矩阵替换为一个矩阵乘,最大化使用npu的计算能力提升性能
+
+实现方式:
+
+```python
+class Model(nn.Module):
+ def __init__(self):
+ super().__init__()
+ # 新增qkv大矩阵
+ self.qkv = nn.Linear(self.hidden_size, self.num_heads * self.head_dim +
+ 2 * self.num_key_value_heads * self.head_dim, bias=False)
+ ...
+
+ def forward(self, hidden_states):
+ '''原始qkv linear层
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+ '''
+
+ # 替换为
+ qkv_states = self.qkv(hidden_states)
+ query_states, key_states, value_states = qkv_states.split(
+ [self.q_hidden_size, self.kv_hidden_size, self.kv_hidden_size], dim=2)
+
+ # 权重融合,将原始q,k,v linear的权重放到qkv大矩阵中
+ def merge_qkv_weight(self, tp_size=1):
+ def _to_parameter(data):
+ return nn.Parameter(data, requires_grad=False)
+
+ qw_size = self.model.layers[0].self_attn.q_proj.weight.shape
+ kw_size = self.model.layers[0].self_attn.k_proj.weight.shape
+ vw_size = self.model.layers[0].self_attn.v_proj.weight.shape
+
+ q_sliced_size = qw_size[0] // tp_size
+ k_sliced_size = kw_size[0] // tp_size
+ v_sliced_size = vw_size[0] // tp_size
+
+ for i in range(len(self.model.layers)):
+ qw = self.model.layers[i].self_attn.q_proj.weight
+ kw = self.model.layers[i].self_attn.k_proj.weight
+ vw = self.model.layers[i].self_attn.v_proj.weight
+
+ weight_list = []
+ for j in range(tp_size):
+ sliced_qw = qw[j * q_sliced_size: (j + 1) * q_sliced_size, :]
+ sliced_kw = kw[j * k_sliced_size: (j + 1) * k_sliced_size, :]
+ sliced_vw = vw[j * v_sliced_size: (j + 1) * v_sliced_size, :]
+ weight_list.append(_to_parameter(torch.cat([sliced_qw, sliced_kw, sliced_vw], axis=0)))
+
+ if len(weight_list) == 1:
+ self.model.layers[i].self_attn.qkv.weight = weight_list[0]
+ else:
+ self.model.layers[i].self_attn.qkv.weight = _to_parameter(torch.cat(weight_list, axis=0))
+```
+
+
+
+### 固定kv cache大小
+
+优化原因:如果kv cache是作为模型的输入,在模型中cat拼接后返回新的kv cache,这种更新方式存在多次申请内存及拷贝的性能损失。
+
+优化方式:根据句子最大长度申请好一块固定大小的kv cache tensor,然后通过scatter_update_算子对指定位置上的kv cache进行更新
+
+以transformers的llama源码为例:
+
+```python
+# transformers/models/llama/modeling_llama.py
+# LlamaForCausalLM的prepare_inputs_for_generation函数新增逻辑
+# 固定kv cache的大小,用作全量图和增量图的kv cache更新
+batch_size, seq_length = input_ids.shape
+use_dtype = self.model.torch_dtype
+if past_key_values is None:
+ kv_shape = (
+ batch_size, self.model.max_position_embeddings, self.model.num_key_value_heads // self.world_size,
+ self.model.hidden_size // self.model.num_attention_heads)
+ past_key_values = ()
+ for i in range(self.model.num_hidden_layers):
+ k_cache = torch.zeros(kv_shape, dtype=use_dtype, device=input_ids.device)
+ v_cache = torch.zeros(kv_shape, dtype=use_dtype, device=input_ids.device)
+ past_key_values += ((k_cache, v_cache),)
+```
+
+更新kv的改动:
+```python
+# 更新指定位置上的kv cache,position_ids在全量图执行时从seq_len 0的位置更新,在增量图执行时从seq_len位置更新
+tmp_ids = updated_kv_positions.reshape(-1)
+# format BSND, 1 means seq_len dim index
+torch_npu.scatter_update_(past_key_value[0], tmp_ids, key_states, 1)
+torch_npu.scatter_update_(past_key_value[1], tmp_ids, value_states, 1)
+
+key_states1 = past_key_value[0] if q_len == 1 else key_states
+value_states1 = past_key_value[1] if q_len == 1 else value_states
+
+past_key_value = past_key_value if use_cache else None
+```
+
+固定kv cache后,由于shape变化带来的其他改动:
+```python
+# prepare_inputs_for_generation函数中新增创建attention_mask以及更新kv位置tensor
+# 主要原因是全量和增量流程对于attention_mask的shape要求不一样,kv使用scatter_update更新需要指定更新位置
+past_key_values_length = 0
+if seq_length > 1:
+ if attention_mask is None:
+ attention_mask = torch.ones((batch_size, seq_length), dtype=torch.bool, device=input_ids.device)
+ self.padding_mask = torch.zeros(batch_size, self.model.max_position_embeddings, device=input_ids.device)
+ self.prompt_length = seq_length
+ self.updated_kv_positions = torch.zeros(batch_size, dtype=position_ids.dtype, device=position_ids.device)
+else:
+ bsz, src_len = attention_mask.size()
+ padding_mask = self.padding_mask
+ padding_mask[:, :src_len] = attention_mask
+ attention_mask = padding_mask
+ past_key_values_length = self.model.max_position_embeddings
+ self.prompt_length += 1
+ self.updated_kv_positions = torch.ones(position_ids.shape, dtype=position_ids.dtype,
+ device=position_ids.device) * (self.prompt_length - 1)
+
+attention_mask = self.model._prepare_decoder_attention_mask(
+ attention_mask, (batch_size, seq_length), past_key_values[0][0], past_key_values_length
+)
+```
+
+## 亲和算子替换
+### IndexPut算子
+遇到矩阵索引操作时,可以用乘法操作替代,避免多次进行小数据的随机访问,提升性能。
+
+示例:
+```python
+target_mask = (target < vocab_start_index) | (target >= vocab_end_index)
+masked_target = target.clone() - vocab_start_index
+masked_target[target_mask] = 0
+```
+替换为:
+```python
+target_mask = (target < vocab_start_index) | (target >= vocab_end_index)
+masked_target = target.clone() - vocab_start_index
+masked_target *= ~target_mask
+```
+其他亲和算子替换可参考:[亲和算子替换](https://www.hiascend.com/document/detail/zh/Pytorch/710/ptmoddevg/trainingmigrguide/performance_tuning_0042.html)
\ No newline at end of file
diff --git "a/ACL_PyTorch/docs/torchair/\345\233\276\344\274\230\345\214\226/1.\345\211\215\347\275\256\345\210\206\346\236\220\345\217\212\347\272\246\346\235\237\346\235\241\344\273\266.md" "b/ACL_PyTorch/docs/torchair/\345\233\276\344\274\230\345\214\226/1.\345\211\215\347\275\256\345\210\206\346\236\220\345\217\212\347\272\246\346\235\237\346\235\241\344\273\266.md"
new file mode 100644
index 0000000000000000000000000000000000000000..cfc48adbc9d839e02809100e9f7defe14264acee
--- /dev/null
+++ "b/ACL_PyTorch/docs/torchair/\345\233\276\344\274\230\345\214\226/1.\345\211\215\347\275\256\345\210\206\346\236\220\345\217\212\347\272\246\346\235\237\346\235\241\344\273\266.md"
@@ -0,0 +1,10 @@
+## 1. TorchAir适用场景
+TorchAir(Torch Ascend Intermediate Representation)对接Pytorch的Dynamo特性,将Pytorch的FX(functionalize)图转化为Ascend IR,通过Graph Engine进行计算图的编译优化等操作,并下发到昇腾硬件执行。成图主要是为了解决host-bound问题,若性能瓶颈不在此,应该参考[torch_npu常用优化点](../torch_npu/常用优化点.md)做针对性的性能优化。
+* Host Bound问题在profiling中的表现为"Overlap Analysis"计算占比低,free占比高。可参考[Host Bound问题分析](https://www.hiascend.com/document/detail/zh/mindstudio/81RC1/practicalcases/GeneralPerformanceIssue/toolsample6_053.html)
+* 入图部分算子必须支持converter,[ATen API支持清单](https://www.hiascend.com/document/detail/zh/Pytorch/710/modthirdparty/torchairuseguide/torchair_00040.html)列出了支持入图的ATen API,这些API能力均对等Eager模式下的ATen API能力。如果自定义模型用到的ATen API不在该列表中,说明对应的API能力可能不完备,用户需根据实际情况进行converter功能扩展,具体步骤参考:[converter补齐](https://gitee.com/ascend/torchair/blob/7.1.0/CONTRIBUTING.md#converter%E8%A1%A5%E9%BD%90)
+* 动态shape可控:当前torchair支持10个档位分档,多了会有性能劣化的问题;动态分档适用于某一维度的变化,比如batch size。
+
+
+## 2. 约束条件
+
+* PyTorch图模式支持单进程和多进程,每个进程只支持使用1张NPU卡,不支持使用多张NPU卡。
\ No newline at end of file
diff --git "a/ACL_PyTorch/docs/torchair/\345\233\276\344\274\230\345\214\226/2.\350\277\201\347\247\273\351\200\202\351\205\215.md" "b/ACL_PyTorch/docs/torchair/\345\233\276\344\274\230\345\214\226/2.\350\277\201\347\247\273\351\200\202\351\205\215.md"
new file mode 100644
index 0000000000000000000000000000000000000000..1bd31e49e8ec9357be81a69dc732239a6612e354
--- /dev/null
+++ "b/ACL_PyTorch/docs/torchair/\345\233\276\344\274\230\345\214\226/2.\350\277\201\347\247\273\351\200\202\351\205\215.md"
@@ -0,0 +1,281 @@
+## 基础使用方法
+
+1. 导入TorchAir相关库
+
+ ```python
+ import torchair
+ ```
+
+2. 配置compile的backend
+
+ ```python
+ config = torchair.CompilerConfig()
+ npu_backend = torchair.get_npu_backend(compiler_config=config)
+ # 创建模型实例
+ model = MyModel()
+ # 编译模型
+ model = torch.compile(model, backend=npu_backend)
+ # 或者只编译模型的子模块
+ model.encoder = torch.compile(model.encoder, backend=npu_backend)
+ model.decoder = torch.compile(model.decoder, backend=npu_backend)
+ # 或者只编译模型的一个方法
+ model.forward = torch.compile(model.forward, backend=npu_backend)
+ ```
+
+## huggingface模型示例
+
+从huggingface的Model Card的sample code迁移时,注意不要直接compile整个pipeline,而是将pipeline中的模型编译。
+
+示例:
+
+ ```python
+ from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
+ model_id = "openai/whisper-large-v3"
+ model = AutoModelForSpeechSeq2Seq.from_pretrained(model_id)
+ # 编译模型
+ model = torch.compile(model)
+ processor = AutoProcessor.from_pretrained(model_id)
+ pipe = pipeline(
+ "automatic-speech-recognition",
+ model=model, # 将编译后的模型作为参数传入pipeline中
+ tokenizer=processor.tokenizer,
+ feature_extractor=processor.feature_extractor,
+ torch_dtype=torch_dtype,
+ device=device,
+ )
+ # 执行推理
+ result = pipeline(sample)
+ ```
+
+## 由多个子模块拼接的模型
+有些模型由多个子模块拼接,在实际调用的时候或许并没有走自己的forward函数,而是通过其他接口去调用子模块。此时需要通过代码调试、代码走读等方式找到实际推理时调用的子模块。
+
+示例:
+
+ ```python
+ # chinese clip模型
+ import torch
+ from cn_clip.clip import load_from_name
+ device = torch.device('npu:0')
+ model, preprocess = load_from_name('ViT-B-16', device=device)
+
+ # 推理接口
+ with torch.no_grad():
+ logits_per_image, logits_per_text = model.get_similarity(image, text)
+ ```
+
+直接打印模型,模型结构如下:
+
+ ```
+ CLIP(
+ (visual): VisualTransformer(
+ ...
+ )
+ (bert): BertModel(
+ ...
+ )
+ )
+ ```
+
+API接口具体实现如下:
+
+ ```python
+ ...
+ def encode_image(self, image, mask_ratio=0):
+ if isinstance(self.visual, ModifiedResNet):
+ # 调用visual submodule
+ return self.visual(image.type(self.dtype))
+ return self.visual(image.type(self.dtype), mask_ratio)
+
+ def encode_text(self, text):
+ ...
+ # 调用bert submodule
+ x = self.bert(text, attention_mask=attn_mask)[0].type(self.dtype)
+ return x[:, 0, :] @ self.text_projection
+
+ def get_similarity(self, image, text):
+ # 处理image
+ image_features = self.encode_image(image)
+ # 处理text
+ text_features = self.encode_text(text)
+
+ # normalized features
+ image_features = image_features / image_features.norm(dim=1, keepdim=True)
+ text_features = text_features / text_features.norm(dim=1, keepdim=True)
+
+ # cosine similarity as logits
+ logit_scale = self.logit_scale.exp()
+ logits_per_image = logit_scale * image_features @ text_features.t()
+ logits_per_text = logits_per_image.t()
+
+ return logits_per_image, logits_per_text
+ ```
+
+通过以上模型结构和推理接口的具体实现可以看出CLIP模型实际调用了visual和bert这两个子模块的forward,因此compile的编译对象为visual和bert模型。
+
+ ```python
+ config = torchair.CompilerConfig()
+ npu_backend = torchair.get_npu_backend(compiler_config=config)
+ model, preprocess = load_from_name('ViT-B-16', device=device)
+ model.visual = torch.compile(model.visual, backend=npu_backend)
+ model.bert = torch.compile(model.bert, backend=npu_backend)
+ ```
+
+## 常用编译配置
+#### 1. torchair.CompilerConfig类
+通过torchair.CompilerConfig类的experimental_config属性可以打开一些提升性能的功能:
+
+* experimental_config.frozen_parameter:推理场景下,该功能将模型执行期间地址不变的Tensor标识为[Parameter类型](https://docs.pytorch.org/docs/2.1/generated/torch.nn.parameter.Parameter.html),从而缩短图下发时间,提升下发性能。推荐设置为True。
+* experimental_config.tiling_schedule_optimize:静态Shape场景下,开启Tiling调度优化,Tiling计算将直接在Device测执行,提升静态Shape模型性能。静态场景下推荐设置为True。
+
+使用示例:
+
+```python
+config = torchair.CompilerConfig()
+config.experimental_config.frozen_parameter = True
+config.experimental_config.tiling_schedule_optimize = True
+npu_backend = torchair.get_npu_backend(compiler_config=config)
+
+model = torch.compile(model, backend=npu_backend)
+```
+
+### 2. torch.compile()
+
+PyTorch提供的原生接口,详细资料可参考[torch.compile文档](https://docs.pytorch.org/docs/stable/generated/torch.compile.html)
+
+```python
+torch.compile(
+ model: Callable[[_InputT], _RetT],
+ *,
+ fullgraph: bool = False,
+ dynamic: Optional[bool] = None,
+ backend: Union[str, Callable] = 'inductor',
+ disable: bool = False
+) → Callable[[_InputT], _RetT][source]
+```
+
+参数说明
+
+|参数名|参数说明|
+|----|----|
+|model|如图的模型或者函数,必选|
+|fullgraph| bool类型,可选。是否对整图进行优化。
- False(缺省值):自动查询可优化部分,有不支持的算子自动断图
-True:捕获整图优化,但如果发生断图会抛出异常|
+|dynamic|bool类型或者None,可选。是否使用动态shape tracing。
- None(缺省值):自动检测是否为动态图
- False:执行静态图
- True:执行动态图|
+|backend|后端选择,缺省值为“inductor",目前昇腾NPU暂不支持。昇腾NPU成图只有一种后端,通过torchair.get_npu_backend接口获取,必选。|
+|disable| bool类型,可选。是否关闭torch.compile能力。
- False(缺省值):开启torch.compile能力
- True:关闭torch.compile能力,采用单算子模式|
+
+参数建议:
+
+* fullgraph:整图性能优于非整图,推荐设置为True。如果遇到不支持成图的计算逻辑,可通过替换算子、修改代码位置等方式尝试解决。
+* dynamic:静态图性能优于动态图,推荐设置为True。静态图要求输入shape固定,且计算过程中不能出现动态shape。
+
+使用示例:
+
+```python
+config = torchair.CompilerConfig()
+config.experimental_config.frozen_parameter = True
+config.experimental_config.tiling_schedule_optimize = True
+npu_backend = torchair.get_npu_backend(compiler_config=config)
+
+model = torch.compile(model, dynamic=False, fullgraph=True, backend=npu_backend)
+```
+
+## 成图问题解决方案
+
+### 1. 拆分prefill和decode阶段
+
+因为prefill和decode阶段使用的FlashAttention算子不同,需要分别成图。
+
+原模型代码:
+
+```python
+import torch, torch_npu, torchair
+class Decoder(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, x, ...):
+ ...
+ return ...
+```
+
+为prefill和decode单独写一个方法,在forward里判断走prefill还是decode
+
+```python
+import torch, torch_npu, torchair
+class Decoder(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, x, ...):
+ if x.size(1) > 1:
+ ...
+ return self.prefill(x, ...)
+ ...
+ return self.decode(x, kv_cache, ...)
+
+ def prefill(self, x, ...):
+ ...
+ return ...
+
+ def decode(self, x, kv_cache, ...):
+ ...
+ return ...
+
+decoder = Decoder()
+# 编译prefill和decode方法
+config = torchair.CompilerConfig()
+config.experimental_config.frozen_parameter = True
+config.experimental_config.tiling_schedule_optimize = True
+npu_backend = torchair.get_npu_backend(compiler_config=config)
+decoder.prefill = torch.compile(decoder.prefill, dynamic=False, fullgraph=True, backend=npu_backend)
+# decode阶段通常是自回归模式运行,输入不固定,dynamic设置为True
+decoder.decode = torch.compile(decoder.decode, dynamic=True, fullgraph=True, backend=npu_backend)
+```
+
+### 2. 将不支持的算子移出图编译部分
+
+部分数据类型或者操作不支持torchAir成图,可以将这部分代码从需要被优化的函数中移出。比如在语音模型中通常会用torch.stft()对语音输入做傅里叶变换,torchAir不支持该操作,无法通过编译,可以调整这部分代码位置,使得我们可以用fullgraph=True让函数主体可以成图。
+
+示例:
+
+替换前:
+
+```python
+class Model
+ ...
+ def decode(self, x):
+ spec = torch.stft(x, N_FFT, HOP_LENGTH, window, return_complex=True)
+ ....
+ return ...
+
+ def forward(self, x):
+ ...
+ out = self.decode(x)
+ ...
+ return ...
+```
+
+替换后:
+
+```python
+class Model
+ ...
+ def decode(self, x, spec):
+ # spec = torch.stft(x, N_FFT, HOP_LENGTH, window, return_complex=True)
+ ....
+ return ...
+
+ def forward(self, x):
+ spec = torch.stft(x, N_FFT, HOP_LENGTH, window, return_complex=True) # 将stft操作移出decode
+ stft_real_img = torch.view_as_real(spec) # npu不支持复数
+ out = self.decode(x, stft_real_img)
+ ...
+ return ...
+
+model = Model()
+model.decode = torch.compile(model.decode)
+```
+
+### 3. 算子converter注册和实现
+完成算子注册PyTorch后,支持以Eager模式调用自定义算子,但不支持图模式调用。自定义算子还需向TorchAir注册并实现对应converter函数,完成ATen IR向GE IR转换,实现算子在NPU上入图。具体步骤参考:[converter补齐](https://gitee.com/ascend/torchair/blob/7.1.0/CONTRIBUTING.md#converter%E8%A1%A5%E9%BD%90)
\ No newline at end of file
diff --git "a/ACL_PyTorch/docs/torchair/\345\233\276\344\274\230\345\214\226/3.\345\212\237\350\203\275\347\262\276\345\272\246\351\227\256\351\242\230\345\256\232\344\275\215.md" "b/ACL_PyTorch/docs/torchair/\345\233\276\344\274\230\345\214\226/3.\345\212\237\350\203\275\347\262\276\345\272\246\351\227\256\351\242\230\345\256\232\344\275\215.md"
new file mode 100644
index 0000000000000000000000000000000000000000..c02c8b66b9ea259245165f05a5d71604054975fc
--- /dev/null
+++ "b/ACL_PyTorch/docs/torchair/\345\233\276\344\274\230\345\214\226/3.\345\212\237\350\203\275\347\262\276\345\272\246\351\227\256\351\242\230\345\256\232\344\275\215.md"
@@ -0,0 +1,72 @@
+## 功能问题定位
+模型中存在不支持的算子会导致torchAir成图失败,具体支持算子清单可以参考:[ATen API支持清单](https://www.hiascend.com/document/detail/zh/Pytorch/710/modthirdparty/torchairuseguide/torchair_00040.html)
+
+### 1. dynamo成图python日志
+
+开启方式:
+
+```python
+import logging
+torch._logging.set_logs(dynamo=logging.DEBUG, aot=logging.DEBUG, output_code=True, graph_code=True, recompiles=True)
+```
+
+主要用于原生dynamo报错时定位具体问题,主要看FX图guard信息,可以看到dynamo成FX图的具体流程,从而找到dynamo报错或者不支持的原因。
+
+### 2. TorchAir python日志
+开启方式:
+```python
+import logging
+torchair.logger.setLevel(logging.DEBUG)
+```
+
+当发生例如converter内部报错、TorchAir配置项报错等问题时,可以打印该日志查看TorchAir的converter转换流程或配置项等。
+
+### 3. TorchAir C++日志
+
+开启方式:
+```bash
+export TNG_LOG_lEVEL=0
+```
+
+主要用于定位Ascend IR图是否配置以及内存问题。在该日志中可以看到Ascend IR图是否编译成功以及传入的具体配置项,同时还能够看到输入/输出的内存申请。
+
+## 精度问题定位
+
+### 1. TorchAir python日志
+开启方式:
+```python
+import logging
+torchair.logger.setLevel(logging.DEBUG)
+```
+
+除了converter内部报错、TorchAir配置项报错等问题,该日志还可以查看单个算子的输入/输出shape用于定位精度问题。
+
+### 2. Tensor print
+图模式下,图内中间节点的device输出结果无法直接print查看。可以通过npu_print API打印。使用方式如下:
+
+```python
+import torch, torch_npu, torchair
+
+class Model(nn.Module):
+ ...
+ def forward(self, x):
+ ...
+ y = x ** 2
+ torchair.ops.npu_print("y=", y, summarize_size=2)
+ ...
+
+model = Model()
+model = torch.compile(model)
+out = model(x)
+```
+
+API说明:def npu_print(*args, summarize_size=3)
+
+* *args: 支持输入torch.Tensor和python基本类型 - str, bool, float, int
+* summarize_size: 打印tensor时,每个轴首尾展示的数据量
+
+接口约束:
+
+* 入参args中至少包含一个Tensor类型输入
+* summarize_size为大于0的正整数或者-1(全打印)
+* 接口是异步打印,多个npu_print的输出顺序与被打印节点执行序一致
\ No newline at end of file
diff --git "a/ACL_PyTorch/docs/torchair/\345\233\276\344\274\230\345\214\226/4.\351\253\230\351\230\266\347\211\271\346\200\247.md" "b/ACL_PyTorch/docs/torchair/\345\233\276\344\274\230\345\214\226/4.\351\253\230\351\230\266\347\211\271\346\200\247.md"
new file mode 100644
index 0000000000000000000000000000000000000000..8378240f83c2176db42884af472f77ecfc9d66fe
--- /dev/null
+++ "b/ACL_PyTorch/docs/torchair/\345\233\276\344\274\230\345\214\226/4.\351\253\230\351\230\266\347\211\271\346\200\247.md"
@@ -0,0 +1,187 @@
+## 1. 动态分档
+
+当模型输入不是固定shape,但可以固定为某几个shape时(通常是batch维度变化,有固定几种batch输入),如果使用动态图性能无法达到最优,但静态图每次遇到新的shape会重新编译。此时可以使用动态分档功能以获得静态图的性能优化,且只需编译一次。
+
+示例:
+
+```python
+import torch
+import torch_npu
+import torchair
+
+class Model(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ def forward(self, x1, x2):
+ return x1 + x2
+
+input1 = torch.ones(2, 2).npu()
+input2 = torch.ones(2, 2).npu()
+config = torchair.CompilerConfig()
+# zip方式,位置一一对应,支持(2, 2)(2, 2)和(4, 2)(4, 2)两种输入组合
+config.inference_config.dynamic_gears_merge_policy = "zip" #缺省值
+# product方式,排列组合,支持(2, 2)(2, 2)、(2, 2)(4, 2)、(4, 2)(2, 2)和(4, 2)(4, 2)四种输入组合
+# config.inference_config.dynamic_gears_merge_policy = "product"
+npu_bakcend = torchair.get_npu_backend(compiler_config=config)
+model = Model().npu()
+# 必须整图编译
+model = torch.compile(model, fullgraph=True, backend=npu_bakcend)
+
+# 设置档位
+torchair.inference.set_dim_gears(input1, dim_gears={0:[2, 4]})
+torchair.inference.set_dim_gears(input2, dim_gears={0:[2, 4]})
+
+# 首次编译+执行,shape为(2, 2)、(2, 2)
+print(model(input1, input2))
+
+# 再次执行,shape为(4, 2)、(4, 2)在档位中,不会触发重新编译
+input1 = torch.ones(4, 2).npu()
+input2 = torch.ones(4, 2).npu()
+print(model(input1, input2))
+```
+
+## 2. torch.dynamo.mark_static()
+
+PyTorch提供的接口,标记某个输入的shape为固定shape。将输入全部使用该接口固定后,如果模型中无动态shape,torch.compile的dynamic参数设置为True时,也可以得到静态图。如果模型输入shape固定,但内部有动态shape,使用这个参数可以在GE编译时生成部分静态子图,提升性能。
+
+示例:
+
+```python
+import torch
+import torch_npu
+import torchair
+
+class Decoder(nn.Module):
+ ...
+
+ def decode(
+ self,
+ x: Tensor,
+ xa: Tensor,
+ positional_embedding: Tensor,
+ kv_cache: Optional[dict] = None,
+ updated_kv_positions: Optional[torch.LongTensor] = None,
+ actual_seq_len: Optional[list] = None,
+ kv_padding_size: Optional[torch.LongTensor] = None
+ ):
+ ...
+ return ...
+
+ def compute_logits(self, x, xa):
+ ...
+ torch._dynamo.mark_static(x)
+ torch._dynamo.mark_static(xa)
+ torch._dynamo.mark_static(positional_embedding)
+ for i in range(n_layer):
+ torch._dynamo.mark_static(self.kv_cache[i]['attn']["key"])
+ torch._dynamo.mark_static(self.kv_cache[i]['attn']["value"])
+ torch._dynamo.mark_static(self.kv_cache[i]['cross_attn']["key"])
+ torch._dynamo.mark_static(self.kv_cache[i]['cross_attn']["value"])
+ torch._dynamo.mark_static(kv_padding_size)
+
+ return self.decode(x, xa, positional_embedding, self.kv_cache,
+ actual_seq_len=[actual_seq_len], kv_padding_size=kv_padding_size,
+ updated_kv_positions=updated_kv_positions)
+
+config = torchair.CompilerConfig()
+npu_bakcend = torchair.get_npu_backend(compiler_config=config)
+decoder = Decoder().npu()
+decoder.decode = torch.compile(decoder.decode, dynamic=True, fullgraph=True, backend=npu_backend)
+```
+
+## 3. 编译缓存
+使用torch.compile编译时,每次启动程序都要重新编译,且编译时间通常较长,不方便调试,我们可以通过cache_compile接口,将首次编译结果落盘到磁盘,加速图模式的启动时间。
+
+一般使用torch.compile的时候,首次执行需要Dynamo编译,Torch Guards,Ascend IR图编译等过程,后续每次执行的时候还需要运行Guards函数来判断是否需要重新编译。使用编译缓存后,首次执行直接load cache,后续再次执行的时候跳过Guards函数。
+
+缓存编译接口:
+
+torchair.inference.cache_compile
+
+```python
+def cache_compile(
+ func, # 缓存编译的method,只支持module的method
+ *,
+ config: Optional[CompilerConfig] = None,
+ dynamic: bool = True, # 是否按照输入动态trace
+ cache_dir: Optional[str] = None, # 缓存根目录,默认.torchair_cache
+ global_rank: Optional[int] = None, # 分布式训练时的rank,默认torch.distributed.get_rank()
+ tp_rank: Optional[int] = None, # 指定的tp rank
+ pp_rank: Optional[int] = None, # 指定的pp rank
+ custom_decompositions: optional[dict] = None, # 用户自定义的decompose策略
+ ge_cache: bool = False, # 是否开启GE缓存
+ **kwargs
+) -> Callable:
+```
+
+### 使用约束:
+
+* func函数只能被处罚一次Dynamo trace, 如果func发生重编译,则会放弃缓存。
+* 对于发生多次trace (Guards失效)的函数,需要进行一次函数封装来使缓存生效。
+* func必须是module实例对象的method,且该方法未被其他装饰器修饰
+* func必须能形成整图,即必须支持full graph
+* 只支持推理模式,不支持带反向计算过程的func缓存。
+
+Ascend IR编译缓存
+
+* 除了优化Dynamo编译耗时,还支持优化Ascend IR图编译耗时,主要通过cache_compile中的ge_cache参数实现,以进一步加速图模式启动时间。
+* 缺省情况下,ge_cache=False (功能不开启),因受CANN版本变更影响,用户需根据实际情况手动开启该功能
+* 缓存的编译结果文件路径与封装的func函数缓存文件路径一致
+
+### 使用示例
+
+原模型代码:
+
+```python
+import torch, torch_npu, torchair
+class Model(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ @torch.inference_mode()
+ def forward(self, x, y):
+ return x + y
+```
+
+适配步骤
+
+* 提取forward的实现为_forward函数,缓存编译_forward函数
+* 初始化时将被编译的函数定义为Module自身的一个属性
+* forward里执行被编译的函数,也就是初始化中设置的属性
+
+```python
+import torch, torch_npu, torchair
+class Model(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.cached_forward = torchair.Inference.cache_compile(self._forward)
+
+ def forward(self, x, y):
+ # 修改为编译后的模型
+ return self.cached_forward(x, y)
+
+ def _forward(self, xy, y):
+ return x + y
+```
+
+带有transformer结构的模型中常见forward中会在prefill和decode阶段中使用不同的算子或者计算逻辑,此时需要针对每个场景新增一个函数:
+
+```python
+class Model(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ config = torchair.CompilerConfig()
+ self.cached_prefill = torchair.inference.cache_compile(self.prefill, config=config)
+ self.cached_decode = torchair.inference.cache_compile(self.decode, config=config)
+
+ def forward(self, x, ...):
+ if x.size(1) > 1:
+ return self.cached_prefill(x, ...)
+ return self.cached_decode(x, ...)
+
+ def prefill(self, x, ...):
+ return ...
+
+ def decode(self, x, kv_cache):
+ return ...
+```
\ No newline at end of file