# npu_inductor_mlir2_2.6.0 **Repository Path**: handsomelemon/npu_inductor_mlir2_2.6.0 ## Basic Information - **Project Name**: npu_inductor_mlir2_2.6.0 - **Description**: npu_inductor_mlir2_2.6.0 - **Primary Language**: Unknown - **License**: Not specified - **Default Branch**: master - **Homepage**: None - **GVP Project**: No ## Statistics - **Stars**: 0 - **Forks**: 2 - **Created**: 2025-10-23 - **Last Updated**: 2025-10-23 ## Categories & Tags **Categories**: Uncategorized **Tags**: None ## README # 简介 本项目开发了名为npu_inductor_mlir(Integrating PyTorch with MLIR for NPU acceleration.)的扩展库,支持用户基于PyTorch框架和torch_npu插件在昇腾NPU上使用图模式进行训练和推理。 npu_inductor_mlir继承自PyTorch框架[Dynamo模式](https://pytorch.org/docs/stable/torch.compiler_deepdive.html),将PyTorch的[FX图](https://pytorch.org/docs/stable/fx.html)由npu_inductor_mlir翻译成Linalg/Hfusion Dialect,经Bisheng编译器编译执行。 # 编译及安装 ## 安装torch-mlir `torch_mlir`安装社区版本,执行`bash build_torch_mlir.sh`或参考`build_torch_mlir.sh`中的步骤进行安装 ## 依赖pytorch和torch_npu版本 支持`pytorch==2.6.0`以及`torch_npu==2.6.0` ## 环境变量设置 ```python export PATH=/path/to/ccec_compiler/bin:$PATH ``` ## Q&A Q:如果遇到`libtinfo.so.5: cannot open shared object file: No such file or directory`这种问题,如何解决? A: `apt install libtinfo5`或者`yum install libtinfo5` # 快速上手 基础例子: ```python import copy import torch import torch_npu import torch.nn as nn import time os.environ['TORCHINDUCTOR_MAX_AUTOTUNE'] = '1' from torch_npu._inductor.ascend_npu_ir.ascend_npu_ir.npu.utils import logger import logging logger.setLevel(logging.INFO) from torch._inductor import config config.debug = True config.trace.enabled = True config.triton.unique_kernel_names = True class Model(nn.Module): def __init__(self): super().__init__() def forward(self, x, y, z): return z + (y + x * z) * (x + y) + 0.22111 model = Model() x = torch.randn([32, 32], device='npu') y = torch.randn([32, 32], device='npu') z = torch.randn([32, 32], device='npu') model_compiled = torch.compile(model, mode='max-autotune-no-cudagraphs') print(model_compiled(x, y, z)) print(model(x, y, z)) ``` - 若不报错,说明成功了 - 执行过程中会看到中间过程产生的FX图,linalg MLIR图 # 特性支持 - 加一行代码接入TorchInductor 在代码中开头加`import npu_inductor_mlir.npu.npu_inductor_plugin`,就会拦截原来inductor->triton的路径,走inductor->torch-mlir的路径 - 支持融合算子编译及执行fallback至CPU backend执行,用于验证FX Graph融合算子的子图抓取和torch-mlir转换的功能和精度,当前已跑通LLaMA3整网。 设置`npu_inductor_mlir_config.fallback_to_torch_mlir_cpu_backend = True`,走torch_mlir自带的cpu backend,False则走BiShengIR的backend - 支持手动控制部分算子fallback至Aten 当前默认开启fallback `aten.mm`和`aten.bmm`,在`python/npu_inductor_mlir/config.py`下可以手动控制 - 继承TorchInductor的debug能力 添加如下代码 ```py from torch._inductor import config config.debug = True config.trace.enabled = True config.triton.unique_kernel_names = True ``` 执行完成后,在当前目录产生`torch_compile_debug`目录,包含Codegen产生的中间结果 - 支持auto_fallback,若融合算子编译失败,则自动fallback到fx图 设置`npu_inductor_mlir_config.auto_fallback = True`开启 # bishengir ``` bishengir-compile -enable-hfusion-compile=true --enable-bin-relocation=0 -block-dim=48 --enable-auto-multi-buffer=false --enable-ops-reorder=false --hfusion-max-buffer-count-tuning=0 -enable-tuning-mode=true --enable-static-bare-ptr=false --enable-symbol-analysis=true mlir_fused_add_mul_0_named_op.mlir -o mlir_fused_add_mul_0_16_False_False ``` #Todo List - [x] 支持更多API组合corner case - [x] 支持融合算子入task queue - [x] 支持更多中间结果dump - [x] 支持基于trace,将Inductor scheduler Node转化为fx图 - [x] 支持与triton解耦 - [x] 支持auto fallback,融合编译失败时,自动fallback到fx图 - [x] 支持动态shape下,Inductor scheduler Node转化为fx图 - [ ] 支持`TORCHINDUCTOR_COMPILE_THREADS > 1` …