From 6af9d8572ca1be9d7c50c44e69224448c7b03a86 Mon Sep 17 00:00:00 2001 From: lyl <452929843@qq.com> Date: Wed, 30 Jul 2025 14:42:44 +0800 Subject: [PATCH] kpex code migration --- README.en.md | 4 + README.md | 50 +++++++++ csrc/kpex.cpp | 11 ++ csrc/tpp/alphafold/bind.h | 21 ++++ csrc/tpp/alphafold/gating_attention.cpp | 109 +++++++++++++++++++ csrc/tpp/alphafold/gating_attention.h | 45 ++++++++ csrc/utils/bf16.h | 20 ++++ csrc/utils/check.h | 50 +++++++++ csrc/utils/memory.cpp | 33 ++++++ csrc/utils/memory.h | 37 +++++++ kpex/__init__.py | 1 + kpex/frontend.py | 0 kpex/tpp/__init__.py | 1 + kpex/tpp/alphafold/__init__.py | 1 + kpex/tpp/alphafold/alphafold.py | 139 ++++++++++++++++++++++++ setup.py | 57 ++++++++++ 16 files changed, 579 insertions(+) create mode 100644 README.en.md create mode 100644 README.md create mode 100644 csrc/kpex.cpp create mode 100644 csrc/tpp/alphafold/bind.h create mode 100644 csrc/tpp/alphafold/gating_attention.cpp create mode 100644 csrc/tpp/alphafold/gating_attention.h create mode 100644 csrc/utils/bf16.h create mode 100644 csrc/utils/check.h create mode 100644 csrc/utils/memory.cpp create mode 100644 csrc/utils/memory.h create mode 100644 kpex/__init__.py create mode 100644 kpex/frontend.py create mode 100644 kpex/tpp/__init__.py create mode 100644 kpex/tpp/alphafold/__init__.py create mode 100644 kpex/tpp/alphafold/alphafold.py create mode 100644 setup.py diff --git a/README.en.md b/README.en.md new file mode 100644 index 0000000..8570e49 --- /dev/null +++ b/README.en.md @@ -0,0 +1,4 @@ +# kunpeng-extension-for-pytorch + +#### Description +kunpeng-extension-for-pytorch is a package for extending the official Pytorch that can easily obtain performance on Kunpeng platform diff --git a/README.md b/README.md new file mode 100644 index 0000000..6281289 --- /dev/null +++ b/README.md @@ -0,0 +1,50 @@ +# kunpeng-extension-for-pytorch + +#### 介绍 +kunpeng-extension-for-pytorch is a package for extending the official Pytorch that can easily obtain performance on Kunpeng platform + +#### 安装教程 + +**一、环境设置** + +1.安装pytorch + +2.加载kutacc环境 + +``` +module use /xxx/xxx/xxx/xxx/xxx/modulefiles +module load xxx/xxx/kutacc +``` + +3.安装依赖 + +``` +pip install ninja=1.11.1.1 pybind11==2.11.1 +``` +**二、执行安装命令** +``` +CFLAGS="-stdlib=libc++ -lc++abi" KPEX_BUILD_TYPE=release KUTACC_ROOT=/xxx/xxx/kutacc pip install --editable . +``` +#### 卸载教程 +``` +pip uninstall kunpeng-pytorch-extension -y +``` +#### 接口说明 +接口名: +kpex_alphafold + +接口描述: +通过替换模型中的算子来优化afphafold2的模型 + +接口参数: +|参数名|类型|描述|输入/输出| +|-----|----|-----|---------| +|model|模型类|alphafold2模型|输入| +|model_config|配置类|alphafold2模型的config|输入| +|new_model|模型类|优化后的alphafold2模型|输出| + +#### 使用说明 + +``` +model = kpex.tpp.alphafold.alphafold.kpex_alphafold(model, model_config) +``` \ No newline at end of file diff --git a/csrc/kpex.cpp b/csrc/kpex.cpp new file mode 100644 index 0000000..e6d96b6 --- /dev/null +++ b/csrc/kpex.cpp @@ -0,0 +1,11 @@ +// Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved + +#include "tpp/alphafold/bind.h" +#include "utils/memory.h" +#include + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + alphafold::bind(m); + m.def("device", &kpex::device); +} \ No newline at end of file diff --git a/csrc/tpp/alphafold/bind.h b/csrc/tpp/alphafold/bind.h new file mode 100644 index 0000000..fb3a5cb --- /dev/null +++ b/csrc/tpp/alphafold/bind.h @@ -0,0 +1,21 @@ +// Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved + +#pragma once + +#include + +#include "gating_attention.h" + +namespace alphafold { +inline void bind(pybind11::module &m) +{ + auto submodule = m.def_submodule("alphafold"); + py::class_(submodule, "GatingAttentionWeight") + .def(py::init(), + py::arg("query_w"), py::arg("key_w"), py::arg("value_w"), py::arg("gate_w"), py::arg("gate_b"), + py::arg("output_w"), py::arg("output_b")); + submodule.def("gating_attention", &gating_attention, py::arg("q_data"), py::arg("m_data"), py::arg("bias"), + py::arg("nonbatched_bias"), py::arg("weights"), py::arg("block_size") = std::nullopt); +} +} \ No newline at end of file diff --git a/csrc/tpp/alphafold/gating_attention.cpp b/csrc/tpp/alphafold/gating_attention.cpp new file mode 100644 index 0000000..58e4459 --- /dev/null +++ b/csrc/tpp/alphafold/gating_attention.cpp @@ -0,0 +1,109 @@ +// Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved + +#include "gating_attention.h" +#include "utils/bf16.h" +#include +#include +#include +#include +#include + +namespace alphafold { +namespace { +int64_t default_block_size(int64_t seq_len) +{ + if (seq_len < 300) { + return 176; + } else if (seq_len < 600) { + return 128; + } else if (seq_len < 800) { + return 80; + } else if (seq_len < 1300) { + return 64; + } else if (seq_len < 1700) { + return 48; + } else { + return 32; + } +} +} + +GatingAttentionWeight::GatingAttentionWeight(at::Tensor &query_w, at::Tensor &key_w, at::Tensor &value_w, + at::Tensor &gating_w, at::Tensor &gating_b, at::Tensor &output_w, at::Tensor &output_b) + :nchannels(query_w.sizes()[2]), nheads(query_w.sizes()[0]), head_size(query_w.sizes()[1]) +{ + KPEX_CHECK(nchannels == nheads * head_size, "invalid query_w shape [", nchannels, ", ", nheads, ", ", head_size, + "]"); + KPEX_CHECK_TENSOR_SHAPE(query_w, nheads, head_size, nchannels); + KPEX_CHECK_TENSOR_SHAPE(key_w, nheads, head_size, nchannels); + KPEX_CHECK_TENSOR_SHAPE(value_w, nheads, head_size, nchannels); + KPEX_CHECK_TENSOR_SHAPE(gating_w, nheads, head_size,nchannels); + KPEX_CHECK_TENSOR_SHAPE(gating_b, nheads, head_size); + KPEX_CHECK_TENSOR_SHAPE(output_w, nchannels, nheads, head_size); + KPEX_CHECK_TENSOR_SHAPE(output_b, nchannels); + + auto float_opt = query_w.options().device(kpex::device()).dtype(c10::kFloat); + auto bf16_opt = query_w.options().device(kpex::device()).dtype(c10::kBFloat16); + query_w = query_w.to(bf16_opt).contiguous().view({nchannels, nchannels}); + key_w = key_w.to(bf16_opt).contiguous().view({nchannels, nchannels}); + value_w = value_w.to(bf16_opt).contiguous().view({nchannels, nchannels}); + gating_w = gating_w.to(bf16_opt).contiguous().view({nchannels, nchannels}); + output_w = output_w.to(bf16_opt).contiguous().view({nchannels, nchannels}); + + this->query_w = query_w; + this->key_w = key_w; + this->value_w = value_w; + this->gating_w = gating_w; + this->output_w = output_w; + + this->gating_b = gating_b.to(float_opt).contiguous(); + this->output_b = output_b.to(float_opt).contiguous(); +} + +at::Tensor gating_attention(at::Tensor &q_data, at::Tensor &m_data, at::Tensor &bias, at::Tensor &nonbatched_bias, + const GatingAttentionWeight &weights, std::optional block_size) +{ + at::Tensor out = at::empty(q_data.sizes(), q_data.options()); + int64_t batch = q_data.sizes()[0]; + int64_t seq_len = q_data.sizes()[1]; + int64_t nchannels = weights.nchannels; + int64_t nheads = weights.nheads; + int64_t head_size = weights.head_size; + int64_t block_size_ = default_block_size(seq_len); + block_size_ = block_size.value_or(block_size_); + + RECORD_FUNCTION("gating_attention", c10::ArrayRef({batch, seq_len, nheads, head_size})); + KPEX_CHECK(q_data.dtype() == c10::kBFloat16, q_data.dtype()); + KPEX_CHECK(m_data.dtype() == c10::kBFloat16, m_data.dtype()); + KPEX_CHECK(bias.dtype() == c10::kBFloat16, bias.dtype()); + KPEX_CHECK(nonbatched_bias.dtype() == c10::kBFloat16, bias.dtype()); + KPEX_CHECK_TENSOR_SHAPE(q_data, batch, seq_len, nchannels); + KPEX_CHECK_TENSOR_SHAPE(bias, batch, 1, 1, seq_len); + if (nonbatched_bias.sizes()[0] != 0) { + KPEX_CHECK_TENSOR_SHAPE(nonbatched_bias, nheads, seq_len, seq_len); + } + + bias = bias.contiguous(); + nonbatched_bias = nonbatched_bias.contiguous(); + + auto q = q_data.new_empty({batch, seq_len, nheads, head_size}); + auto k = q_data.new_empty({batch, seq_len, nheads, head_size}); + auto v = q_data.new_empty({nheads, head_size, batch, seq_len}); + auto gate = q_data.new_empty({batch, seq_len, nheads, head_size}); + auto weighted_avg = q_data.new_empty({batch, seq_len, nheads, head_size}); + at::Tensor input; + { + RECORD_FUNCTION("input_prepack", c10::ArrayRef({})); + input = q_data.view({batch * seq_len, nchannels}); + } + kutacc::gating_attention( + batch, seq_len, nchannels, nheads, head_size, block_size_, + bias.data_ptr(), bias.strides().vec(), nonbatched_bias.data_ptr(), nonbatched_bias.sizes().vec(), nonbatched_bias.strides().vec(), + input.data_ptr(), out.data_ptr(), out.strides().vec(), gate.data_ptr(), k.data_ptr(), v.data_ptr(), q.data_ptr(), q.strides().vec(), + gate.strides().vec(), v.strides().vec(), k.strides().vec(), weights.value_w.data_ptr(), + weighted_avg.data_ptr(), weighted_avg.strides().vec(), weights.query_w.data_ptr(), weights.key_w.data_ptr(), + weights.gating_w.data_ptr(), weights.gating_b.data_ptr(),weights.output_w.data_ptr(), weights.output_b.data_ptr() + ); + return out; +} +} // namespace alphafold \ No newline at end of file diff --git a/csrc/tpp/alphafold/gating_attention.h b/csrc/tpp/alphafold/gating_attention.h new file mode 100644 index 0000000..48e85a3 --- /dev/null +++ b/csrc/tpp/alphafold/gating_attention.h @@ -0,0 +1,45 @@ +// Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved + +#pragma once + +#include "utils/check.h" +#include + +namespace alphafold{ +struct GatingAttentionWeight { + int64_t nchannels; + int64_t nheads; + int64_t head_size; + + at::Tensor query_w; //dtype=bf16 + at::Tensor key_w; //dtype=bf16 + at::Tensor value_w; //dtype=bf16 + at::Tensor gating_w; //dtype=bf16 + at::Tensor gating_b; //dtype=bf16 + at::Tensor output_w; //dtype=bf16 + at::Tensor output_b; //dtype=bf16 + + /** + * @param query_w shape [nheads, head_size, nchannels], dtype=any + * @param key_w shape [nheads, head_size, nchannels], dtype=any + * @param value_w shape [nheads, head_size, nchannels], dtype=any + * @param gating_w shape [nchannels, nheads, head_size], dtype=any + * @param gating_b shape [nheads, head_size], dtype=any + * @param output_w shape [nchannels, nheads, head_size], dtype=any + * @param output_b shape [nchannels], dtype=any + */ + GatingAttentionWeight (at::Tensor &query_w, at::Tensor &key_w, at::Tensor &value_w, at::Tensor &gating_w, + at::Tensor &gating_b, at::Tensor &output_w, at::Tensor &output_b); +}; + +/** + * @param q_data shape [batch, seq_len, nchannels], bf16 + * @param m_data shape [batch, seq_len, nchannels], bf16 + * @param bias shape [batch, 1, 1, seq_len], bf16 + * @param nonbatched_bias shape [nheads, seq_len, seq_len] or [0], bf16 + */ +at::Tensor gating_attention(at::Tensor &q_data, at::Tensor &m_data, at::Tensor &bias, at::Tensor &nonbatched_bias, + const GatingAttentionWeight &weights, std::optional block_size); +} // namespace alphafold + + diff --git a/csrc/utils/bf16.h b/csrc/utils/bf16.h new file mode 100644 index 0000000..e25e557 --- /dev/null +++ b/csrc/utils/bf16.h @@ -0,0 +1,20 @@ +// Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved + +// 临时文件,当前版本 bisheng 的 bf16支持较差,用 neon 函数代替 + +#pragma once + +#include +#include + +namespace kpex { +static inline __bf16 to_bf16(float x) +{ + return vcvth_bf16_f32(x); +} + +static inline float to_float(__bf16 x) +{ + return vcvtah_f32_bf16(x); +} +} // namespace kpex \ No newline at end of file diff --git a/csrc/utils/check.h b/csrc/utils/check.h new file mode 100644 index 0000000..5d68539 --- /dev/null +++ b/csrc/utils/check.h @@ -0,0 +1,50 @@ +// Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved + +#pragma once + +#include +#include +#include +#include +#include + +namespace kpex { +namespace internal { +inline void check_fail_print(std::stringstream &stream) +{} + +template +inline void check_fail_print(std::stringstream &stream, Arg &&arg, Rest &&...rest) +{ + stream << std::forward(arg); + check_fail_print(stream, rest...); +} + +template +inline void check_fail(std::string func, std::string file, int line, Args &&...args) +{ + std::stringstream stream; + stream << "KPEX_CHECK fail in " << func << " at " << file << ":" << line << ", "; + check_fail_print(stream, std::forward(args)...); + stream << "\n"; + std::cerr << stream.str(); + abort(); +} +} //namespace internal +} //namespace kpex + +#define KPEX_CHECK(condition, ...) \ + do { \ + if (__builtin_expect(!(condition), 0)) { \ + kpex::internal::check_fail(__func__, __FILE__, __LINE__, __VA_ARGS__); \ + } \ + } while (0) + +#define KPEX_CHECK_TENSOR_SHAPE(tensor, ...) \ + KPEX_CHECK((tensor).sizes() == c10::IntArrayRef({__VA_ARGS__}), "invalid tensor shape: ", (tensor).sizes(), \ + ", expect: ", c10::IntArrayRef({__VA_ARGS__})) + +#define KPEX_CHECK_TENSORWRAPPER_SHAPE(tensor, ...) \ + KPEX_CHECK(c10::IntArrayRef((tensor).sizes) == c10::IntArrayRef({__VA_ARGS__}), \ + "invalid tensor wrapper shape: ", c10::IntArrayRef((tensor).sizes), \ + ", expect: ", c10::IntArrayRef({__VA_ARGS__})) \ No newline at end of file diff --git a/csrc/utils/memory.cpp b/csrc/utils/memory.cpp new file mode 100644 index 0000000..a3c7d96 --- /dev/null +++ b/csrc/utils/memory.cpp @@ -0,0 +1,33 @@ +#include +#include +#include +#include +#ifdef USE_HBM +#include +#endif + +bool kpex_use_hbm() { +#ifdef USE_HBM + return true; +#else + return false; +#endif +} + +int kpex_posix_memalign(void **memptr, size_t alignment, size_t size) +{ +#ifdef USE_HBM + return memkind_posix_memalign(MEMKIND_HBW_HUGETLB, memptr, alignment, size); +#else + return posix_memalign(memptr, alignment, size); +#endif +} + +void kpex_free(void *ptr) +{ +#ifdef USE_HBM + memkind_free(MEMKIND_HBW_HUGETLB, ptr); +#else + return free(ptr); +#endif +} diff --git a/csrc/utils/memory.h b/csrc/utils/memory.h new file mode 100644 index 0000000..46aa327 --- /dev/null +++ b/csrc/utils/memory.h @@ -0,0 +1,37 @@ +// Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved + +#pragma once + +#include +#include +#include +#include +#include + +bool kpex_use_hbm(); + +int kpex_posix_memalign(void **memptr, size_t alignment, size_t size); + +void kpex_free(void *ptr); + +namespace kpex { +inline c10::Device device() { + return c10::Device(c10::kCPU, (int)kpex_use_hbm()); +} + +template +struct KpexMallocDeleter { + void operator()(T *ptr) const + { + kpex_free(ptr); + } +}; + +template +inline std::unique_ptr > alloc(int64_t size) +{ + void *ptr; + kpex_posix_memalign(&ptr, 64, size * sizeof(T)); + return std::unique_ptr >((T *)ptr); +} +} // namespace kpex \ No newline at end of file diff --git a/kpex/__init__.py b/kpex/__init__.py new file mode 100644 index 0000000..56ffa47 --- /dev/null +++ b/kpex/__init__.py @@ -0,0 +1 @@ +from . import tpp \ No newline at end of file diff --git a/kpex/frontend.py b/kpex/frontend.py new file mode 100644 index 0000000..e69de29 diff --git a/kpex/tpp/__init__.py b/kpex/tpp/__init__.py new file mode 100644 index 0000000..54122cd --- /dev/null +++ b/kpex/tpp/__init__.py @@ -0,0 +1 @@ +from . import alphafold \ No newline at end of file diff --git a/kpex/tpp/alphafold/__init__.py b/kpex/tpp/alphafold/__init__.py new file mode 100644 index 0000000..54122cd --- /dev/null +++ b/kpex/tpp/alphafold/__init__.py @@ -0,0 +1 @@ +from . import alphafold \ No newline at end of file diff --git a/kpex/tpp/alphafold/alphafold.py b/kpex/tpp/alphafold/alphafold.py new file mode 100644 index 0000000..7896026 --- /dev/null +++ b/kpex/tpp/alphafold/alphafold.py @@ -0,0 +1,139 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved + +import copy +import time +import types + +import torch +from torch import nn +import torch.distributed as dist +import kpex._C as kernel + + +class Fast_GatingAttention(nn.Module): + def __init__(self, config, global_config, a_dim, m_dim, output_dim): + super().__init__() + self.config = config + self.global_config = global_config + self.output_dim = output_dim + # k,v dim + self.key_dim = self.config.get('key_dim', int(a_dim)) + self.value_dim = self.config.get('value_dim', int(m_dim)) + self.num_head = self.config['num_head'] + assert self.key_dim % self.num_head == 0 + assert self.value_dim % self.num_head == 0 + self.key_dim = self.key_dim // self.num_head + self.value_dim = self.value_dim // self.num_head + # q,k,v weights + self.query_w = nn.Parameter(torch.Tensor(a_dim,self.num_head,self.key_dim),requires_grad=False) + self.key_w = nn.Parameter(torch.Tensor(m_dim,self.num_head,self.key_dim),requires_grad=False) + self.value_w = nn.Parameter(torch.Tensor(m_dim,self.num_head,self.value_dim),requires_grad=False) + self.gating_w = nn.Parameter(torch.Tensor(a_dim,self.num_head,self.value_dim),requires_grad=False) + self.gating_b = nn.Parameter(torch.Tensor(self.num_head,self.value_dim),requires_grad=False) + self.output_w = nn.Parameter(torch.Tensor(self.num_head,self.value_dim, self.output_dim),requires_grad=False) + self.output_b = nn.Parameter(torch.Tensor(self.output_dim),requires_grad=False) + # softmax & act fn + self.softmax = nn.Softmax(dim=-1) + self.sigmoid = nn.Sigmoid() + + @torch.jit.ignore + def read_time(self) -> float: + return time.time() + + def forward(self, q_data, m_data, bias, nonbatched_bias=torch.Tensor(), block_size=None): + if not hasattr(self, "kpex_weights"): + self.kpex_weights = kernel.alphafold.GatingAttentionWeight( + self.query_w.permute(1, 2, 0), + self.key_w.permute(1, 2, 0), + self.value_w.permute(1, 2, 0), + self.gating_w.permute(1, 2, 0), + self.gating_b, + self.output_w.permute(2, 0, 1), + self.output_b, + ) + act = q_data.to(torch.bfloat16) + out = kernel.alphafold.gating_attention( + act, + act, + bias.to(torch.bfloat16), + nonbatched_bias.to(torch.bfloat16), + self.kpex_weights, + block_size, + ) + return out + + +def gating_attention_forward(self, q_data, m_data, bias, nonbatched_bias=torch.Tensor(), block_size=None): + if not hasattr(self, "kpex_weights"): + self.kpex_weights = kernel.alphafold.GatingAttentionWeight( + self.query_w.permute(1, 2, 0), + self.key_w.permute(1, 2, 0), + self.value_w.permute(1, 2, 0), + self.gating_w.permute(1, 2, 0), + self.gating_b, + self.output_w.permute(2, 0, 1), + self.output_b, + ) + act = q_data.to(torch.bfloat16) + out = kernel.alphafold.gating_attention( + act, + act, + bias.to(torch.bfloat16), + nonbatched_bias.to(torch.bfloat16), + self.kpex_weights, + block_size, + ) + return out + + + +def kpex_alphafold(model, model_config, dtype=torch.float): + new_model = copy.deepcopy(model) + evoformer = new_model.model.impl.evoformer + + if hasattr(evoformer, "extra_msa_stack"): + for block in evoformer.extra_msa_stack: + block.msa_row_attention_with_pair_bias.attention.forward = types.MethodType( + gating_attention_forward, + block.msa_row_attention_with_pair_bias.attention + ) + block.triangle_attention_starting_node.attention.forward = types.MethodType( + gating_attention_forward, + block.triangle_attention_starting_node.attention + ) + block.triangle_attention_ending_node.attention.forward = types.MethodType( + gating_attention_forward, + block.triangle_attention_ending_node.attention + ) + if hasattr(evoformer, "evoformer_iteration"): + for block in evoformer.evoformer_iteration: + block.msa_row_attention_with_pair_bias.attention.forward = types.MethodType( + gating_attention_forward, + block.msa_row_attention_with_pair_bias.attention + ) + block.msa_column_attention.attention.forward = types.MethodType( + gating_attention_forward, + block.msa_column_attention.attention + ) + block.triangle_attention_starting_node.attention.forward = types.MethodType( + gating_attention_forward, + block.triangle_attention_starting_node.attention + ) + block.triangle_attention_ending_node.attention.forward = types.MethodType( + gating_attention_forward, + block.triangle_attention_ending_node.attention + ) + if hasattr(evoformer, "template_embedding"): + template_pair_sub_stack = evoformer.template_embedding.single_template_embedding.template_pair_stack.template_pair_sub_stack + for block in template_pair_sub_stack: + block.triangle_attention_starting_node.attention.forward = types.MethodType( + gating_attention_forward, + block.triangle_attention_starting_node.attention + ) + block.triangle_attention_ending_node.attention.forward = types.MethodType( + gating_attention_forward, + block.triangle_attention_ending_node.attention + ) + return new_model + + diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..5b21ca5 --- /dev/null +++ b/setup.py @@ -0,0 +1,57 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved + +from setuptools import setup, find_packages +from torch.utils import cpp_extension +import glob +from pathlib import Path +import os + +root = Path(__file__).parent + +sources = [ + "csrc/*.cpp", + "csrc/utils/*.cpp", + "csrc/aten/*.cpp", + "csrc/comm/*.cpp", + "csrc/comm/local/*.cpp", + "csrc/tpp/alphafold/*.cpp", +] + +sources = [j for i in sources for j in glob.glob(i)] +extra_compile_args = [] +include_dirs = [(root / "csrc").as_posix()] + +extra_compile_args += ["-fopenmp"] +if os.environ["KPEX_BUILD_TYPE"] == "release": + extra_compile_args += ["-O3"] +elif os.environ["KPEX_BUILD_TYPE"] == "debug": + extra_compile_args += ["-O0", "-g"] +else: + print("requires env KPEX_BUILD_TYPE (release / debug)") + +library_dirs = [] +libraries = [] +KUTACC_ROOT = os.environ.get("KUTACC_ROOT", None) +if KUTACC_ROOT: + include_dirs += [f"{KUTACC_ROOT}/include"] + library_dirs = [f"{KUTACC_ROOT}/lib"] + libraries = ["kutacc"] + extra_compile_args += [f"-Wl, -rpath={KUTACC_ROOT}/lib"] + +setup( + name="kunpeng-pytorch-extension", + version="0.0.1", + author="KPEX", + packages=find_packages(), + ext_modules=[ + cpp_extension.CppExtension( + name="kpex._C", + sources=sources, + include_dirs=include_dirs, + libraries=libraries, + library_dirs=library_dirs, + extra_compile_args={"cxx": extra_compile_args}, + ) + ], + cmdclass={"build_ext": cpp_extension.BuildExtension} +) -- Gitee