From 81cbca158e82892869e6abf01a647a004f5e5261 Mon Sep 17 00:00:00 2001 From: chuboning Date: Thu, 5 Jun 2025 17:06:01 +0800 Subject: [PATCH 01/11] Add 910_95 support --- CMakeLists.txt | 5 +- codegen/gen_backend_stubs.py | 2 + codegen/utils.py | 8 +- test/allowlist_for_publicAPI.json | 5 + test/npu/test_tensors.py | 22 + test/torch_npu_schema.json | 15 +- third_party/acl/inc/acl/acl_base.h | 10 + torch_npu/__init__.py | 8 +- torch_npu/contrib/module/linear_quant.py | 81 ++- torch_npu/contrib/module/quant_conv2d.py | 4 +- torch_npu/csrc/InitNpuBindings.cpp | 2 + torch_npu/csrc/aten/common/CopyKernel.cpp | 6 +- .../csrc/aten/common/FormatCastHelper.cpp | 4 +- .../csrc/aten/common/FormatCastKernelNpu.cpp | 158 ++++- .../csrc/aten/common/LocalScalarDenseNpu.cpp | 23 +- torch_npu/csrc/aten/common/NpuFastReshape.cpp | 2 +- torch_npu/csrc/aten/common/ResizeNpu.cpp | 2 +- torch_npu/csrc/aten/common/ToKernelNpu.cpp | 2 +- torch_npu/csrc/aten/npu_native_functions.yaml | 10 +- .../aten/ops/FlattenDenseTensorsKernelNpu.cpp | 3 +- torch_npu/csrc/core/NPUSerialization.cpp | 2 +- .../csrc/core/npu/NPUCachingAllocator.cpp | 12 +- torch_npu/csrc/core/npu/NPUException.cpp | 2 +- torch_npu/csrc/core/npu/NPUException.h | 2 +- torch_npu/csrc/core/npu/NPUFormat.cpp | 2 +- torch_npu/csrc/core/npu/NPUMacros.h | 2 +- torch_npu/csrc/core/npu/NpuVariables.cpp | 41 +- torch_npu/csrc/core/npu/NpuVariables.h | 5 +- .../csrc/core/npu/interface/AclInterface.cpp | 4 +- .../csrc/core/npu/register/OptionRegister.cpp | 12 + torch_npu/csrc/distributed/Init.cpp | 3 +- .../csrc/distributed/ProcessGroupHCCL.cpp | 16 +- torch_npu/csrc/distributed/reducer.cpp | 2 +- .../csrc/distributed/rpc/tensorpipe_agent.cpp | 3 + torch_npu/csrc/framework/FormatHelper.cpp | 4 + torch_npu/csrc/framework/OpCommand.cpp | 8 +- torch_npu/csrc/framework/OpParamMaker.cpp | 1 + .../csrc/framework/StorageDescHelper.cpp | 23 + torch_npu/csrc/framework/StorageDescHelper.h | 4 + .../framework/contiguous/reshapeV2_opt.cpp | 8 + .../csrc/framework/utils/CalcuOpUtil.cpp | 47 +- torch_npu/csrc/framework/utils/CalcuOpUtil.h | 1 + torch_npu/csrc/framework/utils/NpuUtils.cpp | 2 +- .../csrc/framework/utils/OpPreparation.cpp | 15 +- .../csrc/framework/utils/OpPreparation.h | 1 + torch_npu/csrc/npu/DataParallelComm.cpp | 2 +- .../csrc/transformer_engine/CMakeLists.txt | 6 + .../transformer_engine/CastKernelTeOpApi.cpp | 45 ++ torch_npu/csrc/transformer_engine/Init.cpp | 165 +++++ torch_npu/csrc/transformer_engine/Init.h | 85 +++ torch_npu/csrc/transformer_engine/extension.h | 14 + torch_npu/onnx/wrapper_onnx_ops.py | 70 +- torch_npu/utils/hif8_tensor.py | 635 ++++++++++++++++++ 53 files changed, 1521 insertions(+), 95 deletions(-) create mode 100644 torch_npu/csrc/transformer_engine/CMakeLists.txt create mode 100644 torch_npu/csrc/transformer_engine/CastKernelTeOpApi.cpp create mode 100644 torch_npu/csrc/transformer_engine/Init.cpp create mode 100644 torch_npu/csrc/transformer_engine/Init.h create mode 100644 torch_npu/csrc/transformer_engine/extension.h create mode 100644 torch_npu/utils/hif8_tensor.py diff --git a/CMakeLists.txt b/CMakeLists.txt index ad39472ed6..0150bc1aeb 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -246,6 +246,7 @@ add_subdirectory(${TORCHNPU_ROOT}/core) add_subdirectory(${TORCHNPU_ROOT}/framework) add_subdirectory(${TORCHNPU_ROOT}/flopcount) add_subdirectory(${TORCHNPU_ROOT}/logging) +add_subdirectory(${TORCHNPU_ROOT}/transformer_engine) if (NOT DEFINED BUILD_LIBTORCH) add_subdirectory(${TORCHNPU_ROOT}/distributed) @@ -272,10 +273,10 @@ if (DEFINED BUILD_TENSORPIPE) endif() if (DEFINED BUILD_LIBTORCH) - set(CPP_SRCS ${ATEN_SRCS} ${CORE_SRCS} ${OPS_PLUGIN_SRCS} ${FLOP_SRCS} ${FRAMEWORK_SRCS} ${LOGGING_SRCS} ${NPU_CPP_LIBS_SRCS}) + set(CPP_SRCS ${ATEN_SRCS} ${CORE_SRCS} ${OPS_PLUGIN_SRCS} ${FLOP_SRCS} ${TE_SRCS} ${FRAMEWORK_SRCS} ${LOGGING_SRCS} ${NPU_CPP_LIBS_SRCS}) else() # Compile code with pybind11 - set(CPP_SRCS ${ATEN_SRCS} ${CORE_SRCS} ${OPS_PLUGIN_SRCS} ${DIST_SRCS} ${FLOP_SRCS} ${LOGGING_SRCS} ${FRAMEWORK_SRCS} ${NPU_SRCS} ${PROF_SRCS} ${UTILS_SRCS} ${SAN_SRCS}) + set(CPP_SRCS ${ATEN_SRCS} ${CORE_SRCS} ${OPS_PLUGIN_SRCS} ${DIST_SRCS} ${FLOP_SRCS} ${TE_SRCS} ${LOGGING_SRCS} ${FRAMEWORK_SRCS} ${NPU_SRCS} ${PROF_SRCS} ${UTILS_SRCS} ${SAN_SRCS}) endif() add_library(${PLUGIN_NAME} SHARED ${CPP_SRCS}) diff --git a/codegen/gen_backend_stubs.py b/codegen/gen_backend_stubs.py index bdb6c48a13..248704d492 100644 --- a/codegen/gen_backend_stubs.py +++ b/codegen/gen_backend_stubs.py @@ -395,6 +395,8 @@ def gen_dispatcher_registrations( ns_helper = NamespaceHelper(namespace_str="at") native_func_header = """\ #include "torch_npu/csrc/core/npu/NPURecovery.h" +#include "torch_npu/csrc/core/npu/NpuVariables.h" +#include "torch_npu/csrc/core/npu/NPUException.h" #ifndef BUILD_LIBTORCH #include "torch_npu/csrc/profiler/utils.h" #endif diff --git a/codegen/utils.py b/codegen/utils.py index 187f02fc9d..03e8d2f79c 100644 --- a/codegen/utils.py +++ b/codegen/utils.py @@ -401,6 +401,7 @@ const DeviceGuard device_guard(device_or_default(device));""" device_guard = f"const OptionalDeviceGuard device_guard(device_of({device_of}));" op_key = str(f.func.name) + is_ascend910_95_version = "c10_npu::IsAscend910_95Version()" if enable_opplugin(): if op_key in GLOBAL_STRUCTURED_OP_INFO_CACHE: impl_name = f"op_plugin::{GLOBAL_STRUCTURED_OP_INFO_CACHE[op_key]}" @@ -472,12 +473,17 @@ if (C10_UNLIKELY(at_npu::native::env::CheckOpHookEnable())) {{ if (({force_aclnn} || at_npu::native::env::CheckJitDisable()){tensor_check_str}) {{ return {op_api_impl_name}({args_exprs_str}); }} else {{ + if ({is_ascend910_95_version}) {{ + TORCH_CHECK(false, + "Current aclnn operator {impl_name} do not support internal format.", + PTA_ERROR(ErrCode::NOT_SUPPORT)); + }} return {impl_name}({args_exprs_str}); }} """ else: return_code = f"""\ -if (({force_aclnn} || at_npu::native::env::CheckJitDisable())) {{ +if (({is_ascend910_95_version} || {force_aclnn} || at_npu::native::env::CheckJitDisable())) {{ return {op_api_impl_name}({args_exprs_str}); }} else {{ return {impl_name}({args_exprs_str}); diff --git a/test/allowlist_for_publicAPI.json b/test/allowlist_for_publicAPI.json index a80c38aeb1..9621cbdca8 100644 --- a/test/allowlist_for_publicAPI.json +++ b/test/allowlist_for_publicAPI.json @@ -2528,6 +2528,7 @@ "npu_cross_entropy_loss", "npu_format_cast_", "npu_fusion_attention", + "npu_fusion_attention_v2", "npu_get_float_status", "npu_nms_rotated", "npu_random_choice_with_mask", @@ -2539,6 +2540,7 @@ "npu_mla_prolog_v2", "npu_convert_weight_to_int4pack", "npu_ffn", + "npu_fused_matmul", "npu_geglu", "npu_grouped_matmul", "npu_moe_finalize_routing", @@ -2552,6 +2554,9 @@ "npu_scatter_nd_update_", "npu_stride_copy", "npu_gemma_rms_norm", + "npu_dynamic_mx_quant", + "npu_grouped_dynamic_mx_quant", + "npu_dtype_cast", "npu_swiglu", "npu_gelu", "npu_gelu_backward", diff --git a/test/npu/test_tensors.py b/test/npu/test_tensors.py index 237d6a1aee..ff51e258d3 100644 --- a/test/npu/test_tensors.py +++ b/test/npu/test_tensors.py @@ -1,4 +1,5 @@ from copy import deepcopy +import unittest import numpy as np import torch import torch_npu @@ -22,6 +23,16 @@ types = [ ] +def skipIfUnsupport910_95(): + def skip_dec(func): + def wrapper(self): + if "Ascend910_95" not in torch_npu.npu.get_device_name(): + return unittest.SkipTest("Device 910_95 condition not satisfied") + return func(self) + return wrapper + return skip_dec + + def get_npu_type(type_name): if isinstance(type_name, type): type_name = '{}.{}'.format(type_name.__module__, type_name.__name__) @@ -383,5 +394,16 @@ class TestViewOps(TestCase): self.assertEqual(tensor.view(3, -1).size(), target) +class TestTensorDtype(TestCase): + @skipIfUnsupport910_95() + def test_fp8(self): + tensor1 = torch.randn([2, 2], dtype=torch.float32).npu() + tensor2 = torch.randn([2, 2], dtype=torch.float32).npu() + tensor_f8e5m2 = tensor1.to(torch.float8_e5m2) + tensor_f8e4m3fn = tensor2.to(torch.float8_e4m3fn) + self.assertEqual(tensor_f8e5m2.dtype, torch.float8_e5m2) + self.assertEqual(tensor_f8e4m3fn.dtype, torch.float8_e4m3fn) + + if __name__ == "__main__": run_tests() diff --git a/test/torch_npu_schema.json b/test/torch_npu_schema.json index 65908c94c1..ea7a738c78 100644 --- a/test/torch_npu_schema.json +++ b/test/torch_npu_schema.json @@ -2588,6 +2588,9 @@ "torch_npu.npu_dynamic_quant_asymmetric": { "signature": "(input_dummy, smooth_scales=None, group_index=None, dst_type=torch.int8)" }, + "torch_npu.npu_dynamic_mx_quant": { + "signature": "(*args, **kwargs)" + }, "torch_npu.npu_group_quant": { "signature": "(x, scale, group_index, offset=None, dst_dtype=None)" }, @@ -2595,7 +2598,7 @@ "signature": "(*args, **kwargs)" }, "torch_npu.npu_format_cast": { - "signature": "(self, acl_format)" + "signature": "(self, acl_format, customize_dtype=None)" }, "torch_npu.npu_format_cast_": { "signature": "(*args, **kwargs)" @@ -2835,16 +2838,16 @@ "signature": "(int[] size, *, ScalarType? dtype=None, Device? device=None) -> Tensor" }, "func: npu_format_cast": { - "signature": "(Tensor self, int acl_format) -> Tensor" + "signature": "(Tensor self, int acl_format, int? customize_dtype=None) -> Tensor" }, "func: npu_format_cast_": { - "signature": "(Tensor(a!) self, Tensor src) -> Tensor(a!)" + "signature": "(Tensor(a!) self, Tensor src, int? customize_dtype=None) -> Tensor(a!)" }, "func: npu_format_cast_.acl_format": { - "signature": "(Tensor(a!) self, int acl_format) -> Tensor(a!)" + "signature": "(Tensor(a!) self, int acl_format, int? customize_dtype=None) -> Tensor(a!)" }, "func: npu_format_cast.Tensor": { - "signature": "(Tensor self, Tensor dst) -> Tensor" + "signature": "(Tensor self, Tensor dst, int? customize_dtype=None) -> Tensor" }, "func: npu_change_data_ptr": { "signature": "(Tensor dst, Tensor src, int index) -> int" @@ -2862,7 +2865,7 @@ "signature": "" }, "func: _npu_format_cast": { - "signature": "(Tensor self, int acl_format) -> Tensor" + "signature": "(Tensor self, int acl_format, int? customize_dtype=None) -> Tensor" }, "torch_npu_public_env: INF_NAN_MODE_ENABLE": { "mode": "std::unordered_map infNanMode = {{0, \"max\"}, {1, \"inf_nan\"}}" diff --git a/third_party/acl/inc/acl/acl_base.h b/third_party/acl/inc/acl/acl_base.h index cbcf87b0fc..b9c7346d06 100755 --- a/third_party/acl/inc/acl/acl_base.h +++ b/third_party/acl/inc/acl/acl_base.h @@ -164,6 +164,14 @@ typedef enum { ACL_INT4 = 29, ACL_UINT1 = 30, ACL_COMPLEX32 = 33, + ACL_HIFLOAT8 = 34, + ACL_FLOAT8_E5M2 = 35, + ACL_FLOAT8_E4M3FN = 36, + ACL_FLOAT8_E8M0 = 37, + ACL_FLOAT6_E3M2 = 38, + ACL_FLOAT6_E2M3 = 39, + ACL_FLOAT4_E2M1 = 40, + ACL_FLOAT4_E1M2 = 41, } aclDataType; typedef enum { @@ -182,6 +190,8 @@ typedef enum { ACL_FRACTAL_Z_3D = 33, ACL_FORMAT_NC = 35, ACL_FORMAT_NCL = 47, + ACL_FORMAT_FRACTAL_NZ_C0_16 = 50, + ACL_FORMAT_FRACTAL_NZ_C0_32 = 51, } aclFormat; typedef enum { diff --git a/torch_npu/__init__.py b/torch_npu/__init__.py index a5e842a2a0..78790b96d5 100644 --- a/torch_npu/__init__.py +++ b/torch_npu/__init__.py @@ -1,4 +1,4 @@ -__all__ = ["erase_stream", "matmul_checksum"] +__all__ = ["erase_stream", "matmul_checksum", "HiFloat8Tensor"] import os import sys @@ -61,6 +61,7 @@ from torch_npu.utils.exposed_api import public_npu_functions from torch_npu.distributed.checkpoint.checkpoint import _apply_dcp_patch from torch_npu.npu._stream_check import apply_sanitizer_patch from torch_npu.npu.utils import _erase_stream as erase_stream +from torch_npu.utils.hif8_tensor import HiFloat8Tensor from torch_npu.utils._error_code import ErrCode, pta_error, _except_handler from torch_npu.asd.asd import _asd_patch from torch_npu.asd.checksum import _matmul_checksum as matmul_checksum @@ -90,6 +91,11 @@ for name in dir(torch.ops.npu): __all__.append(name) setattr(torch, name, _wrap_torch_error_func(getattr(torch.ops.npu, name))) +for name in dir(torch_npu._C._te.DType): + if name.startswith('__') or name in ['_dir', 'name']: + continue + setattr(torch_npu, name, getattr(torch_npu._C._te.DType, name)) + all_monkey_patches = [ ["nn.functional", npu_functional], ["nn", npu_modules], diff --git a/torch_npu/contrib/module/linear_quant.py b/torch_npu/contrib/module/linear_quant.py index 5b051b8911..52662252c4 100644 --- a/torch_npu/contrib/module/linear_quant.py +++ b/torch_npu/contrib/module/linear_quant.py @@ -36,6 +36,19 @@ class LinearQuant(nn.Module): If :attr:`bias` is ``True``, the values are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where :math:`k = \frac{1}{\text{in\_features}}` + x1_dtype: Only support torch_npu.hifloat8, torch_npu.float4_e2m1, torch_npu.float4_e1m2, default to None. + x2_dtype: Only support torch_npu.hifloat8, torch_npu.float4_e2m1, torch_npu.float4_e1m2, default to None. + pertoken_scale_dtype: Only support torch_npu.float8_e8m0, default to None. + scale_dtype: Only support torch_npu.float8_e8m0, default to None. + group_sizes: a list of Int, the length of the list is 3. + the first element is group_size_m, means for input, group_size_m elements according to single scale in m dim, + default to 0. + the second element is the group_size_n, means for weight, group_size_n elements according to single scale + in n dim, default to 0. + the third element is the group_size_k, means for input or weight, group_size_k elements according to + single scale in k dim, default to 0. + if any of group_size_m, group_size_n, group_size_k calculated by group_size is 0, recalculate it by + input shape, eg: group_size_m = m / scale_m (m % scale_m must be 0). A4W4 Examples:: >>> x1 = torch.randint(-1, 1, (1, 2), dtype=torch.int32).npu() @@ -60,6 +73,20 @@ class LinearQuant(nn.Module): >>> output = model(x1) >>> print(output.size()) torch.Size(1, 127) + + A8W8 && preblock quantization Examples:: + >>> x1 = torch.randint(-1, 1, (2048, 1024), dtype=torch.float8_e4m3).npu() + >>> x2 = torch.randint(-1, 1, (4096, 1024), dtype=torch.float8_e4m3).npu() + >>> scale = torch.randn((16, 8), dtype=torch.float32).npu() + >>> pertoken_scale = torch.randn((32, 8), dtype=torch.float32).npu() + >>> model = LinearQuant(in_features, out_features, False, group_sizes=[128,128,128]) + >>> model = model.npu() + >>> model.weight.data = x2 + >>> model.scale.data = scale + >>> model.pertoken_scale.data = pertoken_scale + >>> output = model(x1) + >>> print(output.size()) + torch.Size(22048, 4096) """ in_features: int out_features: int @@ -69,8 +96,21 @@ class LinearQuant(nn.Module): pertoken_scale: Tensor bias: Tensor - def __init__(self, in_features: int, out_features: int, *, bias: bool = True, offset: bool = False, - pertoken_scale: bool = False, device=None, dtype=None, output_dtype=None) -> None: + def __init__(self, + in_features: int, + out_features: int, + *, + bias: bool = True, + offset: bool = False, + pertoken_scale: bool = False, + device=None, + dtype=None, + output_dtype=None, + x1_dtype=None, + x2_dtype=None, + pertoken_scale_dtype=None, + scale_dtype=None, + group_sizes=None) -> None: super(LinearQuant, self).__init__() self.in_features = in_features @@ -78,6 +118,11 @@ class LinearQuant(nn.Module): self.weight = Parameter(torch.empty((out_features, in_features)), False) self.scale = Parameter(torch.empty(out_features), False) self.output_dtype = output_dtype + self.x1_dtype = x1_dtype + self.x2_dtype = x2_dtype + self.pertoken_scale_dtype = pertoken_scale_dtype + self.scale_dtype = scale_dtype + self.group_sizes = group_sizes if offset: self.offset = Parameter(torch.empty(out_features, dtype=torch.float32), False) else: @@ -97,16 +142,30 @@ class LinearQuant(nn.Module): scale_quant = self.scale first_last_dim = self.weight.dim() - 1 second_last_dim = self.weight.dim() - 2 - if not ((linear_quant_input.dtype == torch.int32 and self.weight.dtype == torch.int32) or - (linear_quant_input.dtype == torch.int8 and self.weight.dtype == torch.int8)): - raise ValueError("input and weight should be both torch.int32 or both torch.int8 datatype, " - f"but now input is {linear_quant_input.dtype}, weight is {self.weight.dtype}." + ops_error(ErrCode.TYPE)) - is_check_dtype_ok = (self.scale.dtype == torch.float32 and - self.output_dtype not in [torch.bfloat16, torch.int32]) + is_not_int_input = self.weight.dtype not in [torch.int8, torch.int32] + is_check_dtype_ok = (self.scale.dtype == torch.float32 + and (self.output_dtype != torch.bfloat16 or is_not_int_input) + and self.output_dtype != torch.int32) if self.pertoken_scale is None and is_check_dtype_ok: scale_quant = torch_npu.npu_trans_quant_param(self.scale, self.offset) + has_group = (self.group_sizes is not None + and (isinstance(self.group_sizes, list) or isinstance(self.group_sizes, tuple)) + and len(self.group_sizes) == 3 and (self.group_sizes[1] > 1 or self.group_sizes[2] > 1)) + if (scale_quant.dim() > 1 and has_group): + scale_first_last_dim = scale_quant.dim() - 1 + scale_second_last_dim = scale_quant.dim() - 2 + scale_quant = scale_quant.transpose(scale_second_last_dim, scale_first_last_dim) - return torch_npu.npu_quant_matmul(linear_quant_input, self.weight.transpose(second_last_dim, first_last_dim), - scale_quant, offset=self.offset, pertoken_scale=self.pertoken_scale, bias=self.bias, - output_dtype=self.output_dtype) + return torch_npu.npu_quant_matmul(linear_quant_input, + self.weight.transpose(second_last_dim, first_last_dim), + scale_quant, + offset=self.offset, + pertoken_scale=self.pertoken_scale, + bias=self.bias, + output_dtype=self.output_dtype, + x1_dtype=self.x1_dtype, + x2_dtype=self.x2_dtype, + pertoken_scale_dtype=self.pertoken_scale_dtype, + scale_dtype=self.scale_dtype, + group_sizes=self.group_sizes) diff --git a/torch_npu/contrib/module/quant_conv2d.py b/torch_npu/contrib/module/quant_conv2d.py index 1aa59bce43..5bca024785 100644 --- a/torch_npu/contrib/module/quant_conv2d.py +++ b/torch_npu/contrib/module/quant_conv2d.py @@ -108,10 +108,10 @@ class QuantConv2d(nn.Module): self.output_dtype = output_dtype self.weight = \ - Parameter(torch.empty((self.out_channels, self.in_channels, *self.kernel_size), dtype=torch.int8), False) + Parameter(torch.empty((self.out_channels, self.in_channels // self.groups, *self.kernel_size)), False) self.scale = Parameter(torch.empty(self.out_channels, dtype=torch.int64), False) if bias: - self.bias = Parameter(torch.empty(self.out_channels, dtype=torch.int32), False) + self.bias = Parameter(torch.empty(self.out_channels), False) else: self.register_parameter('bias', None) if offset: diff --git a/torch_npu/csrc/InitNpuBindings.cpp b/torch_npu/csrc/InitNpuBindings.cpp index 05ef7980b7..8bc7a11471 100644 --- a/torch_npu/csrc/InitNpuBindings.cpp +++ b/torch_npu/csrc/InitNpuBindings.cpp @@ -15,6 +15,7 @@ #include "torch_npu/csrc/flopcount/Init.h" #include "torch_npu/csrc/logging/Init.h" #include "torch_npu/csrc/npu/Module.h" +#include "torch_npu/csrc/transformer_engine/Init.h" #include "torch_npu/csrc/npu/Stress_detect.h" #include "torch_npu/csrc/utils/TensorType.h" #include "torch_npu/csrc/utils/AutocastMode.h" @@ -167,6 +168,7 @@ PyObject* initModule() AddPyMethodDefs(methods, torch_npu::autocast::autocast_mode_functions()); AddPyMethodDefs(methods, torch_npu::flopcount::flops_count_functions()); AddPyMethodDefs(methods, torch_npu::logging::logging_functions()); + AddPyMethodDefs(methods, torch_npu::te::te_functions()); static struct PyModuleDef torchnpu_module = { PyModuleDef_HEAD_INIT, "torch_npu._C", diff --git a/torch_npu/csrc/aten/common/CopyKernel.cpp b/torch_npu/csrc/aten/common/CopyKernel.cpp index feb0f9b887..03af121f76 100644 --- a/torch_npu/csrc/aten/common/CopyKernel.cpp +++ b/torch_npu/csrc/aten/common/CopyKernel.cpp @@ -76,7 +76,7 @@ void copy_d2d_dtype_format(at::Tensor& self, const at::Tensor& src, bool non_blo at::Tensor src_4D = FormatCastHelper::ApplyBaseFormatTensorBy(src); at::Tensor dst_4D = FormatCastHelper::ApplyBaseFormatTensorBy(self); copy_d2d_dtype_baseformat(dst_4D, src_4D, non_blocking); - NPUNativeFunctions::npu_format_cast_(self, dst_4D); + NPUNativeFunctions::npu_format_cast_(self, dst_4D, c10::nullopt); return; } copy_d2d_dtype_baseformat(self, src, non_blocking); @@ -213,7 +213,7 @@ void copy_h2d(at::Tensor& self, const at::Tensor& src, bool non_blocking) if (!FormatHelper::IsBaseFormatType(self)) { at::Tensor dst = OpPreparation::ApplyTensorWithSizes(self.sizes(), self.options()); copy_h2d_baseformat(dst, src, non_blocking, true); - NPUNativeFunctions::npu_format_cast_(self, dst); + NPUNativeFunctions::npu_format_cast_(self, dst, c10::nullopt); return; } copy_h2d_baseformat(self, src, non_blocking); @@ -326,7 +326,7 @@ void copy_d2d_dtype(at::Tensor& self, const at::Tensor& src, bool non_blocking) } at::Tensor dst_4D = FormatCastHelper::ApplyBaseFormatTensorBy(self); copy_d2d_dtype_baseformat(dst_4D, src_4D, non_blocking); - NPUNativeFunctions::npu_format_cast_(self, dst_4D); + NPUNativeFunctions::npu_format_cast_(self, dst_4D, c10::nullopt); return; } copy_d2d_dtype_format(self, src, non_blocking); diff --git a/torch_npu/csrc/aten/common/FormatCastHelper.cpp b/torch_npu/csrc/aten/common/FormatCastHelper.cpp index 2f61a7c782..ade6a42686 100644 --- a/torch_npu/csrc/aten/common/FormatCastHelper.cpp +++ b/torch_npu/csrc/aten/common/FormatCastHelper.cpp @@ -71,13 +71,13 @@ bool FormatCastHelper::format_cast_between_group( at::Tensor FormatCastHelper::ApplyBaseFormatTensorBy(const at::Tensor& src) { auto format = FormatHelper::GetBaseFormat(src); - return custom_ops::npu_format_cast(src, format); + return custom_ops::npu_format_cast(src, format, c10::nullopt); } at::Tensor& FormatCastHelper::CovertSelfToBaseFormat(at::Tensor& src) { auto format = FormatHelper::GetBaseFormat(src); - return NPUNativeFunctions::npu_format_cast_(src, format); + return NPUNativeFunctions::npu_format_cast_(src, format, c10::nullopt); } } // namespace native diff --git a/torch_npu/csrc/aten/common/FormatCastKernelNpu.cpp b/torch_npu/csrc/aten/common/FormatCastKernelNpu.cpp index 0c4000e524..9bdaadd0f7 100644 --- a/torch_npu/csrc/aten/common/FormatCastKernelNpu.cpp +++ b/torch_npu/csrc/aten/common/FormatCastKernelNpu.cpp @@ -1,16 +1,130 @@ #include "torch_npu/csrc/framework/FormatHelper.h" #include "torch_npu/csrc/framework/utils/OpAdapter.h" #include "torch_npu/csrc/framework/utils/NpuStorageOffsetGuard.h" +#include "torch_npu/csrc/framework/StorageDescHelper.h" #include "torch_npu/csrc/aten/common/FormatCastHelper.h" #include "torch_npu/csrc/aten/NPUNativeFunctions.h" #include "torch_npu/csrc/core/NPUBridge.h" #include "torch_npu/csrc/core/NPUStorageImpl.h" +#include "torch_npu/csrc/core/npu/NpuVariables.h" #include "torch_npu/csrc/aten/CustomFunctions.h" +#include "torch_npu/csrc/transformer_engine/Init.h" +#include "third_party/op-plugin/op_plugin/utils/op_api_common.h" namespace at_npu { namespace native { using tensor_list = std::vector; +using GetFormatFunc = int (*)(const aclTensor *, const int, const int, int64_t **, uint64_t *, int *); + +std::tuple> MaybeUseAclnnNpuFormatCast(const at::Tensor& src, + int64_t acl_format, c10::optional customize_dtype) +{ + const static auto GetFormatFuncAddr = GetOpApiFuncAddr("aclnnNpuFormatCastCalculateSizeAndFormat"); + const static auto FormatCastFuncAddr = GetOpApiFuncAddr("aclnnNpuFormatCast"); + + const static bool aclnnNpuFormatCastExist = + (GetFormatFuncAddr == nullptr || FormatCastFuncAddr == nullptr) ? false : true; + + GetFormatFunc GetFormat = reinterpret_cast(GetFormatFuncAddr); + int64_t *dstStorageShape = nullptr; + uint64_t dstShapeSize = 0; + int dstFormat; + at::SmallVector outputShape = {}; + aclDataType customizeAcltype = (customize_dtype.has_value()) ? + torch_npu::te::GetAclDataType(customize_dtype.value()) : + at_npu::native::OpPreparation::convert_to_acl_data_type(src.scalar_type()); + + if (c10_npu::IsAscend910_95Version()) { + if (aclnnNpuFormatCastExist) { + auto api_ret = GetFormat(ConvertType(src), acl_format, customizeAcltype, &dstStorageShape, + &dstShapeSize, &dstFormat); + NPU_CHECK_ERROR(api_ret, "aclnnNpuFormatCastCalculateSizeAndFormat"); + for (uint64_t i = 0; i < dstShapeSize; i++) { + outputShape.push_back(dstStorageShape[i]); + } + delete[] dstStorageShape; + return std::make_tuple(true, dstFormat, outputShape); + } + TORCH_CHECK(false, + "aclnnNpuFormatCast does not exist, Ascend910_95 series only support aclnn operators.", + PTA_ERROR(ErrCode::NOT_SUPPORT)); + } + if (at_npu::native::env::CheckJitDisable()) { + if (aclnnNpuFormatCastExist) { + auto api_ret = GetFormat(ConvertType(src), acl_format, customizeAcltype, &dstStorageShape, + &dstShapeSize, &dstFormat); + if (api_ret != 0) { + if (customize_dtype.has_value()) { + NPU_CHECK_ERROR(api_ret, "aclnnNpuFormatCastCalculateSizeAndFormat"); + } + return std::make_tuple(false, dstFormat, outputShape); + } + for (uint64_t i = 0; i < dstShapeSize; i++) { + outputShape.push_back(dstStorageShape[i]); + } + delete[] dstStorageShape; + return std::make_tuple(true, dstFormat, outputShape); + } else { + if (C10_UNLIKELY(customize_dtype.has_value())) { + TORCH_CHECK(false, + "customize_dtype is not supported while aclnnNpuFormatCast does not exist.", + PTA_ERROR(ErrCode::NOT_SUPPORT)); + } + return std::make_tuple(false, dstFormat, outputShape); + } + } else { + if (C10_UNLIKELY(customize_dtype.has_value())) { + TORCH_CHECK(false, + "customize_dtype is not supported while jit_compile=True.", + PTA_ERROR(ErrCode::NOT_SUPPORT)); + } + return std::make_tuple(false, dstFormat, outputShape); + } +} + +at::Tensor create_tensor_with_format_and_shape(c10::IntArrayRef baseSizes, + c10::IntArrayRef storageSizes, + const caffe2::TypeMeta dtype, int64_t acl_format) +{ + c10::Allocator *allocator = c10_npu::NPUCachingAllocator::get(); + int64_t nelements = 1; + for (const auto& num : storageSizes) { + nelements *= num; + } + int64_t size_bytes = nelements * dtype.itemsize(); + c10::intrusive_ptr storage_impl = torch_npu::make_npu_storage_impl( + c10::StorageImpl::use_byte_size_t(), + c10::SymInt(size_bytes), + allocator, + true); + auto tensor = at::detail::make_tensor(storage_impl, dtype); + + if (baseSizes.size() != 1 || baseSizes[0] != 0) { + tensor.unsafeGetTensorImpl()->set_sizes_contiguous(baseSizes); + } + tensor.unsafeGetTensorImpl()->empty_tensor_restride(c10::MemoryFormat::Contiguous); + StorageDescHelper::SetDesc(tensor, baseSizes, storageSizes, tensor.strides(), static_cast(acl_format)); + return tensor; +} + +at::Tensor format_cast_impl_out_npu_aclnn(const at::Tensor& src, + int64_t acl_format, c10::IntArrayRef storageSizes) +{ + auto src_new = src.contiguous(); + auto src_new_desc = torch_npu::NPUBridge::GetNpuStorageImpl(src_new)->npu_desc_; + + at::Tensor dst = create_tensor_with_format_and_shape( + src_new_desc.base_sizes_, storageSizes, src.dtype(), acl_format); + + // calculate the output result of the NPU + EXEC_NPU_CMD(aclnnNpuFormatCast, src_new, dst); + + // format cast only change physical layout of base tensor and view tensor's + // metadata remain unchanged + dst.set_(dst.storage(), src_new.storage_offset(), src_new.sizes(), src_new.strides()); + return dst; +} at::Tensor format_cast_impl_out_npu(at::Tensor& dst, const at::Tensor& src) { @@ -36,7 +150,8 @@ at::Tensor format_cast_impl_out_npu(at::Tensor& dst, const at::Tensor& src) } // convert src from src_format to dst_format, write the result into dst(self) -at::Tensor& NPUNativeFunctions::npu_format_cast_(at::Tensor& self, const at::Tensor& src) +at::Tensor& NPUNativeFunctions::npu_format_cast_(at::Tensor& self, const at::Tensor& src, + c10::optional customize_dtype) { torch_npu::utils::torch_check_npu(self); torch_npu::utils::torch_check_npu(src); @@ -47,6 +162,13 @@ at::Tensor& NPUNativeFunctions::npu_format_cast_(at::Tensor& self, const at::Ten return self; } + auto [useAclnn, outFormat, StorageShape] = MaybeUseAclnnNpuFormatCast(self, dst_desc.npu_format_, customize_dtype); + if (useAclnn == true) { + at::Tensor dst = format_cast_impl_out_npu_aclnn(self, outFormat, StorageShape); + self.set_(dst.storage(), dst.storage_offset(), dst.sizes(), dst.strides()); + return self; + } + // calculate the output result of the NPU format_cast_impl_out_npu(self, src); @@ -84,18 +206,20 @@ at::Tensor npu_format_cast_impl( // conver self to dst'format, write the result into new result tensor at::Tensor NPUNativeFunctions::npu_format_cast( const at::Tensor& self, - const at::Tensor& dst) + const at::Tensor& dst, + c10::optional customize_dtype) { torch_npu::utils::torch_check_npu(dst); auto dst_desc = torch_npu::NPUBridge::GetNpuStorageImpl(dst)->npu_desc_; int64_t dst_format = dst_desc.npu_format_; - return custom_ops::npu_format_cast(self, dst_format); + return custom_ops::npu_format_cast(self, dst_format, customize_dtype); } // conver self to acl_format, write the result into self at::Tensor& NPUNativeFunctions::npu_format_cast_( at::Tensor& self, - int64_t acl_format) + int64_t acl_format, + c10::optional customize_dtype) { torch_npu::utils::torch_check_npu(self); auto src_desc = torch_npu::NPUBridge::GetNpuStorageImpl(self)->npu_desc_; @@ -108,6 +232,13 @@ at::Tensor& NPUNativeFunctions::npu_format_cast_( return self; } + auto [useAclnn, outFormat, StorageShape] = MaybeUseAclnnNpuFormatCast(self, acl_format, customize_dtype); + if (useAclnn == true) { + at::Tensor dst = format_cast_impl_out_npu_aclnn(self, outFormat, StorageShape); + self.set_(dst.storage(), dst.storage_offset(), dst.sizes(), dst.strides()); + return self; + } + at::Tensor dst = OpPreparation::ApplyTensorWithFormat( src_desc.base_sizes_, self.options(), acl_format); @@ -128,19 +259,30 @@ int64_t NPUNativeFunctions::get_npu_format(const at::Tensor& self) return src_desc.npu_format_; } -at::Tensor NPUNativeFunctions::_npu_format_cast(const at::Tensor& self, int64_t acl_format) +at::Tensor NPUNativeFunctions::_npu_format_cast(const at::Tensor& self, int64_t acl_format, + c10::optional customize_dtype) { - return npu_format_cast_impl(self, acl_format); + if (FormatHelper::IsBaseFormatType(self) && + FormatHelper::IsBaseFormatType(static_cast(acl_format))) { + FormatCastHelper::format_cast_as_base_format(self, static_cast(acl_format)); + return self; + } + auto [useAclnn, outFormat, StorageShape] = MaybeUseAclnnNpuFormatCast(self, acl_format, customize_dtype); + if (useAclnn == false) { + return npu_format_cast_impl(self, acl_format); + } + return format_cast_impl_out_npu_aclnn(self, outFormat, StorageShape); } -at::Tensor NPUNativeFunctions::npu_format_cast(const at::Tensor& self, int64_t acl_format) +at::Tensor NPUNativeFunctions::npu_format_cast(const at::Tensor& self, int64_t acl_format, + c10::optional customize_dtype) { torch_npu::utils::torch_check_npu(self); if (NPUNativeFunctions::get_npu_format(self) == acl_format) { ASCEND_LOGD("no need to do format cast"); return self; } - return custom_ops::_npu_format_cast(self, acl_format); + return custom_ops::_npu_format_cast(self, acl_format, customize_dtype); } } // namespace native diff --git a/torch_npu/csrc/aten/common/LocalScalarDenseNpu.cpp b/torch_npu/csrc/aten/common/LocalScalarDenseNpu.cpp index 775d95cbfa..5f9fcdc03b 100644 --- a/torch_npu/csrc/aten/common/LocalScalarDenseNpu.cpp +++ b/torch_npu/csrc/aten/common/LocalScalarDenseNpu.cpp @@ -10,13 +10,34 @@ namespace at_npu { namespace native { +#define AT_DISPATCH_CASE_ALL_TYPES_AND5( \ + SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, SCALARTYPE5,...) \ + AT_DISPATCH_CASE_ALL_TYPES(__VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE5, __VA_ARGS__) + + +#define AT_DISPATCH_ALL_TYPES_AND5( \ + SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, SCALARTYPE5, TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, \ + NAME, \ + AT_DISPATCH_CASE_ALL_TYPES_AND5( \ + SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, SCALARTYPE5, __VA_ARGS__)) + + c10::Scalar NPUNativeFunctions::_local_scalar_dense(const at::Tensor& self) { c10::Scalar r; - AT_DISPATCH_ALL_TYPES_AND3( + AT_DISPATCH_ALL_TYPES_AND5( at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16, + at::ScalarType::Float8_e5m2, + at::ScalarType::Float8_e4m3fn, self.scalar_type(), "_local_scalar_dense_npu", [&] { diff --git a/torch_npu/csrc/aten/common/NpuFastReshape.cpp b/torch_npu/csrc/aten/common/NpuFastReshape.cpp index c07136aa50..dfd7d69e0a 100644 --- a/torch_npu/csrc/aten/common/NpuFastReshape.cpp +++ b/torch_npu/csrc/aten/common/NpuFastReshape.cpp @@ -32,7 +32,7 @@ void npu_fast_reshape_(at::Tensor& tensor) // refresh matadata to input tensor StorageDescHelper::ReflushDescBySelf(tensor); auto base_format = InferFormat::GuessBaseFormat(tensor.sizes()); - NPUNativeFunctions::npu_format_cast_(tensor, base_format); + NPUNativeFunctions::npu_format_cast_(tensor, base_format, c10::nullopt); } } // namespace native diff --git a/torch_npu/csrc/aten/common/ResizeNpu.cpp b/torch_npu/csrc/aten/common/ResizeNpu.cpp index af49fa1c33..e329ad41e8 100644 --- a/torch_npu/csrc/aten/common/ResizeNpu.cpp +++ b/torch_npu/csrc/aten/common/ResizeNpu.cpp @@ -46,7 +46,7 @@ const at::Tensor& NPUNativeFunctions::resize_( // no need to reflush NpuStorageDesc here. at::Tensor temp_self = self; if (!FormatHelper::IsBaseFormatType(self)) { - NPUNativeFunctions::npu_format_cast_(temp_self, FormatHelper::GetBaseFormat(self)); + NPUNativeFunctions::npu_format_cast_(temp_self, FormatHelper::GetBaseFormat(self), c10::nullopt); } auto* self_ = self.unsafeGetTensorImpl(); resize_impl_npu_(self_, size, c10::nullopt); diff --git a/torch_npu/csrc/aten/common/ToKernelNpu.cpp b/torch_npu/csrc/aten/common/ToKernelNpu.cpp index 96e67ff5bb..3d2be6452d 100644 --- a/torch_npu/csrc/aten/common/ToKernelNpu.cpp +++ b/torch_npu/csrc/aten/common/ToKernelNpu.cpp @@ -161,7 +161,7 @@ at::Tensor NPUNativeFunctions::to( "dtype cast repalce with float."); } dtype = (dtype == at::ScalarType::Double) ? at::ScalarType::Float : dtype; - return custom_ops::npu_dtype_cast(self, dtype); + return custom_ops::_npu_dtype_cast(self, dtype); } at::Tensor NPUNativeFunctions::to( diff --git a/torch_npu/csrc/aten/npu_native_functions.yaml b/torch_npu/csrc/aten/npu_native_functions.yaml index 95bb740db1..3bc305d03f 100644 --- a/torch_npu/csrc/aten/npu_native_functions.yaml +++ b/torch_npu/csrc/aten/npu_native_functions.yaml @@ -62,12 +62,12 @@ custom: - func: npu_change_data_ptr(Tensor dst, Tensor src, int index) -> int device_check: NoCheck - func: get_npu_format(Tensor self) -> int - - func: npu_format_cast.Tensor(Tensor self, Tensor dst) -> Tensor + - func: npu_format_cast.Tensor(Tensor self, Tensor dst, int? customize_dtype=None) -> Tensor device_check: NoCheck exposed: True - - func: npu_format_cast_.acl_format(Tensor(a!) self, int acl_format) -> Tensor(a!) + - func: npu_format_cast_.acl_format(Tensor(a!) self, int acl_format, int? customize_dtype=None) -> Tensor(a!) exposed: True - - func: npu_format_cast_(Tensor(a!) self, Tensor src) -> Tensor(a!) + - func: npu_format_cast_(Tensor(a!) self, Tensor src, int? customize_dtype=None) -> Tensor(a!) device_check: NoCheck exposed: True - func: empty_with_format(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, int acl_format=2, int? base_addr_aligned_kb=None) -> Tensor @@ -82,9 +82,9 @@ custom: - func: copy_memory_(Tensor(a!) self, Tensor src, bool non_blocking=False) -> Tensor(a!) device_check: NoCheck - func: get_storage_size(Tensor self) -> int - - func: npu_format_cast(Tensor self, int acl_format) -> Tensor + - func: npu_format_cast(Tensor self, int acl_format, int? customize_dtype=None) -> Tensor exposed: True - - func: _npu_format_cast(Tensor self, int acl_format) -> Tensor + - func: _npu_format_cast(Tensor self, int acl_format, int? customize_dtype=None) -> Tensor - func: empty_with_swapped_memory(int[] size, *, ScalarType? dtype=None, Device? device=None) -> Tensor dispatch: CompositeExplicitAutograd: empty_with_swapped_memory diff --git a/torch_npu/csrc/aten/ops/FlattenDenseTensorsKernelNpu.cpp b/torch_npu/csrc/aten/ops/FlattenDenseTensorsKernelNpu.cpp index e2be317874..2f06acbaea 100644 --- a/torch_npu/csrc/aten/ops/FlattenDenseTensorsKernelNpu.cpp +++ b/torch_npu/csrc/aten/ops/FlattenDenseTensorsKernelNpu.cpp @@ -7,7 +7,8 @@ namespace native { at::Tensor NPUNativeFunctions::flatten_dense_tensors(at::TensorList tensors) { static auto cast_back_to_ori_format = [](const at::Tensor& t) { - return custom_ops::npu_format_cast(t, torch_npu::NPUBridge::GetNpuStorageImpl(t)->npu_desc_.origin_format_); + return custom_ops::npu_format_cast(t, torch_npu::NPUBridge::GetNpuStorageImpl(t)->npu_desc_.origin_format_, + c10::nullopt); }; static auto flatten = [](const at::Tensor& t) { return cast_back_to_ori_format(t).contiguous().view({-1}); diff --git a/torch_npu/csrc/core/NPUSerialization.cpp b/torch_npu/csrc/core/NPUSerialization.cpp index 1ae122f342..5cbe642478 100644 --- a/torch_npu/csrc/core/NPUSerialization.cpp +++ b/torch_npu/csrc/core/NPUSerialization.cpp @@ -48,7 +48,7 @@ void npu_info_deserialization(const at::Tensor &t, std::unordered_map(t), format); + at_npu::native::NPUNativeFunctions::npu_format_cast_(const_cast(t), format, c10::nullopt); if (revert_flag) { t.set_requires_grad(true); } diff --git a/torch_npu/csrc/core/npu/NPUCachingAllocator.cpp b/torch_npu/csrc/core/npu/NPUCachingAllocator.cpp index 74afc22031..e1e86a8907 100644 --- a/torch_npu/csrc/core/npu/NPUCachingAllocator.cpp +++ b/torch_npu/csrc/core/npu/NPUCachingAllocator.cpp @@ -726,6 +726,7 @@ BlockState::BlockState(Block *block) SegmentState::SegmentState(Block *head) { + TORCH_INTERNAL_ASSERT(head != nullptr, PTA_ERROR(ErrCode::PTR)); TORCH_INTERNAL_ASSERT(head->prev == nullptr && head->pool != nullptr); is_small = head->pool->is_small; @@ -882,7 +883,7 @@ size_t CachingAllocatorConfig::parseExpandableSegments(const std::vectorsize; auto new_candidate = candidate->next; + if (C10_UNLIKELY(new_candidate == nullptr)) { + return nullptr; + } if (!map_block(new_candidate, std::min(remaining, candidate->next->size), ctx)) { return nullptr; } @@ -2442,7 +2446,11 @@ private: { bool freed_memory = false; for (const auto &name : FreeNPUMemoryCallbacksRegistry()->Keys()) { - freed_memory |= FreeNPUMemoryCallbacksRegistry()->Create(name)->Execute(); + if (FreeNPUMemoryCallbacksRegistry()->Create(name) != nullptr) { + freed_memory |= FreeNPUMemoryCallbacksRegistry()->Create(name)->Execute(); + } else { + TORCH_CHECK(false, "free memory callback get nullptr", PTA_ERROR(ErrCode::PTR)); + } } return freed_memory; } diff --git a/torch_npu/csrc/core/npu/NPUException.cpp b/torch_npu/csrc/core/npu/NPUException.cpp index 034726549b..a91b1d3cac 100644 --- a/torch_npu/csrc/core/npu/NPUException.cpp +++ b/torch_npu/csrc/core/npu/NPUException.cpp @@ -91,7 +91,7 @@ MemUceInfo memUceInfo; std::mutex memUceInfoMutex; -void set_mem_uce_info(MemUceInfo info) +void set_mem_uce_info(MemUceInfo& info) { std::lock_guard lock(memUceInfoMutex); memUceInfo = info; diff --git a/torch_npu/csrc/core/npu/NPUException.h b/torch_npu/csrc/core/npu/NPUException.h index 94e38a5edb..88a77ab810 100644 --- a/torch_npu/csrc/core/npu/NPUException.h +++ b/torch_npu/csrc/core/npu/NPUException.h @@ -259,7 +259,7 @@ bool checkUceErrAndRepair(bool check_error, std::string& err_msg); void record_mem_hbm_ecc_error(); -void set_mem_uce_info(MemUceInfo info); +void set_mem_uce_info(MemUceInfo& info); MemUceInfo get_mem_uce_info(); diff --git a/torch_npu/csrc/core/npu/NPUFormat.cpp b/torch_npu/csrc/core/npu/NPUFormat.cpp index b087842cc3..eed97ce703 100644 --- a/torch_npu/csrc/core/npu/NPUFormat.cpp +++ b/torch_npu/csrc/core/npu/NPUFormat.cpp @@ -37,7 +37,7 @@ std::vector get_npu_storage_sizes(const at::Tensor& self) at::Tensor npu_format_cast(const at::Tensor& self, int64_t acl_format) { - return NPUNativeFunctions::npu_format_cast(self, acl_format); + return NPUNativeFunctions::npu_format_cast(self, acl_format, c10::nullopt); } at::Tensor empty_with_format(c10::IntArrayRef sizes, const c10::TensorOptions& options, diff --git a/torch_npu/csrc/core/npu/NPUMacros.h b/torch_npu/csrc/core/npu/NPUMacros.h index 3223c4f325..960dcb97b6 100644 --- a/torch_npu/csrc/core/npu/NPUMacros.h +++ b/torch_npu/csrc/core/npu/NPUMacros.h @@ -29,6 +29,6 @@ #define TORCH_NPU_API C10_NPU_API -#define C10_COMPILE_TIME_MAX_NPUS 16 +#define C10_COMPILE_TIME_MAX_NPUS 32 // A maximum of 8 P2P links can be created on a NPU device #define C10_P2P_ACCESS_MAX_NPUS 8 diff --git a/torch_npu/csrc/core/npu/NpuVariables.cpp b/torch_npu/csrc/core/npu/NpuVariables.cpp index 3fedb9d387..4a222171ea 100644 --- a/torch_npu/csrc/core/npu/NpuVariables.cpp +++ b/torch_npu/csrc/core/npu/NpuVariables.cpp @@ -41,27 +41,35 @@ static std::map socVersionMap = { void SetSocVersion(const char* const socVersion) { - if (socVersion == nullptr || - g_curSocVersion != SocVersion::UnsupportedSocVersion) { - return; - } + if (socVersion == nullptr || + g_curSocVersion != SocVersion::UnsupportedSocVersion) { + return; + } - SocVersion curSocVersion = SocVersion::UnsupportedSocVersion; + SocVersion curSocVersion = SocVersion::UnsupportedSocVersion; + std::string inputVersion = socVersion; + std::string ascend95Version = "Ascend910_95"; - auto const& iter = socVersionMap.find(socVersion); - if (iter != socVersionMap.end()) { - curSocVersion = iter->second; - } else { - std::string unsupported_soc(socVersion); - std::replace(std::begin(unsupported_soc), std::end(unsupported_soc), '_', ' '); - AT_ERROR("Unsupported soc version: ", unsupported_soc); - } + auto const& iter = socVersionMap.find(socVersion); + if (iter != socVersionMap.end()) { + curSocVersion = iter->second; + } else if ((inputVersion.compare(0, ascend95Version.size(), ascend95Version) == 0)) { + curSocVersion = SocVersion::Ascend910_95; + } else { + std::string unsupported_soc(socVersion); + std::replace(std::begin(unsupported_soc), std::end(unsupported_soc), '_', ' '); + AT_ERROR("Unsupported soc version: ", unsupported_soc); + } - g_curSocVersion = curSocVersion; + g_curSocVersion = curSocVersion; } const SocVersion& GetSocVersion() { + if (g_curSocVersion == SocVersion::UnsupportedSocVersion) { + auto soc_name = c10_npu::acl::AclGetSocName(); + SetSocVersion(soc_name); + } return g_curSocVersion; } @@ -95,5 +103,10 @@ bool IsBF16Supported() { return GetSocVersion() >= SocVersion::Ascend910B1; } + +bool IsAscend910_95Version() +{ + return GetSocVersion() == SocVersion::Ascend910_95; +} } // namespace c10_npu diff --git a/torch_npu/csrc/core/npu/NpuVariables.h b/torch_npu/csrc/core/npu/NpuVariables.h index 3119a64515..b5a55c5f69 100644 --- a/torch_npu/csrc/core/npu/NpuVariables.h +++ b/torch_npu/csrc/core/npu/NpuVariables.h @@ -30,7 +30,8 @@ enum class SocVersion { Ascend910_9381, Ascend910_9382, Ascend910_9372, - Ascend910_9362 + Ascend910_9362, + Ascend910_95 = 260 }; void SetSocVersion(const char* const socVersion); @@ -40,6 +41,8 @@ const SocVersion& GetSocVersion(); bool IsSupportInfNan(); bool IsBF16Supported(); + +bool IsAscend910_95Version(); } // namespace c10_npu #endif diff --git a/torch_npu/csrc/core/npu/interface/AclInterface.cpp b/torch_npu/csrc/core/npu/interface/AclInterface.cpp index b59e9c85c9..520393355b 100644 --- a/torch_npu/csrc/core/npu/interface/AclInterface.cpp +++ b/torch_npu/csrc/core/npu/interface/AclInterface.cpp @@ -174,6 +174,7 @@ aclError AclrtSetStreamFailureMode(aclrtStream stream, uint64_t mode) { if (stream == nullptr) { // default stream return ACL_ERROR_INVALID_PARAM; } + typedef aclError(*aclrtSetStreamFailureModeFunc)(aclrtStream, uint64_t); static aclrtSetStreamFailureModeFunc func = (aclrtSetStreamFailureModeFunc)GET_FUNC(aclrtSetStreamFailureMode); if (func == nullptr) { @@ -844,7 +845,8 @@ bool IsCaptureSupported() static bool have_load_func = false; static bool default_support_capture = ((GetSocVersion() >= SocVersion::Ascend910B1) && (GetSocVersion() < SocVersion::Ascend310B1)) || - (GetSocVersion() >= SocVersion::Ascend910_9391); + ((GetSocVersion() >= SocVersion::Ascend910_9391) && + (GetSocVersion() < SocVersion::Ascend910_95)); if (default_support_capture && !have_load_func) { have_load_func = true; typedef aclError (*AclmdlRICaptureGetInfo)(aclrtStream, aclmdlRICaptureStatus *, aclmdlRI *); diff --git a/torch_npu/csrc/core/npu/register/OptionRegister.cpp b/torch_npu/csrc/core/npu/register/OptionRegister.cpp index 8f7f17a011..9e0c356a04 100644 --- a/torch_npu/csrc/core/npu/register/OptionRegister.cpp +++ b/torch_npu/csrc/core/npu/register/OptionRegister.cpp @@ -4,6 +4,7 @@ #include "torch_npu/csrc/core/npu/register/OptionRegister.h" #include "torch_npu/csrc/core/npu/sys_ctrl/npu_sys_ctrl.h" #include "torch_npu/csrc/core/npu/npu_log.h" +#include "torch_npu/csrc/core/npu/NpuVariables.h" namespace c10_npu { namespace option { @@ -84,6 +85,17 @@ OptionInterfaceBuilder::OptionInterfaceBuilder(const std::string &name, ::std::u void SetOption(const std::string &key, const std::string &val) { + if (c10_npu::IsAscend910_95Version()) { + if (key == "jitCompile" && val == "enable") { + TORCH_NPU_WARN_ONCE("Ascend910_95 series only support jit_compile=False, ", + "the requested value True is invalid and has been reverted to False."); + } + if (key == "ALLOW_INTERNAL_FORMAT" && val == "enable") { + TORCH_NPU_WARN_ONCE("Ascend910_95 series only support allow_internal_format=False, ", + "the requested value True is invalid and has been reverted to False."); + } + return register_options::OptionRegister::GetInstance()->Set(key, "disable"); + } register_options::OptionRegister::GetInstance()->Set(key, val); } diff --git a/torch_npu/csrc/distributed/Init.cpp b/torch_npu/csrc/distributed/Init.cpp index 1e40130f30..4f9bb5ef62 100644 --- a/torch_npu/csrc/distributed/Init.cpp +++ b/torch_npu/csrc/distributed/Init.cpp @@ -98,7 +98,8 @@ public: inline std::vector cast_tensors(at::TensorList tensors) const { static auto cast_back_to_ori_format = [](const at::Tensor &t) { - return at_npu::native::custom_ops::npu_format_cast(t, torch_npu::NPUBridge::GetNpuStorageImpl(t)->npu_desc_.origin_format_); + return at_npu::native::custom_ops::npu_format_cast(t, + torch_npu::NPUBridge::GetNpuStorageImpl(t)->npu_desc_.origin_format_, c10::nullopt); }; return c10::fmap(tensors, cast_back_to_ori_format); } diff --git a/torch_npu/csrc/distributed/ProcessGroupHCCL.cpp b/torch_npu/csrc/distributed/ProcessGroupHCCL.cpp index 3d91f787bf..a7d61f130a 100644 --- a/torch_npu/csrc/distributed/ProcessGroupHCCL.cpp +++ b/torch_npu/csrc/distributed/ProcessGroupHCCL.cpp @@ -2565,7 +2565,7 @@ std::vector cast_to_origin_format(const std::vector& inp inputTensors_[index] = tensor; } else { auto origin_format = torch_npu::NPUBridge::GetNpuStorageImpl(tensor)->npu_desc_.origin_format_; - inputTensors_[index] = at_npu::native::custom_ops::npu_format_cast(tensor, origin_format); + inputTensors_[index] = at_npu::native::custom_ops::npu_format_cast(tensor, origin_format, c10::nullopt); } index++; } @@ -3634,7 +3634,7 @@ c10::intrusive_ptr ProcessGroupHCCL::allreduce( [&](std::vector& hcclStreams, c10::intrusive_ptr&) { if (tensors[0].scalar_type() == at::kBool || tensors[0].scalar_type() == at::kByte) { c10_npu::NPUStreamGuard guard(hcclStreams[0]); - tensors_cp[0] = at_npu::native::custom_ops::npu_dtype_cast(tensors[0], at::kInt); + tensors_cp[0] = at_npu::native::custom_ops::_npu_dtype_cast(tensors[0], at::kInt); } }, [&](std::vector& hcclStreams, c10::intrusive_ptr&) { @@ -3812,7 +3812,7 @@ c10::intrusive_ptr ProcessGroupHCCL::allreduce_coalesced( for (const auto i : c10::irange(tensors.size())) { if (tensors[i].scalar_type() == at::kBool || tensors[i].scalar_type() == at::kByte) { c10_npu::NPUStreamGuard guard(hcclStreams[0]); - tensors_cp[i] = at_npu::native::custom_ops::npu_dtype_cast(tensors[i], at::kInt); + tensors_cp[i] = at_npu::native::custom_ops::_npu_dtype_cast(tensors[i], at::kInt); } } }, @@ -3876,7 +3876,7 @@ c10::intrusive_ptr ProcessGroupHCCL::reduce( [&](std::vector& hcclStreams, c10::intrusive_ptr&) { if (tensors[0].scalar_type() == at::kBool || tensors[0].scalar_type() == at::kByte) { c10_npu::NPUStreamGuard guard(hcclStreams[0]); - tensors_cp[0] = at_npu::native::custom_ops::npu_dtype_cast(tensors[0], at::kInt); + tensors_cp[0] = at_npu::native::custom_ops::_npu_dtype_cast(tensors[0], at::kInt); } }, [&](std::vector& hcclStreams, c10::intrusive_ptr&) { @@ -3936,11 +3936,11 @@ c10::intrusive_ptr ProcessGroupHCCL::_reduce_oop( [&](std::vector& hcclStreams, c10::intrusive_ptr&) { if (inputTensors[0].scalar_type() == at::kBool || inputTensors[0].scalar_type() == at::kByte) { c10_npu::NPUStreamGuard guard(hcclStreams[0]); - inputTensors[0] = at_npu::native::custom_ops::npu_dtype_cast(inputTensors[0], at::kInt); + inputTensors[0] = at_npu::native::custom_ops::_npu_dtype_cast(inputTensors[0], at::kInt); } if (outputTensors[0].scalar_type() == at::kBool || outputTensors[0].scalar_type() == at::kByte) { c10_npu::NPUStreamGuard guard(hcclStreams[0]); - outputTensors[0] = at_npu::native::custom_ops::npu_dtype_cast(outputTensors[0], at::kInt); + outputTensors[0] = at_npu::native::custom_ops::_npu_dtype_cast(outputTensors[0], at::kInt); } }, [&](std::vector& hcclStreams, c10::intrusive_ptr&) { @@ -3975,14 +3975,14 @@ at::Tensor ProcessGroupHCCL::byte_alignment(at::Tensor& tensors) const if (num_add != 0) { bool transflag = false; if (inter_tensors.scalar_type() == at::ScalarType::Bool) { - inter_tensors = at_npu::native::custom_ops::npu_dtype_cast(inter_tensors, at::ScalarType::Int); + inter_tensors = at_npu::native::custom_ops::_npu_dtype_cast(inter_tensors, at::ScalarType::Int); transflag = true; } inter_tensors = op_plugin::constant_pad_nd(inter_tensors, {0, num_add}, 0); if (transflag) { - inter_tensors = at_npu::native::custom_ops::npu_dtype_cast(inter_tensors, at::ScalarType::Bool); + inter_tensors = at_npu::native::custom_ops::_npu_dtype_cast(inter_tensors, at::ScalarType::Bool); } } return inter_tensors; diff --git a/torch_npu/csrc/distributed/reducer.cpp b/torch_npu/csrc/distributed/reducer.cpp index da3664149b..f18614f23a 100644 --- a/torch_npu/csrc/distributed/reducer.cpp +++ b/torch_npu/csrc/distributed/reducer.cpp @@ -368,7 +368,7 @@ void Reducer::mark_variable_ready_dense(size_t variable_index) if (torch_npu::NPUBridge::GetNpuStorageImpl(grad)->npu_desc_.npu_format_ != torch_npu::NPUBridge::GetNpuStorageImpl(variable)->npu_desc_.npu_format_) { grad = at_npu::native::NPUNativeFunctions::npu_format_cast(grad, - torch_npu::NPUBridge::GetNpuStorageImpl(variable)->npu_desc_.npu_format_); + torch_npu::NPUBridge::GetNpuStorageImpl(variable)->npu_desc_.npu_format_, c10::nullopt); } if (comm_hook_ == nullptr) { if (!grad.requires_grad()) { diff --git a/torch_npu/csrc/distributed/rpc/tensorpipe_agent.cpp b/torch_npu/csrc/distributed/rpc/tensorpipe_agent.cpp index 655082b56f..edebbba53f 100644 --- a/torch_npu/csrc/distributed/rpc/tensorpipe_agent.cpp +++ b/torch_npu/csrc/distributed/rpc/tensorpipe_agent.cpp @@ -423,6 +423,9 @@ void TensorPipeAgent::startImpl() priority = opts_.transports->size() - 1 - (iter - opts_.transports->begin()); } std::unique_ptr reg = TensorPipeTransportRegistry()->Create(key); + if (reg == nullptr || reg->transport == nullptr) { + TORCH_CHECK(false, "TensorPipeTransport get nullptr", DIST_ERROR(ErrCode::PTR)); + } if (!reg->transport->isViable()) { continue; } diff --git a/torch_npu/csrc/framework/FormatHelper.cpp b/torch_npu/csrc/framework/FormatHelper.cpp index 6a92fe5af4..9bd270b8fd 100644 --- a/torch_npu/csrc/framework/FormatHelper.cpp +++ b/torch_npu/csrc/framework/FormatHelper.cpp @@ -52,6 +52,10 @@ std::unordered_map FormatHelper::Initialize {ACL_FORMAT_NDC1HWC0, (FormatInfo){ACL_FORMAT_NDC1HWC0, ACL_FORMAT_NCDHW, InferShapeOfNDC1HWC0, "NDC1HWC0", true}}, {ACL_FRACTAL_Z_3D, (FormatInfo){ACL_FRACTAL_Z_3D, ACL_FORMAT_NCDHW, InferShapeOfFZ3D, "FRACTAL_Z_3D", true}}, + {ACL_FORMAT_FRACTAL_NZ_C0_16, + (FormatInfo){ACL_FORMAT_FRACTAL_NZ_C0_16, ACL_FORMAT_ND, nullptr, "FRACTAL_NZ_C0_16", true}}, + {ACL_FORMAT_FRACTAL_NZ_C0_32, + (FormatInfo){ACL_FORMAT_FRACTAL_NZ_C0_32, ACL_FORMAT_ND, nullptr, "FRACTAL_NZ_C0_32", true}}, }; }; diff --git a/torch_npu/csrc/framework/OpCommand.cpp b/torch_npu/csrc/framework/OpCommand.cpp index 6b98651c51..80af05f94b 100644 --- a/torch_npu/csrc/framework/OpCommand.cpp +++ b/torch_npu/csrc/framework/OpCommand.cpp @@ -24,7 +24,9 @@ static std::unordered_map> floating_limits_m {at::ScalarType::Double, {std::numeric_limits::max(), std::numeric_limits::min()}}, {at::ScalarType::Float, {std::numeric_limits::max(), std::numeric_limits::min()}}, {at::ScalarType::BFloat16, {std::numeric_limits::max(), std::numeric_limits::min()}}, - {at::ScalarType::Half, {65504, -65504}}}; + {at::ScalarType::Half, {65504, -65504}}, + {at::ScalarType::Float8_e5m2, {57345, -57345}}, + {at::ScalarType::Float8_e4m3fn, {449, -449}}}; static std::unordered_map> integral_limits_map{ {at::ScalarType::Long, {std::numeric_limits::max(), std::numeric_limits::min()}}, {at::ScalarType::Int, {std::numeric_limits::max(), std::numeric_limits::min()}}, @@ -274,7 +276,7 @@ OpCommand& OpCommand::AddTensorInput(at::Tensor &tensor, at::ScalarType forceSca { std::tuple res; if (commonType.has_value() && commonType.value() != tensor.scalar_type()) { - tensor = custom_ops::npu_dtype_cast(tensor, commonType.value()); + tensor = custom_ops::_npu_dtype_cast(tensor, commonType.value()); } // as for dim=0, the dtype of tensor can not be `uint16` because of `TBE` if (torch_npu::NPUBridge::GetNpuStorageImplDesc(tensor).storage_sizes_.empty()) { @@ -331,7 +333,7 @@ OpCommand& OpCommand::AddScalarInput(const c10::Scalar& input, at::ScalarType ty OpCommand& OpCommand::AddOutput(at::Tensor &output, const string &realType) { if (resultTypeDefined == false && commonType.has_value() && commonType.value() != output.scalar_type()) { - output = custom_ops::npu_dtype_cast(output, commonType.value()); + output = custom_ops::_npu_dtype_cast(output, commonType.value()); } auto res = OpCmdHelper::CovertToAclOutput(output, realType); aclCmd->AddOutput(std::get<0>(res), std::get<1>(res)); diff --git a/torch_npu/csrc/framework/OpParamMaker.cpp b/torch_npu/csrc/framework/OpParamMaker.cpp index f1b9064b6d..aac0ba1814 100644 --- a/torch_npu/csrc/framework/OpParamMaker.cpp +++ b/torch_npu/csrc/framework/OpParamMaker.cpp @@ -574,6 +574,7 @@ void *NewFunc(int caption, int &size) void DeleteFunc(void *ptr) { free(ptr); + ptr = nullptr; } using Func = int (*)(c10_npu::queue::QueueParas *, aclrtStream); diff --git a/torch_npu/csrc/framework/StorageDescHelper.cpp b/torch_npu/csrc/framework/StorageDescHelper.cpp index fecbb86f1f..6f52465d1a 100644 --- a/torch_npu/csrc/framework/StorageDescHelper.cpp +++ b/torch_npu/csrc/framework/StorageDescHelper.cpp @@ -97,6 +97,13 @@ void StorageDescHelper::SetDesc(at::Tensor &dst, const c10::IntArrayRef &size, c torch_npu::NPUBridge::GetNpuStorageImpl(dst)->npu_desc_ = SetDesc(dst.dtype(), size, strides, format); } +void StorageDescHelper::SetDesc(at::Tensor &dst, const c10::IntArrayRef &base_size, + const c10::IntArrayRef &storage_size, const c10::IntArrayRef &strides, aclFormat format) +{ + torch_npu::NPUBridge::GetNpuStorageImpl(dst)->npu_desc_ = + SetDesc(dst.dtype(), base_size, storage_size, strides, format); +} + bool StorageDescHelper::CheckDescInit(const c10::Storage &storage) { return torch_npu::NPUBridge::GetNpuStorageImpl(storage.unsafeGetStorageImpl())->npu_desc_.origin_format_ != @@ -254,6 +261,22 @@ torch_npu::NPUStorageDesc StorageDescHelper::SetDesc(const caffe2::TypeMeta &dty return npu_desc; } +torch_npu::NPUStorageDesc StorageDescHelper::SetDesc(const caffe2::TypeMeta &dtype, const c10::IntArrayRef& base_size, + const c10::IntArrayRef& storage_size, const c10::IntArrayRef& strides, aclFormat format) +{ + struct torch_npu::NPUStorageDesc npu_desc; + npu_desc.data_type_ = dtype; + npu_desc.base_sizes_ = base_size; + npu_desc.base_strides_ = strides; + aclFormat baseFormat; + aclFormat npuFormat; + std::tie(baseFormat, npuFormat) = InferFormat::GuessFormatUnit(base_size, format); + npu_desc.storage_sizes_ = storage_size; + npu_desc.origin_format_ = baseFormat; + npu_desc.npu_format_ = npuFormat; + return npu_desc; +} + int64_t StorageDescHelper::GetMemorySize(const torch_npu::NPUStorageDesc &dst) { const auto &physical_size = FormatHelper::GetStorageSizes(dst); diff --git a/torch_npu/csrc/framework/StorageDescHelper.h b/torch_npu/csrc/framework/StorageDescHelper.h index 5c16ee74e2..37b8933c1a 100644 --- a/torch_npu/csrc/framework/StorageDescHelper.h +++ b/torch_npu/csrc/framework/StorageDescHelper.h @@ -35,6 +35,8 @@ public: static void SetDesc(at::Tensor &dst, const c10::IntArrayRef& size, const c10::IntArrayRef& strides); static void SetDesc(at::Tensor &dst, const c10::IntArrayRef &size, const c10::IntArrayRef &strides, aclFormat format); + static void SetDesc(at::Tensor &dst, const c10::IntArrayRef &base_size, + const c10::IntArrayRef &storage_size, const c10::IntArrayRef &strides, aclFormat format); static bool CheckDescInit(const c10::Storage &storage); // For Serialization to Get and Set NpuStorageDesc @@ -63,6 +65,8 @@ private: const c10::IntArrayRef& strides); static torch_npu::NPUStorageDesc SetDesc(const caffe2::TypeMeta &dtype, const c10::IntArrayRef& size, const c10::IntArrayRef& strides, aclFormat format); + static torch_npu::NPUStorageDesc SetDesc(const caffe2::TypeMeta &dtype, const c10::IntArrayRef& base_size, + const c10::IntArrayRef& storage_size, const c10::IntArrayRef& strides, aclFormat format); }; } // namespace native diff --git a/torch_npu/csrc/framework/contiguous/reshapeV2_opt.cpp b/torch_npu/csrc/framework/contiguous/reshapeV2_opt.cpp index c2abf7f4b2..ee90387910 100644 --- a/torch_npu/csrc/framework/contiguous/reshapeV2_opt.cpp +++ b/torch_npu/csrc/framework/contiguous/reshapeV2_opt.cpp @@ -70,6 +70,14 @@ private: ResetDataPtr(src, self, static_cast(src.storage().data_ptr().get())); return true; + case at::ScalarType::Float8_e5m2: + ResetDataPtr(src, self, + static_cast(src.storage().data_ptr().get())); + return true; + case at::ScalarType::Float8_e4m3fn: + ResetDataPtr(src, self, + static_cast(src.storage().data_ptr().get())); + return true; default: // Turn to conducting d2dCopyAsync for other dtypes. return false; diff --git a/torch_npu/csrc/framework/utils/CalcuOpUtil.cpp b/torch_npu/csrc/framework/utils/CalcuOpUtil.cpp index c2bb14ca66..453a7082da 100644 --- a/torch_npu/csrc/framework/utils/CalcuOpUtil.cpp +++ b/torch_npu/csrc/framework/utils/CalcuOpUtil.cpp @@ -52,8 +52,8 @@ AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(ENUM_PAIR_FUNC) _(at::ScalarType::Bits4x2, ACL_DT_UNDEFINED) \ _(at::ScalarType::Bits8, ACL_DT_UNDEFINED) \ _(at::ScalarType::Bits16, ACL_DT_UNDEFINED) \ - _(at::ScalarType::Float8_e5m2, ACL_DT_UNDEFINED) \ - _(at::ScalarType::Float8_e4m3fn, ACL_DT_UNDEFINED) \ + _(at::ScalarType::Float8_e5m2, ACL_FLOAT8_E5M2) \ + _(at::ScalarType::Float8_e4m3fn, ACL_FLOAT8_E4M3FN) \ _(at::ScalarType::Undefined, ACL_DT_UNDEFINED) \ _(at::ScalarType::NumOptions, ACL_DT_UNDEFINED) @@ -74,6 +74,37 @@ AT_ALL_SCALAR_TYPE_AND_ACL_DATATYPE_PAIR(ENUM_PAIR_FUNC) static std::map STRING_SCALAR_TYPE_TO_ACL_TYPE_MAP = { {"uint16", ACL_UINT16}, {"uint8", ACL_UINT8}, {"uint64", ACL_UINT64}, {"string", ACL_STRING}}; +// at::ScalarType::UInt16/UInt32/UInt64 will be supported after v2.1.0 +static std::unordered_map + ACL_TYPE_TO_SCALAR_TYPE_MAP = {{ACL_DT_UNDEFINED, at::ScalarType::Undefined}, + {ACL_FLOAT, at::ScalarType::Float}, + {ACL_FLOAT16, at::ScalarType::Half}, + {ACL_INT8, at::ScalarType::Char}, + {ACL_INT32, at::ScalarType::Int}, + {ACL_UINT8, at::ScalarType::Byte}, + {ACL_INT16, at::ScalarType::Short}, + {ACL_UINT16, at::ScalarType::Undefined}, + {ACL_UINT32, at::ScalarType::Undefined}, + {ACL_INT64, at::ScalarType::Long}, + {ACL_UINT64, at::ScalarType::Undefined}, + {ACL_DOUBLE, at::ScalarType::Double}, + {ACL_BOOL, at::ScalarType::Bool}, + {ACL_STRING, at::ScalarType::Undefined}, + {ACL_COMPLEX64, at::ScalarType::ComplexFloat}, + {ACL_COMPLEX128, at::ScalarType::ComplexDouble}, + {ACL_BF16, at::ScalarType::BFloat16}, + {ACL_INT4, at::ScalarType::Undefined}, + {ACL_UINT1, at::ScalarType::Undefined}, + {ACL_COMPLEX32, at::ScalarType::ComplexHalf}, + {ACL_HIFLOAT8, at::ScalarType::Byte}, + {ACL_FLOAT8_E5M2, at::ScalarType::Float8_e5m2}, + {ACL_FLOAT8_E4M3FN, at::ScalarType::Float8_e4m3fn}, + {ACL_FLOAT8_E8M0, at::ScalarType::Byte}, + {ACL_FLOAT6_E3M2, at::ScalarType::Byte}, + {ACL_FLOAT6_E2M3, at::ScalarType::Byte}, + {ACL_FLOAT4_E2M1, at::ScalarType::Byte}, + {ACL_FLOAT4_E1M2, at::ScalarType::Byte}}; + aclError AclrtMemcpyAsyncParamCheck( void *dst, size_t destMax, const void *src, size_t count, aclrtMemcpyKind kind, aclrtStream stream) { @@ -291,5 +322,17 @@ int8_t CalcuOpUtil::GetCubeMathType(bool allowHf32) return iter->second; } +at::ScalarType CalcuOpUtil::ConvertToScalarType(const aclDataType data_type) +{ + auto iter = ACL_TYPE_TO_SCALAR_TYPE_MAP.find(data_type); + if (iter == ACL_TYPE_TO_SCALAR_TYPE_MAP.end()) { + TORCH_CHECK(false, + std::string("aclDataType:") + std::to_string(data_type) + " has not been supported", + OPS_ERROR(ErrCode::NOT_SUPPORT)) + } + + return iter->second; +} + } // namespace native } // namespace at_npu diff --git a/torch_npu/csrc/framework/utils/CalcuOpUtil.h b/torch_npu/csrc/framework/utils/CalcuOpUtil.h index b06ab06f90..9a4a802443 100644 --- a/torch_npu/csrc/framework/utils/CalcuOpUtil.h +++ b/torch_npu/csrc/framework/utils/CalcuOpUtil.h @@ -86,6 +86,7 @@ public: static int64_t GetTensorNpuFormat(const at::Tensor &tensor); static c10::SmallVector ConvertIntArrayRefToSmallVector(c10::IntArrayRef intArray); static int8_t GetCubeMathType(bool allowHf32); + static at::ScalarType ConvertToScalarType(const aclDataType data_type); }; } // namespace native diff --git a/torch_npu/csrc/framework/utils/NpuUtils.cpp b/torch_npu/csrc/framework/utils/NpuUtils.cpp index e54c951dc1..6e51d36b6d 100644 --- a/torch_npu/csrc/framework/utils/NpuUtils.cpp +++ b/torch_npu/csrc/framework/utils/NpuUtils.cpp @@ -129,7 +129,7 @@ at::Tensor metadata_convert_match(const at::Tensor &src, bool numelEq) // NCHW will generate a temporary tensor, which always monopolizes its own // storage. if (numelEq && (!FormatHelper::IsBaseFormatType(src))) { - at::Tensor tempTensor = custom_ops::npu_format_cast(src, FormatHelper::GetBaseFormat(src)); + at::Tensor tempTensor = custom_ops::npu_format_cast(src, FormatHelper::GetBaseFormat(src), c10::nullopt); custom_ops::npu_reshape_out(tempTensor, tempTensor.sizes(), true, tempTensor); NpuUtils::RefreshFormat(tempTensor); return tempTensor; diff --git a/torch_npu/csrc/framework/utils/OpPreparation.cpp b/torch_npu/csrc/framework/utils/OpPreparation.cpp index 20f357c654..bdaeaacee6 100644 --- a/torch_npu/csrc/framework/utils/OpPreparation.cpp +++ b/torch_npu/csrc/framework/utils/OpPreparation.cpp @@ -102,6 +102,11 @@ aclDataType OpPreparation::convert_to_acl_data_type(const at::ScalarType &data_t return CalcuOpUtil::ConvertToAclDataType(data_type, realDataType); } +at::ScalarType OpPreparation::convert_to_scalar_type(const aclDataType data_type) +{ + return CalcuOpUtil::ConvertToScalarType(data_type); +} + at::Tensor OpPreparation::copy_scalar_to_device(const c10::Scalar &cpu_scalar, at::ScalarType scalar_data_type) { return CalcuOpUtil::CopyScalarToDevice(cpu_scalar, scalar_data_type); @@ -222,14 +227,14 @@ void OpPreparation::check_memory(const std::initializer_list &inputs at::Tensor OpPreparation::cast_to_ori_format(const at::Tensor &tensor) { auto &tensor_desc = torch_npu::NPUBridge::GetNpuStorageImpl(tensor)->npu_desc_; - auto ret = custom_ops::npu_format_cast(tensor, tensor_desc.origin_format_); + auto ret = custom_ops::npu_format_cast(tensor, tensor_desc.origin_format_, c10::nullopt); return ret; } at::Tensor &OpPreparation::cast_to_ori_format(at::Tensor &tensor) { auto &tensor_desc = torch_npu::NPUBridge::GetNpuStorageImpl(tensor)->npu_desc_; - NPUNativeFunctions::npu_format_cast_(tensor, tensor_desc.origin_format_); + NPUNativeFunctions::npu_format_cast_(tensor, tensor_desc.origin_format_, c10::nullopt); return tensor; } @@ -356,21 +361,21 @@ void OpPreparation::CheckOut(const std::initializer_list &input, if (CalcuOpUtil::GetTensorNpuFormat(output) != format) { TORCH_CHECK(!is_read_write, "can not cast format when output is input", OPS_ERROR(ErrCode::NOT_SUPPORT)); - NPUNativeFunctions::npu_format_cast_(output, format); + NPUNativeFunctions::npu_format_cast_(output, format, c10::nullopt); } } at::Tensor OpPreparation::CastBackToOriFormat(const at::Tensor &tensor) { auto &tensor_desc = torch_npu::NPUBridge::GetNpuStorageImpl(tensor)->npu_desc_; - auto ret = custom_ops::npu_format_cast(tensor, tensor_desc.origin_format_); + auto ret = custom_ops::npu_format_cast(tensor, tensor_desc.origin_format_, c10::nullopt); return ret; } at::Tensor &OpPreparation::CastBackToOriFormat(at::Tensor &tensor) { auto &tensor_desc = torch_npu::NPUBridge::GetNpuStorageImpl(tensor)->npu_desc_; - NPUNativeFunctions::npu_format_cast_(tensor, tensor_desc.origin_format_); + NPUNativeFunctions::npu_format_cast_(tensor, tensor_desc.origin_format_, c10::nullopt); return tensor; } diff --git a/torch_npu/csrc/framework/utils/OpPreparation.h b/torch_npu/csrc/framework/utils/OpPreparation.h index 74ac303898..e87a910112 100644 --- a/torch_npu/csrc/framework/utils/OpPreparation.h +++ b/torch_npu/csrc/framework/utils/OpPreparation.h @@ -22,6 +22,7 @@ public: // From CalcuOpUtil part static aclDataType convert_to_acl_data_type(const at::ScalarType &data_type); static aclDataType convert_to_acl_data_type(const at::ScalarType &data_type, const std::string &realDataType); + static at::ScalarType convert_to_scalar_type(const aclDataType data_type); static at::Tensor copy_scalar_to_device(const c10::Scalar &cpu_scalar, at::ScalarType scalar_data_type); static at::Tensor copy_scalar_to_device(const c10::Scalar &cpu_scalar, at::ScalarType scalar_data_type, const c10::Device device); diff --git a/torch_npu/csrc/npu/DataParallelComm.cpp b/torch_npu/csrc/npu/DataParallelComm.cpp index db0d3efabe..c744e1e1ba 100644 --- a/torch_npu/csrc/npu/DataParallelComm.cpp +++ b/torch_npu/csrc/npu/DataParallelComm.cpp @@ -137,7 +137,7 @@ void check_inputs(TensorList inputs, TensorList outputs, int input_multiplier, i { // need to check len(inputs) == len(outputs) size_t len = inputs.size(); - if (len <= 0) { + if (len == 0) { throw std::runtime_error("input sequence can't be empty" + PTA_ERROR(ErrCode::PARAM)); } diff --git a/torch_npu/csrc/transformer_engine/CMakeLists.txt b/torch_npu/csrc/transformer_engine/CMakeLists.txt new file mode 100644 index 0000000000..b5350d4957 --- /dev/null +++ b/torch_npu/csrc/transformer_engine/CMakeLists.txt @@ -0,0 +1,6 @@ +FILE(GLOB _TE_SRCS *.cpp) + +LIST(APPEND TE_SRCS ${_TE_SRCS}) + +# Pass to parent +set(TE_SRCS ${TE_SRCS} PARENT_SCOPE) diff --git a/torch_npu/csrc/transformer_engine/CastKernelTeOpApi.cpp b/torch_npu/csrc/transformer_engine/CastKernelTeOpApi.cpp new file mode 100644 index 0000000000..afcfd6f7af --- /dev/null +++ b/torch_npu/csrc/transformer_engine/CastKernelTeOpApi.cpp @@ -0,0 +1,45 @@ +#include "torch_npu/csrc/transformer_engine/extension.h" +#include "op_plugin/AclOpsInterface.h" +#include "op_plugin/OpApiInterface.h" +#include "op_plugin/utils/op_api_common.h" + + +namespace torch_npu { +namespace te { + +at::Tensor cast_to_fp8(const at::Tensor &input, int otype) +{ + auto output = at::empty_like(input, torch_npu::te::GetATenDType(otype)); + + if (input.numel() == 0) { + return output; + } + + aclDataType out_acltype = torch_npu::te::GetAclDataType(otype); + TensorWrapper out_wrapper = {output, out_acltype}; + EXEC_NPU_CMD(aclnnCast, input, out_acltype, out_wrapper); + + return output; +} + +void cast_to_fp8_noalloc(const at::Tensor &input, at::Tensor output, int otype) +{ + aclDataType out_acltype = torch_npu::te::GetAclDataType(otype); + TensorWrapper out_wrapper = {output, out_acltype}; + EXEC_NPU_CMD(aclnnCast, input, out_acltype, out_wrapper); + return; +} + +at::Tensor cast_from_fp8(const at::Tensor &input, int itype, int otype) +{ + aclDataType input_acltype = torch_npu::te::GetAclDataType(itype); + aclDataType out_acltype = torch_npu::te::GetAclDataType(otype); + auto output = at::empty_like(input, torch_npu::te::GetATenDType(otype)); + TensorWrapper input_wrapper = {input, input_acltype}; + TensorWrapper out_wrapper = {output, out_acltype}; + EXEC_NPU_CMD(aclnnCast, input_wrapper, out_acltype, out_wrapper); + + return output; +} +} +} diff --git a/torch_npu/csrc/transformer_engine/Init.cpp b/torch_npu/csrc/transformer_engine/Init.cpp new file mode 100644 index 0000000000..bea7de1f36 --- /dev/null +++ b/torch_npu/csrc/transformer_engine/Init.cpp @@ -0,0 +1,165 @@ +#include "torch_npu/csrc/transformer_engine/Init.h" +#ifndef BUILD_LIBTORCH +#include +#include +#endif +#include "torch_npu/csrc/transformer_engine/extension.h" + + +namespace torch_npu { +namespace te { +struct DTypeConstants { + static const int float32_value; + static const int float16_value; + static const int int8_value; + static const int int32_value; + static const int uint8_value; + static const int int16_value; + static const int uint16_value; + static const int uint32_value; + static const int int64_value; + static const int uint64_value; + static const int float64_value; + static const int bool_value; + static const int string_value; + static const int complex64_value; + static const int complex128_value; + static const int bfloat16_value; + static const int int4_value; + static const int uint1_value; + static const int complex32_value; + static const int hifloat8_value; + static const int float8_e5m2_value; + static const int float8_e4m3fn_value; + static const int float8_e8m0_value; + static const int float6_e3m2_value; + static const int float6_e2m3_value; + static const int float4_e2m1_value; + static const int float4_e1m2_value; +}; + +const int DTypeConstants::float32_value = static_cast(DType::TE_FLOAT); +const int DTypeConstants::float16_value = static_cast(DType::TE_FLOAT16); +const int DTypeConstants::int8_value = static_cast(DType::TE_INT8); +const int DTypeConstants::int32_value = static_cast(DType::TE_INT32); +const int DTypeConstants::uint8_value = static_cast(DType::TE_UINT8); +const int DTypeConstants::int16_value = static_cast(DType::TE_INT16); +const int DTypeConstants::uint16_value = static_cast(DType::TE_UINT16); +const int DTypeConstants::uint32_value = static_cast(DType::TE_UINT32); +const int DTypeConstants::int64_value = static_cast(DType::TE_INT64); +const int DTypeConstants::uint64_value = static_cast(DType::TE_UINT64); +const int DTypeConstants::float64_value = static_cast(DType::TE_DOUBLE); +const int DTypeConstants::bool_value = static_cast(DType::TE_BOOL); +const int DTypeConstants::string_value = static_cast(DType::TE_STRING); +const int DTypeConstants::complex64_value = static_cast(DType::TE_COMPLEX64); +const int DTypeConstants::complex128_value = static_cast(DType::TE_COMPLEX128); +const int DTypeConstants::bfloat16_value = static_cast(DType::TE_BF16); +const int DTypeConstants::int4_value = static_cast(DType::TE_INT4); +const int DTypeConstants::uint1_value = static_cast(DType::TE_UINT1); +const int DTypeConstants::complex32_value = static_cast(DType::TE_COMPLEX32); +const int DTypeConstants::hifloat8_value = static_cast(DType::TE_HIFLOAT8); +const int DTypeConstants::float8_e5m2_value = static_cast(DType::TE_FLOAT8_E5M2); +const int DTypeConstants::float8_e4m3fn_value = static_cast(DType::TE_FLOAT8_E4M3FN); +const int DTypeConstants::float8_e8m0_value = static_cast(DType::TE_FLOAT8_E8M0); +const int DTypeConstants::float6_e3m2_value = static_cast(DType::TE_FLOAT6_E3M2); +const int DTypeConstants::float6_e2m3_value = static_cast(DType::TE_FLOAT6_E2M3); +const int DTypeConstants::float4_e2m1_value = static_cast(DType::TE_FLOAT4_E2M1); +const int DTypeConstants::float4_e1m2_value = static_cast(DType::TE_FLOAT4_E1M2); + +#ifndef BUILD_LIBTORCH +PyObject* te_initExtension(PyObject*, PyObject *) +{ + auto torch_npu_C_module = THPObjectPtr(PyImport_ImportModule("torch_npu._C")); + if (!torch_npu_C_module) { + return nullptr; + } + auto torch_npu_C_m = py::handle(torch_npu_C_module).cast(); + auto m = torch_npu_C_m.def_submodule("_te", "_te bindings"); + + py::class_(m, "DType") + .def_readonly_static("float32", &DTypeConstants::float32_value) + .def_readonly_static("float16", &DTypeConstants::float16_value) + .def_readonly_static("int8", &DTypeConstants::int8_value) + .def_readonly_static("int32", &DTypeConstants::int32_value) + .def_readonly_static("uint8", &DTypeConstants::uint8_value) + .def_readonly_static("int16", &DTypeConstants::int16_value) + .def_readonly_static("uint16", &DTypeConstants::uint16_value) + .def_readonly_static("uint32", &DTypeConstants::uint32_value) + .def_readonly_static("int64", &DTypeConstants::int64_value) + .def_readonly_static("uint64", &DTypeConstants::uint64_value) + .def_readonly_static("float64", &DTypeConstants::float64_value) + .def_readonly_static("bool", &DTypeConstants::bool_value) + .def_readonly_static("string", &DTypeConstants::string_value) + .def_readonly_static("complex64", &DTypeConstants::complex64_value) + .def_readonly_static("complex128", &DTypeConstants::complex128_value) + .def_readonly_static("bfloat16", &DTypeConstants::bfloat16_value) + .def_readonly_static("int4", &DTypeConstants::int4_value) + .def_readonly_static("uint1", &DTypeConstants::uint1_value) + .def_readonly_static("complex32", &DTypeConstants::complex32_value) + .def_readonly_static("hifloat8", &DTypeConstants::hifloat8_value) + .def_readonly_static("float8_e5m2", &DTypeConstants::float8_e5m2_value) + .def_readonly_static("float8_e4m3fn", &DTypeConstants::float8_e4m3fn_value) + .def_readonly_static("float8_e8m0", &DTypeConstants::float8_e8m0_value) + .def_readonly_static("float6_e3m2", &DTypeConstants::float6_e3m2_value) + .def_readonly_static("float6_e2m3", &DTypeConstants::float6_e2m3_value) + .def_readonly_static("float4_e2m1", &DTypeConstants::float4_e2m1_value) + .def_readonly_static("float4_e1m2", &DTypeConstants::float4_e1m2_value); + + m.def("cast_to_fp8", &cast_to_fp8, "Cast to FP8", py::call_guard()); + m.def("cast_to_fp8_noalloc", &cast_to_fp8_noalloc, "Cast to FP8", + py::call_guard()); + m.def("cast_from_fp8", &cast_from_fp8, "Cast from FP8", py::call_guard()); + + Py_RETURN_NONE; +} + +static PyMethodDef NPUTeMethods[] = { // NOLINT + {"_te_init", te_initExtension, METH_NOARGS, nullptr}, + {nullptr, nullptr, 0, nullptr} +}; +#endif + +const std::string TeDataTypeToString(int64_t dType) +{ + const std::map + TE_TYPE_TO_STRING_MAP = { + {DType::TE_FLOAT, "torch_npu.float32"}, + {DType::TE_FLOAT16, "torch_npu.float16"}, + {DType::TE_INT8, "torch_npu.int8"}, + {DType::TE_INT32, "torch_npu.int32"}, + {DType::TE_UINT8, "torch_npu.uint8"}, + {DType::TE_INT16, "torch_npu.int16"}, + {DType::TE_UINT16, "torch_npu.uint16"}, + {DType::TE_UINT32, "torch_npu.uint32"}, + {DType::TE_INT64, "torch_npu.int64"}, + {DType::TE_UINT64, "torch_npu.uint64"}, + {DType::TE_DOUBLE, "torch_npu.float64"}, + {DType::TE_BOOL, "torch_npu.bool"}, + {DType::TE_STRING, "torch_npu.string"}, + {DType::TE_COMPLEX64, "torch_npu.complex64"}, + {DType::TE_COMPLEX128, "torch_npu.complex128"}, + {DType::TE_BF16, "torch_npu.bfloat16"}, + {DType::TE_INT4, "torch_npu.int4"}, + {DType::TE_UINT1, "torch_npu.uint1"}, + {DType::TE_COMPLEX32, "torch_npu.complex32"}, + {DType::TE_HIFLOAT8, "torch_npu.hifloat8"}, + {DType::TE_FLOAT8_E5M2, "torch_npu.float8_e5m2"}, + {DType::TE_FLOAT8_E4M3FN, "torch_npu.float8_e4m3fn"}, + {DType::TE_FLOAT8_E8M0, "torch_npu.float8_e8m0"}, + {DType::TE_FLOAT6_E3M2, "torch_npu.float6_e3m2"}, + {DType::TE_FLOAT6_E2M3, "torch_npu.float6_e2m3"}, + {DType::TE_FLOAT4_E2M1, "torch_npu.float4_e2m1"}, + {DType::TE_FLOAT4_E1M2, "torch_npu.float4_e1m2"}}; + + const auto iter = TE_TYPE_TO_STRING_MAP.find(static_cast(dType)); + return iter != TE_TYPE_TO_STRING_MAP.end() ? iter->second : "Unknown dtype"; +} + +#ifndef BUILD_LIBTORCH +PyMethodDef* te_functions() +{ + return NPUTeMethods; +} +#endif +} +} diff --git a/torch_npu/csrc/transformer_engine/Init.h b/torch_npu/csrc/transformer_engine/Init.h new file mode 100644 index 0000000000..01b697c34b --- /dev/null +++ b/torch_npu/csrc/transformer_engine/Init.h @@ -0,0 +1,85 @@ +#pragma once + +#include +#ifndef BUILD_LIBTORCH +#include +#endif +#include "torch_npu/csrc/core/npu/NPUMacros.h" +#include "torch_npu/csrc/core/npu/NPUException.h" +#include "torch_npu/csrc/framework/utils/OpPreparation.h" +#include "third_party/acl/inc/acl/acl_base.h" + +namespace torch_npu { +namespace te { +const int g_teToAclOffset = 256; + +#define TE_ENUM_OFFSET(new_name, old_name) new_name = static_cast(old_name) + g_teToAclOffset, + +#ifndef BUILD_LIBTORCH +TORCH_NPU_API PyMethodDef* te_functions(); +#endif + +enum class DType { + TE_DT_UNDEFINED = -1, + TE_ENUM_OFFSET(TE_FLOAT, ACL_FLOAT) + TE_ENUM_OFFSET(TE_FLOAT16, ACL_FLOAT16) + TE_ENUM_OFFSET(TE_INT8, ACL_INT8) + TE_ENUM_OFFSET(TE_INT32, ACL_INT32) + TE_ENUM_OFFSET(TE_UINT8, ACL_UINT8) + TE_ENUM_OFFSET(TE_INT16, ACL_INT16) + TE_ENUM_OFFSET(TE_UINT16, ACL_UINT16) + TE_ENUM_OFFSET(TE_UINT32, ACL_UINT32) + TE_ENUM_OFFSET(TE_INT64, ACL_INT64) + TE_ENUM_OFFSET(TE_UINT64, ACL_UINT64) + TE_ENUM_OFFSET(TE_DOUBLE, ACL_DOUBLE) + TE_ENUM_OFFSET(TE_BOOL, ACL_BOOL) + TE_ENUM_OFFSET(TE_STRING, ACL_STRING) + TE_ENUM_OFFSET(TE_COMPLEX64, ACL_COMPLEX64) + TE_ENUM_OFFSET(TE_COMPLEX128, ACL_COMPLEX128) + TE_ENUM_OFFSET(TE_BF16, ACL_BF16) + TE_ENUM_OFFSET(TE_INT4, ACL_INT4) + TE_ENUM_OFFSET(TE_UINT1, ACL_UINT1) + TE_ENUM_OFFSET(TE_COMPLEX32, ACL_COMPLEX32) + TE_ENUM_OFFSET(TE_HIFLOAT8, ACL_HIFLOAT8) + TE_ENUM_OFFSET(TE_FLOAT8_E5M2, ACL_FLOAT8_E5M2) + TE_ENUM_OFFSET(TE_FLOAT8_E4M3FN, ACL_FLOAT8_E4M3FN) + TE_ENUM_OFFSET(TE_FLOAT8_E8M0, ACL_FLOAT8_E8M0) + TE_ENUM_OFFSET(TE_FLOAT6_E3M2, ACL_FLOAT6_E3M2) + TE_ENUM_OFFSET(TE_FLOAT6_E2M3, ACL_FLOAT6_E2M3) + TE_ENUM_OFFSET(TE_FLOAT4_E2M1, ACL_FLOAT4_E2M1) + TE_ENUM_OFFSET(TE_FLOAT4_E1M2, ACL_FLOAT4_E1M2) +}; + +inline bool IsTEDType(int64_t t) +{ + if (t >= g_teToAclOffset) { + return true; + } + return false; +} + +// Both torch_npu::te::DType and ScalarType are supported +inline aclDataType GetAclDataType(int64_t t) +{ + if (t >= g_teToAclOffset) { + return static_cast(t - g_teToAclOffset); + } + return at_npu::native::OpPreparation::convert_to_acl_data_type( + static_cast(t)); +} + +inline aclDataType GetAclDataType(DType t) +{ + return static_cast(static_cast(t) - g_teToAclOffset); +} + +inline at::ScalarType GetATenDType(int64_t t) +{ + aclDataType aclType = GetAclDataType(t); + return at_npu::native::OpPreparation::convert_to_scalar_type(aclType); +} + +const std::string TeDataTypeToString(int64_t dType); + +} // namespace te +} // namespace torch_npu diff --git a/torch_npu/csrc/transformer_engine/extension.h b/torch_npu/csrc/transformer_engine/extension.h new file mode 100644 index 0000000000..2adb9bc3bf --- /dev/null +++ b/torch_npu/csrc/transformer_engine/extension.h @@ -0,0 +1,14 @@ +#pragma once + +#include +#include "torch_npu/csrc/transformer_engine/Init.h" + +namespace torch_npu { +namespace te { +at::Tensor cast_to_fp8(const at::Tensor &input, int otype); + +void cast_to_fp8_noalloc(const at::Tensor &input, at::Tensor output, int otype); + +at::Tensor cast_from_fp8(const at::Tensor &input, int itype, int otype); +} +} diff --git a/torch_npu/onnx/wrapper_onnx_ops.py b/torch_npu/onnx/wrapper_onnx_ops.py index a261d17859..d69a875a38 100644 --- a/torch_npu/onnx/wrapper_onnx_ops.py +++ b/torch_npu/onnx/wrapper_onnx_ops.py @@ -244,8 +244,8 @@ class _NPUFormatCastOP(torch.autograd.Function): return torch.ops.npu.npu_format_cast(*args, **kwargs) @staticmethod - def symbolic(g, self: Tensor, acl_format: int): - return g.op("npu::NPUFormatCast", self, acl_format_i=acl_format) + def symbolic(g, self: Tensor, acl_format: int, customize_dtype: int = None): + return g.op("npu::NPUFormatCast", self, acl_format_i=acl_format, customize_dtype_i=customize_dtype) class _NPUSoftmaxCrossEntropyWithLogitsOP(torch.autograd.Function): @@ -730,6 +730,45 @@ class _NPUFusedInferAttentionScoreOP(torch.autograd.Function): key_antiquant_mode, value_antiquant_mode) +class _NPUFusedInferAttentionScoreOP(torch.autograd.Function): + + @staticmethod + def forward(ctx, *args, **kwargs): + return torch.ops.npu.fused_infer_attention_score(*args, **kwargs) + + @staticmethod + def symbolic(g, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, + pse_shift: Optional[Tensor], atten_mask: Optional[Tensor], + actual_seq_lengths: Optional[Tensor], + actual_seq_lengths_kv: Optional[Tensor], + dequant_scale1: Optional[Tensor], quant_scale1: Optional[Tensor], + dequant_scale2: Optional[Tensor], quant_scale2: Optional[Tensor], + quant_offset2: Optional[Tensor], antiquant_scale: Optional[Tensor], + antiquant_offset: Optional[Tensor], block_table: Optional[Tensor], + query_padding_size: Optional[Tensor], kv_padding_size: Optional[Tensor], + key_antiquant_scale: Optional[Tensor], key_antiquant_offset: Optional[Tensor], + value_antiquant_scale: Optional[Tensor], value_antiquant_offset: Optional[Tensor], + key_shared_prefix: Optional[Tensor], value_shared_prefix: Optional[Tensor], + actual_shared_prefix_len: Optional[Tensor], + query_rope: Optional[Tensor], + key_rope: Optional[Tensor], + num_heads: int = 1, scale: float = 1.0, + pre_tokens: int = 2147483647, next_tokens: int = 2147483647, + input_layout: str = "BSH", num_key_value_heads: int = 0, + sparse_mode: int = 0, inner_precise: int = 0, block_size: int = 0, + antiquant_mode: int = 0, softmax_lse_flag: bool = False, + key_antiquant_mode: int = 0, value_antiquant_mode: int = 0): + return g.op("npu::NPUFusedInferAttentionScoreOP", self, query, key, value, + pse_shift, atten_mask, actual_seq_lengths, actual_seq_lengths_kv, + dequant_scale1, quant_scale1, dequant_scale2, quant_scale2, quant_offset2, + antiquant_scale, antiquant_offset, block_table, query_padding_size, kv_padding_size, + key_antiquant_scale, key_antiquant_offset, value_antiquant_scale, value_antiquant_offset, + key_shared_prefix, value_shared_prefix, actual_shared_prefix_len, query_rope, key_rope, + num_heads, scale, pre_tokens, next_tokens, input_layout, num_key_value_heads, + sparse_mode, inner_precise, block_size, antiquant_mode, softmax_lse_flag, + key_antiquant_mode, value_antiquant_mode) + + class _NPUMaskedSoftmaxWithRelPosBiasOP(torch.autograd.Function): @staticmethod @@ -1097,8 +1136,8 @@ def _wrapper_npu_deformable_conv2d(inputs, weight, offset, bias, kernel_size, st padding, dilation, groups, deformable_groups, modulated) -def _wrapper_npu_format_cast(self, acl_format): - return _NPUFormatCastOP.apply(self, acl_format) +def _wrapper_npu_format_cast(self, acl_format, customize_dtype=None): + return _NPUFormatCastOP.apply(self, acl_format, customize_dtype) def _wrapper_npu_softmax_cross_entropy_with_logits(self, labels): @@ -1316,6 +1355,29 @@ def _wrapper_npu_fused_infer_attention_score(self, query, key, value, pse_shift, key_antiquant_mode, value_antiquant_mode) +def _wrapper_npu_fused_infer_attention_score(self, query, key, value, pse_shift, atten_mask, actual_seq_lengths, actual_seq_lengths_kv, + dequant_scale1, quant_scale1, dequant_scale2, quant_scale2, quant_offset2, + antiquant_scale, + antiquant_offset, block_table, query_padding_size, kv_padding_size, + num_heads, scale, pre_tokens, next_tokens, input_layout, + key_antiquant_scale, key_antiquant_offset, value_antiquant_scale, value_antiquant_offset, + key_shared_prefix, value_shared_prefix, actual_shared_prefix_len, query_rope, key_rope, + num_key_value_heads, + sparse_mode, inner_precise, block_size, antiquant_mode, softmax_lse_flag, + key_antiquant_mode, value_antiquant_mode): + return _NPUFusedInferAttentionScoreOP.apply(self, query, key, value, pse_shift, atten_mask, actual_seq_lengths, + actual_seq_lengths_kv, + dequant_scale1, quant_scale1, dequant_scale2, quant_scale2, + quant_offset2, antiquant_scale, + antiquant_offset, block_table, query_padding_size, kv_padding_size, + key_antiquant_scale, key_antiquant_offset, value_antiquant_scale, value_antiquant_offset, + key_shared_prefix, value_shared_prefix, actual_shared_prefix_len, query_rope, key_rope, + num_heads, scale, pre_tokens, next_tokens, input_layout, + num_key_value_heads, + sparse_mode, inner_precise, block_size, antiquant_mode, softmax_lse_flag, + key_antiquant_mode, value_antiquant_mode) + + def _wrapper_npu_mm_all_reduce_base(x1, x2, hcom, reduce_op, bias, antiquant_scale, antiquant_offset, x3, dequant_scale, pertoken_scale, comm_quant_scale_1, comm_quant_scale_2, antiquant_group_size, comm_turn): diff --git a/torch_npu/utils/hif8_tensor.py b/torch_npu/utils/hif8_tensor.py new file mode 100644 index 0000000000..f8dc2d4742 --- /dev/null +++ b/torch_npu/utils/hif8_tensor.py @@ -0,0 +1,635 @@ +"""Tensor class with HIF8 data""" +from __future__ import annotations + +__all__ = ["HiFloat8Tensor"] + +from typing import Any, Dict, Optional, Tuple, Union + +import torch +from torch.utils._pytree import tree_map +import torch_npu +from torch_npu.utils._error_code import ErrCode, pta_error + + +# init transformer engine +torch_npu._C._te_init() + +tex = torch_npu._C._te +aten = torch.ops.aten + +NPU_TE_DType = { + torch.uint8: tex.DType.uint8, + torch.int32: tex.DType.int32, + torch.float32: tex.DType.float32, + torch.half: tex.DType.float16, + torch.bfloat16: tex.DType.bfloat16, +} + + +class _FromHiFloat8Func(torch.autograd.Function): + """Cast from HIF8 to other dtype""" + + @staticmethod + def forward( + _ctx: torch.autograd.function.FunctionCtx, # unused + tensor: HiFloat8Tensor, + dtype: Optional[torch.dtype] = None, + ) -> torch.Tensor: + if dtype is None: + dtype = tensor.dtype + data = tensor._data.contiguous().view(1, -1).detach() + out = tex.cast_from_fp8( + data, + tex.DType.hifloat8, + NPU_TE_DType[dtype], + ) + out = out.view(tensor.size()) + return out + + @staticmethod + def backward( + _ctx: torch.autograd.function.FunctionCtx, # unused + grad: torch.Tensor, + ) -> Tuple[Optional[torch.Tensor], ...]: + # Assume that we want gradients in full precision + return grad, None + + +class _ToHiFloat8Func(torch.autograd.Function): + """Cast to HIF8 from other dtype""" + + @staticmethod + def forward( + _ctx: torch.autograd.function.FunctionCtx, # unused + tensor: torch.Tensor, + scale: Optional[torch.Tensor] = None + ) -> HiFloat8Tensor: + + # Check input tensor TODO + tensor = tensor.contiguous().npu().detach() + if tensor.dtype not in (torch.float32, torch.bfloat16, torch.float16): + tensor = tensor.float() + + # Check scale + if not isinstance(scale, torch.Tensor): + if scale is None: + scale = 1 + scale = torch.full( + [1], + scale, + dtype=torch.float32, + device=tensor.device, + ) + if scale.numel() != 1: + raise ValueError( + "Attempted to initialize HiFloat8Tensor with invalid scale tensor" + + pta_error(ErrCode.VALUE) + ) + scale = scale.to(device=tensor.device, dtype=torch.float32) + + # Cast data to HIF8 + data = tex.cast_to_fp8( + tensor.view(1, -1), + tex.DType.hifloat8, + ) + data = data.view(tensor.size()) + + # Construct HIF8 tensor + return HiFloat8Tensor( + data=data, + fp8_scale=scale, + dtype=tensor.dtype, + ) + + @staticmethod + def backward( + _ctx: torch.autograd.function.FunctionCtx, # unused + grad: torch.Tensor, + ) -> Tuple[Optional[torch.Tensor], ...]: + # Assume that we want gradients in full precision + return grad, None, None + + +class _IdentityFunc(torch.autograd.Function): + """Identity function + + If constructor keyword-arguments are provided, then construct a + new HiFloat8Tensor using the provided tensor's attributes. + + """ + + @staticmethod + def forward( + ctx, + tensor: HiFloat8Tensor, + init_kwargs: Optional[Dict[str, Any]] = None, + ) -> torch.Tensor: + + # Return input tensor if constructor kwargs are not provided + ctx.input_dtype = tensor.dtype + if init_kwargs is None: + return tensor + + # Construct new tensor if constructor kwargs are provided + default_kwargs = dict( + data=tensor._data, + fp8_scale=tensor._scale, + dtype=tensor.dtype, + ) + for key, val in default_kwargs.items(): + if key not in init_kwargs: + init_kwargs[key] = val + return HiFloat8Tensor(**init_kwargs) + + @staticmethod + def backward(ctx, grad): + return grad.to(ctx.input_dtype), None + + +class _ViewFunc(torch.autograd.Function): + """View function + + View the HiFloat8Tensor using the provided shape. + + """ + + @staticmethod + def forward( + ctx, + tensor: torch.Tensor, + shape: Tuple[int] = None, + ) -> torch.Tensor: + + # Return input tensor if shape is not provided + ctx.shape = tensor.shape + if shape is None: + return tensor + + # Construct new tensor if shape is provided + if isinstance(tensor, HiFloat8Tensor): + return HiFloat8Tensor.make_like( + tensor, + data=tensor._data.view(*shape), + ) + return tensor.view(*shape) + + @staticmethod + def backward( + ctx, + grad: torch.Tensor, + ) -> Tuple[Union[torch.Tensor, None], ...]: + + if isinstance(grad, HiFloat8Tensor): + dgrad = HiFloat8Tensor.make_like( + grad, + data=grad._data.view(ctx.shape), + ) + return dgrad, None + return grad.view(ctx.shape), None + + +class _ReshapeFunc(torch.autograd.Function): + """Reshape function + + Reshape the HiFloat8Tensor using the provided shape. + + """ + + @staticmethod + def forward( + ctx, + tensor: torch.Tensor, + shape: Tuple[int] = None, + ) -> torch.Tensor: + + # Return input tensor if shape is not provided + ctx.shape = tensor.shape + if shape is None: + return tensor + + # Construct new tensor if shape is provided + if isinstance(tensor, HiFloat8Tensor): + return HiFloat8Tensor.make_like( + tensor, + data=tensor._data.reshape(*shape), + ) + return tensor.reshape(*shape) + + @staticmethod + def backward( + ctx, + grad: torch.Tensor, + ) -> Tuple[Union[torch.Tensor, None], ...]: + + if isinstance(grad, HiFloat8Tensor): + dgrad = HiFloat8Tensor.make_like( + grad, + data=grad._data.reshape(ctx.shape), + ) + return dgrad, None + return grad.reshape(ctx.shape), None + + +class _TransposeFunc(torch.autograd.Function): + """Transpose function + + Transpose the HiFloat8Tensor. + + """ + + @staticmethod + def forward(ctx, tensor, dim0, dim1): + ctx.save_for_backward(dim0, dim1) + if isinstance(tensor, HiFloat8Tensor): + return HiFloat8Tensor.make_like( + tensor, + data=tensor._data.transpose(dim0, dim1), + ) + return tensor.transpose(dim0, dim1) + + @staticmethod + def backward(ctx, grad): + dim0, dim1 = ctx.saved_tensors + if isinstance(grad, HiFloat8Tensor): + dgrad = HiFloat8Tensor.make_like( + grad, + data=grad._data.transpose(dim0, dim1), + ) + return dgrad, None + return grad.transpose(dim0, dim1), None, None + + +class HiFloat8Tensor(torch.Tensor): + """Experimental tensor class with HIF8 data + + The tensor presents as having a standard, higher-precision dtype, + but the data itself is (scaled) HIF8. For most tensor operations, + the data will be cast to the nominal dtype before performing the + operation. + + Parameters + ---------- + data: torch.Tensor + Raw HIF8 data in a uint8 tensor + fp8_scale: torch.Tensor + Reciprocal of the scaling factor applied when + casting to HIF8, i.e. the scaling factor that must + be applied when casting from HIF8 to higher + precision. Can be inferred from hif8_meta if + provided. + dtype: torch.dtype, default = torch.float32 + Nominal tensor datatype. + + """ + + def __new__( + cls, + *, + data: torch.Tensor, + fp8_scale: Optional[torch.Tensor] = None, + dtype: torch.dtype = torch.float32, + ): + # Check that data buffer is valid + if data.element_size() != 1: + raise ValueError( + f"HiFloat8Tensor requires data buffer with 8-bit dtype (got dtype={data.dtype})" + + pta_error(ErrCode.VALUE) + ) + if data.requires_grad: + raise ValueError( + "HiFloat8Tensor requires non-differentiable data buffer" + + pta_error(ErrCode.VALUE) + ) + if not data.is_npu: + data = data.npu() + + # Initialize tensor object + self = torch.Tensor._make_wrapper_subclass( + cls, + data.size(), + strides=data.stride(), + storage_offset=data.storage_offset(), + dtype=dtype, + layout=data.layout, + requires_grad=data.requires_grad, + device=data.device, + ) + self._data: torch.Tensor = data + + if fp8_scale is None or not isinstance(fp8_scale, torch.Tensor): + fp8_scale = torch.full( + [1], + 1, + dtype=torch.float32, + device=self._data.device, + ) + if fp8_scale.numel() != 1: + raise ValueError( + "Attempted to initialize HiFloat8Tensor with invalid scale tensor" + + pta_error(ErrCode.VALUE) + ) + if fp8_scale.dim() != 1: + fp8_scale = fp8_scale.reshape(1) + if fp8_scale.device != self._data.device or fp8_scale.dtype != torch.float32: + fp8_scale = fp8_scale.to( + device=self._data.device, + dtype=torch.float32, + ) + self._scale: Optional[torch.Tensor] = fp8_scale + + return self + + @classmethod + def make_like( + cls, + tensor: HiFloat8Tensor, + *, + data: torch.Tensor, + **kwargs, + ) -> HiFloat8Tensor: + """Use attributes of a HiFloat8Tensor to create another HiFloat8Tensor + + See constructor for list of keyword arguments. + + """ + default_kwargs = dict( + fp8_scale=tensor._scale, + dtype=tensor.dtype, + ) + for key, val in default_kwargs.items(): + if key not in kwargs: + kwargs[key] = val + return HiFloat8Tensor(data=data, **kwargs) + + def __repr__(self): + return ( + "HiFloat8Tensor(" + f"scale={self._scale.item()}, " + f"data={self.from_hifloat8(dtype=self.dtype)}" + ")" + ) + + def from_hifloat8(self, dtype: Optional[torch.dtype] = None) -> torch.Tensor: + """ + Construct plain PyTorch tensor from HiFloat8Tensor + + By default the resulting tensor's dtype is the + HiFloat8Tensor's nominal dtype. + """ + return _FromHiFloat8Func.apply(self, dtype) + + @classmethod + def to_hifloat8( + cls, + tensor: torch.Tensor, + *, + scale: Optional[torch.Tensor] = None + ): + """Construct HiFloat8Tensor from plain PyTorch tensor""" + return _ToHiFloat8Func.apply( + tensor, + scale + ) + + def float(self) -> torch.Tensor: + return self.from_hifloat8(dtype=torch.float32) + + def bfloat16(self) -> torch.Tensor: + return self.from_hifloat8(dtype=torch.bfloat16) + + def half(self) -> torch.Tensor: + return self.from_hifloat8(dtype=torch.float16) + + def cpu(self) -> torch.Tensor: + return self.from_hifloat8().cpu() + + def clone(self) -> HiFloat8Tensor: + return _IdentityFunc.apply(self, {"data": self._data.detach().clone()}) + + def view(self, *shape: Tuple[int]) -> HiFloat8Tensor: + return _ViewFunc.apply(self, shape) + + def reshape(self, *shape: Tuple[int]) -> HiFloat8Tensor: + return _ReshapeFunc.apply(self, shape) + + def contiguous( + self, + *, + memory_format: torch.memory_format = torch.contiguous_format, + ) -> HiFloat8Tensor: + """Returns tensor with data in provided memory format + + Returns `self` if data is already in correct memory format. + + """ + if self._data.is_contiguous(memory_format=memory_format): + return self + return _IdentityFunc.apply( + self, + {"data": self._data.detach().contiguous(memory_format=memory_format)}, + ) + + def to_dtype(self, dtype: torch.dtype) -> HiFloat8Tensor: + """Create `HiFloat8Tensor` with given nominal dtype + + The new tensor has the same underlying HIF8 data. + + """ + return HiFloat8Tensor.make_like( + self, + data=self._data, + dtype=dtype, + ) + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs=None): + + # In-place copy op + if func == aten.copy_.default: + + # Check tensors + dst = args[0] + src = args[1] + if not isinstance(dst, torch.Tensor): + raise RuntimeError( + "Attempted to copy into something that isn't a PyTorch tensor" + + pta_error(ErrCode.TYPE) + ) + if not isinstance(src, torch.Tensor): + raise RuntimeError( + "Attempted to copy from something that isn't a PyTorch tensor" + + pta_error(ErrCode.TYPE) + ) + + # Special handling based on which tensors are HIF8 + dst_is_hif8 = isinstance(dst, HiFloat8Tensor) + src_is_hif8 = isinstance(src, HiFloat8Tensor) + if dst_is_hif8 and src_is_hif8: + # Directly copy HIF8 data if possible + dst._data.copy_(src._data) + dst._scale.copy_(src._scale.detach()) + + elif not dst_is_hif8 and src_is_hif8: + # Cast source tensor to higher precision + dst.copy_(src.from_hifloat8()) + + elif dst_is_hif8 and not src_is_hif8: + # Make sure input is in expected format + src = src.expand(dst.size()) + src = src.to( + device=dst.device, + memory_format=torch.contiguous_format, + ) + + # Cast to HIF8 + if not dst._data.is_contiguous(): + raise RuntimeError( + "Transformer Engine cast kernels require contiguous data" + + pta_error(ErrCode.INTERNAL) + ) + tex.cast_to_fp8_noalloc( + src.view(1, -1), + dst._data.view(1, -1), + tex.DType.hifloat8, + ) + else: + # Invalid case + raise RuntimeError( + "Using HiFloat8Tensor copy logic, but no HiFloat8Tensor found" + + pta_error(ErrCode.INTERNAL) + ) + + # Nothing to return for in-place ops + return None + + # Slice op + if func == aten.slice.Tensor: + tensor = args[0] + data = tensor._data + data_slice = data.__torch_dispatch__( + func, + types, + [data] + list(args[1:]), + kwargs, + ) + return HiFloat8Tensor.make_like(tensor, data=data_slice) + + # Detach op + if func == aten.detach.default: + # Simply return a new HiFloat8Tensor with the same attrs + return HiFloat8Tensor.make_like( + args[0], + data=args[0]._data, + ) + + # View op + if func == aten.view.default: + tensor = args[0] + data = tensor._data + data_view = data.__torch_dispatch__( + func, + types, + [data] + list(args[1:]), + kwargs, + ) + return HiFloat8Tensor.make_like( + tensor, + data=data_view, + ) + + def maybe_unwrap(t): + if isinstance(t, HiFloat8Tensor): + return t.from_hifloat8() + return t + + def maybe_update_inplace(arg, new_arg, schema_arg): + """Update values of HIF8 tensors + + Keep the same HIF8 scaling factors. + + """ + check_args = isinstance(arg, HiFloat8Tensor) and isinstance(new_arg, torch.Tensor) + check_schema = ( + hasattr(schema_arg, "alias_info") + and hasattr(schema_arg.alias_info, "is_write") + and schema_arg.alias_info.is_write + ) + + if check_args and check_schema: + arg.copy_(new_arg) + + # In-place op + if func._schema.is_mutable: + # Cast to higher precision, perform op, and cast values + # back to original HIF8 buffers + new_args = tree_map(maybe_unwrap, args) + new_kwargs = tree_map(maybe_unwrap, kwargs) + schema_args = func._schema.arguments + args_len = len(args) + out = super().__torch_dispatch__(func, types, new_args, new_kwargs) + for arg, new_arg, schema_arg in zip(args, new_args, schema_args): + maybe_update_inplace(arg, new_arg, schema_arg) + for kwarg, new_kwarg, schema_arg in zip(kwargs, new_kwargs, schema_args[args_len:]): + if not (kwarg == new_kwarg == schema_arg.name): + raise ValueError('name of the kw argument should match' + pta_error(ErrCode.VALUE)) + maybe_update_inplace(kwargs[kwarg], new_kwargs[new_kwarg], schema_arg) + return None + + # Default op + # Note: cast to higher precision and perform op + args = tree_map(maybe_unwrap, args) + if kwargs is not None: + kwargs = tree_map(maybe_unwrap, kwargs) + out = super().__torch_dispatch__(func, types, args, kwargs) + return out + + @classmethod + def _make_in_reduce_ex( + cls, + data: torch.Tensor, + fp8_scale: torch.Tensor, + dtype: torch.dtype, + ) -> HiFloat8Tensor: + """Build HiFloat8Tensor, for use in __reduce__ + + __reduce_ex__ assumes object constructor has positional + arguments. + + """ + return HiFloat8Tensor( + data=data, + fp8_scale=fp8_scale, + dtype=dtype, + ) + + def __reduce_ex__(self, protocol: int) -> tuple: + """Custom pickling to remove references to HIF8 metadata objects""" + return ( + HiFloat8Tensor._make_in_reduce_ex, + (self._data, self._scale, self.dtype), + ) + + def _get_data(self) -> HiFloat8Tensor: + """Get tensor data property""" + return super().data + + def _set_data(self, tensor: torch.Tensor) -> None: + """Set tensor data property + + Cast tensor to HIF8 and store in HIF8 buffer. + + """ + with torch.no_grad(): + self.copy_(tensor) + + # Cast to HIF8 when setting HiFloat8Tensor.data + data = property(_get_data, _set_data) + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + return torch._C._disabled_torch_function_impl(func, types, args, kwargs) + + def transpose(self, dim0, dim1): + return _TransposeFunc.apply(self, dim0, dim1) -- Gitee From 8b919205c0948e3901d02123a563d218a68280a0 Mon Sep 17 00:00:00 2001 From: chuboning Date: Fri, 6 Jun 2025 15:00:38 +0800 Subject: [PATCH 02/11] Remove scale from HiFloat8Tensor --- torch_npu/utils/hif8_tensor.py | 64 +++------------------------------- 1 file changed, 4 insertions(+), 60 deletions(-) diff --git a/torch_npu/utils/hif8_tensor.py b/torch_npu/utils/hif8_tensor.py index f8dc2d4742..353bf1bf51 100644 --- a/torch_npu/utils/hif8_tensor.py +++ b/torch_npu/utils/hif8_tensor.py @@ -62,7 +62,6 @@ class _ToHiFloat8Func(torch.autograd.Function): def forward( _ctx: torch.autograd.function.FunctionCtx, # unused tensor: torch.Tensor, - scale: Optional[torch.Tensor] = None ) -> HiFloat8Tensor: # Check input tensor TODO @@ -70,23 +69,6 @@ class _ToHiFloat8Func(torch.autograd.Function): if tensor.dtype not in (torch.float32, torch.bfloat16, torch.float16): tensor = tensor.float() - # Check scale - if not isinstance(scale, torch.Tensor): - if scale is None: - scale = 1 - scale = torch.full( - [1], - scale, - dtype=torch.float32, - device=tensor.device, - ) - if scale.numel() != 1: - raise ValueError( - "Attempted to initialize HiFloat8Tensor with invalid scale tensor" - + pta_error(ErrCode.VALUE) - ) - scale = scale.to(device=tensor.device, dtype=torch.float32) - # Cast data to HIF8 data = tex.cast_to_fp8( tensor.view(1, -1), @@ -97,7 +79,6 @@ class _ToHiFloat8Func(torch.autograd.Function): # Construct HIF8 tensor return HiFloat8Tensor( data=data, - fp8_scale=scale, dtype=tensor.dtype, ) @@ -107,7 +88,7 @@ class _ToHiFloat8Func(torch.autograd.Function): grad: torch.Tensor, ) -> Tuple[Optional[torch.Tensor], ...]: # Assume that we want gradients in full precision - return grad, None, None + return grad, None class _IdentityFunc(torch.autograd.Function): @@ -133,7 +114,6 @@ class _IdentityFunc(torch.autograd.Function): # Construct new tensor if constructor kwargs are provided default_kwargs = dict( data=tensor._data, - fp8_scale=tensor._scale, dtype=tensor.dtype, ) for key, val in default_kwargs.items(): @@ -271,12 +251,6 @@ class HiFloat8Tensor(torch.Tensor): ---------- data: torch.Tensor Raw HIF8 data in a uint8 tensor - fp8_scale: torch.Tensor - Reciprocal of the scaling factor applied when - casting to HIF8, i.e. the scaling factor that must - be applied when casting from HIF8 to higher - precision. Can be inferred from hif8_meta if - provided. dtype: torch.dtype, default = torch.float32 Nominal tensor datatype. @@ -286,7 +260,6 @@ class HiFloat8Tensor(torch.Tensor): cls, *, data: torch.Tensor, - fp8_scale: Optional[torch.Tensor] = None, dtype: torch.dtype = torch.float32, ): # Check that data buffer is valid @@ -316,27 +289,6 @@ class HiFloat8Tensor(torch.Tensor): ) self._data: torch.Tensor = data - if fp8_scale is None or not isinstance(fp8_scale, torch.Tensor): - fp8_scale = torch.full( - [1], - 1, - dtype=torch.float32, - device=self._data.device, - ) - if fp8_scale.numel() != 1: - raise ValueError( - "Attempted to initialize HiFloat8Tensor with invalid scale tensor" - + pta_error(ErrCode.VALUE) - ) - if fp8_scale.dim() != 1: - fp8_scale = fp8_scale.reshape(1) - if fp8_scale.device != self._data.device or fp8_scale.dtype != torch.float32: - fp8_scale = fp8_scale.to( - device=self._data.device, - dtype=torch.float32, - ) - self._scale: Optional[torch.Tensor] = fp8_scale - return self @classmethod @@ -353,7 +305,6 @@ class HiFloat8Tensor(torch.Tensor): """ default_kwargs = dict( - fp8_scale=tensor._scale, dtype=tensor.dtype, ) for key, val in default_kwargs.items(): @@ -364,7 +315,6 @@ class HiFloat8Tensor(torch.Tensor): def __repr__(self): return ( "HiFloat8Tensor(" - f"scale={self._scale.item()}, " f"data={self.from_hifloat8(dtype=self.dtype)}" ")" ) @@ -381,14 +331,11 @@ class HiFloat8Tensor(torch.Tensor): @classmethod def to_hifloat8( cls, - tensor: torch.Tensor, - *, - scale: Optional[torch.Tensor] = None + tensor: torch.Tensor ): """Construct HiFloat8Tensor from plain PyTorch tensor""" return _ToHiFloat8Func.apply( - tensor, - scale + tensor ) def float(self) -> torch.Tensor: @@ -467,7 +414,6 @@ class HiFloat8Tensor(torch.Tensor): if dst_is_hif8 and src_is_hif8: # Directly copy HIF8 data if possible dst._data.copy_(src._data) - dst._scale.copy_(src._scale.detach()) elif not dst_is_hif8 and src_is_hif8: # Cast source tensor to higher precision @@ -587,7 +533,6 @@ class HiFloat8Tensor(torch.Tensor): def _make_in_reduce_ex( cls, data: torch.Tensor, - fp8_scale: torch.Tensor, dtype: torch.dtype, ) -> HiFloat8Tensor: """Build HiFloat8Tensor, for use in __reduce__ @@ -598,7 +543,6 @@ class HiFloat8Tensor(torch.Tensor): """ return HiFloat8Tensor( data=data, - fp8_scale=fp8_scale, dtype=dtype, ) @@ -606,7 +550,7 @@ class HiFloat8Tensor(torch.Tensor): """Custom pickling to remove references to HIF8 metadata objects""" return ( HiFloat8Tensor._make_in_reduce_ex, - (self._data, self._scale, self.dtype), + (self._data, self.dtype), ) def _get_data(self) -> HiFloat8Tensor: -- Gitee From d922412b38740553af105055bfd646712f28c565 Mon Sep 17 00:00:00 2001 From: chuboning Date: Mon, 9 Jun 2025 15:24:16 +0800 Subject: [PATCH 03/11] Fix npu_format_cast --- torch_npu/csrc/aten/common/CopyKernel.cpp | 6 +-- .../csrc/aten/common/FormatCastHelper.cpp | 4 +- .../csrc/aten/common/FormatCastKernelNpu.cpp | 41 +++++++++++++------ .../csrc/aten/common/LocalScalarDenseNpu.cpp | 24 +++++------ torch_npu/csrc/aten/common/NpuFastReshape.cpp | 2 +- torch_npu/csrc/aten/common/ResizeNpu.cpp | 2 +- torch_npu/csrc/aten/npu_native_functions.yaml | 3 +- .../aten/ops/FlattenDenseTensorsKernelNpu.cpp | 3 +- torch_npu/csrc/core/NPUSerialization.cpp | 2 +- torch_npu/csrc/core/npu/NPUFormat.cpp | 2 +- torch_npu/csrc/distributed/Init.cpp | 3 +- .../csrc/distributed/ProcessGroupHCCL.cpp | 2 +- torch_npu/csrc/distributed/reducer.cpp | 2 +- torch_npu/csrc/framework/utils/NpuUtils.cpp | 2 +- .../csrc/framework/utils/OpPreparation.cpp | 10 ++--- torch_npu/utils/hif8_tensor.py | 4 +- 16 files changed, 64 insertions(+), 48 deletions(-) diff --git a/torch_npu/csrc/aten/common/CopyKernel.cpp b/torch_npu/csrc/aten/common/CopyKernel.cpp index 03af121f76..feb0f9b887 100644 --- a/torch_npu/csrc/aten/common/CopyKernel.cpp +++ b/torch_npu/csrc/aten/common/CopyKernel.cpp @@ -76,7 +76,7 @@ void copy_d2d_dtype_format(at::Tensor& self, const at::Tensor& src, bool non_blo at::Tensor src_4D = FormatCastHelper::ApplyBaseFormatTensorBy(src); at::Tensor dst_4D = FormatCastHelper::ApplyBaseFormatTensorBy(self); copy_d2d_dtype_baseformat(dst_4D, src_4D, non_blocking); - NPUNativeFunctions::npu_format_cast_(self, dst_4D, c10::nullopt); + NPUNativeFunctions::npu_format_cast_(self, dst_4D); return; } copy_d2d_dtype_baseformat(self, src, non_blocking); @@ -213,7 +213,7 @@ void copy_h2d(at::Tensor& self, const at::Tensor& src, bool non_blocking) if (!FormatHelper::IsBaseFormatType(self)) { at::Tensor dst = OpPreparation::ApplyTensorWithSizes(self.sizes(), self.options()); copy_h2d_baseformat(dst, src, non_blocking, true); - NPUNativeFunctions::npu_format_cast_(self, dst, c10::nullopt); + NPUNativeFunctions::npu_format_cast_(self, dst); return; } copy_h2d_baseformat(self, src, non_blocking); @@ -326,7 +326,7 @@ void copy_d2d_dtype(at::Tensor& self, const at::Tensor& src, bool non_blocking) } at::Tensor dst_4D = FormatCastHelper::ApplyBaseFormatTensorBy(self); copy_d2d_dtype_baseformat(dst_4D, src_4D, non_blocking); - NPUNativeFunctions::npu_format_cast_(self, dst_4D, c10::nullopt); + NPUNativeFunctions::npu_format_cast_(self, dst_4D); return; } copy_d2d_dtype_format(self, src, non_blocking); diff --git a/torch_npu/csrc/aten/common/FormatCastHelper.cpp b/torch_npu/csrc/aten/common/FormatCastHelper.cpp index ade6a42686..2f61a7c782 100644 --- a/torch_npu/csrc/aten/common/FormatCastHelper.cpp +++ b/torch_npu/csrc/aten/common/FormatCastHelper.cpp @@ -71,13 +71,13 @@ bool FormatCastHelper::format_cast_between_group( at::Tensor FormatCastHelper::ApplyBaseFormatTensorBy(const at::Tensor& src) { auto format = FormatHelper::GetBaseFormat(src); - return custom_ops::npu_format_cast(src, format, c10::nullopt); + return custom_ops::npu_format_cast(src, format); } at::Tensor& FormatCastHelper::CovertSelfToBaseFormat(at::Tensor& src) { auto format = FormatHelper::GetBaseFormat(src); - return NPUNativeFunctions::npu_format_cast_(src, format, c10::nullopt); + return NPUNativeFunctions::npu_format_cast_(src, format); } } // namespace native diff --git a/torch_npu/csrc/aten/common/FormatCastKernelNpu.cpp b/torch_npu/csrc/aten/common/FormatCastKernelNpu.cpp index 9bdaadd0f7..2d83d17634 100644 --- a/torch_npu/csrc/aten/common/FormatCastKernelNpu.cpp +++ b/torch_npu/csrc/aten/common/FormatCastKernelNpu.cpp @@ -181,16 +181,6 @@ at::Tensor npu_format_cast_impl( int64_t acl_format) { auto src_desc = torch_npu::NPUBridge::GetNpuStorageImpl(src)->npu_desc_; - if (src_desc.npu_format_ == acl_format) { - ASCEND_LOGD("no need to do format cast"); - return src; - } - if (FormatHelper::IsBaseFormatType(src) && - FormatHelper::IsBaseFormatType(static_cast(acl_format))) { - FormatCastHelper::format_cast_as_base_format(src, static_cast(acl_format)); - return src; - } - at::Tensor dst = OpPreparation::ApplyTensorWithFormat( src_desc.base_sizes_, src.options(), acl_format); @@ -259,9 +249,33 @@ int64_t NPUNativeFunctions::get_npu_format(const at::Tensor& self) return src_desc.npu_format_; } +at::Tensor NPUNativeFunctions::_npu_format_cast(const at::Tensor& self, int64_t acl_format) +{ + auto src_desc = torch_npu::NPUBridge::GetNpuStorageImpl(self)->npu_desc_; + if (src_desc.npu_format_ == acl_format) { + ASCEND_LOGD("no need to do format cast"); + return self; + } + if (FormatHelper::IsBaseFormatType(self) && + FormatHelper::IsBaseFormatType(static_cast(acl_format))) { + FormatCastHelper::format_cast_as_base_format(self, static_cast(acl_format)); + return self; + } + auto [useAclnn, outFormat, StorageShape] = MaybeUseAclnnNpuFormatCast(self, acl_format, c10::nullopt); + if (useAclnn == false) { + return npu_format_cast_impl(self, acl_format); + } + return format_cast_impl_out_npu_aclnn(self, outFormat, StorageShape); +} + at::Tensor NPUNativeFunctions::_npu_format_cast(const at::Tensor& self, int64_t acl_format, - c10::optional customize_dtype) + int64_t customize_dtype) { + auto src_desc = torch_npu::NPUBridge::GetNpuStorageImpl(self)->npu_desc_; + if (src_desc.npu_format_ == acl_format) { + ASCEND_LOGD("no need to do format cast"); + return self; + } if (FormatHelper::IsBaseFormatType(self) && FormatHelper::IsBaseFormatType(static_cast(acl_format))) { FormatCastHelper::format_cast_as_base_format(self, static_cast(acl_format)); @@ -282,7 +296,10 @@ at::Tensor NPUNativeFunctions::npu_format_cast(const at::Tensor& self, int64_t a ASCEND_LOGD("no need to do format cast"); return self; } - return custom_ops::_npu_format_cast(self, acl_format, customize_dtype); + if (customize_dtype.has_value()) { + return custom_ops::_npu_format_cast(self, acl_format, customize_dtype.value()); + } + return custom_ops::_npu_format_cast(self, acl_format); } } // namespace native diff --git a/torch_npu/csrc/aten/common/LocalScalarDenseNpu.cpp b/torch_npu/csrc/aten/common/LocalScalarDenseNpu.cpp index 5f9fcdc03b..685f907653 100644 --- a/torch_npu/csrc/aten/common/LocalScalarDenseNpu.cpp +++ b/torch_npu/csrc/aten/common/LocalScalarDenseNpu.cpp @@ -11,22 +11,22 @@ namespace at_npu { namespace native { #define AT_DISPATCH_CASE_ALL_TYPES_AND5( \ - SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, SCALARTYPE5,...) \ - AT_DISPATCH_CASE_ALL_TYPES(__VA_ARGS__) \ - AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \ - AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \ - AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \ - AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__) \ - AT_DISPATCH_CASE(SCALARTYPE5, __VA_ARGS__) + SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, SCALARTYPE5, ...) \ + AT_DISPATCH_CASE_ALL_TYPES(__VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE5, __VA_ARGS__) #define AT_DISPATCH_ALL_TYPES_AND5( \ SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, SCALARTYPE5, TYPE, NAME, ...) \ - AT_DISPATCH_SWITCH( \ - TYPE, \ - NAME, \ - AT_DISPATCH_CASE_ALL_TYPES_AND5( \ - SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, SCALARTYPE5, __VA_ARGS__)) + AT_DISPATCH_SWITCH( \ + TYPE, \ + NAME, \ + AT_DISPATCH_CASE_ALL_TYPES_AND5( \ + SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, SCALARTYPE5, __VA_ARGS__)) c10::Scalar NPUNativeFunctions::_local_scalar_dense(const at::Tensor& self) diff --git a/torch_npu/csrc/aten/common/NpuFastReshape.cpp b/torch_npu/csrc/aten/common/NpuFastReshape.cpp index dfd7d69e0a..c07136aa50 100644 --- a/torch_npu/csrc/aten/common/NpuFastReshape.cpp +++ b/torch_npu/csrc/aten/common/NpuFastReshape.cpp @@ -32,7 +32,7 @@ void npu_fast_reshape_(at::Tensor& tensor) // refresh matadata to input tensor StorageDescHelper::ReflushDescBySelf(tensor); auto base_format = InferFormat::GuessBaseFormat(tensor.sizes()); - NPUNativeFunctions::npu_format_cast_(tensor, base_format, c10::nullopt); + NPUNativeFunctions::npu_format_cast_(tensor, base_format); } } // namespace native diff --git a/torch_npu/csrc/aten/common/ResizeNpu.cpp b/torch_npu/csrc/aten/common/ResizeNpu.cpp index e329ad41e8..af49fa1c33 100644 --- a/torch_npu/csrc/aten/common/ResizeNpu.cpp +++ b/torch_npu/csrc/aten/common/ResizeNpu.cpp @@ -46,7 +46,7 @@ const at::Tensor& NPUNativeFunctions::resize_( // no need to reflush NpuStorageDesc here. at::Tensor temp_self = self; if (!FormatHelper::IsBaseFormatType(self)) { - NPUNativeFunctions::npu_format_cast_(temp_self, FormatHelper::GetBaseFormat(self), c10::nullopt); + NPUNativeFunctions::npu_format_cast_(temp_self, FormatHelper::GetBaseFormat(self)); } auto* self_ = self.unsafeGetTensorImpl(); resize_impl_npu_(self_, size, c10::nullopt); diff --git a/torch_npu/csrc/aten/npu_native_functions.yaml b/torch_npu/csrc/aten/npu_native_functions.yaml index 3bc305d03f..b186df7651 100644 --- a/torch_npu/csrc/aten/npu_native_functions.yaml +++ b/torch_npu/csrc/aten/npu_native_functions.yaml @@ -84,7 +84,8 @@ custom: - func: get_storage_size(Tensor self) -> int - func: npu_format_cast(Tensor self, int acl_format, int? customize_dtype=None) -> Tensor exposed: True - - func: _npu_format_cast(Tensor self, int acl_format, int? customize_dtype=None) -> Tensor + - func: _npu_format_cast(Tensor self, int acl_format) -> Tensor + - func: _npu_format_cast.aclnn(Tensor self, int acl_format, int customize_dtype) -> Tensor - func: empty_with_swapped_memory(int[] size, *, ScalarType? dtype=None, Device? device=None) -> Tensor dispatch: CompositeExplicitAutograd: empty_with_swapped_memory diff --git a/torch_npu/csrc/aten/ops/FlattenDenseTensorsKernelNpu.cpp b/torch_npu/csrc/aten/ops/FlattenDenseTensorsKernelNpu.cpp index 2f06acbaea..e2be317874 100644 --- a/torch_npu/csrc/aten/ops/FlattenDenseTensorsKernelNpu.cpp +++ b/torch_npu/csrc/aten/ops/FlattenDenseTensorsKernelNpu.cpp @@ -7,8 +7,7 @@ namespace native { at::Tensor NPUNativeFunctions::flatten_dense_tensors(at::TensorList tensors) { static auto cast_back_to_ori_format = [](const at::Tensor& t) { - return custom_ops::npu_format_cast(t, torch_npu::NPUBridge::GetNpuStorageImpl(t)->npu_desc_.origin_format_, - c10::nullopt); + return custom_ops::npu_format_cast(t, torch_npu::NPUBridge::GetNpuStorageImpl(t)->npu_desc_.origin_format_); }; static auto flatten = [](const at::Tensor& t) { return cast_back_to_ori_format(t).contiguous().view({-1}); diff --git a/torch_npu/csrc/core/NPUSerialization.cpp b/torch_npu/csrc/core/NPUSerialization.cpp index 5cbe642478..1ae122f342 100644 --- a/torch_npu/csrc/core/NPUSerialization.cpp +++ b/torch_npu/csrc/core/NPUSerialization.cpp @@ -48,7 +48,7 @@ void npu_info_deserialization(const at::Tensor &t, std::unordered_map(t), format, c10::nullopt); + at_npu::native::NPUNativeFunctions::npu_format_cast_(const_cast(t), format); if (revert_flag) { t.set_requires_grad(true); } diff --git a/torch_npu/csrc/core/npu/NPUFormat.cpp b/torch_npu/csrc/core/npu/NPUFormat.cpp index eed97ce703..b087842cc3 100644 --- a/torch_npu/csrc/core/npu/NPUFormat.cpp +++ b/torch_npu/csrc/core/npu/NPUFormat.cpp @@ -37,7 +37,7 @@ std::vector get_npu_storage_sizes(const at::Tensor& self) at::Tensor npu_format_cast(const at::Tensor& self, int64_t acl_format) { - return NPUNativeFunctions::npu_format_cast(self, acl_format, c10::nullopt); + return NPUNativeFunctions::npu_format_cast(self, acl_format); } at::Tensor empty_with_format(c10::IntArrayRef sizes, const c10::TensorOptions& options, diff --git a/torch_npu/csrc/distributed/Init.cpp b/torch_npu/csrc/distributed/Init.cpp index 4f9bb5ef62..1e40130f30 100644 --- a/torch_npu/csrc/distributed/Init.cpp +++ b/torch_npu/csrc/distributed/Init.cpp @@ -98,8 +98,7 @@ public: inline std::vector cast_tensors(at::TensorList tensors) const { static auto cast_back_to_ori_format = [](const at::Tensor &t) { - return at_npu::native::custom_ops::npu_format_cast(t, - torch_npu::NPUBridge::GetNpuStorageImpl(t)->npu_desc_.origin_format_, c10::nullopt); + return at_npu::native::custom_ops::npu_format_cast(t, torch_npu::NPUBridge::GetNpuStorageImpl(t)->npu_desc_.origin_format_); }; return c10::fmap(tensors, cast_back_to_ori_format); } diff --git a/torch_npu/csrc/distributed/ProcessGroupHCCL.cpp b/torch_npu/csrc/distributed/ProcessGroupHCCL.cpp index a7d61f130a..4b2d28025d 100644 --- a/torch_npu/csrc/distributed/ProcessGroupHCCL.cpp +++ b/torch_npu/csrc/distributed/ProcessGroupHCCL.cpp @@ -2565,7 +2565,7 @@ std::vector cast_to_origin_format(const std::vector& inp inputTensors_[index] = tensor; } else { auto origin_format = torch_npu::NPUBridge::GetNpuStorageImpl(tensor)->npu_desc_.origin_format_; - inputTensors_[index] = at_npu::native::custom_ops::npu_format_cast(tensor, origin_format, c10::nullopt); + inputTensors_[index] = at_npu::native::custom_ops::npu_format_cast(tensor, origin_format); } index++; } diff --git a/torch_npu/csrc/distributed/reducer.cpp b/torch_npu/csrc/distributed/reducer.cpp index f18614f23a..da3664149b 100644 --- a/torch_npu/csrc/distributed/reducer.cpp +++ b/torch_npu/csrc/distributed/reducer.cpp @@ -368,7 +368,7 @@ void Reducer::mark_variable_ready_dense(size_t variable_index) if (torch_npu::NPUBridge::GetNpuStorageImpl(grad)->npu_desc_.npu_format_ != torch_npu::NPUBridge::GetNpuStorageImpl(variable)->npu_desc_.npu_format_) { grad = at_npu::native::NPUNativeFunctions::npu_format_cast(grad, - torch_npu::NPUBridge::GetNpuStorageImpl(variable)->npu_desc_.npu_format_, c10::nullopt); + torch_npu::NPUBridge::GetNpuStorageImpl(variable)->npu_desc_.npu_format_); } if (comm_hook_ == nullptr) { if (!grad.requires_grad()) { diff --git a/torch_npu/csrc/framework/utils/NpuUtils.cpp b/torch_npu/csrc/framework/utils/NpuUtils.cpp index 6e51d36b6d..e54c951dc1 100644 --- a/torch_npu/csrc/framework/utils/NpuUtils.cpp +++ b/torch_npu/csrc/framework/utils/NpuUtils.cpp @@ -129,7 +129,7 @@ at::Tensor metadata_convert_match(const at::Tensor &src, bool numelEq) // NCHW will generate a temporary tensor, which always monopolizes its own // storage. if (numelEq && (!FormatHelper::IsBaseFormatType(src))) { - at::Tensor tempTensor = custom_ops::npu_format_cast(src, FormatHelper::GetBaseFormat(src), c10::nullopt); + at::Tensor tempTensor = custom_ops::npu_format_cast(src, FormatHelper::GetBaseFormat(src)); custom_ops::npu_reshape_out(tempTensor, tempTensor.sizes(), true, tempTensor); NpuUtils::RefreshFormat(tempTensor); return tempTensor; diff --git a/torch_npu/csrc/framework/utils/OpPreparation.cpp b/torch_npu/csrc/framework/utils/OpPreparation.cpp index bdaeaacee6..530d359df2 100644 --- a/torch_npu/csrc/framework/utils/OpPreparation.cpp +++ b/torch_npu/csrc/framework/utils/OpPreparation.cpp @@ -227,14 +227,14 @@ void OpPreparation::check_memory(const std::initializer_list &inputs at::Tensor OpPreparation::cast_to_ori_format(const at::Tensor &tensor) { auto &tensor_desc = torch_npu::NPUBridge::GetNpuStorageImpl(tensor)->npu_desc_; - auto ret = custom_ops::npu_format_cast(tensor, tensor_desc.origin_format_, c10::nullopt); + auto ret = custom_ops::npu_format_cast(tensor, tensor_desc.origin_format_); return ret; } at::Tensor &OpPreparation::cast_to_ori_format(at::Tensor &tensor) { auto &tensor_desc = torch_npu::NPUBridge::GetNpuStorageImpl(tensor)->npu_desc_; - NPUNativeFunctions::npu_format_cast_(tensor, tensor_desc.origin_format_, c10::nullopt); + NPUNativeFunctions::npu_format_cast_(tensor, tensor_desc.origin_format_); return tensor; } @@ -361,21 +361,21 @@ void OpPreparation::CheckOut(const std::initializer_list &input, if (CalcuOpUtil::GetTensorNpuFormat(output) != format) { TORCH_CHECK(!is_read_write, "can not cast format when output is input", OPS_ERROR(ErrCode::NOT_SUPPORT)); - NPUNativeFunctions::npu_format_cast_(output, format, c10::nullopt); + NPUNativeFunctions::npu_format_cast_(output, format); } } at::Tensor OpPreparation::CastBackToOriFormat(const at::Tensor &tensor) { auto &tensor_desc = torch_npu::NPUBridge::GetNpuStorageImpl(tensor)->npu_desc_; - auto ret = custom_ops::npu_format_cast(tensor, tensor_desc.origin_format_, c10::nullopt); + auto ret = custom_ops::npu_format_cast(tensor, tensor_desc.origin_format_); return ret; } at::Tensor &OpPreparation::CastBackToOriFormat(at::Tensor &tensor) { auto &tensor_desc = torch_npu::NPUBridge::GetNpuStorageImpl(tensor)->npu_desc_; - NPUNativeFunctions::npu_format_cast_(tensor, tensor_desc.origin_format_, c10::nullopt); + NPUNativeFunctions::npu_format_cast_(tensor, tensor_desc.origin_format_); return tensor; } diff --git a/torch_npu/utils/hif8_tensor.py b/torch_npu/utils/hif8_tensor.py index 353bf1bf51..a1d3e225ea 100644 --- a/torch_npu/utils/hif8_tensor.py +++ b/torch_npu/utils/hif8_tensor.py @@ -321,7 +321,7 @@ class HiFloat8Tensor(torch.Tensor): def from_hifloat8(self, dtype: Optional[torch.dtype] = None) -> torch.Tensor: """ - Construct plain PyTorch tensor from HiFloat8Tensor + Construct PyTorch tensor from HiFloat8Tensor By default the resulting tensor's dtype is the HiFloat8Tensor's nominal dtype. @@ -333,7 +333,7 @@ class HiFloat8Tensor(torch.Tensor): cls, tensor: torch.Tensor ): - """Construct HiFloat8Tensor from plain PyTorch tensor""" + """Construct HiFloat8Tensor from PyTorch tensor""" return _ToHiFloat8Func.apply( tensor ) -- Gitee From 9d5145b14c8526d0b2bb52fe059eb6b2113a48c8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=AE=8B=E6=81=BA?= Date: Wed, 11 Jun 2025 01:14:03 +0000 Subject: [PATCH 04/11] !217 Change npu_rotary_mul onnx * Change npu_rotary_mul onnx --- test/onnx/test_wrapper_onnx_ops.py | 22 ++++++++++++++++++++++ torch_npu/onnx/wrapper_onnx_ops.py | 16 ++++++++++++---- 2 files changed, 34 insertions(+), 4 deletions(-) diff --git a/test/onnx/test_wrapper_onnx_ops.py b/test/onnx/test_wrapper_onnx_ops.py index 1363719fbd..7e072e9653 100644 --- a/test/onnx/test_wrapper_onnx_ops.py +++ b/test/onnx/test_wrapper_onnx_ops.py @@ -1221,6 +1221,28 @@ class TestOnnxOps(TestCase): export_onnx(onnx_model_name) assert (os.path.isfile(os.path.join(TestOnnxOps.test_onnx_path, onnx_model_name))) + @SupportedDevices(['Ascend910B']) + def test_wrapper_npu_rotary_mul_with_mode(self): + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, r1, r2, mode): + return torch_npu.npu_rotary_mul(x, r1, r2, mode) + + def export_onnx(onnx_model_name): + x = torch.rand([8192, 2, 5, 128], dtype=torch.float32).npu() + r1 = torch.rand([8192, 1, 1, 128], dtype=torch.float32).npu() + r2 = torch.rand([8192, 1, 1, 128], dtype=torch.float32).npu() + rotary_mode = "interleave" + model = Model().to("npu") + model(x, r1, r2, rotary_mode) + self.onnx_export(model, (x, r1, r2, rotary_mode), onnx_model_name, ["x", "r1", "r2", "rotary_mode"]) + + onnx_model_name = "model_npu_rotary_mul_interleave.onnx" + export_onnx(onnx_model_name) + assert (os.path.isfile(os.path.join(TestOnnxOps.test_onnx_path, onnx_model_name))) + @SupportedDevices(['Ascend910B']) def test_wrapper_npu_masked_softmax_with_rel_pos_bias(self): class Model(torch.nn.Module): diff --git a/torch_npu/onnx/wrapper_onnx_ops.py b/torch_npu/onnx/wrapper_onnx_ops.py index d69a875a38..b7b4f40974 100644 --- a/torch_npu/onnx/wrapper_onnx_ops.py +++ b/torch_npu/onnx/wrapper_onnx_ops.py @@ -649,8 +649,16 @@ class _NPURotaryMulOP(torch.autograd.Function): return torch.ops.npu.npu_rotary_mul(*args, **kwargs) @staticmethod - def symbolic(g, x: Tensor, r1: Tensor, r2: Tensor): - return g.op("npu::NPURotaryMul", x, r1, r2) + def symbolic(g, x: Tensor, r1: Tensor, r2: Tensor, rotary_mode: str = "half"): + if rotary_mode == "half": + return g.op("npu::NPURotaryMul", x, r1, r2) + elif rotary_mode == "interleave": + mode = 1 + elif rotary_mode == "quarter": + mode = 2 + elif rotary_mode == "interleave-half": + mode = 3 + return g.op("npu::NPURotaryPositionEmbedding", x, r1, r2, mode_i=mode) class _NPUPromptFlashAttentionOP(torch.autograd.Function): @@ -1310,8 +1318,8 @@ def _wrapper_npu_mish(self): return _NPUMishOP.apply(self) -def _wrapper_npu_rotary_mul(x, r1, r2): - return _NPURotaryMulOP.apply(x, r1, r2) +def _wrapper_npu_rotary_mul(x, r1, r2, rotary_mode="half"): + return _NPURotaryMulOP.apply(x, r1, r2, rotary_mode) def _wrapper_npu_prompt_flash_attention(self, query, key, value, padding_mask, atten_mask, pse_shift, actual_seq_lengths, -- Gitee From 5d2bc5abfc9e41174f76c9a5676836f5f875bd33 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=AB=98=E7=BB=B4=E5=81=A5?= Date: Thu, 12 Jun 2025 06:40:53 +0000 Subject: [PATCH 05/11] !225 quantconv2d offset must be False dfx * quantconv2d only support offset False dfx --- torch_npu/contrib/module/quant_conv2d.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torch_npu/contrib/module/quant_conv2d.py b/torch_npu/contrib/module/quant_conv2d.py index 5bca024785..e7a395e850 100644 --- a/torch_npu/contrib/module/quant_conv2d.py +++ b/torch_npu/contrib/module/quant_conv2d.py @@ -6,6 +6,7 @@ from torch.nn.common_types import _size_2_t from torch.nn.parameter import Parameter from torch.nn.modules.utils import _pair import torch_npu +from torch_npu.utils._error_code import ErrCode, ops_error __all__ = ['QuantConv2d'] @@ -115,7 +116,7 @@ class QuantConv2d(nn.Module): else: self.register_parameter('bias', None) if offset: - self.offset = Parameter(torch.empty(out_channels, dtype=torch.float32), False) + raise ValueError("offset must be False" + ops_error(ErrCode.VALUE)) else: self.register_parameter('offset', None) -- Gitee From 3d22c33dd7ffb58267b7df7ce073cbfdb02042c8 Mon Sep 17 00:00:00 2001 From: chuboning Date: Thu, 12 Jun 2025 21:16:54 +0800 Subject: [PATCH 06/11] Fix npu_format_cast --- codegen/utils.py | 2 +- test/torch_npu_schema.json | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/codegen/utils.py b/codegen/utils.py index 03e8d2f79c..5810d3b3e6 100644 --- a/codegen/utils.py +++ b/codegen/utils.py @@ -483,7 +483,7 @@ if (({force_aclnn} || at_npu::native::env::CheckJitDisable()){tensor_check_str}) """ else: return_code = f"""\ -if (({is_ascend910_95_version} || {force_aclnn} || at_npu::native::env::CheckJitDisable())) {{ +if (({force_aclnn} || at_npu::native::env::CheckJitDisable())) {{ return {op_api_impl_name}({args_exprs_str}); }} else {{ return {impl_name}({args_exprs_str}); diff --git a/test/torch_npu_schema.json b/test/torch_npu_schema.json index ea7a738c78..1e5b2151c3 100644 --- a/test/torch_npu_schema.json +++ b/test/torch_npu_schema.json @@ -2865,7 +2865,10 @@ "signature": "" }, "func: _npu_format_cast": { - "signature": "(Tensor self, int acl_format, int? customize_dtype=None) -> Tensor" + "signature": "(Tensor self, int acl_format) -> Tensor" + }, + "func: _npu_format_cast": { + "signature": "(Tensor self, int acl_format, int customize_dtype) -> Tensor" }, "torch_npu_public_env: INF_NAN_MODE_ENABLE": { "mode": "std::unordered_map infNanMode = {{0, \"max\"}, {1, \"inf_nan\"}}" -- Gitee From db2c44df3ceebec808c068524fcd8f9a441b7685 Mon Sep 17 00:00:00 2001 From: chuboning Date: Fri, 13 Jun 2025 17:27:35 +0800 Subject: [PATCH 07/11] Modify transformer_engine --- CMakeLists.txt | 6 +- torch_npu/__init__.py | 4 +- torch_npu/csrc/InitNpuBindings.cpp | 4 +- .../csrc/aten/common/FormatCastKernelNpu.cpp | 4 +- .../CMakeLists.txt | 4 +- .../CastKernelTeOpApi.cpp | 18 ++- .../Init.cpp | 136 +++++++++--------- .../Init.h | 82 ++++++----- .../extension.h | 6 +- torch_npu/utils/hif8_tensor.py | 8 +- 10 files changed, 132 insertions(+), 140 deletions(-) rename torch_npu/csrc/{transformer_engine => custom_dtype}/CMakeLists.txt (33%) rename torch_npu/csrc/{transformer_engine => custom_dtype}/CastKernelTeOpApi.cpp (62%) rename torch_npu/csrc/{transformer_engine => custom_dtype}/Init.cpp (66%) rename torch_npu/csrc/{transformer_engine => custom_dtype}/Init.h (32%) rename torch_npu/csrc/{transformer_engine => custom_dtype}/extension.h (73%) diff --git a/CMakeLists.txt b/CMakeLists.txt index 0150bc1aeb..d38f3d95dc 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -246,7 +246,7 @@ add_subdirectory(${TORCHNPU_ROOT}/core) add_subdirectory(${TORCHNPU_ROOT}/framework) add_subdirectory(${TORCHNPU_ROOT}/flopcount) add_subdirectory(${TORCHNPU_ROOT}/logging) -add_subdirectory(${TORCHNPU_ROOT}/transformer_engine) +add_subdirectory(${TORCHNPU_ROOT}/custom_dtype) if (NOT DEFINED BUILD_LIBTORCH) add_subdirectory(${TORCHNPU_ROOT}/distributed) @@ -273,10 +273,10 @@ if (DEFINED BUILD_TENSORPIPE) endif() if (DEFINED BUILD_LIBTORCH) - set(CPP_SRCS ${ATEN_SRCS} ${CORE_SRCS} ${OPS_PLUGIN_SRCS} ${FLOP_SRCS} ${TE_SRCS} ${FRAMEWORK_SRCS} ${LOGGING_SRCS} ${NPU_CPP_LIBS_SRCS}) + set(CPP_SRCS ${ATEN_SRCS} ${CORE_SRCS} ${OPS_PLUGIN_SRCS} ${FLOP_SRCS} ${CUS_DTYPE_SRCS} ${FRAMEWORK_SRCS} ${LOGGING_SRCS} ${NPU_CPP_LIBS_SRCS}) else() # Compile code with pybind11 - set(CPP_SRCS ${ATEN_SRCS} ${CORE_SRCS} ${OPS_PLUGIN_SRCS} ${DIST_SRCS} ${FLOP_SRCS} ${TE_SRCS} ${LOGGING_SRCS} ${FRAMEWORK_SRCS} ${NPU_SRCS} ${PROF_SRCS} ${UTILS_SRCS} ${SAN_SRCS}) + set(CPP_SRCS ${ATEN_SRCS} ${CORE_SRCS} ${OPS_PLUGIN_SRCS} ${DIST_SRCS} ${FLOP_SRCS} ${CUS_DTYPE_SRCS} ${LOGGING_SRCS} ${FRAMEWORK_SRCS} ${NPU_SRCS} ${PROF_SRCS} ${UTILS_SRCS} ${SAN_SRCS}) endif() add_library(${PLUGIN_NAME} SHARED ${CPP_SRCS}) diff --git a/torch_npu/__init__.py b/torch_npu/__init__.py index 78790b96d5..74f42a5389 100644 --- a/torch_npu/__init__.py +++ b/torch_npu/__init__.py @@ -91,10 +91,10 @@ for name in dir(torch.ops.npu): __all__.append(name) setattr(torch, name, _wrap_torch_error_func(getattr(torch.ops.npu, name))) -for name in dir(torch_npu._C._te.DType): +for name in dir(torch_npu._C._cd.DType): if name.startswith('__') or name in ['_dir', 'name']: continue - setattr(torch_npu, name, getattr(torch_npu._C._te.DType, name)) + setattr(torch_npu, name, getattr(torch_npu._C._cd.DType, name)) all_monkey_patches = [ ["nn.functional", npu_functional], diff --git a/torch_npu/csrc/InitNpuBindings.cpp b/torch_npu/csrc/InitNpuBindings.cpp index 8bc7a11471..3e4a8e5fbc 100644 --- a/torch_npu/csrc/InitNpuBindings.cpp +++ b/torch_npu/csrc/InitNpuBindings.cpp @@ -15,7 +15,7 @@ #include "torch_npu/csrc/flopcount/Init.h" #include "torch_npu/csrc/logging/Init.h" #include "torch_npu/csrc/npu/Module.h" -#include "torch_npu/csrc/transformer_engine/Init.h" +#include "torch_npu/csrc/custom_dtype/Init.h" #include "torch_npu/csrc/npu/Stress_detect.h" #include "torch_npu/csrc/utils/TensorType.h" #include "torch_npu/csrc/utils/AutocastMode.h" @@ -168,7 +168,7 @@ PyObject* initModule() AddPyMethodDefs(methods, torch_npu::autocast::autocast_mode_functions()); AddPyMethodDefs(methods, torch_npu::flopcount::flops_count_functions()); AddPyMethodDefs(methods, torch_npu::logging::logging_functions()); - AddPyMethodDefs(methods, torch_npu::te::te_functions()); + AddPyMethodDefs(methods, c10_npu::custom_dtype_functions()); static struct PyModuleDef torchnpu_module = { PyModuleDef_HEAD_INIT, "torch_npu._C", diff --git a/torch_npu/csrc/aten/common/FormatCastKernelNpu.cpp b/torch_npu/csrc/aten/common/FormatCastKernelNpu.cpp index 2d83d17634..4e683ade9a 100644 --- a/torch_npu/csrc/aten/common/FormatCastKernelNpu.cpp +++ b/torch_npu/csrc/aten/common/FormatCastKernelNpu.cpp @@ -8,7 +8,7 @@ #include "torch_npu/csrc/core/NPUStorageImpl.h" #include "torch_npu/csrc/core/npu/NpuVariables.h" #include "torch_npu/csrc/aten/CustomFunctions.h" -#include "torch_npu/csrc/transformer_engine/Init.h" +#include "torch_npu/csrc/custom_dtype/Init.h" #include "third_party/op-plugin/op_plugin/utils/op_api_common.h" namespace at_npu { @@ -32,7 +32,7 @@ std::tuple> MaybeUseAclnnNpuForma int dstFormat; at::SmallVector outputShape = {}; aclDataType customizeAcltype = (customize_dtype.has_value()) ? - torch_npu::te::GetAclDataType(customize_dtype.value()) : + c10_npu::GetAclDataType(customize_dtype.value()) : at_npu::native::OpPreparation::convert_to_acl_data_type(src.scalar_type()); if (c10_npu::IsAscend910_95Version()) { diff --git a/torch_npu/csrc/transformer_engine/CMakeLists.txt b/torch_npu/csrc/custom_dtype/CMakeLists.txt similarity index 33% rename from torch_npu/csrc/transformer_engine/CMakeLists.txt rename to torch_npu/csrc/custom_dtype/CMakeLists.txt index b5350d4957..bdad67c05d 100644 --- a/torch_npu/csrc/transformer_engine/CMakeLists.txt +++ b/torch_npu/csrc/custom_dtype/CMakeLists.txt @@ -1,6 +1,6 @@ FILE(GLOB _TE_SRCS *.cpp) -LIST(APPEND TE_SRCS ${_TE_SRCS}) +LIST(APPEND CUS_DTYPE_SRCS ${_TE_SRCS}) # Pass to parent -set(TE_SRCS ${TE_SRCS} PARENT_SCOPE) +set(CUS_DTYPE_SRCS ${CUS_DTYPE_SRCS} PARENT_SCOPE) diff --git a/torch_npu/csrc/transformer_engine/CastKernelTeOpApi.cpp b/torch_npu/csrc/custom_dtype/CastKernelTeOpApi.cpp similarity index 62% rename from torch_npu/csrc/transformer_engine/CastKernelTeOpApi.cpp rename to torch_npu/csrc/custom_dtype/CastKernelTeOpApi.cpp index afcfd6f7af..2293ba94dd 100644 --- a/torch_npu/csrc/transformer_engine/CastKernelTeOpApi.cpp +++ b/torch_npu/csrc/custom_dtype/CastKernelTeOpApi.cpp @@ -1,21 +1,20 @@ -#include "torch_npu/csrc/transformer_engine/extension.h" +#include "torch_npu/csrc/custom_dtype/extension.h" #include "op_plugin/AclOpsInterface.h" #include "op_plugin/OpApiInterface.h" #include "op_plugin/utils/op_api_common.h" -namespace torch_npu { -namespace te { +namespace c10_npu { at::Tensor cast_to_fp8(const at::Tensor &input, int otype) { - auto output = at::empty_like(input, torch_npu::te::GetATenDType(otype)); + auto output = at::empty_like(input, c10_npu::GetATenDType(otype)); if (input.numel() == 0) { return output; } - aclDataType out_acltype = torch_npu::te::GetAclDataType(otype); + aclDataType out_acltype = c10_npu::GetAclDataType(otype); TensorWrapper out_wrapper = {output, out_acltype}; EXEC_NPU_CMD(aclnnCast, input, out_acltype, out_wrapper); @@ -24,7 +23,7 @@ at::Tensor cast_to_fp8(const at::Tensor &input, int otype) void cast_to_fp8_noalloc(const at::Tensor &input, at::Tensor output, int otype) { - aclDataType out_acltype = torch_npu::te::GetAclDataType(otype); + aclDataType out_acltype = c10_npu::GetAclDataType(otype); TensorWrapper out_wrapper = {output, out_acltype}; EXEC_NPU_CMD(aclnnCast, input, out_acltype, out_wrapper); return; @@ -32,9 +31,9 @@ void cast_to_fp8_noalloc(const at::Tensor &input, at::Tensor output, int otype) at::Tensor cast_from_fp8(const at::Tensor &input, int itype, int otype) { - aclDataType input_acltype = torch_npu::te::GetAclDataType(itype); - aclDataType out_acltype = torch_npu::te::GetAclDataType(otype); - auto output = at::empty_like(input, torch_npu::te::GetATenDType(otype)); + aclDataType input_acltype = c10_npu::GetAclDataType(itype); + aclDataType out_acltype = c10_npu::GetAclDataType(otype); + auto output = at::empty_like(input, c10_npu::GetATenDType(otype)); TensorWrapper input_wrapper = {input, input_acltype}; TensorWrapper out_wrapper = {output, out_acltype}; EXEC_NPU_CMD(aclnnCast, input_wrapper, out_acltype, out_wrapper); @@ -42,4 +41,3 @@ at::Tensor cast_from_fp8(const at::Tensor &input, int itype, int otype) return output; } } -} diff --git a/torch_npu/csrc/transformer_engine/Init.cpp b/torch_npu/csrc/custom_dtype/Init.cpp similarity index 66% rename from torch_npu/csrc/transformer_engine/Init.cpp rename to torch_npu/csrc/custom_dtype/Init.cpp index bea7de1f36..90644aa1e3 100644 --- a/torch_npu/csrc/transformer_engine/Init.cpp +++ b/torch_npu/csrc/custom_dtype/Init.cpp @@ -1,13 +1,12 @@ -#include "torch_npu/csrc/transformer_engine/Init.h" +#include "torch_npu/csrc/custom_dtype/Init.h" #ifndef BUILD_LIBTORCH #include #include #endif -#include "torch_npu/csrc/transformer_engine/extension.h" +#include "torch_npu/csrc/custom_dtype/extension.h" -namespace torch_npu { -namespace te { +namespace c10_npu { struct DTypeConstants { static const int float32_value; static const int float16_value; @@ -38,43 +37,43 @@ struct DTypeConstants { static const int float4_e1m2_value; }; -const int DTypeConstants::float32_value = static_cast(DType::TE_FLOAT); -const int DTypeConstants::float16_value = static_cast(DType::TE_FLOAT16); -const int DTypeConstants::int8_value = static_cast(DType::TE_INT8); -const int DTypeConstants::int32_value = static_cast(DType::TE_INT32); -const int DTypeConstants::uint8_value = static_cast(DType::TE_UINT8); -const int DTypeConstants::int16_value = static_cast(DType::TE_INT16); -const int DTypeConstants::uint16_value = static_cast(DType::TE_UINT16); -const int DTypeConstants::uint32_value = static_cast(DType::TE_UINT32); -const int DTypeConstants::int64_value = static_cast(DType::TE_INT64); -const int DTypeConstants::uint64_value = static_cast(DType::TE_UINT64); -const int DTypeConstants::float64_value = static_cast(DType::TE_DOUBLE); -const int DTypeConstants::bool_value = static_cast(DType::TE_BOOL); -const int DTypeConstants::string_value = static_cast(DType::TE_STRING); -const int DTypeConstants::complex64_value = static_cast(DType::TE_COMPLEX64); -const int DTypeConstants::complex128_value = static_cast(DType::TE_COMPLEX128); -const int DTypeConstants::bfloat16_value = static_cast(DType::TE_BF16); -const int DTypeConstants::int4_value = static_cast(DType::TE_INT4); -const int DTypeConstants::uint1_value = static_cast(DType::TE_UINT1); -const int DTypeConstants::complex32_value = static_cast(DType::TE_COMPLEX32); -const int DTypeConstants::hifloat8_value = static_cast(DType::TE_HIFLOAT8); -const int DTypeConstants::float8_e5m2_value = static_cast(DType::TE_FLOAT8_E5M2); -const int DTypeConstants::float8_e4m3fn_value = static_cast(DType::TE_FLOAT8_E4M3FN); -const int DTypeConstants::float8_e8m0_value = static_cast(DType::TE_FLOAT8_E8M0); -const int DTypeConstants::float6_e3m2_value = static_cast(DType::TE_FLOAT6_E3M2); -const int DTypeConstants::float6_e2m3_value = static_cast(DType::TE_FLOAT6_E2M3); -const int DTypeConstants::float4_e2m1_value = static_cast(DType::TE_FLOAT4_E2M1); -const int DTypeConstants::float4_e1m2_value = static_cast(DType::TE_FLOAT4_E1M2); +const int DTypeConstants::float32_value = static_cast(DType::FLOAT); +const int DTypeConstants::float16_value = static_cast(DType::FLOAT16); +const int DTypeConstants::int8_value = static_cast(DType::INT8); +const int DTypeConstants::int32_value = static_cast(DType::INT32); +const int DTypeConstants::uint8_value = static_cast(DType::UINT8); +const int DTypeConstants::int16_value = static_cast(DType::INT16); +const int DTypeConstants::uint16_value = static_cast(DType::UINT16); +const int DTypeConstants::uint32_value = static_cast(DType::UINT32); +const int DTypeConstants::int64_value = static_cast(DType::INT64); +const int DTypeConstants::uint64_value = static_cast(DType::UINT64); +const int DTypeConstants::float64_value = static_cast(DType::DOUBLE); +const int DTypeConstants::bool_value = static_cast(DType::BOOL); +const int DTypeConstants::string_value = static_cast(DType::STRING); +const int DTypeConstants::complex64_value = static_cast(DType::COMPLEX64); +const int DTypeConstants::complex128_value = static_cast(DType::COMPLEX128); +const int DTypeConstants::bfloat16_value = static_cast(DType::BF16); +const int DTypeConstants::int4_value = static_cast(DType::INT4); +const int DTypeConstants::uint1_value = static_cast(DType::UINT1); +const int DTypeConstants::complex32_value = static_cast(DType::COMPLEX32); +const int DTypeConstants::hifloat8_value = static_cast(DType::HIFLOAT8); +const int DTypeConstants::float8_e5m2_value = static_cast(DType::FLOAT8_E5M2); +const int DTypeConstants::float8_e4m3fn_value = static_cast(DType::FLOAT8_E4M3FN); +const int DTypeConstants::float8_e8m0_value = static_cast(DType::FLOAT8_E8M0); +const int DTypeConstants::float6_e3m2_value = static_cast(DType::FLOAT6_E3M2); +const int DTypeConstants::float6_e2m3_value = static_cast(DType::FLOAT6_E2M3); +const int DTypeConstants::float4_e2m1_value = static_cast(DType::FLOAT4_E2M1); +const int DTypeConstants::float4_e1m2_value = static_cast(DType::FLOAT4_E1M2); #ifndef BUILD_LIBTORCH -PyObject* te_initExtension(PyObject*, PyObject *) +PyObject* cd_initExtension(PyObject*, PyObject *) { auto torch_npu_C_module = THPObjectPtr(PyImport_ImportModule("torch_npu._C")); if (!torch_npu_C_module) { return nullptr; } auto torch_npu_C_m = py::handle(torch_npu_C_module).cast(); - auto m = torch_npu_C_m.def_submodule("_te", "_te bindings"); + auto m = torch_npu_C_m.def_submodule("_cd", "_cd bindings"); py::class_(m, "DType") .def_readonly_static("float32", &DTypeConstants::float32_value) @@ -113,53 +112,52 @@ PyObject* te_initExtension(PyObject*, PyObject *) Py_RETURN_NONE; } -static PyMethodDef NPUTeMethods[] = { // NOLINT - {"_te_init", te_initExtension, METH_NOARGS, nullptr}, +static PyMethodDef NPUCustomDtypeMethods[] = { // NOLINT + {"_cd_init", cd_initExtension, METH_NOARGS, nullptr}, {nullptr, nullptr, 0, nullptr} }; #endif -const std::string TeDataTypeToString(int64_t dType) +const std::string CustomDataTypeToString(int64_t dType) { const std::map - TE_TYPE_TO_STRING_MAP = { - {DType::TE_FLOAT, "torch_npu.float32"}, - {DType::TE_FLOAT16, "torch_npu.float16"}, - {DType::TE_INT8, "torch_npu.int8"}, - {DType::TE_INT32, "torch_npu.int32"}, - {DType::TE_UINT8, "torch_npu.uint8"}, - {DType::TE_INT16, "torch_npu.int16"}, - {DType::TE_UINT16, "torch_npu.uint16"}, - {DType::TE_UINT32, "torch_npu.uint32"}, - {DType::TE_INT64, "torch_npu.int64"}, - {DType::TE_UINT64, "torch_npu.uint64"}, - {DType::TE_DOUBLE, "torch_npu.float64"}, - {DType::TE_BOOL, "torch_npu.bool"}, - {DType::TE_STRING, "torch_npu.string"}, - {DType::TE_COMPLEX64, "torch_npu.complex64"}, - {DType::TE_COMPLEX128, "torch_npu.complex128"}, - {DType::TE_BF16, "torch_npu.bfloat16"}, - {DType::TE_INT4, "torch_npu.int4"}, - {DType::TE_UINT1, "torch_npu.uint1"}, - {DType::TE_COMPLEX32, "torch_npu.complex32"}, - {DType::TE_HIFLOAT8, "torch_npu.hifloat8"}, - {DType::TE_FLOAT8_E5M2, "torch_npu.float8_e5m2"}, - {DType::TE_FLOAT8_E4M3FN, "torch_npu.float8_e4m3fn"}, - {DType::TE_FLOAT8_E8M0, "torch_npu.float8_e8m0"}, - {DType::TE_FLOAT6_E3M2, "torch_npu.float6_e3m2"}, - {DType::TE_FLOAT6_E2M3, "torch_npu.float6_e2m3"}, - {DType::TE_FLOAT4_E2M1, "torch_npu.float4_e2m1"}, - {DType::TE_FLOAT4_E1M2, "torch_npu.float4_e1m2"}}; + TYPE_TO_STRING_MAP = { + {DType::FLOAT, "torch_npu.float32"}, + {DType::FLOAT16, "torch_npu.float16"}, + {DType::INT8, "torch_npu.int8"}, + {DType::INT32, "torch_npu.int32"}, + {DType::UINT8, "torch_npu.uint8"}, + {DType::INT16, "torch_npu.int16"}, + {DType::UINT16, "torch_npu.uint16"}, + {DType::UINT32, "torch_npu.uint32"}, + {DType::INT64, "torch_npu.int64"}, + {DType::UINT64, "torch_npu.uint64"}, + {DType::DOUBLE, "torch_npu.float64"}, + {DType::BOOL, "torch_npu.bool"}, + {DType::STRING, "torch_npu.string"}, + {DType::COMPLEX64, "torch_npu.complex64"}, + {DType::COMPLEX128, "torch_npu.complex128"}, + {DType::BF16, "torch_npu.bfloat16"}, + {DType::INT4, "torch_npu.int4"}, + {DType::UINT1, "torch_npu.uint1"}, + {DType::COMPLEX32, "torch_npu.complex32"}, + {DType::HIFLOAT8, "torch_npu.hifloat8"}, + {DType::FLOAT8_E5M2, "torch_npu.float8_e5m2"}, + {DType::FLOAT8_E4M3FN, "torch_npu.float8_e4m3fn"}, + {DType::FLOAT8_E8M0, "torch_npu.float8_e8m0"}, + {DType::FLOAT6_E3M2, "torch_npu.float6_e3m2"}, + {DType::FLOAT6_E2M3, "torch_npu.float6_e2m3"}, + {DType::FLOAT4_E2M1, "torch_npu.float4_e2m1"}, + {DType::FLOAT4_E1M2, "torch_npu.float4_e1m2"}}; - const auto iter = TE_TYPE_TO_STRING_MAP.find(static_cast(dType)); - return iter != TE_TYPE_TO_STRING_MAP.end() ? iter->second : "Unknown dtype"; + const auto iter = TYPE_TO_STRING_MAP.find(static_cast(dType)); + return iter != TYPE_TO_STRING_MAP.end() ? iter->second : "Unknown dtype"; } #ifndef BUILD_LIBTORCH -PyMethodDef* te_functions() +PyMethodDef* custom_dtype_functions() { - return NPUTeMethods; + return NPUCustomDtypeMethods; } #endif } -} diff --git a/torch_npu/csrc/transformer_engine/Init.h b/torch_npu/csrc/custom_dtype/Init.h similarity index 32% rename from torch_npu/csrc/transformer_engine/Init.h rename to torch_npu/csrc/custom_dtype/Init.h index 01b697c34b..23235a0027 100644 --- a/torch_npu/csrc/transformer_engine/Init.h +++ b/torch_npu/csrc/custom_dtype/Init.h @@ -9,60 +9,59 @@ #include "torch_npu/csrc/framework/utils/OpPreparation.h" #include "third_party/acl/inc/acl/acl_base.h" -namespace torch_npu { -namespace te { -const int g_teToAclOffset = 256; +namespace c10_npu { +const int g_toAclOffset = 256; -#define TE_ENUM_OFFSET(new_name, old_name) new_name = static_cast(old_name) + g_teToAclOffset, +#define ENUM_OFFSET(new_name, old_name) new_name = static_cast(old_name) + g_toAclOffset, #ifndef BUILD_LIBTORCH -TORCH_NPU_API PyMethodDef* te_functions(); +TORCH_NPU_API PyMethodDef* custom_dtype_functions(); #endif enum class DType { - TE_DT_UNDEFINED = -1, - TE_ENUM_OFFSET(TE_FLOAT, ACL_FLOAT) - TE_ENUM_OFFSET(TE_FLOAT16, ACL_FLOAT16) - TE_ENUM_OFFSET(TE_INT8, ACL_INT8) - TE_ENUM_OFFSET(TE_INT32, ACL_INT32) - TE_ENUM_OFFSET(TE_UINT8, ACL_UINT8) - TE_ENUM_OFFSET(TE_INT16, ACL_INT16) - TE_ENUM_OFFSET(TE_UINT16, ACL_UINT16) - TE_ENUM_OFFSET(TE_UINT32, ACL_UINT32) - TE_ENUM_OFFSET(TE_INT64, ACL_INT64) - TE_ENUM_OFFSET(TE_UINT64, ACL_UINT64) - TE_ENUM_OFFSET(TE_DOUBLE, ACL_DOUBLE) - TE_ENUM_OFFSET(TE_BOOL, ACL_BOOL) - TE_ENUM_OFFSET(TE_STRING, ACL_STRING) - TE_ENUM_OFFSET(TE_COMPLEX64, ACL_COMPLEX64) - TE_ENUM_OFFSET(TE_COMPLEX128, ACL_COMPLEX128) - TE_ENUM_OFFSET(TE_BF16, ACL_BF16) - TE_ENUM_OFFSET(TE_INT4, ACL_INT4) - TE_ENUM_OFFSET(TE_UINT1, ACL_UINT1) - TE_ENUM_OFFSET(TE_COMPLEX32, ACL_COMPLEX32) - TE_ENUM_OFFSET(TE_HIFLOAT8, ACL_HIFLOAT8) - TE_ENUM_OFFSET(TE_FLOAT8_E5M2, ACL_FLOAT8_E5M2) - TE_ENUM_OFFSET(TE_FLOAT8_E4M3FN, ACL_FLOAT8_E4M3FN) - TE_ENUM_OFFSET(TE_FLOAT8_E8M0, ACL_FLOAT8_E8M0) - TE_ENUM_OFFSET(TE_FLOAT6_E3M2, ACL_FLOAT6_E3M2) - TE_ENUM_OFFSET(TE_FLOAT6_E2M3, ACL_FLOAT6_E2M3) - TE_ENUM_OFFSET(TE_FLOAT4_E2M1, ACL_FLOAT4_E2M1) - TE_ENUM_OFFSET(TE_FLOAT4_E1M2, ACL_FLOAT4_E1M2) + UNDEFINED = -1, + ENUM_OFFSET(FLOAT, ACL_FLOAT) + ENUM_OFFSET(FLOAT16, ACL_FLOAT16) + ENUM_OFFSET(INT8, ACL_INT8) + ENUM_OFFSET(INT32, ACL_INT32) + ENUM_OFFSET(UINT8, ACL_UINT8) + ENUM_OFFSET(INT16, ACL_INT16) + ENUM_OFFSET(UINT16, ACL_UINT16) + ENUM_OFFSET(UINT32, ACL_UINT32) + ENUM_OFFSET(INT64, ACL_INT64) + ENUM_OFFSET(UINT64, ACL_UINT64) + ENUM_OFFSET(DOUBLE, ACL_DOUBLE) + ENUM_OFFSET(BOOL, ACL_BOOL) + ENUM_OFFSET(STRING, ACL_STRING) + ENUM_OFFSET(COMPLEX64, ACL_COMPLEX64) + ENUM_OFFSET(COMPLEX128, ACL_COMPLEX128) + ENUM_OFFSET(BF16, ACL_BF16) + ENUM_OFFSET(INT4, ACL_INT4) + ENUM_OFFSET(UINT1, ACL_UINT1) + ENUM_OFFSET(COMPLEX32, ACL_COMPLEX32) + ENUM_OFFSET(HIFLOAT8, ACL_HIFLOAT8) + ENUM_OFFSET(FLOAT8_E5M2, ACL_FLOAT8_E5M2) + ENUM_OFFSET(FLOAT8_E4M3FN, ACL_FLOAT8_E4M3FN) + ENUM_OFFSET(FLOAT8_E8M0, ACL_FLOAT8_E8M0) + ENUM_OFFSET(FLOAT6_E3M2, ACL_FLOAT6_E3M2) + ENUM_OFFSET(FLOAT6_E2M3, ACL_FLOAT6_E2M3) + ENUM_OFFSET(FLOAT4_E2M1, ACL_FLOAT4_E2M1) + ENUM_OFFSET(FLOAT4_E1M2, ACL_FLOAT4_E1M2) }; -inline bool IsTEDType(int64_t t) +inline bool IsCustomDType(int64_t t) { - if (t >= g_teToAclOffset) { + if (t >= g_toAclOffset) { return true; } return false; } -// Both torch_npu::te::DType and ScalarType are supported +// Both c10_npu::DType and ScalarType are supported inline aclDataType GetAclDataType(int64_t t) { - if (t >= g_teToAclOffset) { - return static_cast(t - g_teToAclOffset); + if (t >= g_toAclOffset) { + return static_cast(t - g_toAclOffset); } return at_npu::native::OpPreparation::convert_to_acl_data_type( static_cast(t)); @@ -70,7 +69,7 @@ inline aclDataType GetAclDataType(int64_t t) inline aclDataType GetAclDataType(DType t) { - return static_cast(static_cast(t) - g_teToAclOffset); + return static_cast(static_cast(t) - g_toAclOffset); } inline at::ScalarType GetATenDType(int64_t t) @@ -79,7 +78,6 @@ inline at::ScalarType GetATenDType(int64_t t) return at_npu::native::OpPreparation::convert_to_scalar_type(aclType); } -const std::string TeDataTypeToString(int64_t dType); +const std::string CustomDataTypeToString(int64_t dType); -} // namespace te -} // namespace torch_npu +} // namespace c10_npu diff --git a/torch_npu/csrc/transformer_engine/extension.h b/torch_npu/csrc/custom_dtype/extension.h similarity index 73% rename from torch_npu/csrc/transformer_engine/extension.h rename to torch_npu/csrc/custom_dtype/extension.h index 2adb9bc3bf..91ef1df8a5 100644 --- a/torch_npu/csrc/transformer_engine/extension.h +++ b/torch_npu/csrc/custom_dtype/extension.h @@ -1,14 +1,12 @@ #pragma once #include -#include "torch_npu/csrc/transformer_engine/Init.h" +#include "torch_npu/csrc/custom_dtype/Init.h" -namespace torch_npu { -namespace te { +namespace c10_npu { at::Tensor cast_to_fp8(const at::Tensor &input, int otype); void cast_to_fp8_noalloc(const at::Tensor &input, at::Tensor output, int otype); at::Tensor cast_from_fp8(const at::Tensor &input, int itype, int otype); } -} diff --git a/torch_npu/utils/hif8_tensor.py b/torch_npu/utils/hif8_tensor.py index a1d3e225ea..3070aa96fb 100644 --- a/torch_npu/utils/hif8_tensor.py +++ b/torch_npu/utils/hif8_tensor.py @@ -12,12 +12,12 @@ from torch_npu.utils._error_code import ErrCode, pta_error # init transformer engine -torch_npu._C._te_init() +torch_npu._C._cd_init() -tex = torch_npu._C._te +tex = torch_npu._C._cd aten = torch.ops.aten -NPU_TE_DType = { +NPU_CUSTOM_DType = { torch.uint8: tex.DType.uint8, torch.int32: tex.DType.int32, torch.float32: tex.DType.float32, @@ -41,7 +41,7 @@ class _FromHiFloat8Func(torch.autograd.Function): out = tex.cast_from_fp8( data, tex.DType.hifloat8, - NPU_TE_DType[dtype], + NPU_CUSTOM_DType[dtype], ) out = out.view(tensor.size()) return out -- Gitee From d4d4fa0b393d3b53ebcd24f2de028884de3b69d6 Mon Sep 17 00:00:00 2001 From: chuboning Date: Sat, 14 Jun 2025 10:57:42 +0800 Subject: [PATCH 08/11] Revert some files --- codegen/utils.py | 2 +- test/onnx/test_wrapper_onnx_ops.py | 22 ------ torch_npu/contrib/module/linear_quant.py | 81 +++------------------- torch_npu/contrib/module/quant_conv2d.py | 7 +- torch_npu/csrc/custom_dtype/CMakeLists.txt | 4 +- torch_npu/onnx/wrapper_onnx_ops.py | 78 ++------------------- 6 files changed, 21 insertions(+), 173 deletions(-) diff --git a/codegen/utils.py b/codegen/utils.py index 5810d3b3e6..c603e77fe1 100644 --- a/codegen/utils.py +++ b/codegen/utils.py @@ -475,7 +475,7 @@ if (({force_aclnn} || at_npu::native::env::CheckJitDisable()){tensor_check_str}) }} else {{ if ({is_ascend910_95_version}) {{ TORCH_CHECK(false, - "Current aclnn operator {impl_name} do not support internal format.", + "Ascend910_95 series only support aclnn operator, and current operator {impl_name} do not support internal format.", PTA_ERROR(ErrCode::NOT_SUPPORT)); }} return {impl_name}({args_exprs_str}); diff --git a/test/onnx/test_wrapper_onnx_ops.py b/test/onnx/test_wrapper_onnx_ops.py index 7e072e9653..1363719fbd 100644 --- a/test/onnx/test_wrapper_onnx_ops.py +++ b/test/onnx/test_wrapper_onnx_ops.py @@ -1221,28 +1221,6 @@ class TestOnnxOps(TestCase): export_onnx(onnx_model_name) assert (os.path.isfile(os.path.join(TestOnnxOps.test_onnx_path, onnx_model_name))) - @SupportedDevices(['Ascend910B']) - def test_wrapper_npu_rotary_mul_with_mode(self): - class Model(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x, r1, r2, mode): - return torch_npu.npu_rotary_mul(x, r1, r2, mode) - - def export_onnx(onnx_model_name): - x = torch.rand([8192, 2, 5, 128], dtype=torch.float32).npu() - r1 = torch.rand([8192, 1, 1, 128], dtype=torch.float32).npu() - r2 = torch.rand([8192, 1, 1, 128], dtype=torch.float32).npu() - rotary_mode = "interleave" - model = Model().to("npu") - model(x, r1, r2, rotary_mode) - self.onnx_export(model, (x, r1, r2, rotary_mode), onnx_model_name, ["x", "r1", "r2", "rotary_mode"]) - - onnx_model_name = "model_npu_rotary_mul_interleave.onnx" - export_onnx(onnx_model_name) - assert (os.path.isfile(os.path.join(TestOnnxOps.test_onnx_path, onnx_model_name))) - @SupportedDevices(['Ascend910B']) def test_wrapper_npu_masked_softmax_with_rel_pos_bias(self): class Model(torch.nn.Module): diff --git a/torch_npu/contrib/module/linear_quant.py b/torch_npu/contrib/module/linear_quant.py index 52662252c4..5b051b8911 100644 --- a/torch_npu/contrib/module/linear_quant.py +++ b/torch_npu/contrib/module/linear_quant.py @@ -36,19 +36,6 @@ class LinearQuant(nn.Module): If :attr:`bias` is ``True``, the values are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where :math:`k = \frac{1}{\text{in\_features}}` - x1_dtype: Only support torch_npu.hifloat8, torch_npu.float4_e2m1, torch_npu.float4_e1m2, default to None. - x2_dtype: Only support torch_npu.hifloat8, torch_npu.float4_e2m1, torch_npu.float4_e1m2, default to None. - pertoken_scale_dtype: Only support torch_npu.float8_e8m0, default to None. - scale_dtype: Only support torch_npu.float8_e8m0, default to None. - group_sizes: a list of Int, the length of the list is 3. - the first element is group_size_m, means for input, group_size_m elements according to single scale in m dim, - default to 0. - the second element is the group_size_n, means for weight, group_size_n elements according to single scale - in n dim, default to 0. - the third element is the group_size_k, means for input or weight, group_size_k elements according to - single scale in k dim, default to 0. - if any of group_size_m, group_size_n, group_size_k calculated by group_size is 0, recalculate it by - input shape, eg: group_size_m = m / scale_m (m % scale_m must be 0). A4W4 Examples:: >>> x1 = torch.randint(-1, 1, (1, 2), dtype=torch.int32).npu() @@ -73,20 +60,6 @@ class LinearQuant(nn.Module): >>> output = model(x1) >>> print(output.size()) torch.Size(1, 127) - - A8W8 && preblock quantization Examples:: - >>> x1 = torch.randint(-1, 1, (2048, 1024), dtype=torch.float8_e4m3).npu() - >>> x2 = torch.randint(-1, 1, (4096, 1024), dtype=torch.float8_e4m3).npu() - >>> scale = torch.randn((16, 8), dtype=torch.float32).npu() - >>> pertoken_scale = torch.randn((32, 8), dtype=torch.float32).npu() - >>> model = LinearQuant(in_features, out_features, False, group_sizes=[128,128,128]) - >>> model = model.npu() - >>> model.weight.data = x2 - >>> model.scale.data = scale - >>> model.pertoken_scale.data = pertoken_scale - >>> output = model(x1) - >>> print(output.size()) - torch.Size(22048, 4096) """ in_features: int out_features: int @@ -96,21 +69,8 @@ class LinearQuant(nn.Module): pertoken_scale: Tensor bias: Tensor - def __init__(self, - in_features: int, - out_features: int, - *, - bias: bool = True, - offset: bool = False, - pertoken_scale: bool = False, - device=None, - dtype=None, - output_dtype=None, - x1_dtype=None, - x2_dtype=None, - pertoken_scale_dtype=None, - scale_dtype=None, - group_sizes=None) -> None: + def __init__(self, in_features: int, out_features: int, *, bias: bool = True, offset: bool = False, + pertoken_scale: bool = False, device=None, dtype=None, output_dtype=None) -> None: super(LinearQuant, self).__init__() self.in_features = in_features @@ -118,11 +78,6 @@ class LinearQuant(nn.Module): self.weight = Parameter(torch.empty((out_features, in_features)), False) self.scale = Parameter(torch.empty(out_features), False) self.output_dtype = output_dtype - self.x1_dtype = x1_dtype - self.x2_dtype = x2_dtype - self.pertoken_scale_dtype = pertoken_scale_dtype - self.scale_dtype = scale_dtype - self.group_sizes = group_sizes if offset: self.offset = Parameter(torch.empty(out_features, dtype=torch.float32), False) else: @@ -142,30 +97,16 @@ class LinearQuant(nn.Module): scale_quant = self.scale first_last_dim = self.weight.dim() - 1 second_last_dim = self.weight.dim() - 2 + if not ((linear_quant_input.dtype == torch.int32 and self.weight.dtype == torch.int32) or + (linear_quant_input.dtype == torch.int8 and self.weight.dtype == torch.int8)): + raise ValueError("input and weight should be both torch.int32 or both torch.int8 datatype, " + f"but now input is {linear_quant_input.dtype}, weight is {self.weight.dtype}." + ops_error(ErrCode.TYPE)) - is_not_int_input = self.weight.dtype not in [torch.int8, torch.int32] - is_check_dtype_ok = (self.scale.dtype == torch.float32 - and (self.output_dtype != torch.bfloat16 or is_not_int_input) - and self.output_dtype != torch.int32) + is_check_dtype_ok = (self.scale.dtype == torch.float32 and + self.output_dtype not in [torch.bfloat16, torch.int32]) if self.pertoken_scale is None and is_check_dtype_ok: scale_quant = torch_npu.npu_trans_quant_param(self.scale, self.offset) - has_group = (self.group_sizes is not None - and (isinstance(self.group_sizes, list) or isinstance(self.group_sizes, tuple)) - and len(self.group_sizes) == 3 and (self.group_sizes[1] > 1 or self.group_sizes[2] > 1)) - if (scale_quant.dim() > 1 and has_group): - scale_first_last_dim = scale_quant.dim() - 1 - scale_second_last_dim = scale_quant.dim() - 2 - scale_quant = scale_quant.transpose(scale_second_last_dim, scale_first_last_dim) - return torch_npu.npu_quant_matmul(linear_quant_input, - self.weight.transpose(second_last_dim, first_last_dim), - scale_quant, - offset=self.offset, - pertoken_scale=self.pertoken_scale, - bias=self.bias, - output_dtype=self.output_dtype, - x1_dtype=self.x1_dtype, - x2_dtype=self.x2_dtype, - pertoken_scale_dtype=self.pertoken_scale_dtype, - scale_dtype=self.scale_dtype, - group_sizes=self.group_sizes) + return torch_npu.npu_quant_matmul(linear_quant_input, self.weight.transpose(second_last_dim, first_last_dim), + scale_quant, offset=self.offset, pertoken_scale=self.pertoken_scale, bias=self.bias, + output_dtype=self.output_dtype) diff --git a/torch_npu/contrib/module/quant_conv2d.py b/torch_npu/contrib/module/quant_conv2d.py index e7a395e850..1aa59bce43 100644 --- a/torch_npu/contrib/module/quant_conv2d.py +++ b/torch_npu/contrib/module/quant_conv2d.py @@ -6,7 +6,6 @@ from torch.nn.common_types import _size_2_t from torch.nn.parameter import Parameter from torch.nn.modules.utils import _pair import torch_npu -from torch_npu.utils._error_code import ErrCode, ops_error __all__ = ['QuantConv2d'] @@ -109,14 +108,14 @@ class QuantConv2d(nn.Module): self.output_dtype = output_dtype self.weight = \ - Parameter(torch.empty((self.out_channels, self.in_channels // self.groups, *self.kernel_size)), False) + Parameter(torch.empty((self.out_channels, self.in_channels, *self.kernel_size), dtype=torch.int8), False) self.scale = Parameter(torch.empty(self.out_channels, dtype=torch.int64), False) if bias: - self.bias = Parameter(torch.empty(self.out_channels), False) + self.bias = Parameter(torch.empty(self.out_channels, dtype=torch.int32), False) else: self.register_parameter('bias', None) if offset: - raise ValueError("offset must be False" + ops_error(ErrCode.VALUE)) + self.offset = Parameter(torch.empty(out_channels, dtype=torch.float32), False) else: self.register_parameter('offset', None) diff --git a/torch_npu/csrc/custom_dtype/CMakeLists.txt b/torch_npu/csrc/custom_dtype/CMakeLists.txt index bdad67c05d..7d3d7c0e53 100644 --- a/torch_npu/csrc/custom_dtype/CMakeLists.txt +++ b/torch_npu/csrc/custom_dtype/CMakeLists.txt @@ -1,6 +1,6 @@ -FILE(GLOB _TE_SRCS *.cpp) +FILE(GLOB _CUS_DTYPE_SRCS *.cpp) -LIST(APPEND CUS_DTYPE_SRCS ${_TE_SRCS}) +LIST(APPEND CUS_DTYPE_SRCS ${_CUS_DTYPE_SRCS}) # Pass to parent set(CUS_DTYPE_SRCS ${CUS_DTYPE_SRCS} PARENT_SCOPE) diff --git a/torch_npu/onnx/wrapper_onnx_ops.py b/torch_npu/onnx/wrapper_onnx_ops.py index b7b4f40974..68113cb5ef 100644 --- a/torch_npu/onnx/wrapper_onnx_ops.py +++ b/torch_npu/onnx/wrapper_onnx_ops.py @@ -649,16 +649,8 @@ class _NPURotaryMulOP(torch.autograd.Function): return torch.ops.npu.npu_rotary_mul(*args, **kwargs) @staticmethod - def symbolic(g, x: Tensor, r1: Tensor, r2: Tensor, rotary_mode: str = "half"): - if rotary_mode == "half": - return g.op("npu::NPURotaryMul", x, r1, r2) - elif rotary_mode == "interleave": - mode = 1 - elif rotary_mode == "quarter": - mode = 2 - elif rotary_mode == "interleave-half": - mode = 3 - return g.op("npu::NPURotaryPositionEmbedding", x, r1, r2, mode_i=mode) + def symbolic(g, x: Tensor, r1: Tensor, r2: Tensor): + return g.op("npu::NPURotaryMul", x, r1, r2) class _NPUPromptFlashAttentionOP(torch.autograd.Function): @@ -738,45 +730,6 @@ class _NPUFusedInferAttentionScoreOP(torch.autograd.Function): key_antiquant_mode, value_antiquant_mode) -class _NPUFusedInferAttentionScoreOP(torch.autograd.Function): - - @staticmethod - def forward(ctx, *args, **kwargs): - return torch.ops.npu.fused_infer_attention_score(*args, **kwargs) - - @staticmethod - def symbolic(g, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - pse_shift: Optional[Tensor], atten_mask: Optional[Tensor], - actual_seq_lengths: Optional[Tensor], - actual_seq_lengths_kv: Optional[Tensor], - dequant_scale1: Optional[Tensor], quant_scale1: Optional[Tensor], - dequant_scale2: Optional[Tensor], quant_scale2: Optional[Tensor], - quant_offset2: Optional[Tensor], antiquant_scale: Optional[Tensor], - antiquant_offset: Optional[Tensor], block_table: Optional[Tensor], - query_padding_size: Optional[Tensor], kv_padding_size: Optional[Tensor], - key_antiquant_scale: Optional[Tensor], key_antiquant_offset: Optional[Tensor], - value_antiquant_scale: Optional[Tensor], value_antiquant_offset: Optional[Tensor], - key_shared_prefix: Optional[Tensor], value_shared_prefix: Optional[Tensor], - actual_shared_prefix_len: Optional[Tensor], - query_rope: Optional[Tensor], - key_rope: Optional[Tensor], - num_heads: int = 1, scale: float = 1.0, - pre_tokens: int = 2147483647, next_tokens: int = 2147483647, - input_layout: str = "BSH", num_key_value_heads: int = 0, - sparse_mode: int = 0, inner_precise: int = 0, block_size: int = 0, - antiquant_mode: int = 0, softmax_lse_flag: bool = False, - key_antiquant_mode: int = 0, value_antiquant_mode: int = 0): - return g.op("npu::NPUFusedInferAttentionScoreOP", self, query, key, value, - pse_shift, atten_mask, actual_seq_lengths, actual_seq_lengths_kv, - dequant_scale1, quant_scale1, dequant_scale2, quant_scale2, quant_offset2, - antiquant_scale, antiquant_offset, block_table, query_padding_size, kv_padding_size, - key_antiquant_scale, key_antiquant_offset, value_antiquant_scale, value_antiquant_offset, - key_shared_prefix, value_shared_prefix, actual_shared_prefix_len, query_rope, key_rope, - num_heads, scale, pre_tokens, next_tokens, input_layout, num_key_value_heads, - sparse_mode, inner_precise, block_size, antiquant_mode, softmax_lse_flag, - key_antiquant_mode, value_antiquant_mode) - - class _NPUMaskedSoftmaxWithRelPosBiasOP(torch.autograd.Function): @staticmethod @@ -1318,8 +1271,8 @@ def _wrapper_npu_mish(self): return _NPUMishOP.apply(self) -def _wrapper_npu_rotary_mul(x, r1, r2, rotary_mode="half"): - return _NPURotaryMulOP.apply(x, r1, r2, rotary_mode) +def _wrapper_npu_rotary_mul(x, r1, r2): + return _NPURotaryMulOP.apply(x, r1, r2) def _wrapper_npu_prompt_flash_attention(self, query, key, value, padding_mask, atten_mask, pse_shift, actual_seq_lengths, @@ -1363,29 +1316,6 @@ def _wrapper_npu_fused_infer_attention_score(self, query, key, value, pse_shift, key_antiquant_mode, value_antiquant_mode) -def _wrapper_npu_fused_infer_attention_score(self, query, key, value, pse_shift, atten_mask, actual_seq_lengths, actual_seq_lengths_kv, - dequant_scale1, quant_scale1, dequant_scale2, quant_scale2, quant_offset2, - antiquant_scale, - antiquant_offset, block_table, query_padding_size, kv_padding_size, - num_heads, scale, pre_tokens, next_tokens, input_layout, - key_antiquant_scale, key_antiquant_offset, value_antiquant_scale, value_antiquant_offset, - key_shared_prefix, value_shared_prefix, actual_shared_prefix_len, query_rope, key_rope, - num_key_value_heads, - sparse_mode, inner_precise, block_size, antiquant_mode, softmax_lse_flag, - key_antiquant_mode, value_antiquant_mode): - return _NPUFusedInferAttentionScoreOP.apply(self, query, key, value, pse_shift, atten_mask, actual_seq_lengths, - actual_seq_lengths_kv, - dequant_scale1, quant_scale1, dequant_scale2, quant_scale2, - quant_offset2, antiquant_scale, - antiquant_offset, block_table, query_padding_size, kv_padding_size, - key_antiquant_scale, key_antiquant_offset, value_antiquant_scale, value_antiquant_offset, - key_shared_prefix, value_shared_prefix, actual_shared_prefix_len, query_rope, key_rope, - num_heads, scale, pre_tokens, next_tokens, input_layout, - num_key_value_heads, - sparse_mode, inner_precise, block_size, antiquant_mode, softmax_lse_flag, - key_antiquant_mode, value_antiquant_mode) - - def _wrapper_npu_mm_all_reduce_base(x1, x2, hcom, reduce_op, bias, antiquant_scale, antiquant_offset, x3, dequant_scale, pertoken_scale, comm_quant_scale_1, comm_quant_scale_2, antiquant_group_size, comm_turn): -- Gitee From d83aea44d1ec2d842965153a0095ea150f8c7cec Mon Sep 17 00:00:00 2001 From: chuboning Date: Mon, 16 Jun 2025 09:21:03 +0800 Subject: [PATCH 09/11] Add 910_95 support.Part1 --- torch_npu/__init__.py | 8 +- torch_npu/utils/hif8_tensor.py | 579 --------------------------------- 2 files changed, 1 insertion(+), 586 deletions(-) delete mode 100644 torch_npu/utils/hif8_tensor.py diff --git a/torch_npu/__init__.py b/torch_npu/__init__.py index 74f42a5389..a5e842a2a0 100644 --- a/torch_npu/__init__.py +++ b/torch_npu/__init__.py @@ -1,4 +1,4 @@ -__all__ = ["erase_stream", "matmul_checksum", "HiFloat8Tensor"] +__all__ = ["erase_stream", "matmul_checksum"] import os import sys @@ -61,7 +61,6 @@ from torch_npu.utils.exposed_api import public_npu_functions from torch_npu.distributed.checkpoint.checkpoint import _apply_dcp_patch from torch_npu.npu._stream_check import apply_sanitizer_patch from torch_npu.npu.utils import _erase_stream as erase_stream -from torch_npu.utils.hif8_tensor import HiFloat8Tensor from torch_npu.utils._error_code import ErrCode, pta_error, _except_handler from torch_npu.asd.asd import _asd_patch from torch_npu.asd.checksum import _matmul_checksum as matmul_checksum @@ -91,11 +90,6 @@ for name in dir(torch.ops.npu): __all__.append(name) setattr(torch, name, _wrap_torch_error_func(getattr(torch.ops.npu, name))) -for name in dir(torch_npu._C._cd.DType): - if name.startswith('__') or name in ['_dir', 'name']: - continue - setattr(torch_npu, name, getattr(torch_npu._C._cd.DType, name)) - all_monkey_patches = [ ["nn.functional", npu_functional], ["nn", npu_modules], diff --git a/torch_npu/utils/hif8_tensor.py b/torch_npu/utils/hif8_tensor.py deleted file mode 100644 index 3070aa96fb..0000000000 --- a/torch_npu/utils/hif8_tensor.py +++ /dev/null @@ -1,579 +0,0 @@ -"""Tensor class with HIF8 data""" -from __future__ import annotations - -__all__ = ["HiFloat8Tensor"] - -from typing import Any, Dict, Optional, Tuple, Union - -import torch -from torch.utils._pytree import tree_map -import torch_npu -from torch_npu.utils._error_code import ErrCode, pta_error - - -# init transformer engine -torch_npu._C._cd_init() - -tex = torch_npu._C._cd -aten = torch.ops.aten - -NPU_CUSTOM_DType = { - torch.uint8: tex.DType.uint8, - torch.int32: tex.DType.int32, - torch.float32: tex.DType.float32, - torch.half: tex.DType.float16, - torch.bfloat16: tex.DType.bfloat16, -} - - -class _FromHiFloat8Func(torch.autograd.Function): - """Cast from HIF8 to other dtype""" - - @staticmethod - def forward( - _ctx: torch.autograd.function.FunctionCtx, # unused - tensor: HiFloat8Tensor, - dtype: Optional[torch.dtype] = None, - ) -> torch.Tensor: - if dtype is None: - dtype = tensor.dtype - data = tensor._data.contiguous().view(1, -1).detach() - out = tex.cast_from_fp8( - data, - tex.DType.hifloat8, - NPU_CUSTOM_DType[dtype], - ) - out = out.view(tensor.size()) - return out - - @staticmethod - def backward( - _ctx: torch.autograd.function.FunctionCtx, # unused - grad: torch.Tensor, - ) -> Tuple[Optional[torch.Tensor], ...]: - # Assume that we want gradients in full precision - return grad, None - - -class _ToHiFloat8Func(torch.autograd.Function): - """Cast to HIF8 from other dtype""" - - @staticmethod - def forward( - _ctx: torch.autograd.function.FunctionCtx, # unused - tensor: torch.Tensor, - ) -> HiFloat8Tensor: - - # Check input tensor TODO - tensor = tensor.contiguous().npu().detach() - if tensor.dtype not in (torch.float32, torch.bfloat16, torch.float16): - tensor = tensor.float() - - # Cast data to HIF8 - data = tex.cast_to_fp8( - tensor.view(1, -1), - tex.DType.hifloat8, - ) - data = data.view(tensor.size()) - - # Construct HIF8 tensor - return HiFloat8Tensor( - data=data, - dtype=tensor.dtype, - ) - - @staticmethod - def backward( - _ctx: torch.autograd.function.FunctionCtx, # unused - grad: torch.Tensor, - ) -> Tuple[Optional[torch.Tensor], ...]: - # Assume that we want gradients in full precision - return grad, None - - -class _IdentityFunc(torch.autograd.Function): - """Identity function - - If constructor keyword-arguments are provided, then construct a - new HiFloat8Tensor using the provided tensor's attributes. - - """ - - @staticmethod - def forward( - ctx, - tensor: HiFloat8Tensor, - init_kwargs: Optional[Dict[str, Any]] = None, - ) -> torch.Tensor: - - # Return input tensor if constructor kwargs are not provided - ctx.input_dtype = tensor.dtype - if init_kwargs is None: - return tensor - - # Construct new tensor if constructor kwargs are provided - default_kwargs = dict( - data=tensor._data, - dtype=tensor.dtype, - ) - for key, val in default_kwargs.items(): - if key not in init_kwargs: - init_kwargs[key] = val - return HiFloat8Tensor(**init_kwargs) - - @staticmethod - def backward(ctx, grad): - return grad.to(ctx.input_dtype), None - - -class _ViewFunc(torch.autograd.Function): - """View function - - View the HiFloat8Tensor using the provided shape. - - """ - - @staticmethod - def forward( - ctx, - tensor: torch.Tensor, - shape: Tuple[int] = None, - ) -> torch.Tensor: - - # Return input tensor if shape is not provided - ctx.shape = tensor.shape - if shape is None: - return tensor - - # Construct new tensor if shape is provided - if isinstance(tensor, HiFloat8Tensor): - return HiFloat8Tensor.make_like( - tensor, - data=tensor._data.view(*shape), - ) - return tensor.view(*shape) - - @staticmethod - def backward( - ctx, - grad: torch.Tensor, - ) -> Tuple[Union[torch.Tensor, None], ...]: - - if isinstance(grad, HiFloat8Tensor): - dgrad = HiFloat8Tensor.make_like( - grad, - data=grad._data.view(ctx.shape), - ) - return dgrad, None - return grad.view(ctx.shape), None - - -class _ReshapeFunc(torch.autograd.Function): - """Reshape function - - Reshape the HiFloat8Tensor using the provided shape. - - """ - - @staticmethod - def forward( - ctx, - tensor: torch.Tensor, - shape: Tuple[int] = None, - ) -> torch.Tensor: - - # Return input tensor if shape is not provided - ctx.shape = tensor.shape - if shape is None: - return tensor - - # Construct new tensor if shape is provided - if isinstance(tensor, HiFloat8Tensor): - return HiFloat8Tensor.make_like( - tensor, - data=tensor._data.reshape(*shape), - ) - return tensor.reshape(*shape) - - @staticmethod - def backward( - ctx, - grad: torch.Tensor, - ) -> Tuple[Union[torch.Tensor, None], ...]: - - if isinstance(grad, HiFloat8Tensor): - dgrad = HiFloat8Tensor.make_like( - grad, - data=grad._data.reshape(ctx.shape), - ) - return dgrad, None - return grad.reshape(ctx.shape), None - - -class _TransposeFunc(torch.autograd.Function): - """Transpose function - - Transpose the HiFloat8Tensor. - - """ - - @staticmethod - def forward(ctx, tensor, dim0, dim1): - ctx.save_for_backward(dim0, dim1) - if isinstance(tensor, HiFloat8Tensor): - return HiFloat8Tensor.make_like( - tensor, - data=tensor._data.transpose(dim0, dim1), - ) - return tensor.transpose(dim0, dim1) - - @staticmethod - def backward(ctx, grad): - dim0, dim1 = ctx.saved_tensors - if isinstance(grad, HiFloat8Tensor): - dgrad = HiFloat8Tensor.make_like( - grad, - data=grad._data.transpose(dim0, dim1), - ) - return dgrad, None - return grad.transpose(dim0, dim1), None, None - - -class HiFloat8Tensor(torch.Tensor): - """Experimental tensor class with HIF8 data - - The tensor presents as having a standard, higher-precision dtype, - but the data itself is (scaled) HIF8. For most tensor operations, - the data will be cast to the nominal dtype before performing the - operation. - - Parameters - ---------- - data: torch.Tensor - Raw HIF8 data in a uint8 tensor - dtype: torch.dtype, default = torch.float32 - Nominal tensor datatype. - - """ - - def __new__( - cls, - *, - data: torch.Tensor, - dtype: torch.dtype = torch.float32, - ): - # Check that data buffer is valid - if data.element_size() != 1: - raise ValueError( - f"HiFloat8Tensor requires data buffer with 8-bit dtype (got dtype={data.dtype})" - + pta_error(ErrCode.VALUE) - ) - if data.requires_grad: - raise ValueError( - "HiFloat8Tensor requires non-differentiable data buffer" - + pta_error(ErrCode.VALUE) - ) - if not data.is_npu: - data = data.npu() - - # Initialize tensor object - self = torch.Tensor._make_wrapper_subclass( - cls, - data.size(), - strides=data.stride(), - storage_offset=data.storage_offset(), - dtype=dtype, - layout=data.layout, - requires_grad=data.requires_grad, - device=data.device, - ) - self._data: torch.Tensor = data - - return self - - @classmethod - def make_like( - cls, - tensor: HiFloat8Tensor, - *, - data: torch.Tensor, - **kwargs, - ) -> HiFloat8Tensor: - """Use attributes of a HiFloat8Tensor to create another HiFloat8Tensor - - See constructor for list of keyword arguments. - - """ - default_kwargs = dict( - dtype=tensor.dtype, - ) - for key, val in default_kwargs.items(): - if key not in kwargs: - kwargs[key] = val - return HiFloat8Tensor(data=data, **kwargs) - - def __repr__(self): - return ( - "HiFloat8Tensor(" - f"data={self.from_hifloat8(dtype=self.dtype)}" - ")" - ) - - def from_hifloat8(self, dtype: Optional[torch.dtype] = None) -> torch.Tensor: - """ - Construct PyTorch tensor from HiFloat8Tensor - - By default the resulting tensor's dtype is the - HiFloat8Tensor's nominal dtype. - """ - return _FromHiFloat8Func.apply(self, dtype) - - @classmethod - def to_hifloat8( - cls, - tensor: torch.Tensor - ): - """Construct HiFloat8Tensor from PyTorch tensor""" - return _ToHiFloat8Func.apply( - tensor - ) - - def float(self) -> torch.Tensor: - return self.from_hifloat8(dtype=torch.float32) - - def bfloat16(self) -> torch.Tensor: - return self.from_hifloat8(dtype=torch.bfloat16) - - def half(self) -> torch.Tensor: - return self.from_hifloat8(dtype=torch.float16) - - def cpu(self) -> torch.Tensor: - return self.from_hifloat8().cpu() - - def clone(self) -> HiFloat8Tensor: - return _IdentityFunc.apply(self, {"data": self._data.detach().clone()}) - - def view(self, *shape: Tuple[int]) -> HiFloat8Tensor: - return _ViewFunc.apply(self, shape) - - def reshape(self, *shape: Tuple[int]) -> HiFloat8Tensor: - return _ReshapeFunc.apply(self, shape) - - def contiguous( - self, - *, - memory_format: torch.memory_format = torch.contiguous_format, - ) -> HiFloat8Tensor: - """Returns tensor with data in provided memory format - - Returns `self` if data is already in correct memory format. - - """ - if self._data.is_contiguous(memory_format=memory_format): - return self - return _IdentityFunc.apply( - self, - {"data": self._data.detach().contiguous(memory_format=memory_format)}, - ) - - def to_dtype(self, dtype: torch.dtype) -> HiFloat8Tensor: - """Create `HiFloat8Tensor` with given nominal dtype - - The new tensor has the same underlying HIF8 data. - - """ - return HiFloat8Tensor.make_like( - self, - data=self._data, - dtype=dtype, - ) - - @classmethod - def __torch_dispatch__(cls, func, types, args, kwargs=None): - - # In-place copy op - if func == aten.copy_.default: - - # Check tensors - dst = args[0] - src = args[1] - if not isinstance(dst, torch.Tensor): - raise RuntimeError( - "Attempted to copy into something that isn't a PyTorch tensor" - + pta_error(ErrCode.TYPE) - ) - if not isinstance(src, torch.Tensor): - raise RuntimeError( - "Attempted to copy from something that isn't a PyTorch tensor" - + pta_error(ErrCode.TYPE) - ) - - # Special handling based on which tensors are HIF8 - dst_is_hif8 = isinstance(dst, HiFloat8Tensor) - src_is_hif8 = isinstance(src, HiFloat8Tensor) - if dst_is_hif8 and src_is_hif8: - # Directly copy HIF8 data if possible - dst._data.copy_(src._data) - - elif not dst_is_hif8 and src_is_hif8: - # Cast source tensor to higher precision - dst.copy_(src.from_hifloat8()) - - elif dst_is_hif8 and not src_is_hif8: - # Make sure input is in expected format - src = src.expand(dst.size()) - src = src.to( - device=dst.device, - memory_format=torch.contiguous_format, - ) - - # Cast to HIF8 - if not dst._data.is_contiguous(): - raise RuntimeError( - "Transformer Engine cast kernels require contiguous data" - + pta_error(ErrCode.INTERNAL) - ) - tex.cast_to_fp8_noalloc( - src.view(1, -1), - dst._data.view(1, -1), - tex.DType.hifloat8, - ) - else: - # Invalid case - raise RuntimeError( - "Using HiFloat8Tensor copy logic, but no HiFloat8Tensor found" - + pta_error(ErrCode.INTERNAL) - ) - - # Nothing to return for in-place ops - return None - - # Slice op - if func == aten.slice.Tensor: - tensor = args[0] - data = tensor._data - data_slice = data.__torch_dispatch__( - func, - types, - [data] + list(args[1:]), - kwargs, - ) - return HiFloat8Tensor.make_like(tensor, data=data_slice) - - # Detach op - if func == aten.detach.default: - # Simply return a new HiFloat8Tensor with the same attrs - return HiFloat8Tensor.make_like( - args[0], - data=args[0]._data, - ) - - # View op - if func == aten.view.default: - tensor = args[0] - data = tensor._data - data_view = data.__torch_dispatch__( - func, - types, - [data] + list(args[1:]), - kwargs, - ) - return HiFloat8Tensor.make_like( - tensor, - data=data_view, - ) - - def maybe_unwrap(t): - if isinstance(t, HiFloat8Tensor): - return t.from_hifloat8() - return t - - def maybe_update_inplace(arg, new_arg, schema_arg): - """Update values of HIF8 tensors - - Keep the same HIF8 scaling factors. - - """ - check_args = isinstance(arg, HiFloat8Tensor) and isinstance(new_arg, torch.Tensor) - check_schema = ( - hasattr(schema_arg, "alias_info") - and hasattr(schema_arg.alias_info, "is_write") - and schema_arg.alias_info.is_write - ) - - if check_args and check_schema: - arg.copy_(new_arg) - - # In-place op - if func._schema.is_mutable: - # Cast to higher precision, perform op, and cast values - # back to original HIF8 buffers - new_args = tree_map(maybe_unwrap, args) - new_kwargs = tree_map(maybe_unwrap, kwargs) - schema_args = func._schema.arguments - args_len = len(args) - out = super().__torch_dispatch__(func, types, new_args, new_kwargs) - for arg, new_arg, schema_arg in zip(args, new_args, schema_args): - maybe_update_inplace(arg, new_arg, schema_arg) - for kwarg, new_kwarg, schema_arg in zip(kwargs, new_kwargs, schema_args[args_len:]): - if not (kwarg == new_kwarg == schema_arg.name): - raise ValueError('name of the kw argument should match' + pta_error(ErrCode.VALUE)) - maybe_update_inplace(kwargs[kwarg], new_kwargs[new_kwarg], schema_arg) - return None - - # Default op - # Note: cast to higher precision and perform op - args = tree_map(maybe_unwrap, args) - if kwargs is not None: - kwargs = tree_map(maybe_unwrap, kwargs) - out = super().__torch_dispatch__(func, types, args, kwargs) - return out - - @classmethod - def _make_in_reduce_ex( - cls, - data: torch.Tensor, - dtype: torch.dtype, - ) -> HiFloat8Tensor: - """Build HiFloat8Tensor, for use in __reduce__ - - __reduce_ex__ assumes object constructor has positional - arguments. - - """ - return HiFloat8Tensor( - data=data, - dtype=dtype, - ) - - def __reduce_ex__(self, protocol: int) -> tuple: - """Custom pickling to remove references to HIF8 metadata objects""" - return ( - HiFloat8Tensor._make_in_reduce_ex, - (self._data, self.dtype), - ) - - def _get_data(self) -> HiFloat8Tensor: - """Get tensor data property""" - return super().data - - def _set_data(self, tensor: torch.Tensor) -> None: - """Set tensor data property - - Cast tensor to HIF8 and store in HIF8 buffer. - - """ - with torch.no_grad(): - self.copy_(tensor) - - # Cast to HIF8 when setting HiFloat8Tensor.data - data = property(_get_data, _set_data) - - @classmethod - def __torch_function__(cls, func, types, args=(), kwargs=None): - if kwargs is None: - kwargs = {} - return torch._C._disabled_torch_function_impl(func, types, args, kwargs) - - def transpose(self, dim0, dim1): - return _TransposeFunc.apply(self, dim0, dim1) -- Gitee From b3124f211dbc79426fb6b2f87c93f19416c5e599 Mon Sep 17 00:00:00 2001 From: chuboning Date: Thu, 19 Jun 2025 17:21:10 +0800 Subject: [PATCH 10/11] rename --- codegen/utils.py | 6 +++--- test/npu/test_tensors.py | 8 ++++---- torch_npu/csrc/aten/common/FormatCastKernelNpu.cpp | 4 ++-- torch_npu/csrc/core/npu/NpuVariables.cpp | 8 ++++---- torch_npu/csrc/core/npu/NpuVariables.h | 4 ++-- torch_npu/csrc/core/npu/interface/AclInterface.cpp | 2 +- torch_npu/csrc/core/npu/register/OptionRegister.cpp | 6 +++--- 7 files changed, 19 insertions(+), 19 deletions(-) diff --git a/codegen/utils.py b/codegen/utils.py index c603e77fe1..1df2bfcd01 100644 --- a/codegen/utils.py +++ b/codegen/utils.py @@ -401,7 +401,7 @@ const DeviceGuard device_guard(device_or_default(device));""" device_guard = f"const OptionalDeviceGuard device_guard(device_of({device_of}));" op_key = str(f.func.name) - is_ascend910_95_version = "c10_npu::IsAscend910_95Version()" + is_ascend910_xx_version = "c10_npu::IsAscend910_xxVersion()" if enable_opplugin(): if op_key in GLOBAL_STRUCTURED_OP_INFO_CACHE: impl_name = f"op_plugin::{GLOBAL_STRUCTURED_OP_INFO_CACHE[op_key]}" @@ -473,9 +473,9 @@ if (C10_UNLIKELY(at_npu::native::env::CheckOpHookEnable())) {{ if (({force_aclnn} || at_npu::native::env::CheckJitDisable()){tensor_check_str}) {{ return {op_api_impl_name}({args_exprs_str}); }} else {{ - if ({is_ascend910_95_version}) {{ + if ({is_ascend910_xx_version}) {{ TORCH_CHECK(false, - "Ascend910_95 series only support aclnn operator, and current operator {impl_name} do not support internal format.", + "Ascend910_xx series only support aclnn operator, and current operator {impl_name} do not support internal format.", PTA_ERROR(ErrCode::NOT_SUPPORT)); }} return {impl_name}({args_exprs_str}); diff --git a/test/npu/test_tensors.py b/test/npu/test_tensors.py index ff51e258d3..044a7f7d36 100644 --- a/test/npu/test_tensors.py +++ b/test/npu/test_tensors.py @@ -23,11 +23,11 @@ types = [ ] -def skipIfUnsupport910_95(): +def skipIfUnsupport910_xx(): def skip_dec(func): def wrapper(self): - if "Ascend910_95" not in torch_npu.npu.get_device_name(): - return unittest.SkipTest("Device 910_95 condition not satisfied") + if "Ascend910_xx" not in torch_npu.npu.get_device_name(): + return unittest.SkipTest("Device 910_xx condition not satisfied") return func(self) return wrapper return skip_dec @@ -395,7 +395,7 @@ class TestViewOps(TestCase): class TestTensorDtype(TestCase): - @skipIfUnsupport910_95() + @skipIfUnsupport910_xx() def test_fp8(self): tensor1 = torch.randn([2, 2], dtype=torch.float32).npu() tensor2 = torch.randn([2, 2], dtype=torch.float32).npu() diff --git a/torch_npu/csrc/aten/common/FormatCastKernelNpu.cpp b/torch_npu/csrc/aten/common/FormatCastKernelNpu.cpp index 4e683ade9a..f7ead13ccc 100644 --- a/torch_npu/csrc/aten/common/FormatCastKernelNpu.cpp +++ b/torch_npu/csrc/aten/common/FormatCastKernelNpu.cpp @@ -35,7 +35,7 @@ std::tuple> MaybeUseAclnnNpuForma c10_npu::GetAclDataType(customize_dtype.value()) : at_npu::native::OpPreparation::convert_to_acl_data_type(src.scalar_type()); - if (c10_npu::IsAscend910_95Version()) { + if (c10_npu::IsAscend910_xxVersion()) { if (aclnnNpuFormatCastExist) { auto api_ret = GetFormat(ConvertType(src), acl_format, customizeAcltype, &dstStorageShape, &dstShapeSize, &dstFormat); @@ -47,7 +47,7 @@ std::tuple> MaybeUseAclnnNpuForma return std::make_tuple(true, dstFormat, outputShape); } TORCH_CHECK(false, - "aclnnNpuFormatCast does not exist, Ascend910_95 series only support aclnn operators.", + "aclnnNpuFormatCast does not exist, Ascend910_xx series only support aclnn operators.", PTA_ERROR(ErrCode::NOT_SUPPORT)); } if (at_npu::native::env::CheckJitDisable()) { diff --git a/torch_npu/csrc/core/npu/NpuVariables.cpp b/torch_npu/csrc/core/npu/NpuVariables.cpp index 4a222171ea..9aaca59b54 100644 --- a/torch_npu/csrc/core/npu/NpuVariables.cpp +++ b/torch_npu/csrc/core/npu/NpuVariables.cpp @@ -48,13 +48,13 @@ void SetSocVersion(const char* const socVersion) SocVersion curSocVersion = SocVersion::UnsupportedSocVersion; std::string inputVersion = socVersion; - std::string ascend95Version = "Ascend910_95"; + std::string ascend95Version = "Ascend910_xx"; auto const& iter = socVersionMap.find(socVersion); if (iter != socVersionMap.end()) { curSocVersion = iter->second; } else if ((inputVersion.compare(0, ascend95Version.size(), ascend95Version) == 0)) { - curSocVersion = SocVersion::Ascend910_95; + curSocVersion = SocVersion::Ascend910_xx; } else { std::string unsupported_soc(socVersion); std::replace(std::begin(unsupported_soc), std::end(unsupported_soc), '_', ' '); @@ -104,9 +104,9 @@ bool IsBF16Supported() return GetSocVersion() >= SocVersion::Ascend910B1; } -bool IsAscend910_95Version() +bool IsAscend910_xxVersion() { - return GetSocVersion() == SocVersion::Ascend910_95; + return GetSocVersion() == SocVersion::Ascend910_xx; } } // namespace c10_npu diff --git a/torch_npu/csrc/core/npu/NpuVariables.h b/torch_npu/csrc/core/npu/NpuVariables.h index b5a55c5f69..6481369f58 100644 --- a/torch_npu/csrc/core/npu/NpuVariables.h +++ b/torch_npu/csrc/core/npu/NpuVariables.h @@ -31,7 +31,7 @@ enum class SocVersion { Ascend910_9382, Ascend910_9372, Ascend910_9362, - Ascend910_95 = 260 + Ascend910_xx = 260 }; void SetSocVersion(const char* const socVersion); @@ -42,7 +42,7 @@ bool IsSupportInfNan(); bool IsBF16Supported(); -bool IsAscend910_95Version(); +bool IsAscend910_xxVersion(); } // namespace c10_npu #endif diff --git a/torch_npu/csrc/core/npu/interface/AclInterface.cpp b/torch_npu/csrc/core/npu/interface/AclInterface.cpp index 520393355b..54190a681d 100644 --- a/torch_npu/csrc/core/npu/interface/AclInterface.cpp +++ b/torch_npu/csrc/core/npu/interface/AclInterface.cpp @@ -846,7 +846,7 @@ bool IsCaptureSupported() static bool default_support_capture = ((GetSocVersion() >= SocVersion::Ascend910B1) && (GetSocVersion() < SocVersion::Ascend310B1)) || ((GetSocVersion() >= SocVersion::Ascend910_9391) && - (GetSocVersion() < SocVersion::Ascend910_95)); + (GetSocVersion() < SocVersion::Ascend910_xx)); if (default_support_capture && !have_load_func) { have_load_func = true; typedef aclError (*AclmdlRICaptureGetInfo)(aclrtStream, aclmdlRICaptureStatus *, aclmdlRI *); diff --git a/torch_npu/csrc/core/npu/register/OptionRegister.cpp b/torch_npu/csrc/core/npu/register/OptionRegister.cpp index 9e0c356a04..0451bfdc14 100644 --- a/torch_npu/csrc/core/npu/register/OptionRegister.cpp +++ b/torch_npu/csrc/core/npu/register/OptionRegister.cpp @@ -85,13 +85,13 @@ OptionInterfaceBuilder::OptionInterfaceBuilder(const std::string &name, ::std::u void SetOption(const std::string &key, const std::string &val) { - if (c10_npu::IsAscend910_95Version()) { + if (c10_npu::IsAscend910_xxVersion()) { if (key == "jitCompile" && val == "enable") { - TORCH_NPU_WARN_ONCE("Ascend910_95 series only support jit_compile=False, ", + TORCH_NPU_WARN_ONCE("Ascend910_xx series only support jit_compile=False, ", "the requested value True is invalid and has been reverted to False."); } if (key == "ALLOW_INTERNAL_FORMAT" && val == "enable") { - TORCH_NPU_WARN_ONCE("Ascend910_95 series only support allow_internal_format=False, ", + TORCH_NPU_WARN_ONCE("Ascend910_xx series only support allow_internal_format=False, ", "the requested value True is invalid and has been reverted to False."); } return register_options::OptionRegister::GetInstance()->Set(key, "disable"); -- Gitee From 652a62d9f297d639ed535e7af0e98743229649d3 Mon Sep 17 00:00:00 2001 From: chuboning Date: Thu, 19 Jun 2025 14:44:15 +0000 Subject: [PATCH 11/11] !242 Fix SetOption error * Fix SetOption error --- torch_npu/csrc/core/npu/register/OptionRegister.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torch_npu/csrc/core/npu/register/OptionRegister.cpp b/torch_npu/csrc/core/npu/register/OptionRegister.cpp index 0451bfdc14..69acfc234f 100644 --- a/torch_npu/csrc/core/npu/register/OptionRegister.cpp +++ b/torch_npu/csrc/core/npu/register/OptionRegister.cpp @@ -89,12 +89,13 @@ void SetOption(const std::string &key, const std::string &val) if (key == "jitCompile" && val == "enable") { TORCH_NPU_WARN_ONCE("Ascend910_xx series only support jit_compile=False, ", "the requested value True is invalid and has been reverted to False."); + return register_options::OptionRegister::GetInstance()->Set(key, "disable"); } if (key == "ALLOW_INTERNAL_FORMAT" && val == "enable") { TORCH_NPU_WARN_ONCE("Ascend910_xx series only support allow_internal_format=False, ", "the requested value True is invalid and has been reverted to False."); + return register_options::OptionRegister::GetInstance()->Set(key, "disable"); } - return register_options::OptionRegister::GetInstance()->Set(key, "disable"); } register_options::OptionRegister::GetInstance()->Set(key, val); } -- Gitee