diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..a4eabe559e088631de8e2a895c925373696b89e0 --- /dev/null +++ b/LICENSE @@ -0,0 +1,28 @@ +BSD 3-Clause License + +Copyright (c) 2025, wumingyang + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/patch/kdnn.patch b/patch/kdnn.patch new file mode 100644 index 0000000000000000000000000000000000000000..c586fe80051018a9444c3d6a137993d3b6990467 --- /dev/null +++ b/patch/kdnn.patch @@ -0,0 +1,2848 @@ +diff --git a/CMakeLists.txt b/CMakeLists.txt +index 98593c2de9..5279402572 100644 +--- a/CMakeLists.txt ++++ b/CMakeLists.txt +@@ -325,6 +325,14 @@ cmake_dependent_option(USE_ITT "Use Intel(R) VTune Profiler ITT functionality" + cmake_dependent_option( + USE_MKLDNN "Use MKLDNN. Only available on x86, x86_64, and AArch64." + "${CPU_INTEL}" "CPU_INTEL OR CPU_AARCH64" OFF) ++cmake_dependent_option( ++ USE_KDNN "Use KDNN. Only available on AArch64." ++ "${CPU_AARCH64}" "CPU_INTEL OR CPU_AARCH64" OFF) ++if(USE_KDNN) ++ set(AT_KDNN_ENABLED 1) ++else() ++ set(AT_KDNN_ENABLED 0) ++endif() + cmake_dependent_option( + USE_MKLDNN_ACL "Use Compute Library for the Arm architecture." OFF + "USE_MKLDNN AND CPU_AARCH64" OFF) +diff --git a/aten/src/ATen/CMakeLists.txt b/aten/src/ATen/CMakeLists.txt +index 6d9152a4d0..75a70b84f5 100644 +--- a/aten/src/ATen/CMakeLists.txt ++++ b/aten/src/ATen/CMakeLists.txt +@@ -85,6 +85,8 @@ file(GLOB mkldnn_xpu_cpp "native/mkldnn/xpu/*.cpp" "native/mkldnn/xpu/detail/*.c + file(GLOB native_cpp "native/*.cpp") + file(GLOB native_mkl_cpp "native/mkl/*.cpp") + file(GLOB native_mkldnn_cpp "native/mkldnn/*.cpp") ++file(GLOB native_kdnn_h "native/kdnn/*.h") ++file(GLOB native_kdnn_cpp "native/kdnn/*.cpp") + file(GLOB vulkan_cpp "vulkan/*.cpp") + file(GLOB native_vulkan_cpp "native/vulkan/*.cpp" "native/vulkan/api/*.cpp" "native/vulkan/impl/*.cpp" "native/vulkan/ops/*.cpp") + +@@ -234,6 +236,9 @@ if(USE_VULKAN) + else() + set(all_cpu_cpp ${all_cpu_cpp} ${vulkan_cpp}) + endif() ++if(AT_KDNN_ENABLED) ++ set(all_cpu_cpp ${all_cpu_cpp} ${native_kdnn_cpp}) ++endif() + + if(USE_XPU) + list(APPEND ATen_XPU_SRCS ${mkldnn_xpu_cpp}) +@@ -570,7 +575,7 @@ install(FILES "${CMAKE_CURRENT_BINARY_DIR}/cmake-exports/ATenConfig.cmake" + + set(INSTALL_HEADERS ${base_h} ${ATen_CORE_HEADERS} ${native_nested_h} ${ATen_TRANSFORMER_HEADERS}) + if(NOT INTERN_BUILD_MOBILE) +- list(APPEND INSTALL_HEADERS ${native_h} ${native_cpu_h} ${native_ao_sparse_h} ${native_quantized_h} ${cuda_h} ${native_cuda_h} ${native_hip_h} ${cudnn_h} ${hip_h} ${xpu_h} ${mps_h} ${native_mps_h} ${native_utils_h} ${miopen_h}) ++ list(APPEND INSTALL_HEADERS ${native_h} ${native_cpu_h} ${native_ao_sparse_h} ${native_quantized_h} ${cuda_h} ${native_cuda_h} ${native_hip_h} ${cudnn_h} ${hip_h} ${xpu_h} ${mps_h} ${native_mps_h} ${native_utils_h} ${miopen_h} ${native_kdnn_h}) + # Metal + if(USE_PYTORCH_METAL_EXPORT) + # Add files needed from exporting metal models(optimized_for_mobile) +diff --git a/aten/src/ATen/Config.h.in b/aten/src/ATen/Config.h.in +index fdd2ac2bc5..4d27d81d90 100644 +--- a/aten/src/ATen/Config.h.in ++++ b/aten/src/ATen/Config.h.in +@@ -19,3 +19,4 @@ + #define AT_PARALLEL_NATIVE @AT_PARALLEL_NATIVE@ + #define AT_BLAS_F2C() @AT_BLAS_F2C@ + #define AT_BLAS_USE_CBLAS_DOT() @AT_BLAS_USE_CBLAS_DOT@ ++#define AT_KDNN_ENABLED() @AT_KDNN_ENABLED@ +diff --git a/aten/src/ATen/Context.cpp b/aten/src/ATen/Context.cpp +index ff8ceed7e8..d7b3cb6082 100644 +--- a/aten/src/ATen/Context.cpp ++++ b/aten/src/ATen/Context.cpp +@@ -48,6 +48,14 @@ void Context::setUserEnabledMkldnn(bool e) { + enabled_mkldnn = e; + } + ++bool Context::userEnabledKdnn() const { ++ return enabled_kdnn; ++} ++ ++void Context::setUserEnabledKdnn(bool e) { ++ enabled_kdnn = e; ++} ++ + bool Context::deterministicCuDNN() const { + return deterministic_cudnn; + } +@@ -360,6 +368,14 @@ bool Context::hasMKLDNN() { + #endif + } + ++bool Context::hasKDNN() { ++#if AT_KDNN_ENABLED() ++ return true; ++#else ++ return false; ++#endif ++} ++ + bool Context::hasOpenMP() { + #ifdef _OPENMP + return true; +diff --git a/aten/src/ATen/Context.h b/aten/src/ATen/Context.h +index 5118009595..49fa914a9a 100644 +--- a/aten/src/ATen/Context.h ++++ b/aten/src/ATen/Context.h +@@ -117,6 +117,7 @@ class TORCH_API Context { + static bool hasMKL(); + static bool hasLAPACK(); + static bool hasMKLDNN(); ++ static bool hasKDNN(); + static bool hasMAGMA() { + return detail::getCUDAHooks().hasMAGMA(); + } +@@ -200,6 +201,8 @@ class TORCH_API Context { + void setUserEnabledCuDNN(bool e); + bool userEnabledMkldnn() const; + void setUserEnabledMkldnn(bool e); ++ bool userEnabledKdnn() const; ++ void setUserEnabledKdnn(bool e); + bool benchmarkCuDNN() const; + void setBenchmarkCuDNN(bool); + int benchmarkLimitCuDNN() const; +@@ -408,6 +411,7 @@ class TORCH_API Context { + bool allow_fp16_reduction_cublas = true; + bool allow_bf16_reduction_cublas = true; + bool enabled_mkldnn = true; ++ bool enabled_kdnn = true; + bool enabled_nnpack = true; + at::LinalgBackend linalg_preferred_backend = + c10::utils::check_env("TORCH_LINALG_PREFER_CUSOLVER") == true +@@ -541,6 +545,10 @@ inline bool hasMKLDNN() { + return globalContext().hasMKLDNN(); + } + ++inline bool hasKDNN() { ++ return globalContext().hasKDNN(); ++} ++ + inline void manual_seed(uint64_t seed) { + auto gen = globalContext().defaultGenerator(c10::DeviceType::CPU); + { +diff --git a/aten/src/ATen/native/Convolution.cpp b/aten/src/ATen/native/Convolution.cpp +index 2f85c97243..e5ecf5995d 100644 +--- a/aten/src/ATen/native/Convolution.cpp ++++ b/aten/src/ATen/native/Convolution.cpp +@@ -15,6 +15,7 @@ + #include + #include + #include ++#include + + #ifndef AT_PER_OPERATOR_HEADERS + #include +@@ -957,6 +958,14 @@ at::Tensor conv2d_symint( + if (at::isComplexType(input_.scalar_type())) { + output = complex_convolution(input, weight, bias, stride, padding, dilation, false, {{0, 0}}, groups); + } else { ++#if AT_KDNN_ENABLED() ++ bool validate = (at::globalContext().userEnabledKdnn() && groups.expect_int()==1 && at::native::kdnn_validate_utils::isValidateInputTensor(input) && ++ at::native::kdnn_validate_utils::isValidateInputTensor(weight)); ++ if (validate) { ++ output = at::native::kdnn_conv(input, weight, bias, stride, padding, dilation, {{0, 0}}, groups); ++ return is_batched ? std::move(output) : output.squeeze(0); ++ } ++#endif + output = at::convolution_symint(input, weight, bias, stride, padding, dilation, false, {{0, 0}}, groups); + } + return is_batched ? std::move(output) : output.squeeze(0); +@@ -982,6 +991,14 @@ at::Tensor conv3d_symint( + if (at::isComplexType(input_.scalar_type())) { + output = complex_convolution(input, weight, bias, stride, padding, dilation, false, {{0, 0, 0}}, groups); + } else { ++#if AT_KDNN_ENABLED() ++ bool validate = (at::globalContext().userEnabledKdnn() && groups.expect_int()==1 && at::native::kdnn_validate_utils::isValidateInputTensor(input) && ++ at::native::kdnn_validate_utils::isValidateInputTensor(weight)); ++ if (validate) { ++ output = at::native::kdnn_conv(input, weight, bias, stride, padding, dilation, {{0, 0, 0}}, groups); ++ return is_batched ? std::move(output) : output.squeeze(0); ++ } ++#endif + output = at::convolution_symint(input, weight, bias, stride, padding, dilation, false, {{0, 0, 0}}, groups); + } + return is_batched ? std::move(output) : output.squeeze(0); +diff --git a/aten/src/ATen/native/Embedding.cpp b/aten/src/ATen/native/Embedding.cpp +index b0c4644e57..363fc9329f 100644 +--- a/aten/src/ATen/native/Embedding.cpp ++++ b/aten/src/ATen/native/Embedding.cpp +@@ -1,4 +1,5 @@ + #define TORCH_ASSERT_ONLY_METHOD_OPERATORS ++#include + #include + #include + #include +@@ -7,6 +8,7 @@ + #include + #include + #include ++#include // kdnn + + #ifndef AT_PER_OPERATOR_HEADERS + #include +@@ -40,6 +42,16 @@ Tensor embedding_symint(const Tensor & weight, const Tensor & indices, + auto indices_arg = TensorArg(indices, "indices", 1); + checkScalarTypes("embedding", indices_arg, {kLong, kInt}); + ++ // 根据AT_KDNN_ENABLED判断是否使用KDNN实现 ++#if AT_KDNN_ENABLED() ++ bool validate = (at::globalContext().userEnabledKdnn() && ++ at::native::kdnn_validate_utils::isValidateInputTensor(weight) && ++ at::native::kdnn_validate_utils::isValidateInputTensor(indices)); ++ if (validate) { ++ return kdnn_embedding(weight, indices, padding_idx, scale_grad_by_freq, sparse); ++ } ++#endif ++ + // TODO: use tensor.index() after improving perf + if (indices.dim() == 1) { + return weight.index_select(0, indices); +diff --git a/aten/src/ATen/native/Linear.cpp b/aten/src/ATen/native/Linear.cpp +index e0af0a6b75..01debde832 100644 +--- a/aten/src/ATen/native/Linear.cpp ++++ b/aten/src/ATen/native/Linear.cpp +@@ -1,4 +1,5 @@ + #define TORCH_ASSERT_ONLY_METHOD_OPERATORS ++#include + #include + #include + #include +@@ -8,6 +9,7 @@ + #include + #include + #include ++#include // kdnn + + #ifndef AT_PER_OPERATOR_HEADERS + #include +@@ -83,6 +85,16 @@ Tensor linear(const Tensor& input, const Tensor& weight, const std::optional::borrowed(*bias_opt) + : c10::MaybeOwned::owned(std::in_place); ++ ++ // test ++#if AT_KDNN_ENABLED() ++ bool validate = (at::globalContext().userEnabledKdnn() && at::native::kdnn_validate_utils::isValidateInputTensor(input) && ++ at::native::kdnn_validate_utils::isValidateInputTensor(weight)); ++ if (validate) { ++ return kdnn_linear(input, weight, *bias); ++ } ++#endif ++ // testend + if (input.is_mkldnn()) { + return at::mkldnn_linear(input, weight, *bias); + } +diff --git a/aten/src/ATen/native/SoftMax.cpp b/aten/src/ATen/native/SoftMax.cpp +index 94090fe079..7a880b202c 100644 +--- a/aten/src/ATen/native/SoftMax.cpp ++++ b/aten/src/ATen/native/SoftMax.cpp +@@ -9,7 +9,6 @@ + #include + #include + #include +- + #ifndef AT_PER_OPERATOR_HEADERS + #include + #include +@@ -32,18 +31,17 @@ + #include + #endif + ++#include ++#include + #include + #include + #include +- + namespace at::meta { + TORCH_META_FUNC(_softmax) + (const Tensor& input, const int64_t dim, const bool half_to_float) { + int64_t dim_ = maybe_wrap_dim(dim, input.dim()); +- + auto output_options = + input.options().memory_format(LEGACY_CONTIGUOUS_MEMORY_FORMAT); +- + if (half_to_float) { + output_options = output_options.dtype(ScalarType::Float); + } +@@ -52,7 +50,6 @@ TORCH_META_FUNC(_softmax) + TORCH_CHECK( + dim_ >= 0 && dim_ < input_dim, + "dim must be non-negative and less than input dimensions"); +- + set_output_raw_strided(0, input.sizes(), {}, output_options); + } + +@@ -337,7 +334,6 @@ TORCH_IMPL_FUNC(softmax_cpu_out) + const int64_t dim, + const bool half_to_float, + const Tensor& output) { +- TORCH_CHECK(!half_to_float, "softmax with half to float conversion is not supported on CPU"); + + if (input.numel() == 0) { + return; +@@ -353,11 +349,30 @@ TORCH_IMPL_FUNC(softmax_cpu_out) + TORCH_CHECK( + dim_ >= 0 && dim_ < input_.dim(), + "dim must be non-negative and less than input dimensions"); +- if (input_.ndimension() > 0 && dim_ == input_.ndimension() - 1) { ++ bool is_last_dim = (input_.ndimension() > 0) && (dim_ == input_.ndimension() - 1); ++#if AT_KDNN_ENABLED() ++ bool kdnnUsable = at::globalContext().userEnabledKdnn() && ++ at::native::kdnn_validate_utils::isValidateInputTensor(input_) && ++ at::native::kdnn_validate_utils::isValidateInputTensor(output); ++ if (kdnnUsable) { ++ at::native::kdnn_softmax_kernel(kCPU, output, input_, dim_); ++ } ++ else { ++ if (is_last_dim) { ++ softmax_lastdim_kernel(kCPU, output, input_); ++ } ++ else { ++ softmax_kernel(kCPU, output, input_, dim_); ++ } ++ } ++#else ++ if (is_last_dim) { + softmax_lastdim_kernel(kCPU, output, input_); +- } else { ++ } ++ else { + softmax_kernel(kCPU, output, input_, dim_); + } ++#endif + } + + TORCH_IMPL_FUNC(log_softmax_cpu_out) +@@ -379,12 +394,30 @@ TORCH_IMPL_FUNC(log_softmax_cpu_out) + if (input_.dim() == 0) { + input_ = input_.view(1); + } +- +- if (input_.ndimension() > 0 && dim_ == input_.ndimension() - 1) { ++ bool is_last_dim = (input_.ndimension() > 0) && (dim_ == input_.ndimension() - 1); ++#if AT_KDNN_ENABLED() ++ bool kdnnUsable = at::globalContext().userEnabledKdnn() && ++ at::native::kdnn_validate_utils::isValidateInputTensor(input_) && ++ at::native::kdnn_validate_utils::isValidateInputTensor(output); ++ if (kdnnUsable) { ++ at::native::kdnn_softmax_kernel(kCPU, output, input_, dim_); ++ } ++ else { ++ if (is_last_dim) { ++ log_softmax_lastdim_kernel(kCPU, output, input_); ++ } ++ else { ++ log_softmax_kernel(kCPU, output, input_, dim_); ++ } ++ } ++#else ++ if (is_last_dim) { + log_softmax_lastdim_kernel(kCPU, output, input_); +- } else { ++ } ++ else { + log_softmax_kernel(kCPU, output, input_, dim_); + } ++#endif + } + + TORCH_IMPL_FUNC(softmax_backward_cpu_out) +@@ -451,6 +484,7 @@ Tensor softmax(const Tensor& input_, const int64_t dim_, std::optional dtype) { ++ TORCH_CHECK(self.dim() > 0, "Softmax is not defined for scalars"); + return at::softmax(self, dimname_to_position(self, dim), dtype); + } + + Tensor log_softmax(const Tensor& self, Dimname dim, std::optional dtype) { ++ TORCH_CHECK(self.dim() > 0, "Softmax is not defined for scalars"); + return at::log_softmax(self, dimname_to_position(self, dim), dtype); + } + +diff --git a/aten/src/ATen/native/cpu/Activation.cpp b/aten/src/ATen/native/cpu/Activation.cpp +index 88b43015d9..be47114ad0 100644 +--- a/aten/src/ATen/native/cpu/Activation.cpp ++++ b/aten/src/ATen/native/cpu/Activation.cpp +@@ -1399,7 +1399,6 @@ void prelu_backward_kernel(TensorIterator& iter) { + + } // namespace + +- + REGISTER_DISPATCH(hardsigmoid_stub, &hardsigmoid_kernel); + REGISTER_DISPATCH(hardsigmoid_backward_stub, &hardsigmoid_backward_kernel); + REGISTER_DISPATCH(threshold_stub, &threshold_kernel); +diff --git a/aten/src/ATen/native/group_norm.cpp b/aten/src/ATen/native/group_norm.cpp +index 655b2c1e72..41616b345f 100644 +--- a/aten/src/ATen/native/group_norm.cpp ++++ b/aten/src/ATen/native/group_norm.cpp +@@ -4,12 +4,16 @@ + #include + #include + #include ++#include ++#include + + #ifndef AT_PER_OPERATOR_HEADERS + #include + #include + #else + #include ++#include ++#include + #include + #include + #include +@@ -97,6 +101,39 @@ std::tuple native_group_norm( + const auto dtype = param_scalar_type(X, mixed_type); + Tensor mean = at::empty({N, group}, X.options().dtype(dtype)); + Tensor rstd = at::empty({N, group}, X.options().dtype(dtype)); ++#if AT_KDNN_ENABLED() ++ bool validation = at::native::kdnn_validate_utils::isValidateInputTensor(X); ++ auto XDtype = X.scalar_type(); ++ at::TensorOptions options = X.options().dtype(XDtype).device(X.device()); ++ ++ IntArrayRef input_size(X.sizes()); ++ at::Tensor gamma_t, beta_t; ++ if (!gamma_opt.has_value()) { ++ auto dummy_weight = at::ones(input_size, options); ++ gamma_t = dummy_weight; ++ } else { ++ gamma_t = gamma_opt.value(); ++ } ++ ++ if (!beta_opt.has_value()) { ++ auto dummy_bias = at::zeros(input_size, options); ++ beta_t = dummy_bias; ++ } else { ++ beta_t = beta_opt.value(); ++ } ++ ++ bool isUsableDtypeForGNorm = at::native::isSupportedDtypeForNorm(X, gamma_t, beta_t); ++ bool kdnn_usable = at::globalContext().userEnabledKdnn() && validation && isUsableDtypeForGNorm && (!mixed_type) && (memory_format == at::MemoryFormat::Contiguous) && (eps != 0.0); ++ if (kdnn_usable) { ++ float eps_value = static_cast(eps); ++ auto mean_float = mean.to(at::kFloat); ++ auto rstd_float = rstd.to(at::kFloat); ++ at::native::kdnn_group_norm_forward(X, Y, gamma_t, beta_t, eps_value, group, N, C, HxW, mean_float, rstd_float); ++ mean = mean_float.to(dtype); ++ rstd = rstd_float.to(dtype); ++ return std::make_tuple(Y, mean, rstd); ++ } ++#endif + GroupNormKernel( + X.device().type(), X, gamma, beta, N, C, HxW, group, eps, Y, mean, rstd); + return std::make_tuple(Y, mean, rstd); +@@ -202,6 +239,7 @@ Tensor group_norm( + const auto& beta = bias.defined() ? bias.contiguous() : kEmpty; + TORCH_CHECK(!gamma.defined() || gamma.sym_numel() == C); + TORCH_CHECK(!beta.defined() || beta.sym_numel() == C); ++ + return std::get<0>( + at::native_group_norm_symint(X, gamma, beta, N, C, HxW, num_groups, eps)); + } +diff --git a/aten/src/ATen/native/kdnn/dependencies/.gitkeep b/aten/src/ATen/native/kdnn/dependencies/.gitkeep +new file mode 100644 +index 0000000000..e69de29bb2 +diff --git a/aten/src/ATen/native/kdnn/kdnn.cpp b/aten/src/ATen/native/kdnn/kdnn.cpp +new file mode 100644 +index 0000000000..82065e1528 +--- /dev/null ++++ b/aten/src/ATen/native/kdnn/kdnn.cpp +@@ -0,0 +1,92 @@ ++#include ++#include ++#include "kdnn.hpp" ++ ++namespace at::native { ++ namespace kdnn_validate_utils { ++ bool isValueIllegal(const Tensor& tensor) ++ { ++ return tensor.isnan().any().item().to() || tensor.isinf().any().item().to(); ++ } ++ ++ bool isKDNNDTypeUnSupported(const Tensor& tensor) ++ { ++ auto dtype = tensor.scalar_type(); ++ auto it = at::native::type_map.find(dtype); ++ if (it == at::native::type_map.end()) { ++ return true; ++ } ++ return false; ++ } ++ ++ bool isKDNNLayoutUnSupported(const Tensor& tensor) ++ { ++ // check is dense ++ if (tensor.layout() != c10::kStrided) { ++ return true; ++ } ++ ++ auto dim = tensor.dim(); ++ auto it = at::native::layout_map.find(dim); ++ if (it == at::native::layout_map.end()) { ++ return true; ++ } ++ return false; ++ } ++ ++ bool isTensorEmpty(const Tensor& tensor) ++ { ++ if (tensor.numel() == 0) { ++ return true; ++ } ++ return false; ++ } ++ ++ bool isValidateInputTensor(const Tensor& input) ++ { ++ if (isKDNNDTypeUnSupported(input) || isKDNNLayoutUnSupported(input) || isTensorEmpty(input) || isValueIllegal(input)) { ++ return false; ++ } ++ return true; ++ } ++ } ++ ++ ++ KDNN::TensorInfo getKDNNTensor(const Tensor& tensor) ++ { ++ auto dims = tensor.sizes(); ++ std::vector dims_vec(dims.begin(), dims.end()); ++ KDNN::Shape shape(dims_vec.data(), dims.size()); ++ auto dtype = tensor.scalar_type(); ++ ++ KDNN::Element::TypeT type; ++ try { ++ auto it_type = type_map.find(dtype); ++ if (it_type == type_map.end()) { ++ throw DtypeNotFoundException(dtype); ++ } ++ type = it_type->second; ++ } catch (const DtypeNotFoundException& e) { ++ std::cerr << "Error: " << e.what() << '\n'; ++ } ++ ++ KDNN::Layout layout; ++ try { ++ auto it_layout = layout_map.find(dims.size()); ++ if (it_layout == layout_map.end()) { ++ throw LayoutNotFoundException(dims.size()); ++ } ++ layout = it_layout->second; ++ } catch (const LayoutNotFoundException& e) { ++ std::cerr << "Error: " << e.what() << '\n'; ++ } ++ return KDNN::TensorInfo(shape, type, layout); ++ } ++ ++ template ++ std::string kdnn_map_value_to_String(const T& value) // 泛型打印kdnn所有类型map中的value类型 ++ { ++ return typeid(value).name(); ++ } ++} +diff --git a/aten/src/ATen/native/kdnn/kdnn.h b/aten/src/ATen/native/kdnn/kdnn.h +new file mode 100644 +index 0000000000..cc8f9c72d5 +--- /dev/null ++++ b/aten/src/ATen/native/kdnn/kdnn.h +@@ -0,0 +1,130 @@ ++#pragma once ++#include ++#include ++#include ++ ++#include ++#include ++#include ++#include ++#include ++#include ++#include ++#include "kdnn.hpp" ++ ++namespace std { ++ struct KeyType { ++ c10::ScalarType a; ++ c10::ScalarType b; ++ c10::ScalarType c; ++ KeyType(c10::ScalarType a, c10::ScalarType b, c10::ScalarType c) : a(a), b(b), c(c) {} ++ bool operator==(const KeyType& other) const { ++ return a == other.a && b == other.b && c == other.c; ++ } ++ }; ++ ++ template<> ++ struct hash { ++ size_t operator()(const KeyType& key) const { ++ size_t hash_val = 0; ++ std::hash> hasher; ++ hash_val ^= hasher(static_cast>(key.a)) + 0x9e3779b9 + (hash_val << 6) + (hash_val >> 2); ++ hash_val ^= hasher(static_cast>(key.b)) + 0x9e3779b9 + (hash_val << 6) + (hash_val >> 2); ++ hash_val ^= hasher(static_cast>(key.c)) + 0x9e3779b9 + (hash_val << 6) + (hash_val >> 2); ++ return hash_val; ++ } ++ }; ++} ++ ++namespace at::native { ++ ++ constexpr auto kUInt8 = at::kByte; ++ constexpr auto kInt8 = at::kChar; ++ constexpr auto kInt16 = at::kShort; ++ constexpr auto kInt32 = at::kInt; ++ constexpr auto kInt64 = at::kLong; ++ constexpr auto kUInt16 = at::kUInt16; ++ constexpr auto kUInt32 = at::kUInt32; ++ constexpr auto kUInt64 = at::kUInt64; ++ constexpr auto kFloat16 = at::kHalf; ++ constexpr auto kFloat32 = at::kFloat; ++ constexpr auto kFloat64 = at::kDouble; ++ ++ static const std::unordered_map layout_map = { ++ {1, KDNN::Layout::A}, ++ {2, KDNN::Layout::AB}, ++ {3, KDNN::Layout::ABC}, ++ {4, KDNN::Layout::ABCD}, ++ {5, KDNN::Layout::ABCDE} ++ }; ++ ++ static const std::unordered_map type_map = { ++ {kFloat32, KDNN::Element::TypeT::F32}, ++ {kFloat16, KDNN::Element::TypeT::F16}, ++ // {kFloat, KDNN::Element::TypeT::BF16}, 目前不支持BF16 ++ {kInt32, KDNN::Element::TypeT::S32}, ++ {kInt8, KDNN::Element::TypeT::S8}, ++ {kUInt8, KDNN::Element::TypeT::U8}, ++ }; ++ ++ static const std::unordered_map output_type_map = { ++ {std::KeyType(kFloat32, kFloat32, kFloat32), kFloat32}, ++ {std::KeyType(kFloat16, kFloat16, kFloat16), kFloat16}, ++ {std::KeyType(kInt8, kInt8, kInt32), kInt32}, ++ {std::KeyType(kInt8, kUInt8, kInt32), kInt32}, ++ {std::KeyType(kUInt8, kInt8, kInt32), kInt32}, ++ {std::KeyType(kUInt8, kUInt8, kInt32), kInt32}, ++ {std::KeyType(kFloat32, kFloat32, kFloat32), kFloat32}, ++ {std::KeyType(kFloat16, kInt8, kFloat16), kFloat32}, ++ {std::KeyType(kFloat16, kUInt8, kFloat16), kFloat32} ++ }; ++ ++ // utils ++ KDNN::TensorInfo getKDNNTensor(const Tensor& tensor); ++ template ++ std::string kdnn_map_value_to_String(const T& value); ++ // KDNN Validate Utils ++ namespace kdnn_validate_utils { ++ bool isKDNNDTypeUnSupported(const Tensor& tensor); ++ bool isKDNNLayoutUnSupported(const Tensor& tensor); ++ bool isTensorEmpty(const Tensor& tensor); ++ bool isValueIllegal(const Tensor& tensor); ++ bool isValidateInputTensor(const Tensor& input); ++ } ++ // softmax ++ void execute_softmax_layer(const Tensor& input, const Tensor& output, KDNN::SoftmaxAlgorithmKind algorithm_kind, int64_t dim); ++ void kdnn_softmax_lastdim_kernel(const DeviceType device_type, const Tensor& ou, const Tensor& in); ++ void kdnn_softmax_kernel(const DeviceType device_type, const Tensor& ou, const Tensor& in, int64_t dim); ++ void kdnn_log_softmax_lastdim_kernel(const DeviceType device_type, const Tensor& ou, const Tensor& in); ++ void kdnn_log_softmax_kernel(const DeviceType device_type, const Tensor& ou, const Tensor& in, int64_t dim); ++ void kdnn_softmax_backward_data_kernel(const DeviceType device_type, const Tensor& ou, const Tensor& in, Tensor grad_ou); ++ ++ // Linear ++ void kdnn_linear_forward(const Tensor& input, const Tensor& weight, Tensor& output, const Tensor& bias); ++ void kdnn_linear_forward(const Tensor& input, const Tensor& weight, Tensor& output); ++ Tensor kdnn_linear(const Tensor& input, const Tensor& weight, const std::optional& bias_opt); ++ ++ ++ // GroupNorm ++ void kdnn_group_norm_forward(const Tensor& input, Tensor& output, const Tensor& gamma, ++ const Tensor& beta, float eps, int64_t group, int64_t N, ++ int64_t C, int64_t HxW, Tensor& mean, Tensor& rstd); ++ bool isSupportedDtypeForNorm(const Tensor& input, const Tensor& gamma, const Tensor& beta); ++ ++ // layerNorm ++ void kdnn_layer_norm_forward(at::Tensor& out, at::Tensor& mean, at::Tensor& rstd, ++ const at::Tensor& input, IntArrayRef normalized_shape, const Tensor& gamma, ++ const Tensor& beta, double eps, int64_t M, int64_t N); ++ ++ //RMS Norm ++ Tensor kdnn_rms_norm_forward(const Tensor& input, IntArrayRef normalized_shape,std::optional eps, const std::optional& weight_opt); ++ ++ // Embedding ++ Tensor kdnn_embedding(const Tensor & weight, const Tensor & indices, ++ c10::SymInt padding_idx, bool scale_grad_by_freq, bool sparse); ++ ++ // Conv ++ Tensor kdnn_conv(const Tensor &input, const Tensor &weight, const Tensor &bias, SymIntArrayRef stride, ++ SymIntArrayRef padding, SymIntArrayRef dilation, SymIntArrayRef output_padding, ++ c10::SymInt groups); ++} +diff --git a/aten/src/ATen/native/kdnn/kdnn_conv.cpp b/aten/src/ATen/native/kdnn/kdnn_conv.cpp +new file mode 100644 +index 0000000000..ecccd2e242 +--- /dev/null ++++ b/aten/src/ATen/native/kdnn/kdnn_conv.cpp +@@ -0,0 +1,144 @@ ++#define TORCH_ASSERT_ONLY_METHOD_OPERATORS ++#include ++#include ++#include ++#include ++#include ++ ++#ifndef AT_PER_OPERATOR_HEADERS ++#include ++#include ++#else ++#include ++#include ++#endif ++#include ++ ++using SizeType = KDNN::SizeType; ++using Shape = KDNN::Shape; ++ ++#if !AT_KDNN_ENABLED() ++ ++namespace at { ++namespace native { ++ ++Tensor kdnn_conv(const Tensor &input, const Tensor &weight, const Tensor &bias, SymIntArrayRef stride, ++ SymIntArrayRef padding, SymIntArrayRef dilation, SymIntArrayRef output_padding, c10::SymInt groups) { ++ TORCH_CHECK(false, "kdnn_conv: ATen not compiled with KDNN support"); ++} ++ ++} // namespace native ++} // namespace at ++#else // AT_KDNN_ENABLED ++ ++namespace at { ++namespace native { ++static const std::unordered_map SupportedDtypeForConv = { ++ {kFloat32, KDNN::Element::TypeT::F32}, ++ {kFloat16, KDNN::Element::TypeT::F16}, ++}; ++ ++static bool isSupportedDtypeForConv(const Tensor& tensor) { ++ auto dtype = tensor.scalar_type(); ++ auto it = at::native::SupportedDtypeForConv.find(dtype); ++ if (it != at::native::type_map.end()) { ++ return true; ++ } ++ return false; ++} ++ ++static inline std::vector getVec(const SymIntArrayRef& array, int64_t offset = 0) { ++ std::vector result; ++ result.reserve(array.size()); ++ for (auto a : array) { ++ result.push_back(a.expect_int() + offset); ++ } ++ return result; ++} ++ ++static inline int64_t getOutShapeByPadding(int64_t inSize, int64_t kernelSize, int64_t stride, int64_t dilation_kdnn, ++ int64_t paddingL, int64_t paddingR) ++{ ++ int64_t outSize = (inSize + paddingL + paddingR - ((dilation_kdnn + 1) * (kernelSize - 1) + 1)) / stride + 1; ++ return outSize; ++} ++ ++static inline std::vector getOutSize(const Tensor &input, const Tensor &weight, std::vector strideVec, ++ Shape paddingL, Shape paddingR, std::vector dilationVec, ++ int64_t dim) ++{ ++ if (dim == 2) { ++ // 2D input(N, IC, IH, IW) weight(OC, IC, KH, KW) output(N, OC, OH, OW) bias(OC) ++ int64_t N = input.size(0), OC = weight.size(0), IH = input.size(-2), IW = input.size(-1), KH = weight.size(-2), KW = weight.size(-1); ++ auto OH = getOutShapeByPadding(IH, KH, strideVec[0], dilationVec[0], static_cast(paddingL[0]), ++ static_cast(paddingR[0])); ++ auto OW = getOutShapeByPadding(IW, KW, strideVec[1], dilationVec[1], static_cast(paddingL[1]), ++ static_cast(paddingR[1])); ++ return {N, OC, OH, OW}; // 2D ++ } else if( dim == 3) { ++ // 3D input(N, IC, ID, IH, IW) weight(OC, IC, KD, KH, KW) output(N, OC, OD, OH, OW) bias(OC) ++ int64_t N = input.size(0), OC = weight.size(0), IH = input.size(-2), IW = input.size(-1), KH = weight.size(-2), KW = weight.size(-1); ++ int64_t ID = input.size(2), KD = weight.size(2); ++ auto OD = getOutShapeByPadding(ID, KD, strideVec[0], dilationVec[0], static_cast(paddingL[0]), ++ static_cast(paddingR[0])); ++ auto OH = getOutShapeByPadding(IH, KH, strideVec[1], dilationVec[1], static_cast(paddingL[1]), ++ static_cast(paddingR[1])); ++ auto OW = getOutShapeByPadding(IW, KW, strideVec[2], dilationVec[2], static_cast(paddingL[2]), ++ static_cast(paddingR[2])); ++ return {N, OC, OD, OH, OW}; //3D ++ } else { ++ TORCH_CHECK(false, "Invalid input dimension. Expected 2 or 3 dimensions."); ++ } ++ return {}; ++} ++ ++Tensor kdnn_conv(const Tensor &input, const Tensor &weight, const Tensor &bias, SymIntArrayRef stride, ++ SymIntArrayRef padding, SymIntArrayRef dilation, SymIntArrayRef output_padding, c10::SymInt groups) ++{ ++ auto dim = output_padding.size(); // 2D or 3D ++ auto inputContig = input.contiguous(); ++ auto weightContig = weight.contiguous(); ++ auto biasContig = bias.contiguous(); ++ ++ if (!isSupportedDtypeForConv(inputContig)) { ++ TORCH_CHECK(false, "UnSupported input data type for kdnn_conv."); ++ } ++ ++ // conver SymIntArrayRef to kdnn shape ++ std::vector strideVec = getVec(stride); ++ std::vector paddingVec = getVec(padding); ++ std::vector dilationVec = getVec(dilation, -1); ++ Shape strideShape(strideVec.data(), strideVec.size()); ++ Shape dilationShape = {dilationVec.data(), dilationVec.size()}; ++ ++ TORCH_CHECK(paddingVec.size() == dim, "Invalid padding size. Expected ", dim, " elements."); ++ ++ // get paddingL paddingR from pytorch padding ++ // 2D paddingL(top, left), paddingR(down, right) from pytorch padding(top=dwon, left=right) ++ // 3D paddingL(depth_front, top, left), paddingR(depth_back, down, right) from pytorch padding(depth_front=depth_back, top=dwon, left=right) ++ Shape paddingL(paddingVec.data(), paddingVec.size()), paddingR(paddingVec.data(), paddingVec.size()); ++ ++ // create output tensor ++ auto outputSize = getOutSize(input, weight, strideVec, paddingL, paddingR, dilationVec, dim); ++ auto output = at::empty(outputSize, input.options()); ++ auto outputContig = output.contiguous(); ++ ++ // crete kdnn input, weight, output, bias TensorInfo ++ KDNN::ConvolutionAlgorithm alg(KDNN::ConvolutionAlgorithm::AUTO); ++ KDNN::TensorInfo inputTensor = getKDNNTensor(inputContig); ++ KDNN::TensorInfo weightTensor = getKDNNTensor(weightContig); ++ KDNN::TensorInfo outputTensor = getKDNNTensor(outputContig); ++ KDNN::TensorInfo biasTensor = bias.defined() ? ++ getKDNNTensor(biasContig) : ++ KDNN::TensorInfo(KDNN::Shape(0), outputTensor.GetType(), KDNN::Layout::A); ++ void *biasPtr = bias.defined() ? biasContig.data_ptr() : nullptr; ++ ++ KDNN::ConvolutionLayerFWD convFwdLayer(inputTensor, weightTensor, outputTensor, biasTensor, strideShape, ++ dilationShape, paddingL, paddingR, alg); ++ convFwdLayer.Run(inputContig.data_ptr(), weightContig.data_ptr(), outputContig.data_ptr(), biasPtr); ++ return output; ++} ++ ++} // namespace native ++} // namespace at ++#endif // AT_KDNN_ENABLED +diff --git a/aten/src/ATen/native/kdnn/kdnn_embedding.cpp b/aten/src/ATen/native/kdnn/kdnn_embedding.cpp +new file mode 100644 +index 0000000000..f8a7e830d8 +--- /dev/null ++++ b/aten/src/ATen/native/kdnn/kdnn_embedding.cpp +@@ -0,0 +1,72 @@ ++#define TORCH_ASSERT_ONLY_METHOD_OPERATORS ++#include ++#include ++#include ++#include ++#include // kdnn ++#include ++ ++#ifndef AT_PER_OPERATOR_HEADERS ++#include ++#include ++#else ++#include ++#include ++#include ++#endif ++#include ++// #include "kdnn.hpp" ++#include ++ ++#if !AT_KDNN_ENABLED() ++ ++namespace at { ++namespace native { ++ ++Tensor kdnn_embeeding(const Tensor &self, const Tensor &weight, const std::optional &bias_opt) ++{ ++ TORCH_CHECK(false, "kdnn_embeeding: ATen not compiled with KDNN support"); ++} ++ ++} // namespace native ++} // namespace at ++#else // AT_KDNN_ENABLED ++ ++namespace at { ++namespace native { ++ ++Tensor kdnn_embedding(const Tensor &weight, const Tensor &indices, c10::SymInt padding_idx, bool scale_grad_by_freq, ++ bool sparse) ++{ ++ // 获取权重张量的属性 ++ int num_embeddings = weight.size(0); ++ int embedding_dim = weight.size(1); ++ __fp16 *weight_ptr = (__fp16 *)weight.data_ptr(); ++ ++ // 创建EmbeddingLayerFWD对象 ++ KDNN::EmbeddingLayerFWD embedding(num_embeddings, embedding_dim, weight_ptr); ++ ++ // 处理indices张量 ++ auto indice_data = indices.data_ptr(); ++ int num = indices.numel(); ++ ++ // 调用Run方法获取结果指针 ++ __fp16 *out_ptr = embedding.Run(indice_data, num); ++ ++ // 获取indices的形状 ++ auto indice_shape = indices.sizes(); ++ std::vector output_shape(indice_shape.begin(), indice_shape.end()); ++ output_shape.push_back(embedding_dim); ++ ++ // 将__fp16*转换为void* ++ void *out_ptr_void = static_cast(out_ptr); ++ ++ // 创建结果张量 ++ at::Tensor output = at::from_blob(out_ptr_void, output_shape, at::kHalf); ++ ++ return output; ++} ++ ++} // namespace native ++} // namespace at ++#endif // AT_KDNN_ENABLED +diff --git a/aten/src/ATen/native/kdnn/kdnn_exception.h b/aten/src/ATen/native/kdnn/kdnn_exception.h +new file mode 100644 +index 0000000000..a5fd968be0 +--- /dev/null ++++ b/aten/src/ATen/native/kdnn/kdnn_exception.h +@@ -0,0 +1,19 @@ ++#include ++#include ++ ++namespace at { ++namespace native { ++ ++class LayoutNotFoundException : public std::runtime_error { ++public: ++ explicit LayoutNotFoundException(int dim_size) ++ : std::runtime_error("Layout not found for " + std::to_string(dim_size) + " dimensions") {} ++}; ++ ++class DtypeNotFoundException : public std::runtime_error { ++public: ++ explicit DtypeNotFoundException(at::ScalarType type) ++ : std::runtime_error("KDNN does not support the data type " + std::string(c10::scalarTypeToTypeMeta(type).name()) + " yet") {} ++}; ++} // native ++} // at +diff --git a/aten/src/ATen/native/kdnn/kdnn_linear.cpp b/aten/src/ATen/native/kdnn/kdnn_linear.cpp +new file mode 100644 +index 0000000000..753d2fa543 +--- /dev/null ++++ b/aten/src/ATen/native/kdnn/kdnn_linear.cpp +@@ -0,0 +1,156 @@ ++#define TORCH_ASSERT_ONLY_METHOD_OPERATORS ++#include ++#include ++#include ++#include ++#include ++ ++#ifndef AT_PER_OPERATOR_HEADERS ++#include ++#include ++#else ++#include ++#include ++#include ++#include ++#endif ++#include ++// #include "kdnn.hpp" ++#include ++ ++#if !AT_KDNN_ENABLED() ++ ++namespace at { ++namespace native { ++ ++Tensor kdnn_linear( ++ const Tensor& self, ++ const Tensor& weight, const std::optional& bias_opt) { ++ TORCH_CHECK(false, "kdnn_linear: ATen not compiled with KDNN support"); ++} ++ ++} // namespace native ++} // namespace at ++#else // AT_KDNN_ENABLED ++ ++ ++namespace at { ++namespace native { ++ ++ ++void kdnn_linear_forward(const Tensor& input, const Tensor& weight, Tensor& output, const Tensor& bias) ++{ ++ auto input_contig = input.contiguous(); ++ auto bias_contig = bias.expand({input.size(0), bias.size(0)}).contiguous(); // broadcast bias to match input size ++ auto weightT_contig = weight.t().contiguous(); ++ ++ auto inputDtype = input_contig.scalar_type(); ++ auto weightDtype = weightT_contig.scalar_type(); ++ auto biasDtype = bias_contig.scalar_type(); ++ ++ // Infer output dtype and shape ++ at::ScalarType outputDtype = kFloat32; ++ auto key_struct = std::KeyType(inputDtype, weightDtype, biasDtype); ++ auto it = output_type_map.find(key_struct); ++ if (it!=output_type_map.end()) { ++ outputDtype = it->second; ++ } ++ ++ std::vector output_size = {input.size(0), weight.size(0)}; ++ at::TensorOptions options = input.options().dtype(outputDtype).device(input.device()); ++ output = at::empty(output_size, options); ++ auto output_contig = output.contiguous(); ++ ++ // check contiguous ++ if (!input_contig.is_contiguous() || !weightT_contig.is_contiguous() || ++ !bias_contig.is_contiguous() || !output_contig.is_contiguous()) { ++ throw std::runtime_error("kdnn_linear_forward: input, weight, bias, output must be contiguous"); ++ } ++ ++ void * input_ptr = input_contig.data_ptr(); ++ void * weight_ptr = weightT_contig.data_ptr(); ++ void * bias_ptr = bias_contig.data_ptr(); ++ void * output_ptr = output_contig.data_ptr(); ++ ++ KDNN::TensorInfo inputTensor = getKDNNTensor(input_contig); ++ KDNN::TensorInfo weightTensor = getKDNNTensor(weightT_contig); ++ KDNN::TensorInfo biasTensor = getKDNNTensor(bias_contig); ++ KDNN::TensorInfo outputTensor = getKDNNTensor(output_contig); ++ ++ KDNN::Gemm gemmLayer(inputTensor, weightTensor, outputTensor, biasTensor); ++ gemmLayer.Run(input_ptr, weight_ptr, output_ptr, bias_ptr); ++ return; ++} ++ ++void kdnn_linear_forward(const Tensor& input, const Tensor& weight, Tensor& output) ++{ ++ auto input_contig = input.contiguous(); ++ auto weightT_contig = weight.t().contiguous(); ++ ++ auto inputDtype = input_contig.scalar_type(); ++ auto weightDtype = weightT_contig.scalar_type(); ++ ++ // Infer output dtype and shape ++ at::ScalarType outputDtype = kFloat32; ++ auto key_struct = std::KeyType(inputDtype, weightDtype, inputDtype); // bias is not defined use same type as input ++ auto it = output_type_map.find(key_struct); ++ if (it!=output_type_map.end()) { ++ outputDtype = it->second; ++ } ++ ++ std::vector output_size = {input.size(0), weight.size(0)}; ++ at::TensorOptions options = input.options().dtype(outputDtype).device(input.device()); ++ output = at::empty(output_size, options); ++ auto output_contig = output.contiguous(); ++ ++ // check contiguous ++ if (!input_contig.is_contiguous() || !weightT_contig.is_contiguous() || !output_contig.is_contiguous()) { ++ throw std::runtime_error("kdnn_linear_forward: input, weight, output must be contiguous"); ++ } ++ ++ void * input_ptr = input_contig.data_ptr(); ++ void * weight_ptr = weightT_contig.data_ptr(); ++ void * output_ptr = output_contig.data_ptr(); ++ ++ KDNN::TensorInfo inputTensor = getKDNNTensor(input_contig); ++ KDNN::TensorInfo weightTensor = getKDNNTensor(weightT_contig); ++ KDNN::TensorInfo outputTensor = getKDNNTensor(output_contig); ++ ++ KDNN::Gemm gemmLayer(inputTensor, weightTensor, outputTensor); ++ gemmLayer.Run(input_ptr, weight_ptr, output_ptr); ++ return; ++} ++ ++Tensor kdnn_linear(const Tensor& input, const Tensor& weight, const std::optional& bias_opt) ++{ ++ c10::MaybeOwned bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt); ++ const Tensor& bias = *bias_maybe_owned; ++ const int64_t dim = input.dim(); ++ TORCH_CHECK( ++ input.dim() != 0, ++ "kdnn_linear: input needs to has dim at least 1, input dim ", ++ input.dim()); ++ ++ // reshape first if input dim != 2 and the reshape will cost a memory copy. ++ auto input_reshaped = ++ dim == 2 ? input : input.reshape({-1, input.size(input.dim() - 1)}); ++ ++ at::Tensor output; ++ if (bias.defined()) { ++ kdnn_linear_forward(input_reshaped, weight, output, bias); ++ } else { ++ kdnn_linear_forward(input_reshaped, weight, output); ++ } ++ ++ if (dim != 2) { ++ auto input_size = input.sizes(); ++ std::vector output_size(input_size.begin(), input_size.end() - 1); ++ output_size.push_back(weight.size(0)); ++ return output.reshape(output_size); ++ } ++ return output; ++} ++ ++} // namespace native ++} // namespace at ++#endif // AT_KDNN_ENABLED +diff --git a/aten/src/ATen/native/kdnn/kdnn_norm.cpp b/aten/src/ATen/native/kdnn/kdnn_norm.cpp +new file mode 100644 +index 0000000000..bdfbf263a7 +--- /dev/null ++++ b/aten/src/ATen/native/kdnn/kdnn_norm.cpp +@@ -0,0 +1,242 @@ ++#include ++#include ++#include ++#include ++#include ++#include ++#include ++#include ++#include ++#include ++#include ++#include ++#include ++#include ++#include ++#include ++#include ++#include ++ ++#ifndef AT_PER_OPERATOR_HEADERS ++#include ++#else ++#include ++#include ++#endif ++ ++namespace at::native{ ++using scalar_t = float; ++static const std::unordered_map SupportedDtypeForNorm = { ++ {ScalarType::Float, KDNN::Element::TypeT::F32}, ++ {ScalarType::Half, KDNN::Element::TypeT::F16}, ++}; ++ ++static bool is_all_zero(const torch::Tensor& tensor) { ++ if (tensor.numel() == 0) return false; // 空张量直接返回false ++ return torch::all(tensor == 0).item(); ++} ++ ++static bool is_all_one(const torch::Tensor& tensor) { ++ if (tensor.numel() == 0) return false; ++ return torch::all(tensor == 1).item(); ++} ++ ++bool isSupportedDtypeForNorm(const Tensor& input, const Tensor& gamma, const Tensor& beta) ++{ ++ at::ScalarType dtypes[] = {input.scalar_type(), gamma.scalar_type(), beta.scalar_type()}; ++ for (at::ScalarType dtype : dtypes) { ++ if (SupportedDtypeForNorm.find(dtype) == SupportedDtypeForNorm.end()) { ++ return false; ++ } ++ } ++ return true; ++} ++ ++static KDNN::NormalizationFlags checkScaleAndShiftUsableThenSetFlags(const Tensor& scale, const Tensor& shift) ++{ ++ bool usable_scale = scale.numel() != 0; ++ bool usable_shift = shift.numel() != 0; ++ bool is_default_data = is_all_zero(scale) && is_all_one(shift); ++ if (!usable_scale && !usable_shift) { ++ return KDNN::NormalizationFlags::NONE; ++ } ++ if (is_default_data) { ++ return KDNN::NormalizationFlags::NONE; ++ } ++ if (!usable_scale && usable_shift) { ++ return KDNN::NormalizationFlags::USE_SHIFT; ++ } ++ if (!usable_shift && usable_scale) { ++ return KDNN::NormalizationFlags::USE_SCALE; ++ } ++ return KDNN::NormalizationFlags::USE_SCALE | KDNN::NormalizationFlags::USE_SHIFT; ++} ++ ++void kdnn_group_norm_forward( ++ const Tensor& input, ++ Tensor& output, ++ const Tensor& gamma, ++ const Tensor& beta, ++ float eps, ++ int64_t group, ++ int64_t N, ++ int64_t C, ++ int64_t HxW, ++ Tensor& mean, ++ Tensor& rstd) ++{ ++ TORCH_CHECK(input.numel() == N * C * HxW); ++ TORCH_CHECK(gamma.numel() == C, "size of weight should be equal to number of channels"); ++ ++ KDNN::NormalizationFlags flags = checkScaleAndShiftUsableThenSetFlags(gamma, beta); ++ KDNN::TensorInfo srcInfo = at::native::getKDNNTensor(input); ++ KDNN::TensorInfo dstInfo = at::native::getKDNNTensor(output); ++ auto weightType = SupportedDtypeForNorm.find(gamma.scalar_type())->second; ++ KDNN::TensorInfo weightBiasInfo = {{static_cast(C)}, weightType, KDNN::Layout::A}; ++ ++ KDNN::GroupNormalizationLayerFWD gnormLayer(srcInfo, weightBiasInfo, group, dstInfo, flags); // declare gnorm layer ++ gnormLayer.Run(input.contiguous().data_ptr(), output.contiguous().data_ptr(), ++ gamma.contiguous().data_ptr(), beta.contiguous().data_ptr(), ++ static_cast(mean.contiguous().data_ptr()), static_cast(rstd.contiguous().data_ptr()), ++ true, eps); ++} ++ ++void kdnn_layer_norm_forward(at::Tensor& out, at::Tensor& mean, at::Tensor& rstd, ++ const at::Tensor& input, IntArrayRef normalized_shape, const Tensor& gamma, ++ const Tensor& beta, double eps, int64_t M, int64_t N) ++{ ++ float eps_val = static_cast(eps); // convert eps to float ++ // Get the length of the normalized shape ++ int64_t normalized_shape_len = normalized_shape.size(); ++ // Get the sizes of the input tensor ++ auto InputSizes = input.sizes(); ++ ++ // Calculate the number of dimensions of the input tensor minus the length of the normalized shape ++ int64_t dims = input.ndimension() - normalized_shape_len; ++ ++ // Create a vector to store the input shape ++ std::vector input_shape(InputSizes.begin(), InputSizes.begin() + dims); ++ ++ // Create a vector to store the shape to reshape ++ std::vector shape_to_reshape = input_shape; ++ // Calculate the number of elements to normalize and add it to the reshape vector ++ shape_to_reshape.push_back(N); ++ ++ // Create a KDNN::Shape object to store the information of dimensions doesn't need to normalize ++ KDNN::Shape statInfoShape(input_shape.data(), input_shape.size()); ++ ++ // Find the layout information ++ KDNN::Layout layout = layout_map.find(dims)->second; ++ ++ // Check if gamma and beta are usable and set the normalization flags ++ KDNN::NormalizationFlags flags = checkScaleAndShiftUsableThenSetFlags(gamma, beta); ++ ++ // Create a view of the input tensor for normalization ++ Tensor inRef = input.reshape(shape_to_reshape); ++ ++ // Get the TensorInfo for the input tensor ++ KDNN::TensorInfo srcInfo = at::native::getKDNNTensor(inRef); ++ ++ // Create a view of the output tensor for normalization ++ out = out.reshape(shape_to_reshape); ++ ++ // Get the TensorInfo for the output tensor ++ KDNN::TensorInfo dstInfo = at::native::getKDNNTensor(out); ++ ++ // Create views for gamma and beta for normalization ++ Tensor gammaRef = gamma.reshape({static_cast(N)}); ++ Tensor betaRef = beta.reshape({static_cast(N)}); ++ ++ //Create views for mean and rstd for nomalization ++ mean = mean.reshape(input_shape); ++ rstd = rstd.reshape(input_shape); ++ // Find the weight type ++ const auto weightType = SupportedDtypeForNorm.find(gamma.scalar_type())->second; ++ // Create TensorInfo for weights and biases ++ KDNN::TensorInfo weightBiasInfo = {{static_cast(N)}, weightType, KDNN::Layout::A}; ++ KDNN::TensorInfo statInfo = {statInfoShape, weightType, layout}; ++ ++ // Create the layer normalization layer ++ KDNN::NormalizationLayerFWD lnormLayer(srcInfo, statInfo, weightBiasInfo, dstInfo, flags); ++ ++ // Execute the layer normalization operation ++ lnormLayer.Run(inRef.contiguous().data_ptr(), out.contiguous().data_ptr(), gammaRef.contiguous().data_ptr(), ++ betaRef.contiguous().data_ptr(), static_cast(mean.contiguous().data_ptr()), static_cast(rstd.contiguous().data_ptr()), ++ true, eps_val); ++ //reshape back out, mean and rstd then expand mean and rstd into n dims ++ IntArrayRef reshapeBack(InputSizes); ++ out = out.reshape(reshapeBack); ++ mean = mean.reshape({static_cast(M)}); ++ rstd = rstd.reshape({static_cast(M)}); ++ DimVector stat_shape; ++ for (const auto idx : c10::irange(dims)) { ++ stat_shape.emplace_back(input.sizes()[idx]); ++ } ++ for (const auto idx C10_UNUSED : c10::irange(dims, input.dim())) { ++ stat_shape.emplace_back(1); ++ } ++ mean = mean.reshape(stat_shape); ++ rstd = rstd.reshape(stat_shape); ++} ++ ++Tensor kdnn_rms_norm_forward(const Tensor& input, IntArrayRef normalized_shape,std::optional eps, const std::optional& weight_opt) ++{ ++ int64_t normalized_shape_len = normalized_shape.size(); ++ int64_t unnormalized_shape_len = input.dim() - normalized_shape_len; ++ auto origin_shape = input.sizes(); ++ std::vector new_input_shape(origin_shape.begin(), origin_shape.begin() + unnormalized_shape_len); ++ uint64_t N = 1; ++ for( const auto idx : c10::irange(normalized_shape_len)){ ++ N *= normalized_shape[idx]; ++ } ++ new_input_shape.push_back(N); ++ at::Tensor X = input.reshape(new_input_shape); ++ auto rmsDtype = ScalarType::Half; ++ at::TensorOptions options = X.options().dtype(rmsDtype).device(X.device()); ++ ++ IntArrayRef input_size(X.sizes()); ++ at::Tensor Y = at::zeros(input_size, options); ++ at::Tensor gamma; ++ if (!weight_opt.has_value()) { ++ auto dummy_weight = at::ones({X.size(X.dim() - 1)}, options); ++ gamma = dummy_weight; ++ } else { ++ gamma = weight_opt.value(); ++ gamma = gamma.reshape({static_cast(N)}); ++ gamma = gamma.to(rmsDtype); ++ } ++ scalar_t eps_val; ++ if (!eps.has_value()) { ++ eps_val = std::numeric_limits::type>::epsilon(); ++ } else { ++ eps_val = static_cast(eps.value()); ++ } ++ ++ auto X_contig = X.contiguous(); ++ auto Y_contig = Y.contiguous(); ++ auto gamma_contig = gamma.contiguous(); ++ ++ void *X_ptr = X_contig.data_ptr(); ++ void *Y_ptr = Y_contig.data_ptr(); ++ void *gamma_ptr = gamma_contig.data_ptr(); ++ ++ KDNN::TensorInfo srcInfo = getKDNNTensor(X_contig); ++ KDNN::TensorInfo dstInfo = getKDNNTensor(Y_contig); ++ KDNN::TensorInfo weightInfo = getKDNNTensor(gamma_contig); ++ KDNN::NormalizationFlags flags = KDNN::NormalizationFlags::USE_SCALE; ++ uint64_t statSize = 1; ++ for( const auto idx : c10::irange(unnormalized_shape_len)){ ++ statSize *= input.sizes()[idx]; ++ } ++ new_input_shape.pop_back(); ++ KDNN::Shape statShape(new_input_shape.data(), new_input_shape.size()); ++ float *variance = (float *)malloc(sizeof(float) * statSize); ++ KDNN::TensorInfo statInfo = {statShape, KDNN::Element::TypeT::F16, layout_map.find(unnormalized_shape_len) -> second}; ++ KDNN::RMSNormalizationLayerFWD rmsNormLayer(srcInfo, statInfo, weightInfo, dstInfo, flags); ++ rmsNormLayer.Run(X_ptr, Y_ptr, gamma_ptr, variance, false, eps_val); ++ free(variance); ++ gamma.reshape(normalized_shape); ++ IntArrayRef output_size(origin_shape); ++ return Y.reshape(output_size); ++} ++} +\ No newline at end of file +diff --git a/aten/src/ATen/native/kdnn/kdnn_softmax.cpp b/aten/src/ATen/native/kdnn/kdnn_softmax.cpp +new file mode 100644 +index 0000000000..9dd129b0bb +--- /dev/null ++++ b/aten/src/ATen/native/kdnn/kdnn_softmax.cpp +@@ -0,0 +1,31 @@ ++#include ++#include ++#include ++#include ++#include ++#include ++#include ++#include ++#include ++#include ++#include ++ ++namespace at::native { ++using scalar_t = float; ++void execute_softmax_layer(const Tensor& input, const Tensor& output, KDNN::SoftmaxAlgorithmKind algorithm_kind, int64_t dim) { ++ KDNN::TensorInfo in_info = getKDNNTensor(input); ++ KDNN::TensorInfo ou_info = getKDNNTensor(output); ++ KDNN::SoftmaxLayerFWD softmaxLayerFwd(in_info, ou_info, dim, algorithm_kind); ++ softmaxLayerFwd.Run(input.data_ptr(), output.contiguous().data_ptr()); ++} ++ ++void kdnn_softmax_kernel(const DeviceType device_type, const Tensor& ou, const Tensor& in, int64_t dim) { ++ TORCH_CHECK(dim >= 0 && dim < in.dim(), "dim must be non-negative and less than input dimensions"); ++ execute_softmax_layer(in, ou, KDNN::SoftmaxAlgorithmKind::SOFTMAX, dim); ++} ++ ++void kdnn_log_softmax_kernel(const DeviceType device_type, const Tensor& ou, const Tensor& in, int64_t dim) { ++ TORCH_CHECK(dim >= 0 && dim < in.dim(), "dim must be non-negative and less than input dimensions"); ++ execute_softmax_layer(in, ou, KDNN::SoftmaxAlgorithmKind::LOGSOFTMAX, dim); ++} ++} +\ No newline at end of file +diff --git a/aten/src/ATen/native/layer_norm.cpp b/aten/src/ATen/native/layer_norm.cpp +index b11bcaba38..0ea91da9fe 100644 +--- a/aten/src/ATen/native/layer_norm.cpp ++++ b/aten/src/ATen/native/layer_norm.cpp +@@ -6,6 +6,10 @@ + #include + #include + #include ++#include ++#if AT_KDNN_ENABLED() ++#include ++#endif + + #ifndef AT_PER_OPERATOR_HEADERS + #include +@@ -42,6 +46,20 @@ static void layer_norm_with_mean_rstd_out( + double eps, + int64_t M, + int64_t N) { ++#if AT_KDNN_ENABLED() ++ bool validation = at::native::kdnn_validate_utils::isValidateInputTensor(input) && at::native::kdnn_validate_utils::isValidateInputTensor(gamma) && at::native::kdnn_validate_utils::isValidateInputTensor(beta); ++ bool isUsableDtypeForLNorm = at::native::isSupportedDtypeForNorm(input, gamma, beta); ++ bool mixed_type = is_mixed_type(input, gamma, beta); ++ bool kdnn_usable = at::globalContext().userEnabledKdnn() && validation && isUsableDtypeForLNorm && (!mixed_type) && (eps != 0.0); ++ if(kdnn_usable) { ++ auto mean_float = mean.to(at::kFloat); ++ auto rstd_float = rstd.to(at::kFloat); ++ at::native::kdnn_layer_norm_forward(out, mean_float, rstd_float, input, normalized_shape, gamma, beta, eps, M, N); ++ mean = mean_float.to(input.scalar_type()); ++ rstd = rstd_float.to(input.scalar_type()); ++ return; ++ } ++#endif + LayerNormKernel(kCPU, input, gamma, beta, M, N, eps, &out, &mean, &rstd); + const auto input_shape = input.sizes(); + const size_t axis = input.dim() - normalized_shape.size(); +@@ -282,6 +300,15 @@ Tensor rms_norm( + } + IntArrayRef dims_to_reduce_ref = IntArrayRef(dims_to_reduce); + ++#if AT_KDNN_ENABLED() ++ bool validation = at::native::kdnn_validate_utils::isValidateInputTensor(input); ++ bool isFP16 = input.scalar_type() == at::ScalarType::Half; ++ bool is_kdnn_usable = at::globalContext().userEnabledKdnn() && validation && isFP16; ++ if (is_kdnn_usable) { ++ return at::native::kdnn_rms_norm_forward(input, normalized_shape, eps, weight_opt); ++ } ++#endif ++ + auto result = AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, +diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt +index 9be7f3732f..06b8bd40d0 100644 +--- a/caffe2/CMakeLists.txt ++++ b/caffe2/CMakeLists.txt +@@ -756,6 +756,8 @@ endif() + + if(NOT BUILD_LIBTORCHLESS) + add_library(torch_cpu ${Caffe2_CPU_SRCS}) ++ ++ + if(HAVE_SOVERSION) + set_target_properties(torch_cpu PROPERTIES + VERSION ${TORCH_VERSION} SOVERSION ${TORCH_SOVERSION}) +@@ -1585,6 +1587,20 @@ endif() + + target_link_libraries(torch_cpu PRIVATE flatbuffers) + ++set(KDNN_ROOT_DIR "/usr/local/kdnn" CACHE PATH "Path to KDNN") ++set(BLAS_ROOT_DIR "${PROJECT_SOURCE_DIR}/aten/src/ATen/native/kdnn/dependencies" CACHE PATH "Path to BLAS") ++ ++include("${PROJECT_SOURCE_DIR}/cmake/Modules/FindKDNN.cmake") ++if (KDNN_FOUND) ++ include_directories(${KDNN_INCLUDE_DIRS}) ++ message(STATUS "GTest Library: ${KDNN_LIBRARIES}") ++ message(STATUS "GTest Library headers: ${KDNN_INCLUDE_DIRS}") ++endif () ++target_include_directories(torch_cpu PUBLIC ${KDNN_INCLUDE_DIRS}) ++target_link_libraries(torch_cpu PUBLIC ${KDNN_LIBRARIES}) ++ ++get_target_property(LINK_LIBS torch_cpu INTERFACE_LINK_LIBRARIES) ++ + # Note [Global dependencies] + # Some libraries (e.g. OpenMPI) like to dlopen plugins after they're initialized, + # and they assume that all of their symbols will be available in the global namespace. +diff --git a/cmake/Modules/FindKDNN.cmake b/cmake/Modules/FindKDNN.cmake +new file mode 100644 +index 0000000000..415f18f1c6 +--- /dev/null ++++ b/cmake/Modules/FindKDNN.cmake +@@ -0,0 +1,55 @@ ++# Distributed under the OSI-approved BSD 3-Clause License. See accompanying ++# file Copyright.txt or https://cmake.org/licensing for details. ++ ++# ---------- ++# FindKPL ++# ---------- ++# ++# This module defines the following variables: ++# ++# KDNN_FOUND - True if KPL was found ++# KPL_INCLUDE_DIRS - include directories for KPL ++# KPL_LIBRARIES - link against this library to use KPL ++# ++ ++# Use KDNN_ROOT_DIR environment variable to find the library and headers ++find_path(KDNN_INCLUDE_DIR ++ NAMES kdnn.hpp ++ PATHS ${KDNN_ROOT_DIR} ++ PATH_SUFFIXES include ++ NO_DEFAULT_PATH ++) ++set(AARCH64_CODEGEN_PATH "$ENV{KDNN_SRC_DIRS}/src/aarch64_codegen/") ++ ++find_library(KDNN_LIBRARY ++ NAMES kdnn ++ PATHS ${KDNN_ROOT_DIR} ++ PATH_SUFFIXES lib ++ NO_DEFAULT_PATH ++ ) ++ ++find_library(KPL_BLAS_LIBRARY ++ NAMES kblas ++ PATHS ${BLAS_ROOT_DIR} ++ NO_DEFAULT_PATH ++) ++ ++include(FindPackageHandleStandardArgs) ++find_package_handle_standard_args(KDNN DEFAULT_MSG ++ KDNN_INCLUDE_DIR ++ KDNN_LIBRARY ++ KPL_BLAS_LIBRARY ++) ++ ++mark_as_advanced( ++ KDNN_LIBRARY ++ KPL_BLAS_LIBRARY ++ KDNN_INCLUDE_DIR ++) ++ ++# Find the extra libraries and include dirs ++if(KDNN_FOUND) ++ list(APPEND KDNN_INCLUDE_DIRS ${KDNN_INCLUDE_DIR}) ++ list(APPEND KDNN_LIBRARIES ++ ${KDNN_LIBRARY} ${KPL_BLAS_LIBRARY}) ++endif() +\ No newline at end of file +diff --git a/copy_kml.sh b/copy_kml.sh +new file mode 100644 +index 0000000000..9ad38fd729 +--- /dev/null ++++ b/copy_kml.sh +@@ -0,0 +1,103 @@ ++#!/bin/sh ++ ++# Default parameter initialization ++KML_PATH="" ++TARGET_DIR="./aten/src/ATen/native/kdnn/dependencies" ++MATCH=sme ++MATCH_OPT="" ++COMPILER=gcc ++COMPILER_OPT="" ++ ++# Function to rename the library using patchelf ++function rename_lib() { ++ if [ -f "${TARGET_DIR}"/libkblas.so ]; then ++ patchelf --set-soname libkblas.so ${TARGET_DIR}/libkblas.so ++ else ++ exit 1 ++ fi ++} ++ ++# Function to print help information ++function print_help() { ++ echo "[Usage] sh copy_kml.sh --xxx=xxx" ++ echo "Parameters:" ++ echo "--path=/path/to/HPCKit | absolute_path_HPCKit " ++ echo "--match=[sme(default)|sve] | sme or sve " ++ echo "--compiler=[gcc(default)|clang] | g++ or clang " ++ return ++} ++ ++# Parse command line arguments ++for arg in "$@"; do ++ case "$arg" in ++ --path=*) ++ KML_PATH="${arg#*=}" ++ ;; ++ --match=*) ++ MATCH_OPT="${arg#*=}" ++ ;; ++ --compiler=*) ++ COMPILER_OPT="${arg#*=}" ++ ;; ++ --help|-h) ++ print_help ++ exit ++ ;; ++ *) ++ echo "Unknown parameter: $arg" ++ exit 1 ++ ;; ++ esac ++done ++ ++# Update MATCH based on MATCH_OPT ++if [ "${MATCH_OPT}" = "sve" ]; then ++ MATCH=sve ++fi ++ ++# Update COMPILER based on COMPILER_OPT ++if [ "${COMPILER_OPT}" = "clang" ]; then ++ COMPILER=clang ++fi ++ ++# Construct the KML_PATH ++KML_PATH=${KML_PATH}"/HPCKit/latest/kml/"${COMPILER}"/lib/"${MATCH}"/kblas/multi" ++ ++# Check if the necessary parameter is provided ++if [ -z "$KML_PATH" ]; then ++ echo "Error: Source directory path must be specified using --path=" ++ exit 1 ++fi ++ ++# Check if the source directory exists ++if [ ! -d "$KML_PATH" ]; then ++ echo "Error: Directory does not exist - $KML_PATH" ++ exit 1 ++fi ++ ++# Create the target directory if it does not exist ++mkdir -p "$TARGET_DIR" ++if [ $? -ne 0 ]; then ++ echo "Error: Unable to create target directory - $TARGET_DIR" ++ exit 1 ++fi ++ ++# Execute the copy operation ++echo "Copying *.so files from $KML_PATH to $TARGET_DIR ..." ++cp -v "$KML_PATH"/*.so "$TARGET_DIR"/ ++ ++# Rename the library using patchelf ++rename_lib ++ ++# Add the path of kblas.so to the LD_LIBRARY_PATH environment variable ++export LD_LIBRARY_PATH=${KML_PATH}:${LD_LIBRARY_PATH} ++echo "Added to LD_LIBRARY_PATH" ++echo ${LD_LIBRARY_PATH} ++ ++# Check the result of the copy operation ++if [ $? -eq 0 ]; then ++ echo "Copy completed!" ++else ++ echo "Error: There was a problem during the copy process, please check the path and permissions" ++ exit 1 ++fi +\ No newline at end of file +diff --git a/test/test_embedding.py b/test/test_embedding.py +new file mode 100644 +index 0000000000..f1f960781b +--- /dev/null ++++ b/test/test_embedding.py +@@ -0,0 +1,52 @@ ++import torch ++ ++def test_1d_index_select(): ++ # 设置随机种子以保证结果可重复 ++ torch.manual_seed(42) ++ ++ # 创建一个embedding层,输入维度为5,输出维度为3 ++ embedding = torch.nn.Embedding(5, 3) ++ ++ # 创建输入索引(1维情况) ++ input_indices = torch.tensor([0, 2, 4]) ++ ++ # 使用embedding层获取输出 ++ output = embedding(input_indices) ++ ++ # 手动计算预期结果 ++ weight = embedding.weight ++ expected_output = torch.index_select(weight, 0, input_indices) ++ ++ # 验证结果 ++ assert torch.allclose(output, expected_output), "1D index select test failed" ++ print("1D index select test passed") ++ ++def test_multi_dimensional_index_select(): ++ # 设置随机种子以保证结果可重复 ++ torch.manual_seed(42) ++ ++ # 创建一个embedding层,输入维度为5,输出维度为3 ++ embedding = torch.nn.Embedding(5, 3) ++ ++ # 创建输入索引(多维情况) ++ input_indices = torch.tensor([[0, 2], ++ [1, 4]]) ++ ++ # 使用embedding层获取输出 ++ output = embedding(input_indices) ++ ++ # 手动计算预期结果 ++ weight = embedding.weight ++ expected_output = torch.stack([ ++ torch.index_select(weight, 0, input_indices[0]), ++ torch.index_select(weight, 0, input_indices[1]) ++ ], dim=0) ++ ++ # 验证结果 ++ assert torch.allclose(output, expected_output), "Multi-dimensional index select test failed" ++ print("Multi-dimensional index select test passed") ++ ++if __name__ == "__main__": ++ test_1d_index_select() ++ test_multi_dimensional_index_select() ++ print("All tests passed!") +diff --git a/test/test_kdnn_conv.py b/test/test_kdnn_conv.py +new file mode 100644 +index 0000000000..6c02e0d83a +--- /dev/null ++++ b/test/test_kdnn_conv.py +@@ -0,0 +1,186 @@ ++import torch ++import torch.nn as nn ++import os ++os.environ["OMP_NUM_THREADS"] = "36" ++ ++def create_config(case_name, dtype, batch_size, in_channels, out_channels, kernel_size, stride, padding, input_size_h, input_size_w, bias, depth_size=None): ++ config = { ++ "case_name": case_name, ++ "dtype": dtype, ++ "batch_size": batch_size, ++ "in_channels": in_channels, ++ "out_channels": out_channels, ++ "kernel_size": kernel_size, ++ "stride": stride, ++ "padding": padding, ++ "input_size_h": input_size_h, ++ "input_size_w": input_size_w, ++ "bias": bias ++ } ++ if depth_size is not None: ++ config["depth_size"] = depth_size ++ return config ++ ++# Conv2D Configurations ++conv2d_cases = [ ++ ("test_conv2d_with_bias", torch.float32, 2, 3, 16, 3, 2, 1, 32, 32, True), ++ ("test_conv2d_no_bias", torch.float32, 2, 3, 16, 3, 2, 1, 32, 32, False), ++ ("test_conv2d_fp16", torch.float16, 2, 3, 16, 3, 2, 1, 360, 640, True), ++ ("test_conv2d_fp16", torch.float16, 2, 3, 16, 3, 2, 1, 720, 1280, True) ++] ++ ++conv2d_case_list = [create_config(*args) for args in conv2d_cases] ++ ++# Conv3D Configurations ++conv3d_cases = [ ++ ("test_conv3d_with_bias", torch.float32, 2, 3, 16, 3, 2, 1, 32, 32, True, 2), ++ ("test_conv3d_no_bias", torch.float32, 2, 3, 16, 3, 2, 1, 32, 32, False, 2), ++ ("test_conv3d_fp16", torch.float16, 2, 3, 16, 3, 2, 1, 360, 640, True, 2), ++ ("test_conv3d_fp16", torch.float16, 2, 3, 16, 3, 2, 1, 720, 1280, True, 2) ++] ++ ++conv3d_case_list = [create_config(*args) for args in conv3d_cases] ++ ++####################################################################################### ++ ++def test_conv2d(config): ++ # 每次初始化前重新设定随机种子 ++ torch.manual_seed(42) ++ conv = nn.Conv2d( ++ in_channels=config["in_channels"], ++ out_channels=config["out_channels"], ++ kernel_size=config["kernel_size"], ++ stride=config["stride"], ++ padding=config["padding"], ++ bias=config["bias"] ++ ).to(dtype=config["dtype"]) ++ # 固定权重 ++ nn.init.xavier_normal_(conv.weight) ++ if config["bias"]: ++ nn.init.constant_(conv.bias, 0.1) ++ ++ batch_size = config["batch_size"] ++ in_channels = config["in_channels"] ++ input_size_h = config["input_size_h"] ++ input_size_w = config["input_size_w"] ++ # 生成随机输入 ++ with torch.no_grad(): ++ x = torch.randn(batch_size, in_channels, input_size_h, input_size_w).to(dtype=config["dtype"]) ++ # 原生卷积输出 ++ output = conv(x) ++ return output ++ ++def test_conv3d(config): ++ # 每次初始化前重新设定随机种子 ++ torch.manual_seed(42) ++ conv = nn.Conv3d( ++ in_channels=config["in_channels"], ++ out_channels=config["out_channels"], ++ kernel_size=config["kernel_size"], ++ stride=config["stride"], ++ padding=config["padding"], ++ bias=config["bias"] ++ ).to(dtype=config["dtype"]) ++ # 固定权重 ++ nn.init.xavier_normal_(conv.weight) ++ if config["bias"]: ++ nn.init.constant_(conv.bias, 0.1) ++ ++ batch_size = config["batch_size"] ++ depth_size = config["depth_size"] ++ in_channels = config["in_channels"] ++ input_size_h = config["input_size_h"] ++ input_size_w = config["input_size_w"] ++ # 生成随机输入 ++ with torch.no_grad(): ++ x = torch.randn(batch_size, in_channels, depth_size, input_size_h, input_size_w).to(dtype=config["dtype"]) ++ # 原生卷积输出 ++ output = conv(x) ++ return output ++ ++def compare_traditional(kdnn, native): ++ # 对比结果 ++ diff = (native - kdnn).abs() ++ max_error = diff.max().item() ++ mean_error = diff.mean().item() ++ rmse_error = torch.sqrt(torch.mean(diff**2)).item() ++ return max_error, mean_error, rmse_error ++ ++def compare_kdnn(kdnn, native): ++ diff = torch.abs(native - kdnn) ++ error = diff.sum() / diff.numel() ++ return error ++ ++def run(config, test_conv): ++ # excute conv ++ try: ++ torch._C._set_kdnn_enabled(True) ++ output_kdnn = test_conv(config) ++ torch._C._set_kdnn_enabled(False) ++ output_native = test_conv(config) ++ except RuntimeError as e: ++ print("config list:") ++ for key, value in config.items(): ++ print(f"{key}:{value}") ++ print(e) ++ ++ # initial var ++ case_name = config["case_name"] ++ dtype = config["dtype"] ++ batch_size = config["batch_size"] ++ depth_size = config.get("depth_size", None) # depth only define in 3d ++ in_channels = config["in_channels"] ++ out_channels = config["out_channels"] ++ kernel_size = config["kernel_size"] ++ stride = config["stride"] ++ padding = config["padding"] ++ input_size_h = config["input_size_h"] ++ input_size_w = config["input_size_w"] ++ bias = config["bias"] ++ ++ kdnn_shape = output_kdnn.shape ++ native_shape = output_native.shape ++ max_error, mean_error, rmse_error = compare_traditional(output_kdnn, output_native) ++ ave_error = compare_kdnn(output_kdnn, output_native) ++ # format output ++ print(f""" ++ {'=' * 50} ++ [ 测试用例: {case_name}] ++ {'=' * 50} ++ ++ [ 参数配置 ] ++ - dtype : {dtype} ++ - batch_size : {batch_size} ++ - depth_size : {depth_size} ++ - in_channels : {in_channels} ++ - out_channels : {out_channels} ++ - kernel_size : {kernel_size} ++ - stride : {stride} ++ - padding : {padding} ++ - input_size : {input_size_h}x{input_size_w} ++ - bias : {bias} ++ ++ [ 输出形状验证 ] ++ - KDNN 卷积形状 : {kdnn_shape} ++ - 原生卷积形状 : {native_shape} ++ {'✔' if kdnn_shape == native_shape else '✘'} 形状一致性验证通过 ++ ++ [ 前20个元素对比 ] ++ KDNN : {', '.join(f'{v:.4f}' for v in output_kdnn.flatten()[:20])} ++ Native: {', '.join(f'{v:.4f}' for v in output_native.flatten()[:20])} ++ ++ [ 差异统计 ] ++ - 最大绝对误差 (MAX AE) : {max_error:.6e} ++ - 平均绝对误差 (MAE) : {mean_error:.6e} ++ - 均方根误差 (RMSE) : {rmse_error:.6e} ++ - KDNN平均误差 (AVEE) : {ave_error:.6e} ++ {'✔' if ave_error < 1e-3 else '✘'} 数值一致性验证通过 ++ {'=' * 50} ++ """) ++ ++for case in conv2d_case_list: ++ run(case, test_conv2d) ++ ++for case in conv3d_case_list: ++ run(case, test_conv3d) ++ +diff --git a/test/test_kdnn_group_norm.py b/test/test_kdnn_group_norm.py +new file mode 100644 +index 0000000000..348949348f +--- /dev/null ++++ b/test/test_kdnn_group_norm.py +@@ -0,0 +1,119 @@ ++import torch ++import torch.nn as nn ++import torch.nn.functional as F ++import time ++error_threshold = 1e-3 ++ ++def compare_kdnn(kdnn, native): ++ diff = torch.abs(native - kdnn) ++ error = diff.sum() / diff.numel() ++ return error ++ ++def test_fp32(size, groups, channels): ++ global pass_num ++ global fail_num ++ print("\n===========fp32 start test===========") ++ print(f"size: {size}") ++ ++ # run F32 ++ data = torch.randn(size, dtype = torch.float32) ++ ++ # Enable KDNN ++ torch._C._set_kdnn_enabled(True) ++ ++ start = time.time() ++ gn = nn.GroupNorm(num_groups=groups, num_channels=channels) ++ x_normalized_1 = gn(data) ++ end = time.time() ++ print(f'kdnn run time: {end - start} seconds') ++ #print(x_normalized_1) ++ ++ # Disable KDNN ++ torch._C._set_kdnn_enabled(False) ++ ++ start = time.time() ++ gn = nn.GroupNorm(num_groups=groups, num_channels=channels) ++ x_normalized_2 = gn(data) ++ end = time.time() ++ print(f'origin run time: {end - start} seconds') ++ #print(x_normalized_2) ++ ++ max_err = torch.max(torch.abs(x_normalized_1 - x_normalized_2)) ++ err = compare_kdnn(x_normalized_1, x_normalized_2) ++ print(f"F32 the max absolute err of gnorm between kdnn and native result is: {max_err.item()}") ++ print(f"FP32 average absolute error: {err.item():.10f}") ++ if err > error_threshold: ++ print("Test Failed") ++ fail_num += 1 ++ else: ++ print("Test Passed") ++ pass_num += 1 ++ print("====================================") ++ ++def test_fp16(size, groups, channels): ++ global pass_num ++ global fail_num ++ print("\n===========fp16 start test===========") ++ print(f"size: {size}") ++ ++ # run F16 ++ data = torch.randn(size, dtype = torch.half) ++ ++ # Enable KDNN ++ torch._C._set_kdnn_enabled(True) ++ ++ start = time.time() ++ gn = nn.GroupNorm(num_groups=groups, num_channels=channels, dtype = torch.half) ++ x_normalized_3 = gn(data) ++ end = time.time() ++ print(f'kdnn run time: {end - start} seconds') ++ #print(x_normalized_3) ++ ++ # Disable KDNN ++ torch._C._set_kdnn_enabled(False) ++ ++ start = time.time() ++ gn = nn.GroupNorm(num_groups=groups, num_channels=channels, dtype = torch.half) ++ x_normalized_4 = gn(data) ++ end = time.time() ++ print(f'origin run time: {end - start} seconds') ++ #print(x_normalized_4) ++ ++ max_err = torch.max(torch.abs(x_normalized_3 - x_normalized_4)) ++ print(f"F16 the max absolute err of gnorm between kdnn and native result is: {max_err.item()}") ++ err = compare_kdnn(x_normalized_3, x_normalized_4) ++ print(f"FP16 average absolute error: {err.item():.10f}\n") ++ if err > error_threshold: ++ print("Test Failed") ++ fail_num += 1 ++ else: ++ print("Test Passed") ++ pass_num += 1 ++ print("====================================") ++ ++global pass_num ++global fail_num ++pass_num = 0 ++fail_num = 0 ++ ++if __name__ == "__main__": ++ # fp32 ++ test_fp32((32, 32, 224, 224), 8, 32) ++ test_fp32((4, 128, 18, 320), 2, 128) ++ test_fp32((1, 32, 10, 90, 16), 4, 32) ++ test_fp32((1, 16, 20, 12, 8), 4, 16) ++ test_fp32((2, 8, 256, 256, 72), 4, 8) ++ test_fp32((2, 8, 50, 50, 8), 4, 8) ++ ++ # fp16 ++ test_fp16((32, 32, 224, 224), 8, 32) ++ test_fp16((4, 128, 18, 320), 2, 128) ++ test_fp16((1, 32, 10, 90, 16), 4, 32) ++ test_fp16((1, 16, 20, 12, 8), 4, 16) ++ test_fp16((2, 8, 256, 256, 72), 4, 8) ++ test_fp16((2, 8, 50, 50, 8), 4, 8) ++ ++ print("====================================") ++ print("TEST PASSED:", pass_num) ++ print("TEST FAILED:", fail_num) ++ print("====================================") +diff --git a/test/test_kdnn_layer_norm.py b/test/test_kdnn_layer_norm.py +new file mode 100644 +index 0000000000..c6cdc968f6 +--- /dev/null ++++ b/test/test_kdnn_layer_norm.py +@@ -0,0 +1,126 @@ ++import torch ++import torch.nn as nn ++import torch.nn.functional as F ++import time ++error_threshold = 1e-3 ++def compare_kdnn(kdnn, native): ++ diff = torch.abs(native - kdnn) ++ error = diff.sum() / diff.numel() ++ return error ++ ++def bench_fp32(size): ++ global pass_num ++ global fail_num ++ print("\n===========start test===========") ++ print(f"size: {size}") ++ # 输入数据:[batch_size, channels, height, width] ++ # data = torch.randn(32, 32, 224, 224, dtype=torch.float32) ++ data = torch.randn(size, dtype=torch.float32) ++ # --- 启用 KDNN --- ++ torch._C._set_kdnn_enabled(True) ++ start = time.time() ++ ln = nn.LayerNorm((size[-2], size[-1])) # LayerNorm 需要指定 normalized_shape ++ ln(data) ++ x_normalized_1 = ln(data) ++ end = time.time() ++ print(f'KDNN (FP32) run time: {end - start:.10f} seconds') ++ ++ # --- 禁用 KDNN --- ++ torch._C._set_kdnn_enabled(False) ++ start = time.time() ++ ln = nn.LayerNorm((size[-2], size[-1])) ++ x_normalized_2 = ln(data) ++ end = time.time() ++ print(f'Native (FP32) run time: {end - start:.10f} seconds') ++ ++ # 计算误差 ++ err = torch.max(torch.abs(x_normalized_1 - x_normalized_2)) ++ print(f"FP32 max absolute error: {err.item():.10f}") ++ ++ err = compare_kdnn(x_normalized_1,x_normalized_2) ++ print(f"FP32 average absolute error: {err.item():.10f}") ++ if err > error_threshold: ++ print("Test Failed") ++ fail_num += 1 ++ else: ++ print("Test Passed") ++ pass_num += 1 ++ print("====================================") ++ ++def bench_fp16(size): ++ global pass_num ++ global fail_num ++ print("\n===========start test===========") ++ print(f"size: {size}") ++ # 输入数据(FP16) ++ data = torch.randn(size, dtype=torch.half) ++ ++ # --- 启用 KDNN --- ++ torch._C._set_kdnn_enabled(True) ++ print("KDNN", torch._C._get_kdnn_enabled()) ++ start = time.time() ++ ln = nn.LayerNorm((size[-2], size[-1]),eps=1e-5) ++ ln.weight.data=ln.weight.data.half() ++ ln.bias.data=ln.weight.data.half() ++ print("ln.input.data.dtype: ",data.dtype) ++ print("ln.weight.data.dtype: ",ln.weight.data.dtype) ++ print("ln.bias.data.dtype: ",ln.bias.data.dtype) ++ x_normalized_3 = ln(data) ++ end = time.time() ++ print(f'KDNN (FP16) run time: {end - start:.10f} seconds') ++ ++ # --- 禁用 KDNN --- ++ torch._C._set_kdnn_enabled(False) ++ print("KDNN", torch._C._get_kdnn_enabled()) ++ start = time.time() ++ ln = nn.LayerNorm((size[-2], size[-1]),eps=1e-5) ++ ln.weight.data=ln.weight.data.half() ++ ln.bias.data=ln.weight.data.half() ++ x_normalized_4 = ln(data) ++ end = time.time() ++ print(f'Native (FP16) run time: {end - start:.10f} seconds') ++ ++ # 计算误差 ++ err = torch.max(torch.abs(x_normalized_3 - x_normalized_4)) ++ print(f"FP16 max absolute error: {err.item():.10f}") ++ ++ err = compare_kdnn(x_normalized_3,x_normalized_4) ++ print(f"FP16 average absolute error: {err.item():.10f}\n") ++ if err > error_threshold: ++ print("Test Failed") ++ fail_num += 1 ++ else: ++ print("Test Passed") ++ pass_num += 1 ++ print("====================================") ++ ++global pass_num ++global fail_num ++pass_num = 0 ++fail_num = 0 ++ ++if __name__ == "__main__": ++ print("=== Testing LayerNorm (FP32) ===") ++ bench_fp32((2,256,256,144)) ++ bench_fp32((4,8,10,10)) ++ bench_fp32((2,3,4)) ++ bench_fp32((5,8,25,30)) ++ bench_fp32((3,5,10,25,25)) ++ bench_fp32((2,10,512,512)) ++ bench_fp32((256, 256, 144)) ++ bench_fp32((1,8,256,256,72)) ++ ++ print("\n=== Testing LayerNorm (FP16) ===") ++ bench_fp16((2,256,256,144)) ++ bench_fp16((4,8,10,10)) ++ bench_fp16((2,3,4)) ++ bench_fp16((5,8,25,30)) ++ bench_fp16((3,5,10,25,25)) ++ bench_fp16((2,10,512,512)) ++ bench_fp16((256, 256, 144)) ++ bench_fp16((1,8,256,256,72)) ++ ++ print("====================================") ++ print("TEST PASSED:", pass_num) ++ print("TEST FAILED:", fail_num) ++ print("====================================") +\ No newline at end of file +diff --git a/test/test_kdnn_linear/compare.py b/test/test_kdnn_linear/compare.py +new file mode 100644 +index 0000000000..ae97e13005 +--- /dev/null ++++ b/test/test_kdnn_linear/compare.py +@@ -0,0 +1,102 @@ ++import torch ++import torch.nn as nn ++from config import * ++ ++global pass_num ++global failed_num ++pass_num = 0 ++failed_num = 0 ++ ++def elementwise_error_check_kdnn(tensor1, tensor2, ref_max, K, threshold=0.001): ++ """ ++ 参数: ++ tensor1 (Tensor): 第一个输入张量 ++ tensor2 (Tensor): 第二个输入张量 ++ threshold (float): 误差阈值,默认0.001 ++ ++ 返回: ++ error_mask (BoolTensor): 布尔掩码张量,True表示对应位置误差超过阈值 ++ error_count (int): 超过阈值的元素总数 ++ """ ++ # 形状一致性检查 ++ assert tensor1.shape == tensor2.shape, "张量形状不匹配" ++ eps = threshold * K ++ ++ # 计算相对误差 ++ diff = torch.abs(tensor1 - tensor2) ++ e = diff / ref_max ++ if tensor1.dtype == torch.int32: ++ eps = K / 350 + 1 ++ e = diff ++ # 生成布尔掩码 ++ error_mask = e > eps ++ ++ # 统计超标元素数量 ++ error_count = torch.sum(error_mask).item() ++ ++ return error_mask, error_count ++ ++def compare(tensor_kdnn, tensor_ref, K, i): ++ global pass_num ++ global failed_num ++ # 对比结果 ++ diff = (tensor_ref - tensor_kdnn).abs() ++ print("#"*40) ++ print(f"原生卷积输出形状: {tensor_ref.shape}") ++ print(f"KDNN卷积输出形状: {tensor_kdnn.shape}") ++ print("\n差异统计:") ++ print(f"最大绝对误差: {diff.max().item():.6f}") ++ print(f"平均绝对误差: {diff.mean().item():.6f}") ++ print(f"均方根误差 : {torch.sqrt(torch.mean(diff**2)).item():.6f}") ++ ref_abs = torch.abs(tensor_ref) ++ ref_max = torch.max(ref_abs) ++ print("ref_max={}".format(ref_max)) ++ mask, count = elementwise_error_check_kdnn(tensor_kdnn, tensor_ref, ref_max, K, 0.001) ++ print("误差掩码矩阵:\n", mask.flatten()[:20]) ++ print("超标元素总数:", count) ++ if count == 0 : ++ print("========================================") ++ print(caseNameList[i], "Pass") ++ print("========================================") ++ pass_num += 1 ++ else : ++ print("========================================") ++ print(caseNameList[i], "Failed") ++ print("========================================") ++ failed_num += 1 ++ print("#"*40) ++ ++def run(batches, out_features, in_features): ++ for i in range(len(dtypeList)): ++ print("compare", caseNameList[i], "M ", batches, "N ", out_features, "K ", in_features) ++ head = caseNameList[i]+"_"+str(batches)+"_"+str(out_features)+"_"+str(in_features) ++ # compare ++ output_kdnn_id = head+"_output_kdnn" ++ output_kdnn = kdnn_result_data[output_kdnn_id] ++ ++ output_native_id = head+"_output_native" ++ output_native = native_result_data[output_native_id] ++ compare(output_kdnn, output_native, in_features, i) ++ ++# read kdnn data ++kdnn_result_data = torch.load('linear_outputtensors_kdnn.pt') ++ ++# read native data ++native_result_data = torch.load('linear_outputtensors_native.pt') ++ ++for batches in batches_list: ++ for N_K in NK_List: ++ N = N_K[0] ++ K = N_K[1] ++ run(batches, N, K) ++ ++for shape in specified_shape: ++ batches = shape[0] ++ N = shape[1] ++ K = shape[2] ++ run(batches, N, K) ++ ++print("#"*40) ++print("pass_num:", pass_num) ++print("failed_test:", failed_num) ++print("#"*40) +\ No newline at end of file +diff --git a/test/test_kdnn_linear/config.py b/test/test_kdnn_linear/config.py +new file mode 100644 +index 0000000000..7ff4e71254 +--- /dev/null ++++ b/test/test_kdnn_linear/config.py +@@ -0,0 +1,35 @@ ++import torch ++import torch.nn as nn ++ ++src_Idx = 0 ++wei_Idx = 1 ++dst_Idx = 2 ++bias_Idx = 3 ++ ++dtypeList = [ ++ [torch.float16, torch.float16, torch.float16, torch.float16], ++ [torch.float16, torch.float16, torch.float32, torch.float32], ++ [torch.int8, torch.int8, torch.int32, torch.int32], # w8a8 ++ [torch.float16, torch.int8, torch.int32, torch.float32], # w8a16 ++ [torch.float32, torch.float32, torch.float32, torch.float32], ++] ++ ++caseNameList = [ ++ "fp16_fp16_fp16_fp16", ++ "FP16_FP16_FP32_FP32", ++ "INT8_INT8_INT32_INT32", ++ "FP16_INT8_INT32_FP32", ++ "FP32_FP32_FP32_FP32" ++] ++ ++batches_list = [20, 200, 1000, 10000, 20000] ++NK_List=[ ++ [5,5], ++ [20,10], ++ [100,200] ++] ++ ++specified_shape = [ ++ [4096, 32768, 512], ++ [4096, 3072, 5120] ++] +\ No newline at end of file +diff --git a/test/test_kdnn_linear/generate.py b/test/test_kdnn_linear/generate.py +new file mode 100644 +index 0000000000..9994f971a4 +--- /dev/null ++++ b/test/test_kdnn_linear/generate.py +@@ -0,0 +1,64 @@ ++import torch ++import torch.nn as nn ++from config import * ++ ++data_input = {} ++ ++def generate(batches, out_features, in_features): ++ for i in range(len(dtypeList)): ++ print("=========================") ++ print("test", caseNameList[i], "M ", batches, "N ", out_features, "K ", in_features) ++ ++ generator = torch.Generator().manual_seed(42) ++ input_tensor = torch.empty(batches, in_features) ++ input_tensor.requires_grad = False ++ ++ # 不能随机生成char类型,所以先统一随机生成fp32类型 ++ input_tensor = torch.randn(batches, in_features, generator=generator, requires_grad=False) # fp32 ++ weight_tensor = torch.randn(out_features, in_features, generator=generator, requires_grad=False) # fp32 ++ bias_tensor = torch.randn(out_features, generator=generator, requires_grad=False) # fp32 ++ src_dtype = dtypeList[i][src_Idx] ++ wei_dtype = dtypeList[i][wei_Idx] ++ bias_dtype = dtypeList[i][bias_Idx] ++ ++ if src_dtype == torch.int8: ++ input_tensor = torch.randint(-127, 127, (batches, in_features), generator=generator) ++ else: ++ input_tensor = input_tensor.to(src_dtype) ++ if wei_dtype == torch.int8: ++ weight_tensor = torch.randint(-127, 127, (out_features, in_features), generator=generator) ++ else: ++ weight_tensor = weight_tensor.to(wei_dtype) ++ if bias_dtype == torch.int32: ++ bias_tensor = torch.randint(-65536, 65536, (out_features,), generator=generator) ++ else: ++ bias_tensor = bias_tensor.to(bias_dtype) ++ ++ head = caseNameList[i]+"_"+str(batches)+"_"+str(out_features)+"_"+str(in_features) ++ input_tensor_id = head+"_input_tensor" ++ weight_tensor_id = head+"_weight_tensor" ++ bias_tensor_id = head+"_bias_tensor" ++ ++ data_input[input_tensor_id] = input_tensor ++ data_input[weight_tensor_id] = weight_tensor ++ data_input[bias_tensor_id] = bias_tensor ++ print("============",i,"============") ++ print("save input:",input_tensor_id) ++ print("save wei:",weight_tensor_id) ++ print("save bias:",bias_tensor_id) ++ ++####################################################################################################### ++ ++for batches in batches_list: ++ for N_K in NK_List: ++ N = N_K[0] ++ K = N_K[1] ++ generate(batches, N, K) ++ ++for shape in specified_shape: ++ batches = shape[0] ++ N = shape[1] ++ K = shape[2] ++ generate(batches, N, K) ++ ++torch.save(data_input, "linear_inputtensors.pt") +\ No newline at end of file +diff --git a/test/test_kdnn_linear/kdnn_cal.py b/test/test_kdnn_linear/kdnn_cal.py +new file mode 100644 +index 0000000000..ee07780394 +--- /dev/null ++++ b/test/test_kdnn_linear/kdnn_cal.py +@@ -0,0 +1,69 @@ ++import sys ++import torch ++import torch.nn as nn ++import pickle ++from datetime import datetime ++from config import * ++ ++data_output = {} ++ ++def test_linear_kdnn(turn, input, wei, bias, batches, out_features, in_features): ++ src_test_dtype = dtypeList[turn][src_Idx] ++ input = input.to(src_test_dtype) ++ # 创建Linear层 ++ linear = nn.Linear(in_features, out_features, bias=True) ++ linear.weight.requires_grad = False ++ linear.bias.requires_grad = False ++ # 手动设置确定参数 ++ with torch.no_grad(): ++ wei_dtype = dtypeList[turn][wei_Idx] ++ linear.weight.data = wei.to(wei_dtype) ++ bias_dtype = dtypeList[turn][dst_Idx] ++ linear.bias.data = bias.to(bias_dtype) ++ ++ out = linear(input) ++ return out ++ ++def cal_kdnn(batches, out_features, in_features): ++ for i in range(len(dtypeList)): ++ print("=========================") ++ print("test", caseNameList[i], "M ", batches, "N ", out_features, "K ", in_features) ++ # get input ++ head = caseNameList[i]+"_"+str(batches)+"_"+str(out_features)+"_"+str(in_features) ++ input_tensor_id = head+"_input_tensor" ++ weight_tensor_id = head+"_weight_tensor" ++ bias_tensor_id = head+"_bias_tensor" ++ ++ input_tensor = loaded_data[input_tensor_id] ++ weight_tensor = loaded_data[weight_tensor_id] ++ bias_tensor = loaded_data[bias_tensor_id] ++ ++ # cal ++ torch._C._set_kdnn_enabled(True) ++ print("KDNN:", torch._C._get_kdnn_enabled()) ++ start_time_kdnn = datetime.now() ++ output_kdnn= test_linear_kdnn(i, input_tensor, weight_tensor, bias_tensor, batches, out_features, in_features) ++ end_time_kdnn = datetime.now() ++ elapsed_time_kdnn = end_time_kdnn - start_time_kdnn ++ print("kdnn elapsed_time:", elapsed_time_kdnn) ++ # save ++ output_kdnn_id = head+"_output_kdnn" ++ data_output[output_kdnn_id] = output_kdnn ++ print("save output:", output_kdnn_id) ++ ++####################################################################################################### ++loaded_data = torch.load('linear_inputtensors.pt') ++ ++for batches in batches_list: ++ for N_K in NK_List: ++ N = N_K[0] ++ K = N_K[1] ++ cal_kdnn(batches, N, K) ++ ++for shape in specified_shape: ++ batches = shape[0] ++ N = shape[1] ++ K = shape[2] ++ cal_kdnn(batches, N, K) ++ ++torch.save(data_output, "linear_outputtensors_kdnn.pt") +diff --git a/test/test_kdnn_linear/native_cal.py b/test/test_kdnn_linear/native_cal.py +new file mode 100644 +index 0000000000..9ad18bbd11 +--- /dev/null ++++ b/test/test_kdnn_linear/native_cal.py +@@ -0,0 +1,62 @@ ++import sys ++import torch ++import torch.nn as nn ++import pickle ++from datetime import datetime ++from config import * ++ ++data_output = {} ++ ++def test_linear_native(turn, input, wei, bias, batches, out_features, in_features): ++ # 创建Linear层 ++ linear = nn.Linear(in_features, out_features, bias=True) ++ linear.weight.requires_grad = False ++ linear.bias.requires_grad = False ++ # 手动设置确定参数 ++ with torch.no_grad(): ++ linear.weight.data = wei.to(torch.float32) ++ linear.bias.data = bias.to(torch.float32) ++ out = linear(input.to(torch.float32)) ++ return out ++ ++def cal_native(batches, out_features, in_features): ++ for i in range(len(dtypeList)): ++ print("=========================") ++ print("test", caseNameList[i], "M ", batches, "N ", out_features, "K ", in_features) ++ # get input ++ head = caseNameList[i]+"_"+str(batches)+"_"+str(out_features)+"_"+str(in_features) ++ input_tensor_id = head+"_input_tensor" ++ weight_tensor_id = head+"_weight_tensor" ++ bias_tensor_id = head+"_bias_tensor" ++ ++ input_tensor = loaded_data[input_tensor_id] ++ weight_tensor = loaded_data[weight_tensor_id] ++ bias_tensor = loaded_data[bias_tensor_id] ++ ++ # cal ++ start_time_kdnn = datetime.now() ++ output_native= test_linear_native(i, input_tensor, weight_tensor, bias_tensor, batches, out_features, in_features) ++ end_time_kdnn = datetime.now() ++ elapsed_time_kdnn = end_time_kdnn - start_time_kdnn ++ print("native elapsed_time:", elapsed_time_kdnn) ++ # save ++ output_native_id = head+"_output_native" ++ data_output[output_native_id] = output_native ++ print("save output:", output_native_id) ++ ++####################################################################################################### ++loaded_data = torch.load('linear_inputtensors.pt') ++ ++for batches in batches_list: ++ for N_K in NK_List: ++ N = N_K[0] ++ K = N_K[1] ++ cal_native(batches, N, K) ++ ++for shape in specified_shape: ++ batches = shape[0] ++ N = shape[1] ++ K = shape[2] ++ cal_native(batches, N, K) ++ ++torch.save(data_output, "linear_outputtensors_native.pt") +diff --git a/test/test_kdnn_rms_norm.py b/test/test_kdnn_rms_norm.py +new file mode 100644 +index 0000000000..b25300d2e9 +--- /dev/null ++++ b/test/test_kdnn_rms_norm.py +@@ -0,0 +1,130 @@ ++import torch ++import torch.nn as nn ++import datetime ++ ++class RMSNorm(nn.Module): ++ def __init__(self, normalized_shape, eps=1e-6, weight=None): ++ super(RMSNorm, self).__init__() ++ if isinstance(normalized_shape, int): ++ normalized_shape = (normalized_shape,) ++ self.normalized_shape = normalized_shape ++ self.eps = eps ++ if weight is not None: ++ self.weight = weight ++ else: ++ self.weight = nn.Parameter(torch.ones(normalized_shape)) ++ self.reset_parameters() ++ ++ def reset_parameters(self): ++ if isinstance(self.weight, nn.Parameter): ++ nn.init.ones_(self.weight) ++ ++ def forward(self, x): ++ # 确保输入数据类型为 fp16 或 fp32 ++ assert x.dtype in [torch.float16, torch.float32], "Input must be fp16 or fp32" ++ ++ # 计算均方根 ++ dims = tuple(range(-len(self.normalized_shape), 0)) ++ ++ # 如果输入是 fp16,先提升到 fp32 进行计算,然后转换回 fp16 ++ if x.dtype == torch.float16: ++ x_fp32 = x.float() # 转换为 fp32 ++ rms = torch.sqrt(torch.mean(x_fp32 ** 2, dim=dims, keepdim=True) + self.eps) ++ normalized_x = x_fp32 / rms # 归一化 ++ normalized_x = normalized_x.half() # 转换回 fp16 ++ else: ++ rms = torch.sqrt(torch.mean(x ** 2, dim=dims, keepdim=True) + self.eps) ++ normalized_x = x / rms # 归一化 ++ ++ # 应用权重 ++ normalized_x = normalized_x * self.weight ++ ++ return normalized_x ++ ++def compareDiff(tensor, ref): ++ mae_loss = nn.L1Loss() ++ mae = mae_loss(tensor, ref) ++ print(f"MAE: {mae.item()}") ++ diff = torch.abs(tensor - ref) ++ me = torch.mean(diff) ++ print(f"平均绝对误差: {me.item()}") ++ mean = torch.mean(tensor - ref) ++ print(f"平均误差: {mean.item()}") ++ ++def helper(size, normalized_shape, subweight, epsVal): ++ print("\n===========start test===========") ++ print(f"size: {size}, normalized_shape:{normalized_shape}") ++ ++ # 测试朴素fp32 ++ input_tensor = torch.randn(size, dtype = torch.float32) ++ starttime = datetime.datetime.now() ++ rms_ref = RMSNorm(normalized_shape = normalized_shape, eps = epsVal) ++ rms_ref.weight.data.fill_(subweight) ++ output_ref = rms_ref(input_tensor) ++ endtime = datetime.datetime.now() ++ time1 = (endtime - starttime) ++ print(f"ref dtype F32,ref F32 time: {time1.total_seconds():.6f} seconds") ++ ++ # 测试原生fp32 ++ torch._C._set_kdnn_enabled(False) ++ starttime = datetime.datetime.now() ++ rms_torch = nn.RMSNorm(normalized_shape = normalized_shape, eps = epsVal) ++ rms_torch.weight.data.fill_(subweight) ++ output_torch = rms_torch(input_tensor) ++ endtime = datetime.datetime.now() ++ time2 = (endtime - starttime) ++ print(f"torch dtype F32,pytorch-origin time: {time2.total_seconds():.6f} seconds") ++ ++ # 测试朴素fp16 ++ input_tensor_ref_fp16 = input_tensor.half() ++ starttime = datetime.datetime.now() ++ rms_ref_f16 = RMSNorm(normalized_shape = normalized_shape, eps = epsVal) ++ rms_ref.weight.data.fill_(subweight) ++ output_ref_f16 = rms_ref_f16(input_tensor_ref_fp16) ++ endtime = datetime.datetime.now() ++ time3 = (endtime - starttime) ++ print(f"ref dtype F16,ref F16 time: {time3.total_seconds():.6f} seconds") ++ ++ # 测试原生fp16 ++ starttime = datetime.datetime.now() ++ rms_torch_f16 = nn.RMSNorm(normalized_shape = normalized_shape, eps = epsVal) ++ rms_torch_f16.weight.data.fill_(subweight) ++ output_torch_f16 = rms_torch_f16(input_tensor_ref_fp16) ++ endtime = datetime.datetime.now() ++ time4 = (endtime - starttime) ++ print(f"torch dtype F16,pytorch-origin time: {time4.total_seconds():.6f} seconds") ++ ++ # 测试kdnnfp16 ++ torch._C._set_kdnn_enabled(True) ++ starttime = datetime.datetime.now() ++ rms_kdnn = nn.RMSNorm(normalized_shape = normalized_shape, eps = epsVal) ++ rms_kdnn.weight.data.fill_(subweight) ++ output_kdnn = rms_kdnn(input_tensor_ref_fp16) ++ endtime = datetime.datetime.now() ++ time5 = (endtime - starttime) ++ print(f"kdnn dtype F16,pytorch-kdnn time: {time5.total_seconds():.6f} seconds") ++ ++ # 打印结果 ++ print("===========kdnn f16 vs native f32 ref===========") ++ compareDiff(output_kdnn, output_ref) ++ ++ print("===========kdnn f16 vs pytorch f32 ref==========") ++ compareDiff(output_kdnn, output_torch) ++ ++ print("===========kdnn f16 vs native f16 ref===========") ++ compareDiff(output_kdnn, output_ref_f16) ++ ++ print("===========native f16 ref vs pytorch f16 ref===========") ++ compareDiff(output_kdnn, output_torch_f16) ++ ++# 示例用法 ++if __name__ == "__main__": ++ helper((4,8,10,10), (10,10), 1.5, 1e-5) ++ helper((4,8,10,10), (10), 1.1, 1e-6) ++ helper((2,3,4), (4), 1, 1e-5) ++ helper((5,8,25,30), (25,30), 1.6, 1e-5) ++ helper((3,5,10,25,25), (25,25), 1.2, 1e-5) ++ helper((3,5,10,25,25), (10,25,25), 1, 1e-6) ++ helper((2,10,512,512), (512), 1.8, 1e-5) ++ helper((2,256,256,144),(256,144), 1.15235, 1e-5) ++ helper((1,8,256,256,72), (256,72), 1.327689541, 1e-8) +\ No newline at end of file +diff --git a/test/test_kdnn_softmax_accuracy.py b/test/test_kdnn_softmax_accuracy.py +new file mode 100644 +index 0000000000..12cf6b502f +--- /dev/null ++++ b/test/test_kdnn_softmax_accuracy.py +@@ -0,0 +1,46 @@ ++import torch ++import torch.nn.functional as F ++ ++data = torch.randn(18, 16, 3600, 3600) ++print("*" * 10) ++torch._C._set_kdnn_enabled(True) ++prob1 = F.softmax(data, dim = 0) ++print("*" * 10) ++ ++torch._C._set_kdnn_enabled(False) ++prob2 = F.softmax(data, dim = 0) ++print("*" * 10) ++ ++error = torch.abs(prob1 - prob2).mean() ++print(f"dim = 0 The max absolute error of softmax between kdnn and native result is:{error.item()}") ++ ++torch._C._set_kdnn_enabled(True) ++prob3 = F.softmax(data, dim = 1) ++ ++torch._C._set_kdnn_enabled(False) ++prob4 = F.softmax(data, dim = 1) ++print("*" * 10) ++ ++error = torch.max(torch.abs(prob3 - prob4)) ++print(f"dim = 1 The max absolute error of softmax between kdnn and native result is:{error.item()}") ++ ++torch._C._set_kdnn_enabled(True) ++prob5 = F.softmax(data, dim = 2) ++ ++torch._C._set_kdnn_enabled(False) ++prob6 = F.softmax(data, dim = 2) ++print("*" * 10) ++ ++error = torch.max(torch.abs(prob5 - prob6)) ++print(f"dim = 2 The max absolute error of softmax between kdnn and native result is:{error.item()}") ++ ++torch._C._set_kdnn_enabled(True) ++prob7 = F.softmax(data, dim = 3) ++ ++torch._C._set_kdnn_enabled(False) ++prob8 = F.softmax(data, dim = 3) ++print("*" * 10) ++ ++error = torch.max(torch.abs(prob7 - prob8)) ++print(f"dim = 3 The max absolute error of softmax between kdnn and native result is:{error.item()}") ++ +diff --git a/test/test_kdnn_softmax_performance.py b/test/test_kdnn_softmax_performance.py +new file mode 100644 +index 0000000000..d0e6180913 +--- /dev/null ++++ b/test/test_kdnn_softmax_performance.py +@@ -0,0 +1,32 @@ ++import torch ++import torch.nn.functional as F ++import time ++ ++# Enable KDNN ++torch._C._set_kdnn_enabled(True) ++ ++# Test KDNN softmax with different tensor shapes ++shapes = [ ++ (18, 16, 3600, 3600), ++ (7200, 16, 9, 9), ++ (4, 1, 14400, 512), ++ (4, 14400, 512) ++] ++ ++dims = [3, 3, 3, 2] ++ ++for shape, dim in zip(shapes, dims): ++ print("*" * 10) ++ print(f"Testing shape: {shape}") ++ data = torch.randn(shape) ++ ++ for _ in range(5): ++ k = dim ++ while k > -1: ++ start = time.time() ++ prob = F.softmax(data, dim=k) ++ end = time.time() ++ print(f'dim {k}:') ++ print(f'kdnn run time: {end - start} seconds') ++ k -= 1 ++ print("*" * 10) +diff --git a/torch/backends/kdnn/__init__.py b/torch/backends/kdnn/__init__.py +new file mode 100644 +index 0000000000..edd14c32cc +--- /dev/null ++++ b/torch/backends/kdnn/__init__.py +@@ -0,0 +1,33 @@ ++import sys ++import torch ++from contextlib import contextmanager ++from torch.backends import ContextProp, PropModule, __allow_nonbracketed_mutation ++ ++def is_available(): ++ r"""Returns whether PyTorch is built with KDNN support.""" ++ return torch._C.has_kdnn ++ ++def set_flags(_enabled): ++ orig_flags = (torch._C._get_kdnn_enabled(),) ++ torch._C._set_kdnn_enabled(_enabled) ++ return orig_flags ++ ++@contextmanager ++def flags(enabled=False): ++ with __allow_nonbracketed_mutation(): ++ orig_flags = set_flags(enabled) ++ try: ++ yield ++ finally: ++ with __allow_nonbracketed_mutation(): ++ set_flags(orig_flags[0]) ++ ++class KdnnModule(PropModule): ++ def __init__(self, m, name): ++ super(KdnnModule, self).__init__(m, name) ++ ++ enabled = ContextProp(torch._C._get_kdnn_enabled, torch._C._set_kdnn_enabled) ++ ++# Cool stuff from torch/backends/cudnn/__init__.py and ++# https://stackoverflow.com/questions/2447353/getattr-on-a-module/7668273#7668273 ++sys.modules[__name__] = KdnnModule(sys.modules[__name__], __name__) +\ No newline at end of file +diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp +index 8b0a44c45c..fe95c558a0 100644 +--- a/torch/csrc/Module.cpp ++++ b/torch/csrc/Module.cpp +@@ -833,6 +833,25 @@ PyObject* THPModule_userEnabledMkldnn(PyObject* _unused, PyObject* noargs) { + Py_RETURN_FALSE; + } + ++PyObject* THPModule_setUserEnabledKdnn(PyObject* _unused, PyObject* arg) { ++ HANDLE_TH_ERRORS ++ TORCH_CHECK( ++ PyBool_Check(arg), ++ "set_enabled_kdnn expects a bool, " ++ "but got ", ++ THPUtils_typename(arg)); ++ at::globalContext().setUserEnabledKdnn(arg == Py_True); ++ Py_RETURN_NONE; ++ END_HANDLE_TH_ERRORS ++} ++ ++PyObject* THPModule_userEnabledKdnn(PyObject* _unused, PyObject* noargs) { ++ if (at::globalContext().userEnabledKdnn()) ++ Py_RETURN_TRUE; ++ else ++ Py_RETURN_FALSE; ++} ++ + PyObject* THPModule_setDeterministicCuDNN(PyObject* _unused, PyObject* arg) { + HANDLE_TH_ERRORS + TORCH_CHECK( +@@ -1407,6 +1426,8 @@ static PyMethodDef TorchMethods[] = { // NOLINT + {"_set_cudnn_enabled", THPModule_setUserEnabledCuDNN, METH_O, nullptr}, + {"_get_mkldnn_enabled", THPModule_userEnabledMkldnn, METH_NOARGS, nullptr}, + {"_set_mkldnn_enabled", THPModule_setUserEnabledMkldnn, METH_O, nullptr}, ++ {"_get_kdnn_enabled", THPModule_userEnabledKdnn, METH_NOARGS, nullptr}, ++ {"_set_kdnn_enabled", THPModule_setUserEnabledKdnn, METH_O, nullptr}, + {"_get_cudnn_allow_tf32", THPModule_allowTF32CuDNN, METH_NOARGS, nullptr}, + {"_set_cudnn_allow_tf32", THPModule_setAllowTF32CuDNN, METH_O, nullptr}, + {"_get_cudnn_benchmark", THPModule_benchmarkCuDNN, METH_NOARGS, nullptr}, +@@ -2199,7 +2220,7 @@ Call this whenever a new thread is created in order to propagate values from + ASSERT_TRUE(set_module_attr("_has_xpu", has_xpu)); + ASSERT_TRUE( + set_module_attr("_has_mkldnn", at::hasMKLDNN() ? Py_True : Py_False)); +- ++ ASSERT_TRUE(set_module_attr("_has_kdnn", at::hasKDNN() ? Py_True : Py_False)); + #ifdef _GLIBCXX_USE_CXX11_ABI + ASSERT_TRUE(set_module_attr( + "_GLIBCXX_USE_CXX11_ABI", _GLIBCXX_USE_CXX11_ABI ? Py_True : Py_False)); \ No newline at end of file