diff --git a/torch_npu/_inductor/__init__.py b/torch_npu/_inductor/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a6779c7c602bfbfd9854d06e9bc8e211a9082846 --- /dev/null +++ b/torch_npu/_inductor/__init__.py @@ -0,0 +1,111 @@ +import os + +import torch +from torch._dynamo.device_interface import register_interface_for_device, get_interface_for_device +from torch._inductor import lowering as inductor_lowering +from torch._inductor.choices import InductorChoices +from torch._inductor.codegen.common import register_backend_for_device, register_device_op_overrides +from torch._inductor.runtime import autotune_cache +from torch_npu.npu import device_count +from torch_npu.utils._dynamo_device import NpuInterface, current_device, set_device +from torch_npu.utils._inductor import NPUDeviceOpOverrides + +from . import config as npu_config +from . import codegen +from .npu_fusion_attention_graph import register_fa_pass +from .config import aggresive_autotune, num_vector_core, set_compile_threads, disable_comprehensive_padding +from .config import log as npulog +from .decomposition import _register_npu_inductor_decompositons +from .lowering import make_reduction, npu_make_fallback +from .npu_choices import should_use_persistent_reduction +from .npu_device import NewNPUDeviceOpOverrides +from .runtime import _load_cached_autotuning +from .utils import get_current_raw_stream, patch_is_gpu, patch_has_triton +from .codecache import patch_aot_code_compiler_compile, patch_cache_base_get_system + +set_compile_threads() +disable_comprehensive_padding() + + +def _inductor_register_backend_for_device(): + from .codegen.scheduling import NPUTritonScheduling + from .codegen.wrapper import NPUWrapperCodeGen + from .codegen.cpp_wrapper import CppWrapperNpu + register_backend_for_device('npu', NPUTritonScheduling, NPUWrapperCodeGen, CppWrapperNpu) + + +_inductor_register_backend_for_device() + + +def _inductor_register_device_op_overrides(): + register_device_op_overrides('npu', NewNPUDeviceOpOverrides()) + + +_inductor_register_device_op_overrides() + +device = get_interface_for_device("npu") + +inductor_lowering.make_reduction = make_reduction +inductor_lowering.make_fallback = npu_make_fallback + + +def patch_torch_for_aoti(): + from .graph import patch_codegen_with_cpp_wrapper + from .cpp_builder import patch_get_cpp_torch_device_options + from .codegen.cpp_utils import patch_device_to_aten + from .utils import patch_is_same_tensor + from .fx_passes.joint_graph import patch_constant_fold_uniform_value + from .ir import patch_fallback_kernel_codegen + + patch_codegen_with_cpp_wrapper() + patch_get_cpp_torch_device_options() + patch_device_to_aten() + patch_is_same_tensor() + patch_constant_fold_uniform_value() + patch_fallback_kernel_codegen() + + patch_aot_code_compiler_compile() + + + +if os.environ.get("DISABLE_AOTI_PATCH", "0") != "1": + patch_torch_for_aoti() + + +if npu_config.dump_fx_graph: + from .codegen.ir_fx import _patch_npu_inductor_ir + + _patch_npu_inductor_ir() + +if npu_config.dump_fx_graph: + from .lowering_fx import _register_npu_inductor_fallbacks +else: + from .lowering import _register_npu_inductor_fallbacks + +_register_npu_inductor_fallbacks() +_register_npu_inductor_decompositons() + + +# register fx_pass should be put behind of _register_npu_inductor_decompositons +def _replace_benchmark_all_configs(): + from torch._inductor.triton_heuristics import CachingAutotuner + from .npu_triton_heuristics import benchmark_all_configs + CachingAutotuner.benchmark_all_configs = benchmark_all_configs + + +if (aggresive_autotune): + _replace_benchmark_all_configs() + import os + + os.environ["TRITON_BENCH_METHOD"] = "npu" + +InductorChoices.should_use_persistent_reduction = should_use_persistent_reduction +autotune_cache._load_cached_autotuning = _load_cached_autotuning + +register_fa_pass() +patch_cache_base_get_system() +patch_is_gpu() +patch_has_triton() + + + diff --git a/torch_npu/_inductor/codecache.py b/torch_npu/_inductor/codecache.py new file mode 100644 index 0000000000000000000000000000000000000000..1efec225f25adfc124285f8f489446adeb9bf4d7 --- /dev/null +++ b/torch_npu/_inductor/codecache.py @@ -0,0 +1,130 @@ +import os +import contextlib +import hashlib +import json +from typing import ( + Any, + Callable, + cast, + Dict, + Generator, + List, + NoReturn, + Optional, + Sequence, + Tuple, + TYPE_CHECKING, + TypeVar, + Union, +) + +import torch +from torch._inductor import config +from torch._inductor.codecache import CacheBase, get_lock_dir, LOCK_TIMEOUT +from torch._inductor.graph import GraphLowering +import torch_npu +from torch_npu.utils._error_code import ErrCode, pta_error + +empty_json = "{}" + + +@contextlib.contextmanager +def lock_context(key): + from filelock import FileLock + lock_dir = get_lock_dir() + lock = FileLock(os.path.join(lock_dir, key + ".lock"), timeout=LOCK_TIMEOUT) + with lock: + yield + + + +def patch_cache_base_get_system(): + # patch function CacheBase.get_system with get_system_npu, add logic to support CANN + @staticmethod + def get_system(): + try: + from triton.compiler.compiler import triton_key + + # Use triton_key instead of triton.__version__ as the version + # is not updated with each code change + triton_version = triton_key() + except ModuleNotFoundError: + triton_version = None + + try: + system: Dict[str, Any] = { + "device": {"name": None}, + "version": { + "triton": triton_version, + }, + } + device_properties = torch_npu.npu.get_device_properties( + torch_npu.npu.current_device() + ) + if torch.version.cann is not None: + system["device"]["name"] = device_properties.name + system["version"]["cann"] = torch.version.cann + elif torch.version.cuda is not None: + system["device"]["name"] = device_properties.name + system["version"]["cuda"] = torch.version.cuda + else: + system["device"]["name"] = device_properties.gcnArchName + system["version"]["hip"] = torch.version.hip + except (AssertionError, RuntimeError): + # If deivce is not installed, none of the above config is relevant. + system = {} + + system["hash"] = hashlib.sha256( + json.dumps(system, sort_keys=True).encode("utf-8") + ).hexdigest() + + return system + + CacheBase.get_system = get_system + + +def patch_aot_code_compiler_compile(): + # In v2.6.0, aoti has bug when init oss_proxy_executor with default op_json, + # which could not be skipped, so here we try to create a new npu op_json, + # and clear the content of default op_json. + from torch._inductor.codecache import AotCodeCompiler + AotCodeCompiler.src_compile = AotCodeCompiler.compile + + @classmethod + def compile_npu( + cls, + graph: GraphLowering, + source_code: str, + serialized_extern_kernel_nodes: Optional[str], + device_type: str, + additional_files: List[str], + ) -> Union[List[str], str]: + result = cls.src_compile( + graph, source_code, serialized_extern_kernel_nodes, + device_type, additional_files + ) + generated_files = additional_files + if not config.aot_inductor.package: + return result + + output_so = [r for r in result if r.endswith(".so")] + if len(output_so) > 1: + raise RuntimeError(f"Could not generate npu op json, because there are" + f"more than one so in generated files: {result}" + pta_error(ErrCode.INTERNAL)) + output_so = output_so[0] + key = os.path.basename(output_so)[0].replace(".", "_") + dir_basename = os.path.splitext(output_so)[0] + with lock_context(key): + if serialized_extern_kernel_nodes: + extern_kernel_nodes_json = dir_basename + "_npu.json" + with open(extern_kernel_nodes_json, "w") as f: + f.write(serialized_extern_kernel_nodes) + generated_files.append(extern_kernel_nodes_json) + + if serialized_extern_kernel_nodes: + source_json_file = dir_basename + ".json" + with open(source_json_file, "w") as f: + f.write(empty_json) + return generated_files + AotCodeCompiler.compile = compile_npu + \ No newline at end of file diff --git a/torch_npu/_inductor/codegen/__init__.py b/torch_npu/_inductor/codegen/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9ec8bf6a9514dd02c5cb40b8e7c18afb42b3198e --- /dev/null +++ b/torch_npu/_inductor/codegen/__init__.py @@ -0,0 +1,35 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved. + + +from torch._inductor import sizevars +from torch._inductor.codegen.simd import SIMDKernel +from torch._inductor.codegen.triton import TritonKernel +from torch._inductor.codegen.triton import TritonScheduling +from torch._inductor.ir import Reduction, LoopBody +from torch_npu._inductor.codegen._sizevars import simplify +from torch_npu._inductor.codegen.ir import (num_splits, loopbody__call__, transform_dims_in_indexing, + substituted_dims_in_indexing) +from torch_npu._inductor.codegen.scheduling import create_tiling +from torch_npu._inductor.codegen.triton import group_fn, select_index_dtype +from torch_npu._inductor.codegen.triton import is_compatible + +from ..config import log as npulog + + +Reduction.num_splits = num_splits +setattr(LoopBody, 'transform_dims_in_indexing', transform_dims_in_indexing) +setattr(LoopBody, 'substituted_dims_in_indexing', substituted_dims_in_indexing) + +LoopBody.__call__ = loopbody__call__ +# need to enable this to speedup attn_cp_test +# triton scheduling +TritonScheduling.group_fn = group_fn +TritonScheduling.select_index_dtype = select_index_dtype +TritonScheduling.create_tiling = create_tiling +# triton kernel +setattr(SIMDKernel, 'is_compatible', is_compatible) + +# util +sizevars.SizeVarAllocator.simplify = simplify diff --git a/torch_npu/_inductor/codegen/_sizevars.py b/torch_npu/_inductor/codegen/_sizevars.py new file mode 100644 index 0000000000000000000000000000000000000000..f2947420502897d2c1f4ffb819cdd810b86fbe2c --- /dev/null +++ b/torch_npu/_inductor/codegen/_sizevars.py @@ -0,0 +1,9 @@ +import sympy +from sympy import Expr +from torch._inductor.utils import sympy_subs + + +def simplify(self, expr: Expr): + if isinstance(expr, (tuple, list)): + return [sympy.expand(s).xreplace(self.replacements) for s in expr] + return sympy.expand(expr).xreplace(self.replacements) diff --git a/torch_npu/_inductor/codegen/cpp_utils.py b/torch_npu/_inductor/codegen/cpp_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7a9b0887e7f914a3c42b5746257ab20973c15755 --- /dev/null +++ b/torch_npu/_inductor/codegen/cpp_utils.py @@ -0,0 +1,6 @@ +import torch_npu + + +def patch_device_to_aten(): + from torch._inductor import codegen + codegen.cpp_utils.DEVICE_TO_ATEN["npu"] = "at::kPrivateUse1" diff --git a/torch_npu/_inductor/codegen/cpp_wrapper.py b/torch_npu/_inductor/codegen/cpp_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..9a16cfabf8af4321b5b3ebfff157a2093c03803f --- /dev/null +++ b/torch_npu/_inductor/codegen/cpp_wrapper.py @@ -0,0 +1,896 @@ +import functools +import os +import sys +from itertools import chain, count, zip_longest +from typing import Any, Callable, List, Optional, Tuple, TYPE_CHECKING, Union +import sympy +import torch +from torch import dtype as torch_dtype +from torch._inductor import config +from torch._inductor.codecache import CudaKernelParamCache +from torch._inductor.codecache import get_cpp_wrapper_cubin_path_name +from torch._inductor.codegen.aoti_hipify_utils import maybe_hipify_code_wrapper +from torch._inductor.codegen.common import get_device_op_overrides +from torch._inductor.codegen.cpp_utils import cexpr, DTYPE_TO_CPP, DEVICE_TO_ATEN +from torch._inductor.codegen.cpp_wrapper_cpu import CppWrapperCpu +from torch._inductor.codegen.multi_kernel import MultiKernelCall +from torch._inductor.codegen.wrapper import PythonWrapperCodegen, SymbolicCallArg +from torch._inductor.ir import IRNode, TensorBox +from torch._inductor.runtime.runtime_utils import dynamo_timed +from torch._inductor.utils import DeferredLineBase +from torch._inductor.virtualized import V +from torch._inductor.utils import _align, ALIGN_BYTES + +from .. import config as npu_config +from ..config import npu_block as NPU_ALIGN_BYTES + +if TYPE_CHECKING: + from torch._inductor.graph import GraphLowering + + +def checkIfTrue(value, msg): + if not value: + raise RuntimeError(msg) + return True + + +class DeferredNpuKernelLine(DeferredLineBase): + """ + When using cpp wrapper, NPU kernel load and launch needs to wait for Triton kernels + to be tuned and stored as cubin files, so use a deferred line to backfill those information + """ + + def __init__( + self, + kernel_name: str, + line_template: str, + keys: Tuple[str, ...], + additional_files: List[str], + ): + super().__init__(line_template) + checkIfTrue(not isinstance(line_template, DeferredLineBase), "line template can not be DeferredLineBase") + self.additional_files = additional_files + self.kernel_name = kernel_name + self.line_template = line_template + self.keys = keys + + def __call__(self): + if self.kernel_name.startswith("multi_kernel_"): + # MultiKernel will select one kernel after running the autotune block + self.kernel_name = MultiKernelCall.lookup_choice(self.kernel_name) + params = CudaKernelParamCache.get(self.kernel_name) + checkIfTrue(params is not None, f"{self.kernel_name} not found in CudaKernelParamCache") + + for key in self.keys: + checkIfTrue(key in params, f"{key} not found in CudaKernelParamCache[{self.kernel_name}]") + + if key == get_cpp_wrapper_cubin_path_name(): + checkIfTrue(os.path.exists(params[key]), f"{params[key]} does not exist") + self.additional_files.append(params[key]) + + return self.line_template % tuple(params[key] for key in self.keys) + + def _new_line(self, line): + return DeferredNpuKernelLine( + self.kernel_name, line, self.keys, self.additional_files + ) + + +class DeferredNpuDefaultGrid: + """ + A container for the default grid, which may be used by DeferredNpuGridLine + """ + + def __init__( + self, + kernel_name: str, + grid, + grid_callable: Optional[Callable[..., Any]] = None, + **grid_extra_kwargs, + ): + self.kernel_name = kernel_name + self.grid = grid + self.grid_callable = grid_callable + self.grid_extra_kwargs = grid_extra_kwargs + + def __iter__(self): + # DeferredNpuDefaultGrid can be passed to the base class, PythonWrapperCodegen, + # to generate the autotune code block, and thus we need this iterator + return iter(self.grid) + + def _process_grid(self, grid: Union[List[Any], Tuple[Any, ...]]): + if isinstance(grid, (list, tuple)): + return [self._process_grid(e) for e in grid] + else: + return grid.inner_expr if isinstance(grid, SymbolicCallArg) else grid + + def __call__(self): + if self.kernel_name.startswith("multi_kernel_"): + # MultiKernel will select one kernel after running the autotune block + self.kernel_name = MultiKernelCall.lookup_choice(self.kernel_name) + + grid = self.grid + checkIfTrue(isinstance(grid, (list, tuple)), f"expected {grid=} to be a list") + + grid = self._process_grid(grid) + + checkIfTrue(self.grid_callable is not None, "grid_callable can't be None") + + if not self.grid_extra_kwargs: + grid_fn = self.grid_callable(*grid) + else: + grid_fn = self.grid_callable(*grid, **self.grid_extra_kwargs) + + params = CudaKernelParamCache.get(self.kernel_name) + checkIfTrue(params is not None, f"{self.kernel_name} not found in CudaKernelParamCache") + + return grid_fn(params["meta"]) + + +class DeferredNpuGridLine(DeferredLineBase): + """ + When using cpp wrapper, NPU kernel load and launch needs to wait for Triton kernels + to be tuned and stored as cubin files, so use a deferred line to backfill those information + """ + + def __init__( + self, + kernel_name: str, + grid_var: str, + grid, + autotune_configs, + ): + super().__init__("") + self.kernel_name = kernel_name + self.grid_var = grid_var + self.grid = grid + self.autotune_configs = autotune_configs + + def __call__(self): + if self.kernel_name.startswith("multi_kernel_"): + # MultiKernel will select one kernel after running the autotune block + self.kernel_name = MultiKernelCall.lookup_choice(self.kernel_name) + + params = CudaKernelParamCache.get(self.kernel_name) + + checkIfTrue(params is not None, f"{self.kernel_name} not found in CudaKernelParamCache") + + if self.autotune_configs is not None: + # This indicates the Triton kernel is a user-defined one. + grid = None + if len(self.grid) == 1: + grid = self.grid[0] + else: + for i, c in enumerate(self.autotune_configs): + if all(arg == params["meta"][key] for key, arg in c.kwargs.items()): + grid = self.grid[i] + break + checkIfTrue(grid is not None, "grid can not be None") + grid_args_str = ", ".join( + [cexpr(V.graph.sizevars.simplify(item)) for item in grid] + ) + else: + launch_grid = (params['grid_x'], params['grid_y'], params['grid_z']) + grid_args_str = ", ".join( + [cexpr(item) for item in launch_grid] + ) + + return f"\n Grid {self.grid_var} = Grid({grid_args_str});\n" + + def _new_line(self, line): + return DeferredNpuGridLine( + self.kernel_name, self.grid_var, self.grid, self.autotune_configs + ) + + +class CppWrapperNpu(CppWrapperCpu): + """ + Generates cpp wrapper for running on NPU and calls CUDA kernels + """ + + def __init__(self) -> None: + self.device = 'npu' + self.device_codegen = get_device_op_overrides(self.device) + super().__init__() + self.grid_id = count() + self.visited_raii_handle = set() + self.visited_handle_for_kernel_id = dict() + + @staticmethod + def create( + is_subgraph: bool, subgraph_name: str, parent_wrapper: PythonWrapperCodegen + ): + # comment at CppWrapperCpu `codegen_subgraph` function. + return CppWrapperNpu() + + def super_write_header_rewrite(self): + """Copied from CppWrapperCpu to: + (1) change __file__ path for cpython, so that we can use aoti_runtime in current path. + (2) rewrite include path of aoti header file. + """ + if V.graph.is_const_graph: + # We do not write header for constant graph, it will be written by main module. + return + + if V.graph.aot_mode: + self.header.splice( + """ + #include + #include + """ + ) + with open( + os.path.join(os.path.dirname(__file__), "aoti_runtime", "interface.cpp") + ) as f: + self.header.splice(f.read()) + else: + self.header.splice( + """ + import torch + from torch._inductor.codecache import CppWrapperCodeCache + + cpp_wrapper_src = ( + ''' + #include + namespace py = pybind11; + + class RAIIPyObject { + public: + RAIIPyObject() : obj_(nullptr) {} + RAIIPyObject(PyObject* obj) : obj_(obj) {} + ~RAIIPyObject() { + Py_XDECREF(obj_); + } + RAIIPyObject& operator=(const RAIIPyObject& other) { + if (this != &other) { + Py_XDECREF(obj_); + obj_ = other.obj_; + Py_XINCREF(obj_); + } + return *this; + } + operator PyObject*() { + return obj_; + } + PyObject* get() { + return obj_; + } + private: + PyObject* obj_; + }; + + #include + #include + using namespace torch::aot_inductor; + """ + ) + + self.header.splice( + f""" + #include + #include + #include + // Here comment c_shim_npu.h because npu doesn't implement it. + // #include + + #include + typedef at::Half half; + typedef at::BFloat16 bfloat16; + + // Round up to the nearest multiple of {ALIGN_BYTES} + [[maybe_unused]] static int64_t align(int64_t nbytes) {{ + return (nbytes + {ALIGN_BYTES} - 1) & -{ALIGN_BYTES}; + }} + """ + ) + extend_aoti_c_shim_include = ( + f"torch/csrc/inductor/aoti_torch/generated/extend/c_shim_{self.device}.h" + ) + extend_aoti_c_shim_path = os.path.join( + os.path.dirname(torch.__file__), + "include", + extend_aoti_c_shim_include, + ) + if os.path.exists(extend_aoti_c_shim_path): + self.header.splice(f"#include <{extend_aoti_c_shim_include}>") + + enable_kernel_profile = config.cpp.enable_kernel_profile and sys.platform in [ + "linux", + "win32", + ] + if config.profiler_mark_wrapper_call or enable_kernel_profile: + # No C shim for profiling APIs, assuming profiling is a debugging feature which + # does not provide any ABI compatibility promise. + self.header.splice("#include ") + + def write_header(self): + if V.graph.is_const_graph: + # We do not write header for constant graph, it will be written by main module. + return + + self.super_write_header_rewrite() + self.header.splice("#include ") + self.header.splice("#include ") + self.header.splice(self.device_codegen.abi_compatible_header()) + self.header.splice( + maybe_hipify_code_wrapper(self.device_codegen.kernel_driver()) + ) + self.header.splice("#include ") + self.header.splice("#include ") + if npu_config.aot_inductor.debug_kernel: + self.header.splice("#include ") + + def write_get_raw_stream(self, device_idx: int, graph=None) -> str: + name = f"stream{device_idx}" + self.writeline( + maybe_hipify_code_wrapper( + f"{self.device_codegen.cpp_stream_type()} {name};" + ) + ) + self.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK({self.device_codegen.aoti_get_stream()}({device_idx}, (void**)&{name}));" + ) + return name + + def codegen_inputs(self): + # See Note: [Input Alignment handling in Inductor] + # + # JIT Inductor does not guard on input alignment. It relies on copy_misaligned_inputs to + # copy misaligned inputs to aligned buffers. For AOTInductor, we expect users to use it + # as non-Python deployment for its best performance, so implicitly copying misaligned inputs + # to aligned buffers is going to bring a surprising performance hit. Instead, we check input + # alignment and throw an error if any input is misaligned. + if V.graph.aot_mode and V.graph.inputs_to_check: + for idx in V.graph.inputs_to_check: + input_name = V.graph.graph_input_names[idx] + checkIfTrue(input_name in V.graph.graph_inputs, f"{input_name} not found in graph inputs") + + value = V.graph.graph_inputs[input_name] + checkIfTrue(isinstance(value, TensorBox), + f"{input_name} is expected to be tensor but found as {type(value)}") + + self.prefix.splice( + f""" + if ((long({input_name}.data_ptr()) & ({NPU_ALIGN_BYTES} -1)) != 0) {{ + throw std::runtime_error("{input_name} is not aligned to {NPU_ALIGN_BYTES} bytes"); + }} + """ + ) + + super().codegen_inputs() + + def define_kernel( + self, + kernel_name: str, + kernel_body: str, + metadata: Optional[str] = None, + gpu=True, + ): + if gpu: + if config.triton.autotune_at_compile_time: + # Call PythonWrapperCodegen to create the autotune code block + PythonWrapperCodegen.define_kernel( + self, kernel_name, kernel_body, metadata, gpu + ) + else: + return CppWrapperCpu.define_kernel( + self, kernel_name, kernel_body, metadata, gpu + ) + + def generate(self, is_inference): + with dynamo_timed("CppWrapperNpu.generate", log_pt2_compile_event=True): + self.prefix.writeline("\n") + if not V.graph.aot_mode: + for kernel in chain( + sorted(self.src_to_kernel.values()), + sorted( + [entry[0] for entry in self.user_defined_kernel_cache.values()] + ), + ): + self.prefix.writeline( + maybe_hipify_code_wrapper( + f"static {self.device_codegen.cpp_kernel_type()} {kernel} = nullptr;" + ) + ) + self.prefix.writeline("\n") + return super().generate(is_inference) + + def generate_user_defined_triton_kernel( + self, + kernel_name: str, + raw_args: List[Any], + grid: List[Any], + configs, + triton_meta, + constexprs, + ): + if ( + config.triton.autotune_at_compile_time + and kernel_name not in self.kernel_autotune_names + ): + # Call PythonWrapperCodegen to create the autotune code block + PythonWrapperCodegen.generate_user_defined_triton_kernel( + self, + kernel_name, + raw_args, + grid, + configs, + triton_meta, + constexprs, + ) + + # in C++ wrapper, we don't pass constexpr args, as they don't + # get added as parameters to the PTX code compiled from the + # user-defined Triton kernel (only non-constexpr args do) + raw_args = [raw_arg for i, raw_arg in enumerate(raw_args) if i not in constexprs] + args = [self.val_to_arg_str(v) for v in raw_args] + arg_types = [ + arg.get_dtype() if isinstance(arg, IRNode) else type(arg) + for arg in raw_args + ] + + # Call self.generate_kernel_call to generate the real kernel call in cpp + self.generate_kernel_call( + kernel_name, + args, + arg_types=arg_types, + raw_args=raw_args, + grid=grid, + gpu=True, + triton=True, + triton_meta=triton_meta, + autotune_configs=configs, + ) + + @functools.lru_cache(None) # noqa: B019 + def generate_load_kernel_once( + self, + kernel_name: str, + device_index, + graph: "GraphLowering", # for per-graph caching + ): + keys = (get_cpp_wrapper_cubin_path_name(), "mangled_name", "shared_mem") + kernel_var_name = f"kernels.{kernel_name}" if V.graph.aot_mode else kernel_name + self.writeline(f"if ({kernel_var_name} == nullptr) {{") + deferred_gpu_kernel_line = DeferredNpuKernelLine( + kernel_name, + " " + kernel_var_name + r' = loadKernel("%s", "%s", %s, this->cubin_dir_);', + keys, + self.additional_files, + ) + self.writeline(deferred_gpu_kernel_line) + self.writeline("}") + return kernel_var_name + + def codegen_tensor_item_npu( + self, dtype: torch.dtype, tensor: str, scalar: str, indented_buffer=None + ): + dtype_str = str(dtype).split(".")[-1] + writer = indented_buffer or self + + if dtype == torch.float16 or dtype == torch.bfloat16: + scalar_tmp = f"{scalar}_tmp" + writer.writeline(f"{DTYPE_TO_CPP[dtype]} {scalar_tmp};") + writer.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_item_{dtype_str}({tensor}, &{scalar_tmp}));" + ) + writer.writeline(f"float {scalar} = float({scalar_tmp});") + struct_data = f'float {scalar} __attribute__((aligned(4)));' + arg_data = f'static_cast({scalar})' + else: + writer.writeline(f"{DTYPE_TO_CPP[dtype]} {scalar};") + writer.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_item_{dtype_str}({tensor}, &{scalar}));" + ) + struct_data = f'{DTYPE_TO_CPP[dtype]} {scalar} __attribute__((aligned(sizeof({DTYPE_TO_CPP[dtype]} ))));' + arg_data = f'static_cast<{DTYPE_TO_CPP[dtype]}>({scalar})' + + return struct_data, arg_data + + def codegen_device(self, device): + if device.type not in DEVICE_TO_ATEN: + raise RuntimeError(device.type + "not found in DEVICE_TO_ATEN") + device_str = DEVICE_TO_ATEN[device.type][5:].lower() # remove "at::k" + if device_str == "privateuse1": + device_str = "npu" + self.used_cached_devices.add(device_str) + return f"cached_torch_device_type_{device_str}, {device.index if device.index else 0}" + + def write_wrapper_decl(self): + super().write_wrapper_decl() + with self.prefix.indent(): + if not V.graph.aot_mode: + return + dump_path = npu_config.aot_inductor.dump_path_cpp + if npu_config.aot_inductor.debug_kernel: + self.prefix.splice( + f""" + auto dump_path = std::filesystem::current_path() / "{dump_path}"; + if (!std::filesystem::exists(dump_path)) {{ + std::filesystem::create_directory(dump_path); + }} + """ + ) + + self.prefix.splice( + """ + auto tensor_handle_to_tensor_pointer = [](AtenTensorHandle handle) { + return reinterpret_cast(handle); + }; + """ + ) + + def generate_debug_str(self, args, kernel_name, kernel_id, mark): + if not npu_config.aot_inductor.debug_kernel: + return "" + if kernel_id not in self.visited_handle_for_kernel_id: + self.visited_handle_for_kernel_id[kernel_id] = set() + + def get_tensor_from_handle(h, t): + if h in self.visited_handle_for_kernel_id[kernel_id]: + return "" + self.visited_handle_for_kernel_id[kernel_id].add(h) + return f" auto {t} = *tensor_handle_to_tensor_pointer({h});\n" + + # Only dump tensor args, e.g, ['buf2', '8L', '4L'] => ['buf2'] + tensor_args = [arg for arg in args if not arg[0].isdigit()] + + tensor_args_h = [f"{arg}_h" for arg in tensor_args] + tensor_args_t = [f"{arg}_t" for arg in tensor_args] + handle_tensor_str = "".join([ + get_tensor_from_handle(h, t) for h, t in zip(tensor_args_h, tensor_args_t) + ]) + + dump_path = npu_config.aot_inductor.dump_path_cpp + return f""" + c10_npu::npuSynchronizeDevice(); + \n{handle_tensor_str} + std::vector arg_{mark}{{{", ".join(tensor_args_t)}}}; + torch::save(arg_{mark}, "{dump_path}/{kernel_id}_{kernel_name}_{mark}.pt"); + """ + + def generate_launch_call( + self, + call_args, + arg_types, + arg_signatures, + kernel_id, + grid_var, + kernel_name + ): + kernel_val_name = f"kernels.{kernel_name}" if V.graph.aot_mode else kernel_name + new_args: list[str] = [] + + # Add more cases for other types as needed + signature2dtype = { + "i1": "int32_t", + "i8": "int8_t", + "i16": "int16_t", + "i32": "int32_t", + "i64": "int64_t", + "u32": "uint32_t", + "u64": "uint64_t", + "fp16": "float", + "bf16": "float", + "fp32": "float", + "f32": "float", + "fp64": "double", + } + + + struct_def_body = '' + struct_arg_body = '' + + def process_args(arg, arg_type, arg_signature=None): + var_name = f"var_{next(self.arg_var_id)}" + # ignore nvTmaDesc, as host-side TMA descriptors need + # to be passed to the compiled Triton kernel by value + if isinstance(arg_type, torch_dtype) and arg_signature != "nvTmaDesc": + if arg.endswith(".item()"): # scalar + # Need to declare a scalar in this case + arg = arg[:-7] + struct_data, arg_data = self.codegen_tensor_item_npu( + arg_type, + arg, + var_name, + ) + else: + # void* + device_ptr_type = self.device_codegen.cpp_device_ptr() + self.writeline( + maybe_hipify_code_wrapper( + f"{device_ptr_type} {var_name} = reinterpret_cast<{device_ptr_type}>({arg}.data_ptr());" + ) + ) + if npu_config.aot_inductor.debug_kernel: + if arg not in self.visited_raii_handle: + self.writeline( + f"AtenTensorHandle {arg}_h = {arg}.get();" + ) + self.visited_raii_handle.add(arg) + struct_data = f'void* {var_name} __attribute__((aligned(8)));' + arg_data = f'static_cast({var_name})' + + elif arg_type in (sympy.Integer, int): + # int + self.writeline(f"int {var_name} = {cexpr(arg)};") + struct_data = f'int {var_name} __attribute__((aligned(4)));' + arg_data = f'static_cast({var_name})' + + elif arg_type in (sympy.Float, float): + # float + self.writeline(f"float {var_name} = {cexpr(arg)};") + struct_data = f'float {var_name} __attribute__((aligned(4)));' + arg_data = f'static_cast({var_name})' + + # For symbolic call arguments, examine the arg signatures from triton meta + # to explicitly cast to the right type + # Reason: `auto` can infer unexpected type against kernel input signature. + elif ( + isinstance(arg_type, type(SymbolicCallArg)) + and arg_signature is not None + and arg_signature in signature2dtype.keys() + ): + # or scalar symbolic type,currently only support scalar symbolic type + self.writeline( + f"{signature2dtype[arg_signature]} {var_name} = {cexpr(arg)};" + ) + struct_data = f'{signature2dtype[arg_signature]} {var_name} __attribute__((aligned(sizeof({signature2dtype[arg_signature]}))));' + arg_data = f'static_cast<{signature2dtype[arg_signature]}>({var_name})' + else: + raise TypeError("Infer arg_type to cpp failed!") + + nonlocal struct_def_body + nonlocal struct_arg_body + struct_def_body += struct_data + ' ' + struct_arg_body += arg_data + ', ' + + for arg, arg_type, arg_signature in zip_longest( + call_args, arg_types, arg_signatures + ): + process_args(arg, arg_type, arg_signature) + + debug_str_before_kernel = self.generate_debug_str(call_args, kernel_name, kernel_id, "before") + debug_str_after_kernel = self.generate_debug_str(call_args, kernel_name, kernel_id, "after") + + launch_str = f""" + auto launch_call_{kernel_id} = [=]() {{ + int32_t grid_x = {grid_var}.grid_x; + int32_t grid_y = {grid_var}.grid_y; + int32_t grid_z = {grid_var}.grid_z; + rtError_t ret; + void* ffts_addr = NULL; + uint32_t ffts_len; + ret = rtGetC2cCtrlAddr((uint64_t*)&ffts_addr, &ffts_len); + if (ret != RT_ERROR_NONE) return ret; + void* workspace_addr = NULL; + + struct __attribute__((packed)) {{ + void* ffts_addr __attribute__((aligned(8))); + void* workspace_addr __attribute__((aligned(8))); + {struct_def_body} + int32_t grid_x __attribute__((aligned(4))); + int32_t grid_y __attribute__((aligned(4))); + int32_t grid_z __attribute__((aligned(4))); + }} kernel_args = {{ + static_cast(ffts_addr), + static_cast(workspace_addr), + {struct_arg_body} + static_cast(grid_x), + static_cast(grid_y), + static_cast(grid_z) + }}; + + uint32_t block_num = grid_x * grid_y * grid_z; + auto arg_ptr = static_cast(&kernel_args); + auto arg_size = sizeof(kernel_args); + {debug_str_before_kernel} + ret = rtKernelLaunch({kernel_val_name}, block_num, arg_ptr, arg_size, NULL, stream); + {debug_str_after_kernel} + if (ret != RT_ERROR_NONE) return ret; + return ret; + }}; + """ + return f"launch_call_{kernel_id}", launch_str + + def generate_default_grid( + self, + kernel_name: str, + grid_args: List[Any], + gpu: bool = True, + grid_callable: Optional[Callable[..., Any]] = None, + **grid_extra_kwargs, + ): + """ + Generate grid configs for launching a CUDA kernel using the grid + function from triton_heuristics. Because its computation needs + to read kernel config after autotune, it is done in a deferred way + using DeferredNpuDefaultGrid. + """ + checkIfTrue(gpu, "CppWrapperNpu.generate_default_grid does not support non-NPU") + return DeferredNpuDefaultGrid( + kernel_name, grid_args, grid_callable, **grid_extra_kwargs + ) + + def generate_kernel_call_npu( + self, + kernel_name: str, + call_args, + grid=None, + device_index=None, + npu=True, + triton=True, + arg_types=None, + raw_args=None, + grid_fn: str = "grid", + triton_meta=None, + autotune_configs=None, + grid_extra_kwargs="", + ): + if ( + config.triton.autotune_at_compile_time + and kernel_name not in self.kernel_autotune_names + ): + # Call PythonWrapperCodegen to create the autotune code block + PythonWrapperCodegen.generate_kernel_call( + self, + kernel_name, + call_args, + grid, + device_index, + npu, + triton, + arg_types, + raw_args, + grid_fn, + triton_meta, + autotune_configs, + grid_extra_kwargs, + ) + + if device_index is None: + current_device = V.graph.get_current_device_or_throw() + device_index = current_device.index + + stream = ( + "stream" + if V.graph.aot_mode + else self.write_get_raw_stream(device_index, V.graph) + ) + + if triton: + device_index, call_args = self.prepare_triton_kernel_call( + device_index, call_args + ) + _ = self.generate_load_kernel_once(kernel_name, device_index, V.graph) + + # args with value 1 are added into equal_to_1 and constants + # in triton_meta (in the Python codegen) which makes them + # inlined in the PTX and compiled CUBIN + arg_signatures = [] + if ( + triton_meta is not None + and triton_meta.get("configs") + and triton_meta.get("signature") + ): + equal_to_1 = triton_meta["configs"][0].equal_to_1 + call_args = [ + arg + for i, arg in enumerate(call_args) + if i not in equal_to_1 + ] + arg_types = [t for i, t in enumerate(arg_types) if i not in equal_to_1] + # extract the arg signatures from triton_meta + arg_signatures = triton_meta["signature"].values() + arg_signatures = [ + v + for i, v in enumerate(arg_signatures) + if i not in equal_to_1 + ] + + current_kernel_id = next(self.kernel_callsite_id) + current_grid_id = next(self.grid_id) + + # gen grids + grid_var = f"{kernel_name}_grid_{current_grid_id}" + self.writeline( + DeferredNpuGridLine(kernel_name, grid_var, grid, autotune_configs) + ) + + call, call_args_str = self.generate_launch_call( + call_args, arg_types, arg_signatures, current_kernel_id, grid_var, kernel_name + ) + self.writeline(f"{call_args_str}") + + # add debug printer code for all triton kernel related calls + debug_printer_manager = V.graph.wrapper_code.debug_printer + debug_printer_manager.set_printer_args( + call_args, kernel_name, arg_types, None + ) + with debug_printer_manager: + self.writeline(f"if ({grid_var}.is_non_zero()) {{") + self.writeline( + DeferredNpuKernelLine( + kernel_name, + r" launchKernel({}, {});".format( \ + call, + f'"{kernel_name}"', + ), + (), + self.additional_files, + ), + ) + + self.writeline("}\n") + else: + casted = [] + for arg_type, arg in zip(arg_types, call_args): + new_arg = arg + if arg_type.endswith("*") and arg != "nullptr": + new_arg = f"{arg}.data_ptr()" + casted.append(f"({arg_type}){new_arg}") + call_args_str = ", ".join(casted) + self.writeline(f"kernels.{kernel_name}({call_args_str}, {stream});") + + def generate_kernel_call( + self, + kernel_name: str, + call_args, + grid=None, + device_index=None, + gpu=True, + triton=True, + arg_types=None, + raw_args=None, + grid_fn: str = "grid", + triton_meta=None, + autotune_configs=None, + grid_extra_kwargs="", + ): + """ + Override the default value of argument 'gpu' to True here. + generate_kernel_call can still be called with gpu=False because of + a mix of cpu kernels and gpu kernels. + """ + + """ + To fit with NPU: we write a new function 'generate_kernel_call_npu + and make a new parameter called 'npu', which always equals to 'gpu', + because 'gpu' parameter means 'not cpu' in upper logic + """ + + if not gpu: + # Even in CppWrapperNpu, we may see cpp kernels + return CppWrapperCpu.generate_kernel_call( + self, + kernel_name, + call_args, + grid, + device_index, + gpu, + triton, + arg_types, + raw_args, + grid_fn, + triton_meta, + autotune_configs, + grid_extra_kwargs, + ) + + self.generate_kernel_call_npu( + kernel_name, + call_args, + grid, + device_index, + gpu, + triton, + arg_types, + raw_args, + grid_fn, + triton_meta, + autotune_configs, + grid_extra_kwargs, + ) + + def make_zero_buffer(self, name): + return f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_zero_({name}.get()));" diff --git a/torch_npu/_inductor/codegen/ir.py b/torch_npu/_inductor/codegen/ir.py new file mode 100644 index 0000000000000000000000000000000000000000..b288ad8ae58b404652a958adf8a6993c7f35adbe --- /dev/null +++ b/torch_npu/_inductor/codegen/ir.py @@ -0,0 +1,199 @@ +from typing import List, Tuple, Dict, Any, Optional +import itertools +import sympy +from torch._inductor.ir import (ReductionHint, IRNode, ModularIndexing, FloorDiv) +from torch._inductor.utils import sympy_subs, sympy_index_symbol +from torch._inductor.virtualized import V +from torch_npu._inductor.codegen.triton import NPUIndexTritonKernel + +from ..config import log + + +# NPU doesn't need to support ReductionHint.OUTER, and persistent reduction +def num_splits( + device, + dst_dtype, + src_dtype, + inner_fn, + ranges, + reduction_ranges, + reduction_type, + reduction_numel, + input_node: Optional[IRNode] = None, +): + return ReductionHint.DEFAULT, 1 + + +def detect_flattened_dims(kernel, index): + new_vars = {} + if not isinstance(index, (sympy.core.add.Add, ModularIndexing, FloorDiv)): + return new_vars + + def detect_flattened_axis(expr): + def init_new_vars(var, length): + if var not in new_vars: + new_vars[var] = {length: [None, None]} + if length not in new_vars[var]: + new_vars[var][length] = [None, None] + + if isinstance(expr, ModularIndexing): + var, divisor, length = expr.args + init_new_vars(var, length) + new_vars[var][length][1] = (expr, divisor, length) + elif isinstance(expr, FloorDiv): + var, divisor = expr.args + init_new_vars(var, divisor) + # over than 1 node_schedule, var may be deleted in kernel.range_tree_nodes + # it shoule be find in range_tree_nodes_removed dict + if (var in kernel.range_tree_nodes): + numel = kernel.range_tree_nodes[var].length + else: + numel = kernel.range_tree_nodes_removed[var].length + + length = expr.eval(numel, divisor) + new_vars[var][divisor][0] = (expr, divisor, length) + + else: + for x in expr.args: + detect_flattened_axis(x) + + # add + if isinstance(index, sympy.core.add.Add): + for x in index.args: + detect_flattened_axis(x) + elif isinstance(index, (ModularIndexing, FloorDiv)): + detect_flattened_axis(index) + else: + pass + + # make sure FloorDiv, MouldarIndexing must be in-pair + for var, divisors in new_vars.items(): + if var in kernel.range_tree_nodes: + parent_axis = kernel.range_tree_nodes[var] + else: + parent_axis = kernel.range_tree_nodes_removed[var] + for divisor, pair in divisors.items(): + if not pair[0] and not pair[1]: + pass + # FloorDiv not inplace + elif not pair[0]: + _, _, length = pair[1] + expr = FloorDiv(var, length) + new_vars[var][divisor][0] = (expr, length, parent_axis.length // length) + # ModularIndexing not inplace + elif not pair[1]: + expr = ModularIndexing(var, 1, divisor) + new_vars[var][divisor][1] = (expr, 1, divisor) + else: + pass + + return new_vars + + +def rebuild_flattened_dims(indexing): + def rebuild_flattened_dim(key, index, old_node, flatten_dim): + for _, pair in flatten_dim.items(): + new_var_expr = sympy.Integer(0) + origin_axis_length = 0 + pair_is_valid = True + # don't create duplicated axis, e.g. y1:1024, y1 % 1024 is duplicated + expr, divisor, length = pair[1] + if not old_node.parent.duplicated_check(divisor, length): + if expr not in V.kernel.expr_substituted: + V.kernel.expr_substituted[expr] = old_node.symbol() + break + + for axis in pair: + expr, divisor, length = axis + # 3. try to rebuild the axis in kernel + new_node = old_node.parent.lookup(divisor, length) + + # 4. substitute div/mod expression in indexing + index = index.subs(expr, new_node.symbol()) + indexing[key] = index + if isinstance(expr, FloorDiv): + new_var_expr = new_var_expr + new_node.symbol() * divisor + origin_axis_length = divisor * length + elif isinstance(expr, ModularIndexing): + new_var_expr = new_var_expr + new_node.symbol() + V.kernel.expr_substituted[expr] = new_node.symbol() + + if var not in V.kernel.range_tree_nodes_substituted: + V.kernel.range_tree_nodes_substituted[var] = [] + V.kernel.range_tree_nodes_substituted[var].append((origin_axis_length, new_var_expr)) + + def find_index_in_substitute(index, kernel): + return any([index.find(key) for key in kernel.expr_substituted.keys()]) + + kernel = V.kernel + for key, index in indexing.items(): + # 1. try to find out flattened axis from indexing + flatten_dims = detect_flattened_dims(kernel, index) + # 2. try to rebuild these flattened dims + for var, flatten_dim in flatten_dims.items(): + if (var in kernel.range_tree_nodes): + old_node = kernel.range_tree_nodes[var] + else: + old_node = kernel.range_tree_nodes_removed[var] + + rebuild_flattened_dim(key, index, old_node, flatten_dim) + + if find_index_in_substitute(index, kernel): + new_index = sympy_subs(index, kernel.expr_substituted) + indexing[key] = new_index + + +def substituted_dims_in_indexing(self, indexing, kernel, range_tree_nodes_substituted): + substituted = False + for var, candidates in range_tree_nodes_substituted.items(): + if not (len(candidates) > 0): + raise RuntimeError("assert len(candidates) > 0, candidates") + exprs = sorted(candidates, reverse=True, key=lambda x: x[0]) + # the best candidate is with the longest numel + numel = exprs[0][0] + expr = exprs[0][1] + node = kernel.range_tree_nodes[var] + if node.length != numel: + log.debug("sub nodes (expr%s, numel:%d) can not substitute parent node(%s:%d)", + expr, numel, node.symbol(), node.length) + continue + for key, index in indexing.items(): + if var in index.free_symbols: + index = index.subs(var, expr) + indexing[key] = index + substituted = True + + return substituted + + +def generate_body_indexing(body, indices): + index = list(itertools.chain.from_iterable(indices)) + if not (len(index) == len(body.var_ranges)): + raise RuntimeError("assert len(index) == len(body.var_ranges), (index, body.var_ranges)") + if not (all(v not in body.var_ranges for v in index)): + raise RuntimeError("assert all(v not in body.var_ranges for v in index)") + + replacements = dict(zip(body.var_ranges.keys(), index)) + indexing_map = dict(zip(index, body.var_ranges.keys())) + setattr(body, 'indexing_map', indexing_map) + body.indexing = { + name: sympy_subs(expr, replacements) + for name, expr in body.indexing_exprs.items() + } + + +def transform_dims_in_indexing(self, indices): + if self.indexing is None: + generate_body_indexing(self, indices) + + if V.kernel is not None and isinstance(V.kernel, NPUIndexTritonKernel): + rebuild_flattened_dims(self.indexing) + + +# select tiling axis, recover missing dimensions, +def loopbody__call__(self, *indices): + if self.indexing is None: + generate_body_indexing(self, indices) + result = self.root_block() + self.indexing = None + return result diff --git a/torch_npu/_inductor/codegen/ir_fx.py b/torch_npu/_inductor/codegen/ir_fx.py new file mode 100644 index 0000000000000000000000000000000000000000..6b768760a7594c9b12d034c92942cb33e9f4b914 --- /dev/null +++ b/torch_npu/_inductor/codegen/ir_fx.py @@ -0,0 +1,864 @@ +import traceback +import typing +from typing import ( + Any, + Callable, + List, + Optional, + Union +) +from typing import Optional +from unittest.mock import patch +import sympy +import torch +from sympy import Expr +from torch._inductor import config +from torch._inductor import ir +from torch._inductor.virtualized import ops, V +from torch.utils._ordered_set import OrderedSet + +from ..lowering_fx import ( + fetch_graphs, + merge_traced_graphs, + node_id, + clone, + create_fake_input, + subtract_graph +) + + +def _patch_loops_get_name(self): + return self.node_name + + +def _patch_loops_get_traced_graph(self): + return self.traced_graph + + +@classmethod +def _patch_loops_create(cls, *args, **kwargs): + origin_node = kwargs.pop("origin_node", None) + traced_graph = kwargs.pop("traced_graph", None) + node_name = kwargs.pop("node_name", None) + tb = kwargs.pop("traceback", None) + r = cls(*args, **kwargs) + # Need to explicitly set origin_node here to propagate it down. + # todo(chilli): I think it would be better for IRNode to directly set + # origin_node + r._post_init_setattr("origin_node", origin_node) + r._post_init_setattr("traceback", tb or r.traceback) + r._post_init_setattr("traced_graph", traced_graph) + r._post_init_setattr("node_name", node_name) + return ir.TensorBox.create(r) + + +def _patch_pointwise_constant_to_device(self, device, traced_graph=None, node_name=None): + """Move this to a given device. Requires that all reads are to constants.""" + loader = self.make_loader() + loader = patch.object(ir.ConstantBuffer, "override_device", device)(loader) + + r = ir.Pointwise(device, self.dtype, loader, self.ranges) + r._post_init_setattr("traced_graph", traced_graph) + r._post_init_setattr("node_name", node_name) + return r + + +@classmethod +def _patch_reduction_create( + cls, + device: torch.device, + dst_dtype: torch.dtype, + src_dtype: torch.dtype, + inner_fn: Callable[..., Any], + ranges: ir.Sequence[Expr], + reduction_ranges: ir.Sequence[Expr], + reduction_type: str, + reduction_hint: ir.ReductionHint = ir.ReductionHint.DEFAULT, + input_node: Optional[ir.IRNode] = None, + traced_graph=None, + node_name: str = None +) -> ir.TensorBox: + reduction_numel = V.graph.sizevars.simplify(ir.sympy_product(reduction_ranges)) + + if reduction_numel == 0: + # N.B. This is a hack to generate the literal of the given type + # Ideally, we should be fixing `def constant` in triton.py + # but it breaks due to hardcoded dtypes in other places + def py_cnst(val: object) -> Union[bool, float, int]: + if dst_dtype == torch.bool: + return bool(val) + elif dst_dtype.is_floating_point: + if not isinstance(val, typing.SupportsFloat): + raise RuntimeError("assert val must support float conversion") + return float(val) + else: + if not isinstance(val, typing.SupportsInt): + raise RuntimeError("assert val must support int conversion") + return int(val) + + rtypes_to_inits = { + "sum": py_cnst(0), + "xor_sum": py_cnst(0), + "prod": py_cnst(1), + "any": py_cnst(0), + # "all" is desugared to `!any(!val)` + } + + if reduction_type not in rtypes_to_inits: + raise RuntimeError(f"assert {reduction_type} not supported for zero-dimension tensors!") + + def const_fn(index: int) -> ir.OpsValue: + return ops.constant(rtypes_to_inits[reduction_type], dst_dtype) + + return ir.Pointwise.create( + device=device, + dtype=src_dtype, + inner_fn=const_fn, + ranges=list(ranges), + traced_graph=traced_graph, + node_name=node_name + ) + + if reduction_numel == 1: + # this reduction is actually a pointwise op + if reduction_type in ("argmin", "argmax"): + + def fn(index: int) -> ir.OpsValue: + return ops.constant(0, dst_dtype) + + else: + + def fn(index: int) -> ir.OpsValue: + reduction_index = [sympy.S.Zero for _ in reduction_ranges] + return inner_fn(index, reduction_index) + + return ir.Pointwise.create( + device=device, dtype=dst_dtype, inner_fn=fn, ranges=ranges + ) + + if ( + isinstance(reduction_numel, ir.Integer) + and V.graph.sizevars.size_hint(reduction_numel) + < config.unroll_reductions_threshold + and (ir.sympy_product(ranges) != 1 or ir.is_gpu(device.type)) + ): + # NB: This works around pytorch issues 140457 + # since turning reductions into pointwise ops can exacerbate this problem + return ir.Pointwise.create( + device=device, + dtype=dst_dtype, + inner_fn=cls._unroll_reduction_fn( + inner_fn, reduction_ranges, reduction_type, src_dtype + ), + ranges=ranges, + traced_graph=traced_graph, + node_name=node_name + ) + + # triton doesn't support reduce to single element well, so break it up + hint, split = cls.num_splits( + device, + dst_dtype, + src_dtype, + inner_fn, + ranges, + reduction_ranges, + reduction_type, + reduction_numel, + input_node, + ) + # intermediate reduction in split can contain complex indexing, + # and num_splits will fail to correctly set the hint + # reuse the passed hint if available + if reduction_hint == ir.ReductionHint.DEFAULT: + reduction_hint = hint + if split == -1: + if input_node is None: + raise RuntimeError("assert input_node cannot be None") + new_ranges, new_reduction_ranges = ir.extract_input_node_reduction_ranges( + input_node + ) + if new_ranges is None: + raise RuntimeError("assert new_ranges cannot be None") + if new_reduction_ranges is None: + raise RuntimeError("assert new_reduction_ranges cannot be None") + return cls.create_multilayer_existing_ranges( + device, + dst_dtype, + src_dtype, + inner_fn, + ranges, + reduction_ranges, + new_ranges, + new_reduction_ranges, + reduction_type, + reduction_hint, + ) + elif split > 1: + # triton doesn't support reduce to single element well, so break it up + return cls.create_multilayer( + device, + dst_dtype, + src_dtype, + inner_fn, + ranges, + reduction_ranges, + reduction_type, + split, + reduction_hint, + ) + + r = ir.Reduction( + device=device, + dtype=dst_dtype, + inner_fn=inner_fn, + ranges=ranges, + reduction_ranges=reduction_ranges, + reduction_type=reduction_type, + src_dtype=src_dtype, + reduction_hint=reduction_hint, + ) + r._post_init_setattr("traced_graph", traced_graph) + r._post_init_setattr("node_name", node_name) + + return ir.TensorBox.create(r) + + +def _patch_baseview_get_traced_graph(self): + if hasattr(self, 'traced_graph') and self.traced_graph is not None: + return self.traced_graph + return self.data.get_traced_graph() + + +def _patch_base_view_get_reads(self): + with patch.object(ir.FlexibleLayout, "allow_indexing", True): + r = ir.extract_read_writes( + self.make_loader(), + self.get_size(), + ).reads + for md in r: + if md.index.has(ir.ModularIndexing): + if md.index.has(ir.FloorDiv): + self.realize() + return r + else: + for m in md.index.find(ir.ModularIndexing): + for arg in m.args: + if arg.has(ir.ModularIndexing): + self.realize() + return r + return r + + +def has_buffer(inp): + if not hasattr(inp, 'data'): + return False + if isinstance(inp.data, ir.Buffer): + return True + return has_buffer(inp.data) + + +def get_buffer(inp): + if isinstance(inp.data, ir.Buffer): + return inp.data + return get_buffer(inp.data) + + +def _patch_baseview_realize(self): + if hasattr(self, 'traced_graph') and self.traced_graph is not None: + r = self.data.realize() + buffer = get_buffer(self) + if isinstance(buffer, (ir.MultiOutput, ir.InputBuffer, ir.ConcatKernel)): + return r + traced_graph = buffer.data.get_traced_graph() + buf_name = buffer.get_name() + new_traced_graph, placeholder = subtract_graph(self.traced_graph, traced_graph, node_name=buf_name) + if placeholder is not None: + placeholder.name = buf_name + device = buffer.get_device() + dtype = buffer.get_dtype() + size = buffer.get_size() + stride = buffer.get_stride() + fake_input = create_fake_input(size, stride, device, dtype) + placeholder.meta['val'] = fake_input + self._post_init_setattr("traced_graph", new_traced_graph) + return r + else: + return self.data.realize() + + +def _patch_baseview_realize_hint(self): + if hasattr(self, 'traced_graph') and self.traced_graph is not None: + r = self.data.realize_hint() + if not has_buffer(self): + return r + buffer = get_buffer(self) + if isinstance(buffer, (ir.MultiOutput, ir.InputBuffer, ir.ConcatKernel)): + return r + traced_graph = buffer.data.get_traced_graph() + buf_name = buffer.get_name() + new_traced_graph, placeholder = subtract_graph(self.traced_graph, traced_graph, node_name=buf_name) + if placeholder is not None: + placeholder.name = buf_name + device = buffer.get_device() + dtype = buffer.get_dtype() + size = buffer.get_size() + stride = buffer.get_stride() + fake_input = create_fake_input(size, stride, device, dtype) + placeholder.meta['val'] = fake_input + self._post_init_setattr("traced_graph", new_traced_graph) + return r + else: + return self.data.realize_hint() + + +def _patch_mark_reuse(self, users): + if isinstance(self.data, ir.StorageBox): + if self.data.should_realize_on_reuse(users): + if hasattr(self, 'traced_graph') and self.traced_graph is not None: + r = self.data.realize() + buffer = get_buffer(self) + if isinstance(buffer, (ir.MultiOutput, ir.InputBuffer, ir.ConcatKernel)): + return r + traced_graph = buffer.data.get_traced_graph() + buf_name = buffer.get_name() + new_traced_graph, placeholder = subtract_graph(self.traced_graph, traced_graph, node_name=buf_name) + if placeholder is not None: + placeholder.name = buf_name + device = buffer.get_device() + dtype = buffer.get_dtype() + size = buffer.get_size() + stride = buffer.get_stride() + fake_input = create_fake_input(size, stride, device, dtype) + placeholder.meta['val'] = fake_input + self._post_init_setattr("traced_graph", new_traced_graph) + return r + else: + return self.data.realize() + else: + return self.data.mark_reuse(users) + + +@classmethod +def _patch_expandview_create(cls, x, new_size, traced_graph=None, node_name=None): + new_size = cls._normalize_size(x, new_size) + + if ir.is_storage_and_layout(x): + storage, old_layout = ir.as_storage_and_layout(x) + skip = len(new_size) - len(old_layout.size) + if skip < 0: + raise RuntimeError(f"assert Internal error: skip must be non-negative, got {skip}") + new_stride = [sympy.Integer(0)] * skip + for stride, size in zip(old_layout.stride, old_layout.size): + new_stride.append( + stride + if not V.graph.sizevars.shape_env.evaluate_expr( + sympy.Eq(size, 1), size_oblivious=True + ) + else sympy.Integer(0) + ) + new_layout = ir.FixedLayout( + old_layout.device, + old_layout.dtype, + list(new_size), + new_stride, + old_layout.offset, + ) + + r = ir.ReinterpretView(data=storage, layout=new_layout) + r._post_init_setattr("traced_graph", traced_graph) + r._post_init_setattr("node_name", node_name) + return r + + r = ir.ExpandView(data=x, size=new_size) + r._post_init_setattr("traced_graph", traced_graph) + r._post_init_setattr("node_name", node_name) + + return r + + +@classmethod +def _patch_permuteview_create(cls, x, dims, traced_graph=None, node_name=None): + dims = cls._map_neg_dims(dims) + if OrderedSet(dims) != OrderedSet(range(len(dims))): + raise RuntimeError("assert OrderedSet(dims) != OrderedSet(range(len(dims)))") + if ir.is_storage_and_layout(x): + storage, old_layout = ir.as_storage_and_layout(x) + new_layout = ir.FixedLayout( + old_layout.device, + old_layout.dtype, + [old_layout.size[i] for i in dims], + [old_layout.stride[i] for i in dims], + old_layout.offset, + ) + r = ir.ReinterpretView(data=storage, layout=new_layout) + r._post_init_setattr("traced_graph", traced_graph) + r._post_init_setattr("node_name", node_name) + return r + + r = ir.PermuteView(data=x, dims=dims) + r._post_init_setattr("traced_graph", traced_graph) + r._post_init_setattr("node_name", node_name) + return r + + +@classmethod +def _patch_view_create(cls, x, new_size, traced_graph=None, node_name=None): + if not isinstance(new_size, (tuple, list)): + raise RuntimeError("assert new_size must be tuple, list, or tuple") + old_size, new_size = cls.resolve_negative_size(x.get_size(), new_size) + # Skip pointless views + if V.graph.sizevars.statically_known_list_equals(old_size, new_size): + return x + + unbacked_symbols_in_sizes = False + if ( + len(ir.free_unbacked_symbols(old_size)) > 0 + or len(ir.free_unbacked_symbols(new_size)) > 0 + ): + unbacked_symbols_in_sizes = True + + if 0 in new_size: + + def fake_reindex(index): + return tuple([0] * len(old_size)) + + r = cls(x, list(new_size), fake_reindex) + r._post_init_setattr("traced_graph", traced_graph) + r._post_init_setattr("node_name", node_name) + return r + + # next: a new class for FixedTransferLayout that output layout is constrained by input layout + elif (ir.is_contiguous_storage_and_layout( + x) or unbacked_symbols_in_sizes): # and not isinstance(x.data, ir.ReinterpretView): + if unbacked_symbols_in_sizes and (not ir.is_contiguous_storage_and_layout(x)): + # realize x; otherwise, the dynamic_reshape_indexer below will fail + # due to the size_hint's inability to process unbacked SymInts + x = ir.ExternKernel.realize_input(x) + + storage, old_layout = ir.as_storage_and_layout(x, want_contiguous=True) + new_layout = ir.FixedLayout( + old_layout.device, + old_layout.dtype, + new_size, + ir.FlexibleLayout.contiguous_strides(new_size), + old_layout.offset, + ) + + r = ir.ReinterpretView(data=storage, layout=new_layout) + r._post_init_setattr("traced_graph", traced_graph) + r._post_init_setattr("node_name", node_name) + return r + + reindex = cls.dynamic_reshape_indexer(old_size, new_size) + + r = cls(data=x, size=list(new_size), reindex=reindex) + r._post_init_setattr("traced_graph", traced_graph) + r._post_init_setattr("node_name", node_name) + return r + + +@classmethod +def _patch_sliceview_create(cls, x, dim, start, end, step=1, clamp=True, traced_graph=None, + node_name=None): # next: crm, clamp=True + step = sympy.expand(step) + if not (isinstance(step, sympy.Expr) or step > 0): + raise RuntimeError("assert step must be a sympy.Expr or a positive number") + try: + if start == 0 and end >= 2 ** 63 - 1 and step == 1: + return x + except TypeError: + pass + sizevars = V.graph.sizevars + new_size = list(x.get_size()) + + if clamp: + start, end = cls.normalize_start_end(x, dim, start, end) + + new_size[dim] = ir.FloorDiv(end - start + (step - 1), step) + + if ir.is_storage_and_layout(x): + # Fast path + storage, old_layout = ir.as_storage_and_layout(x) + new_stride = list(old_layout.stride) + new_stride[dim] = new_stride[dim] * step + new_layout = ir.FixedLayout( + old_layout.device, + old_layout.dtype, + new_size, + new_stride, + old_layout.offset + old_layout.stride[dim] * start, + ) + r = ir.ReinterpretView(data=storage, layout=new_layout) + r._post_init_setattr("traced_graph", traced_graph) + r._post_init_setattr("node_name", node_name) + return r + + def reindex(index): + if len(index) != len(new_size): + raise RuntimeError(f"assert wrong ndim {index} {new_size}") + index = list(index) + index[dim] = index[dim] * step + start + return index + + # redirect to a generic view + r = ir.SliceView(data=x, size=new_size, reindex=reindex) + r._post_init_setattr("traced_graph", traced_graph) + r._post_init_setattr("node_name", node_name) + return r + + +def _patch_buffer_get_traced_graph(self): + return self.traced_graph + + +@classmethod +def _patch_concatkernel_create(cls, inputs, dim): + device = inputs[0].get_device() + dtype = inputs[0].get_dtype() + new_size = list(inputs[0].get_size()) + offsets_start = [0] + offsets_end = [new_size[dim]] + if not (0 <= dim < len(new_size)): + raise RuntimeError(f"assert dim ({dim}) must be between 0 and {len(new_size) - 1}") + for i in range(1, len(inputs)): + input_size = inputs[i].get_size() + offsets_start.append(new_size[dim]) + if len(input_size) != len(new_size): + raise RuntimeError( + f"assert input_size and new_size is not same. Got {len(input_size)} vs {len(new_size)}") + if inputs[i].get_dtype() != dtype: + raise RuntimeError(f"assert Expected dtype {dtype}, but got {inputs[i].get_dtype()}") + if inputs[i].get_device() != device: + raise RuntimeError(f"assert Expected device {device}, but got {inputs[i].get_device()}") + + for j in range(len(new_size)): + if j == dim: + new_size[j] = new_size[j] + input_size[j] + else: + new_size[j] = V.graph.sizevars.guard_equals( + new_size[j], input_size[j] + ) + offsets_end.append(new_size[dim]) + + output_stride = ir.FlexibleLayout.contiguous_strides(new_size) + # If any of the inputs is in CL format, use CL format for the output + for i in range(len(inputs)): + x = inputs[i] + if ir.is_storage_and_layout(x): + layout = x.get_layout() + if ( + isinstance(layout, ir.FixedLayout) + and layout.is_channels_last_contiguous(layout.size, layout.stride) + ): + # use CL stride for the output + output_stride = ir.make_channels_last_strides_for(new_size) + break + + any_input_is_storage_and_layout = any(ir.is_storage_and_layout(x) for x in inputs) + fx_node_args = V.graph.current_node.args[0] + if not isinstance(fx_node_args, list): + raise RuntimeError("assert fx_node_args must be a list") + # If any of the inputs has meta tensor and the meta tensor is in CL format, use CL format for the output + if any_input_is_storage_and_layout is False and any( + "val" in arg.meta + and ( + arg.meta["val"].is_contiguous(memory_format=torch.channels_last) + or arg.meta["val"].is_contiguous(memory_format=torch.channels_last_3d) + ) + for arg in fx_node_args + ): + output_stride = ir.make_channels_last_strides_for(new_size) + + concat_kernel = ir.ConcatKernel( + name=None, + layout=ir.FixedLayout( + device=device, + dtype=dtype, + size=new_size, + stride=output_stride, + ), + inputs=[], + ) + + kernel = ir.StorageBox(concat_kernel) + op_names = [] + for i in range(len(inputs)): + input_buffer = cls.realize_into( + inputs[i], + ir.SliceView.create( + kernel, dim, offsets_start[i], offsets_end[i], clamp=False + ), + ) + concat_kernel.inputs.append(input_buffer) + + if isinstance(inputs[i].data, ir.BaseView): + input_unwrapped = inputs[i].data.unwrap_view() + else: + input_unwrapped = inputs[i].data + + if ( + input_unwrapped.is_input_buffer() + and ir.is_gpu(inputs[i].get_device().type) + and not ir.is_dynamic(input_buffer) + ): + op_names.append(input_buffer.get_operation_name()) + + if len(op_names) > 1 and V.graph.has_feature(device, ir.BackendFeature.FOREACH): + V.graph.register_operation_list(op_names) + + cat_inputs = [ir.TensorBox(ir.StorageBox(inp)) for inp in concat_kernel.inputs] + input_graphs = fetch_graphs([cat_inputs]) + node_name = f'cat_{next(node_id)}' + new_graph = merge_traced_graphs(input_graphs, torch.ops.aten.cat, node_name, dim=dim) + + concat_kernel._post_init_setattr("name", V.graph.register_buffer(concat_kernel)) + concat_kernel._post_init_setattr("inputs", cls.unwrap_storage(concat_kernel.inputs)) + concat_kernel._post_init_setattr("traced_graph", new_graph) + concat_kernel._post_init_setattr("node_name", node_name) + + return kernel + + +def _patch_concatkernel_get_traced_graph(self): + return self.traced_graph + + +@classmethod +def _patch_concatkernel_realize_into(cls, src, dst): + # Attempt to turn this into a ReinterpretView rather than assert. + # This has concessions around layout, as as_storage_and_layout + # can cause us to go from flexible to fixed layout. + if not isinstance(dst, ir.ReinterpretView): + if ir.is_storage_and_layout(dst): + storage, layout = ir.as_storage_and_layout(dst) + dst = ir.ReinterpretView(data=storage, layout=layout) + if not isinstance(dst, ir.ReinterpretView): + raise RuntimeError(f"assert Expected dst to be an instance of ir.ReinterpretView. Got: {dst}") + if isinstance(src, ir.TensorBox): + # unwrap a TensorBox + return cls.realize_into(src.data, dst) + if isinstance(src, ir.StorageBox): + src.realize() + # ExternKernelAlloc has specific requirements for output layout, should create a copy + if not hasattr(src.data, "layout"): + raise RuntimeError("assert src.data has no attribute 'layout'") + if cls.can_realize_into_without_copy(src): + src.data.layout = ir.NonOwningLayout(dst) + return src.data + pw = clone(src, memory_format=torch.contiguous_format) + return cls.realize_into(pw, dst) + + +def _patch_externkernel_copy_input(x): + traced_graph = x.get_traced_graph() + node_name = x.get_name() + if traced_graph is None: + traced_graph = fetch_graphs([x])[0] + node_name = f'getitem_{next(node_id)}' + + pw = ir.Pointwise.create( + device=x.get_device(), + dtype=x.get_dtype(), + inner_fn=x.make_loader(), + ranges=x.get_size(), + origin_node=x.get_origin_node(), + traceback=x.get_traceback(), + traced_graph=traced_graph, + node_name=node_name + ) + pw.realize() + return pw + + +@classmethod +def _patch_externkernel_convert_to_reinterpret_view(cls, x): + """ + In order to pass this to an extern kernel we need a + ReinterpretView not a View. This allows us to avoid some + unneeded copies. + """ + if not isinstance(x, ir.BaseView): + raise RuntimeError(f"assert Expected type {ir.BaseView}, got {type(x)}") + if isinstance(x, ir.ReinterpretView): + return x + + # NOTE: Don't use extract_read_writes here as it fails when + # make_loader() inlines the computation + x_unwrap_view = x.unwrap_view() + buf = V.graph.get_buffer(x_unwrap_view.get_name()) + if buf is None: + raise RuntimeError("assert buf cannot be None") + x_unwrap_view_fx_node = buf.get_origin_node() + # Prefer channels last format according to how the format is set from eager. + if ( + x_unwrap_view_fx_node is not None + and "val" in x_unwrap_view_fx_node.meta + and isinstance(x_unwrap_view.layout, ir.FlexibleLayout) + and ( + x_unwrap_view_fx_node.meta["val"].is_contiguous( + memory_format=torch.channels_last + ) + or x_unwrap_view_fx_node.meta["val"].is_contiguous( + memory_format=torch.channels_last_3d + ) + ) + ): + x_unwrap_view.freeze_layout_with_same_order( + ir.make_channels_last_strides_for(x_unwrap_view.get_size()) + ) + else: + x_unwrap_view.freeze_layout() + + index_args, var_ranges = ir.dependencies.index_vars_squeeze( + x.get_size(), prefix="r" + ) + range_vars = index_args[0] + index = x.make_indexer()(range_vars) + + index = V.graph.sizevars.simplify_with_ranges(index, var_ranges) + strides = V.graph.sizevars.stride_vars(index, range_vars) + offset = V.graph.sizevars.offset_var(index, range_vars) + expected = ir.sympy_dot(range_vars, strides) + offset + + if index != expected: + ir.log.debug( + "convert_to_reinterpret_view failed: stride=%s offset=%s index=%s", + strides, + offset, + index, + ) + raise NotImplementedError + + r = ir.ReinterpretView( + data=x.data, + layout=ir.FixedLayout( + device=x.get_device(), + dtype=x.get_dtype(), + size=x.get_size(), + stride=strides, + offset=offset, + ), + ) + r._post_init_setattr("traced_graph", x.get_traced_graph()) + r._post_init_setattr("node_name", x.get_name()) + return r + + +@classmethod +def _patch_devicecopy_create(cls, x, device, non_blocking, traced_graph=None, node_name=None): + if ( + not x.is_extern() + and all(r in V.graph.constants for r in x.get_read_names()) + and not config.aot_inductor.use_runtime_constant_folding + ): + return x.constant_to_device(device) + + V.graph.add_device_info(device) + V.graph.add_device_info(x.get_device()) + + ir.developer_warning("DeviceCopy in input program") + constant_args = (non_blocking,) + r = ir.DeviceCopy( + ir.FlexibleLayout( + device=device, + dtype=x.get_dtype(), + size=x.get_size(), + ), + [cls.realize_input(x)], + constant_args, + ) + r._post_init_setattr("traced_graph", traced_graph) + r._post_init_setattr("node_name", node_name) + return r + + +def _patch_devicecopy_get_traced_graph(self): + return self.traced_graph + + +def _patch_multioutput_get_traced_graph(self): + return None + + +ir.MultiOutput.get_traced_graph = _patch_multioutput_get_traced_graph + + +def _patch_mutablebox_get_name(self): + return self.data.get_name() + + +def _patch_mutablebox_get_traced_graph(self): + return self.data.get_traced_graph() + + +@classmethod +def _patch_mutationlayout_realize_into(cls, src, dst, unsafe_alias=False): + dst.realize() + # NOTE: We must realize users of `dst` before we realize `src`, since + # realization order determines scheduling order. Otherwise, src's + # mutation would be scheduled before the existing users of dst! + V.graph.mark_buffer_mutated(dst.get_name()) + + if isinstance(src, ir.TensorBox): + src = src.data + + # We copy the contents of src into dst. In most cases this should + # be fused into a single kernel by the scheduler. + # NOTE: We cannot change src's layout to mutate dst directly as this + # would alias src to dst, which is not correct as further s to + # dst would effect users of src. However if there are no more users of + # dst, we can alias src to dst. + src.realize_hint() + + if not unsafe_alias: + input_graphs = fetch_graphs([dst, src]) + node_name = f'copy__{next(node_id)}' + new_graph = merge_traced_graphs(input_graphs, torch.ops.aten.copy, node_name) + + src = ir.Pointwise.create( + device=src.get_device(), + dtype=src.get_dtype(), + inner_fn=src.make_loader(), + ranges=[ + V.graph.sizevars.guard_equals(a, b) + for a, b in zip(src.get_size(), dst.get_size()) + ], + traced_graph=new_graph, + node_name=node_name, + ).data + + src.realize() + if not isinstance(src.data.layout, ir.FlexibleLayout): + raise RuntimeError("assert src.data.layout should be isinstance if ir.FlexibleLayout") + src.data.layout = ir.MutationLayoutSHOULDREMOVE(dst) + return src.data + + +def _patch_npu_inductor_ir(): + ir.Reduction.create = _patch_reduction_create + ir.BaseView.get_traced_graph = _patch_baseview_get_traced_graph + ir.BaseView.get_reads = _patch_base_view_get_reads + ir.BaseView.realize = _patch_baseview_realize + ir.BaseView.realize_hint = _patch_baseview_realize_hint + ir.BaseView.mark_reuse = _patch_mark_reuse + ir.ExpandView.create = _patch_expandview_create + ir.PermuteView.create = _patch_permuteview_create + ir.View.create = _patch_view_create + ir.SliceView.create = _patch_sliceview_create + ir.Buffer.traced_graph = None + ir.Buffer.get_traced_graph = _patch_buffer_get_traced_graph + ir.ConcatKernel.create = _patch_concatkernel_create + ir.ConcatKernel.get_traced_graph = _patch_concatkernel_get_traced_graph + ir.ConcatKernel.realize_into = _patch_concatkernel_realize_into + ir.ExternKernel.copy_input = _patch_externkernel_copy_input + ir.ExternKernel.convert_to_reinterpret_view = _patch_externkernel_convert_to_reinterpret_view + ir.DeviceCopy.create = _patch_devicecopy_create + ir.DeviceCopy.get_traced_graph = _patch_devicecopy_get_traced_graph + ir.MutableBox.get_name = _patch_mutablebox_get_name + ir.MutableBox.get_traced_graph = _patch_mutablebox_get_traced_graph + ir.Loops.get_name = _patch_loops_get_name + ir.Loops.get_traced_graph = _patch_loops_get_traced_graph + ir.Loops.create = _patch_loops_create + ir.Pointwise.constant_to_device = _patch_pointwise_constant_to_device + ir.MutationLayoutSHOULDREMOVE.realize_into = _patch_mutationlayout_realize_into diff --git a/torch_npu/_inductor/codegen/kernel_analysis.py b/torch_npu/_inductor/codegen/kernel_analysis.py new file mode 100644 index 0000000000000000000000000000000000000000..30d048940a0715be24fe3e93eea5a624a452a06d --- /dev/null +++ b/torch_npu/_inductor/codegen/kernel_analysis.py @@ -0,0 +1,305 @@ +from typing import List, Tuple +import sympy +from torch._inductor import ir +from torch._inductor.scheduler import SchedulerNode +from torch._inductor.utils import sympy_index_symbol +from torch._inductor.virtualized import V + + +class IndexAnalysis: + def __init__(self, kernel, raw_index, is_store_index=False, is_index_expr=False): + self.index = raw_index.subs(V.graph.sizevars.var_to_val) + self.kernel = kernel + self.tiling_axis = [x.symbol() for x in self.kernel.tiling_axis] + self.stride_list = None # stride list [1,2,4,24] + self.reshape_sizes = [] # [RBLOCK, 1, 1, XBLOCK_SUB] + self.broadcast_sizes = [] # [RBLOCK, XBLOCK_SUB] + self.permute_shape = [] # [0,2,1,3] + self.var_replacements = {} # r2 ->r2_0, etc + self.var_directions = {} # r2_0 -> [None,:,None] + self.similar = None # (r,x,z,y) + self.need_permute = False + self.need_broadcast = False + self.need_reshape = False + self.gold = kernel.golden_var_list # tuple([x.symbol() for x in reversed(kernel.tiling_axis)]) + self.var_stride = [ + (key, coeff) + for key, coeff in self.index.as_coefficients_dict().items() + if not isinstance(key, sympy.Integer) + ] + # sort by stride + self.var_stride.sort(key=lambda x: x[1]) + # only contains tiing axis var + self.var_list = tuple([x[0] for x in self.var_stride if x[0] in self.tiling_axis]) + self.stride_list = tuple([x[1] for x in self.var_stride if x[0] in self.tiling_axis]) + self.is_store_index = is_store_index + self.is_index_expr = is_index_expr + + def get_most_similar_shape(self): + matched_dims = 0 + self.similar = None + for value in self.kernel.index_analysis.keys(): + if len(value) != len(self.gold): + continue + i = 0 + while i < len(self.var_list): + if value[i] == self.var_list[i]: + i = i + 1 + else: + break + + if i > matched_dims: + matched_dims = i + self.similar = value + return self.similar + + @classmethod + def same_var_list(cls, var1, var2): + if len(var1) != len(var2): + return False + for i, v in enumerate(var1): + if v != var2[i]: + return False + return True + + def shrink_permute_shape(self, permute_shape): + diff = len(self.gold) - len(self.kernel.tiling_axis) + new_shape = [x for x in permute_shape if x - diff >= 0] + return new_shape + + def analyze_permute_shape(self): + if self.is_index_expr: + return + if self.gold == self.similar: + self.need_permute = False + return + + similar = tuple(reversed(self.similar)) + gold = tuple(reversed(self.gold)) + self.permute_shape = [None] * len(gold) + + if self.is_store_index: + for i, x in enumerate(similar): + if x != gold[i]: + index = gold.index(x) + self.permute_shape[i] = index + self.need_permute = True + else: + self.permute_shape[i] = i + return + + for i, x in enumerate(gold): + if x != similar[i]: + index = similar.index(x) + self.permute_shape[i] = index + self.need_permute = True + else: + self.permute_shape[i] = i + + def analyze_broadcast_sizes(self): + if not self.need_reshape: + self.need_broadcast = False + return + self.need_broadcast = True + reversed_similar = reversed(self.similar) + similar = [x for x in reversed_similar] + self.broadcast_sizes = ["1"] * len(similar) + for i, x in enumerate(similar): + self.broadcast_sizes[i] = f"{x.name.upper()}BLOCK_SUB" + + def analyze_reshape_sizes(self): + if all(x in self.var_list for x in self.tiling_axis): + self.need_reshape = False + return + self.need_reshape = True + reversed_similar = reversed(self.similar) + similar = [x for x in reversed_similar] + var_list = [x for x in reversed(self.var_list)] + self.reshape_sizes = ["1"] * len(similar) + for _, x in enumerate(var_list): + index = similar.index(x) + self.reshape_sizes[index] = f"{x.name.upper()}BLOCK_SUB" + + def analyze_var_direction(self): + if self.var_list == self.gold: + return + var_list = self.var_list if len(self.var_list) == len(self.gold) else self.similar + if var_list == self.gold: + return + if not var_list: + return + var_list = list(reversed(var_list)) + gold = list(tuple(reversed(self.gold))) + if len(var_list) != len(gold): + raise RuntimeError("assert var_list and gold must have same length") + var_list = [x for x in var_list if x in self.kernel.tiling_axis] + gold = [x for x in gold if x in self.kernel.tiling_axis] + for i, x in enumerate(gold): + index = var_list.index(x) + if (index == i): + continue + new_var = sympy_index_symbol(f"{x}") if self.is_index_expr else sympy_index_symbol(f"{x}_{index}") + if new_var in self.var_replacements: + continue + direction = ["None"] * len(gold) + direction[index] = ":" + direction_str = f"[{','.join(direction)}]" + self.var_replacements[x] = new_var + self.var_directions[new_var] = direction_str + self.kernel.range_tree_nodes[x].var_directions[new_var] = direction_str + + def analyze_index(self): + if isinstance(self.index, sympy.Integer): + return + if not self.kernel.golden_var_list: + self.kernel.select_golden_varlist() + self.gold = self.kernel.golden_var_list + + if self.gold is None: + raise RuntimeError("assert gold must not be None") + if len(self.gold) != len(self.tiling_axis): + raise RuntimeError("assert gold must have same length as tiling_axis") + + def all_tiling_in_var_list(): + return all([x in self.var_list for x in self.tiling_axis]) + # 2 analyze permute shape for full_dim_len index + + if all_tiling_in_var_list(): + self.similar = self.var_list + self.analyze_permute_shape() + if self.var_list not in self.kernel.index_analysis: + self.kernel.index_analysis[self.var_list] = self + # 3. analyze reshape and broadcast sizes + else: + pass + + # 4 analyze var direction + self.analyze_var_direction() + + def generate_statement(self): + statement = "" + if self.need_reshape: + reshape_sizes = f"[{','.join(self.reshape_sizes)}]" + statement = f".reshape({reshape_sizes})" + if self.need_broadcast: + broadcast_sizes = f"[{','.join(self.broadcast_sizes)}]" + statement = f"{statement}.broadcast_to({broadcast_sizes})" + if self.need_permute: + statement = f"{statement}.permute({self.permute_shape})" + return statement + + +class ReductionAnalysis: + def __init__(self, kernel): + self.kernel = kernel + self.reduction = None + self.reduced_dim = None + if self.numof_reduction_axis() > 1: + self.kernel.persistent_reduction = True + self.reduced_dim = 0 + return + + reduction = self.kernel.find_reduction_node() + if reduction is None or not isinstance(reduction, ir.Reduction): + raise RuntimeError("failed to get one reduction node") + if not hasattr(reduction, "reduced_idx"): + raise RuntimeError("reduction node doesn't have attr reduced_idx") + self.reduction = reduction + self.reduced_dim = self.analyze_reduction_dim() + + def is_higher_order_reduction(self): + return self.dim < len(self.kernel.tiling_axis) - 1 + + def is_1d_reduction(self): + return self.kernel.numels["r"] > 1 and len(self.kernel.numels) == 1 + + def get_reduce_dim_reshape(self, reduce_axis): + if self.is_1d_reduction(): + shape_str = f"[{reduce_axis.name.upper()}BLOCK_SUB]" + else: + shape = ["1"] * len(self.kernel.tiling_axis) + shape[self.reduced_dim] = f"{reduce_axis.name.upper()}BLOCK_SUB" + shape_str = f"[{','.join(shape)}]" + return shape_str + + def dense_size_list(self) -> List[str]: + sizes = [f"{x.name.upper()}BLOCK_SUB" for x in self.kernel.tiling_axis] + if self.numof_reduction_axis() > 1: + return sizes + + reduce_axis = self.kernel.tiling_axis[-1] + sizes.pop(-1) + sizes.insert(self.reduced_dim, f"{reduce_axis.name.upper()}BLOCK_SUB") + return sizes + + def dense_size_str(self): + sizes = self.dense_size_list() + if self.numof_reduction_axis() > 1: + return f"[{'* '.join(sizes)}]" + return f"[{', '.join(sizes)}]" + + def numof_reduction_axis(self): + return self.kernel.numof_reduction_axis() + + def reduction_axis_list(self): + return self.kernel.reduction_axis_list() + + def analyze_reduction_dim(self): + + if self.numof_reduction_axis() > 1: + self.kernel.persistent_reduction = True + self.reduced_dim = 0 + return 0 + + if not self.kernel.golden_var_list: + self.kernel.select_golden_varlist() + if self.kernel.golden_var_list is None: + raise RuntimeError("assert self.kernel.golden_var_list is not None") + + dim = -1 + for i, x in enumerate(reversed(self.kernel.golden_var_list)): + if x.name[0] == 'r': + dim = i + break + return dim + + def analyze_reduction_dim1(self): + if self.numof_reduction_axis() > 1: + self.kernel.persistent_reduction = True + self.reduced_dim = 0 + return 0 + reduction = self.reduction + # kept = [0,1,3], reduced = [2] + for i, x in enumerate(reduction.reduced_idx): + if reduction.reduction_ranges[i] <= 1: + continue + reduced_idx = x + break + # the index (in reduction.ranges) of low_dims + low_dims = [i for i, x in enumerate(reduction.kept_idx) if x > reduced_idx] + if not low_dims: + return len(self.kernel.tiling_axis) - 1 + elif len(low_dims) == len(reduction.kept_idx): + return 0 + # reduction dim when low_dims are not meraged + dim = len(reduction.kept_idx) - len(low_dims) + + tiling_axis = self.kernel.tiling_axis[:-1] + merged = 1 + j = len(tiling_axis) - 1 + # remove all low_dims from tiling_axis + # all axis before ahead of j are high-orders + # then following is reduced dim + ranges = [x for x in reduction.ranges if x > 1] + for i in reversed(low_dims): + len_axis = tiling_axis[j].length + len_reduction = ranges[i] * merged + if len_reduction < len_axis: + merged = merged * len_reduction + elif len_reduction == len_axis: + j = j - 1 + merged = 1 + else: + raise RuntimeError(f"assert should not reach here low_dims({i})={len_reduction}, axis[{j}]=len)") + dim = j + 1 + return dim diff --git a/torch_npu/_inductor/codegen/npu_kernel_features.py b/torch_npu/_inductor/codegen/npu_kernel_features.py new file mode 100644 index 0000000000000000000000000000000000000000..2e85b17c3d9c46d4303e53ccdeedd46a0853a0a5 --- /dev/null +++ b/torch_npu/_inductor/codegen/npu_kernel_features.py @@ -0,0 +1,104 @@ +import functools +from typing import Iterable +from typing import Iterable +from typing import Tuple, List +import sympy +import torch +from torch._inductor.codegen.simd import SIMDScheduling +from torch._inductor.codegen.simd_kernel_features import SIMDKernelFeatures, NodeScheduleEntry +from torch._inductor.utils import cache_on_self +from torch._inductor.virtualized import V +from torch.utils._ordered_set import OrderedSet + + +class NumelList(Tuple): + + @staticmethod + def calc_numels(other): + if isinstance(other, Iterable): + numel = NumelList.calc_numels(other) + return numel + elif isinstance(other, NumelList): + return other.numels() + else: + return other + + def numels(self): + numel = functools.reduce(lambda a, b: a * b, self, 1) + return numel + + def __eq__(self, other): + numel = self.numels() + numel2 = other.numels() if isinstance(other, NumelList) else other + return numel == numel2 + + def __le__(self, other): + numel = self.numels() + numel2 = other.numels() if isinstance(other, NumelList) else other + return numel <= numel2 + + def __lt__(self, other): + numel = self.numels() + numel2 = other.numels() if isinstance(other, NumelList) else other + return numel < numel2 + + def __ge__(self, other): + numel = self.numels() + numel2 = other.numels() if isinstance(other, NumelList) else other + return numel >= numel2 + + def __gt__(self, other): + numel = self.numels() + numel2 = other.numels() if isinstance(other, NumelList) else other + return numel > numel2 + + def __mod__(self, other): + numel = self.numels() + numel2 = other.numels() if isinstance(other, NumelList) else other + return numel % numel2 + + def __truediv__(self, other): + numel = self.numels() + numel2 = other.numels() if isinstance(other, NumelList) else other + return numel / numel2 + + def __floordiv__(self, other): + numel = self.numels() + numel2 = other.numels() if isinstance(other, NumelList) else other + return numel // numel2 + + def __mul__(self, other): + numel = self.numels() + numel2 = other.numels() if isinstance(other, NumelList) else other + return numel * numel2 + + def __rmul__(self, other): + numel = self.numels() + numel2 = other.numels() if isinstance(other, NumelList) else other + return numel * numel2 + + def __add__(self, other): + numel = self.numels() + numel2 = other.numels() if isinstance(other, NumelList) else other + return numel + numel2 + + def __radd__(self, other): + numel = self.numels() + numel2 = other.numels() if isinstance(other, NumelList) else other + return numel + numel2 + + def __hash__(self): + return super(NumelList, self).__hash__() + + +class NPUKernelFeatures(SIMDKernelFeatures): + def __init__( + self, + node_schedule: List[NodeScheduleEntry], + numel: sympy.Expr, + reduction_numel: sympy.Expr = sympy.S.One, + ): + super().__init__(node_schedule, numel, reduction_numel) + self.numel = NumelList(self.numel) if isinstance(self.numel, Iterable) else self.numel + self.reduction_numel = NumelList(self.reduction_numel) if isinstance(self.reduction_numel, + Iterable) else self.reduction_numel diff --git a/torch_npu/_inductor/codegen/scheduling.py b/torch_npu/_inductor/codegen/scheduling.py new file mode 100644 index 0000000000000000000000000000000000000000..8a77ac9a16ceb49c0208b16df0933faba35420ee --- /dev/null +++ b/torch_npu/_inductor/codegen/scheduling.py @@ -0,0 +1,595 @@ +import collections +import contextlib +import itertools +import functools +import os +from typing import Dict, Sequence, List, Iterable, Any, Union +import sympy +import torch +from torch._dynamo.utils import counters +from torch._inductor import scheduler, metrics +from torch._inductor.codecache import code_hash +from torch._inductor.codegen.multi_kernel import MultiKernel +from torch._inductor.codegen.simd import DisableReduction, EnableReduction, SIMDKernelFeatures, SIMDKernel +from torch._inductor.codegen.simd import schedule_log, scheduler, WhyNoFuse +from torch._inductor.codegen.triton import (TritonScheduling, log, config) +from torch._inductor.codegen.triton import ( + TritonScheduling, + config, + schedule_log, + get_fused_kernel_name, + get_kernel_category_by_source_code, + Placeholder, + get_kernel_metadata, + get_path, + IndentedBuffer +) +from torch._inductor.utils import sympy_index_symbol, ModularIndexing, FloorDiv, sympy_product +from torch._inductor.virtualized import V +from torch.fx.immutable_collections import immutable_dict +from torch._inductor.dependencies import MemoryDep, StarDep, WeakDep +from torch.utils._ordered_set import OrderedSet +from torch._inductor.codegen.simd import CandidateTiling + +from .triton import NPUIndexTritonKernel, flatten +from .kernel_analysis import ReductionAnalysis +from .npu_kernel_features import NumelList, NPUKernelFeatures +from .split_tiling import SplitTiling +from .triton import NPUIndexTritonKernel +from .. import config as npu_config +from ..lowering_fx import ( + create_fx_from_snodes_by_traced_graph, + create_compile_kwargs, + generate_fx_graph_code, + dump_fx_graph_code +) + +from ..config import log + + +def flatten_groups(nums): + res = [] + for i in nums: + if isinstance(i, Iterable): + for x in i: + res.append(x) + else: + res.append(i) + return res + + +@classmethod +def create_tiling( + cls, pw_tiling: Sequence[sympy.Expr], reduction_tiling: Sequence[sympy.Expr] +) -> Dict[str, sympy.Expr]: + """ + Create a tiling dict from pointwise and reduction splits. + """ + + pw_tiling = flatten_groups(pw_tiling) + pw_prefixes = ["w", "v", "t", "z", "y", "x"][-len(pw_tiling):] + if len(reduction_tiling) == 0: + reduction_prefixes = [] + else: + reduction_tiling = flatten_groups(reduction_tiling) + reduction_tiling = [NumelList(reduction_tiling).numels()] + reduction_prefixes = ["r"][: len(reduction_tiling)] + tiling = immutable_dict( + list(zip(pw_prefixes, pw_tiling)) + + list(zip(reduction_prefixes, reduction_tiling))) + return tiling + + +class NPUTritonScheduling(TritonScheduling): + def __init__(self, input_scheduler): + super().__init__(input_scheduler) + self.kernel_type = NPUIndexTritonKernel + + def create_kernel_choices( + self, kernel_features: SIMDKernelFeatures, kernel_args, kernel_kwargs + ) -> List[SIMDKernel]: + + return [ + self.kernel_type( + *kernel_args, + **kernel_kwargs, + ) + ] + + # transform indexing before call codegen_node_schedule_with_kernel + def codegen_node_schedule(self, kernel_features: SIMDKernelFeatures, nodes): + node_schedule = kernel_features.node_schedule + tiling = self.select_tiling( + node_schedule, kernel_features.numel, kernel_features.reduction_numel + ) + + kernels = self.create_kernel_choices( + kernel_features, [tiling], {"features": kernel_features} + ) + kernel = kernels[0] + setattr(kernel, "node_schedule", node_schedule) + self.decide_codegen_dims_in_kernel(node_schedule, kernel) + + for kernel in kernels: + self.codegen_node_schedule_with_kernel(node_schedule, kernel) + + MultiKernel.merge_workspaces_inplace(kernels) + for kernel in kernels: + with V.set_kernel_handler(kernel): + src_code = kernel.codegen_kernel() + + V.graph.removed_buffers |= kernel.removed_buffers + V.graph.inplaced_to_remove |= kernel.inplaced_to_remove + + traced_graph_hash = None + if npu_config.dump_fx_graph: + if not npu_config.traced_fx_graph_cache: + npu_config.traced_fx_graph_cache = os.path.join(os.getenv("TORCHINDUCTOR_CACHE_DIR"), + 'traced_fx_graph_cache') + os.makedirs(npu_config.traced_fx_graph_cache, exist_ok=True) + traced_graph, fx_call_args, fx_args, compile_kwargs = create_fx_from_snodes_by_traced_graph(nodes) + if traced_graph is None: + log.warning(f"For nodes {nodes}, could not gen fx graph while dump-graph.") + else: + traced_graph_hash = code_hash(traced_graph.print_readable(print_output=False) + torch.__version__) + + kernel_name, src_code = self.define_kernel(src_code, node_schedule, kernel, traced_graph_hash) + + kernel.kernel_name = kernel_name + kernel.code_hash = code_hash(src_code) + del kernel + + final_kernel: Union[SIMDKernel, MultiKernel] + if len(kernels) > 1: + final_kernel = MultiKernel(kernels) + else: + (final_kernel,) = kernels + + with V.set_kernel_handler(final_kernel): + for node in kernel_features.scheduler_nodes(): + node.mark_run() + + self.codegen_comment(node_schedule) + final_kernel.call_kernel(final_kernel.kernel_name) + + if npu_config.dump_fx_graph and traced_graph is not None: + new_compile_kwargs = create_compile_kwargs(final_kernel, fx_call_args, fx_args) + if new_compile_kwargs: + compile_kwargs |= new_compile_kwargs + fx_dump_path = os.path.join(npu_config.traced_fx_graph_cache, traced_graph_hash) + os.makedirs(fx_dump_path, exist_ok=True) + fx_code = generate_fx_graph_code(traced_graph.code, src_code, kernel_name, compile_kwargs) + dump_fx_graph_code(fx_code, fx_dump_path, traced_graph_hash) + + if config.nan_asserts: + final_kernel.codegen_nan_check() + if config.warn_mix_layout: + final_kernel.warn_mix_layout(kernels[0].kernel_name) + + V.graph.removed_buffers |= final_kernel.removed_buffers + V.graph.inplaced_to_remove |= final_kernel.inplaced_to_remove + + if ( + V.graph.wrapper_code.supports_intermediate_hooks + and config.generate_intermediate_hooks + ): + # Not every node in the schedule will actually be live on output; + # we can't check dead buffers. + live_outs = kernels[0].args.live_output_buffers() + for node in kernel_features.scheduler_nodes(): + name = node.get_name() + if name not in live_outs: + continue + if node.node is None: + raise RuntimeError("assert node.node is not None") + + origin_node = node.node.get_origin_node() + if origin_node is not None: + counters["inductor"]["intermediate_hooks"] += 1 + V.graph.wrapper_code.writeline( + f"run_intermediate_hooks({origin_node.name!r}, {name})" + ) + + self.scheduler.free_buffers() + + def define_kernel(self, src_code, node_schedule, kernel, traced_graph_hash: str): + wrapper = V.graph.wrapper_code + if (src_code, traced_graph_hash) in wrapper.src_to_kernel: + kernel_name = wrapper.src_to_kernel[(src_code, traced_graph_hash)] + if npu_config.dump_fx_graph: + src_code = src_code.replace(str(Placeholder.DESCRIPTIVE_NAME), kernel_name) + subs_name = kernel_name if config.triton.unique_kernel_names else "triton_" + src_code = src_code.replace(str(Placeholder.KERNEL_NAME), subs_name) + if traced_graph_hash: + src_code = src_code.replace('TRACED_GRAPH_HASH', traced_graph_hash) + src_code = src_code.replace('TRACED_GRAPH_DIR', npu_config.traced_fx_graph_cache) + else: + fused_name = ( + get_fused_kernel_name(node_schedule, config.triton.descriptive_names) + if config.triton.descriptive_names + else "" + ) + kernel_category = get_kernel_category_by_source_code(src_code)[:3] + kernel_name = "_".join( + ["triton", kernel_category, fused_name, wrapper.next_kernel_suffix()] + ) + # use the original src_code as the key + wrapper.src_to_kernel[(src_code, traced_graph_hash)] = kernel_name + subs_name = kernel_name if config.triton.unique_kernel_names else "triton_" + + # DESCRIPTIVE_NAME is used for profiling purposes; it shows the full kernel name + # even when unique_kernel_names is turned off. Meanwhile, KERNEL_NAME is sometimes set + # to "triton_" to maximize caching opportunities (when unique_kernel_names = False). + src_code = src_code.replace(str(Placeholder.DESCRIPTIVE_NAME), kernel_name) + src_code = src_code.replace(str(Placeholder.KERNEL_NAME), subs_name) + if traced_graph_hash: + src_code = src_code.replace('TRACED_GRAPH_HASH', traced_graph_hash) + src_code = src_code.replace('TRACED_GRAPH_DIR', npu_config.traced_fx_graph_cache) + + src_code = src_code.replace("#pragma CMT", "#") + + basename, _, kernel_path = get_path(code_hash(src_code.strip()), "py") + + compile_wrapper = IndentedBuffer() + compile_wrapper.writeline(f"async_compile.triton({subs_name!r}, '''") + compile_wrapper.splice(src_code, strip=True) + current_device = V.graph.get_current_device_or_throw() + compile_wrapper.writeline(f"''', device_str='{current_device.type}')") + + metadata_comment = f"# kernel path: {kernel_path}" + origins, detailed_origins = get_kernel_metadata(node_schedule, wrapper) + metadata_comment += "\n" + origins + "\n" + detailed_origins + # Extra debug message for npu. + snode_str = "" + snodes = [node for node in node_schedule if node not in (DisableReduction, EnableReduction)] + snode_str = f"\n# SchedulerNodes: {snodes}" + metadata_comment += snode_str + "\n" + if npu_config.dump_fx_graph: + from ..lowering_fx import snodes_to_fx + gm = snodes_to_fx.get(str(snodes), "") + gm_str = "\n# Graph Module str:\n" + gm_str += "\n".join([f"# {line}" for line in gm.split("\n")]) + metadata_comment += gm_str + "\n" + + wrapper.define_kernel( + kernel_name, compile_wrapper.getvalue(), metadata_comment + ) + + # log kernel metadata for offline analysis. + # E.g. one can find all unaligned inner reduction and check if + # padding helps with the perf kernel by kernel. + if metrics.is_metric_table_enabled("kernel_metadata"): + metrics.log_kernel_metadata(kernel_name, kernel_path, src_code) + + return kernel_name, src_code + + def codegen_node( + self, node: Union[scheduler.FusedSchedulerNode, scheduler.SchedulerNode] + ): + """ + Given a set of pre-fused nodes, generate a Triton kernel. + """ + + nodes: List[scheduler.SchedulerNode] = node.get_nodes() # type: ignore[assignment] + _, (numel, rnumel) = max(nodes, key=lambda x: int(x.is_reduction())).group + + node_schedule = self.generate_node_schedule(nodes, numel, rnumel) + schedule_log.debug("Schedule:\n %s", node_schedule) + + return self.codegen_node_schedule( + NPUKernelFeatures(node_schedule, numel, rnumel), nodes + ) + + def can_fuse(self, node1, node2): + """ + Hook called by Scheduler to determine if the Triton backend + can fuse node1 and node2. These nodes might already be + FusedSchedulerNodes. + """ + if isinstance(node1, scheduler.ForeachKernelSchedulerNode) or isinstance( + node2, scheduler.ForeachKernelSchedulerNode + ): + return scheduler.ForeachKernelSchedulerNode.can_fuse(node1, node2) + + _, (numel1, rnumel1) = node1.group + _, (numel2, rnumel2) = node2.group + why = WhyNoFuse(node1, node2) + + if node1.is_split_scan() and not node2.is_split_scan(): + if node2.is_reduction(): + why("Split scan cannot fuse with reductions") + elif node2.is_split_scan() and not node1.is_split_scan(): + if node1.is_reduction(): + why("Split scan cannot fuse with reductions") + + if node1.is_reduction() and node2.is_reduction(): + reduction_can_fuse = numel1 == numel2 and rnumel1 == rnumel2 + if not reduction_can_fuse: + why( + "numel/rnumel mismatch (reduce) (%s, %s), (%s, %s)", + numel1, + numel2, + rnumel1, + rnumel2, + ) + return reduction_can_fuse + + if not node1.is_reduction() and not node2.is_reduction(): + if not (numel1 == numel2 and rnumel1 == rnumel2): + if not node2.is_template(): + why( + "numel/rnumel mismatch (non-reduce) (%s, %s), (%s, %s)", + numel1, + numel2, + rnumel1, + rnumel2, + ) + return False + else: + # prologue fusion input sizes differ from output group + # fuse so long as this node matches the group of existing prologue nodes + for node in node2.get_nodes(): + # dont need to check epilogue nodes for prologue fusion, break after template + if node.is_template(): + break + # we would have already restricted prologue from fusing if it had multiple + # uses, so it must be fusing into this node + if not node.used_buffer_names() & node1.get_buffer_names(): + continue + _, (pro_numel, pro_rnumel) = node.group + if not (numel1 == pro_numel and rnumel1 == pro_rnumel): + why( + "numel/rnumel mismatch prologue mismatch (%s, %s), (%s, %s)", + numel1, + pro_numel, + rnumel1, + pro_rnumel, + ) + return False + + for n in (node1, node2): + if n.is_template(): + return True + + # check for a bad combined tiling + tiling1 = self.select_tiling(node1.get_nodes(), numel1, rnumel1) + tiling2 = self.select_tiling(node2.get_nodes(), numel1, rnumel1) + tiling3 = self.select_tiling( + node1.get_nodes() + node2.get_nodes(), numel1, rnumel1 + ) + if config.triton.tiling_prevents_pointwise_fusion: + cond = True + if len(tiling1) > 2: + if len(tiling2) > 2: + cond = tiling1 == tiling2 == tiling3 + else: + cond = tiling1 == tiling3 + elif len(tiling2) > 2: + cond = tiling2 == tiling3 + if not cond: + why( + "tiling mismatch (%s, %s, %s)", + tiling1, + tiling2, + tiling3, + ) + return False + + return True + + if not node1.is_reduction() and node2.is_reduction(): + if not (rnumel1 == 1 and rnumel2 != 1): + raise AssertionError + if numel1 == numel2 * rnumel2: + if not all( + SIMDKernel.is_compatible((numel2, rnumel2), n.get_ranges()) + for n in node1.get_nodes() + ): + why("nodes numel/rnumel incompatibility") + return False + if ( + config.triton.tiling_prevents_reduction_fusion + and not node1.is_template() + ): + is_reduction_tiling_valid = tuple( + self.select_tiling(node1.get_nodes(), numel1).values() + ) in ( + (numel1, 1), + (numel2, rnumel2, 1), + numel1, + ) + if not is_reduction_tiling_valid: + why("invalid tiling for reduction") + return is_reduction_tiling_valid + return True + + if numel1 != numel2: + why("nodes numel incompatibility") + return numel1 == numel2 + + if not (node1.is_reduction() and not node2.is_reduction()): + raise AssertionError + # swap args to hit the case above + return self.can_fuse_horizontal(node2, node1) + + can_fuse_vertical = can_fuse + can_fuse_horizontal = can_fuse + + def decide_codegen_dims_in_kernel(self, node_schedule, kernel): + def current_reduction_nodes(nodes): + return itertools.takewhile(lambda n: n is not DisableReduction, nodes) + + with kernel: + # 1. transform dims: create new dims to substitute floor_divide and modular expression + stack = contextlib.ExitStack() + for _, node in enumerate(node_schedule): + if node is DisableReduction: + stack.enter_context(kernel.disable_reduction()) + elif node is EnableReduction: + stack.close() + else: + index_vars = kernel.split_and_set_ranges(node.get_ranges()) + node._body.transform_dims_in_indexing(index_vars) + # 2. go through range_tree_nodes to findout, to find one axis could be substituted by others + self.additional_nodes_to_be_subs(kernel, kernel.range_tree_nodes_substituted) + # 3.do the substitution on all indexing + for node in node_schedule: + if node in (EnableReduction, DisableReduction): + continue + indexing = node._body.indexing + node._body.substituted_dims_in_indexing(indexing, kernel, kernel.range_tree_nodes_substituted) + + # 4.remove the substituted dims from kernel + for var, _ in kernel.range_tree_nodes_substituted.items(): + if (var in kernel.range_tree_nodes): + root = kernel.range_tree_nodes[var].parent + root.remove_entry(var) + # select split and tiling axis + split_tiling = SplitTiling(kernel) + split_tiling.select_split_tiling_axis() + kernel.load_store_indexing = split_tiling.indexing + # ReductionAnalysis depends on kernel.load_store_indexing + if kernel.inside_reduction: + kernel.reduce_analysis = ReductionAnalysis(kernel) + + def additional_nodes_to_be_subs(self, kernel, node_to_be_substituted): + for node in kernel.range_tree_nodes.values(): + if node.expr != sympy_index_symbol(f"{node.parent.prefix}index") \ + or len(node.parent.var_ranges) == 1 \ + or node.symbol() in node_to_be_substituted: + continue + numel = sympy.Integer(1) + new_var_expr = sympy.Integer(0) + for k, s in node.parent.var_ranges.items(): + if k == node.symbol(): + continue + numel = numel * s + sub_node = kernel.range_tree_nodes[k] + new_var_expr = new_var_expr + sub_node.symbol() * sub_node.divisor + + if numel == node.length: + node_to_be_substituted[node.symbol()] = [(node.length, new_var_expr)] + else: + log.warning("sub nodes (expr%s, numel:%d) can not make up parent node(%s:%d)", + new_var_expr, numel, node.symbol(), node.length) + + @classmethod + @functools.lru_cache(32) + def candidate_tilings(cls, node, numel, reduction_numel) -> list[CandidateTiling]: + """ + The main difference from gpu is default tiling, npu needs non-collapse ranges. + """ + is_pointwise = reduction_numel == 1 + + def assert_true(cond, msg=""): + if not cond: + raise AssertionError(msg) + + def tile_ranges(is_pointwise: bool, ranges, rw) -> list[CandidateTiling]: + assert_true(len(rw.range_vars) == len(ranges), f"{rw.range_vars=} {ranges=}") + + dep_sources = [rw.reads, rw.writes] + assert_true(all( + isinstance(dep, (MemoryDep, StarDep)) + for dep in itertools.chain.from_iterable(dep_sources) + )) + deps = [ + dep + for dep in itertools.chain.from_iterable(dep_sources) + if dep.name not in V.graph.removed_buffers + and isinstance(dep, MemoryDep) + ] + write_names = OrderedSet([dep.name for dep in rw.writes]) + + def collapse_ranges(ranges: Sequence[sympy.Expr]) -> sympy.Expr: + return V.graph.sizevars.simplify(sympy_product(ranges)) + + tilings = [ + CandidateTiling( + tiling=cls.create_partial_tiling( + ranges, is_pointwise + ), + name="none", + score=0, + ) + ] + + for dep in deps: + strides = V.graph.sizevars.stride_hints(dep.index, rw.range_vars) + assert_true(len(strides) == len(ranges)) + try: + split = strides.index(1) + 1 + if split == len(ranges): + continue + if all(s == 0 for s in strides[split:]): + continue + + except ValueError: + continue + + tiled_groups = ( + collapse_ranges(ranges[:split]), + collapse_ranges(ranges[split:]), + ) + + # score by number of elements + score = V.graph.sizevars.size_hint( + sympy_product( + size for size, stride in zip(ranges, strides) if stride != 0 + ) + ) + if dep.name in write_names: + # ngimel said contiguous writes is more important than reads + score *= 2 + if CandidateTiling.is_good_size(tiled_groups[0]): + score *= 2 + if CandidateTiling.is_good_size(tiled_groups[1]): + score *= 2 + + if ( + V.graph.sizevars.size_hint( + score - sympy_product(itertools.chain(ranges, reduction_ranges)) + ) + >= 0 + ): + tilings.append( + CandidateTiling( + tiling=cls.create_partial_tiling( + [ + collapse_ranges(ranges[:split]), + collapse_ranges(ranges[split:]), + ], + reduction_numel, + ), + score=score, + name=dep.name, + ) + ) + + return tilings + + pointwise_ranges, reduction_ranges = node.get_ranges() + if len(pointwise_ranges) <= 1 and len(reduction_ranges) <= 1: + return [] + + # Tile either pointwise or reduction dims. + pointwise_ranges, reduction_ranges = node.get_ranges() + partial_tilings = tile_ranges( + is_pointwise, + pointwise_ranges if is_pointwise else reduction_ranges, + node.pointwise_or_reduction_read_writes(is_pointwise), + ) + + # Fill in the missing ranges. + full_tilings = [ + CandidateTiling( + tiling=cls.complete_partial_tiling( + tiling.tiling, numel, reduction_numel + ), + score=tiling.score, + name=tiling.name, + ) + for tiling in partial_tilings + ] + + return full_tilings \ No newline at end of file diff --git a/torch_npu/_inductor/codegen/split_tiling.py b/torch_npu/_inductor/codegen/split_tiling.py new file mode 100644 index 0000000000000000000000000000000000000000..782cc9f7455cd6d9f4eea63075768fdeb1af0690 --- /dev/null +++ b/torch_npu/_inductor/codegen/split_tiling.py @@ -0,0 +1,283 @@ +from functools import reduce +import sympy as sympy +from torch._inductor.codegen.simd import (EnableReduction, DisableReduction) +from torch._inductor.codegen.triton import TritonKernel +from torch._inductor.loop_body import MemoryUsageType +from torch._inductor.runtime.runtime_utils import next_power_of_2 +from torch._inductor.utils import ModularIndexing, sympy_subs +from torch._inductor.virtualized import V + +from .kernel_analysis import IndexAnalysis +from .triton_utils import get_aligned_numel +from ..config import num_vector_core, log + + +# split and tiling axis selector +class SplitTiling: + def __init__(self, kernel: TritonKernel): + self.kernel = kernel + self.indexing = [] # load and store indexing among all scheduler nodes + kernel.sorted_axis = [x for x in kernel.range_tree_nodes.values()] + kernel.sorted_axis.sort(reverse=True, key=self.key) + for i, dim in enumerate(kernel.sorted_axis): + dim.sorted_order = i + + self.find_lowest_dimension() + self.should_outer_reduce = False + self.possible_need_permute = self.find_possible_permutes() + + def find_possible_permutes(self): + if len(self.kernel.low_dims) <= 1: + return False + var_lists = [] + low_dims = [self.kernel.sorted_axis[x].symbol() for x in self.kernel.low_dims] + for index in self.indexing: + var_stride = [ + (key, coeff) + for key, coeff in index.as_coefficients_dict().items() + if not isinstance(key, sympy.Integer) + ] + var_stride.sort(key=lambda x: x[1]) + var_list = tuple([x[0] for x in var_stride if x[0] in low_dims]) + var_lists.append(var_list) + for i, var_list in enumerate(var_lists): + if len(var_list) < len(low_dims): + continue + for j, other in enumerate(var_lists): + if i == j or len(other) < len(low_dims): + continue + if var_list != other: + return True + return False + + @classmethod + def key(cls, x): + # to be higher than x and y + if x.name[0] == 'w' or x.name[0] == 'v' or x.name[0] == 't': + return "zz" + x.name + # to be lower than floor_dir + elif isinstance(x.expr, ModularIndexing): + return x.name[0] + "0" + x.name[1:] + else: + return x.name + + @classmethod + def total_split_numels(cls, axis_list): + numels = [x.length for x in axis_list] + return reduce(lambda x, y: x * y, numels) if numels else 1 + + # Split 原则1 :先做维度合并,再切分 。通过维度合并降维降低split和tiling轴选择策略的复杂性 。 + # Split 原则2 : 切分轴尽量选择高维度的轴, 这样load/store 能够有比较好的线性度 , + # Split 原则3 : 规约轴和低维轴不应选为切分轴 。但如果高维规约类融合算子,而且高维尺寸非常大( >= 64KB),其他维度不足以支持切分,可以考虑对规约轴切分。 + # Split 原则4 :切分轴的总numel 要超过 aicore总数。切分轴的数量最好不要超过3个(triton 最多支持三维发射), 因此 如果一点要超, 需要维度合并。 + def select_split_axis(self): + self.kernel.split_axis.clear() + + # total numel exceed aicore or total split axis exceed 3 + def meet_stop_condition(): + if self.total_split_numels(self.kernel.split_axis) >= num_vector_core: + return True + if len(self.kernel.split_axis) == 3: + return True + return False + + def select_one_split_axis(not_reduction=True, not_low_dims=True): + for axis in self.kernel.sorted_axis: + if not_reduction and axis.prefix == "r": + continue + if not_low_dims and axis.sorted_order in self.kernel.low_dims: + continue + if axis in self.kernel.split_axis: + continue + axis.is_split_axis = True + return axis + return None + + count = 0 + while not meet_stop_condition(): + count += 1 + axis = select_one_split_axis(not_reduction=True, not_low_dims=True) + if axis is not None: + self.kernel.split_axis.append(axis) + continue + axis = select_one_split_axis(not_reduction=True, not_low_dims=False) + if axis is not None: + self.kernel.split_axis.append(axis) + continue + if count > 10: + break + + if not self.kernel.split_axis and self.kernel.sorted_axis: + self.kernel.split_axis.append(self.kernel.sorted_axis[0]) + + self.kernel.split_axis.sort(reverse=True, key=self.key) + for i, x in enumerate(self.kernel.split_axis): + x.split_order = i + + # Tiling 原则1:load / store 中索引表达式的中的低维轴都要成为tiling 轴. + # Tiling 原则2:对于规约算子,规约轴要成为tiling轴。 + # Tiling 原则3: 多维规约, 只有规约轴可以被选择为tiling轴 + # Tiling 原则4: tiling轴 要覆盖 total numel 的 80% + + # two tiling axis might be insufficient when there're 3 or more low-dims in indexing + def select_tiling_axis(self): + self.kernel.tiling_axis.clear() + + # cover the biggest axis and not exceed 3 axis + def meet_stop_condition(): + total_numel = reduce(lambda x, y: x + y, + map(lambda x: x.length, self.kernel.sorted_axis)) if self.kernel.sorted_axis else 1 + tiling_numel = reduce(lambda x, y: x + y, + map(lambda x: x.length, self.kernel.tiling_axis)) if self.kernel.tiling_axis else 1 + if self.kernel.numof_reduction_axis() > 1 and all( + self.kernel.range_tree_nodes[var].is_tiling_axis for var in self.kernel.reduction_axis_list()): + return True + # currently, the maximum dim that triton-ascend support is 2 + max_transpose_dims = 2 + if (self.possible_need_permute or tiling_numel / total_numel >= 0.8) and \ + len(self.kernel.tiling_axis) >= min(max_transpose_dims, len(self.kernel.sorted_axis)): + return True + return False + + def select_tiling(low_dim=True, reduction=True): + for axis in reversed(self.kernel.sorted_axis): + if low_dim and axis.sorted_order in self.kernel.low_dims and axis not in self.kernel.tiling_axis: + axis.is_tiling_axis = True + self.kernel.tiling_axis.append(axis) + if reduction and axis.prefix == 'r' and axis not in self.kernel.tiling_axis: + axis.is_tiling_axis = True + self.kernel.tiling_axis.append(axis) + if low_dim or reduction: + continue + # using principle 4, select one longest + longest = axis # self.find_longest_dimension(check_in_tiling = True) + if longest and longest not in self.kernel.tiling_axis: + self.kernel.tiling_axis.append(longest) + longest.is_tiling_axis = True + if meet_stop_condition(): + break + + select_tiling(low_dim=True, reduction=True) + count = 0 + while not meet_stop_condition(): + select_tiling(low_dim=False, reduction=False) + count += 1 + if count > 10: + break + self.kernel.tiling_axis.sort(reverse=True, key=self.key) + for i, x in enumerate(self.kernel.tiling_axis): + x.tiling_order = i + + def select_split_tiling_axis(self): + self.select_split_axis() + self.select_tiling_axis() + + # the below logic doesn't work when there're two reduction axis, but only one need outer reduction + def should_outer_reduce_me(self, x): + should_outer = self.kernel.is_higher_order_reduction(True) and SplitTiling.great_than(x.length, + 32768) and x.is_loop + if should_outer: + self.should_outer_reduce = True + self.kernel.split_axis = x + self.kernel.split_axis.is_split_axis = True + return should_outer + + def find_longest_dimension(self, check_in_tiling=False): + longest = None + for axis in self.kernel.sorted_axis: + if (longest is None or axis.length > longest.length) and \ + (not check_in_tiling or axis not in self.kernel.tiling_axis): + longest = axis + return longest + + # return True when x is the low-dim in indexing + def is_lowest_dimension(self, x): + return x.sorted_order in self.kernel.low_dims + + def find_lowest_dimension(self): + def construct_low_dim(): + for index in self.indexing: + coefficients_dict = index.as_coefficients_dict() + for key, value in coefficients_dict.items(): + if not key.free_symbols: + continue + key = list(key.free_symbols)[0] + if key not in self.kernel.range_tree_nodes: + continue + + if value == sympy.Integer(1): + axis = self.kernel.range_tree_nodes[key] + self.kernel.low_dims.add(axis.sorted_order) + + # all read index should be considered + buf_names = [ + node.node.name + for node in self.kernel.node_schedule + if node not in (EnableReduction, DisableReduction) + ] + for node in self.kernel.node_schedule: + if node in (EnableReduction, DisableReduction): + continue + names = [] + + for read in node._body.memory_usage[MemoryUsageType.LOAD]: + name = read.index_name + arg = read.buffer_name + read_is_inptr = False if arg[:3] != 'arg' and arg in buf_names else True + if read_is_inptr: + names.append(name) + for key, index in node._body.indexing.items(): + if key in names and index not in self.indexing: + self.indexing.append(index) + + if self.kernel.inside_reduction: + construct_low_dim() + return + + # for non-reduction, write index should be considered + for node in self.kernel.node_schedule: + if node in (EnableReduction, DisableReduction): + continue + names = [] + for write in node._body.memory_usage[MemoryUsageType.STORE]: + names.append(write.index_name) + for write in node._body.memory_usage[MemoryUsageType.STORE_REDUCTION]: + names.append(write.index_name) + for key, index in node._body.indexing.items(): + if key in names and index not in self.indexing: + self.indexing.append(index) + + construct_low_dim() + + @staticmethod + def convert(x, y): + xnumel = x + ynumel = y + if isinstance(xnumel, (sympy.Symbol, sympy.Expr)) and not isinstance(xnumel, sympy.Integer): + xnumel = xnumel.subs(V.graph.sizevars.var_to_val) + + if isinstance(ynumel, (sympy.Symbol, sympy.Expr)) and not isinstance(ynumel, sympy.Integer): + ynumel = ynumel.subs(V.graph.sizevars.var_to_val) + + if isinstance(xnumel, sympy.Integer) and isinstance(ynumel, int): + ynumel = sympy.Integer(ynumel) + + if isinstance(ynumel, sympy.Integer) and isinstance(xnumel, int): + xnumel = sympy.Integer(xnumel) + + return (xnumel, ynumel) + + @staticmethod + def less_than(x, y): + xnumel, ynumel = SplitTiling.convert(x, y) + return xnumel < ynumel + + @staticmethod + def great_than(x, y): + xnumel, ynumel = SplitTiling.convert(x, y) + return xnumel > ynumel + + @staticmethod + def ge_than(x, y): + xnumel, ynumel = SplitTiling.convert(x, y) + return xnumel >= ynumel diff --git a/torch_npu/_inductor/codegen/tile_generator.py b/torch_npu/_inductor/codegen/tile_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..5ce883552d81ebdd74f06ad8cd90b1a45fcc0797 --- /dev/null +++ b/torch_npu/_inductor/codegen/tile_generator.py @@ -0,0 +1,253 @@ +import copy +import functools +import math +import sys +from torch._inductor.runtime.runtime_utils import next_power_of_2 +from torch._inductor.runtime.triton_heuristics import Config + +from .triton_utils import byte_per_numel +from .. import config + + +# generate tiling configs +class TileGenerator: + + def __init__(self, numels, axis_names, tiling_axis, split_axis, low_dims, persistent_reduction, + configs, dtype, dual_reduction=False): + self.numels = numels.copy() + + self.blocks = [x for x in self.numels] + self.candidate_blocks = [] + self.sub_blocks = self.blocks.copy() + self.axis_name = axis_names + self.tiling_axis = tiling_axis + self.split_axis = split_axis + self.low_dims = low_dims + self.configs = configs + self.dtype_bytes = self.get_byte_per_numel(dtype) + self.stop_numel = 1024 // self.dtype_bytes + self.block_name = {} + self.sub_block_name = {} + self.persistent_reduction = persistent_reduction + self.dual_reduction = dual_reduction + for axis, name in enumerate(self.axis_name): + if axis not in tiling_axis and axis not in split_axis: + self.blocks[axis] = 1 + self.sub_blocks[axis] = 1 + continue + if axis in self.split_axis: + self.block_name[axis] = f"{name.upper()}BLOCK" + if axis in self.tiling_axis: + self.sub_block_name[axis] = f"{name.upper()}BLOCK_SUB" + + def calcu_last_split_blocks(self, axis): + splits = 1 + for x in self.split_axis: + if x != axis: + splits = splits * ((self.numels[x] + self.blocks[x] - 1) // self.blocks[x]) + else: + break + + last_splits = config.num_vector_core // splits + last_blocks = (self.numels[axis] + last_splits - 1) // last_splits + return last_blocks + + + def aligned_numel(self, numel): + min_numel = 32 // self.dtype_bytes + if numel <= min_numel: + return numel + aligned = ((numel + min_numel - 1) // min_numel) * min_numel + return aligned + + @classmethod + def get_byte_per_numel(cls, dtype): + if dtype is None: + return 1 + return byte_per_numel[dtype] + + def valid_tile_numel(self, total_numel): + byte_num = self.dtype_bytes + max_numel = 16384 * 4 // byte_num + return total_numel <= max_numel + + def calculate_config_numel(self, cfg): + total_numel = 1 + for axis in self.tiling_axis: + total_numel = total_numel * cfg[self.sub_block_name[axis]] + return total_numel + + def calculate_total_numel(self): + smallest = sys.maxsize + + def calculate_total_numel_candi(blocks): + total_numel = 1 + for axis in self.tiling_axis: + total_numel = total_numel * self.sub_blocks[axis] + return total_numel + + for candi_blocks in self.candidate_blocks: + numel = calculate_total_numel_candi(candi_blocks) + if numel < smallest: + smallest = numel + return smallest + + def fill_config(self, cfg, blocks): + for axis in self.split_axis: + cfg[self.block_name[axis]] = blocks[axis] + for axis in self.tiling_axis: + tiling_numel = self.aligned_numel(self.sub_blocks[axis]) + cfg[self.sub_block_name[axis]] = tiling_numel + + def find_config(self, cfg): + for config_var in self.configs: + if config_var.kwargs == cfg: + return True + return False + + def add_to_configs(self, candi_block): + newcfg = {} + self.fill_config(newcfg, candi_block) + total_numel = self.calculate_config_numel(newcfg) + if self.valid_tile_numel(total_numel) and not self.find_config(newcfg): + self.configs.append(Config(newcfg, num_warps=1, num_stages=1)) + + def descend_one_axis(self, axis, is_split=False): + def calc_total_programs(): + grids = [] + for axis in self.split_axis: + numel = self.numels[axis] + block_size = self.blocks[axis] + programs = (numel + block_size - 1) // block_size + grids.append(programs) + + total_programs = functools.reduce(lambda x, y: x * y, grids) if grids else 1 + return total_programs + + reached_stop_numel = False + slow_decend_split = False + + while True: + total_numel = self.stop_numel + 100 + for candi_block in self.candidate_blocks: + self.add_to_configs(candi_block) + + # tile numel reached threshold + total_numel = self.calculate_total_numel() + if total_numel <= self.stop_numel: + self.add_to_configs(self.blocks) + reached_stop_numel = True + break + + numel = self.blocks[axis] if is_split else self.sub_blocks[axis] + if numel == 1: + self.add_to_configs(self.blocks) + break + + if is_split: + if self.persistent_reduction and self.axis_name[axis][0] == "r": + reached_stop_numel = True + break + total_programs = calc_total_programs() + if total_programs > config.num_vector_core: + last_blocks = self.calcu_last_split_blocks(axis) + if last_blocks != self.blocks[axis]: + self.blocks[axis] = last_blocks + self.candidate_blocks.append(tuple(self.blocks)) + break + if total_programs > config.num_vector_core // 2 or self.dual_reduction: + if len(self.candidate_blocks) > 2: + self.candidate_blocks.pop(0) + self.candidate_blocks.append(tuple(self.blocks)) + slow_decend_split = (total_programs > config.num_vector_core // 2) + + if not slow_decend_split: + self.blocks[axis] = numel // 2 + self.sub_blocks[axis] = self.blocks[axis] + else: + step = numel // 4 if numel // 4 > 1 else 1 + self.blocks[axis] = numel - step + self.sub_blocks[axis] = self.blocks[axis] + total_programs = calc_total_programs() + if self.blocks[axis] == 1 and (total_programs > config.num_vector_core // 2 or self.dual_reduction): + self.candidate_blocks.append(tuple(self.blocks)) + else: + if numel >= 32: + self.sub_blocks[axis] = next_power_of_2(numel // 2) + else: # numel >4 and numel < 128 : + numel = self.sub_blocks[axis] + self.sub_blocks[axis] = numel - 1 + return reached_stop_numel + + + def descend_all_low_dims(self): + low_dim_numels = [self.sub_blocks[x] for x in self.low_dims] + if not low_dim_numels: + return + + def descent_all_axis(min_numel): + for axis in self.low_dims: + if self.axis_name[axis][0] == "r" and self.persistent_reduction: + continue + numel = self.sub_blocks[axis] + if numel == 1: + continue + if min_numel > 1 and abs(numel - min_numel) / min_numel < 0.2: + continue + if numel >= 128: + self.sub_blocks[axis] = next_power_of_2(numel // 2) + else: # numel >4 and numel < 128 : + numel = self.sub_blocks[axis] + numel = numel // 2 + self.sub_blocks[axis] = min(self.aligned_numel(numel), next_power_of_2(numel)) + + count = 0 + total_numel = self.calculate_total_numel() + while total_numel > self.stop_numel and count < 100: + count += 1 + total_numel = self.calculate_total_numel() + for candi_block in self.candidate_blocks: + self.add_to_configs(candi_block) + min_numel = min(low_dim_numels) + descent_all_axis(min_numel) + total_numel_2 = self.calculate_total_numel() + if total_numel == total_numel_2: + descent_all_axis(0) + + return total_numel < self.stop_numel + + def descend_split_tiling(self): + + tiling_not_low_dims = [x for x in self.tiling_axis if x not in self.low_dims] + + def descend_split_axis(): + + for axis in self.split_axis: + if self.descend_one_axis(axis, is_split=True): + return True + + total = self.calculate_total_numel() + return total <= self.stop_numel + + def desceond_tiling_not_low_dims(): + for axis in tiling_not_low_dims: + if self.axis_name[axis][0] == "r" and self.persistent_reduction: + continue + if self.descend_one_axis(axis): + return True + total = self.calculate_total_numel() + return total <= self.stop_numel + + + while True: + # descend split axis + if descend_split_axis(): + break + if len(self.candidate_blocks) > 0: + self.sub_blocks = list(self.candidate_blocks[0]) + # descend tiling but not low dims + if desceond_tiling_not_low_dims(): + break + # descend low dims, need to descend all axis at the same time + self.descend_all_low_dims() + break diff --git a/torch_npu/_inductor/codegen/triton.py b/torch_npu/_inductor/codegen/triton.py new file mode 100644 index 0000000000000000000000000000000000000000..35c6deaae77d6a89efaa90e5bc6dc200769775a7 --- /dev/null +++ b/torch_npu/_inductor/codegen/triton.py @@ -0,0 +1,1954 @@ +import functools +import itertools +import operator +import os +import re +import textwrap +from enum import Enum +from typing import List, Set, Iterable, Callable, Sequence +from typing import ( + Optional, + Union, + Tuple, + Any, + cast, + Dict +) +import sympy +import torch +from torch._inductor import config, ir +from torch.utils._ordered_set import OrderedSet +from torch._inductor.codegen.common import ( + IndentedBuffer, + SizeArg, + DeferredLine, + ArgName +) +from torch._inductor.codegen.common import free_symbol_is_type +from torch._inductor.codegen.simd import CantSplit, DisableReduction, EnableReduction +from torch._inductor.codegen.triton import ( + IndexingOptions, + triton_reshape, + TritonCSEVariable, +) +from torch._inductor.ops_handler import OpsHandler +from torch._inductor.codegen.triton import ( + TritonKernel, + TritonKernelOverrides, + IterationRangesRoot, + IterationRangesEntry, + CSEVariable, + gen_common_triton_imports, + BlockPtrOptions, + triton_acc_type, + constant_repr, + is_welford_reduction, FixedTritonConfig, + prefix_is_reduction, upcast_acc_dtype, + get_kernel_category_by_source_code, + get_fused_kernel_name +) +from torch._inductor.codegen.triton_utils import config_of, signature_of, signature_to_meta +from torch._inductor.dtype_propagation import DtypePropagationOpsHandler +from torch._inductor.runtime.hints import ReductionHint +from torch._inductor.runtime.runtime_utils import next_power_of_2 +from torch._inductor.scheduler import SchedulerNode +from torch._inductor.utils import ( + Placeholder, + get_bounds_index_expr, + upcast_compute_type, + sympy_product +) +from torch._inductor.utils import sympy_index_symbol, generate_assert +from torch._inductor.utils import sympy_subs +from torch._inductor.virtualized import ( + V, + StoreMode, + ReductionType, + _ops as ops, +) +from torch.utils import _pytree as pytree +from torch.utils._sympy.functions import FloorDiv, Identity, ModularIndexing +from torch.utils._sympy.numbers import int_oo +from torch.utils._sympy.symbol import SymT, symbol_is_type +from torch.utils._sympy.value_ranges import bound_sympy, ValueRanges +from torch._inductor.bounds import ValueRangeAnalysis +from torch._inductor.runtime import triton_heuristics + +from .. import config as inductor_npu_config + +from .kernel_analysis import IndexAnalysis, ReductionAnalysis +from .npu_kernel_features import NumelList +from ..runtime import NPUDeviceProperties +from .. import npu_triton_heuristics + + +def flatten(nums): + res = [] + for i in nums: + if isinstance(i, list): + res.extend(flatten(i)) + else: + res.append(i) + return res + + +class NPUTritonKernelOverrides(TritonKernelOverrides): + + @staticmethod + def exp(x): + return f"tl_math.exp({x})" + + @staticmethod + def sqrt(x): + return f"tl_math.sqrt({x})" + + @staticmethod + def tanh(x): + return f"tl_math.tanh({x})" + + @staticmethod + def rsqrt(x): + return f"tl.rsqrt({x})" + + @staticmethod + def floor(x): + return f"tl_math.floor({x})" + + @staticmethod + def erf(x): + return f"tl_math.erf({x})" + + @staticmethod + def ceil(x): + return f"tl_math.ceil({x})" + + @classmethod + def index_expr(cls, expr, dtype): + indexing = V.kernel.indexing(expr, block_ptr=False, is_index_expr=True) + if not isinstance(indexing, IndexingOptions): + raise TypeError(f"not a IndexingOptions : {indexing}") + + # Our sympy expr printing casts to the current kernel index dtype. + # we only respect non int32-int64 dtypes and otherwise use current kernel indexing dtype + index_dtype = torch.int32 if V.kernel.index_dtype == "tl.int32" else torch.int64 + dtype = dtype if dtype not in (torch.int32, torch.int64) else index_dtype + var = V.kernel.cse.generate( + V.kernel.compute, + indexing.index_str, + bounds=get_bounds_index_expr(expr), + dtype=dtype, + ) + + if dtype not in (torch.int32, torch.int64): + var = V.kernel.cse.generate( + V.kernel.compute, + cls.to_dtype(var, dtype), + dtype=upcast_compute_type(dtype), + ) + else: + # We are not always consistent in enforcing that the output of the index expr printing + # results in the indexing dtype. So if we detect that we have an input which might type promote + # to a dtype other than indexing dtype, add a cast. + # Trying to avoid + dtype = index_dtype + for index_var in expr.free_symbols: + if symbol_is_type(index_var, SymT.TMP): + dtype = torch.promote_types( + dtype, V.kernel.cse.varname_map[index_var.name].dtype + ) + + if dtype != index_dtype: + var = V.kernel.cse.generate( + V.kernel.compute, + cls.to_dtype(var, index_dtype), + dtype=index_dtype, + ) + + var.mask_vars = indexing.mask_vars + return var + + +def group_fn(self, sizes): + groups = list() + for s in sizes: + if not s: + groups.append(1) + elif isinstance(s, list): + group = flatten(s) + groups.append(NumelList(tuple(group)) if isinstance(group, list) else group) + else: + groups.append(s) + return tuple(groups) + + +@staticmethod +def select_index_dtype(node_schedule, numel, reduction_numel): + return "tl.int32" + + +class IterationRangesEntryNPUIndex(IterationRangesEntry): + def __init__( + self, + *args, **kwargs): + super().__init__(*args, **kwargs) + self.is_tiling_axis = False + self.is_split_axis = False + self.indexing_code = IndentedBuffer() + self.sorted_order = None + self.tiling_order = None + self.split_order = None + self.var_directions = {} + self.directions = [] + # don't use functools.lru_cache(None), so that previous indexing_code produdec by previous index, + # could be overwritten + self.codegen = self._codegen + + # axis mask + def _codegen_mask(self): + + if self.is_tiling_axis: + BLOCK_NAME = f"{self.name.upper()}BLOCK" + upper = f"min({BLOCK_NAME}+{self.symbol()}_offset, {self.name}_numel)" if self.is_split_axis else f"{self.name}_numel" + line = f"{self.name}_mask = {self.name} < {upper}" + self.writeline(line) + for var in self.var_directions.keys(): + line = f"{var.name}_mask = {var.name} < {upper}" + self.writeline(line) + else: + pass + + def get_axis_direction(self): + + # assume self.golden_var_list is to be correct axis order + + if self.directions: + return f"[{','.join(self.directions)}]" + tiling_axis = [x.symbol() for x in self.kernel.tiling_axis] + + rev_orders = [x for x in self.kernel.golden_var_list if x in tiling_axis] + self.directions = ["None"] * len(tiling_axis) + if len(tiling_axis) != len(rev_orders): + raise RuntimeError(f"assert tiling len={len(tiling_axis)}, not equal to golden varlist len ={len(rev_orders)}") + var_orders = list(reversed(rev_orders)) + index = var_orders.index(self.symbol()) + self.directions[index] = ":" + return f"[{','.join(self.directions)}]" + + # axis var, need to define var with diffent direction + def _codegen(self): + self.indexing_code.clear() + index = None + # for multiple reduce dims, don't need this + if not self.is_tiling_axis: + return self.name + + direction = self.get_axis_direction() + index = f"{self.name} = {self.codegen_index(direction)}" + for var, dir_index in self.var_directions.items(): + line = f"{var.name} = {self.codegen_index(dir_index)}" + self.writeline(line) + + # reduction axis + if self.prefix == 'r': + if V.kernel.inside_reduction and V.kernel.current_node \ + and isinstance(V.kernel.current_node, SchedulerNode) \ + and V.kernel.current_node.node \ + and V.kernel.current_node.node.data \ + and isinstance(V.kernel.current_node.node.data, ir.Reduction): + reduction_type = V.kernel.current_node.node.data.reduction_type + if reduction_type in {"argmax", "argmin"}: + self.writeline(f"{self.parent.prefix}index = " + f"{self.codegen_index(None)}") + if index: + self.writeline(index) + self._codegen_mask() + return self.name + + def writeline(self, line): + self.indexing_code.writeline(line) + + def is_1d_persisent_reduction(self): + return len(V.kernel.tiling_axis) == 1 and V.kernel.persistent_reduction + + def codegen_index(self, direction): + BLOCK_NAME = f"{self.name.upper()}BLOCK" + BLOCK_NAME_SUB = f"{BLOCK_NAME}_SUB" + index = None + if self.prefix == 'r': + if V.kernel.persistent_reduction: + if self.is_1d_persisent_reduction(): + index = f"tl.arange(0, {BLOCK_NAME_SUB})" + else: + index = f"base_{self.name}" + else: + index = f"(loop_{self.name} * {BLOCK_NAME_SUB}) + base_{self.name}" + else: + if self.is_split_axis: + offset = f"{self.symbol()}_offset" + index = f"{offset} + (loop_{self.name} * {BLOCK_NAME_SUB}) + base_{self.name}" + else: + index = f"(loop_{self.name} * {BLOCK_NAME_SUB}) + base_{self.name}" + + if len(V.kernel.tiling_axis) > 1 and direction is not None: + index += direction + + return index + + def codegen_header(self, code): + # generate offset index loop + lines = [] + BLOCK_NAME = f"{self.name.upper()}BLOCK" + BLOCK_NAME_SUB = f"{BLOCK_NAME}_SUB" + + if self.is_1d_persisent_reduction(): + return + + if self.is_split_axis: + lines.append(f"{self.symbol()}_offset = tl.program_id({self.split_order}) * {BLOCK_NAME}") + + if self.is_tiling_axis: + lines.append(f"base_{self.name}= tl.arange(0, {BLOCK_NAME_SUB})") + block = f"{BLOCK_NAME}" if self.is_split_axis else f"{self.symbol()}_numel" + lines.append(f"loops_{self.name} = ({block} + {BLOCK_NAME_SUB} - 1) // {BLOCK_NAME_SUB}") + + else: + pass + + code.writelines(lines) + + def precomputed_args(self): + # for dynamic shapes, find parts of indexing expressions that have to be precomputed + precomputed_args: List[sympy.Expr] = [] + if isinstance(self.expr, (sympy.Symbol, sympy.Integer)): + return precomputed_args + + if not isinstance(self.expr, (FloorDiv, ModularIndexing)): + raise RuntimeError("assert isinstance(self.expr, (FloorDiv, ModularIndexing)), type(self.expr)") + for arg in self.expr.args[1:]: + if not isinstance(arg, (sympy.Integer, sympy.Symbol)): + symbols = arg.free_symbols + if len(symbols) > 0 and all( + symbol_is_type(s, SymT.SIZE) for s in symbols + ): + precomputed_args.append(arg) + return precomputed_args + + def __eq__(self, other): + return self.name == other.name + + +class IterationRangesRootNPUIndex(IterationRangesRoot): + def __init__( + self, + name: str, + numel: sympy.Expr, + prefix: str, + index: int, + kernel: TritonKernel, + pid_cache=None, + *, + is_loop: bool, + tensor_dim: Optional[int], + grid_dim: Optional[int], + ): + super().__init__(name, numel, prefix, index, kernel, pid_cache, is_loop=is_loop, tensor_dim=tensor_dim, + grid_dim=grid_dim, has_zdim=False) + + def __repr__(self): + return f"IterationRangesRootNPUIndex({self.name!r}, {self.numel}, ...)" + + def remove_entry(self, name): + if name in self.var_ranges: + del self.var_ranges[name] + if name in self.var_list: + del self.var_list[self.var_list.index(name)] + if name in V.kernel.range_tree_nodes: + V.kernel.range_tree_nodes_removed[name] = V.kernel.range_tree_nodes[name] + del V.kernel.range_tree_nodes[name] + if name in self.nodes: + del self.nodes[name] + + def duplicated_check(self, divisor, length): + """ + Lookup a given RangeTreeEntry, creating it if needed + """ + if V.graph.sizevars.statically_known_equals(divisor * length, self.numel): + expr = FloorDiv(sympy_index_symbol(f"{self.prefix}index"), divisor) + else: + expr = ModularIndexing( + sympy_index_symbol(f"{self.prefix}index"), divisor, length + ) + + return expr not in self.nodes + + def lookup(self, divisor, length): + """ + Lookup a given RangeTreeEntry, creating it if needed + """ + if V.graph.sizevars.statically_known_equals(divisor * length, self.numel): + expr = FloorDiv(sympy_index_symbol(f"{self.prefix}index"), divisor) + else: + expr = ModularIndexing( + sympy_index_symbol(f"{self.prefix}index"), divisor, length + ) + + if expr not in self.nodes: + node = IterationRangesEntryNPUIndex( + f"{self.prefix}{next(V.kernel.iter_vars_count)}", + divisor, + length, + expr, + self, + ) + V.kernel.range_tree_nodes[node.symbol()] = node + self.var_list.append(node.symbol()) + self.var_ranges[node.symbol()] = length + self.nodes[expr] = node + + return self.nodes[expr] + + +@classmethod +def is_compatible( + cls, + groups: Iterable[sympy.Expr], + lengths: Sequence[Sequence[sympy.Expr]], + reduction_numel: sympy.Expr = sympy.S.One +): + # Fill in the reduction numel, in case the node is missing it. + sizevars = V.graph.sizevars + if len(lengths[1]) == 0 and ( + sizevars.statically_known_equals( + sympy_product(groups), + sympy_product(lengths[0]) * reduction_numel, + ) + ): + lengths = (lengths[0], [reduction_numel]) + + try: + groups = flatten(groups) + NPUIndexTritonKernel._split_iteration_ranges(groups, lengths) + return True + except CantSplit: + return False + + +class NPUIndexTritonKernel(TritonKernel): + overrides = NPUTritonKernelOverrides + + def __init__( + self, + tiling: Dict[str, sympy.Expr], + min_elem_per_thread=0, + optimize_mask=True, + fixed_config: Optional[FixedTritonConfig] = None, + **kwargs, ): + + super().__init__(tiling=tiling, + min_elem_per_thread=min_elem_per_thread, + optimize_mask=optimize_mask, + fixed_config=fixed_config, + **kwargs) + self.first_node = True + self.inside_high_order_reduction = False + self.low_dims = set() + self.split_axis = [] + self.tiling_axis = [] + self.range_tree_nodes_removed: Dict[sympy.Symbol, IterationRangesEntry] = {} + self.range_tree_nodes_substituted = {} + self.expr_substituted = {} + self.sorted_axis = [] + self.prefix: IndentedBuffer = IndentedBuffer() + self.index_analysis = {} # var_list -> indexAnalysis + self.golden_var_list = None + self.reduce_analysis = None + self.load_store_indexing = None + + def _get_grid_type(self) -> type[triton_heuristics.GridExpr]: + return npu_triton_heuristics.GridNpu + + def gen_triton_ext_imports(self): + imports = IndentedBuffer() + imports.splice( + """ + from torch._inductor.runtime import triton_helpers + from torch_npu._inductor import npu_triton_heuristics + from torch_npu._inductor import npu_triton_helpers + from torch_npu._inductor.runtime import NPUDeviceProperties + from torch_npu._inductor.npu_triton_helpers import libdevice, math as tl_math + import torch + import torch_npu + """ + ) + return imports.getvalue() + + def patch_triton_hash(self): + # remove this method once the original invocation is fixed + import hashlib + from triton.compiler.compiler import triton_key, make_backend + from triton.runtime.driver import driver + backend = make_backend(driver.active.get_current_target()) + key = f"{triton_key()}-{backend.hash()}" + return hashlib.sha256(key.encode("utf-8")).hexdigest() + + def numof_tiling_axis(self): + return len(self.tiling_axis) + + # do nothing in NpuTritonKernel + def codegen_range_tree(self): + pass + + def initialize_range_tree(self, pid_cache): + self.total_numels = 0 + for k, x in self.numels.items(): + if not isinstance(x, sympy.Integer): + x = x.subs(V.graph.sizevars.var_to_val) + self.numels[k] = x + if x > 1: + self.total_numels += 1 + + no_r_dim = not self.inside_reduction or self.numels["r"] == 1 + prefixes = "wvtzyxr" + active_prefixes = prefixes[-len(self.numels):] + # prefix can not be 's', 'u', 'ps' , 'i', 'z' + # prefix can not be 'p' but can be 'z' since 2.6 + grid_dims = "xyztvw" + if self.no_x_dim: + tensor_dims = "r" + elif no_r_dim: + tensor_dims = "xyztvw" + else: + tensor_dims = "xyztvwr" + tensor_dims = "".join(p for p in tensor_dims if p in active_prefixes) + for i, prefix in enumerate(active_prefixes): + is_reduction = prefix_is_reduction(prefix) + tensor_dim = tensor_dims.find(prefix) if prefix in tensor_dims else None + grid_dim = None if is_reduction else grid_dims.find(prefix) + index = i if grid_dim is None else grid_dim + self.range_trees.append( + IterationRangesRootNPUIndex( + f"{prefix}index", + self.numels[prefix], + prefix, + index, + self, + pid_cache=pid_cache, + is_loop=is_reduction and not self.persistent_reduction, + tensor_dim=tensor_dim, + grid_dim=grid_dim + ) + ) + + def codegen_reduction_numels(self, buffer) -> None: + reduction_trees = [tree for tree in self.range_trees if tree.is_reduction] + if len(reduction_trees) > 1: + raise AssertionError("Currently npu don't support multi-reduction ranges trees, e.g, r0, r1.") + + def get_axis_dtype(self, axis): + dtype = None + if axis is None: + return None + for node in self.node_schedule: + if node in (EnableReduction, DisableReduction): + continue + if axis.symbol() in node._body.indexing_map: + dtype = V.graph.get_dtype(node.node.name) + break + if dtype is None: + should_break_all = False + for node in self.node_schedule: + if should_break_all: + break + if node in (EnableReduction, DisableReduction): + continue + for key, _ in node._body.indexing_map.items(): + if key in self.range_tree_nodes: + dim = self.range_tree_nodes[key] + else: + dim = self.range_tree_nodes_removed[key] + + if dim.parent == axis.parent: + dtype = V.graph.get_dtype(node.node.name) + should_break_all = True + break + return dtype + + def create_inductor_meta(self): + mutated_args = set() + for mutation in self.mutations: + if mutation in self.args.input_buffers: + mutated_args.add(self.args.input_buffers[mutation]) + if ( + mutation in self.args.inplace_buffers + and mutation not in V.graph.removed_buffers + and mutation not in self.removed_buffers + ): + mutated_args.add(self.args.inplace_buffers[mutation].inner_name) + if mutation in self.args.output_buffers: + mutated_args.add(self.args.output_buffers[mutation]) + mutated_args = sorted(mutated_args) + tiling_axis = [x.sorted_order for x in self.tiling_axis] + split_axis = [x.sorted_order for x in self.split_axis] + axis_names = [x.name for x in self.sorted_axis] + split_axis_dtype = self.get_axis_dtype(self.split_axis[0]) if self.split_axis else None + inductor_meta = { + "grid_type": self._get_grid_type().__name__, + "autotune_hints": set(self.autotune_hints), + "kernel_name": str(Placeholder.DESCRIPTIVE_NAME), + "mutated_arg_names": mutated_args, + + # Due to breaking change of triton 3.0, the original invocation is broken + "backend_hash": self.patch_triton_hash(), # torch.utils._triton.triton_hash_with_backend(), + "split_axis": split_axis, + "tiling_axis": tiling_axis, + "axis_names": axis_names, + "low_dims": self.low_dims, + "numof_reduction_axis": self.numof_reduction_axis(), + "split_axis_dtype": split_axis_dtype, + "dual_reduction": self.numof_reduction_axis() > 1, + "traced_graph_hash": "TRACED_GRAPH_HASH", + "traced_graph_dir": "TRACED_GRAPH_DIR", + "store_cubin": config.triton.store_cubin, + "force_disable_caches": config.force_disable_caches, + "profile_bandwidth_with_do_bench_using_profiling": config.profile_bandwidth_with_do_bench_using_profiling, + } + return inductor_meta + + # numels sent to autotune configs + def get_size_hints(self): + size_hints = [] + if (len(self.range_tree_nodes.values()) == 0): + return [v for _, v in self.numels.items()] + + for _, node in enumerate(self.sorted_axis): + if isinstance(node.expr, ModularIndexing): + numel_expr = node.length + else: + numel_expr = node.expr.subs({sympy_index_symbol(r.name): r.numel for r in self.range_trees}) + + numel_expr = V.graph.sizevars.symbolic_hint(numel_expr) + + size_hints.append(numel_expr) + return size_hints + + def add_numel_to_call_args(self, name, call_args, arg_types): + for node in self.sorted_axis: + if isinstance(node.expr, ModularIndexing): + numel_expr = node.length + else: + numel_expr = node.expr.subs({sympy_index_symbol(r.name): r.numel for r in self.range_trees}) + + if isinstance(numel_expr, (sympy.Integer, sympy.Symbol)): + expr = numel_expr + else: + expr = V.graph.wrapper_code.generate_node_numel_expr(name, node, numel_expr) + call_args.append(expr) + arg_types.append(type(expr)) + + def gen_numel_args(self, signature, triton_meta_signature, argdefs): + for node in self.sorted_axis: + arg_name = f"{node.name}_numel" + if not inductor_npu_config.inductor_static_mode: + sizearg = SizeArg(arg_name, node.length) + signature.append(sizearg) + triton_meta_signature[arg_name] = signature_of( + sizearg, size_dtype=self.index_dtype + ) + argdefs.append(ArgName(arg_name)) + else: + argdefs.append(ArgName(arg_name, is_constexpr=True)) + self.triton_meta["constants"][arg_name] = node.length + + # BLOCK and SUB_BLOCK definitions + def add_autotune_args(self, argdefs): + for axis in self.split_axis: + argdefs.append(ArgName(f"{axis.name.upper()}BLOCK", is_constexpr=True)) + + for axis in self.tiling_axis: + if axis.name[0] == 'r' and self.persistent_reduction: + continue + argdefs.append(ArgName(f"{axis.name.upper()}BLOCK_SUB", is_constexpr=True)) + + def _get_heuristic(self): + if self.persistent_reduction: + if not self.inside_reduction: + raise RuntimeError("assert self.inside_reduction to be true") + return "persistent_reduction_npu_index" + elif self.inside_reduction: + return "reduction_npu_index" + return "pointwise_npu_index" + + def get_kernel_name(self, src_code, node_schedule, kernel): + wrapper = V.graph.wrapper_code + if src_code in wrapper.src_to_kernel: + kernel_name = wrapper.src_to_kernel[src_code] + else: + fused_name = ( + get_fused_kernel_name(node_schedule, config.triton.descriptive_names) + if config.triton.descriptive_names + else "" + ) + kernel_category = get_kernel_category_by_source_code(src_code)[:3] + kernel_name = "_".join( + ["triton", kernel_category, fused_name, wrapper.get_next_kernel_suffix()] + ) + return kernel_name + + # modify triton_meta, inductor_meta , etc. + def codegen_kernel(self, name=None): + code = IndentedBuffer() + size_hints = self.get_size_hints() + heuristics = self._get_heuristic() + if name is None: + code.splice(gen_common_triton_imports()) + # Note: add extra imports for extensions + code.splice(self.gen_triton_ext_imports()) + + if config.benchmark_kernel: + code.splice(self.imports_for_benchmark_kernel()) + + argdefs, _, signature, _ = self.args.python_argdefs() + + for i, arg in enumerate(signature): + if isinstance(arg, SizeArg): + symbol = cast(sympy.Symbol, arg.expr) + if symbol in V.graph.sizevars.inv_precomputed_replacements: + signature[i] = SizeArg( + arg.name, V.graph.sizevars.inv_precomputed_replacements[symbol] + ) + + triton_meta_signature = signature_to_meta(signature, size_dtype=self.index_dtype, argdefs=argdefs) + + triton_meta = { + "signature": triton_meta_signature, + "device": + NPUDeviceProperties.create( + V.graph.get_current_device_or_throw() + ), + "constants": {}, + # special config for NPU, specify compile target + "mix_mode": "aiv", + } + + inductor_meta = self.create_inductor_meta() + num_gb = None + if config.benchmark_kernel or config.profile_bandwidth: + num_gb = self.estimate_kernel_num_bytes() / 1e9 + inductor_meta["kernel_num_gb"] = num_gb + + self.triton_meta = triton_meta + self.gen_numel_args(signature, triton_meta_signature, argdefs) + + # add in tiling args + self.add_autotune_args(argdefs) + # for scalar codegen + if len(self.range_tree_nodes) == 0: + self.write_scalar() + else: + self.codegen_body() + + for helper in self.helper_functions: + code.writeline("") + code.splice(helper) + + # Note: override original triton_heuristics + if self.inside_reduction: + reduction_hint = self.features.get_reduction_hint() + heuristics_line = f""" + @npu_triton_heuristics.{heuristics}( + size_hints={size_hints}, + reduction_hint={reduction_hint}, + filename=__file__, + triton_meta={triton_meta!r}, + inductor_meta={inductor_meta!r} + ) + @triton.jit + """ + else: + tile_hint = "" + if len(size_hints) == 2: + if len(signature) == 4: # input, output and 2 args + tile_hint = "tile_hint=TileHint.SQUARE," + else: + tile_hint = "tile_hint=TileHint.DEFAULT," + heuristics_line = f""" + @npu_triton_heuristics.{heuristics}( + size_hints={size_hints!r}, {tile_hint} + filename=__file__, + triton_meta={triton_meta!r}, + inductor_meta={inductor_meta!r}, + min_elem_per_thread={self.min_elem_per_thread} + ) + @triton.jit + """ + code.splice(heuristics_line) + code.writeline( + f"def {name or str(Placeholder.KERNEL_NAME)}({', '.join(x.full_name() for x in argdefs)}):" + ) + with code.indent(): + self.codegen_static_numels(code) + for old, new in self.args.aliases(): + code.writeline(f"{old} = {new}") + code.splice(self.body) + + if config.benchmark_kernel: + code.splice(self.codegen_kernel_benchmark(num_gb)) + + return code.getvalue() + + def codegen_static_numels(self, code): + for symbol in self.reduction_axis_list(): + if symbol.name[0] != "r" or not self.persistent_reduction: + continue + + node = self.range_tree_nodes[symbol] + simplified_tree_numel = V.graph.sizevars.simplify(node.length) + if isinstance(simplified_tree_numel, (sympy.Integer, int)): + val = int(simplified_tree_numel) + else: + continue + val = next_power_of_2(val) + code.writeline(f"{node.name.upper()}BLOCK_SUB: tl.constexpr = {val}") + + def lowest_axis_variable(self): + if len(self.tiling_axis) == 0: + return None + return self.tiling_axis[-1] + + def is_isolated_symbol(self, input_str, range_val): + patterns = [r'\b' + re.escape(range_val.name) + r'\b'] + for var in range_val.var_directions.keys(): + pattern = r'\b' + re.escape(var.name) + r'\b' + patterns.append(pattern) + + for pattern in patterns: + if re.search(pattern, input_str): + return True + return False + + def find_axis_in_load_store(self, range_val): + if not range_val: + return False + for line in self.loads._lines: + if line.find('tl.load') >= 0 and self.is_isolated_symbol(line, range_val): + return True + for line in self.compute._lines: + if line.find('tl.load') >= 0 and self.is_isolated_symbol(line, range_val): + return True + for line in self.post_loop_store._lines: + if line.find('tl.store') >= 0 and self.is_isolated_symbol(line, range_val): + return True + for line in self.stores._lines: + if isinstance(line, DeferredLine): + line = line.line + if line.find('tl.store') >= 0 and self.is_isolated_symbol(line, range_val): + return True + return False + + def write_scalar(self): + self.body.splice(self.indexing_code) + self.body.splice(self.loads) + self.body.splice(self.compute) + self.body.splice(self.stores) + self.loads.clear() + self.compute.clear() + self.stores.clear() + self.post_loop_store.clear() + self.prefix.clear() + + def codegen_body(self): + if not ( + self.loads + or self.stores + or self.compute + or self.post_loop_store + ): + return + + def write_pointwise(): + self.body.splice(self.indexing_code) + self.body.splice(self.loads) + self.body.splice(self.compute) + self.body.splice(self.stores) + + def codegen_range(index): + def is_1d_reduction(): + return self.numels["r"] > 1 and len(self.numels) == 1 + + def loop_body(index, indexing_code, is_last_axis, do_indent=True): + if do_indent: + self.body.do_indent() + if indexing_code: + self.body.splice(indexing_code) + if is_last_axis: + write_pointwise() + else: + codegen_range(index + 1) + if do_indent: + self.body.do_unindent() + + if index < 0 or index >= len(self.range_tree_nodes): + return + + range_val = self.sorted_axis[index] + numof_tilings = len(self.tiling_axis) + last_tiling = range_val.is_tiling_axis and numof_tilings >= 1 and range_val.tiling_order == len( + self.tiling_axis) - 1 + next_is_dual_reduction_tiling = index == len( + self.sorted_axis) - numof_tilings - 1 and self.numof_reduction_axis() + + is_last_axis = index == len(self.sorted_axis) - 1 + indexing_code = getattr(range_val, "indexing_code") + reduction_1d = is_1d_reduction() + do_indent = False + # do nothing except for writing porintwise + if len(self.loads._lines) == 0 and len(self.stores._lines) == 0: + do_indent = False + indexing_code = None + # tiling axis and last tiling + if range_val.is_tiling_axis and last_tiling: + do_indent = False + need_axis_loop = self.find_axis_in_load_store(range_val) + if not need_axis_loop: + indexing_code = None + if (range_val.prefix != 'r' or not self.persistent_reduction) and need_axis_loop: + self.body.splice(self.prefix) + self.body.writeline(f"for loop_{range_val.name} in range(loops_{range_val.name}):") + do_indent = True + loop_body(index, indexing_code, is_last_axis, do_indent) + self.body.splice(self.post_loop_store) + self.post_loop_store.clear() + + # tiling axis and but not last tiling + elif range_val.is_tiling_axis: + do_indent = False + if len(self.loads._lines) == 0 and len(self.stores._lines) == 0: + do_indent = False + indexing_code = None + if self.numof_reduction_axis() <= 1: + do_indent = True + self.body.writeline(f"for loop_{range_val.name} in range(loops_{range_val.name}):") + loop_body(index, indexing_code, is_last_axis, do_indent=do_indent) + + elif not is_last_axis: + do_indent = True + if range_val.is_split_axis: + offset = f"{range_val.name}_offset" + self.body.writeline(f"for {range_val.name} in range({offset}, " + f"min({offset} + {range_val.name.upper()}BLOCK, {range_val.name}_numel)):") + else: + self.body.writeline(f"for {range_val.name} in range({range_val.name}_numel):") + + if not reduction_1d and self.persistent_reduction: + self.body.do_indent() + self.body.splice(self.prefix) + self.prefix.clear() + self.body.do_unindent() + + loop_body(index, indexing_code, is_last_axis, do_indent=do_indent) + else: + write_pointwise() + + if self.first_node: + for node in self.sorted_axis: + node.codegen_header(self.body) + + while True: + if not self.sorted_axis[-1].is_tiling_axis: + x = self.sorted_axis[-1] + self.sorted_axis.pop(-1) + self.sorted_axis.insert(0, x) + else: + break + + if self.first_node: + codegen_range(0) + else: + last_axis_order = self.tiling_axis[-1].sorted_order + if self.persistent_reduction and self.numof_reduction_axis() > 1: + last_axis_order = last_axis_order - self.numof_reduction_axis() + 1 + for _ in range(last_axis_order): + self.body.do_indent() + codegen_range(last_axis_order) + for _ in range(last_axis_order): + self.body.do_unindent() + + self.cse.invalidate(self.outside_loop_vars) + self.loads.clear() + self.compute.clear() + self.stores.clear() + self.post_loop_store.clear() + self.prefix.clear() + self.first_node = False + + # for creat constant tensor, if have two axis, constant=tl.full([1,1]) else tl.full([1]) + def triton_tensor_ndim(self): + if self.numof_reduction_axis() > 1: + return 1 + + return len(self.tiling_axis) + + # indexing.mask_str is None , see varmean_test.py + def store_reduction(self, name: str, index: sympy.Expr, value: CSEVariable): + if not self.inside_reduction: + raise RuntimeError("assert self.inside_reduction") + + self.inside_reduction = False + indexing = self.indexing(index, block_ptr=True) + self.inside_reduction = True + var = self.args.output(name) + if isinstance(indexing, BlockPtrOptions): + self.post_loop_store.writeline( + DeferredLine( + name, + self.codegen_block_ptr_store_line( + name, + indexing, + indexing.format(var), + value, + f", boundary_check={indexing.boundary_check()!r}", + ), + ) + ) + else: + if not isinstance(indexing, IndexingOptions): + raise RuntimeError("assert isinstance(indexing, IndexingOptions)") + line = f"tl.store({var} + ({indexing.index_str} ), {value}, {indexing.mask_str})" + if self.numof_reduction_axis() > 1: + line = f"tl.store({var} + ({indexing.index_str} + tl.arange(0,1) ), {value}, {indexing.mask_str})" + self.post_loop_store.writeline( + DeferredLine(name, line) + ) + + # apply new var in case dim are permuted/broadcast + def store( + self, name: str, index: sympy.Expr, value: CSEVariable, mode: StoreMode = None + ) -> None: + + var = self.args.output(name) + original_index = index + index_analyze = IndexAnalysis(self, index, is_store_index=True) + index_analyze.analyze_index() + indexing = self.indexing(index, dense_indexing=True, block_ptr=mode is None, index_analyze=index_analyze) + index_str = indexing.index_str + value_str = f"{value}" + mask_str = indexing.mask_str + + if index_analyze.need_permute: + value_str = value_str.replace(f"{value}", f"{value}{index_analyze.generate_statement()}") + + advance_block_ptr = None + if isinstance(indexing, BlockPtrOptions): + block_ptr, advance_block_ptr, other = self.codegen_block_ptr( + name, var, indexing + ) + # block_ptr stores don't do implicit casting + line = self.codegen_block_ptr_store_line( + name, indexing, block_ptr, value, other + ) + elif mode is None: + line = f"tl.store({var} + ({index_str}), {value_str}, {mask_str})" + if self.numof_reduction_axis() > 1: + line = f"tl.store({var} + ({index_str} + tl.arange(0,1) ), {value_str}, {indexing.mask_str})" + + elif mode == "atomic_add": + line = f"tl.atomic_add({var} + ({index_str}), {value_str}, {indexing.mask_str})" + else: + raise NotImplementedError(f"store mode={mode}") + + self.stores.writeline(DeferredLine(name, line)) + if advance_block_ptr: + self.stores.writeline(advance_block_ptr) + + if not self.inside_reduction: + self.outside_loop_vars.add(value) + + def find_reduction_node(self): + node = self.current_node + if node is not None and isinstance(node, SchedulerNode): + reduction = node.node.data + if reduction is not None and isinstance(reduction, ir.Reduction): + return reduction + + for node in self.node_schedule: + if node in (EnableReduction, DisableReduction): + continue + reduction = node.node.data + if reduction is not None and isinstance(reduction, ir.Reduction): + return reduction + + return None + + # select the golden varlist, from to which to deduce permute, broadcast shape + def select_golden_varlist(self): + longest = None + maximum_length = 0 + self.golden_var_list = None + + def all_tiling_in_var_list(var_list): + return all([x in var_list for x in self.tiling_axis]) + + # all are load indexings, select the longest as gold + for index in self.load_store_indexing: + index = index.subs(V.graph.sizevars.var_to_val) + analyze = IndexAnalysis(self, index) + if len(analyze.var_list) > maximum_length and all_tiling_in_var_list(analyze.var_list): + longest = analyze.var_list + maximum_length = len(longest) + # this may cause problems + if not longest: + self.golden_var_list = tuple([x.symbol() for x in self.tiling_axis]) if self.tiling_axis else [] + else: + self.golden_var_list = tuple([x for x in longest if x in self.tiling_axis]) if self.tiling_axis else [] + if self.golden_var_list is None: + raise RuntimeError("assert self.golden_var_list is None") + + # to generate shape of the tile + + def dense_size_list(self) -> List[str]: + if self.inside_reduction: + if not self.reduce_analysis: + self.reduce_analysis = ReductionAnalysis(self) + return self.reduce_analysis.dense_size_list() + + if not self.golden_var_list: + self.select_golden_varlist() + + golden_var_list = self.golden_var_list if self.golden_var_list else [x.symbol() for x in self.tiling_axis] + if golden_var_list is None: + raise RuntimeError("assert golden_var_list is None") + sizes = [None for _ in golden_var_list] + for i, var in enumerate(reversed(golden_var_list)): + axis = self.range_tree_nodes[var] + sizes[i] = f"{axis.name.upper()}BLOCK_SUB" + return sizes + + def dense_size_str(self): + if self.inside_reduction: + if not self.reduce_analysis: + self.reduce_analysis = ReductionAnalysis(self) + return self.reduce_analysis.dense_size_str() + sizes = self.dense_size_list() + return f"[{', '.join(sizes)}]" + + # and add to shape to value + def reduction_resize(self, value, dim): + ndims = self.triton_tensor_ndim() + if ndims == 1: + return f"triton_helpers.promote_to_tensor({value})" + dense_list = self.dense_size_list() + dense_list[dim] = "1" + expand_str = ", ".join(dense_list) + return f"{value}.reshape({expand_str})" + + # to determine reduction_dim + def reduction_dim(self): + if not self.reduce_analysis: + self.reduce_analysis = ReductionAnalysis(self) + return self.reduce_analysis.reduced_dim + + def filter_masks(self, mask_vars): + for node in self.sorted_axis: + if not (node.is_tiling_axis): + mask_vars.discard(f"{node.name}_mask") + + def numof_reduction_axis(self): + root = self.range_trees[-1] + if root is None: + return 0 + + return len(root.var_list) + + def reduction_axis_list(self): + root = self.range_trees[-1] + if root is None: + return [] + return root.var_list + + def reduction( + self, + dtype: torch.dtype, + src_dtype: torch.dtype, + reduction_type: ReductionType, + value: Union[CSEVariable, Tuple[CSEVariable, ...]], + ) -> Union[CSEVariable, Tuple[CSEVariable, ...]]: + if not self.inside_reduction: + raise RuntimeError("assert self.inside_reduction") + masks = {f"{node.symbol()}_mask" for node in self.sorted_axis} + self.filter_masks(masks) + masks = sorted(masks) + if self._load_mask: + masks.append(self._load_mask) + reduction_range_prefix = self.range_trees[-1].prefix + if not self.reduce_analysis: + self.reduce_analysis = ReductionAnalysis(self) + dense_size_str = self.dense_size_str() + + if len(dense_size_str) > 2: + value = self._map_tuple_or_scalar( + lambda v: self.cse.generate( + self.compute, f"tl.reshape({v}, {dense_size_str})", dtype=v.dtype, + ), + value, + + ) + + dim: int + root_op: str + + def final_reduction(value): + module = "tl" # use tl + if reduction_type in {"max", "min"}: + return self.reduction_resize(f"{module}.{reduction_type}({value}, {dim})", dim) + return self.reduction_resize(f"{module}.{reduction_type}({value}, {dim})", dim) + + def final_argreduce(buffer, result_var, value, index): + buffer.splice( + f"""\ + _, {result_var}_tmp = triton_helpers.{root_op}_with_index({value}, {index}, {dim}) + {result_var} = {self.reduction_resize(f'{result_var}_tmp', dim)} + """ + ) + + def get_reduction_axis(): + return list(self.range_tree_nodes.values())[-1] + + cache_key = (src_dtype, reduction_type, value) + if cache_key in self.cse.reduction_cache: + return self.cse.reduction_cache[cache_key] + + dim = self.reduction_dim() + acc_type = triton_acc_type(src_dtype) + torch_acc_type = upcast_acc_dtype(src_dtype) + result_var: Any = self.cse.newvar(dtype=torch_acc_type) + result_var.mask_vars = {var for var in masks if var[0] != "r"} + cond = " & ".join(masks) + + def where_cond(tval, fval): + if not cond: + return tval + return TritonKernelOverrides.where(cond, tval, fval) + + if self.persistent_reduction: + default = ir.Reduction.default_value(reduction_type, src_dtype) + default = self._map_tuple_or_scalar(constant_repr, default) + + def _mask_value(value, default): + return self.cse.generate(self.compute, where_cond(value, default), dtype=value.dtype) + + # masked_value doesn't work dual reduction + if self.numof_reduction_axis() == 1: + if isinstance(value, tuple): + masked_value = [_mask_value(v, d) for v, d in zip(value, default)] + else: + masked_value = _mask_value(value, default) + else: + masked_value = value + + if reduction_type in {"argmax", "argmin", "max", "min"}: + reduce_axis = get_reduction_axis() + broadcast_string: str + reshape_str = self.reduce_analysis.get_reduce_dim_reshape(reduce_axis) + broadcast_string = f"tl.broadcast_to({reduce_axis.symbol()}.reshape({reshape_str}), {masked_value}.shape)" + accumulator_index = str( + self.cse.generate( + self.compute, + broadcast_string, + dtype=torch.int64 + ) + ) + if reduction_type == "argmax" or reduction_type == "argmin": + root_op = {"argmax": "max", "argmin": "min"}[reduction_type] + final_argreduce( + self.compute, result_var, masked_value, accumulator_index + ) + elif reduction_type == "max" or reduction_type == "min": + result_var = self.cse.generate( + self.compute, final_reduction(masked_value), dtype=masked_value.dtype, + ) + elif reduction_type == "welford_reduce": + raise RuntimeError("assert False, welford_reduction and is not supported now..") + elif reduction_type == "welford_combine": + raise RuntimeError("assert False, welford_combine and is not supported now..") + else: + result_var = self.cse.generate( + self.compute, final_reduction(masked_value), dtype=masked_value.dtype, + ) + else: + accumulator = self.cse.namedvar(f"_{result_var}", dtype=torch_acc_type) + default = ir.Reduction.default_accumulator(reduction_type, src_dtype) + default = self._map_tuple_or_scalar(constant_repr, default) + if not isinstance(default, tuple): + self.prefix.writeline( + f"{accumulator} = tl.full({self.dense_size_str()}, {default}, {acc_type})" + ) + + if reduction_type in {"argmax", "argmin"}: + accumulator_index = f"_{result_var}_index" + long_max = torch.iinfo(torch.int64).max + self.prefix.writeline( + f"{accumulator_index} = tl.full({self.dense_size_str()}, {long_max}, tl.int64)" + ) + root_op = {"argmax": "max", "argmin": "min"}[reduction_type] + + self.compute.splice( + f"""\ + {accumulator}_next, {accumulator_index}_next = triton_helpers.{root_op}imum_with_index( + {accumulator}, {accumulator_index}, {value}, {reduction_range_prefix}index + ) + {accumulator} = {where_cond(f'{accumulator}_next', accumulator)} + {accumulator_index} = {where_cond(f'{accumulator_index}_next', accumulator_index)} + """ + ) + final_argreduce(self.post_loop_store, result_var, accumulator, accumulator_index) + elif is_welford_reduction(reduction_type): + raise RuntimeError("assert False, welford_reduction and is not supported now..") + else: + combine_fn = ir.get_reduction_combine_fn(reduction_type, src_dtype) + updated = combine_fn(accumulator, value) + self.compute.writeline( + f"{accumulator} = {where_cond(updated, accumulator)}" + ) + + if src_dtype == torch.bool: + accumulator = f"{accumulator}.to(tl.int8)" + result_type = triton_compute_type(dtype) + self.post_loop_store.writeline( + f"{result_var} = {final_reduction(accumulator)}.to({result_type})" + ) + else: + self.post_loop_store.writeline( + f"{result_var} = {final_reduction(accumulator)}" + ) + + self.cse.reduction_cache[cache_key] = result_var + + if isinstance(result_var, tuple): + self.outside_loop_vars |= set(result_var) + else: + self.outside_loop_vars.add(result_var) + + return result_var + + # broadcast, permute handling + def load(self, name: str, index: sympy.Expr): + var = self.args.input(name) + original_index = index + store_cache = self.cse.store_cache + if name in store_cache: + result_var = store_cache[name] + return result_var + + index_analyze = IndexAnalysis(self, index) + index_analyze.analyze_index() + indirect_indexing = self.is_indirect_indexing(index) + indexing = self.indexing(index, block_ptr=True) + has_rindex = indexing.has_rindex() + has_tmpmask = indexing.has_tmpmask() + is_coalesced = any( + i == 1 for i in self.get_strides_of_load(original_index).values() + ) + ep = "" + if ( + (has_tmpmask or has_rindex) + and V.graph.get_dtype(name) != torch.bool + and indexing.has_mask() + ): + other = ", other=0.0" + else: + other = "" + + advance_block_ptr = None + append_broadcast = None + dtype = V.graph.get_dtype(name) + + if V.graph.is_unspec_arg(name): + line = var + else: + if isinstance(indexing, BlockPtrOptions): + block_ptr, advance_block_ptr, other = self.codegen_block_ptr( + name, var, indexing, other + ) + line = f"tl.load({block_ptr}{other}{ep})" + # add needed size=1 dimensions + line = triton_reshape( + line, indexing.block_shape, indexing.reshape_suffix + ) + elif isinstance(original_index, sympy.Integer): + line = f"tl.load({var} + tl.arange(0,1) + ({original_index}))" + full_list = ["1"] * (len(self.tiling_axis) if self.tiling_axis else 1) + append_broadcast = f"[{', '.join(full_list)} ]" + else: + index_str = indexing.index_str + mask_str = indexing.mask_str + line = f"tl.load({var} + ({index_str}), {mask_str}{ep}{other})" + + dtype = V.graph.get_dtype(name) + if dtype in (torch.bfloat16,): + line += ".to(tl.float32)" + if dtype == torch.bool and torch.version.hip is None: + line += ".to(tl.int1)" + if has_tmpmask: + # Masked loads must come after the mask is computed + load_buffer = self.compute + elif ( + self.inside_reduction + and self.range_trees[-1].is_loop + and not indirect_indexing + and not has_rindex + ): + # can lift a common load outside of reduction loop + # One exception is when this is an indirect_load. + load_buffer = self.prefix + + else: + load_buffer = self.loads + + result_var = self.cse.generate(load_buffer, line, dtype=dtype) + if not (isinstance(result_var, TritonCSEVariable)): + raise RuntimeError("assert isinstance(result_var, TritonCSEVariable)") + result_var.mask_vars = indexing.mask_vars # type: ignore[assignment] + + if append_broadcast and append_broadcast != '[]': + line = f"tl.broadcast_to({result_var}, {append_broadcast})" + result_var = self.cse.generate(load_buffer, line, dtype=dtype) + # triton can handle broadcast + elif index_analyze.need_permute: + line = f"{result_var}{index_analyze.generate_statement()}" + result_var = self.cse.generate(self.loads, line, dtype=dtype) + + if advance_block_ptr: + load_buffer.writeline(advance_block_ptr) + + if not self.inside_reduction or (not indexing.has_rmask() and not has_rindex): + self.outside_loop_vars.add(result_var) + + return result_var + + # don't call symlify_indexing + def prepare_indexing( + self, + index: sympy.Expr, + index_analyze, + is_index_expr=False + ): + index = sympy_subs(index, V.graph.sizevars.precomputed_replacements) + # if simple replacements didn't get rid of floor/ceil, try full subs + if len(index.atoms(sympy.floor)) or len(index.atoms(sympy.ceiling)): + index = index.subs(V.graph.sizevars.precomputed_replacements) + + if len(index.atoms(sympy.ceiling)): + for a in index.atoms(sympy.ceiling): + # for nested exprs, atoms yields top level first (?) + # so if everything goes fine, lower level replacements will come up empty + symbols = a.free_symbols + if len(symbols) > 0 and all( + symbol_is_type(s, (SymT.SIZE, SymT.PRECOMPUTED_SIZE)) + for s in symbols + ): + replacements = {a: V.graph.sizevars.lookup_precomputed_size(a)} + index = sympy_subs(index, replacements) + + simp_index = index + + simp_index = ( + simp_index if not isinstance(simp_index, Identity) else simp_index.args[0] + ) + + # to generate range.var_directions for permuted axis + index_analyze.analyze_index() + return self.codegen_indexing(simp_index) + + def replace_index_vars(self, index, index_analyze): + + new_index = index + if index_analyze.var_replacements: + new_index = sympy_subs(index, index_analyze.var_replacements) + return new_index + + def index_to_str(self, index: sympy.Expr) -> str: + if isinstance(index, list): + return f"[{', '.join(map(self.index_to_str, index))}]" + index = self.rename_indexing(index) + return self.kexpr(index) # type: ignore[call-arg] + + # 1. only remove the line which asserts index var should be in "xyr" + # 2. don't do simplify_indexing, which combine continuous dims + # 3. removed block_ptr, removed dense mask/broadcast support + # dense_mask_vars should be generated from sorted_axis + # upgraded to torch251 + def indexing( + self, + index: sympy.Expr, + *, + copy_shape=None, + dense_indexing=False, + override_mask=None, + block_ptr=False, + index_analyze=None, + is_index_expr=False + ) -> Union[IndexingOptions, BlockPtrOptions]: + """ + Compute the index and mask to pass to tl.load() or tl.store() + """ + if not index_analyze: + index_analyze = IndexAnalysis(self, index, is_index_expr=is_index_expr) + index_analyze.analyze_index() + + index = self.prepare_indexing(index, index_analyze, is_index_expr) + index_vars = index.free_symbols + has_rindex = False + index = sympy_subs(index, V.graph.sizevars.precomputed_replacements) + # if simple replacements didn't get rid of floor/ceil, try full subs + if len(index.atoms(sympy.floor)) or len(index.atoms(sympy.ceiling)): + index = index.subs(V.graph.sizevars.precomputed_replacements) + if len(index.atoms(sympy.ceiling)): + for a in index.atoms(sympy.ceiling): + # for nested exprs, atoms yields top level first (?) + # so if everything goes fine, lower level replacements will come up empty + symbols = a.free_symbols + if len(symbols) > 0 and all( + s.name.startswith("s") or s.name.startswith("ps") for s in symbols + ): + replacements = {a: V.graph.sizevars.lookup_precomputed_size(a)} + index = sympy_subs(index, replacements) + + # if not self.inside_reduction : + index = self.replace_index_vars(index, index_analyze) + index_vars = index.free_symbols + has_rindex = False + + mask_vars: Set[str] = set() + for var in index_vars: + if not (isinstance(var, sympy.Symbol)): + raise RuntimeError("assert isinstance(var, sympy.Symbol)") + + has_rindex = has_rindex or var.name.startswith("r") + if override_mask: + pass + elif var.name.startswith("tmp"): + # indirect indexing + cse_var = self.cse.varname_map[var.name] + mask_vars.update(cse_var.mask_vars) + elif var.name.startswith(("s", "ps", "i")): + pass + else: + # var is one of xN, yN or rN + mask_vars.add(f"{var.name}_mask") + + expand_str = None + index_str = self.index_to_str(index) + + if isinstance(index, sympy.Integer): + expand_str = f"{copy_shape}.shape" if copy_shape else self.dense_size_str() + if (index != 0): + index_str = f"tl.full({expand_str}, {index_str}, tl.int32)" + else: + index_str = f"tl.arange(0,1)" + return IndexingOptions(index_str, OrderedSet(), expand_str, has_rindex, index) + + if override_mask: + mask_vars = {override_mask} + if self._load_mask: + mask_vars.add(self._load_mask) + self.filter_masks(mask_vars) + return IndexingOptions(index_str, mask_vars, expand_str, has_rindex, index) # type: ignore[arg-type] + + def codegen_indexing(self, expr: sympy.Expr): + expr = V.graph.sizevars.simplify_with_ranges(expr, self.var_ranges()) + for sym in sorted(expr.free_symbols, key=str): + if sym in self.range_tree_nodes: + # if indexing expression is complicated, we precompute it on the host side + # and send the result as a kernel argument + replacements = {} + for ps in self.range_tree_nodes[sym].precomputed_args(): # type: ignore[index] + replacements[ps] = V.graph.sizevars.lookup_precomputed_size(ps) + if len(replacements) > 0: + self.range_tree_nodes[sym].expr = sympy_subs( # type: ignore[index] + self.range_tree_nodes[sym].expr, replacements # type: ignore[index] + ) + self.range_tree_nodes[sym].codegen() # type: ignore[index] + return expr + + # when xindex(16) -> x2:2,x3:8, when new length:16 in , should return (x2,x3) + def split_and_set_ranges(self, lengths: Sequence[Sequence[sympy.Expr]]): + groups = [rt.numel for rt in self.range_trees] + if not self.inside_reduction: + groups[-1] = sympy.S.One + + return self.map_kernel_groups_to_node_sizes(groups, lengths, self.set_ranges) + + # support split multiple ranges (instead of double) from one flatten range, triple-ranges are needed in mamba model + @staticmethod + def _split_iteration_ranges( + groups: Iterable[sympy.Expr], lengths: Sequence[Sequence[sympy.Expr]] + ): + sv = V.graph.sizevars + new_ranges: List[List[sympy.Expr]] = [[] for _ in groups] + remaining = [sv.simplify(g) for g in groups] + for i, group in enumerate(remaining): + if isinstance(group, (list, tuple)): + remaining[i] = NumelList(group).numels() + + var_count = itertools.count() + + def add_range(i, expr): + expr = sv.simplify(expr) + if not sv.statically_known_multiple_of(remaining[i], expr): + raise CantSplit() + # guard on the last item out + remaining[i] = FloorDiv(remaining[i], expr) + new_ranges[i].append(expr) + return next(var_count) + + def make_combined(strides, index_list): + def getter(flat_vars): + expr = sympy.Integer(0) + for stride, index in zip(strides, index_list): + expr = stride * flat_vars[index] + expr + return expr + + return getter + + def size_hints(group): + if isinstance(group, (list, tuple)): + return sv.size_hint(NumelList(group).numels()) + return sv.size_hint(group) + + def add_multiple_range(size, return_getters): + # need to break size in multiple + index_list = [] + stride_list = [] + group = current_group + remained_size = size + # Two checks: + # 1. remaining sizes to be merged + # 2. remained_size is already divided to 1 + while (group < len(remaining) and remaining[group] > 1) and (remained_size > 1): + group_size = remaining[group] + # size should be divisible by group_size + if not sv.statically_known_multiple_of(remained_size, group_size): + raise CantSplit() + index_list.append(add_range(group, group_size)) + remained_size = FloorDiv(remained_size, group_size) + stride_list.append(remained_size) + group = group + 1 + if remained_size != 1: + raise CantSplit() + return_getters.append(make_combined(stride_list, index_list)) + + return_getters_groups = [] + current_group = 0 + + for length_group in lengths: + return_getters = [] + for size in length_group: + if sv.statically_known_equals(size, 1): # type: ignore[arg-type] + return_getters.append(lambda _: sympy.Integer(0)) + continue + + while ( + current_group < len(remaining) + and size_hints(remaining[current_group]) == 1 + ): + # scroll to next group with remaining elements + current_group += 1 + size_hint = sv.size_hint(size) + if size_hint > size_hints(remaining[current_group]): + # add multiple ranges (two or more) to the list, as well as the getter funcs + add_multiple_range(size_hint, return_getters) + else: + return_getters.append( + operator.itemgetter(add_range(current_group, size_hint)) + ) + return_getters_groups.append(return_getters) + + if not (all(V.graph.sizevars.size_hint(s) == 1 for s in remaining)): + raise RuntimeError("assert all(V.graph.sizevars.size_hint(s) == 1 for s in remaining)") + + return new_ranges, return_getters_groups + + # torch260 done + # just to override load method of CSEProxy, however, CSEProxy is an inner which can not be monkey patched, + # we need to override the whole inner class + def __enter__(self): + class CSEProxy: + self.name = "CSEProxy" + vr_analysis = ValueRangeAnalysis() + + @staticmethod + def __getattr__(name: str) -> Callable[..., CSEVariable]: # type: ignore[misc] + def inner(*args, **kwargs): + bounds = CSEProxy._bound_variable(name, *args, **kwargs) + + value = getattr(parent_handler, name)(*args, **kwargs) # type: ignore[has-type] + dtype_handler = DtypePropagationOpsHandler() + + output_idx = 0 + + def do_cse(v): + # cpp backend doesnt set current device + if V.graph.current_device is not None: + device_str = V.graph.get_current_device_or_throw().type + triton_backend = ( + config.cpu_backend == "triton" + if device_str == "cpu" + else config.cuda_backend == "triton" + ) + else: + triton_backend = False + + # only triton backend tracks dtype currently + if triton_backend: + if name == "masked": + output_dtype = value.dtype + else: + output_dtype = getattr( + dtype_handler, + name, + )(*args, **kwargs) + else: + # cpp backend doesnt track dtype yet + output_dtype = None + + csevar = V.kernel.cse.generate( + V.kernel.compute, + v, + bounds=bounds, + dtype=output_dtype, + ) + + nonlocal output_idx + if ( + config.test_configs.runtime_triton_dtype_assert + and triton_backend + ): + from torch._inductor.codegen.triton import triton_type + + # we tree_map over the output, so we need to fetch corresponding dtype + if isinstance(output_dtype, (list, tuple)): + output_dtype = output_dtype[output_idx] + + V.kernel.compute.writeline( + f"tl.static_assert({csevar}.dtype == {triton_type(output_dtype)})" + ) + output_idx += 1 + + csevar.update_on_args(name, args, kwargs) + + return csevar + + return pytree.tree_map(do_cse, value) + + return inner + + @staticmethod + def _bound_variable(name, *args, **kwargs): + """ + If the variable comes from an FX node, we forward the bound we have already computed + Else, if the variable when codegen'ing another op, we try to compute its bounds + """ + from torch._inductor.select_algorithm import TritonTemplateKernel + + if isinstance(V.kernel, TritonTemplateKernel): + return ValueRanges.unknown() + + fx_node = V.interpreter.current_node + if fx_node.target == name and self.node_to_bounds is not None: + if not (isinstance(self.node_to_bounds, dict)): + raise RuntimeError("assert isinstance(self.node_to_bounds, dict)") + + return self.node_to_bounds.get(fx_node, ValueRanges.unknown()) + elif config.compute_all_bounds and hasattr(ValueRangeAnalysis, name): + # These create lots of inner strings. We would need to compute the bounds at the ops + # We will also likely not get much from computing VRs on these nodes + if any( + s in fx_node.target + for s in ("set_indirect", "reduction", "scan") + ): + return ValueRanges.unknown() + + # We assume that the inputs come from `ops.` and are not strings. If you want to generate + # intermediary strings, wrap them in CSE variables with properly initialised bounds. + + # If there is no FX bound but we know how to compute one we do so + if (kwargs): + raise RuntimeError("assert not kwargs") + + def arg_to_bound(x): + if isinstance(x, CSEVariable): + return x.bounds + elif isinstance(x, sympy.Expr): + return bound_sympy(x) + else: + return x + + arg_bounds = list(map(arg_to_bound, args)) + return getattr(CSEProxy.vr_analysis, name)(*arg_bounds) + return ValueRanges.unknown() + + @staticmethod + def indirect_indexing( + var: CSEVariable, + size: Union[sympy.Expr, int], + check: bool = True, + wrap_neg=True, + ): + if isinstance(size, int): + size = sympy.Integer(size) + if not (isinstance(size, sympy.Expr)): + raise RuntimeError("assert isinstance(size, sympy.Expr), size") + # Skip CSE since this doesn't return an expression + + if var.bounds.lower < 0: # type: ignore[operator] + if wrap_neg: + stm = ops.add(var, ops.index_expr(size, torch.long)) + # Mixed negative and non-negative + if var.bounds.upper >= 0: # type: ignore[operator] + lt = ops.lt(var, 0) + stm = ops.where(lt, stm, var) + else: + stm = var + + # Propagate bounds as we know how to compute them properly + new_bounds = ValueRanges.unknown() + if var.bounds != ValueRanges.unknown() and isinstance( + size, sympy.Number + ): + # Take the negative part of the bound and add size to it + # Then take union of that and the positive part + # This is a tighter bound than that of a generic ops.where, as we have info on the cond + neg_bounds = var.bounds & ValueRanges(-int_oo, -1) + new_bounds = ValueRanges( + neg_bounds.lower + size, neg_bounds.upper + size + ) + # We don't have a good way of representing the empty range + if var.bounds.upper >= 0: # type: ignore[operator] + pos = var.bounds & ValueRanges(0, int_oo) + new_bounds = new_bounds | pos + + var = self.cse.generate(self.compute, stm, bounds=new_bounds) + + sympy_var = parent_handler.indirect_indexing(var, size, check) + if generate_assert(check): + assert_lower = not (var.bounds.lower >= 0) + # value ranges cannot x < s when x and s are symbols + assert_upper = not isinstance(size, sympy.Number) or not ( + var.bounds.upper < size + ) + self.check_bounds(sympy_var, size, assert_lower, assert_upper) + return sympy_var + + @staticmethod + def check_bounds( + expr: sympy.Expr, size: sympy.Expr, lower: bool, upper: bool + ): + return self.check_bounds(expr, size, lower, upper) + + @staticmethod + def load(name: str, index: sympy.Expr) -> CSEVariable: + if name in self.cse.invalidated_stores: + # A load from an invalidated store requires us to + # keep the actual buffer around + V.kernel.must_keep_buffers.add(name) + if free_symbol_is_type(index, SymT.TMP): + return self.indirect_load(name, index) + store_cache = self.cse.store_cache + if name in store_cache: + return self.load(name, index) + out = self.load(name, index) + # count load that is not in the store_cache, and also not in the + # cse cache. + if out.use_count == 1: + self.num_load += 1 + return out + + @staticmethod + def _update_store_cache(name: str, value: CSEVariable): + self.cse.store_cache[name] = value + if self.current_node and name in V.graph.name_to_buffer: + buf = self.current_node.get_output(name) + for other_name in buf.get_mutations(): + self.cse.store_cache[other_name] = value + + @staticmethod + def store( + name: str, index: sympy.Expr, value: CSEVariable, mode: StoreMode = None + ) -> None: + self.store_buffer_names.add(name) + if mode is None: + CSEProxy._update_store_cache(name, value) + if name not in V.graph.removed_buffers: + return self.store(name, index, value, mode=mode) + return None # type: ignore[return-value] + + @staticmethod + def store_reduction(name: str, index: sympy.Expr, value: CSEVariable): + self.store_buffer_names.add(name) + CSEProxy._update_store_cache(name, value) + + if name not in V.graph.removed_buffers: + return self.store_reduction(name, index, value) + raise RuntimeError("store_reduction") + + @staticmethod + def reduction( + dtype: torch.dtype, + src_dtype: torch.dtype, + reduction_type: ReductionType, + value: Union[CSEVariable, Tuple[CSEVariable, ...]], + ) -> Union[CSEVariable, Tuple[CSEVariable, ...]]: + self.num_reduction += 1 + return self.reduction(dtype, src_dtype, reduction_type, value) + + @staticmethod + def scan( + dtypes: Tuple[torch.dtype, ...], + combine_fn: Callable[ + [Tuple[CSEVariable, ...], Tuple[CSEVariable, ...]], + Tuple[CSEVariable, ...], + ], + values: Tuple[CSEVariable, ...], + ) -> Tuple[CSEVariable, ...]: + return self.scan(dtypes, combine_fn, values) + + @staticmethod + def sort( + dtypes: Tuple[torch.dtype, ...], + values: Tuple[CSEVariable, ...], + stable: bool, + descending: bool, + ) -> Tuple[CSEVariable, ...]: + return self.sort(dtypes, values, stable, descending) + + @staticmethod + def bucketize( + values: CSEVariable, + boundaries: Tuple[str, sympy.Expr, sympy.Expr, sympy.Expr], + boundary_indices: CSEVariable, + indexing_dtype: torch.dtype, + right: bool, + sorter: Optional[Tuple[str, sympy.Expr]] = None, + sorter_indices: Optional[CSEVariable] = None, + ) -> CSEVariable: + return self.bucketize( + values, + boundaries, + boundary_indices, + indexing_dtype, + right, + sorter, + sorter_indices, + ) + + # Use mypy to check protocol implemented correctly + def _typecheck_CSEProxy(h: CSEProxy) -> OpsHandler[CSEVariable]: + return h + + super().__enter__() + if not (self.overrides): + raise RuntimeError("assert self.overrides") + parent_handler = self.overrides() + self.exit_stack.enter_context(V.set_ops_handler(CSEProxy())) + self.exit_stack.enter_context(V.set_kernel_handler(self)) + return self diff --git a/torch_npu/_inductor/codegen/triton_utils.py b/torch_npu/_inductor/codegen/triton_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1bbaef2a2fd2f6b0c3b2f6b17dbf57ddd569d750 --- /dev/null +++ b/torch_npu/_inductor/codegen/triton_utils.py @@ -0,0 +1,26 @@ +import torch + +# wrapper npu 32 bytes align, get and pass unalign info to triton meta +# then autotune choose tiling param and send them to bishengIR +byte_per_numel = { + torch.float32: 4, # torch.float32 or torch.float + torch.float64: 8, # torch.float64 or torch.double + torch.float16: 2, # torch.float16 or torch.half + torch.bfloat16: 2, # torch.bfloat16 + torch.int32: 4, # torch.int32 or torch.int + torch.int64: 8, # torch.int64 or torch.long + torch.int16: 2, # torch.int16 or torch.short + torch.int8: 1, # torch.int8 + torch.uint8: 1, # torch.uint8 + torch.bool: 1, # torch.bool + torch.complex32: 4, # torch.complex32 (not yet available in PyTorch as of the latest stable release) + torch.complex64: 8, # torch.complex64 + torch.complex128: 16 # torch.complex128 +} + + +def get_aligned_numel(dtype): + if dtype in byte_per_numel: + return 32 // byte_per_numel[dtype] + else: + return 1 diff --git a/torch_npu/_inductor/codegen/wrapper.py b/torch_npu/_inductor/codegen/wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..7e966c554ec3ac087a31d93d6aab7d5a44719d73 --- /dev/null +++ b/torch_npu/_inductor/codegen/wrapper.py @@ -0,0 +1,254 @@ +import os +import copy +from typing import Any, Callable, Optional, TYPE_CHECKING, Union +import hashlib +import sympy + +import torch +from torch._inductor import config +from torch._inductor.codegen.wrapper import PythonWrapperCodegen, SymbolicCallArg, SubgraphPythonWrapperCodegen +from torch._inductor.runtime import triton_heuristics +from torch._inductor.utils import ( + cache_on_self, +) +from torch._inductor.virtualized import V +from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode +from torch.utils._sympy.singleton_int import SingletonInt +from torch._inductor.ir import GraphPartitionSignature + +from torch_npu._inductor import config as npu_config +import torch_npu.npu.aclnn + + +class NPUWrapperCodeGen(PythonWrapperCodegen): + def __init__(self): + super().__init__() + + @staticmethod + def create( + is_subgraph: bool, + subgraph_name: str, + parent_wrapper: PythonWrapperCodegen, + partition_signatures: Optional[GraphPartitionSignature] = None, + ): + if is_subgraph: + return SubgraphPythonWrapperCodegen(subgraph_name, parent_wrapper, partition_signatures) + return NPUWrapperCodeGen() + + def write_header(self) -> None: + super().write_header() + self.imports.splice( + f""" + import torch_npu + """, + strip=True, + ) + + @cache_on_self + def write_triton_header_once(self) -> None: + import_str = f""" + import triton + import triton.language as tl + from {triton_heuristics.__name__} import start_graph, end_graph + import torch_npu + has_initialized = False + """ + if config.triton.autotune_at_compile_time: + self.kernel_autotune_calls.splice(import_str) + self.kernel_autotune_calls.writeline( + V.graph.device_ops.import_get_raw_stream_as("get_raw_stream") + ) + if not V.graph.cpp_wrapper: + self.imports.splice(import_str, strip=True) + self.imports.writeline( + V.graph.device_ops.import_get_raw_stream_as("get_raw_stream") + ) + + # generate numel expr for range_tree_node + def generate_node_numel_expr(self, kernel_name: str, node, numel_expr): + expr = f"{kernel_name}_{node.name}_numel" + if (expr, V.graph) not in self.kernel_numel_expr: + # declare expr once in each graph (scope) + self.kernel_numel_expr.add((expr, V.graph)) + self.writeline( + f"{self.declare}{expr} = {self.expr_printer(numel_expr)}{self.ending}" + ) + else: + self.writeline(f"{expr} = {self.expr_printer(numel_expr)}{self.ending}") + # We can get symbolic expressions here, like s0*64 + # It is fine to have them here, but we need to handle them correctly as their own type + # This is tricky to do, so we wrap in a custom type, distinct from scalars, but also from sympy* + # scalars as well. + # This is handled in `generate_args_decl` which has a correct comment of: TODO: only works for + # constant now, need type info. I agree, this needs type info, and while this is not true type info + # it suffices as a type hint for the purposes of producing the correct code for this type. + return SymbolicCallArg(expr, numel_expr) + + # don't free anything + def make_buffer_free(self, buffer): + return "" + + # don't assert + def codegen_input_size_asserts(self) -> None: + pass + + def get_next_kernel_suffix(self) -> str: + iter_val = copy.copy(self._names_iter) + return f"{next(iter_val)}" + + def add_benchmark_harness(self, output): + """ + Override, add aot-inductor debug kernel support. + """ + if not config.benchmark_harness: + return None + + if npu_config.aot_inductor.debug_kernel: + return self.add_npu_repro(output) + + return super().add_benchmark_harness(output) + + def add_npu_repro(self, output): + self.add_repro_func(output) + self.add_benchmark_func(output) + + output.writelines(["", "", 'if __name__ == "__main__":']) + with output.indent(): + # List how to use. Read details in torch_npu/_inductor/config.py. + output.writelines( + [ + "# torch_npu._inductor.config.force_fallback_kernel_id = 'all'", + "# or", + "# torch_npu._inductor.config.force_fallback_kernel_id = [1, 2, 10]", + "torch_npu._inductor.config.aot_inductor.debug_kernel_in_run = True", + "result = benchmark_compiled_module()", + "print(result)", + ] + ) + + def add_repro_func(self, output): + seen_constants = set() + + def add_fake_input(name, shape, stride, device, dtype): + output.writeline( + f"{name} = rand_strided(" + f"{self.codegen_python_shape_tuple(shape)}, " + f"{self.codegen_python_shape_tuple(stride)}, " + f"device='{device}', dtype={dtype})" + ) + + def get_hash(name): + byte = name.encode('utf-8') + sha1 = hashlib.sha1() + sha1.update(byte) + return sha1.hexdigest() + + def save_tensor(tensor, path): + dirname = os.path.dirname(path) + if not os.path.exists(dirname): + os.makedirs(dirname) + torch.save(tensor, path) + + def add_real_tensor(name, tensor): + tensor_dir = npu_config.aot_inductor.repro_tensor_path + if isinstance(tensor, FakeTensor): + raise RuntimeError(f"Could not generate repro func because detected {name} is FakeTensor " + f"when trying to dump it. Set repro and debug_kernel false to avoid it.") + hash_name = get_hash(name) + tensor_path = os.path.join(os.getcwd(), tensor_dir, f"{hash_name}.pt") + if name not in seen_constants: + save_tensor(tensor, tensor_path) + seen_constants.add(name) + output.writeline( + f"{name} = torch.load('{tensor_path}')" + ) + + def add_torchbind_input(name, value): + import pickle + + output.writeline(f"{name} = pickle.loads({pickle.dumps(value)!r})") + output.writelines( + ["", "", f"def repro_run({', '.join(V.graph.graph_inputs.keys())}):"] + ) + with output.indent(): + output.splice( + """ + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + """, + strip=True, + ) + for name, value in V.graph.constants.items(): + # all the constants are global variables, that's why we need + # these 'global var_name' lines + output.writeline(f"global {name}") + add_real_tensor(name, value) + + if len(V.graph.torchbind_constants) > 0: + output.writeline("import pickle") + for name, torchbind_obj in V.graph.torchbind_constants.items(): + # all the constants are global variables, that's why we need + # these 'global var_name' lines + output.writeline(f"global {name}") + add_torchbind_input(name, torchbind_obj) + + call_str = f"call([{', '.join(V.graph.graph_inputs.keys())}])" + output.writeline(f"fn = lambda: {call_str}") + output.writeline("return fn()") + + def add_benchmark_func(self, output): + def add_fake_input(name, shape, stride, device, dtype): + output.writeline( + f"{name} = rand_strided(" + f"{self.codegen_python_shape_tuple(shape)}, " + f"{self.codegen_python_shape_tuple(stride)}, " + f"device='{device}', dtype={dtype})" + ) + + def add_expr_input(name, val): + output.writeline(f"{name} = {val}") + + output.writelines( + ["", "", "def benchmark_compiled_module(times=10, repeat=10):"] + ) + with output.indent(): + output.splice( + """ + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + """, + strip=True, + ) + for name, value in V.graph.graph_inputs.items(): + if isinstance(value, sympy.Symbol) and isinstance( + V.graph.sizevars.var_to_val.get(value, None), SingletonInt + ): + continue + if isinstance(value, sympy.Expr): # Don't need to add symbolic + add_expr_input(name, V.graph.sizevars.size_hint(value, fallback=42)) + else: + shape = [ + V.graph.sizevars.size_hint(x, fallback=42) + for x in value.get_size() + ] + stride = [ + V.graph.sizevars.size_hint(x, fallback=42) + for x in value.get_stride() + ] + add_fake_input( + name, + shape, + stride, + value.get_device(), + value.get_dtype(), + ) + + call_str = f"repro_run({', '.join(V.graph.graph_inputs.keys())})" + output.writeline(f"fn = lambda: {call_str}") + output.writeline("return fn()") + + def write_prefix(self) -> None: + super().write_prefix() + + def generate_return(self, output_refs: list[str]) -> None: + super().generate_return(output_refs) \ No newline at end of file diff --git a/torch_npu/_inductor/config.py b/torch_npu/_inductor/config.py new file mode 100644 index 0000000000000000000000000000000000000000..4b0184b1c083245f7235fb39f3a47401d279e6f4 --- /dev/null +++ b/torch_npu/_inductor/config.py @@ -0,0 +1,115 @@ +import logging +import os # noqa: C101 +from typing import Any, Callable, Dict, Optional, TYPE_CHECKING +import torch +from torch._inductor import config +from triton.runtime.driver import driver + +enable_npu_indexing = True + +config.triton.unique_kernel_names = True +# avoid test_opensora_cases_model_16_forward reinterpre_tensor issue +config.allow_buffer_reuse = False +# inductor debug switch +config.trace.enabled = True + +# npu hardware params from trion +target = driver.active.get_current_target() +device = driver.active.get_current_device() +prop = driver.active.utils.get_device_properties(device) + +num_cube_core = prop["num_aicore"] +num_vector_core = prop["num_aicore"] + +# unit byte +npu_block = 32 + + +# For debug +class aot_inductor: + # If debug_kernel is set, codegen in python wrapper (output_code.py) and cpp wrapper (model.pt2) + # will be modified to dump fx graph and weights. Meanwhile, generate repro func in output_code.py. + # Then, run aoti and output_code.py will dump tensor args before and after each triton kernel, + # which can be used to detect which kernel is incorrect. + debug_kernel = os.environ.get("AOTI_ASCEND_DEBUG_KERNEL", False) + + # No need to set debug_kernel_in_run manually. It will be set in output_code.py + # by codegen if debug_kernel is set. + debug_kernel_in_run = False + + # Path that to be used for dump weights in aoti to reproduce when debug_kernel is set. + repro_tensor_path = os.environ.get("AOTI_ASCEND_REPRO_TENSOR_PATH", "aoti_repro_tensors") + + # Path that to be used for dump tensor args before and after triton kernel in aoti execute + # when debug_kernel is set. + dump_path_cpp = os.environ.get("AOTI_ASCEND_DUMP_PATH_CPP", "aoti_dump_cpp") + + # Path that to be used for dump tensor args before and after triton kernel in output_code.py + # when debug_kernel_in_run is set. + dump_path_py = os.environ.get("AOTI_DUMP_PATH_PY", "aoti_dump_py") + + +traced_fx_graph_cache = os.environ.get("INDUCTOR_ASCEND_FX_GRAPH_CACHE", None) +check_accuracy = os.environ.get("INDUCTOR_ASCEND_CHECK_ACCURACY", False) +auto_fallback = os.environ.get("INDUCTOR_ASCEND_AUTO_FALLBACK", True) +fallback_warning = os.environ.get("INDUCTOR_ASCEND_FALLBACK_WARNING", False) + +# Trace fx graph when lowering and dump. +dump_fx_graph = os.environ.get("INDUCTOR_ASCEND_DUMP_FX_GRAPH", False) \ + or check_accuracy \ + or aot_inductor.debug_kernel +# Specify kernel ids that to be force fallback to fx graph call. +# Usage: `torch_npu._inductor.config.force_fallback_kernel_id = 'all' ` +# or `torch_npu._inductor.config.force_fallback_kernel_id = [1, 2, 10] ` +# (1) 'all' means try to fallback all kernel to fx graph call. +# (2) [1, 2, 10] means try to fallback kernel like triton_xxx_1, triton_xxx_2 and triton_xxx_10 +force_fallback_kernel_id = [] + +acc_comp_tol = { + torch.float32: {'rtol': 1.3e-6, 'atol': 1e-5}, + torch.float16: {'rtol': 1e-3, 'atol': 1e-5}, + torch.bfloat16: {'rtol': 1.6e-2, 'atol': 1e-5}, + "default": {'rtol': 1.3e-6, 'atol': 1e-5}, +} + +if ("Ascend910B" in target.arch): + num_vector_core = num_cube_core * 2 + +log_level_env = os.getenv('INDUCTOR_ASCEND_LOG_LEVEL', 'WARNING').upper() +log_level_mapping = { + 'DEBUG': logging.DEBUG, + 'INFO': logging.INFO, + 'WARNING': logging.WARNING, + 'ERROR': logging.ERROR, + 'CRITICAL': logging.CRITICAL +} +log_level = log_level_mapping.get(log_level_env.upper(), logging.INFO) +logging.basicConfig( + level=log_level, + format='%(asctime)s - %(levelname)s - %(message)s' +) +log = logging.getLogger(__name__) + +aggresive_autotune = os.getenv("INDUCTOR_ASCEND_AGGRESSIVE_AUTOTUNE", '0').lower() in ('1', 'true') +inductor_static_mode = os.environ.get('INDUCTOR_STATIC_MODE', '0').lower() in ('1', 'yes', 'true') +profile_path = "./profile_result/" + + +def set_compile_threads(): + if "TORCHINDUCTOR_COMPILE_THREADS" in os.environ: + torchinductor_compile_threads = int(os.environ["TORCHINDUCTOR_COMPILE_THREADS"]) + if torchinductor_compile_threads == 1: + return + log.warning(f"TORCHINDUCTOR_COMPILE_THREADS is set to {torchinductor_compile_threads}, " + "but currently only support 1. It will be modified to 1.") + + os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1" + torch._inductor.config.compile_threads = 1 + + def get_env_num_workers(): + return 1 + torch._inductor.select_algorithm.get_env_num_workers = get_env_num_workers + + +def disable_comprehensive_padding(): + torch._inductor.config.comprehensive_padding = False \ No newline at end of file diff --git a/torch_npu/_inductor/cpp_builder.py b/torch_npu/_inductor/cpp_builder.py new file mode 100644 index 0000000000000000000000000000000000000000..a72ea3f07d698c4f22a35cbdf9faa7629e6c81fa --- /dev/null +++ b/torch_npu/_inductor/cpp_builder.py @@ -0,0 +1,120 @@ +import os +from typing import Any, List, Optional, Sequence, Tuple, Union + +import torch +from torch.utils.cpp_extension import _HERE, _TORCH_PATH, TORCH_LIB_PATH + +from torch_npu.utils.cpp_extension import PYTORCH_NPU_INSTALL_PATH +from torch_npu.utils._error_code import ErrCode, pta_error + +if "ASCEND_HOME_PATH" not in os.environ: + def lazy_error(): + raise RuntimeError("Could not find ASCEND_HOME_PATH in env. Please run set_env.sh first." + + pta_error(ErrCode.NOT_FOUND)) + get_ascend_home = lazy_error +else: + def get_ascend_home_from_env(): + return os.environ["ASCEND_HOME_PATH"] + get_ascend_home = get_ascend_home_from_env + +TORCH_LIB_PATH = os.path.join(_TORCH_PATH, 'lib') + + +def include_paths(npu: bool = False) -> List[str]: + """ + Get the includ paths required to build a C++ extension. + + Args: + npu: If 'True', includes NPU-specific include paths. + + Returns: + A list if include path strings. + """ + lib_include = os.path.join(_TORCH_PATH, "include") + paths = [ + lib_include, + # Remove this once torch/torch.h is officially no longer supported for C++ extensions. + os.path.join(lib_include, 'torch', 'csrc', 'api', 'include'), + # Some internal (old) Torch headers don't properly prefix their includes, + # so we need to pass -Itorch/lib/include/TH as well. + os.path.join(lib_include, 'TH'), + os.path.join(lib_include, 'THC') + ] + if npu: + ASCEND_HOME = get_ascend_home() + paths.extend([ + os.path.join(ASCEND_HOME, "include"), + os.path.join(ASCEND_HOME, "include/experiment"), + os.path.join(ASCEND_HOME, "include/experiment/msprof"), + ]) + + paths.append(os.path.join(PYTORCH_NPU_INSTALL_PATH, "include")) + return paths + + +def library_paths(npu: bool = False) -> List[str]: + """ + Get the library paths required to build a C++. + + Args: + npu: If 'True', includes NPU-specific library paths. + + Returns: + A list of library path strings. + """ + # We need to link against libtorch.so + paths = [TORCH_LIB_PATH] + if npu: + if "LIBTORCH_NPU_PATH" in os.environ: + libtorch_npu_path = os.environ["LIBTORCH_NPU_PATH"] + else: + libtorch_npu_path = os.path.join(PYTORCH_NPU_INSTALL_PATH, "lib") + paths.append(libtorch_npu_path) + + ASCEND_HOME = get_ascend_home() + cann_lib_path = os.path.join(ASCEND_HOME, "lib64") + paths.append(cann_lib_path) + + return paths + + +def get_cpp_torch_device_options( + device_type: str, + aot_mode: bool = False, + compile_only: bool = False, +) -> Tuple[List[str], List[str], List[str], List[str], List[str], List[str], List[str]]: + + npu = "npu" == device_type + + definations: List[str] = [] + include_dirs: List[str] = [] + cflags: List[str] = [] + ldflags: List[str] = [] + libraries_dirs: List[str] = [] + libraries: List[str] = [] + passthough_args: List[str] = [] + + include_dirs = include_paths(npu) + libraries_dirs = library_paths(npu) + + if npu: + definations.append("USE_NPU") + libraries += ["torch_npu", "runtime", "ascendcl"] + + # Could not add BUILD_LIBTORCH=ON to definations because it cannot + # process defination include "=" like -DXXX=xx. + passthough_args += ["-DBUILD_LIBTORCH=ON -Wno-unused-function"] + + return ( + definations, + include_dirs, + cflags, + ldflags, + libraries_dirs, + libraries, + passthough_args, + ) + + +def patch_get_cpp_torch_device_options(): + torch._inductor.cpp_builder.get_cpp_torch_device_options = get_cpp_torch_device_options \ No newline at end of file diff --git a/torch_npu/_inductor/decomposition.py b/torch_npu/_inductor/decomposition.py new file mode 100644 index 0000000000000000000000000000000000000000..b9c725f3ff2d762e060349b7f2b25afe04c5ec51 --- /dev/null +++ b/torch_npu/_inductor/decomposition.py @@ -0,0 +1,49 @@ +import torch._ops +from torch._inductor.decomposition import decompositions, pw_cast_for_opmath +from torch._inductor.decomposition import register_decomposition + +from .lowering import _init_set + +aten = torch.ops.aten + +DECOMPOSITION_OVERLOAD_OP = [ + aten._log_softmax, + aten.nll_loss_forward, + # aten.gelu_backward, + # aten.gelu, + aten.nll_loss_backward, + aten._log_softmax_backward_data, + aten.embedding_dense_backward, + aten.addmm, + aten.gelu +] + + +def _register_npu_inductor_decompositons(): + overload_op_set = set() + _init_set(DECOMPOSITION_OVERLOAD_OP, overload_op_set) + + for op in overload_op_set: + if (op in decompositions): + del decompositions[op] + + @register_decomposition([aten.scatter.src]) + @pw_cast_for_opmath + def scatter_src(self, input_tensor, dim, index_tensor, source_tensor): + (XNUMEL, YS) = input_tensor.shape + index_rblock = torch.arange(YS).npu().reshape((1, YS)).repeat((XNUMEL, 1)) + + index_tensor_brd = index_tensor.to(torch.int32).broadcast_to(XNUMEL, YS) + source_tensor_brd = source_tensor.broadcast_to(XNUMEL, YS).to(torch.float32) + scatter1 = torch.where(index_rblock == index_tensor_brd, 1.0, 0.0) * source_tensor_brd + return scatter1 + + @register_decomposition([aten.expm1]) + def expm1(x): + tensor = torch.exp(x) - torch.ones_like(x) + return tensor + + @register_decomposition([aten.erfc]) + def erfc(x): + tensor = torch.ones_like(x) - torch.exp(x) + return tensor diff --git a/torch_npu/_inductor/fx_passes/joint_graph.py b/torch_npu/_inductor/fx_passes/joint_graph.py new file mode 100644 index 0000000000000000000000000000000000000000..11210910d13863808c24a583a1cb8f4f83c5afea --- /dev/null +++ b/torch_npu/_inductor/fx_passes/joint_graph.py @@ -0,0 +1,15 @@ +import torch +import torch._inductor.fx_passes.joint_graph as joint_graph + + +def patch_constant_fold_uniform_value(): + # Fix bug in aot_inductor for torch. + # Eliminate dead-nodes to remove extra constants generated by torch.tensor. + src_func = joint_graph.constant_fold_uniform_value + + def new_constant_fold_uniform_value(gm): + src_func(gm) + if isinstance(gm, torch.fx.GraphModule): + gm.graph.eliminate_dead_code() + + joint_graph.constant_fold_uniform_value = new_constant_fold_uniform_value \ No newline at end of file diff --git a/torch_npu/_inductor/graph.py b/torch_npu/_inductor/graph.py new file mode 100644 index 0000000000000000000000000000000000000000..caff8fbc60c1ba44c737658896e9c44095bec474 --- /dev/null +++ b/torch_npu/_inductor/graph.py @@ -0,0 +1,114 @@ +from typing import ( + Any, + List, + Tuple, + Union, +) +import itertools + +import torch +from torch.fx.node import Node +from torch._inductor import config, metrics +from torch._subclasses.fake_tensor import FakeTensor +from torch._dynamo.utils import defake, dynamo_timed +from torch._inductor.virtualized import NullHandler, V + + +def patch_codegen_with_cpp_wrapper(): + def npu_codegen_with_cpp_wrapper(self) -> Tuple[str, List[Tuple[int, Node]]]: + # add "npu" support + if any(device in self.device_types for device in ["cuda", "xpu", "npu"]): + if config.triton.autotune_at_compile_time: + # If autotune_at_compile_time is True, we can do the codegen in one-pass + return self.codegen() + else: + # first pass + self.cpp_wrapper = False + compiled = self.compile_to_module().call + + def materialize( + x: Union[torch.SymInt, torch.SymFloat, torch.Tensor] + ) -> Union[int, float, torch.Tensor]: + if x is None: + return None + elif isinstance(x, (torch.SymInt, torch.SymFloat)): + # Need concrete value to run dynamic shapes and tune the result + return x.node.hint + elif isinstance(x, FakeTensor): + return defake(x) + else: + if not isinstance(x, torch.Tensor): + raise AssertionError("Unknown type when creating real inputs" + str(type(x))) + return x + + tracing_context = torch._guards.TracingContext.try_get() + if tracing_context is not None and not isinstance( + V.real_inputs, NullHandler + ): + if tracing_context.output_strides: + tracing_context.output_strides.clear() + + params_flat = [ + param + for param in tracing_context.params_flat # type: ignore[union-attr] + if param is not None + ] + real_inputs = [ + materialize(x) + for x in itertools.chain(params_flat, V.real_inputs) + ] + else: + # In the backward pass, V.real_inputs is not OrderedSet. + # Generating random inputs based on self.example_inputs sometimes can be problematic, + # e.g. illegal memory access. A comprehensive fix is to autotune in a separate process. + real_inputs = [ + materialize(x) # type:ignore[arg-type] + for x in ( + self.example_inputs # type:ignore[union-attr] + if isinstance(V.real_inputs, NullHandler) + else V.real_inputs + ) + ] + + if self.mutated_inputs: + from .compile_fx import clone_preserve_strides + + mutated_input_idxs = [ + idx + for idx, name in enumerate(self.graph_inputs) + if name in self.mutated_inputs + and isinstance(real_inputs[idx], torch.Tensor) + ] + for idx in mutated_input_idxs: + # clone mutated Tensor inputs to avoid mutating them in + # the first pass of the CPP wrapper-based compilation, as + # this will lead to a side effect on the example inputs: + # e.g. if torch.compile(f)(x) if called on input-mutating + # f, the inputs x will be mutated twice in the process: + # once here, and again when running the compiled model; + # this will also lead to a numerically incorrect output + mutated_inp = real_inputs[idx] + if not isinstance(mutated_inp, torch.Tensor): + raise AssertionError + real_inputs[idx] = clone_preserve_strides(mutated_inp) + del mutated_inp + + with torch.utils._python_dispatch._disable_current_modes(): + compiled(real_inputs) + del real_inputs + + # second pass + self.cpp_wrapper = True + self.removed_buffers.clear() + self.removed_operations.clear() + self.inplaced_to_remove.clear() + V.graph.sizevars.precomputed_replacements.clear() + V.graph.sizevars.inv_precomputed_replacements.clear() + metrics.reset() + with config.patch({"triton.autotune_at_compile_time": False}): + return self.codegen() + else: + # cpu + return self.codegen() + from torch._inductor.graph import GraphLowering + GraphLowering.codegen_with_cpp_wrapper = npu_codegen_with_cpp_wrapper \ No newline at end of file diff --git a/torch_npu/_inductor/ir.py b/torch_npu/_inductor/ir.py new file mode 100644 index 0000000000000000000000000000000000000000..77a65303c4f947d45b6670e5f41f20e7ced7498d --- /dev/null +++ b/torch_npu/_inductor/ir.py @@ -0,0 +1,86 @@ +import itertools +import torch +from torch._inductor.virtualized import ops, OpsValue, V +from torch._inductor.ir import log, Layout +from torch._inductor import config + +def patch_fallback_kernel_codegen(): + def codegen_npu(self, wrapper) -> None: # type: ignore[no-untyped-def] + kernel = self.op_overload + if kernel.namespace == "aten": # type: ignore[union-attr] + if not isinstance(kernel, torch._ops.OpOverload): + raise AssertionError(f"kernel should be OpOverload, but got {type(kernel)}") + if V.graph.cpp_wrapper: + # Fallback all npu op to proxy executor and warn when gpu do not. + from torchgen.aoti.fallback_ops import inductor_fallback_ops + self.use_runtime_dispatch = True + if str(kernel) in inductor_fallback_ops: + log.warning( + "%s is using proxy executor as fallback instead of aoti shim.", + kernel, + ) + + elif kernel.namespace == "_quantized": # type: ignore[union-attr] + # Internal Quantized Fallback Ops + assert isinstance(kernel, torch._ops.OpOverload) + elif V.graph.cpp_wrapper: + # For non-aten OpOverload, i.e. custom ops + # If the op is in custom_ops_to_c_shims, generate direct function call + self.use_runtime_dispatch = ( + kernel not in config.aot_inductor.custom_ops_to_c_shims + ) + + # Handle the special case where a complex number is input to a C-shim kernel for + # a scalar input. The torchgen'ed shim API will use type "double", which is + # incompatible with complex numbers, forcing a fallback to runtime dispatch. + if ( + V.graph.cpp_wrapper + and isinstance(kernel, torch._ops.OpOverload) + and not self.use_runtime_dispatch + ): + + def is_number(t: torch.JitType) -> bool: + if isinstance(t, torch.OptionalType): + return is_number(t.getElementType()) + return isinstance(t, torch.NumberType) + + # Using unflatten_args is a bit of a hack, but all the complex arguments we + # care about are in self.constant_args, and calling unflatten_args puts them + # in the correct order without triggering codegen. + args, kwargs = self.unflatten_args(self.inputs, self.constant_args) + # Append kwarg values to args. ordered_kwargs_for_cpp_kernel is guaranteed + # to be set, since this is an OpOverload kernel. + args_iter = itertools.chain( + args, + ( + self.get_kwargs_value(k, **kwargs) + for k in self.ordered_kwargs_for_cpp_kernel + ), + ) + self.use_runtime_dispatch = any( + isinstance(v, complex) and is_number(a.real_type) + for v, a in zip(args_iter, kernel._schema.arguments) + ) + + self.codegen_comment(wrapper) + if self.use_runtime_dispatch: + exported_args = self.export_extern_kernel_node() + wrapper.generate_fallback_kernel_with_runtime_lookup( + self.get_name(), + self.python_kernel_name, + lambda: [*self.codegen_args(), *self.codegen_kwargs()], + self.op_overload, + exported_args, + # NOTE: [special handling of all_reduce_coalesced_'s return value] + self.outputs if self.outputs else self.mutation_outputs, + ) + else: + wrapper.generate_fallback_kernel(self) + if isinstance(self.layout, Layout): + self.codegen_size_asserts(wrapper) + self.codegen_alignment_asserts(wrapper) + + self.codegen_unbacked_symbol_defs(wrapper) + + from torch._inductor.ir import FallbackKernel + FallbackKernel.codegen = codegen_npu diff --git a/torch_npu/_inductor/lowering.py b/torch_npu/_inductor/lowering.py new file mode 100644 index 0000000000000000000000000000000000000000..2b47e091af8f49d3662f4c613a97b505f3e9266b --- /dev/null +++ b/torch_npu/_inductor/lowering.py @@ -0,0 +1,269 @@ +import sympy +import torch._ops +from torch._inductor import ir +from torch._inductor import lowering +from torch._inductor.decomposition import decompositions, pw_cast_for_opmath +from torch._inductor.ir import ExpandView, TensorBox, ops_wrapper +from torch._inductor.ir import Reduction +from torch._inductor.lowering import sum_ +from torch._inductor.utils import sympy_product +from torch._prims_common import ( + is_boolean_dtype, + is_integer_dtype, + get_computation_dtype, +) +from torch._inductor.lowering import ( + lowerings, + make_fallback, + register_lowering, + to_dtype, + fallback_cumsum, + _validate_reduction_axis, + div, + squeeze, + square, + sub, + fallback_handler, + is_boolean_type, + logical_and, + make_pointwise, + _make_reduction_inner, + _validate_reduction_axis, + add_needs_realized_inputs, + add_layout_constraint +) +import torch_npu +from torch_npu import npu_dtype_cast, _npu_dtype_cast +from .lowering_op_list import GENERATE_LIST, GENERATE_LIST2, FALLBACK_LIST, LOWERING_OVERLOAD_OP + + +def npu_make_fallback(op, layout_constraint=None, warn=True, override_decomp=False): + if op in decompositions and not override_decomp: + raise RuntimeError(f"both a fallback and a decomp for same op: {op}") + + def register_fallback(op_overload): + add_needs_realized_inputs(op_overload) + if layout_constraint is not None: + add_layout_constraint(op_overload, layout_constraint) + return register_lowering(op_overload, type_promotion_kind=None)( + fallback_handler(op_overload) + ) + + if isinstance(op, torch._ops.OpOverloadPacket): + for ol in op.overloads(): + op_overload = getattr(op, ol) + register_fallback(op_overload) + elif isinstance(op, (torch._ops.OpOverload, torch._ops.HigherOrderOperator)): + register_fallback(op) + else: + raise RuntimeError(f"Unsupported fallback {op} with type {type(op)}") + +make_fallback = npu_make_fallback + + +def make_reduction(reduction_type: str, override_return_dtype=None): + def inner(x, axis=None, keepdims=False, *, dtype=None): + kwargs = _make_reduction_inner( + x, + axis=axis, + keepdims=keepdims, + dtype=dtype, + override_return_dtype=override_return_dtype, + ) + result = Reduction.create(reduction_type=reduction_type, input_node=x, **kwargs) + if isinstance( + result.data.data, Reduction + ): # Only realize if reduction isn't unrolled + size = x.get_size() + axis = set(_validate_reduction_axis(x, axis)) + kept_idx = [] + reduced_idx = [] + for i in range(len(size)): + if i in axis: + reduced_idx.append(i) + else: + kept_idx.append(i) + + object.__setattr__(result.data.data, "kept_idx", kept_idx) + object.__setattr__(result.data.data, "reduced_idx", reduced_idx) + + result.realize() + return result + + return inner + +lowering.make_reduction = make_reduction + +aten = torch.ops.aten +tr_c10d = torch.ops.tr_c10d +prims = torch.ops.prims + + +def _init_set(input_list, output_set): + for fn in input_list: + output_set.add(fn) + if isinstance(fn, torch._ops.OpOverloadPacket): + for overload in fn.overloads(): + other_fn = getattr(fn, overload) + output_set.add(other_fn) + + +def _register_npu_inductor_fallbacks(): + gen_set = set() + _init_set(GENERATE_LIST, gen_set) + overload_op_set = set() + _init_set(LOWERING_OVERLOAD_OP, overload_op_set) + + # 把不在白名单的op fallback + for op in lowerings: + if op not in decompositions and op not in gen_set: + if isinstance(op, torch._ops.OpOverloadPacket) or \ + isinstance(op, (torch._ops.OpOverload, torch._ops.HigherOrderOperator)): + flag = False + for gens in GENERATE_LIST2: + if str(op).find(gens) != -1: + flag = True + if flag: + continue + else: + make_fallback(op) + FALLBACK_LIST.append(op) + # 把需要overload的op在lowering里删除 + for op in overload_op_set: + if op in lowerings: + del lowerings[op] + + # register the reductions useing custom make_reduction + reduce_amax = register_lowering(aten.amax)(make_reduction("max")) + reduce_amin = register_lowering(aten.amin)(make_reduction("min")) + reduce_argmax = register_lowering(aten.argmax)( + make_reduction("argmax", override_return_dtype=torch.int64) + ) + reduce_argmin = register_lowering(aten.argmin)( + make_reduction("argmin", override_return_dtype=torch.int64) + ) + + + @register_lowering(aten.max, type_promotion_kind=None) + def reduce_max(x, dim=None, keepdim=False): + if dim is not None: + return ( + reduce_amax(x, axis=dim, keepdims=keepdim), + reduce_argmax(x, axis=dim, keepdims=keepdim), + ) + + return reduce_amax(x, axis=None, keepdims=keepdim) + + @register_lowering(aten.min, type_promotion_kind=None) + def reduce_min(x, dim=None, keepdim=False): + if dim is not None: + return ( + reduce_amin(x, axis=dim, keepdims=keepdim), + reduce_argmin(x, axis=dim, keepdims=keepdim), + ) + + return reduce_amin(x, axis=None, keepdims=keepdim) + + @register_lowering(aten.mean) + def mean(x, axis=None, keepdim=False, *, dtype=None): + if dtype is not None: + x = to_dtype(x, dtype) + size = x.get_size() + axis = _validate_reduction_axis(x, axis) + # compute in higher-precision until end of mean lowering + output_dtype = x.get_dtype() + if output_dtype in (torch.float16, torch.bfloat16): + x = to_dtype(x, torch.float) + sum_result = sum_(x, axis, keepdim) + denom = sympy_product(size[i] for i in axis) + denom = ir.IndexingConstant(index=denom, dtype=x.get_dtype(), device=x.get_device()) + denom = ExpandView.create(denom, list(sum_result.get_size())) + return to_dtype(div(sum_result, denom), output_dtype) + + @register_lowering(aten.cumsum) + def cumsum(x, axis=None, dtype=None): + if ( + is_integer_dtype(x.get_dtype()) or is_boolean_dtype(x.get_dtype()) + ) and dtype is None: + # torch.int64->torch.int32 + dtype = torch.int32 + if len(x.get_size()) == 0: + if axis not in [0, -1]: + raise ValueError("axis must be 0 or -1") + dtype = dtype or x.get_dtype() + return to_dtype(x, dtype, copy=True) + return fallback_cumsum(x, dim=axis, dtype=dtype) + + @register_lowering(npu_dtype_cast, type_promotion_kind=None) + def _convert_npu_type(x: TensorBox, dtype: torch.dtype): + return to_dtype(x, dtype, copy=True) + + @register_lowering(_npu_dtype_cast, type_promotion_kind=None) + def _convert__npu_type(x: TensorBox, dtype: torch.dtype): + return to_dtype(x, dtype, copy=True) + + def var_mean_sum_(x, axis, correction, keepdim, return_mean): + if correction is None: + correction = 1 + + size = x.get_size() + axis = _validate_reduction_axis(x, axis) + x_mean = mean(x, axis, keepdim=True) + if return_mean: + x_mean.realize() + + diffs = square(sub(x, x_mean)) + sum_result = sum_(diffs, axis, keepdim) + denom = sympy_product(size[i] for i in axis) + if correction: + denom = sympy.Max(denom - correction, 0) + denom = ir.IndexingConstant(index=denom, dtype=x.get_dtype(), device=x.get_device()) + denom = ExpandView.create(denom, list(sum_result.get_size())) + x_var = div(sum_result, denom) + if not return_mean: + return (x_var,) + + x_mean = x_mean if keepdim else squeeze(x_mean, axis) + return x_var, x_mean + + def var_mean_helper_(x, *, axis, correction, keepdim, return_mean): + out_dtype = x.get_dtype() + compute_dtype = get_computation_dtype(out_dtype) + x = to_dtype(x, compute_dtype, copy=False) + kwargs = dict( + x=x, + axis=axis, + correction=correction, + keepdim=keepdim, + return_mean=return_mean, + ) + output = ( + var_mean_sum_(**kwargs) + ) + output = tuple(to_dtype(x, out_dtype, copy=False) for x in output) + return output[0] if not return_mean else output + + @register_lowering(aten.var_mean) + def var_mean(x, axis=None, *, correction=None, keepdim=False): + return var_mean_helper_( + x, axis=axis, correction=correction, keepdim=keepdim, return_mean=True + ) + + @register_lowering([aten.var, prims.var]) + def var_(x, axis=None, *, correction=None, keepdim=False): + return var_mean_helper_( + x, axis=axis, correction=correction, keepdim=keepdim, return_mean=False + ) + + @register_lowering(aten.embedding, type_promotion_kind=None) + def embedding(weight, indices, padding_idx=-1, scale_grad_by_freq=False, sparse=False): + return fallback_handler(aten.embedding.default)(weight, indices, padding_idx=-1, scale_grad_by_freq=False, + sparse=False) + + @register_lowering(aten.cat) + def cat(inputs, dim=0): + return fallback_handler(aten.cat.default)(inputs, dim) + + make_fallback(aten._log_softmax) + make_fallback(aten.gather) + make_fallback(aten.nll_loss_forward) diff --git a/torch_npu/_inductor/lowering_fx.py b/torch_npu/_inductor/lowering_fx.py new file mode 100644 index 0000000000000000000000000000000000000000..f863f50e7a2a71d54239732cc7c7606a20fcd4de --- /dev/null +++ b/torch_npu/_inductor/lowering_fx.py @@ -0,0 +1,2314 @@ +import functools +import itertools +import os +import math +import textwrap +from typing import ( + Any, + Callable, + Dict, + List, + Optional, + Set, + Tuple, + Union, +) +import sympy +import torch._ops +import torch._ops +from sympy.core import Expr, Integer, Symbol +from torch._inductor import ir +from torch._inductor import ir +from torch._inductor import lowering +from torch._inductor import lowering +from torch._inductor import scheduler +from torch._inductor import scheduler +from torch._inductor.decomposition import decompositions +from torch._inductor.decomposition import decompositions, pw_cast_for_opmath +from torch._inductor.fx_passes.post_grad import view_to_reshape +from torch._inductor.ir import ( + ExpandView, + IndexingConstant, + is_triton, + ops_wrapper, + PermuteView, + Pointwise, + Reduction, + SqueezeView, + TensorBox, + IRNode, + validate_ir, + View, +) +from torch._inductor.ir import ExpandView, TensorBox +from torch._inductor.ir import ExpandView, TensorBox +from torch._inductor.ir import Reduction +from torch._inductor.ir import Reduction +from torch._inductor.utils import ModularIndexing, FloorDiv +from torch._inductor.utils import ( + decode_device, + sympy_product, +) +from torch._inductor.utils import sympy_product +from torch._inductor.utils import sympy_product +from torch._inductor.virtualized import ops, V +from torch._prims_common import ( + canonicalize_dims, + check, + dtype_to_type, + ELEMENTWISE_TYPE_PROMOTION_KIND, + get_computation_dtype, + is_boolean_dtype, + is_float_dtype, + is_integer_dtype, + Number, +) +from torch._prims_common import ( + is_boolean_dtype, + is_integer_dtype, + get_computation_dtype, +) +from torch._prims_common import ( + is_boolean_dtype, + is_integer_dtype, + get_computation_dtype, +) +from torch.fx.experimental.proxy_tensor import make_fx +from torch.utils._sympy.functions import ( + FloorDiv, + Identity, + ModularIndexing, +) +from .config import log +from .lowering_op_list import GENERATE_LIST, GENERATE_LIST2, FALLBACK_LIST, LOWERING_OVERLOAD_OP + +aten = torch.ops.aten +tr_c10d = torch.ops.tr_c10d +prims = torch.ops.prims +npu = torch.ops.npu + + +def _init_set(input_list, output_set): + for fn in input_list: + output_set.add(fn) + if isinstance(fn, torch._ops.OpOverloadPacket): + for overload in fn.overloads(): + other_fn = getattr(fn, overload) + output_set.add(other_fn) + + +LOWERING_OVERLOAD_OP = list(set(GENERATE_LIST) | set(LOWERING_OVERLOAD_OP)) + +fn_to_aten_fn = {} +node_id = itertools.count(0) +snodes_to_fx = {} + + +def register_fn_to_aten_fn(fn, aten_fn=None): + if fn not in fn_to_aten_fn: + fn_to_aten_fn[fn] = aten_fn + return fn + + +def register_to_aten(aten_fn=None): + def decorator(fn): + if fn not in fn_to_aten_fn: + fn_to_aten_fn[fn] = aten_fn + return fn + + return decorator + + +reduction_type_to_aten_fn = { + "sum": aten.sum, + "prod": aten.prod, + "xor_sum": prims.xor_sum, + "any": aten.any, + "max": aten.amax, + "min": aten.amin, + "argmax": aten.argmax, + "argmin": aten.argmin +} + +operator_to_string = { + '+': 'a', + '-': 'sub', + '*': 'm', + '/': 'd', + '(': 'l', + ')': 'r', + '.': 'p', +} + +string_to_operator = {v: k for k, v in operator_to_string.items()} + + +def map_operators_to_strings(expr_str: str): + expr_str = expr_str.replace(' ', '') + for op, string in operator_to_string.items(): + expr_str = expr_str.replace(op, string) + return '_' + expr_str + + +def map_strings_to_operators(expr_str: str): + for op, string in string_to_operator.items(): + expr_str = expr_str.replace(op, string) + return expr_str[1:] + + +class TracedGraph: + def __init__(self): + self.graph = torch.fx.Graph() + self.last_node: Optional[torch.fx.Node] = None + self.sym_nodes: Dict[str, torch.fx.Node] = {} + + def __str__(self): + return str(self.graph) + + def get_placeholder_names(self): + placeholder_names = set() + for node in self.graph.nodes: + if node.op == 'placeholder' and node.name not in self.sym_nodes: + placeholder_names.add(node.name) + return placeholder_names + + __repr__ = __str__ + + +def create_fake_input(size, stride, device, dtype): + size = [V.graph.sizevars.shape_env.create_symintnode(s, hint=None) \ + if isinstance(s, Expr) and not isinstance(s, Integer) else s for s in size] + stride = [V.graph.sizevars.shape_env.create_symintnode(s, hint=None) \ + if isinstance(s, Expr) and not isinstance(s, Integer) else s for s in stride] + with V.graph.fake_mode: + fake_input = torch.empty_strided(size, stride, device=device, dtype=dtype) + return fake_input + + +def create_sym_inputs(traced_graph: TracedGraph, size: List[Expr]): + for s in size: + if isinstance(s, (List, Tuple)): + create_sym_inputs(traced_graph, s) + continue + if isinstance(s, Expr) and not isinstance(s, Integer): + s_name = str(s) + if not isinstance(s, Symbol): + s_name = map_operators_to_strings(s_name) + if s_name in traced_graph.sym_nodes: + continue + new_node = traced_graph.graph.placeholder(s_name) + new_node.meta['val'] = V.graph.sizevars.shape_env.create_symintnode(s, hint=None) + traced_graph.sym_nodes.update({s_name: new_node}) + + +def process_ir_constant(inp: ExpandView) -> Union[TracedGraph, int, float]: + skip = False + if isinstance(inp.data, IndexingConstant): + dtype = inp.data.dtype + inp = inp.data.index + # convert to original dtype. + if dtype in [torch.float32, torch.float16, torch.bfloat16]: + # sympy inputs + if isinstance(inp, Expr) and not isinstance(inp, sympy.core.numbers.Number): + traced_graph = TracedGraph() + create_sym_inputs(traced_graph, [inp]) + s_name = str(inp) + if not isinstance(inp, Symbol): + s_name = map_operators_to_strings(str(inp)) + traced_graph.last_node = traced_graph.sym_nodes[s_name] + inp = traced_graph + else: + inp = float(inp) + elif isinstance(inp.data, ir.Constant): + dtype = inp.data.dtype + inp = inp.data.value + else: + skip = True + return inp, skip + + +def fetch_graphs(inputs: Optional[List[TensorBox]]): + if isinstance(inputs, (TensorBox, ir.StorageBox, ir.View, sympy.Symbol, ir.Constant)): + inputs = [inputs] + input_graphs = [] + for inp in inputs: + if isinstance(inp, List): + input_graphs.append(fetch_graphs(inp)) + continue + if not isinstance(inp, ( + TensorBox, ir.StorageBox, ir.View, ir.ReinterpretView, ir.PermuteView, ir.SliceView, ir.ExpandView)): + input_graphs.append(inp) + continue + if isinstance(inp, ExpandView): + inp, skip = process_ir_constant(inp) + if not skip: + input_graphs.append(inp) + continue + name = inp.get_name() + traced_graph = inp.get_traced_graph() + if traced_graph is not None: + input_graphs.append(traced_graph) + continue + traced_graph = TracedGraph() + device = inp.get_device() + dtype = inp.get_dtype() + size = inp.get_size() + stride = inp.get_stride() + new_node = traced_graph.graph.placeholder(name) + fake_input = create_fake_input(size, stride, device, dtype) + new_node.meta['val'] = fake_input + traced_graph.last_node = new_node + input_graphs.append(traced_graph) + return input_graphs + + +def merge_traced_graphs(input_graphs: List[TracedGraph], origin_fn, node_name, **kwargs): + new_graph = TracedGraph() + exist_nodes: Dict[str, torch.fx.Node] = {} + + def merge_graph(input_graphs: List[TracedGraph]): + for input_graph in input_graphs: + if isinstance(input_graph, List): + merge_graph(input_graph) + continue + if not isinstance(input_graph, TracedGraph): + continue + for node in input_graph.graph.nodes: + if node.name in exist_nodes: + continue + new_node = new_graph.graph.node_copy(node, lambda n: exist_nodes[n.name]) + exist_nodes[node.name] = new_node + if node.name in input_graph.sym_nodes: + new_graph.sym_nodes.update({node.name: new_node}) + + def parse_args(input_graphs, exist_nodes): + args = [] + for input_graph in input_graphs: + if isinstance(input_graph, TracedGraph): + args.append(exist_nodes[input_graph.last_node.name]) + elif isinstance(input_graph, (List, Tuple)): + args.append(parse_args(input_graph, exist_nodes)) + else: + if isinstance(input_graph, Expr) and not isinstance(input_graph, Integer): + if not isinstance(input_graph, Symbol): + input_graph = map_operators_to_strings(str(input_graph)) + args.append(new_graph.sym_nodes[str(input_graph)]) + else: + args.append(input_graph) + return args + + num_args = len(input_graphs) + + for k, v in kwargs.items(): + if isinstance(v, Expr) and not isinstance(v, Integer): + traced_graph = TracedGraph() + create_sym_inputs(traced_graph, [v]) + s_name = str(v) + if not isinstance(v, Symbol): + s_name = map_operators_to_strings(str(v)) + traced_graph.last_node = traced_graph.sym_nodes[s_name] + kwargs[k] = traced_graph.sym_nodes[s_name] + input_graphs.append(traced_graph) + merge_graph(input_graphs) + input_graphs = input_graphs[:num_args] + # if inputs do not have any valid graphs, like full/iota + create_sym_inputs(new_graph, input_graphs) + args = parse_args(input_graphs, exist_nodes) + with new_graph.graph.inserting_after(new_graph.last_node): + new_node = new_graph.graph.call_function(origin_fn, args=tuple(args), kwargs=kwargs) + new_node.name = node_name + new_graph.last_node = new_node + return new_graph + + +def merge_fx_graphs(traced_graphs: List[TracedGraph]): + new_graph = TracedGraph() + exist_nodes: Dict[str, torch.fx.Node] = {} + last_nodes = [] + + def merge_graph(input_graphs: List[TracedGraph]): + for input_graph in input_graphs: + if isinstance(input_graph, List): + merge_graph(input_graph) + continue + if not isinstance(input_graph, TracedGraph): + continue + for node in input_graph.graph.nodes: + if node.name in exist_nodes: + continue + new_node = new_graph.graph.node_copy(node, lambda n: exist_nodes[n.name]) + exist_nodes[node.name] = new_node + last_nodes.append(exist_nodes[input_graph.last_node.name]) + + merge_graph(traced_graphs) + new_graph.last_node = last_nodes + return new_graph + + +def subtract_graph(graph1: TracedGraph, graph2: TracedGraph, node_name=None) -> Tuple[TracedGraph, torch.fx.Node]: + new_graph = TracedGraph() + last_node2 = graph2.last_node + graph1_node_names = {node.name for node in graph1.graph.nodes} + graph2_node_names = {node.name for node in graph2.graph.nodes} + placeholder = None + exist_nodes: Dict[str, torch.fx.Node] = {} + if node_name not in graph1_node_names: + placeholder = new_graph.graph.placeholder(last_node2.name if node_name is None else node_name) + exist_nodes[last_node2.name] = placeholder + for node in graph1.graph.nodes: + if node.name in graph2_node_names and node.name not in graph1.sym_nodes: + continue + new_node = new_graph.graph.node_copy(node, lambda n: exist_nodes[n.name]) + exist_nodes[node.name] = new_node + new_graph.last_node = exist_nodes[graph1.last_node.name] + new_graph.sym_nodes = graph1.sym_nodes + return new_graph, placeholder + + +def get_last_node(gm: torch.fx.GraphModule): + last_node = None + for node in gm.graph.nodes: + last_node = node + return last_node + + +def tensor_info(tensor): + if isinstance(tensor, (list, tuple)): + infos = ", ".join(tensor_info(t) for t in tensor) + return f"[{infos}]" + if not isinstance(tensor, torch.Tensor): + return str(tensor) + info = str(tensor) + info = info[:-1] + info += f", strides={tensor.stride()})" + return info + + +def create_fx_from_snodes_by_traced_graph(snodes: List[scheduler.SchedulerNode]): + fx_call_inputs = [] + try: + for snode in snodes: + snode.node.data.traced_graph.last_node.name = snode.node.get_name() + except Exception as e: + log.warning(f"Could not rebuild fx graph for {snodes}, reason: {e}") + return None, None, None, None + + if len(snodes) == 1: + traced_graph = snodes[0].node.data.traced_graph + else: + traced_graph = merge_fx_graphs([snode.node.data.traced_graph for snode in snodes]) + fx_inputs = [] + for node in traced_graph.graph.nodes: + if node.op == 'placeholder': + fx_call_inputs.append(node.target) + fx_inputs.append(node.meta['val']) + non_contiguous_indices = {} + non_contiguous_indices["inputs"] = [ + i + for i, inp in enumerate(fx_inputs) + if torch.is_tensor(inp) and not inp.is_contiguous() + ] + num_inputs = len(fx_call_inputs) + fx_call_outputs = [] + for snode in snodes: + if snode.has_aliasing_or_mutation(): + for buf in snode.get_outputs(): + if len(buf.get_mutations()): + fx_call_outputs.extend(buf.get_mutations()) + elif len(buf.get_aliases()): + fx_call_outputs.append(buf.get_name()) + elif snode.node.get_name() not in (V.graph.removed_buffers | V.graph.inplaced_to_remove): + fx_call_outputs.append(snode.node.get_name()) + num_outputs = len(fx_call_outputs) + outputs = traced_graph.last_node if isinstance(traced_graph.last_node, List) \ + else [traced_graph.last_node] + outputs = [ + output + for output in outputs + if output.name not in (V.graph.removed_buffers | V.graph.inplaced_to_remove) + ] + fx_call_args = fx_call_inputs + fx_call_outputs + traced_graph.graph.output(tuple(outputs)) + traced_graph.graph.lint() + orig_module = torch.nn.Module() + gm = torch.fx.GraphModule(orig_module, traced_graph.graph) + gm.recompile() + + def runnable_gm(*args): + return torch.fx.Interpreter(gm).run(*args) + + with V.graph.fake_mode: + gm = make_fx(runnable_gm)(*fx_inputs) + view_to_reshape(gm) + last_node = get_last_node(gm) + fx_output_nodes = last_node.args[0] + fx_outputs = [node.meta['val'] for node in fx_output_nodes] + non_contiguous_indices["outputs"] = [ + i + num_inputs + for i, call_output in enumerate(fx_call_outputs) + if not V.graph.try_get_buffer(call_output).layout.is_contiguous() + ] + fx_args = fx_inputs + fx_outputs + snodes_to_fx[str(snodes)] = f"{gm}\n inputs: {tensor_info(fx_inputs)}\n outputs: {tensor_info(fx_outputs)}\n" + + return gm, fx_call_args, fx_args, { + "num_inputs": num_inputs, + "num_outputs": num_outputs, + "non_contiguous_indices": non_contiguous_indices, + } + + +def create_compile_kwargs(final_kernel, fx_call_args, fx_args): + _, kernel_call_args, _, arg_types = final_kernel.args.python_argdefs() + for idx, call_arg in enumerate(fx_call_args): + if call_arg in final_kernel.args.inplace_buffers: + fx_call_args[idx] = final_kernel.args.inplace_buffers[call_arg].other_names[-1] + fx_arg_shapes = [fx_arg.shape for fx_arg in fx_args if isinstance(fx_arg, torch.Tensor)] + + if set(kernel_call_args) != set(fx_call_args): + return None + final_kernel.add_numel_to_call_args(final_kernel.kernel_name, kernel_call_args, arg_types) + + index_map = {element: idx for idx, element in enumerate(kernel_call_args)} + call_args_mapping = [index_map[element] for element in fx_call_args] + + mismatch_indices_shapes = {} + + for i in range(len(fx_call_args)): + mismatch_indices_shapes[i] = fx_arg_shapes[i] + + return { + "call_args_mapping": call_args_mapping, + "mismatch_indices_shapes": mismatch_indices_shapes, + } + + +def generate_fx_graph_code(code, kernel_code, kernel_name, compile_kwargs): + code = textwrap.indent(code, ' ') + code_template = f""" +import os +import torch +from torch._inductor.compile_fx import clone_preserve_strides +from torch._dynamo.testing import rand_strided +from torch import device + +import math +import random +import os +import tempfile +from math import inf, nan +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +from torch._inductor.codegen.multi_kernel import MultiKernelCall +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch_npu._inductor import get_current_raw_stream as get_raw_stream +from torch_npu._inductor import config as npu_config + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool + +file_path = os.path.abspath(__file__) +dir_path = os.path.dirname(file_path) + + +class GraphModule(torch.nn.Module): + def __init__(self): + super().__init__() +{code} +model = GraphModule().npu() +call_args_mapping = {compile_kwargs['call_args_mapping']} +num_inputs = {compile_kwargs['num_inputs']} +num_outputs = {compile_kwargs['num_outputs']} +non_contiguous_indices = {compile_kwargs['non_contiguous_indices']} +mismatch_indices_shapes = {compile_kwargs['mismatch_indices_shapes']} + +def run(): + async_compile = AsyncCompile() + {kernel_name} = async_compile.triton('{kernel_name}', ''' +{kernel_code} + ''', device_str='npu') + + async_compile.wait(globals()) + del async_compile + + stream0 = get_raw_stream(0) + + + args = torch.load(os.path.join(dir_path, "data.pth")) + + call_inputs_indices = call_args_mapping[:num_inputs] + call_outputs_indices = call_args_mapping[num_inputs:] + + args = [arg.npu() if isinstance(arg, torch.Tensor) else arg for arg in args] + + fx_args = [] + for idx in call_args_mapping: + arg = args[idx] + if isinstance(arg, torch.Tensor): + fx_arg = clone_preserve_strides(arg).float() if arg.dtype == torch.bfloat16 else clone_preserve_strides(arg) + fx_args.append(fx_arg) + + fx_inputs = [fx_args[idx].contiguous() if idx in non_contiguous_indices['inputs'] else fx_args[idx] for idx in range(num_inputs)] + if len(mismatch_indices_shapes): + for ind, shape in mismatch_indices_shapes.items(): + if ind >= num_inputs: + break + fx_inputs[ind] = fx_inputs[ind].reshape(shape) + model_outputs = model.forward(*fx_inputs) + for idx, (out1, out2) in enumerate(zip(model_outputs, fx_args[num_inputs:(num_inputs + num_outputs)])): + out1 = out1.reshape(out2.shape) + if idx in non_contiguous_indices['outputs']: + out2.copy_(out1) + else: + out2.data = out1.data + + {kernel_name}.run(*args, stream=stream0) + + for actual, expected in zip([args[i] for i in call_outputs_indices], fx_args[num_inputs:]): + if actual.dtype != expected.dtype: + expected = expected.to(actual.dtype) + acc_comp_tol = npu_config.acc_comp_tol.get(actual.dtype, npu_config.acc_comp_tol['default']) + rtol = acc_comp_tol['rtol'] + atol = acc_comp_tol['atol'] + try: + torch.testing.assert_close(actual, expected, rtol=rtol, atol=atol, equal_nan=False) + except Exception as e: + print(e) + +if __name__ == "__main__": + run() +""" + return code_template + + +def dump_fx_graph_code(code, dump_path, traced_graph_hash): + py_path = os.path.join(dump_path, traced_graph_hash + '.py') + with open(py_path, 'w') as f: + f.write(code) + + +def clone(x, *, memory_format=None): + # TODO(jansel): memory format + input_graphs = fetch_graphs(x) + node_name = f'clone_{next(node_id)}' + new_graph = merge_traced_graphs(input_graphs, aten.clone, node_name) + return Pointwise.create( + device=x.get_device(), + dtype=x.get_dtype(), + inner_fn=x.make_loader(), + ranges=list(x.get_size()), + traced_graph=new_graph, + node_name=node_name + ) + + +def _register_npu_inductor_fallbacks(): + gen_set = set() + _init_set(GENERATE_LIST, gen_set) + overload_op_set = set() + _init_set(LOWERING_OVERLOAD_OP, overload_op_set) + + # 把不在白名单的op fallback + for op in lowering.lowerings: + if op not in decompositions and op not in gen_set: + if isinstance(op, torch._ops.OpOverloadPacket) or \ + isinstance(op, (torch._ops.OpOverload, torch._ops.HigherOrderOperator)): + flag = False + for gens in GENERATE_LIST2: + if str(op).find(gens) != -1: + flag = True + if flag: + continue + else: + lowering.make_fallback(op) + FALLBACK_LIST.append(op) + + # 把需要overload的op在lowering里删除 + for op in overload_op_set: + if op in lowering.lowerings: + del lowering.lowerings[op] + + def transform_args( + args: List[Any], + kwargs: Dict[str, Any], + broadcast: bool, + type_promotion_kind: Optional[ELEMENTWISE_TYPE_PROMOTION_KIND], + convert_input_to_bool: bool, + ) -> Tuple[List[Any], Dict[str, Any]]: + args_indices = [i for i, x in enumerate(args) if isinstance(x, TensorBox)] + kwargs_indices = [k for k, v in kwargs.items() if isinstance(v, TensorBox)] + # check that there's something to transform + if not args_indices and not kwargs_indices: + return args, kwargs + + if type_promotion_kind or convert_input_to_bool: + if convert_input_to_bool: + dtype = torch.bool + else: + # this is a crude approximation for promoting args + promoting_args = [ + a + for a in args + if isinstance(a, (Number, sympy.Basic)) or hasattr(a, "dtype") + ] + # only consider tensor kwargs for promotion, for now + promoting_args.extend(a for a in kwargs.values() if hasattr(a, "dtype")) + dtype = lowering.get_promoted_dtype( + *promoting_args, type_promotion_kind=type_promotion_kind # type: ignore[arg-type] + ) + + device = ( + args[args_indices[0]] if args_indices else kwargs[kwargs_indices[0]] + ).get_device() + + # sometimes args are an immutable list so we can't mutate them + def promote(arg): + if isinstance(arg, TensorBox): + return to_dtype(arg, dtype) + elif isinstance(arg, ir.Constant): + return ir.Constant(value=arg.value, dtype=dtype, device=device) + else: + return arg + + args = [promote(a) for a in args] + kwargs = {k: promote(v) for k, v in kwargs.items()} + + if broadcast: + broadcasted = broadcast_tensors( + *list( + itertools.chain( + (args[i] for i in args_indices), + (kwargs[k] for k in kwargs_indices), + ) + ) + ) + size = list(broadcasted[0].get_size()) + + for i, x in zip(args_indices, broadcasted[: len(args_indices)]): + args[i] = x + for k, x in zip(kwargs_indices, broadcasted[len(args_indices):]): + kwargs[k] = x + + for i in range(len(args)): + if isinstance(args[i], ir.Constant): + args[i] = ExpandView.create(args[i], size) + for k in kwargs: + if isinstance(kwargs[k], ir.Constant): + kwargs[k] = ExpandView.create(kwargs[k], size) + + return args, kwargs + + def _register_lowering( + aten_fn, decomp_fn, broadcast, type_promotion_kind, convert_input_to_bool + ): + + """ + Add a lowering to lowerings dict + + Arguments: + aten_fn: torch.ops.aten.* fn we are lowering + decomp_fn: alternate implementation on our IR + broadcast: True to apply broadcasting to tensor inputs + type_promotion_kind: kind of type promotion applied to tensor inputs, `None` means no type promotion + convert_input_to_bool: some logical ops require inputs are converted to bool + """ + + @functools.wraps(decomp_fn) + def wrapped(*args, **kwargs): + args: List[Any] = list(args) + kwargs: Dict[str, Any] = dict(kwargs) + unpacked = False + # maybe we need to use pytrees here + if len(args) == 1 and isinstance(args[0], (list, tuple)): + unpacked = True + args = list(args[0]) + + if not all( + (fn in lowering.fallbacks or lowering.in_namespace(fn, "_c10d_functional")) for fn in aten_fn + ): + # explicitly assert for "out=" ops for better error messages + if any(x == "out" for x in kwargs.keys()): + raise RuntimeError("assert out= ops aren't yet supported") + + args, kwargs = transform_args( + args, kwargs, broadcast, type_promotion_kind, convert_input_to_bool + ) + + if unpacked: + args = [args] + + out = decomp_fn(*args, **kwargs) + validate_ir(out) + + return out + + aten_fn = lowering.get_overloads(aten_fn) + + lowering.lowerings.update(dict.fromkeys(aten_fn, wrapped)) + return wrapped + + def register_lowering( + aten_fn, + broadcast=False, + type_promotion_kind=lowering.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, + convert_input_to_bool=False, + ): + + """ + Shim to support decorator syntax. + """ + return functools.partial( + _register_lowering, + aten_fn, + broadcast=broadcast, + type_promotion_kind=type_promotion_kind, + convert_input_to_bool=convert_input_to_bool, + ) + + def _make_reduction_inner(x, *, axis, keepdims, dtype, override_return_dtype): + if dtype is not None: + x = to_dtype(x, dtype) + size = x.get_size() + axis = set(lowering._validate_reduction_axis(x, axis)) + + kept_sizes = [] + kept_idx = [] + reduced_sizes = [] + reduced_idx = [] + for i in range(len(size)): + if i in axis: + reduced_idx.append(i) + reduced_sizes.append(size[i]) + else: + kept_idx.append(i) + kept_sizes.append(size[i]) + + def loader(index, reduction_index): + if len(reduction_index) != len(reduced_idx): + raise RuntimeError("assert reduction index length mismatch") + if keepdims: + if len(index) != len(size): + raise RuntimeError("assert index size length mismatch") + index = [index[i] for i in kept_idx] + if len(index) != len(kept_idx): + raise RuntimeError("assert index kept_idx length mismatch") + new_index = [None] * (len(index) + len(reduction_index)) + for idx, var in itertools.chain( + zip(kept_idx, index), zip(reduced_idx, reduction_index) + ): + new_index[idx] = var + return inner_loader(new_index) + + if keepdims: + new_size = list(size) + for i in reduced_idx: + new_size[i] = sympy.S.One + else: + new_size = kept_sizes + + inner_loader = x.make_loader() + return dict( + device=x.get_device(), + dst_dtype=override_return_dtype or x.get_dtype(), + src_dtype=x.get_dtype(), + inner_fn=loader, + ranges=new_size, + reduction_ranges=reduced_sizes, + ) + + def make_reduction(reduction_type: str, override_return_dtype=None): + def inner(x, axis=None, keepdims=False, *, dtype=None): + kwargs = _make_reduction_inner( + x, + axis=axis, + keepdims=keepdims, + dtype=dtype, + override_return_dtype=override_return_dtype, + ) + node_name = f'reduction_{next(node_id)}' + input_graphs = fetch_graphs([x, axis if axis is not None else list(range(len(x.get_size())))]) + new_graph = merge_traced_graphs(input_graphs, reduction_type_to_aten_fn[reduction_type], + node_name, keepdim=keepdims) + + result = Reduction.create(reduction_type=reduction_type, + input_node=x, + node_name=node_name, + traced_graph=new_graph, + **kwargs) + if isinstance( + result.data.data, Reduction + ): + # Only realize if reduction isn't unrolled + size = x.get_size() + axis = set(lowering._validate_reduction_axis(x, axis)) + kept_idx = [] + reduced_idx = [] + for i in range(len(size)): + if i in axis: + reduced_idx.append(i) + else: + kept_idx.append(i) + + object.__setattr__(result.data.data, "kept_idx", kept_idx) + object.__setattr__(result.data.data, "reduced_idx", reduced_idx) + + result.realize() + return result + + return inner + + lowering.make_reduction = make_reduction + + def to_dtype(x: TensorBox, dtype: torch.dtype, copy=False): + src_dtype = x.get_dtype() + if src_dtype == dtype: + return clone(x) if copy else x + + def _to_dtype(x): + return ops.to_dtype(x, dtype, src_dtype=src_dtype) + + register_fn_to_aten_fn(_to_dtype, aten.to.dtype) + return make_pointwise(_to_dtype, override_return_dtype=dtype, dtype=dtype)(x) + + @register_lowering(prims.convert_element_type, type_promotion_kind=None) + def _convert_element_type(x: TensorBox, dtype: torch.dtype): + if dtype.is_complex or x.get_dtype().is_complex: + if x.get_size(): + # Decompose since aa aten fallback is more friendly for c++ codegen. + # This decomposition doesn't work for empty tensor, which needs more investigation. + dst = empty_like(x, dtype=dtype) + ir.InplaceCopyFallback.create(dst, x) + return dst + else: + return lowering.fallback_handler( + prims.convert_element_type.default, add_to_fallback_set=False + )(x, dtype) + return to_dtype(x, dtype, copy=True) + + def register_pointwise( + aten_fn, + name=None, + broadcast=True, + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, + convert_input_to_bool=False, + override_return_dtype=None, + override_fn_when_input_bool=None, + allow_alpha=False, + use_libdevice_for_f64=False, + triton_fallback=None, + ): + """A pointwise function that maps ops.{name} to inputs""" + name = name or aten_fn.__name__ + fn = ops_wrapper(name) + if use_libdevice_for_f64: + fn_libdevice = ops_wrapper("libdevice_" + name) + lowering.register_op_dtype_propagation_rules( + "libdevice_" + name, type_promotion_kind, override_return_dtype + ) + + lowering.register_op_dtype_propagation_rules( + name, type_promotion_kind, override_return_dtype + ) + + if override_fn_when_input_bool is not None: + override_fn_when_input_bool = ops_wrapper(override_fn_when_input_bool) + + fn = register_fn_to_aten_fn(fn, aten_fn) + + fn = make_pointwise( + fn, + override_return_dtype=override_return_dtype, + override_fn_when_input_bool=override_fn_when_input_bool, + override_fn_when_gpu_float64=fn_libdevice if use_libdevice_for_f64 else None, + # type: ignore[possibly-undefined] + allow_alpha=allow_alpha, + triton_fallback=triton_fallback, + ) + fn = register_lowering( + aten_fn, + broadcast=broadcast, + type_promotion_kind=type_promotion_kind, + convert_input_to_bool=convert_input_to_bool, + )(fn) + + if hasattr(prims, name): + register_lowering( + getattr(prims, name), + type_promotion_kind=None, + convert_input_to_bool=convert_input_to_bool, + )(fn) + return fn + + def make_pointwise( + fn, + override_return_dtype=None, + override_device=None, + override_fn_when_input_bool=None, + override_fn_when_gpu_float64=None, + allow_alpha=False, + triton_fallback=None, + **kwargs + ): + def inner(*inputs: TensorBox, alpha=None): + if triton_fallback is not None and any( + isinstance(inp, IRNode) and is_triton(inp) for inp in inputs + ): + # not implemented + if allow_alpha: + raise RuntimeError("assert allow_alpha is not allowed") + return triton_fallback(*inputs) + + inputs = lowering.promote_constants(inputs, override_return_dtype) + if allow_alpha: + if alpha is not None and alpha != 1: + inputs = list(inputs) + inputs[-1] = mul(inputs[-1], alpha) + else: + if alpha is not None: + raise RuntimeError("assert alpha is not None") + loaders = [x.make_loader() for x in inputs] + ranges = inputs[0].get_size() + dtype = override_return_dtype or inputs[0].get_dtype() + is_gpu_device = lowering.is_gpu(decode_device(inputs[0].get_device()).type) + + for other in inputs[1:]: + if not (isinstance(other, ir.BaseConstant) or len(ranges) == len(other.get_size())): + raise RuntimeError(f"assert ndim mismatch {fn} {ranges} {other.get_size()}") + + # in tracing, we will annotate pointwise nodes that correspond to the output of + # a pointwise node that would have been run in eager. intermediary pointwise nodes + # during decompositions are not annotated. + emulate_precision_casts = ( + V.graph is not None + and getattr(V.graph, "current_node", None) is not None + and V.graph.current_node.meta is not None + and V.graph.current_node.meta.get("low_precision_pointwise_barrier", False) + and dtype in (torch.bfloat16, torch.float16) + ) + + def inner_fn(index): + if len(index) != len(ranges): + raise RuntimeError(f"assert wrong ndim {index} {ranges}") + if dtype == torch.bool and override_fn_when_input_bool is not None: + return override_fn_when_input_bool(*[load(index) for load in loaders]) + elif ( + override_fn_when_gpu_float64 + and is_gpu_device + and dtype == torch.float64 + ): + return override_fn_when_gpu_float64(*[load(index) for load in loaders]) + else: + inputs_loaded = [] + for load in loaders: + out = load(index) + if emulate_precision_casts: + downcast = ops.to_dtype(out, dtype, use_compute_types=False) + out = ops.to_dtype(downcast, dtype) + inputs_loaded.append(out) + + out = fn(*inputs_loaded) + if emulate_precision_casts: + # fp16/bf16 kernels are computed in fp32. Casting down to fp16/bf16 here, + # then upcasting again, to emulate casts that eager would do. + downcast = ops.to_dtype(out, dtype, use_compute_types=False) + return ops.to_dtype(downcast, dtype) + return out + + if not override_device: + device = None + for i in inputs: + if lowering.is_gpu(i.get_device().type): + device = i.get_device() + break + if not device: + device = inputs[0].get_device() + + device = override_device or device + + input_graphs = fetch_graphs(inputs) + node_name = f'pointwise_{next(node_id)}' + origin_fn = fn_to_aten_fn[fn] + new_graph = merge_traced_graphs(input_graphs, origin_fn, node_name, **kwargs) + + return Pointwise.create( + device=device, + dtype=dtype, + inner_fn=inner_fn, + ranges=ranges, + node_name=node_name, + traced_graph=new_graph, + ) + + return inner + + @register_lowering(aten.where, broadcast=False, type_promotion_kind=None) + def where(cond, a, b): + def fn(*args): + return ops.where(*args) + + if isinstance(a, (float, int)): + a = lowering.constant_like(a)(b) + if isinstance(b, (float, int)): + b = lowering.constant_like(b)(a) + + args = [cond, a, b] + dtype = lowering.get_promoted_dtype( + args[1], args[2], type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ) + indices = [i for i, x in enumerate(args) if isinstance(x, TensorBox)] + for i, x in zip(indices, broadcast_tensors(*[args[i] for i in indices])): + args[i] = x + for i in range(len(args)): + if isinstance(args[i], ir.Constant): + args[i] = ExpandView.create(args[i], list(args[indices[0]].get_size())) + register_fn_to_aten_fn(fn, aten.where) + return make_pointwise(fn, override_return_dtype=dtype)( + args[0], to_dtype(args[1], dtype), to_dtype(args[2], dtype) + ) + + @register_lowering(aten.broadcast_tensors, broadcast=False, type_promotion_kind=None) + def broadcast_tensors(*inputs): + if len(inputs) == 1 and isinstance(inputs[0], (list, tuple)): + return broadcast_tensors(*inputs[0]) + target: List[sympy.Expr] = functools.reduce( + lowering.broadcast_symbolic_shapes, [x.get_size() for x in inputs], [] + ) + outputs = [] + for x in inputs: + sizes = x.get_size() + if len(sizes) != len(target) or any( + ( + ( + V.graph.sizevars.shape_env.evaluate_expr( + sympy.Eq(a, 1), size_oblivious=True + ) + and not V.graph.sizevars.shape_env.evaluate_expr( + sympy.Eq(b, 1), size_oblivious=True + ) + ) + or ( + not V.graph.sizevars.shape_env.evaluate_expr( + sympy.Eq(a, 1), size_oblivious=True + ) + and V.graph.sizevars.shape_env.evaluate_expr( + sympy.Eq(b, 1), size_oblivious=True + ) + ) + ) + for a, b in zip(sizes, target) + ): + x = expand(x, target) + outputs.append(x) + return outputs + + @register_lowering(aten.squeeze, type_promotion_kind=None) + def squeeze(x, dim=None): + if not isinstance(x, TensorBox): + raise RuntimeError("assert x should be instance of TensorBox") + + if dim is None: + return TensorBox(SqueezeView.create(x.data)) + + dim = ( + V.graph.sizevars.evaluate_static_shape(dim) + if isinstance(dim, (int, sympy.Expr)) + else tuple(V.graph.sizevars.evaluate_static_shape(d) for d in dim) + ) + dim = canonicalize_dims(len(x.get_size()), dim) # type: ignore[call-overload] + dims = set((dim,) if not isinstance(dim, tuple) else dim) + + new_shape = [] + for d, s in enumerate(x.get_size()): + if not ( + d in dims + and V.graph.sizevars.evaluate_expr(sympy.Eq(s, 1, size_oblivious=True)) + ): + new_shape.append(s) + + # squeeze does nothing if the size isn't 1 + return view(x, new_shape) if new_shape != x.get_size() else x + + @register_lowering([aten.squeeze_]) + def squeeze_(x, dim=None): + val = squeeze(x, dim) + if not isinstance(x, TensorBox): + raise RuntimeError("assert x should be instance of TensorBox") + if not isinstance(val, TensorBox): + raise RuntimeError("assert val should be instance of TensorBox") + x.data = val.data + return x + + @register_lowering(aten.isinf) + def isinf(x): + if lowering.is_integer_type(x): + return full_like(x, False, dtype=torch.bool) + fn = ops_wrapper("isinf") + register_fn_to_aten_fn(fn, aten.isinf) + return make_pointwise(fn, override_return_dtype=torch.bool)(x) + + @register_lowering(aten.isnan) + def isnan(x): + if lowering.is_integer_type(x): + return full_like(x, False, dtype=torch.bool) + fn = ops_wrapper("isnan") + register_fn_to_aten_fn(fn, aten.isnan) + return make_pointwise(fn, override_return_dtype=torch.bool)(x) + + @register_lowering(aten.ceil) + def ceil(x): + if lowering.is_integer_type(x): + return clone(x) + fn = ops_wrapper("ceil") + register_fn_to_aten_fn(fn, aten.ceil) + return make_pointwise(fn)(x) + + @register_lowering(aten.floor) + def floor(x): + if lowering.is_integer_type(x): + return clone(x) + fn = ops_wrapper("floor") + register_fn_to_aten_fn(fn, aten.floor) + return make_pointwise(fn)(x) + + @register_lowering(aten.round.default) + def round(x): + if lowering.is_integer_type(x): + return clone(x) + else: + fn = ops_wrapper("round") + register_fn_to_aten_fn(fn, aten.round) + return make_pointwise(fn)(x) + + @register_lowering(aten.trunc) + def trunc(x): + if lowering.is_integer_type(x): + return clone(x) + fn = ops_wrapper("trunc") + register_fn_to_aten_fn(fn, aten.trunc) + return make_pointwise(fn)(x) + + @register_lowering(aten.expand, type_promotion_kind=None) + def expand(x, sizes): + from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols + + (x,) = lowering.promote_constants([x]) + if isinstance(x, ir.BaseConstant): + return ExpandView.create(x, tuple(sizes)) + if not isinstance(x, TensorBox): + raise RuntimeError("assert x should be instance of TensorBox") + if not isinstance(sizes, (list, tuple)): + raise RuntimeError("assert x should be instance of (list, tuple)") + if tuple(x.get_size()) == tuple(sizes): + return x + + if not free_unbacked_symbols(x.get_size()): + x_size_product = V.graph.sizevars.size_hint(sympy_product(x.get_size())) + # It would be better to realize the input if any of its sizes + # are unbacked, because typically the size will be non-zero. However, + # this cannot be done directly as below as we'll choke on the size_hint + # here + if x_size_product > 0 and not free_unbacked_symbols(sizes): + # maybe realize input before broadcasting it + x.mark_reuse( + V.graph.sizevars.size_hint(sympy_product(sizes)) // x_size_product + ) + input_graphs = fetch_graphs([x.data, tuple(sizes)]) + node_name = f'expand_{next(node_id)}' + new_graph = merge_traced_graphs(input_graphs, aten.expand, node_name) + return TensorBox(ExpandView.create(x.data, tuple(sizes), traced_graph=new_graph, node_name=node_name)) + + @register_lowering(aten.expand_as, type_promotion_kind=None) + def expand_as(x, y): + return expand(x, y.get_size()) + + @register_lowering(aten.repeat) + def repeat(x, repeats): + input_graphs = fetch_graphs([x, repeats]) + node_name = f'repeat_{next(node_id)}' + new_graph = merge_traced_graphs(input_graphs, aten.repeat, node_name) + old_size = list(x.get_size()) + if len(repeats) > len(old_size): + old_size = [sympy.S.One] * (len(repeats) - len(old_size)) + old_size + x = view(x, list(old_size)) + if len(repeats) != len(x.get_size()): + raise RuntimeError("assert repeat should have same size as x.size") + + new_size = list(x.get_size()) + + zero_tensor = False + for i in range(len(repeats)): + if repeats[i] == 0: + zero_tensor = True + new_size[i] = new_size[i] * repeats[i] + + if zero_tensor: + return empty(new_size, dtype=x.get_dtype(), device=x.get_device()) + if all((a == 1 or b == 1) for a, b in zip(repeats, old_size)): + return clone(expand(x, new_size)) + + x_loader: Callable[[Any], Any] + + def inner_fn(index): + if len(index) != len(repeats): + raise RuntimeError("assert repeat should have same length as repeats") + index = list(index) + for i in range(len(repeats)): + if repeats[i] != 1: + if old_size[i] == 1: + index[i] = sympy.S.Zero + else: + index[i] = ModularIndexing(index[i], 1, old_size[i]) + return x_loader(index) + + old_size_product = V.graph.sizevars.size_hint(sympy_product(old_size)) + if old_size_product > 0: + # maybe realize the input + x.mark_reuse( + V.graph.sizevars.size_hint(sympy_product(new_size)) // old_size_product + ) + + x_loader = x.make_loader() + return Pointwise.create( + device=x.get_device(), + dtype=x.get_dtype(), + inner_fn=inner_fn, + ranges=list(new_size), + traced_graph=new_graph, + node_name=node_name + ) + + @register_lowering(aten._unsafe_view, type_promotion_kind=None) + @register_lowering(aten.view, type_promotion_kind=None) + @register_lowering(aten.reshape, type_promotion_kind=None) + def view(x, sizes): + if not isinstance(x, TensorBox): + raise RuntimeError("assert x should be instance of TensorBox") + if not isinstance(sizes, (list, tuple)): + raise RuntimeError("assert sizes should be instance of (list, tuple)") + input_graphs = fetch_graphs([x.data, sizes]) + node_name = f'view_{next(node_id)}' + new_graph = merge_traced_graphs(input_graphs, aten.reshape, node_name) + return TensorBox(View.create(x.data, sizes, traced_graph=new_graph, node_name=node_name)) + + @register_lowering(aten.permute, type_promotion_kind=None) + def permute(x, dims): + if not isinstance(x, TensorBox): + raise RuntimeError("assert x should be instance of TensorBox") + if not isinstance(dims, (list, tuple)): + raise RuntimeError("assert dims should be instance of (list, tuple)") + input_graphs = fetch_graphs([x.data, dims]) + node_name = f'permute_{next(node_id)}' + new_graph = merge_traced_graphs(input_graphs, aten.permute, node_name) + return TensorBox(PermuteView.create(x.data, tuple(dims), traced_graph=new_graph, node_name=node_name)) + + @register_lowering(aten.slice, type_promotion_kind=None) + def slice_(x, dim=0, start=0, end=2 ** 63, step=1, clamp=True): + if not isinstance(x, TensorBox): + raise RuntimeError("assert x should be instance of TensorBox") + dim = _validate_dim(x, dim, 0) + input_graphs = fetch_graphs([x.data]) + node_name = f'slice_{next(node_id)}' + new_graph = merge_traced_graphs(input_graphs, aten.slice, node_name, dim=dim, start=start, end=end, step=step) + + return TensorBox( + ir.SliceView.create(x.data, dim, start, end, step, traced_graph=new_graph, node_name=node_name)) + + @register_lowering(aten.select, type_promotion_kind=None) + def select(x, dim, idx): + idx = View.handle_negative_index(idx, x.get_size()[dim]) + return squeeze(slice_(x, dim, idx, idx + 1), dim) + + @register_lowering(aten.split, type_promotion_kind=None) + def split(x, sizes, dim=0): + dim = _validate_dim(x, dim, 0) + sizes_ = sizes + + # If sizes is an integer (or a SymInt), we turn it into a list of sizes + # by computing what the actual size of each chunk should be. + if not isinstance(sizes, (list, tuple)): + x_size = x.get_size()[dim] + chunks = V.graph.sizevars.evaluate_static_shape( + FloorDiv(x_size + sizes - 1, sizes) + ) + sizes_ = [sizes] * chunks + # The last chunk might have a smaller size than the rest. + sizes_[-1] = x_size - (chunks - 1) * sizes + + # From this point, we assume that the sum of the sizes of all chunks + # equals the size of the base tensor. + result = [] + start = 0 + for size in sizes_: + end = start + size + # No need for clamping here, since we compute the exact + # start and end values. + result.append(slice_(x, dim, start, end, clamp=False)) + start = end + return result + + @register_lowering(aten.split_with_sizes, type_promotion_kind=None) + def split_with_sizes(x, sizes, dim=0): + return split(x, sizes, dim) + + @register_lowering(aten.unbind, type_promotion_kind=None) + def unbind(x, dim=0): + dim = _validate_dim(x, dim, 0) + x_size = V.graph.sizevars.evaluate_static_shape(x.get_size()[dim]) + result = [select(x, dim, i) for i in range(x_size)] + return result + + @register_lowering(aten.unsqueeze, type_promotion_kind=None) + def unsqueeze(x, dim): + dim = _validate_dim(x, dim, 1) + new_shape = list(x.get_size()) + new_shape.insert(dim, sympy.S.One) + return view(x, new_shape) + + @register_lowering(aten.unsqueeze_, type_promotion_kind=None) + def unsqueeze_(x, dim): + val = unsqueeze(x, dim) + if not isinstance(x, TensorBox): + raise RuntimeError("assert x should be instance of TensorBox") + if not isinstance(val, TensorBox): + raise RuntimeError("assert val should be instance of TensorBox") + x.data = val.data + return x + + def _validate_dim(x, dim, offset=0): + dim = V.graph.sizevars.shape_env.evaluate_expr(sympy.sympify(dim)) + ndim = len(x.get_size()) + if dim < 0: + dim += ndim + offset + if not (0 <= dim < ndim + offset): + raise RuntimeError(f"assert dim {dim} is out of bounds. Expected: 0 <= dim < {ndim + offset}") + return dim + + @register_lowering(aten.copy, type_promotion_kind=None) + def copy(self, src, non_blocking=False): + x = src + if self.get_device() != src.get_device(): + x = lowering.to_device(x, self.get_device()) + if self.get_dtype() != src.get_dtype(): + x = to_dtype(x, self.get_dtype()) + + if self.get_size() != src.get_size(): + out = expand(x, self.get_size()) + return clone(out) + return clone(x) + + @register_lowering(prims.iota) + def iota( + length, + *, + start, + step, + dtype, + device, + requires_grad, + ): + def fn(index): + return ops.index_expr(step * index[0] + start, dtype=dtype) + + node_name = f'iota_{next(node_id)}' + new_graph = merge_traced_graphs([length], prims.iota, node_name, \ + start=start, step=step, \ + dtype=dtype, device=device, \ + requires_grad=requires_grad) + return Pointwise.create( + device=decode_device(device), + dtype=dtype, + inner_fn=fn, + ranges=[length], + traced_graph=new_graph, + node_name=node_name + ) + + @register_lowering(aten.select_scatter, type_promotion_kind=None) + def select_scatter(x, src, dim: int, index: int): + if x.get_dtype() != src.get_dtype(): + raise RuntimeError(f"assert Expected dtype {src.get_dtype()}, but got {x.get_dtype()}") + input_graphs = fetch_graphs([x, src, dim, index]) + node_name = f'select_scatter_{next(node_id)}' + new_graph = merge_traced_graphs(input_graphs, aten.select_scatter, node_name) + x_loader = x.make_loader() + dim = _validate_dim(x, dim, 0) + if V.graph.sizevars.evaluate_expr(sympy.Lt(index, 0)): + index = index + x.get_size()[dim] + V.graph.sizevars.guard_leq(0, index) # type: ignore[arg-type] + V.graph.sizevars.guard_lt(index, x.get_size()[dim]) # type: ignore[arg-type] + src = expand(unsqueeze(src, dim), x.get_size()) + src_loader = src.make_loader() + + def inner_fn(idx): + return ops.where( + ops.eq( + ops.index_expr(idx[dim], torch.int32), + ops.index_expr(index, torch.int32), + ), + src_loader(idx), + x_loader(idx), + ) + + return Pointwise.create( + device=x.get_device(), + dtype=x.get_dtype(), + inner_fn=inner_fn, + ranges=list(x.get_size()), + traced_graph=new_graph, + node_name=node_name + ) + + @register_lowering(aten.slice_scatter, type_promotion_kind=None) + def slice_scatter(x, src, dim=0, start=None, end=None, step=1): + if x.get_dtype() != src.get_dtype(): + raise RuntimeError(f"assert Expected dtype {src.get_dtype()}, but got {x.get_dtype()}") + input_graphs = fetch_graphs([x, src]) + node_name = f'slice_scatter_{next(node_id)}' + new_graph = merge_traced_graphs(input_graphs, aten.slice_scatter, node_name, \ + dim=dim, + start=start, + end=end, + step=step) + x_loader = x.make_loader() + dim = _validate_dim(x, dim, 0) + dim_size = x.get_size()[dim] + + start, end = ir.SliceView.normalize_start_end(x, dim, start, end) + + src_size = list(x.get_size()) + src_size[dim] = FloorDiv(end - start + (step - 1), step) + src = expand(src, src_size) + src_loader = src.make_loader() + + def inner_fn(idx): + if start == 0 and end == dim_size and step == 1: + # selecting every element is the same as just src.clone() + return src_loader(idx) + + idx_dim = ops.index_expr(idx[dim], torch.int64) + src_idx = list(idx) + src_idx[dim] = FloorDiv(idx[dim] - start, step) + + mask = [] + if start != 0: + mask.append( + ops.ge( + idx_dim, + ops.index_expr(sympy.expand(start), torch.int64), + ) + ) + if end != dim_size: + mask.append( + ops.lt( + idx_dim, + ops.index_expr(sympy.expand(end), torch.int64), + ) + ) + if step != 1: + mask.append( + ops.eq( + ops.index_expr( + ModularIndexing(idx[dim] - start, 1, step), torch.int64 + ), + ops.constant(0, torch.int64), + ) + ) + if not mask: + raise RuntimeError("assert mask cannot be empty") + mask = functools.reduce(ops.and_, mask) + src_val = ops.masked( + mask, + lambda: src_loader(src_idx), + 0 if lowering.is_integer_type(x) else 0.0, + ) + return ops.where( + mask, + src_val, + x_loader(idx), + ) + + return Pointwise.create( + device=x.get_device(), + dtype=x.get_dtype(), + inner_fn=inner_fn, + ranges=list(x.get_size()), + traced_graph=new_graph, + node_name=node_name + ) + + @register_lowering([torch.tensor, aten.scalar_tensor]) + def tensor(data, *, dtype=None, device=None, layout=None, pin_memory=False): + lowering.assert_nyi(layout in (None, torch.strided), f"layout={layout}") + lowering.assert_nyi(not pin_memory, "pin_memory") + input_graphs = fetch_graphs([data]) + node_name = f'tensor_{next(node_id)}' + new_graph = merge_traced_graphs(input_graphs, aten.scalar_tensor, node_name, \ + dtype=dtype, + device='npu', + layout=layout, + pin_memory=False) + if isinstance(lowering._unwrap(data), int): + dtype = dtype or torch.int64 + else: + dtype = dtype or torch.get_default_dtype() + + ranges: List[sympy.Expr] = [] + + if isinstance(data, sympy.Basic): + + def inner_fn(index): + return ops.index_expr(data, dtype) + + elif isinstance(data, (float, int)): + + def inner_fn(index): + return ops.constant(data, dtype) + + elif len(data) == 0 or isinstance(data[0], (float, int)) and len(data) <= 8: + # inline small tensors + ranges.append(sympy.Integer(len(data))) + + def inner_fn(index): + def binary_search(start, end): + if start >= end: + raise RuntimeError(f"assert start ({start}) must be less than end ({end})") + if end - start == 1: + return ops.constant(data[start], dtype) + mid = (end - start) // 2 + start + return ops.where( + ops.lt( + ops.index_expr(index[0], torch.int64), + ops.constant(mid, torch.int64), + ), + binary_search(start, mid), + binary_search(mid, end), + ) + + if len(data) == 0: + return ops.constant(0, dtype) + return binary_search(0, len(data)) + + else: + return V.graph.add_tensor_constant( + torch.tensor(data, dtype=dtype, device=device) + ) + + return Pointwise.create( + device=decode_device(device), + dtype=dtype, + inner_fn=inner_fn, + ranges=ranges, + traced_graph=new_graph, + node_name=node_name + ) + + def tensor_constructor(fill_value): + # torch.zeros, torch.ones, etc + def inner( + *size, + names=None, + dtype=None, + device=None, + layout=None, + pin_memory=False, + memory_format=None, + ): + lowering.assert_nyi(names is None, "named tensors") + lowering.assert_nyi(layout in (None, torch.strided), f"layout={layout}") + lowering.assert_nyi(not pin_memory, "pin_memory") + device = decode_device(device) + dtype = dtype or torch.get_default_dtype() + if len(size) == 1 and isinstance(size[0], (list, tuple, torch.Size)): + size = tuple(size[0]) + # See pytorch issues 118102 + # All sizes at lowering time should be sympy.Symbol, not SymInt! + for s in size: + if isinstance(s, torch.SymInt): + raise RuntimeError("assert s must not be of type torch.SymInt") + size = [sympy.expand(s) for s in size] + return _full(fill_value, device, dtype, size) + + return inner + + def _full(fill_value, device, dtype, size): + value = fill_value + if not isinstance(fill_value, (int, float)) and hasattr(value, "value"): + value = value.value + + if isinstance(value, (int, float)): + + def inner_fn(index): + return ops.constant(value, dtype) + + elif isinstance(value, sympy.Basic): + + def inner_fn(index): + return ops.index_expr(value, dtype) + + else: + if len(value.get_size()) != 0: + raise RuntimeError("assert value should be equal to 0") + value_loader = value.make_loader() + + def inner_fn(index): + return value_loader([]) + + node_name = f'full_{next(node_id)}' + new_graph = merge_traced_graphs([size, fill_value], aten.full.default, node_name, \ + device='npu', dtype=dtype, layout=torch.strided, pin_memory=False) + + return Pointwise.create( + device=device, + dtype=dtype, + inner_fn=inner_fn, + ranges=list(size), + traced_graph=new_graph, + node_name=node_name + ) + + @register_lowering(aten.empty_strided) + def empty_strided( + size, stride, *, dtype=None, layout=None, device=None, pin_memory=None + ): + if not isinstance(size, (list, tuple)): + raise RuntimeError(f"assert Expected list or tuple") + if not isinstance(stride, (list, tuple)): + raise RuntimeError(f"assert Expected list or tuple or None") + lowering.assert_nyi(not pin_memory, "pin_memory") + lowering.assert_nyi(layout in (None, torch.strided), f"layout={layout}") + dtype = lowering.decode_dtype(dtype) or torch.get_default_dtype() + device = device or torch.tensor(0.0).device + device = decode_device(device) + pointwise = _full(fill_value=0, device=device, dtype=dtype, size=size) + pointwise.realize() + buffer = pointwise.data.data + # explicitly set ranges to zeros in order to make a NopKernelSchedulerNode + buffer.data = lowering.dataclasses.replace(buffer.data, ranges=[0] * len(size)) + if not isinstance(buffer, ir.ComputedBuffer): + raise RuntimeError(f"assert Expected ir.ComputedBuffer") + size = [sympy.expand(s) for s in size] + stride = ( + [sympy.expand(s) for s in stride] + if stride + else ir.FlexibleLayout.contiguous_strides(size) + ) + buffer.layout = ir.FixedLayout( + device=device, + dtype=dtype, + size=size, + stride=stride, + ) + return pointwise + + @register_lowering([torch.empty, aten.empty]) + def empty( + *size, + names=None, + dtype=None, + layout=None, + device=None, + pin_memory=None, + memory_format=None, + ): + lowering.assert_nyi(names is None, "named tensors") + device = decode_device(device) + if len(size) == 1 and isinstance(size[0], (list, tuple, torch.Size)): + size = tuple(size[0]) + return empty_strided( + size, None, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory + ) + + @register_lowering([torch.full, aten.full]) + def full(size, fill_value, **kwargs): + if kwargs.get("dtype") is None: + raise RuntimeError("assert kwargs dtype should be handled by decomposition") + return tensor_constructor(fill_value)(size, **kwargs) + + register_lowering(aten.clone)(clone) + + @register_lowering(aten.constant_pad_nd, type_promotion_kind=None) + def constant_pad_nd(x, padding, fill_value=0): + if (len(padding) % 2) != 0: + raise RuntimeError("assert len(padding) must % 2=0") + + input_graphs = fetch_graphs([x, padding]) + node_name = f'constand_pad_nd_{next(node_id)}' + new_graph = merge_traced_graphs(input_graphs, aten.constant_pad_nd, node_name, value=fill_value) + + if all(p == 0 for p in padding): + return clone(x) + + sizes = x.get_size() + + bounds = list(reversed(list(zip(padding[::2], padding[1::2])))) + n = len(sizes) - len(bounds) + + # if padding is a complicated expression, hoist it + bounds_precomp: List[Tuple[sympy.Symbol, Any]] = [] + for low, high in bounds: + bounds_precomp.append((V.graph.sizevars.lookup_precomputed_size(low), high)) # type: ignore[arg-type] + + output_size = list(sizes[:n]) + mask_sizes = [] + for (low, high), size in zip(bounds, sizes[n:]): + mask_sizes.append(size) + output_size.append(sympy.expand(size + low + high)) + if len(output_size) != len(sizes): + raise RuntimeError("assert len(output_size) must equal to len(sizes)") + fill_value = dtype_to_type(x.get_dtype())(fill_value) + + def mask(index): + mask = [] + for idx, (low, high), length in zip(index[n:], bounds, mask_sizes): + if low != 0: + mask.append(lowering.range_mask_low(idx, 0)) + if high != 0: + mask.append(lowering.range_mask_high(idx, length)) + mask = functools.reduce(ops.and_, mask) + return ops.masked(mask, lambda: x_loader(index), fill_value) + + def offset_fn(index): + new_index = list(index[:n]) + for idx, (low, high) in zip(index[n:], bounds_precomp): + new_index.append(idx - low) + if len(new_index) != len(index): + raise RuntimeError("assert len(new_index) must equal len(index)") + return mask(new_index) + + x_loader = x.make_loader() + return Pointwise.create( + device=x.get_device(), + dtype=x.get_dtype(), + inner_fn=offset_fn, + ranges=output_size, + traced_graph=new_graph, + node_name=node_name + ) + + @make_pointwise + @register_to_aten(aten_fn=aten.pow) + def pow_native(a, b): + return ops.pow(a, b) + + @register_lowering(aten.pow, broadcast=True) + def pow(a, b): + if isinstance(b, float) and b == int(b): + return pow(a, int(b)) + elif isinstance(b, float) and b == 0.5: + return sqrt(a) + elif isinstance(b, int) and b == 1: + return clone(a) + + input_graphs = fetch_graphs([a, b]) + node_name = f'pointwise_{next(node_id)}' + new_graph = merge_traced_graphs(input_graphs, aten.pow, node_name) + + # Type promotion ensures all tensor arguments have the same type + dtype = next(x.get_dtype() for x in (a, b) if isinstance(x, ir.TensorBox)) + is_integer_pow = is_integer_dtype(dtype) + + # Optimize away small fixed powers, or for integers avoid falling back to ATen + embed_exponent = isinstance(b, int) and ( + -32 < b < 32 or (is_integer_pow and b >= 0) + ) + if embed_exponent: + loader = a.make_loader() + + def fn(idx): + return lowering.pow_recursive(loader(idx), b, a.get_dtype()) + + return Pointwise.create( + device=a.get_device(), + dtype=a.get_dtype(), + inner_fn=fn, + ranges=a.get_size(), + node_name=node_name, + traced_graph=new_graph, + ) + + if isinstance(a, Number): + if a == 1: + return full_like(b, 1) + if a == 2 and is_float_dtype(b.get_dtype()): + return exp2(b) + + if is_integer_pow: + # ops.pow doesn't work for integers + if isinstance(a, Number): + return lowering.fallback_pow_scalar(a, b) + elif isinstance(b, Number): + return lowering.fallback_pow_tensor_scalar(a, b) + else: + return lowering.fallback_pow_tensor_tensor(a, b) + + return pow_native(a, b) + + def mutate_to(changed, val, unsafe_alias=False): + if isinstance(changed, TensorBox): + changed_data = changed.data + else: + changed_data = changed + if isinstance(val, TensorBox): + val = val.data + + if not isinstance(val, ir.StorageBox): + # introduce a copy to handle views + input_graphs = fetch_graphs([changed, val]) + node_name = f'copy__{next(node_id)}' + new_graph = merge_traced_graphs(input_graphs, aten.copy_, node_name) + val = Pointwise.create( + device=changed.get_device(), + dtype=changed.get_dtype(), + inner_fn=val.make_loader(), + ranges=changed.get_size(), + traced_graph=new_graph, + node_name=node_name + ).data + if not isinstance(val, ir.StorageBox): + raise RuntimeError("assert val should be instance of ir.StorageBox") + + if isinstance(changed_data, ir.StorageBox) and not ( + changed_data.is_input_buffer() + # In AOTI, module parameters and buffers are not lifted as graph inputs + or changed_data.is_module_buffer() + or isinstance(changed_data.data, ir.NopKernel) + ): + # Fast path, just swing the data pointer + val.realize() + changed_data.data = val.data + return changed + + ir.MutationLayoutSHOULDREMOVE.realize_into( + val, changed_data, unsafe_alias=unsafe_alias + ) + return changed + + empty_like = register_lowering(aten.empty_like)(lowering.create_tensor_like(empty)) + ones_like = lowering.create_tensor_like(tensor_constructor(1)) + zeros_like = lowering.create_tensor_like(tensor_constructor(0)) + + @register_lowering(aten.full_like, type_promotion_kind=None) + def full_like(x, fill_value, **kwargs): + return lowering.create_tensor_like(tensor_constructor(fill_value))(x, **kwargs) + + @register_lowering(aten.fill_) + def fill_(x, fill_value): + return mutate_to(x, full_like(x, fill_value)) + + @register_lowering(aten.copy_, type_promotion_kind=None) + def copy_(dst, src, non_blocking=False): + if dst is src: + # dst.copy_(dst) can happen from the reinplacing pass + return dst + src = lowering.to_device(src, dst.get_device()) + src = to_dtype(src, dst.get_dtype()) + src = expand(src, dst.get_size()) + return mutate_to(dst, src) + + @make_pointwise + def floordiv(a, b): + return ops.floordiv(a, b) + + @make_pointwise + def truncdiv(a, b): + return ops.truncdiv(a, b) + + @register_lowering(aten.div, broadcast=True) + def div_mode(a, b, rounding_mode=None): + both_integer = lowering.is_integer_type(a) and lowering.is_integer_type(b) + both_boolean = lowering.is_boolean_type(a) and lowering.is_boolean_type(b) + + # floordiv and truncdiv need special handling for integer tensors on Triton, + # see the discussion at openai triton issues 605 + if rounding_mode == "floor": + if both_boolean: + raise RuntimeError("assert floordiv operands cannot be boolean at the same time") + return floordiv(a, b) if both_integer else floor(div(a, b)) + if rounding_mode == "trunc": + if both_boolean: + raise RuntimeError("assert truncdiv operands can not be boolean at the same time") + return truncdiv(a, b) if both_integer else trunc(div(a, b)) + return div(a, b) + + @register_lowering([aten.mul], broadcast=True) + def mul(a, b): + both_bool = lowering.is_boolean_type(a) and lowering.is_boolean_type(b) + if both_bool: + return logical_and(a, b) + else: + fn = ops_wrapper(aten.mul.__name__) + fn = register_fn_to_aten_fn(fn, aten.mul) + return make_pointwise(fn)(a, b) + + @register_lowering([aten.reciprocal], broadcast=True, ) + def reciprocal(a): + return div(1.0, a) + + @register_lowering([prims.div], broadcast=True) + def div_prim(a, b): + is_integral = all(lowering.is_boolean_type(x) or lowering.is_integer_type(x) for x in [a, b]) + + if is_integral: + return truncdiv(a, b) + + if (divisor := lowering.get_constant_value(b)) is not None: + # Replace divide by constant with multiply by reciprocal + if divisor.value == 0: + reciprocal = math.copysign(float("inf"), divisor.value) + else: + reciprocal = 1.0 / divisor.value + return mul(a, reciprocal) + + def fn(*args): + return ops.truediv(*args) + + fn = register_fn_to_aten_fn(fn, aten.div) + return make_pointwise(fn)(a, b) + + @register_lowering( + [aten.true_divide, aten.div.Tensor], + broadcast=True, + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + ) + def div(a, b): + a, b = lowering.promote_constants( + (a, b), type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT + ) + return div_prim(a, b) + + @register_lowering(aten.rsqrt) + def rsqrt(x): + dtype = x.get_dtype() + if is_integer_dtype(dtype) or is_boolean_dtype(dtype): + x = to_dtype(x, torch.get_default_dtype()) + + def _rsqrt(x): + return ops.rsqrt(x) + + register_fn_to_aten_fn(_rsqrt, aten.rsqrt) + return make_pointwise(_rsqrt)(x) + + @register_lowering(aten.prod) + def prod(x, axis=None, keepdims=False, *, dtype=None): + if ( + is_integer_dtype(x.get_dtype()) or is_boolean_dtype(x.get_dtype()) + ) and dtype is None: + dtype = torch.int64 + + fn = make_reduction("prod", override_return_dtype=dtype) + return fn(x, axis, keepdims, dtype=dtype) + + @register_lowering(aten.any) + def reduce_any(x, dim=None, keepdim=False): + x = to_dtype(x, torch.bool) + return make_reduction("any")(x, axis=dim, keepdims=keepdim) + + @register_lowering(aten.max, type_promotion_kind=None) + def reduce_max(x, dim=None, keepdim=False): + if dim is not None: + return ( + reduce_amax(x, axis=dim, keepdims=keepdim), + reduce_argmax(x, axis=dim, keepdims=keepdim), + ) + + return reduce_amax(x, axis=None, keepdims=keepdim) + + @register_lowering(aten.min, type_promotion_kind=None) + def reduce_min(x, dim=None, keepdim=False): + if dim is not None: + return ( + reduce_amin(x, axis=dim, keepdims=keepdim), + reduce_argmin(x, axis=dim, keepdims=keepdim), + ) + + return reduce_amin(x, axis=None, keepdims=keepdim) + + register_lowering(prims.xor_sum)(make_reduction("xor_sum")) + reduce_amax = register_lowering(aten.amax)(make_reduction("max")) + reduce_amin = register_lowering(aten.amin)(make_reduction("min")) + reduce_argmax = register_lowering(aten.argmax)( + make_reduction("argmax", override_return_dtype=torch.int64) + ) + reduce_argmin = register_lowering(aten.argmin)( + make_reduction("argmin", override_return_dtype=torch.int64) + ) + + add = register_pointwise( + aten.add, allow_alpha=True, override_fn_when_input_bool="logical_or" + ) + + def register_pointwise_numeric(op, name=None, triton_fallback=None): + return register_pointwise( + op, + name=name, + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + triton_fallback=triton_fallback, + ) + + def register_pointwise_numeric_ldf64(op): + return register_pointwise( + op, + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + use_libdevice_for_f64=True, + ) + + def register_inplace(aten_op, outplace_op): + @register_lowering(aten_op, type_promotion_kind=None) + def fn(*args, **kwargs): + result = outplace_op(*args, **kwargs) + result = to_dtype(result, args[0].get_dtype()) + return mutate_to(args[0], result) + + return fn + + rsqrt = register_pointwise_numeric(aten.rsqrt) + exp = register_pointwise_numeric_ldf64(aten.exp) + exp2 = register_pointwise_numeric(aten.exp2) + expm1 = register_pointwise_numeric(aten.expm1) + relu = register_pointwise(aten.relu) + sigmoid = register_pointwise_numeric_ldf64(aten.sigmoid) + sqrt = register_pointwise_numeric_ldf64(aten.sqrt) + square = register_pointwise(aten.square) + sub = register_pointwise(aten.sub, allow_alpha=True) + register_pointwise_numeric_ldf64(aten.cos) + register_pointwise_numeric_ldf64(aten.sin) + abs_val = register_pointwise(aten.abs) + bitwise_and = register_pointwise(aten.bitwise_and) + bitwise_left_shift = register_pointwise(aten.bitwise_left_shift) + bitwise_not = register_pointwise( + aten.bitwise_not, override_fn_when_input_bool="logical_not" + ) + bitwise_or = register_pointwise(aten.bitwise_or) + bitwise_right_shift = register_pointwise(aten.bitwise_right_shift) + bitwise_xor = register_pointwise(aten.bitwise_xor) + register_pointwise_numeric(aten.lgamma) + erf = register_pointwise_numeric(aten.erf) + register_lowering( + aten.special_erf, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT + )(erf) + + register_pointwise_numeric(aten.log1p) + register_pointwise_numeric(aten.tan) + register_pointwise_numeric(aten.tanh) + register_pointwise_numeric_ldf64(aten.log) + logical_and = register_pointwise( + aten.logical_and, + type_promotion_kind=None, + convert_input_to_bool=True, + override_return_dtype=torch.bool, + ) + logical_not = register_pointwise( + aten.logical_not, + type_promotion_kind=None, + convert_input_to_bool=True, + override_return_dtype=torch.bool, + ) + logical_or = register_pointwise( + aten.logical_or, + type_promotion_kind=None, + convert_input_to_bool=True, + override_return_dtype=torch.bool, + ) + logical_xor = register_pointwise( + aten.logical_xor, + type_promotion_kind=None, + convert_input_to_bool=True, + override_return_dtype=torch.bool, + ) + maximum = register_pointwise(aten.maximum) + minimum = register_pointwise(aten.minimum) + clamp_min = register_pointwise(aten.clamp_min, name='maximum') + clamp_max = register_pointwise(aten.clamp_max, name='minimum') + neg = register_pointwise(aten.neg) + abs_val1 = register_pointwise(aten.abs) + register_pointwise(aten.remainder) + sign = register_pointwise(aten.sign, override_fn_when_input_bool="identity") + register_pointwise(aten.ceil) + register_pointwise(aten.signbit, override_return_dtype=torch.bool) + + register_lowering(aten._neg_view)(neg) + + register_pointwise(aten.le, override_return_dtype=torch.bool) + register_pointwise(aten.lt, override_return_dtype=torch.bool) + register_pointwise(aten.ge, override_return_dtype=torch.bool) + gt = register_pointwise(aten.gt, override_return_dtype=torch.bool) + register_pointwise(aten.eq, override_return_dtype=torch.bool) + register_pointwise(aten.ne, override_return_dtype=torch.bool) + + register_pointwise_numeric(aten.cosh) + register_pointwise_numeric(aten.sinh) + register_pointwise_numeric(aten.acos) + register_pointwise_numeric(aten.acosh) + register_pointwise_numeric(aten.asin) + register_pointwise_numeric(aten.asinh) + register_pointwise_numeric(aten.atan2) + register_pointwise_numeric(aten.atan) + register_pointwise_numeric(aten.atanh) + register_pointwise_numeric(aten.copysign) + register_pointwise_numeric(aten.erfc) + register_pointwise_numeric(aten.erfinv) + register_pointwise_numeric(aten.hypot) + register_pointwise_numeric(aten.log10) + register_pointwise_numeric(aten.log2) + register_pointwise_numeric(aten.nextafter) + + register_inplace(aten.add_, add) + register_inplace(aten.bitwise_and_, bitwise_and) + register_inplace(aten.bitwise_left_shift_, bitwise_left_shift) + register_inplace(aten.bitwise_not_, bitwise_not) + register_inplace(aten.bitwise_or_, bitwise_or) + register_inplace(aten.bitwise_right_shift_, bitwise_right_shift) + register_inplace(aten.bitwise_xor_, bitwise_xor) + register_inplace(aten.mul_, mul) + register_inplace(aten.div_.Tensor, div) + register_inplace(aten.div_.Tensor_mode, div_mode) + register_inplace(aten.logical_and_, logical_and) + register_inplace(aten.logical_not_, logical_not) + register_inplace(aten.logical_or_, logical_or) + register_inplace(aten.logical_xor_, logical_xor) + register_inplace(aten.sub_, sub) + register_inplace(aten.relu_, relu) + register_inplace(aten.sigmoid_, sigmoid) + + register_lowering(aten.__and__)(bitwise_and) + register_lowering(aten.__lshift__)(bitwise_left_shift) + register_lowering(aten.__or__)(bitwise_or) + register_lowering(aten.__rshift__)(bitwise_right_shift) + register_lowering(aten.__xor__)(bitwise_xor) + + register_inplace(aten.__iand__, aten.__and__) + register_inplace(aten.__ilshift__, aten.__lshift__) + register_inplace(aten.__ior__, aten.__or__) + register_inplace(aten.__irshift__, aten.__rshift__) + register_inplace(aten.__ixor__, aten.__xor__) + + ########################################################################## + + @register_lowering([aten.sum, prims.sum]) + def sum_(x, axis=None, keepdims=False, *, dtype=None): + if ( + is_integer_dtype(x.get_dtype()) or is_boolean_dtype(x.get_dtype()) + ) and dtype is None: + dtype = torch.int64 + + fn = make_reduction("sum", override_return_dtype=dtype) + return fn(x, axis, keepdims, dtype=dtype) + + + @register_lowering(aten.mean) + def mean(x, axis=None, keepdim=False, *, dtype=None): + if dtype is not None: + x = to_dtype(x, dtype) + size = x.get_size() + axis = lowering._validate_reduction_axis(x, axis) + # compute in higher-precision until end of mean lowering + output_dtype = x.get_dtype() + if output_dtype in (torch.float16, torch.bfloat16): + x = to_dtype(x, torch.float) + sum_result = sum_(x, axis, keepdim) + denom = sympy_product(size[i] for i in axis) + denom = ir.IndexingConstant(index=denom, dtype=x.get_dtype(), device=x.get_device()) + denom = ExpandView.create(denom, list(sum_result.get_size())) + return to_dtype(div(sum_result, denom), output_dtype) + + @register_lowering(aten.cumsum) + def cumsum(x, axis=None, dtype=None): + if ( + is_integer_dtype(x.get_dtype()) or is_boolean_dtype(x.get_dtype()) + ) and dtype is None: + # torch.int64->torch.int32 + dtype = torch.int32 + if len(x.get_size()) == 0: + if axis not in [0, -1]: + raise ValueError("axis must be 0 or -1") + dtype = dtype or x.get_dtype() + return to_dtype(x, dtype, copy=True) + return lowering.fallback_cumsum(x, dim=axis, dtype=dtype) + + @register_lowering(npu.npu_dtype_cast, type_promotion_kind=None) + def _convert_npu_type(x: TensorBox, dtype: torch.dtype): + return to_dtype(x, dtype, copy=True) + + @register_lowering(npu._npu_dtype_cast, type_promotion_kind=None) + def _convert__npu_type(x: TensorBox, dtype: torch.dtype): + return to_dtype(x, dtype, copy=True) + + def var_mean_sum_(x, axis, correction, keepdim, return_mean): + if correction is None: + correction = 1 + + size = x.get_size() + axis = lowering._validate_reduction_axis(x, axis) + x_mean = mean(x, axis, keepdim=True) + if return_mean: + x_mean.realize() + + diffs = square(sub(x, x_mean)) + sum_result = sum_(diffs, axis, keepdim) + denom = sympy_product(size[i] for i in axis) + if correction: + denom = sympy.Max(denom - correction, 0) + denom = ir.IndexingConstant(index=denom, dtype=x.get_dtype(), device=x.get_device()) + denom = ExpandView.create(denom, list(sum_result.get_size())) + x_var = div(sum_result, denom) + if not return_mean: + return (x_var,) + + x_mean = x_mean if keepdim else squeeze(x_mean, axis) + return x_var, x_mean + + def var_mean_helper_(x, *, axis, correction, keepdim, return_mean): + out_dtype = x.get_dtype() + compute_dtype = get_computation_dtype(out_dtype) + x = to_dtype(x, compute_dtype, copy=False) + kwargs = dict( + x=x, + axis=axis, + correction=correction, + keepdim=keepdim, + return_mean=return_mean, + ) + output = ( + var_mean_sum_(**kwargs) + ) + output = tuple(to_dtype(x, out_dtype, copy=False) for x in output) + return output[0] if not return_mean else output + + @register_lowering(aten.var_mean) + def var_mean(x, axis=None, *, correction=None, keepdim=False): + return var_mean_helper_( + x, axis=axis, correction=correction, keepdim=keepdim, return_mean=True + ) + + @register_lowering([aten.var, prims.var]) + def var_(x, axis=None, *, correction=None, keepdim=False): + return var_mean_helper_( + x, axis=axis, correction=correction, keepdim=keepdim, return_mean=False + ) + + @register_lowering(aten.embedding, type_promotion_kind=None) + def embedding(weight, indices, padding_idx=-1, scale_grad_by_freq=False, sparse=False): + return lowering.fallback_handler(aten.embedding.default)(weight, indices, padding_idx=-1, + scale_grad_by_freq=False, + sparse=False) + + @register_lowering(aten.cat) + def cat(inputs, dim=0): + return lowering.fallback_handler(aten.cat.default)(inputs, dim) + + lowering.make_fallback(aten._log_softmax) + lowering.make_fallback(aten.gather) + lowering.make_fallback(aten.nll_loss_forward) diff --git a/torch_npu/_inductor/lowering_op_list.py b/torch_npu/_inductor/lowering_op_list.py new file mode 100644 index 0000000000000000000000000000000000000000..db9c427e60d69e95e39e9a7b83396198831d6070 --- /dev/null +++ b/torch_npu/_inductor/lowering_op_list.py @@ -0,0 +1,107 @@ +import torch +from torch_npu import npu_dtype_cast, _npu_dtype_cast + +aten = torch.ops.aten +tr_c10d = torch.ops.tr_c10d +prims = torch.ops.prims + +GENERATE_LIST = [ + prims.iota, + aten.full, + aten.mul, + aten.add, + aten.sub, + aten.div, + aten.exp, + aten.maximum, + aten.sum, + aten.select, + aten.unsqueeze, + aten.repeat, + aten.clone, + aten.reshape, + aten.where, + aten.lt, + aten.minimum, + aten.gt, + aten.le, + aten.ceil, + aten.floor, + aten.rsqrt, + aten.abs, + aten.log, + aten.bitwise_xor, + aten.amax, + # backward + prims.convert_element_type, + aten.min, + aten.max, + aten.erf, + aten.argmax, + aten.argmin, + aten.clamp_min, + aten.slice, + aten.neg, + aten.cat, + aten.arange, + aten.expand, + aten.eq, + aten.where, + aten.scalar_tensor, + aten.ge, + aten.permute, + aten.sqrt, + aten.relu, + aten.clamp, + aten.clamp_max, + aten.mean, + npu_dtype_cast, + _npu_dtype_cast, + aten.select_scatter, + aten.slice_scatter, + prims.broadcast_in_dim, + prims.maximum, + aten.ne, + aten.sigmoid, + aten.sign, + aten.logical_and, + aten.logical_or, + aten.logical_not, + aten.pow, + aten.gelu, + aten.tanh, + aten.isnan, + aten.bitwise_and, + aten.squeeze, + aten.copy, + aten.reciprocal +] + +GENERATE_LIST2 = [ + "foreach" +] + +FALLBACK_LIST = [] + +# Delete these op in lowering list and then update lowering list with new lowering, +# otherwise, it will not use npu overload lowering. +LOWERING_OVERLOAD_OP = [ + aten.cumsum, + aten.mean, + aten.max, + aten.min, + aten.amin, + aten.amax, + aten.argmax, + aten.argmin, + + aten.var_mean, + aten.var, + + aten.embedding, + aten.split, + aten.split_with_sizes, + aten.nll_loss_forward, + aten.gather, + aten.cat, +] diff --git a/torch_npu/_inductor/npu_choices.py b/torch_npu/_inductor/npu_choices.py new file mode 100644 index 0000000000000000000000000000000000000000..438399e4b6764e581dfb10ac872e8de2a01ca0b4 --- /dev/null +++ b/torch_npu/_inductor/npu_choices.py @@ -0,0 +1,33 @@ +import typing +from typing import Any, Dict, List, Type, TYPE_CHECKING +import sympy +from torch._inductor import config +from torch._inductor.codegen.simd_kernel_features import SIMDKernelFeatures +from torch._inductor.codegen.triton import TritonKernel +from torch._inductor.runtime.hints import ReductionHint +from torch._inductor.virtualized import V + + +@staticmethod +def should_use_persistent_reduction( + features: SIMDKernelFeatures, cooperative_reduction: bool +) -> bool: + """ + Heuristic to decide if a persistent reduction should be used. + """ + if not config.triton.persistent_reductions: + return False + threshold = { + ReductionHint.INNER: 1024, + ReductionHint.DEFAULT: 1024 + }.get(features.get_reduction_hint(), 64) + if cooperative_reduction: + # The RSPLIT of cooperative reductions means each thread block is operating on fewer elements + try: + threshold *= 32 // min(V.graph.sizevars.size_hint(features.numel), 32) + except ValueError: + pass # unbacked symint + + if config.triton.multi_kernel: + threshold *= 16 + return V.graph.sizevars.statically_known_leq(features.reduction_numel, threshold) # type: ignore[arg-types] diff --git a/torch_npu/_inductor/npu_device.py b/torch_npu/_inductor/npu_device.py new file mode 100644 index 0000000000000000000000000000000000000000..ef5bf7b4d58ee5765a3a1c480ce8857887f7cf3e --- /dev/null +++ b/torch_npu/_inductor/npu_device.py @@ -0,0 +1,208 @@ +import torch +from torch_npu.npu import device_count +from torch_npu.utils._dynamo_device import NpuInterface, current_device, set_device +from torch_npu.utils._inductor import NPUDeviceOpOverrides +from . import config as npu_config + + +## Override original inductor device overrides in torch_npu +class NewNPUDeviceOpOverrides(NPUDeviceOpOverrides): + def import_get_raw_stream_as(self, name): + return f"from torch_npu._inductor import get_current_raw_stream as {name}" + + def set_device(self, device_idx): + return f"torch.npu.set_device({device_idx})" + + def synchronize(self): + return """ + stream = torch.npu.current_stream() + stream.synchronize() + """ + + def device_guard(self, device_idx): + return f"torch.npu.utils.device({device_idx})" + + def cpp_aoti_device_guard(self): + raise NotImplementedError + + def cpp_aoti_stream_guard(self): + return "AOTICudaStreamGuard" + + def kernel_driver(self): + source_code = """ + namespace { + + struct Grid { + Grid(uint32_t x, uint32_t y, uint32_t z) + : grid_x(x), grid_y(y), grid_z(z) {} + uint32_t grid_x; + uint32_t grid_y; + uint32_t grid_z; + + bool is_non_zero() { + return grid_x > 0 && grid_y > 0 && grid_z > 0; + } + }; + + } // anonymous namespace + + extern "C" { + typedef int (* callback)(unsigned int type, void* data, unsigned int len); + extern int MsprofReportApi(unsigned int agingFlag, const MsprofApi *api); + extern unsigned long int MsprofSysCycleTime(); + extern int MsprofRegisterCallback(unsigned int moduleId, callback handle); + static unsigned int __MsprofFlagL0 = 0; + static unsigned int __MsprofFlagL1 = 0; + + int ProfCtrlHandle(unsigned int CtrlType, void* CtrlData, unsigned int DataLen) { + if ((CtrlData == nullptr) || (DataLen == 0U)) { + return 1; + } + + if (CtrlType == 1) { + MsprofCommandHandle* handle = (MsprofCommandHandle *)(CtrlData); + if (handle->type >= 6) // 6 is not used here + return 1; + if (handle->type == 1) { // init - 0 , start - 1 + __MsprofFlagL0 = ((0x00000800ULL & handle->profSwitch) == 0x00000800ULL) ? 1 : 0; + __MsprofFlagL1 = ((0x00000002ULL & handle->profSwitch) == 0x00000002ULL) ? 1 : 0; + } + } + return 0; + } + } + """ + + load_code = """ + static std::unordered_map registered_names; + static std::unordered_map> func_stubs; + + static inline void * loadKernel( + std::string filePath, + const std::string &&nameFuncMode, + uint32_t sharedMemBytes, + const std::optional &cubinDir = std::nullopt) { + if (cubinDir) { + std::filesystem::path p1{*cubinDir}; + std::filesystem::path p2{filePath}; + filePath = (p1 / p2.filename()).string(); + } + std::string funcName; + std::string kernel_mode_str; + size_t spacePos = nameFuncMode.find(' '); + if (spacePos != std::string::npos) { + kernel_mode_str = nameFuncMode.substr(spacePos + 1); + funcName = nameFuncMode.substr(0, spacePos); + } else { + throw std::runtime_error(std::string("Parse kernel name failed, expect " + "'kernel_name kernel_mode', bug got: ") + nameFuncMode); + } + + std::ifstream file(std::string(filePath), std::ios::binary | std::ios::ate); + if (!file.is_open()) { + throw std::runtime_error(std::string("open npubin failed")); + } + + std::streamsize data_size = file.tellg(); + + file.seekg(0, std::ios::beg); + char* buffer = new char[data_size]; + if (!file.read(buffer, data_size)) { + throw std::runtime_error(std::string("read npubin failed")); + } + + rtError_t rtRet; + + rtDevBinary_t devbin; + devbin.data = buffer; + devbin.length = data_size; + const std::string kernel_mode{kernel_mode_str}; + if (kernel_mode == "aiv") { + devbin.magic = RT_DEV_BINARY_MAGIC_ELF_AIVEC; + } else { + devbin.magic = RT_DEV_BINARY_MAGIC_ELF; + } + devbin.version = 0; + + int device = 0; + rtRet = rtSetDevice(device); + if (rtRet != RT_ERROR_NONE) { + throw std::runtime_error(std::string("rtSetDevice failed, 0x") + std::to_string(rtRet)); + } + + void *devbinHandle = NULL; + rtRet = rtDevBinaryRegister(&devbin, &devbinHandle); + if (rtRet != RT_ERROR_NONE) { + throw std::runtime_error(std::string("rtDevBinaryRegister failed, 0x") + std::to_string(rtRet)); + } + + const char* name = funcName.c_str(); + + std::string stubName(name); + stubName += "_" + std::to_string(registered_names[name]); + registered_names[name]++; + auto registered = func_stubs.emplace(stubName, std::make_unique(0)); + void *func_stub_handle = registered.first->second.get(); + rtRet = rtFunctionRegister(devbinHandle, func_stub_handle, stubName.c_str(), + (void *)name, 0); + if (rtRet != RT_ERROR_NONE) { + throw std::runtime_error(std::string("rtFunctionRegister failed, stubName = ") + stubName + + std::string(" , 0x") + std::to_string(rtRet)); + } + + return func_stub_handle; + } + """ + + # Could not use OpCommand when debug_kernel, because we want to + # use torch::save, which will cause dead lock in child thread. + launch_code = """ + static inline void launchKernel( + std::function launch_call, + std::string&& kernel_name) { + launch_call(); + } + """ if npu_config.aot_inductor.debug_kernel else """ + static inline void launchKernel( + std::function launch_call, + std::string&& kernel_name) { + at_npu::native::OpCommand cmd; + cmd.Name(kernel_name.c_str()) + .SetCustomHandler(launch_call) + .Run(); + } + """ + extra_code = "" + source_codes = source_code + load_code + launch_code + extra_code + return source_codes + + def abi_compatible_header(self): + return """ + #include + #include + #include + #include + #include + #include + #include + #include + + #include + #include + #include + #include + #include + #include "experiment/runtime/runtime/rt.h" + """ + + def cpp_stream_type(self): + return "aclrtStream" + + def aoti_get_stream(self): + return "aoti_torch_get_current_cuda_stream" + + def cpp_kernel_type(self): + return "void *" + + def cpp_device_ptr(self): + return "void*" diff --git a/torch_npu/_inductor/npu_fusion_attention_graph.py b/torch_npu/_inductor/npu_fusion_attention_graph.py new file mode 100644 index 0000000000000000000000000000000000000000..4242ba5b8edd0ffc504f027bccfd2165233547bd --- /dev/null +++ b/torch_npu/_inductor/npu_fusion_attention_graph.py @@ -0,0 +1,253 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. +import functools +import sympy +import torch +import torch.nn.functional as F +from torch.autograd import Function +from torch.library import Library, impl +import torch_npu + +npu_def = Library("npu_graph", "DEF") +npu_lib = Library("npu_graph", "IMPL", "PrivateUse1") +meta_lib = Library("npu_graph", "IMPL", "Meta") + +npu_def.define( + "npu_fa(Tensor query, Tensor key, Tensor value, int head_num, str input_layout, Tensor? pse=None, Tensor? padding_mask=None, Tensor? atten_mask=None, float scale=1., float keep_prob=1., int pre_tockens=2147483647, int next_tockens=2147483647, int inner_precise=0, int[]? prefix=None, int[]? actual_seq_qlen=None, int[]? actual_seq_kvlen=None, int sparse_mode=0, bool gen_mask_parallel=True, bool sync=False) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor)") +npu_def.define( + "npu_fa_backward(Tensor query, Tensor key, Tensor value, Tensor dy, int head_num, str input_layout, *, Tensor? pse=None, Tensor? padding_mask=None, Tensor? atten_mask=None, Tensor? softmax_max=None, Tensor? softmax_sum=None, Tensor? softmax_in=None, Tensor? attention_in=None, float scale_value=1., float keep_prob=1., int pre_tockens=2147483647, int next_tockens=2147483647, int inner_precise=0, Tensor? seed=None, Tensor? offset=None, Tensor? numels=None, int[]? prefix=None, int[]? actual_seq_qlen=None, int[]? actual_seq_kvlen=None, int sparse_mode=0, bool gen_mask_parallel=True, bool sync=False) -> (Tensor, Tensor, Tensor, Tensor)") + + +@impl(npu_lib, "npu_fa") +def npu_fa(*args, **kwargs): + if len(args) > 8: + args = list(args) + # for scale + try: + args[8] = 1.0 / args[8] + except IndexError: + args[8] = 1.0 / (args[8] + 1e-6) + r1, r2, r3, r4, seed, offset, numel = torch_npu.npu_fusion_attention(*args, **kwargs) + r2.requires_grad = False + r3.requires_grad = False + r4.requires_grad = False + return r1, r2, r3, r4, torch.tensor([seed], requires_grad=False), torch.tensor([offset], + requires_grad=False), torch.tensor( + [numel], requires_grad=False) + + +@impl(npu_lib, "npu_fa_backward") +def npu_fa_backward(*args, **kwargs): + if 'scale_value' in kwargs: + kwargs['scale_value'] = 1.0 / kwargs['scale_value'] + return torch_npu.npu_fusion_attention_grad(*args, **kwargs) + + +@impl(meta_lib, "npu_fa") +def npu_fa(query, key, value, head_num, input_layout, pse=None, padding_mask=None, + atten_mask=None, scale=1.0, keep_prob=1.0, pre_tockens=2147483647, next_tockens=2147483647, + inner_precise=0, prefix=None, actual_seq_qlen=None, actual_seq_kvlen=None, sparse_mode=0, + gen_mask_parallel=True, sync=False): + B = query.size(0) + N = head_num + S1 = query.size(2) + S2 = key.size(2) + + if input_layout == "BSH": + B = query.size(0) + S1 = query.size(1) + S2 = key.size(1) + + if input_layout == "SBH": + B = query.size(1) + S1 = query.size(0) + S2 = key.size(0) + + attention_score = torch.empty_like(query, dtype=query.dtype, device='meta').contiguous() + softmax_max = torch.empty([B, head_num, S1, 8], dtype=torch.float32, device='meta') + softmax_sum = torch.empty([B, head_num, S1, 8], dtype=torch.float32, device='meta') + softmax_out = torch.empty([0], dtype=query.dtype, device='meta') + return (torch.empty_like(attention_score), + torch.empty_like(softmax_max), + torch.empty_like(softmax_sum), + torch.empty_like(softmax_out), + torch.tensor([0], device='meta', requires_grad=False), + torch.tensor([0], device='meta', requires_grad=False), + torch.tensor([0], device='meta', requires_grad=False)) + + +@impl(meta_lib, "npu_fa_backward") +def npu_fa_backward(query, key, value, dy, head_num, input_layout, *, pse=None, padding_mask=None, atten_mask=None, + softmax_max=None, softmax_sum=None, softmax_in=None, attention_in=None, scale_value=1.0, + keep_prob=1.0, pre_tockens=2147483647, next_tockens=2147483647, inner_precise=0, seed=0, offset=0, + numels=0, prefix=None, actual_seq_qlen=None, actual_seq_kvlen=None, sparse_mode=0, + gen_mask_parallel=True, sync=False): + dq = torch.empty_like(query, dtype=query.dtype, device='meta').contiguous() + dk = torch.empty_like(key, dtype=query.dtype, device='meta').contiguous() + dv = torch.empty_like(value, dtype=query.dtype, device='meta').contiguous() + dpse = torch.empty([0], dtype=query.dtype, device='meta').contiguous() + return (torch.empty_like(dq), torch.empty_like(dk), torch.empty_like(dv), torch.empty_like(dpse) if pse else None) + + +class NpuGraphAttentionFunction(Function): + @staticmethod + def forward(ctx, query, key, value, head_num, input_layout, pse=None, padding_mask=None, atten_mask=None, scale=1.0, + keep_prob=1.0, pre_tockens=2147483647, next_tockens=2147483647, inner_precise=0, prefix=None, + actual_seq_qlen=None, actual_seq_kvlen=None, sparse_mode=0, gen_mask_parallel=True, sync=False): + # 前向传播逻辑 + # 这里假设有一个实现前向传播的函数 `npu_fusion_attention_forward` + result0, result1, result2, result3, result4, result5, result6 = torch.ops.npu_graph.npu_fa( + query, key, value, head_num, input_layout, pse=pse, padding_mask=padding_mask, atten_mask=atten_mask, + scale=scale, keep_prob=keep_prob, pre_tockens=pre_tockens, next_tockens=next_tockens, + inner_precise=inner_precise, prefix=prefix, actual_seq_qlen=actual_seq_qlen, + actual_seq_kvlen=actual_seq_kvlen, sparse_mode=sparse_mode, gen_mask_parallel=gen_mask_parallel, sync=sync + ) + # 保存中间结果,以便在反向传播中使用 + ctx.save_for_backward(query, key, value, pse, padding_mask, atten_mask, result1, result2, result3, result0, + result4, result5, result6) + ctx.head_num = head_num + ctx.input_layout = input_layout + ctx.scale = scale + ctx.keep_prob = keep_prob + ctx.pre_tockens = pre_tockens + ctx.next_tockens = next_tockens + ctx.inner_precise = inner_precise + ctx.prefix = prefix + ctx.actual_seq_qlen = actual_seq_qlen + ctx.actual_seq_kvlen = actual_seq_kvlen + ctx.sparse_mode = sparse_mode + ctx.gen_mask_parallel = gen_mask_parallel + ctx.sync = sync + + return result0, result1, result2, result3, result4, result5, result6 + + @staticmethod + def backward(ctx, grad_result0, grad_result1, grad_result2, grad_result3, grad_result4, grad_result5, grad_result6): + # 获取保存的中间结果 + query, key, value, pse, padding_mask, atten_mask, result1, result2, result3, result0, result4, result5, result6 = ctx.saved_tensors + # 反向传播逻辑 + # 这里假设有一个实现反向传播的函数 `npu_fusion_attention_backward` + grad_query, grad_key, grad_value, grad_pse = torch.ops.npu_graph.npu_fa_backward( + query, key, value, grad_result0, ctx.head_num, ctx.input_layout, pse=pse, padding_mask=padding_mask, + atten_mask=atten_mask, softmax_max=result1, softmax_sum=result2, softmax_in=result3, attention_in=result0, + scale_value=ctx.scale, keep_prob=ctx.keep_prob, pre_tockens=ctx.pre_tockens, next_tockens=ctx.next_tockens, + inner_precise=ctx.inner_precise, seed=result4, offset=result5, numels=result6, prefix=ctx.prefix, + actual_seq_qlen=ctx.actual_seq_qlen, actual_seq_kvlen=ctx.actual_seq_kvlen, sparse_mode=ctx.sparse_mode, + gen_mask_parallel=ctx.gen_mask_parallel, sync=ctx.sync + ) + return ( + grad_query, grad_key, grad_value, None, None, grad_pse, None, None, None, None, None, None, None, None, None, + None, None, None, None, None, None, None, None, None, None, None) + + +def npu_fusion_attention_graph(query, key, value, head_num, input_layout, pse=None, padding_mask=None, + atten_mask=None, scale=1.0, keep_prob=1.0, pre_tockens=2147483647, + next_tockens=2147483647, + inner_precise=0, prefix=None, actual_seq_qlen=None, actual_seq_kvlen=None, sparse_mode=0, + gen_mask_parallel=True, sync=False): + return NpuGraphAttentionFunction.apply(query, key, value, head_num, input_layout, pse, padding_mask, + atten_mask, scale, keep_prob, pre_tockens, next_tockens, + inner_precise, prefix, actual_seq_qlen, actual_seq_kvlen, sparse_mode, + gen_mask_parallel, sync) + + +torch_npu.npu_fusion_attention_graph = npu_fusion_attention_graph + + +def register_fa_pass(): + TOKEN_MAX = 2147483647 + from torch._inductor.pattern_matcher import register_replacement, fwd_only, joint_fwd_bwd + from torch._inductor.fx_passes.joint_graph import patterns + from torch._dynamo.utils import counters + from torch._inductor.fx_passes.fuse_attention import partialize_and_update_signature + + def _npu_fusion_attention_graph_pattern_1(query, key, value, inv_scale_factor, dropout_p): + q = query.permute(0, 2, 1, 3) + k = key.permute(0, 2, 1, 3) + v = value.permute(0, 2, 1, 3) + return torch.nn.functional.dropout( + torch.matmul(q, k.transpose(-2, -1)).div(inv_scale_factor).softmax(dim=-1), + p=dropout_p, + ).matmul(v) + + def _npu_fusion_attention_graph_replacement_1(query, key, value, inv_scale_factor, dropout_p): + counters["inductor"]["fuse_attention"] += 1 + head_num = query.size(2) + input_layout = "BNSD" + return torch_npu.npu_fusion_attention_graph( + query.transpose(1, 2), + key.transpose(1, 2), + value.transpose(1, 2), + head_num, + input_layout, + None, + atten_mask=None, + scale=inv_scale_factor, + keep_prob=1.0 - dropout_p, + )[0] + + def _get_sfdp_patterns(): + device = 'npu' + g_inp = functools.partial( + torch.empty, (2, 4, 8, 16), device=device, requires_grad=True + ) + c_inp = functools.partial(torch.tensor, 2.0, device=device) + d = {"dropout_p": 0.113377} + candidates = [] + for dtype in [torch.float]: + g = functools.partial(g_inp, dtype=dtype) + c = functools.partial(c_inp, dtype=dtype) + candidates.append(( + _npu_fusion_attention_graph_pattern_1, + _npu_fusion_attention_graph_replacement_1, + [g(), g(), g(), c()], + d, + )) + + for pattern, replacement, args, workaround in candidates: + # gets serialized to a python file and does not require tracing at runtime. + if not isinstance(workaround, dict): + raise ValueError("workaround not dict") + name = pattern.__name__ + + if dtype != torch.float: + name += "_half" + + if args[0].size(0) == 1: + name += "_bs1" + + training_name = name + "_training" + yield training_name, { + "search_fn": pattern, + "replace_fn": replacement, + "example_inputs": args, + "trace_fn": joint_fwd_bwd, + "pass_dicts": patterns, + "scalar_workaround": workaround, + } + + if workaround: + if not (len(workaround) == 1 and "dropout_p" in workaround): + raise ValueError("not (len(workaround) == 1 and dropout_p in workaround)") + # functools.partial insufficient because we look at signature downstream + pattern = partialize_and_update_signature(pattern, dropout_p=0.0) + replacement = partialize_and_update_signature( + replacement, dropout_p=0.0 + ) + workaround = {} + + inference_name = name + "_inference" + yield inference_name, { + "search_fn": pattern, + "replace_fn": replacement, + "example_inputs": args, + "trace_fn": fwd_only, + "pass_dicts": patterns, + "scalar_workaround": workaround, + } + + for _, register_replacement_kwargs in _get_sfdp_patterns(): + register_replacement( + **register_replacement_kwargs, + ) + diff --git a/torch_npu/_inductor/npu_static_kernel.py b/torch_npu/_inductor/npu_static_kernel.py new file mode 100644 index 0000000000000000000000000000000000000000..5d6b7fe4d86bebb93d1641b8ff6341ce88224999 --- /dev/null +++ b/torch_npu/_inductor/npu_static_kernel.py @@ -0,0 +1,151 @@ +__all__ = [] + +import os +import subprocess +import datetime +from pathlib import Path + +import torch_npu +from .config import log + +_uninstall_path = None + + +def safe_resolve_output_dir(build_dir: str): + base_dir = Path.cwd().resolve() + if build_dir is not None: + if "\x00" in build_dir: + raise ValueError("build_dir contains null byte") + + candidate = Path(build_dir) + if ".." in candidate.parts: + raise ValueError("build_dir must not contain '..'") + + script_dir = candidate if candidate.is_absolute() else base_dir / candidate + + cur = Path(script_dir.anchor) + for part in script_dir.parts[1:]: + cur = cur / part + if cur.exists() and cur.is_symlink(): + raise ValueError(f"symlink detected in path: {cur}") + + try: + script_dir = script_dir.resolve(strict=False) + except Exception as e: + raise ValueError(f"cannot resolve path {script_dir}: {e}") + else: + script_dir = base_dir + + timestamp = f"{datetime.datetime.now().strftime('%Y%m%d%H%M%S%f')}_{os.getpid()}" + result_root = script_dir / f"{timestamp}_kernel_aot_optimization_build_outputs" + + try: + result_root.mkdir(exist_ok=True) + except (PermissionError, OSError) as e: + raise RuntimeError(f"failed to create output directory {result_root}: {e}") from e + + return result_root + + +class AclopDumpContext: + def __init__(self, save_path: str): + self.save_path = save_path + + def __enter__(self): + torch_npu._C._aclop_start_dmp(self.save_path) + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + torch_npu._C._aclop_stop_dump() + + +def save_uninstall_info(filename: str): + global _uninstall_path + latest = Path(os.environ["ASCEND_HOME_PATH"]) + root = latest.parent + pattern = f"*/opp/static_kernel/ai_core/{filename}/uninstall.sh" + match = next(root.glob(pattern), None) + if match is None: + _uninstall_path = None + log.debug(f"can not find uninstall path, pattern: {pattern}") + else: + _uninstall_path = str(match) + + +class StaticKernelCompiler: + def __init__(self, build_dir=None): + self.result_root = safe_resolve_output_dir(build_dir) + log.debug(f"StaticKernelCompiler initialized. Build directory: {self.result_root}") + + def __enter__(self): + log.info(f"Starting operator dump...") + torch_npu._C._aclop_start_dump(str(self.result_root)) + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + torch_npu._C._aclop_stop_dump() + log.info(f"Stopping operator dump.") + + if exc_type: + log.error(f"An exception occurred during model execution: {exc_val}") + log.info(f"Skipping static kernel compilation due to the error.") + return + + log.info(f"Starting static kernel compilation process...") + debug_dirs = [d for d in self.result_root.iterdir() if d.is_dir() and d.name.endswith("_debug")] + if not debug_dirs: + log.error(f"Can not find json of ops, skipping op_compiler.") + return + + debug_dir = max(debug_dirs, key=lambda d: d.stat().st_mtime) + json_files = list(debug_dir.glob("*.json")) + if not json_files: + log.error(f"No json files in {debug_dir}, skipping op_compiler.") + return + + cmd = [ + "op_compiler", + "-p", str(debug_dir), + "-v", torch_npu.npu.get_device_name(), + "-l", "info", + "-j", "4", + "-o", str(self.result_root), + ] + try: + log.debug(f"Executing op_compiler command: {' '.join(cmd)}") + res = subprocess.run(cmd, check=True, capture_output=True, text=True) + log.debug(f"op_compiler execution successful, msg: {res.stdout}") + except subprocess.CalledProcessError as e: + log.error(f"op_compiler execution failed, msg: {e.stderr}") + return + + for run_pkg in self.result_root.glob("*.run"): + filename = run_pkg.name + try: + log.info(f"Installing static kernel package: {filename}") + result = subprocess.run([str(run_pkg)], check=True, capture_output=True, text=True) + log.info(f"{filename} install successful, msg: {result.stdout}") + save_uninstall_info(filename[:-4]) + torch_npu.npu._aclnn_reselect_static_kernel() + except subprocess.CalledProcessError as e: + log.error(f" {filename} install failed, msg: {e.stderr}") + + +def uninstall_static_kernel(): + global _uninstall_path + if not _uninstall_path: + log.debug(f"uninstall_path is none, skip uninstall static kernel") + return + + try: + result = subprocess.run( + [_uninstall_path], + check=True, + capture_output=True, + text=True, + ) + log.debug(f"{_uninstall_path} uninstall success, msg: \n{result.stdout}") + except subprocess.CalledProcessError as e: + log.error(f"{_uninstall_path} uninstall failed, msg: \n{e.stderr}") + finally: + _uninstall_path = None \ No newline at end of file diff --git a/torch_npu/_inductor/npu_triton_helpers.py b/torch_npu/_inductor/npu_triton_helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..5140a2911a04928ea1e0d0e2eb72972a13e9cd01 --- /dev/null +++ b/torch_npu/_inductor/npu_triton_helpers.py @@ -0,0 +1,22 @@ +import triton +import triton.language as tl + +import triton.language.extra.ascend.libdevice as libdevice +from torch._inductor.runtime import triton_helpers + +libdevice = tl.extra.ascend.libdevice +math = tl.math + + +@triton.jit +def maximum(a, b): + return tl.maximum(a, b) + + +@triton.jit +def minimum(a, b): + return tl.minimum(a, b) + + +triton_helpers.maximum = maximum +triton_helpers.minimum = minimum diff --git a/torch_npu/_inductor/npu_triton_heuristics.py b/torch_npu/_inductor/npu_triton_heuristics.py new file mode 100644 index 0000000000000000000000000000000000000000..0cd52fe110d6f7f50291f147f94253262bbd8813 --- /dev/null +++ b/torch_npu/_inductor/npu_triton_heuristics.py @@ -0,0 +1,1324 @@ +# This file is based on triton_heuristics with heuristics designed for NPU +import copy +import functools +from functools import lru_cache +import hashlib +import importlib +import json +import logging +import dataclasses +import os +import re +import sys +import time +import shutil +import hashlib +import csv +import uuid +from itertools import count +from typing import Any, Callable, Literal, Optional, TYPE_CHECKING, Union, List +from contextlib import contextmanager +import torch +from torch._logging import warning_once +import triton +from torch._dynamo.utils import dynamo_timed +from torch._inductor import config +from torch._inductor.compile_fx import clone_preserve_strides +from torch._inductor.runtime.autotune_cache import AutotuneCache +from torch._inductor.runtime.benchmarking import benchmarker +from torch._inductor.runtime.runtime_utils import ( + create_bandwidth_info_str, + get_num_bytes, + +) +from torch._inductor.utils import triton_version_uses_attrs_dict +from torch.utils._ordered_set import OrderedSet +from torch._inductor.runtime.triton_heuristics import ( + CachingAutotuner, + HeuristicType, + unique_configs, + hash_configs, + Config, + ASTSource, + _find_names, + get_first_attr, + collected_calls, + _dump_launch_params, + builtins, + NoTritonConfigsError, + TritonCompileResult, + GridExpr, + config_to_dict +) +from torch._inductor.runtime.runtime_utils import triton_hash_to_path_key +from triton.compiler import CompiledKernel +from torch._inductor.triton_bundler import TritonBundler + +try: + from triton.backends.compiler import GPUTarget + from triton.runtime.autotuner import OutOfResources + import torch.autograd.profiler as autograd_profiler +except ImportError: + GPUTarget = None + OutOfResources = None + autograd_profiler = None + +import torch_npu +from torch_npu.utils._error_code import ErrCode, pta_error + +from .codegen.split_tiling import SplitTiling +from .utils import get_current_raw_stream +from .codegen.tile_generator import TileGenerator +from .codegen.triton_utils import get_aligned_numel +from .config import aggresive_autotune +from .config import log +from . import config as npu_config + +kernel_idx = count() + + +@contextmanager +def create_profiler(torch_path): + experimental_config = torch_npu.profiler._ExperimentalConfig( + aic_metrics=torch_npu.profiler.AiCMetrics.PipeUtilization, + profiler_level=torch_npu.profiler.ProfilerLevel.Level0, ) + profile_path = torch_path + with torch_npu.profiler.profile( + activities=[torch_npu.profiler.ProfilerActivity.NPU], + record_shapes=False, + profile_memory=False, + with_stack=False, + schedule=torch_npu.profiler.schedule(wait=0, warmup=1, active=1, repeat=1, skip_first=1), + on_trace_ready=torch_npu.profiler.tensorboard_trace_handler(profile_path), + experimental_config=experimental_config) as prof: + yield prof + + +def delete_file_base(base_path): + if os.path.exists(base_path): + shutil.rmtree(base_path) + + +def read_device_time(torch_path, triton_only=True, return_list=True): + for root, _, files in os.walk(torch_path): + for file in files: + if file != 'kernel_details.csv': + continue + target_file = os.path.join(root, file) + with open(target_file, newline='') as csvfile: + durations = [] + reader = csv.DictReader(csvfile) + for row_read in reader: + durations.append(float(row_read['Duration(us)'])) + if return_list: + return durations + ret = sum(durations) + return ret + raise RuntimeError(f"Could not find kernel_details.csv from dir {torch_path}") + + +def _summarize_statistics(times, quantiles, return_mode): + if quantiles is not None: + ret = torch.quantile(times, torch.tensor(quantiles, dtype=torch.float)).tolist() + if len(ret) == 1: + ret = ret[0] + return ret + if return_mode == "all": + return times.tolist() + return getattr(torch, return_mode)(times).item() + + +def do_bench_using_profiling_npu(fn, warmup=2, rep=10, grad_to_none=None, quantiles=None, return_mode="mean"): + if return_mode not in ["min", "max", "mean", "median", "all"]: + raise RuntimeError("return_mode must be one of 'min', 'max', 'mean', 'median', 'all'") + + stream = torch.npu.current_stream() + stream.synchronize() + + # Warm-up + for _ in range(warmup): + fn() + stream.synchronize() + + random_uuid = uuid.uuid4().hex + md5_hash = hashlib.md5(random_uuid.encode()).hexdigest() + torch_path = os.path.join(os.getcwd(), "profile_result", f"triton_{md5_hash}") + with create_profiler(torch_path) as prof: + stream.synchronize() + for _ in range(rep + 10): + fn() + prof.step() + stream.synchronize() + times = read_device_time(torch_path, triton_only=False, return_list=True) + delete_file_base(torch_path) + return _summarize_statistics(torch.tensor(times), quantiles, return_mode) + + +@dataclasses.dataclass +class GridNpu(GridExpr): + numels: List[str] = None + + def generate(self, meta: dict[str, int]) -> None: + numel_args = [] + split_axis = meta.get("split_axis", None) + split_blocks = meta.get("split_blocks", None) + if split_axis is None or split_blocks is None: + raise RuntimeError(f"Could not get split_axis or split_blocks from meta {meta}.") + + def grid_fn(i): + if i >= len(split_axis): + return "1" + axis = split_axis[i] + block = split_blocks[i] + return f"({self.numels[axis]} + {block} - 1) // {block}" + self.x_grid = grid_fn(0) + self.y_grid = grid_fn(1) + self.z_grid = grid_fn(2) + + +class GridExprNpu(GridExpr): + @staticmethod + def from_meta_and_set_numel( + inductor_meta: dict[str, Any], + cfg: Union[Config, dict[str, int]], + numels: List[str], + mode: Literal["python", "cpp"] = "python", + ) -> GridExpr: + grid_cls = globals()[inductor_meta["grid_type"]] + if not issubclass(grid_cls, GridNpu): + raise AssertionError(f"grid_type in inductor_meta must be subclass of GridNpu" + f"but got {inductor_meta['grid_type']}") + grid = grid_cls(inductor_meta=inductor_meta, mode=mode, numels=numels) + if isinstance(cfg, Config): + cfg = config_to_dict(cfg) + grid.generate(cfg) + return grid + + +class TritonCompileResultNpu(TritonCompileResult): + def make_launcher(self): + cfg = self.config + compile_meta = self.compile_meta + binary = self.kernel + fn = binary.src.fn + binary._init_handles() + + known_constants = OrderedSet( + arg for i, arg in enumerate(fn.arg_names) if i in fn.constexprs + ) + none_args = OrderedSet( + k + for k, v in compile_meta["constants"].items() + if v is None and k not in known_constants + ) + none_args = none_args.difference(OrderedSet(compile_meta["signature"].keys())) + + if triton_version_uses_attrs_dict(): + call_args = fn.arg_names + def_args = fn.arg_names + if ( + "num_warps" in compile_meta["constants"] + or "num_stages" in compile_meta["constants"] + ): + # num_warps/num_stages are special implicit args that are not in the signature + # see test_triton_kernel_special_params + def_args = [ + arg for arg in def_args if arg not in ("num_warps", "num_stages") + ] + repl = { + k: str(compile_meta["constants"].get(k)) + for k in ("num_warps", "num_stages") + } + call_args = [repl.get(arg, arg) for arg in call_args] + else: + call_args = [ + arg + for i, arg in enumerate(fn.arg_names) + if i not in fn.constexprs and arg not in none_args + ] + cfg_dict = config_to_dict(cfg) + def_args = [ + name + for name in fn.arg_names + if name not in cfg_dict and name not in none_args + ] + + binary_shared = ( + binary.shared if hasattr(binary, "shared") else binary.metadata.shared + ) + + scope = { + "grid_meta": cfg.kwargs, + "bin": binary, + "launch_enter_hook": binary.__class__.launch_enter_hook, + "launch_exit_hook": binary.__class__.launch_exit_hook, + "metadata": ( + binary.packed_metadata + if hasattr(binary, "packed_metadata") + else binary.metadata + ), + "shared": binary_shared, + "num_warps": ( + binary.num_warps + if hasattr(binary, "num_warps") + else binary.metadata.num_warps + ), + "cta_args": ( + ( + binary.num_ctas, + *get_first_attr(binary, "cluster_dims", "clusterDims"), + ) + if hasattr(binary, "num_ctas") + else ( + (binary.metadata.num_ctas, *binary.metadata.cluster_dims) + if hasattr(binary, "metadata") + else () + ) + ), + "function": get_first_attr(binary, "function", "cu_function"), + "runner": get_first_attr(binary, "run", "c_wrapper"), + "log": log, + } + + if not hasattr(binary, "launch_metadata"): + # launch args before CompiledKernel.launch_metadata is added. + # TODO(jansel): delete this branch in mid-2025 + runner_args = [ + "grid_0", + "grid_1", + "grid_2", + "num_warps", + "*cta_args", + "shared", + "stream", + "function", + "launch_enter_hook", + "launch_exit_hook", + "metadata", + *call_args, + ] + else: + if binary.__class__.launch_enter_hook: + launch_metadata = f"bin.launch_metadata((grid_0, grid_1, grid_2), stream, {', '.join(call_args)})" + else: + launch_metadata = "None" + runner_args = [ + "grid_0", + "grid_1", + "grid_2", + "stream", + "function", + "metadata", + launch_metadata, + "launch_enter_hook", + "launch_exit_hook", + *call_args, + ] + + if "extra_launcher_args" in self.inductor_meta: + def_args = [*def_args, *self.inductor_meta["extra_launcher_args"]] + + numels = [ + arg + for arg in fn.arg_names + if "_numel" in arg + ] + grid = GridExprNpu.from_meta_and_set_numel(self.inductor_meta, cfg, numels) + # grid.prefix is usually empty, grid.x_grid is something like `-(xnumel//-1024)` + lines = [ + f"def launcher({', '.join(def_args)}, stream):", + *[f" {line}" for line in grid.prefix], + f" grid_0 = {grid.x_grid}", + f" grid_1 = {grid.y_grid}", + f" grid_2 = {grid.z_grid}", + f" log.debug(", + f" f'[Runtime] Launch KERNEL {fn.fn.__name__} with ' ", + f" f'grid {{grid_0, grid_1, grid_2}} and cfg {{grid_meta}}]'", + f" )", + f" runner({', '.join(runner_args)})", + ] + exec("\n".join(lines), scope) + + launcher = scope["launcher"] + launcher.config = cfg + launcher.n_regs = getattr(binary, "n_regs", None) + launcher.n_spills = getattr(binary, "n_spills", None) + launcher.shared = binary_shared + launcher.store_cubin = self.inductor_meta.get("store_cubin", False) + # store this global variable to avoid the high overhead of reading it when calling run + if launcher.store_cubin: + launcher.fn = fn + launcher.bin = binary + if triton_version_uses_attrs_dict(): + # arg filtering wasn't done above + cfg_dict = config_to_dict(cfg) + def_args = [x for x in def_args if x not in cfg_dict] + call_args = [ + x + for x in call_args + if compile_meta["signature"].get(x, "constexpr") != "constexpr" + and x not in none_args + ] + launcher.def_args = def_args + launcher.call_args = call_args + return launcher + + +class NPUCachingAutotuner(CachingAutotuner): + def __init__( + self, + fn, + triton_meta, # passed directly to triton + configs, + save_cache_hook, + mutated_arg_names: List[str], # see [Note: clone mutated buffers] + optimize_mem, + heuristic_type, + size_hints=None, + inductor_meta=None, # metadata not relevant to triton + custom_kernel=False, # whether the kernel is inductor-generated or custom + filename: Optional[str] = None, + reset_to_zero_arg_names: Optional[List[str]] = None, + ): + super().__init__(fn, triton_meta, configs, save_cache_hook, mutated_arg_names, optimize_mem, heuristic_type, + size_hints, inductor_meta, custom_kernel, filename, reset_to_zero_arg_names) + + self.exceptions = [] + self.fn_name = None + + @staticmethod + def api_accuracy_checker(expected, actual, kernel_name, dump_path): + from msprobe.core.common.const import CompareConst + from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import BENCHMARK_COMPARE_SUPPORT_LIST + from msprobe.pytorch.api_accuracy_checker.triton_adapter.get_compare_result import get_compare_result + from msprobe.pytorch.api_accuracy_checker.triton_adapter.precision_compare import precision_compare + from msprobe.pytorch.api_accuracy_checker.triton_adapter.common.compare_utils import \ + convert_compare_column_to_row, print_check_details + from msprobe.pytorch.api_accuracy_checker.triton_adapter.precision_standard.triton_standard_register import \ + exist_in_precision_standard + + dtype = actual.dtype + + # only float use precision standard + if exist_in_precision_standard(kernel_name): + if str(dtype) in BENCHMARK_COMPARE_SUPPORT_LIST: + compare_column = precision_compare(kernel_name, expected, actual, dtype) # calc metrics + compare_row = convert_compare_column_to_row(compare_column, kernel_name) + status = get_compare_result(compare_row, kernel_name) # get compare results + if status == CompareConst.ERROR: + log.warning(f'CHECK ACCURACY FAILED! kernel: {kernel_name}, Dump Path: {dump_path}') + print_check_details(compare_column, kernel_name) + actual.copy_(expected) + checked_by_msprobe = True + else: + log.warning(f'The data type {dtype} is not supported for new precision standard. ' + f'Check accuracy by tolerance method.') + checked_by_msprobe = False + else: + log.warning(f'kernel_name {kernel_name} does not in new precision standard. ' + f'Check accuracy by tolerance method.') + checked_by_msprobe = False + return checked_by_msprobe + + def precompile( + self, + warm_cache_only=False, + reload_kernel: Optional[Callable[[], CachingAutotuner]] = None, + static_triton_bundle_key: Optional[str] = None, + ): + if warm_cache_only: + self._precompile_worker() + return + with self.lock: + # Helper function for reloading a kernel generated in a worker + # in the parent class. Normally we don't need to reload the kernel + # in the parent process, but in certain cases (coordesc tuning, dynamic_scale_rblock), + # we need to actually run compilation on the parent process + if reload_kernel is not None: + self._reload_kernel = reload_kernel + self._precompile_worker() + self._make_launchers() + + def _precompile_worker(self): + if self.compile_results: + for result in self.compile_results: + TritonBundler.put( + triton_hash_to_path_key(result.kernel.hash), + self.triton_meta.get("device", 0), + ) + return + if self.launchers: + raise AssertionError("Before _precompile_worker, launchers must bt empty") + + if not self.configs: + raise NoTritonConfigsError("No triton configs are available") + + compile_results = [] + exc = None + exc_stack = "" + for c in self.configs: + try: + compile_results.append(self._precompile_config(c)) + except Exception as e: + import traceback + exc_stack = traceback.format_exc() + exc = e + if len(compile_results) == 0: + raise NoTritonConfigsError( + f"No valid triton configs. {type(exc).__name__}: {exc} \nStack trace:{exc_stack}" + ) + self.compile_results = compile_results + self.configs = None + + def _precompile_config(self, cfg: Config) -> TritonCompileResultNpu: + """Ahead of time compile a given autotuner config.""" + compile_meta = copy.deepcopy(self.triton_meta) + cfg_kwargs = cfg.kwargs + for k, v in cfg_kwargs.items(): + if k not in self.fn.arg_names: + continue + compile_meta["constants"][k] = v + + for i in self.fn.constexprs: + arg_name = self.fn.arg_names[i] + if arg_name not in compile_meta["constants"] and ( + arg_name == "num_warps" or arg_name == "num_stages" + ): + compile_meta["constants"][arg_name] = getattr(cfg, arg_name) + compile_meta["num_warps"] = cfg.num_warps + compile_meta["num_stages"] = cfg.num_stages + compile_meta["debug"] = ( + os.getenv("INDUCTOR_ASCEND_DEBUG", 'false').lower() in ('true', '1') + and self.inductor_meta.get("assert_indirect_indexing", True) + and not self.inductor_meta.get("is_hip", False) + ) + + # device type will be "hip" rather than "cuda" here + compile_meta["device_type"] = self.device_props.type + compile_meta["cc"] = self.device_props.cc + + if not ASTSource: + raise RuntimeError("Installed triton version too old, please upgrade") + + compile_args = ( + ASTSource( + self.fn, + compile_meta["signature"], + compile_meta["constants"], + ), + ) + + cc_warp_size = 32 + target = GPUTarget( + compile_meta["device_type"], + compile_meta["cc"], + cc_warp_size, + ) + + options = { + "num_warps": compile_meta["num_warps"], + "num_stages": compile_meta["num_stages"], + "debug": compile_meta["debug"] + } + compile_kwargs = { + "target": target, + "options": options, + } + + try: + binary = triton.compile(*compile_args, **compile_kwargs) + except Exception: + log.debug( + "Triton compilation failed: %s\n%s\nmetadata: %s", + self.inductor_meta.get("kernel_name", "triton_"), + self.fn.src, + compile_meta, + ) + raise + return TritonCompileResultNpu(binary, cfg, compile_meta, self.inductor_meta) + + def _make_launchers(self): + if len(self.launchers) == len(self.compile_results): + return + + from torch._dynamo.device_interface import DeviceGuard + + device_interface = self.get_device_interface() + + # load binary to the correct device + with DeviceGuard(device_interface, self.triton_meta["device"]): + # need to initialize context + device_interface.synchronize(device_interface.current_device()) + launchers = [] + exc = None + exc_stack = "" + for result in self.compile_results: + try: + launchers.append(result.make_launcher()) + except Exception as e: + import traceback + exc_stack = traceback.format_exc() + exc = e + + if len(launchers) == 0: + raise RuntimeError(f"No valid triton configs. {type(exc).__name__}: {exc}\n" + f"Stack trace: {exc_stack}") + self.launchers = launchers + + def save_gpu_kernel(self, input_stream, input_launcher): + self.save_npu_kernel(input_stream, input_launcher) + + def save_npu_kernel(self, input_stream, input_launcher): + key = self.inductor_meta.get("kernel_name", None) # unique kernel name + + if key is None: + raise RuntimeError("assert key is not None, kernel_name can not be None") + params = { + "mangled_name": ( + input_launcher.bin.metadata.name + if hasattr(input_launcher.bin.metadata, "name") + else input_launcher.bin.metadata["name"] + ), + "num_warps": ( + input_launcher.bin.num_warps + if hasattr(input_launcher.bin, "num_warps") + else input_launcher.bin.metadata.num_warps + ), + "shared_mem": ( + input_launcher.bin.shared + if hasattr(input_launcher.bin, "shared") + else input_launcher.bin.metadata.shared + ), + "stream": input_stream, + # User defined triton kernels will have arbitrary kwarg names + "meta": input_launcher.config.kwargs, + } + from torch._inductor.codecache import CudaKernelParamCache + + bin_type = "npubin" + binary = input_launcher.bin.asm[bin_type] # npubin type = npubin + CudaKernelParamCache.set(key, params, binary, bin_type='cubin') # CudaKernelParam + + self.cuda_kernel_saved = True + + # bench method is called by torch, grid can not be modified + def bench(self, launcher, *args, with_profiler=False, **kwargs): + """Measure the performance of a given launcher""" + + if not self.custom_kernel and launcher.n_spills > self.inductor_meta.get( + "spill_threshold", 16 + ): + return float("inf") + + device_interface = self.get_device_interface() + stream = device_interface.get_raw_stream(device_interface.current_device()) + + def kernel_call(): + cloned_args, cloned_kwargs = self.clone_args(*args, **kwargs) + launcher( + *cloned_args, + **cloned_kwargs, + stream=stream, + ) + + if self.inductor_meta.get("profile_bandwidth_with_do_bench_using_profiling", False): + return do_bench_using_profiling_npu(kernel_call, rep=1) + + return benchmarker.benchmark_gpu(kernel_call, rep=1) + + + def autotune_to_one_config(self, *args, **kwargs): + """Do the actual autotuning""" + start_time = time.time_ns() + timings = self.benchmark_all_configs(*args, **kwargs) + benchmark_time_taken_ns = time.time_ns() - start_time + self.launchers = [builtins.min(timings, key=timings.get)] + self.autotune_time_taken_ns = ( + self.precompile_time_taken_ns + benchmark_time_taken_ns + ) + if self.save_cache_hook: + self.save_cache_hook(self.launchers[0].config, self.autotune_time_taken_ns) + + @lru_cache(None) + def get_fx_graph_dump_path(self): + traced_graph_hash = self.inductor_meta.get("traced_graph_hash") + dump_dir = self.inductor_meta.get("traced_graph_dir", "") + dump_path = os.path.join(dump_dir, traced_graph_hash) + if dump_dir == "" or not os.path.exists(dump_path): + return None + return dump_path + + def get_fx_graph_call(self, auto_fallback=False): + kernel_name = self.inductor_meta.get("kernel_name", "triton_") + traced_graph_hash = self.inductor_meta.get("traced_graph_hash") + dump_dir = self.inductor_meta.get("traced_graph_dir", "") + dump_path = os.path.join(dump_dir, traced_graph_hash) + if dump_dir == "" or not os.path.exists(dump_path): + return None, None, None, None + sys.path.append(dump_path) + fx_module = importlib.import_module(traced_graph_hash) + sys.path.remove(dump_path) + + model = fx_module.model + num_inputs = fx_module.num_inputs + num_outputs = fx_module.num_outputs + non_contiguous_indices = fx_module.non_contiguous_indices + mismatch_indices_shapes = fx_module.mismatch_indices_shapes + + def fx_graph_call(*fx_args): + fx_inputs = [fx_args[idx].contiguous() if idx in non_contiguous_indices['inputs'] else \ + fx_args[idx] for idx in range(num_inputs)] + if len(mismatch_indices_shapes): + for ind, shape in mismatch_indices_shapes.items(): + if ind >= num_inputs: + break + fx_inputs[ind] = fx_inputs[ind].reshape(shape) + model_outputs = model.forward(*fx_inputs) + for idx, (out1, out2) in enumerate(zip(model_outputs, fx_args[num_inputs:(num_inputs + num_outputs)])): + out1 = out1.reshape(out2.shape) + if idx in non_contiguous_indices['outputs']: + out2.copy_(out1) + else: + out2.data = out1.data + + def fallback_call(*args): + fx_args = [args[idx] for idx in fx_module.call_args_mapping] + return fx_graph_call(*fx_args) + + if auto_fallback: + return fallback_call, kernel_name, None, None + return fx_graph_call, kernel_name, dump_path, fx_module + + def data_dump(self, *args, dump_path=None): + dump_path = self.get_fx_graph_dump_path() if dump_path is None else dump_path + if dump_path is None: + log.warning(f"data dump for kernel {self.get_fn_name()} failed, no valid dump_path is supplied.") + return False + data_dump_path = os.path.join(dump_path, 'data.pth') + torch.save(args, data_dump_path) + return True + + def get_fn_name(self): + if self.fn_name is not None: + return self.fn_name + try: + self.fn_name = self.fn.fn.__name__ + except AttributeError: + self.fn_name = "unknown" + return self.fn_name + + def fallback_to_fx(self, *args, launcher, stream, **kwargs): + """ + Try to fallback kernel to fx graph call according to kernel id. + """ + def should_fallback(): + fallback_id = npu_config.force_fallback_kernel_id + if fallback_id != "all" and not isinstance(fallback_id, list): + raise RuntimeError("torch_npu._inductor.config.aot_inductor.force_fallback_kernel_id " + "should be set to 'all' or List, e.g, [1, 2, 10]." + pta_error(ErrCode.VALUE)) + + if isinstance(fallback_id, list): + kernel_name = self.get_fn_name() + try: + kernel_id = int(kernel_name.split("_")[-1]) + except ValueError: + kernel_id = -1 + if kernel_id not in fallback_id: + return False + return True + + if not should_fallback(): + return None + + fx_graph_call, _, _, fx_module = self.get_fx_graph_call() + if not fx_graph_call: + return None + + call_outputs_indices = fx_module.call_args_mapping[fx_module.num_inputs:] + fx_args = [] + for idx in fx_module.call_args_mapping: + arg = args[idx] + if isinstance(arg, torch.Tensor): + fx_arg = clone_preserve_strides(arg).float() if arg.dtype == torch.bfloat16 else clone_preserve_strides( + arg) + fx_args.append(fx_arg) + + fx_graph_call(*fx_args) + for actual, expected in zip([args[i] for i in call_outputs_indices], fx_args[fx_module.num_inputs:]): + if actual.dtype != expected.dtype: + expected = expected.to(actual.dtype) + actual.copy_(expected) + for arg in fx_args: + del arg + return True + + + def check_accuracy(self, *args, launcher, grid, stream, **kwargs): + fx_graph_call, kernel_name, dump_path, fx_module = self.get_fx_graph_call() + if not fx_graph_call: + return None + call_outputs_indices = fx_module.call_args_mapping[fx_module.num_inputs:] + + fx_args = [] + for idx in fx_module.call_args_mapping: + arg = args[idx] + if isinstance(arg, torch.Tensor): + fx_arg = clone_preserve_strides(arg).float() if arg.dtype == torch.bfloat16 else clone_preserve_strides( + arg) + fx_args.append(fx_arg) + + fx_graph_call(*fx_args) + + launcher( + *args, + **kwargs, + stream=stream, + ) + + try: + import msprobe + has_msprobe = True + except ImportError: + has_msprobe = False + warning_once(log, "msprobe import failed, please check. " + "It may be due to missing dependencies or other factors. " + "Check accuracy by tolerance method.") + for actual, expected in zip([args[i] for i in call_outputs_indices], fx_args[fx_module.num_inputs:]): + if actual.dtype != expected.dtype: + expected = expected.to(actual.dtype) + checked_by_msprobe = False + if has_msprobe: + checked_by_msprobe = self.api_accuracy_checker(expected, actual, kernel_name, dump_path) + if not has_msprobe or not checked_by_msprobe: + acc_comp_tol = npu_config.acc_comp_tol.get(actual.dtype, npu_config.acc_comp_tol['default']) + rtol = acc_comp_tol['rtol'] + atol = acc_comp_tol['atol'] + + matches = torch.isclose( + actual, expected, rtol=rtol, atol=atol, equal_nan=False + ) + if not matches.all(): + abs_diff = torch.abs(actual - expected) + rel_diff = abs_diff / torch.abs(expected) + rel_diff.masked_fill_(matches, 0) + log.warning(f"CHECK ACCURACY FAILED! Greatest Relative Difference: {rel_diff.max().item()}, " + f"Kernel Name: {kernel_name}, Dump Path: {dump_path}") + actual.copy_(expected) + del matches + for arg in fx_args: + del arg + return True + + def debug_kernel_in_run(self, *args, launcher, stream, **kwargs): + ''' + Save tensors for kernel args and outputs before and after kernel execute. + These tensors can be load and compared with tensors dumped by aot-inductor cpp runtime. + ''' + dump_path = npu_config.aot_inductor.dump_path_py + if not os.path.exists(dump_path): + os.makedirs(dump_path) + + idx = next(kernel_idx) + fn_name = self.get_fn_name() + dump_args = [arg for arg in args if isinstance(arg, torch.Tensor)] + torch.npu.synchronize() + torch.save(dump_args, f"{dump_path}/{idx}_{fn_name}_before.pt") + + result = super().run(*args, stream=stream, **kwargs) + + torch.npu.synchronize() + torch.save(dump_args, f"{dump_path}/{idx}_{fn_name}_after.pt") + return result + + def maybe_run_debug(self, *args, grid_, stream, launcher, **kwargs): + kernel_name = self.get_fn_name() + log.info(f"Try to run debug mode for kernel {kernel_name}.") + if npu_config.dump_fx_graph: + _ = self.data_dump(*args) + + if npu_config.check_accuracy: + if self.check_accuracy(*args, launcher=launcher, grid=grid_, stream=stream, **kwargs): + return "check_accuracy" + elif npu_config.force_fallback_kernel_id: + fallback_result = self.fallback_to_fx(*args, launcher=launcher, grid_=grid_, stream=stream, **kwargs) + if fallback_result is not None: + log.debug(f"fallback kernel {self.get_fn_name()} to fx graph call.") + return "force_fallback_kernel_id" + else: + log.warning(f"kernel {self.get_fn_name()} could not fallback to fx.") + elif npu_config.aot_inductor.debug_kernel_in_run: + _ = self.debug_kernel_in_run(*args, launcher=launcher, grid_=grid_, stream=stream, **kwargs) + return "debug_kernel_in_run" + + log.info(f"No debug mode is activated for kernel {kernel_name}.") + return None + + def run( + self, *args, stream, benchmark_run=False, **kwargs + ): # type:ignore[override] + if self.triton_interpret: + args, grid = self._interpret_args_grid(args, self.configs[0]) + copied_kwargs = copy.copy(self.configs[0].kwargs) + copied_kwargs.pop('split_axis', None) + copied_kwargs.pop('split_blocks', None) + + return self.fn[grid]( + *args, + **kwargs, + **copied_kwargs, + ) + + if hasattr(self.launchers[0], "fallback"): + return self.launchers[0]( + *args, + **kwargs, + ) + + if len(self.launchers) != 1: + if len(self.launchers) == 0: + start_time = time.time_ns() + self.precompile() + self.precompile_time_taken_ns = time.time_ns() - start_time + if len(self.launchers) > 1: + self.autotune_to_one_config(*args, **kwargs) + + if not getattr( + self.launchers[0].config, "found_by_coordesc", False + ) and self.inductor_meta.get("coordinate_descent_tuning", False): + self.launchers = [ + self.coordinate_descent_tuning( + self.launchers[0], *args, **kwargs + ) + ] + + (launcher, ) = self.launchers + if launcher.store_cubin and (not benchmark_run or not self.cuda_kernel_saved): + self.save_gpu_kernel(stream, launcher) + + if self.dump_launch_params: + _dump_launch_params(args, kwargs, launcher, self.fn.__name__) + + _, grid = self._interpret_args_grid(args, launcher.config) + debug_mode = self.maybe_run_debug(*args, grid_=grid, stream=stream, launcher=launcher, **kwargs) + if debug_mode: + log.info(f"Kernel {self.get_fn_name()} goes into {debug_mode} and return.") + return + + # it is faster than entering and exiting a context manager, even if the context + # manager is a nullcontext. + if autograd_profiler._is_profiler_enabled: + with torch._C._profiler._RecordFunctionFast( + self.inductor_meta.get("kernel_name", "triton kernel"), + args, + { + "kernel_file": (self.filename or ""), + "kernel_hash": self.kernel_hash, + "kernel_backend": "triton", + "stream": stream, + }, + ): + return launcher( + *args, + **kwargs, + stream=stream, + ) + else: + return launcher( + *args, + **kwargs, + stream=stream, + ) + + def _interpret_args_grid( + self, args: tuple[Any, ...], cfg: Config + ) -> tuple[tuple[Any, ...], tuple[int, int, int]]: + + numels = [ + arg + for arg in self.fn.arg_names + if "_numel" in arg + ] + grid = GridExprNpu.from_meta_and_set_numel(self.inductor_meta, cfg, numels).eval_slow( + dict( + zip( + [ + *self.fn.arg_names, + *self.inductor_meta.get("extra_launcher_args", ()), + ], + args, + ) + ) + ) + if self.inductor_meta.get("extra_launcher_args"): + args = args[: -len(self.inductor_meta["extra_launcher_args"])] + return args, grid + + +class NPUDebugAutotuner(NPUCachingAutotuner): + def __init__(self, *args, regex_filter="", **kwargs): + self.regex_filter = regex_filter + super().__init__(*args, **kwargs) + self.cached = None + + def run(self, *args, input_grid, stream): + possible_names = _find_names(self) + kernel_name = f"{max(possible_names, key=len)}" + if not re.match(self.regex_filter, kernel_name): + return + super().run(*args, grid=input_grid, stream=stream) + (launcher,) = self.launchers + + if self.cached is None: + ms = self.bench(launcher, *args, input_grid=input_grid) + num_in_out_ptrs = len( + [ + arg_name + for arg_name in self.fn.arg_names + if arg_name.startswith("in_out_ptr") + ] + ) + num_gb = get_num_bytes(*args, num_in_out_args=num_in_out_ptrs) / 1e9 + gb_per_s = num_gb / (ms / 1e3) + self.cached = (ms, num_gb, gb_per_s, kernel_name) + else: + ms, num_gb, gb_per_s, kernel_name = self.cached + collected_calls.append((ms, num_gb, gb_per_s, kernel_name)) + print( + create_bandwidth_info_str(ms, num_gb, gb_per_s, suffix=f" \t {kernel_name}") + ) + + +def cached_autotune( + size_hints: Optional[List[int]], + configs: List[Config], + triton_meta, + heuristic_type, + filename=None, + inductor_meta=None, + custom_kernel=False, +): + """ + A copy of triton.autotune that calls our subclass. Our subclass + has additional debugging, error handling, and on-disk caching. + """ + configs = unique_configs(configs) + if not (len(configs) == 1 or filename): + raise RuntimeError("assert len(configs) == 1 or filename") + + inductor_meta = {} if inductor_meta is None else inductor_meta + + disabled = inductor_meta.get("force_disable_caches", False) + + # on disk caching logic and/or remote caching + autotune_cache = None + if ( + not disabled + and filename is not None + and (len(configs) > 1 or inductor_meta.get("coordinate_descent_tuning")) + and not os.environ.get("TRITON_INTERPRET", "0") == "1" + ): + configs_hash = hash_configs(configs) + + autotune_cache = AutotuneCache.create(inductor_meta, filename, configs_hash) + if autotune_cache: + best_config = autotune_cache.read_best(inductor_meta, configs) + if best_config: + configs = [best_config] + else: + if disabled: + log.debug("autotune caching is disabled by config.force_disable_caches") + + mutated_arg_names = inductor_meta.pop("mutated_arg_names", ()) + optimize_mem = inductor_meta.pop("optimize_mem", True) + + if "restore_value" in triton_meta: + mutated_arg_names += triton_meta.pop("restore_value") + + reset_to_zero_arg_names: List[str] = [] + if "reset_to_zero" in triton_meta: + reset_to_zero_arg_names.extend(triton_meta.pop("reset_to_zero")) + + def decorator(fn): + + if inductor_meta.get("profile_bandwidth"): + return NPUDebugAutotuner( + fn, + triton_meta=triton_meta, + inductor_meta=inductor_meta, + regex_filter=inductor_meta["profile_bandwidth_regex"], + with_profiler=inductor_meta[ + "profile_bandwidth_with_do_bench_using_profiling" + ], + configs=configs, + save_cache_hook=autotune_cache and autotune_cache.save, + mutated_arg_names=mutated_arg_names, + reset_to_zero_arg_names=reset_to_zero_arg_names, + optimize_mem=optimize_mem, + heuristic_type=heuristic_type, + size_hints=size_hints, + custom_kernel=custom_kernel, + filename=filename, + with_bandwidth_info=True, + ) + return NPUCachingAutotuner( + fn, + triton_meta=triton_meta, + inductor_meta=inductor_meta, + configs=configs, + save_cache_hook=autotune_cache and autotune_cache.save, + mutated_arg_names=mutated_arg_names, + reset_to_zero_arg_names=reset_to_zero_arg_names, + optimize_mem=optimize_mem, + heuristic_type=heuristic_type, + size_hints=size_hints, + custom_kernel=custom_kernel, + filename=filename, + ) + + return decorator + + +# split:sizeof split, xblock:axis1 length, rblock:axis2 length +def triton_config_npu_index( + size_hints, + inductor_meta, + triton_meta=None, + reduction=False, + persistent_reduction=False, + +) -> List[Config]: + num_warps = 1 + num_stages = 1 + configs = [] + split_axis = inductor_meta["split_axis"] + tiling_axis = inductor_meta["tiling_axis"] + low_dims = inductor_meta["low_dims"] + split_axis_dtype = inductor_meta["split_axis_dtype"] + axis_names = inductor_meta["axis_names"] + dual_reduction = inductor_meta["dual_reduction"] + + tile_generator = TileGenerator(size_hints, axis_names, tiling_axis, split_axis, low_dims, + persistent_reduction=persistent_reduction, configs=configs, + dtype=split_axis_dtype, dual_reduction=dual_reduction) + + tile_generator.descend_split_tiling() + + if not configs: + cfg = {} + for x in split_axis: + cfg[f"{axis_names[x].upper()}BLOCK"] = size_hints[x] + if not cfg: + cfg["dummy"] = 1 + tmp = Config(cfg, num_warps=num_warps, num_stages=num_stages) + configs.append(tmp) + + for cfg in configs: + split_blocks = [None for x in split_axis] + for i, axis in enumerate(split_axis): + name = axis_names[axis] + block_name = f"{name.upper()}BLOCK" + split_blocks[i] = cfg.kwargs[block_name] + cfg.kwargs["split_axis"] = tuple(split_axis) + cfg.kwargs["split_blocks"] = tuple(split_blocks) + + return configs + + +def pointwise_npu_index( + size_hints, + triton_meta, + tile_hint=None, + filename=None, + min_elem_per_thread=0, + inductor_meta=None, +): + inductor_meta = {} if inductor_meta is None else inductor_meta + triton_config_with_settings = functools.partial( + triton_config_npu_index + ) + return cached_autotune( + size_hints, + triton_config_with_settings(size_hints, inductor_meta=inductor_meta), + triton_meta=triton_meta, + inductor_meta=inductor_meta, + heuristic_type=HeuristicType.POINTWISE, + filename=filename, + ) + + +def reduction_npu_index( + size_hints, + reduction_hint=False, + triton_meta=None, + filename=None, + inductor_meta=None, +): + """args to @triton.heuristics()""" + inductor_meta = {} if inductor_meta is None else inductor_meta + inductor_meta["reduction_hint"] = reduction_hint + if triton_meta is None: + raise RuntimeError("assert triton_meta is not None") + + contiguous_config = triton_config_npu_index(size_hints, inductor_meta=inductor_meta, reduction=True) + return cached_autotune( + size_hints, + [ + *contiguous_config, + ], + triton_meta=triton_meta, + inductor_meta=inductor_meta, + filename=filename, + heuristic_type=HeuristicType.REDUCTION, + ) + + +def persistent_reduction_npu_index( + size_hints, + reduction_hint=False, + triton_meta=None, + filename=None, + inductor_meta=None, +): + inductor_meta = {} if inductor_meta is None else inductor_meta + inductor_meta["reduction_hint"] = reduction_hint + configs = triton_config_npu_index(size_hints, inductor_meta=inductor_meta, reduction=True, + persistent_reduction=True) + + return cached_autotune( + size_hints, + configs, + triton_meta=triton_meta, + inductor_meta=inductor_meta, + filename=filename, + heuristic_type=HeuristicType.PERSISTENT_REDUCTION, + ) + + +def foreach(triton_meta, num_warps, filename=None, inductor_meta=None): + """ + Compile a triton foreach kernel + """ + return cached_autotune( + None, + [triton.Config({}, num_stages=1, num_warps=num_warps)], + triton_meta=triton_meta, + inductor_meta=inductor_meta, + heuristic_type=HeuristicType.TEMPLATE, + filename=filename, + ) + + +@dynamo_timed +def benchmark_all_configs(self, *args, input_grid, **kwargs): + print(f"candidate launcher count = {len(self.launchers)}") + + tilling_kernel_list = [] + + def kernel_call(launcher): + def call_kernel(): + if launcher.config.pre_hook is not None: + launcher.config.pre_hook( + {**dict(zip(self.arg_names, args)), **launcher.config.kwargs} + ) + cloned_args, cloned_kwargs = self.clone_args(*args, **kwargs) + launcher( + *cloned_args, + **cloned_kwargs, + grid=input_grid, + stream=stream, + ) + + return call_kernel + + for launcher in self.launchers: + if not self.custom_kernel and launcher.n_spills > config.triton.spill_threshold: + return float("inf") + + stream = self.gpu_device.get_raw_stream( # type: ignore[call-arg] + self.gpu_device.current_device() + ) + tilling_kernel_list.append(kernel_call(launcher)) + + def do_batch_benchmark(tilling_kernel_list): + + def delete_file(base_path): + if os.path.exists(base_path): + shutil.rmtree(base_path) + + stream = torch.npu.current_stream() + experimental_config = torch_npu.profiler._ExperimentalConfig( + aic_metrics=torch_npu.profiler.AiCMetrics.PipeUtilization, + profiler_level=torch_npu.profiler.ProfilerLevel.Level1, + l2_cache=False, + data_simplification=False + ) + + random_uuid = uuid.uuid4().hex + md5_hash = hashlib.md5(random_uuid.encode()).hexdigest() + + from torch_npu._inductor.config import profile_path + + torch_path = profile_path + md5_hash + rep = 1 + with torch_npu.profiler.profile( + activities=[ + torch_npu.profiler.ProfilerActivity.NPU + ], + schedule=torch_npu.profiler.schedule(wait=0, warmup=1, active=rep, repeat=1, skip_first=1), + on_trace_ready=torch_npu.profiler.tensorboard_trace_handler(torch_path), + record_shapes=False, + profile_memory=False, + with_stack=False, + with_flops=False, + with_modules=False, + experimental_config=experimental_config) as prof: + stream.synchronize() + for _ in range(rep + 3): + for fn in tilling_kernel_list: + fn() + prof.step() + stream.synchronize() + + import pandas as pd + for root, _, files in os.walk(torch_path): + for file in files: + if file != 'kernel_details.csv': + continue + target_file = os.path.join(root, file) + df = pd.read_csv(target_file) + triton_rows = df[df['Name'].str.startswith('triton', na=False)] + ret = triton_rows['Duration(us)'].astype(float).tolist() + delete_file(torch_path) + return ret + + delete_file(torch_path) + return [] + + try: + timinglist = do_batch_benchmark(tilling_kernel_list) + if not len(timinglist) == len(self.launchers): + raise RuntimeError("not len(timinglist) == len(self.launchers)") + timings = {launcher: timing for launcher, timing in zip(self.launchers, timinglist)} + except Exception as e: + print("some cases in batch benchmark has error! Logging Exception as:") + print(e) + print("switched to single bench...") + timings = { + launcher: self.bench(launcher, *args, **kwargs) + for launcher in self.launchers + } + + for k, v in timings.items(): + self.coordesc_tuner.cache_benchmark_result(k.config, v) + + if log.isEnabledFor(logging.DEBUG): + for k, v in timings.items(): + log.debug( + "%s: %f, nreg %d, nspill %d, #shared-mem %s", + k.config, + v, + k.n_regs, + k.n_spills, + k.shared, + ) + return timings diff --git a/torch_npu/_inductor/runtime.py b/torch_npu/_inductor/runtime.py new file mode 100644 index 0000000000000000000000000000000000000000..9d7716a200818624e83dc478e0e8520782f5fc72 --- /dev/null +++ b/torch_npu/_inductor/runtime.py @@ -0,0 +1,70 @@ +import functools +from typing import List, Dict +from typing import Optional +from torch._inductor.remote_cache import JsonDataTy +from torch._inductor.runtime.hints import DeviceProperties +from torch.utils._triton import has_triton, has_triton_package + +from .config import num_vector_core + +if has_triton_package(): + from triton import Config + + +# overload this to avoid autotune after best_config already generated +def _load_cached_autotuning( + best_config: Dict[str, JsonDataTy], + configs_hash: str, + configs: List[Config], + inductor_meta: Dict, +) -> Optional[Config]: + if best_config is None: + return None + if best_config.pop("configs_hash", None) != configs_hash: + return None + # Remove time taken for comparison + best_config.pop("time_taken_ms", None) + + # if inductor_meta.get("coordinate_descent_tuning") : + num_warps = best_config.pop("num_warps") + num_stages = best_config.pop("num_stages") + triton_config = Config(best_config, num_warps=num_warps, num_stages=num_stages) + triton_config.found_by_coordesc = True + return triton_config + + +class NPUDeviceProperties(DeviceProperties): + + @classmethod + @functools.lru_cache(None) + def create(cls, device) -> DeviceProperties: + import torch + from torch._dynamo.device_interface import get_interface_for_device + + device_type = device.type + + if torch.version.hip and device_type == "cuda": + device_type = "hip" + + device_interface = get_interface_for_device(device) + props = device_interface.get_device_properties(device) + + try: + multi_processor_count = num_vector_core + except AttributeError: + if device_type == "xpu": + multi_processor_count = props.gpu_subslice_count + else: + raise + return cls( + type=device_type, + index=device.index, + multi_processor_count=multi_processor_count, + cc=device_interface.get_compute_capability(device), + major=getattr(props, "major", None), + regs_per_multiprocessor=getattr(props, "regs_per_multiprocessor", None), + max_threads_per_multi_processor=getattr( + props, "max_threads_per_multi_processor", None + ), + warp_size=getattr(props, "warp_size", 32 if device_type != "cpu" else None), + ) diff --git a/torch_npu/_inductor/utils.py b/torch_npu/_inductor/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..095f1f69cf2bff023c2ee492c726940d801c75c5 --- /dev/null +++ b/torch_npu/_inductor/utils.py @@ -0,0 +1,78 @@ +import functools +import torch +import torch_npu + + +# Not good implementation, but no other way +def get_current_raw_stream(device): + return torch.npu.current_stream(device).npu_stream + + +def patch_is_same_tensor(): + from torch._subclasses.fake_tensor import FakeTensor + + def is_same_tensor(data: torch.Tensor, value: torch.Tensor): + if isinstance(data, FakeTensor) or isinstance(value, FakeTensor): + return False + return ( + not data.is_mkldnn + and data.size() == value.size() + and data.stride() == value.stride() + and data.dtype == value.dtype + and data.device == value.device + and data.untyped_storage().data_ptr() == value.untyped_storage().data_ptr() + and data.storage_offset() == value.storage_offset() + ) + + from torch._inductor import utils, graph + utils.is_same_tensor = is_same_tensor + # We need to do extra-patch because of code like `from xxx import is_same_tensor` + graph.is_same_tensor = is_same_tensor + + +def patch_is_gpu(): + from torch._inductor.utils import GPU_TYPES + GPU_TYPES.append('npu') + + +def patch_has_triton(): + from torch.utils._triton import has_triton_package + + @functools.lru_cache(None) + def has_triton() -> bool: + if not has_triton_package(): + return False + + from torch._dynamo.device_interface import get_interface_for_device + + def cuda_extra_check(device_interface): + return True + + def cpu_extra_check(device_interface): + import triton.backends + + return "cpu" in triton.backends.backends + + def _return_true(device_interface): + return True + + triton_supported_devices = { + "cuda": cuda_extra_check, + "xpu": _return_true, + "cpu": cpu_extra_check, + "npu": _return_true + } + + def is_device_compatible_with_triton(): + for device, extra_check in triton_supported_devices.items(): + device_interface = get_interface_for_device(device) + if device_interface.is_available() and extra_check(device_interface): + return True + return False + + return is_device_compatible_with_triton() + + torch.utils._triton.has_triton = has_triton + torch._inductor.scheduler.has_triton = has_triton + + diff --git a/torch_npu/utils/_dynamo.py b/torch_npu/utils/_dynamo.py index 5915b8ed9c96a8b2d46a1cc94a31de505dcc24a4..7b0ec31552a5e5b0bb900f0840e477905a23a573 100644 --- a/torch_npu/utils/_dynamo.py +++ b/torch_npu/utils/_dynamo.py @@ -1,3 +1,4 @@ +import sys import inspect from typing import Dict, List @@ -106,6 +107,62 @@ def TensorVariable_call_method(self, tx, name, args, kwargs): return TensorVariable.call_method_raw(self, tx, name, args, kwargs) +class _InductorNpuRegistry: + _disabled_register = False + _has_inited = False + + @classmethod + def register_inductor_npu(cls): + if cls.has_initialized() or cls._disabled_register: + return + from torch_npu import _inductor + cls._has_inited = True + + @classmethod + def disable_register(cls): + cls._disabled_register = True + + @classmethod + def enable_register(cls): + cls._disabled_register = False + + @classmethod + def has_initialized(cls): + if cls._has_inited: + return True + # Maybe initialized by call `import torch_npu._inductor` manually. + if 'torch_npu._inductor' in sys.modules: + cls._has_inited = True + return cls._has_inited + + +def is_inductor_npu_initialized(): + return _InductorNpuRegistry.has_initialized() + + +def disable_register_inductor_npu(): + _InductorNpuRegistry.disable_register() + + +def enable_register_inductor_npu(): + _InductorNpuRegistry.enable_register() + + +def register_inductor_npu(): + _InductorNpuRegistry.register_inductor_npu() + + +def patch_inductor_wrapper(): + from torch import _TorchCompileInductorWrapper + src_call = _TorchCompileInductorWrapper.__call__ + + def new_call(self, model_, inputs_): + register_inductor_npu() + return src_call(self, model_, inputs_) + + _TorchCompileInductorWrapper.__call__ = new_call + + def patch_dynamo_optimize(): src_optimize = optimize @@ -137,4 +194,5 @@ def add_dynamo_methods(): TensorVariable.call_method_raw = TensorVariable.call_method TensorVariable.call_method = TensorVariable_call_method patch_dynamo_optimize() + patch_inductor_wrapper() diff --git a/torch_npu/utils/_dynamo_device.py b/torch_npu/utils/_dynamo_device.py index 43bc29d8979a2c5743d85d926b31a11e80354a9a..f1e53f4c2ea24fab2d9283985937e9c5e9be9114 100644 --- a/torch_npu/utils/_dynamo_device.py +++ b/torch_npu/utils/_dynamo_device.py @@ -65,11 +65,26 @@ class NpuInterface(DeviceInterface): @staticmethod def get_compute_capability(device=None): - r"""Query the minor and major data of device. Cann does not - have a corresponding concept and is not supported. By default, it returns None + r"""Different from cuda, only return the chip model here. """ - return None + return torch.npu.get_device_name(device) + + @staticmethod + def exchange_device(device: int) -> int: + curr_device = current_device() + set_device(device) + return curr_device + + @staticmethod + def maybe_exchange_device(device: int) -> int: + return device + + @staticmethod + def is_bf16_supported(including_emulation: bool = False): + return True def _dynamo_register_interface_for_device(): register_interface_for_device("npu", NpuInterface) + for i in range(32): + register_interface_for_device(f"npu:{i}", NpuInterface)