登录
注册
开源
企业版
高校版
搜索
帮助中心
使用条款
关于我们
开源
企业版
高校版
私有云
模力方舟
登录
注册
代码拉取完成,页面将自动刷新
开源项目
>
人工智能
>
机器学习/深度学习
&&
捐赠
捐赠前请先登录
取消
前往登录
扫描微信二维码支付
取消
支付完成
支付提示
将跳转至支付宝完成支付
确定
取消
Watch
不关注
关注所有动态
仅关注版本发行动态
关注但不提醒动态
69
Star
298
Fork
184
MindSpore
/
akg
代码
Issues
17
Pull Requests
37
Wiki
统计
流水线
服务
质量分析
Jenkins for Gitee
腾讯云托管
腾讯云 Serverless
悬镜安全
阿里云 SAE
Codeblitz
SBOM
我知道了,不再自动展开
更新失败,请稍后重试!
移除标识
内容风险标识
本任务被
标识为内容中包含有代码安全 Bug 、隐私泄露等敏感信息,仓库外成员不可访问
[AIKG] 高性能算子知识库共建
TODO
#ID248M
Task-Tracking
Yanzhi_YI
成员
创建于
2025-10-15 23:40
| name | about | labels | | ---- | ----------------------------------- | --------- | | Task | 高性能算子代码生成优化知识共建 | kind/task | <!-- Thanks for sending an issue! Here are some tips for you: 1) If this is your first time, please read our contributor guidelines: https://gitee.com/mindspore/mindspore/blob/master/CONTRIBUTING.md 2) If you want to get the answer quickly, please add label `mindspore-assistant` to the issue, we will find it and answer you as soon as possible. --> ## 问题描述 AIKG(AI-driven Kernel Generator)是一个基于大模型驱动的 AI 算子生成工具,虽然能够生成**功能正确**的 CUDA C 和 C++ 算子代码,但生成代码的性能往往不够理想。 <br/> <br/> --- ## 典型案例 ### 案例 1:CUDA C - Reduce 算子 **❌ 当前可能生成的简单实现**(几乎没有并行) ```cuda __global__ void simple_reduce(float* input, float* output, int n) { int idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx == 0) { float sum = 0; for (int i = 0; i < n; i++) { sum += input[i]; // 串行累加,性能极差 } output[0] = sum; } } ``` **✅ 高性能实现的其中一种方案示例** ```cuda __global__ void optimized_reduce(float* input, float* output, int n) { __shared__ float shared_data[256]; // 1️⃣ Thread-level reduce: 每个线程处理多个元素 int tid = threadIdx.x; int idx = blockIdx.x * blockDim.x * 4 + threadIdx.x; float thread_sum = 0; #pragma unroll for (int i = 0; i < 4; i++) { if (idx + i * blockDim.x < n) { thread_sum += input[idx + i * blockDim.x]; } } // 2️⃣ Warp-level reduce: 利用 warp shuffle 指令 #pragma unroll for (int offset = 16; offset > 0; offset /= 2) { thread_sum += __shfl_down_sync(0xffffffff, thread_sum, offset); } // 3️⃣ Block-level reduce: 通过 shared memory if (tid % 32 == 0) { shared_data[tid / 32] = thread_sum; } __syncthreads(); // 4️⃣ Final reduce and atomic write if (tid < 8) { thread_sum = (tid < blockDim.x / 32) ? shared_data[tid] : 0; for (int offset = 4; offset > 0; offset /= 2) { thread_sum += __shfl_down_sync(0xff, thread_sum, offset); } if (tid == 0) { atomicAdd(output, thread_sum); } } } ``` **性能差异**: 优化版本可能快 **50-100x** ### 核心问题 **简单的文档提示不足以让 LLM 主动生成复杂的性能优化代码** LLM 倾向于生成最直接的朴素实现,缺少: - 多级优化(如 reduce 的 thread + warp + block 多级归约) - 内存访问优化(shared memory、memory coalescing) - 并行优化(向量化、warp primitives) --- ### 案例 2:C++ - 向量化操作 **❌ 当前可能生成的简单实现** ```cpp torch::Tensor relu_forward(torch::Tensor x) { auto output = torch::empty_like(x); auto x_ptr = x.data_ptr<float>(); auto out_ptr = output.data_ptr<float>(); int64_t n = x.numel(); // 朴素标量循环 for (int64_t i = 0; i < n; i++) { out_ptr[i] = x_ptr[i] > 0 ? x_ptr[i] : 0; } return output; } ``` **✅ 高性能实现的其中一种方案示例** ```cpp #include <arm_neon.h> // ARM NEON 指令集 torch::Tensor relu_forward_optimized(torch::Tensor x) { auto output = torch::empty_like(x); auto x_ptr = x.data_ptr<float>(); auto out_ptr = output.data_ptr<float>(); int64_t n = x.numel(); // 1️⃣ 向量化主循环: 每次处理 4 个 float (NEON) int64_t vec_size = 4; int64_t vec_end = (n / vec_size) * vec_size; float32x4_t zero_vec = vdupq_n_f32(0.0f); for (int64_t i = 0; i < vec_end; i += vec_size) { float32x4_t x_vec = vld1q_f32(&x_ptr[i]); uint32x4_t mask = vcgtq_f32(x_vec, zero_vec); float32x4_t result = vreinterpretq_f32_u32(vandq_u32( vreinterpretq_u32_f32(x_vec), mask)); vst1q_f32(&out_ptr[i], result); } // 2️⃣ 处理剩余元素 for (int64_t i = vec_end; i < n; i++) { out_ptr[i] = x_ptr[i] > 0 ? x_ptr[i] : 0; } return output; } ``` **性能差异**: 向量化版本可能快 **4-8x**(取决于 CPU 架构) --- ## 共建高性能算子知识库 为了提升 AIKG 对各类 DSL 的代码生成性能,我们诚挚欢迎 **SIG 成员、学生、算子专家、DSL 开发者等社区开发者**通过提供高质量的文档和示例,帮助 AIKG 生成更高性能的算子代码! ### 参与方式 1. **贡献高性能示例代码** - 为各类 DSL(CUDA C、C++、Triton、TileLang 等)提供高性能实现示例 - 包含对应的优化思路注释和性能数据 2. **完善性能优化文档** - 补充 DSL 的性能优化模式文档(如多级 reduce、tiling、向量化等) - 添加算法选择建议和最佳实践(如卷积算法选择、内存访问优化等) 3. **优化参考文档** - 改进 DSL 基础文档、API 说明、专家建议文档 - 使文档更易于 LLM 理解和学习高性能实现模式 4. **改进 AIKG 生成流程** - 提出更好的代码生成策略和 Prompt 设计 - 优化文档驱动接入(DDI)机制 - 分享性能分析和优化经验
| name | about | labels | | ---- | ----------------------------------- | --------- | | Task | 高性能算子代码生成优化知识共建 | kind/task | <!-- Thanks for sending an issue! Here are some tips for you: 1) If this is your first time, please read our contributor guidelines: https://gitee.com/mindspore/mindspore/blob/master/CONTRIBUTING.md 2) If you want to get the answer quickly, please add label `mindspore-assistant` to the issue, we will find it and answer you as soon as possible. --> ## 问题描述 AIKG(AI-driven Kernel Generator)是一个基于大模型驱动的 AI 算子生成工具,虽然能够生成**功能正确**的 CUDA C 和 C++ 算子代码,但生成代码的性能往往不够理想。 <br/> <br/> --- ## 典型案例 ### 案例 1:CUDA C - Reduce 算子 **❌ 当前可能生成的简单实现**(几乎没有并行) ```cuda __global__ void simple_reduce(float* input, float* output, int n) { int idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx == 0) { float sum = 0; for (int i = 0; i < n; i++) { sum += input[i]; // 串行累加,性能极差 } output[0] = sum; } } ``` **✅ 高性能实现的其中一种方案示例** ```cuda __global__ void optimized_reduce(float* input, float* output, int n) { __shared__ float shared_data[256]; // 1️⃣ Thread-level reduce: 每个线程处理多个元素 int tid = threadIdx.x; int idx = blockIdx.x * blockDim.x * 4 + threadIdx.x; float thread_sum = 0; #pragma unroll for (int i = 0; i < 4; i++) { if (idx + i * blockDim.x < n) { thread_sum += input[idx + i * blockDim.x]; } } // 2️⃣ Warp-level reduce: 利用 warp shuffle 指令 #pragma unroll for (int offset = 16; offset > 0; offset /= 2) { thread_sum += __shfl_down_sync(0xffffffff, thread_sum, offset); } // 3️⃣ Block-level reduce: 通过 shared memory if (tid % 32 == 0) { shared_data[tid / 32] = thread_sum; } __syncthreads(); // 4️⃣ Final reduce and atomic write if (tid < 8) { thread_sum = (tid < blockDim.x / 32) ? shared_data[tid] : 0; for (int offset = 4; offset > 0; offset /= 2) { thread_sum += __shfl_down_sync(0xff, thread_sum, offset); } if (tid == 0) { atomicAdd(output, thread_sum); } } } ``` **性能差异**: 优化版本可能快 **50-100x** ### 核心问题 **简单的文档提示不足以让 LLM 主动生成复杂的性能优化代码** LLM 倾向于生成最直接的朴素实现,缺少: - 多级优化(如 reduce 的 thread + warp + block 多级归约) - 内存访问优化(shared memory、memory coalescing) - 并行优化(向量化、warp primitives) --- ### 案例 2:C++ - 向量化操作 **❌ 当前可能生成的简单实现** ```cpp torch::Tensor relu_forward(torch::Tensor x) { auto output = torch::empty_like(x); auto x_ptr = x.data_ptr<float>(); auto out_ptr = output.data_ptr<float>(); int64_t n = x.numel(); // 朴素标量循环 for (int64_t i = 0; i < n; i++) { out_ptr[i] = x_ptr[i] > 0 ? x_ptr[i] : 0; } return output; } ``` **✅ 高性能实现的其中一种方案示例** ```cpp #include <arm_neon.h> // ARM NEON 指令集 torch::Tensor relu_forward_optimized(torch::Tensor x) { auto output = torch::empty_like(x); auto x_ptr = x.data_ptr<float>(); auto out_ptr = output.data_ptr<float>(); int64_t n = x.numel(); // 1️⃣ 向量化主循环: 每次处理 4 个 float (NEON) int64_t vec_size = 4; int64_t vec_end = (n / vec_size) * vec_size; float32x4_t zero_vec = vdupq_n_f32(0.0f); for (int64_t i = 0; i < vec_end; i += vec_size) { float32x4_t x_vec = vld1q_f32(&x_ptr[i]); uint32x4_t mask = vcgtq_f32(x_vec, zero_vec); float32x4_t result = vreinterpretq_f32_u32(vandq_u32( vreinterpretq_u32_f32(x_vec), mask)); vst1q_f32(&out_ptr[i], result); } // 2️⃣ 处理剩余元素 for (int64_t i = vec_end; i < n; i++) { out_ptr[i] = x_ptr[i] > 0 ? x_ptr[i] : 0; } return output; } ``` **性能差异**: 向量化版本可能快 **4-8x**(取决于 CPU 架构) --- ## 共建高性能算子知识库 为了提升 AIKG 对各类 DSL 的代码生成性能,我们诚挚欢迎 **SIG 成员、学生、算子专家、DSL 开发者等社区开发者**通过提供高质量的文档和示例,帮助 AIKG 生成更高性能的算子代码! ### 参与方式 1. **贡献高性能示例代码** - 为各类 DSL(CUDA C、C++、Triton、TileLang 等)提供高性能实现示例 - 包含对应的优化思路注释和性能数据 2. **完善性能优化文档** - 补充 DSL 的性能优化模式文档(如多级 reduce、tiling、向量化等) - 添加算法选择建议和最佳实践(如卷积算法选择、内存访问优化等) 3. **优化参考文档** - 改进 DSL 基础文档、API 说明、专家建议文档 - 使文档更易于 LLM 理解和学习高性能实现模式 4. **改进 AIKG 生成流程** - 提出更好的代码生成策略和 Prompt 设计 - 优化文档驱动接入(DDI)机制 - 分享性能分析和优化经验
评论 (
0
)
登录
后才可以发表评论
状态
TODO
TODO
ACCEPTED
WIP
VALIDATION
DONE
CLOSED
REJECTED
负责人
未设置
标签
未设置
项目
未立项任务
未立项任务
里程碑
未关联里程碑
未关联里程碑
Pull Requests
未关联
未关联
关联的 Pull Requests 被合并后可能会关闭此 issue
分支
未关联
分支 (17)
标签 (11)
master
br_aikg
ms_custom_ops
ms_custom_ops_0902_br_infer_iter
code_clean
r2.4
r2.2
r2.3
r2.1
r2.0
r1.9
r1.8
r1.7
r1.6
r1.3
r1.5
r1.2
v2.4.0
v2.1.0
v2.0.0
v1.9.0
v1.8.0
v1.7.0
v1.6.0
v1.5.0
v1.4.0
v1.3.0
v1.2.0
开始日期   -   截止日期
-
置顶选项
不置顶
置顶等级:高
置顶等级:中
置顶等级:低
优先级
不指定
严重
主要
次要
不重要
预计工期
(小时)
参与者(1)
Python
1
https://gitee.com/mindspore/akg.git
git@gitee.com:mindspore/akg.git
mindspore
akg
akg
点此查找更多帮助
搜索帮助
Git 命令在线学习
如何在 Gitee 导入 GitHub 仓库
Git 仓库基础操作
企业版和社区版功能对比
SSH 公钥设置
如何处理代码冲突
仓库体积过大,如何减小?
如何找回被删除的仓库数据
Gitee 产品配额说明
GitHub仓库快速导入Gitee及同步更新
什么是 Release(发行版)
将 PHP 项目自动发布到 packagist.org
仓库举报
回到顶部
登录提示
该操作需登录 Gitee 帐号,请先登录后再操作。
立即登录
没有帐号,去注册