登录
注册
开源
企业版
高校版
搜索
帮助中心
使用条款
关于我们
开源
企业版
高校版
私有云
模力方舟
AI 队友
登录
注册
轻量养虾,开箱即用!低 Token + 稳定算力,Gitee & 模力方舟联合出品的 PocketClaw 正式开售!点击了解详情~
代码拉取完成,页面将自动刷新
仓库状态说明
开源项目
>
人工智能
>
机器学习/深度学习
&&
捐赠
捐赠前请先登录
取消
前往登录
扫描微信二维码支付
取消
支付完成
支付提示
将跳转至支付宝完成支付
确定
取消
Watch
不关注
关注所有动态
仅关注版本发行动态
关注但不提醒动态
2.8K
Star
9.1K
Fork
5.3K
GVP
MindSpore
/
mindspore
关闭
代码
Issues
220
Pull Requests
283
Wiki
统计
流水线
服务
质量分析
Jenkins for Gitee
腾讯云托管
腾讯云 Serverless
悬镜安全
阿里云 SAE
Codeblitz
SBOM
开发画像分析
我知道了,不再自动展开
更新失败,请稍后重试!
移除标识
内容风险标识
本任务被
标识为内容中包含有代码安全 Bug 、隐私泄露等敏感信息,仓库外成员不可访问
[Feature]: ==号支持重载,对齐PTA
TODO
#IDCQ7L
XianglongZeng
成员
创建于
2025-12-12 21:29
### 🚀 背景描述 当前 MindSpore 的 Equal 算子(==)在处理 Tensor 与标量(Scalar)的比较时,可能存在类型提升或回退问题。为了提升易用性并与 PyTorch 行为对齐,需要支持 Python 的 `==` 符号重载,并确保在各个后端高效执行,特别是利用 `eq_scalar` 算子来处理 Tensor 与标量的比较,避免不必要的 Tensor 转换以及同步拷贝带来的性能损失。 ### 设计思路 ### 标杆与接口(Benchmark & API) - 标杆接口:`torch.eq(input, other, *, out=None)` - 功能:逐元素比较 `input` 和 `other`。支持广播。 - MindSpore 接口: - `mindspore.ops.equal(input, other)` - `Tensor.__eq__(other)` ### 任务清单(Tasks) | 序号 | 任务项 | 任务子项 | 状态(新增/修改/无变更/不涉及) | 备注 | | ---- | ------------------ | ----------------- | ------------------------------ | ---- | | 1 | 接口基本功能 | Primitive | 修改 | 引用 eq_scalar_op 和 equal_op | | | | functional | 新增 | api_def/equal.yaml, __eq__.yaml | | | | multitype_ops | 修改 | equal_impl.py 拆分 Tensor/Scalar 逻辑 | | 2 | 后端及数据类型支持 | Ascend | 不涉及 | 已有 Kernel 支持 | | | | GPU | 不涉及 | — | | | | CPU | 不涉及 | — | | 3 | 支持 vmap | | 不涉及 | — | | 4 | 支持动态 Shape | | 不涉及 | — | | 5 | 支持反向 | | 不涉及 | Equal 不可导 | | 6 | 补齐资料 | API 映射 | 新增 | __eq__ 映射到 equal | ### YAML 定义(Definition 摘要) - `mindspore/ops/api_def/equal.yaml`: ```yaml equal: - op_yaml: eq_scalar_op.yaml py_method: tensor_equal Ascend: pyboost CPU: pyboost GPU: pyboost interface: tensor, function - op_yaml: equal_op.yaml py_method: tensor_equal Ascend: pyboost CPU: pyboost GPU: pyboost interface: tensor, function ``` - `mindspore/ops/api_def/__eq__.yaml`: ```yaml __eq__: alias: equal ``` ### 实现方案(Implementation) - **api_def 定义**:通过 `equal.yaml` 定义多态分发,优先匹配 `eq_scalar_op`(Tensor, Number),其次匹配 `equal_op`(Tensor, Tensor)。 - **MultitypeFuncGraph 适配**:修改 `mindspore/python/mindspore/ops/composite/multitype_ops/equal_impl.py`,将 `_tensor_equal_tensor` 拆分为: - `_tensor_equal_scalar`:调用 `ops.eq_scalar(tensor, number)`。 - `_scalar_equal_tensor`:调用 `ops.eq_scalar(tensor, number)`(交换参数)。 - `_tensor_equal_tensor`:调用 `ops.equal(tensor, tensor)`。 - **eq_scalar_op 配置**:修改 `eq_scalar_op.yaml`,启用 functional 接口生成(`function: disable: False`)。 ### 测试方案(Test Plan) - 执行模式与平台: - 模式:`Pynative`、`Graph O0`(KBK) - 平台:`Ascend`、`GPU`、`CPU`(Equal/eq_scalar 为通用基础算子,各后端均需覆盖) - 功能与一致性(Functional & Consistency): 1) 接口覆盖: - `mindspore.ops.equal(input, other)`(functional) - `Tensor.__eq__(other)`(运算符 `==`) - `ops.equal(tensor, scalar)` / `ops.equal(scalar, tensor)`:确认分发走 `eq_scalar` 路径(Tensor-Scalar / Scalar-Tensor) - `ops.equal(tensor, tensor)`:确认分发走 `equal` 路径(Tensor-Tensor) 2) 广播(broadcast)覆盖: - `input.shape=(N, D)`,`other.shape=(N, 1)` / `(1, D)` / `()`(scalar)等可广播组合 - 对齐 PyTorch `torch.eq` 的广播规则与结果 3) 标量类型覆盖(Scalar coverage): - `other` 取 `int/float/bool`(Python Number/Bool) - 特殊值:`0/-1/1`、极大/极小数、`inf/-inf/nan`(浮点标量) 4) dtype 覆盖(Tensor dtype coverage): - `bool` - `int8/int16/int32/int64` - `uint8`(若后端支持) - `float16/float32/float64`(按平台能力可选 `bfloat16`) - 期望:输出 dtype 为 `bool`,逐元素比较结果与 torch 对齐 5) shape/rank 覆盖: - 0D(标量 Tensor)、1D、2D、3D(至少覆盖常见 N 维场景) - 空 Tensor:`shape` 含 0 维(如 `(0,)`、`(0, D)`),期望可执行且输出形状符合广播规则 6) 非连续输入(non-contiguous): - 例如 `x = ops.transpose(t, (1, 0))` 等视图场景,比较结果正确、无崩溃 - 动态形态(Dynamic Shape/Rank): - 使用 TEST_OP/动态形态框架覆盖: - 动态 rank:输入 rank 动态变化(如 1D/2D/3D),比较可达 - 动态 shape:`N/D` 动态(含广播维度变化) - 动态标量:`other` 为 scalar 与 0D Tensor 两条路径均覆盖 - 异常与边界(Errors & Edge Cases): - `input` 与 `other` 不可广播:期望抛出 `ValueError`(或框架定义的形状不匹配错误) - `other` 类型不支持(如 list/dict/自定义对象):期望抛出 `TypeError` - `input` 为不支持的 dtype(如 complex64/complex128):期望抛出 `TypeError`(若当前不支持) - `Tensor.__eq__`:验证 `tensor == other` 与 `ops.equal(tensor, other)` 结果一致;并覆盖 `other == tensor` 的反向比较(例如 scalar 侧 `__eq__` 返回 `NotImplemented` 时回落到 Tensor 比较) - 标杆对比(Benchmarking): - 首选与 `torch.eq`(CPU)结果对齐;对浮点场景包含 `nan` 时按 IEEE 规则(`nan != nan`)校验 - 对比项:输出 shape、逐元素布尔结果、广播行为 ### 验收报告 - 基本信息: 1. 要验收的算子:`mindspore.ops.equal` / `Tensor.__eq__` 2. 需求 ISSUE 单:— 3. mindspore 仓 PR:— 4. MindSporeTest 测试仓 PR:— 5. 对比标杆版本:`torch 2.x` 6. 标杆算子:`torch.eq` 7. 是否为副作用算子:否 8. 报错规范:MindSpore 日志与错误信息规范 – 资料验证: | 自测内容 | 自测结果 | 备注 | | --- | --- | --- | | 新增接口列表 | 不涉及 | 对外接口均为存量接口 | | 提供典型场景的ut/st用例 | 是 | | | 接口是否提供中文rst并与英文注释对应 | 是 | 存量接口 | | 接口描述是否详细准确 | 是 | | | 与pytorch的接口是否一致 | 是 | | | summary部分是否有提供公式 | 不涉及 | | | 属性描述是否完整正确 | 不涉及 | | | input描述是否完整正确 | 不涉及 | | | 输出描述是否完整正确 | 不涉及 | | | 输出尺寸和input是否一样 | 否 | 满足广播原则 | | Raises项描述完整正确 | 是 | | | 支持的平台都有填写 | 是 | | | 资料格式(包括样例格式)检查是否ok |是 | | | 样例是否有提供 | 是 | | | 样例是否有打印结果 | 是 | | | 样例执行情况是否ok | 是 | | | 算子与API能力沙盘是否补齐 | 是 | | – 功能验证: | 自测内容 | 自测结果 | 备注 | | --- | --- | --- | | 默认参数场景是否验证 | | | | 空Tensor输入的正反向是否验证 | | | | inf和nan是否验证 | | | | 算子支持数据类型是否与标杆对齐(pytorch npu/gpu/cpu) | | | | 输入取值范围是否有验证 | | | | 输入维度是否有覆盖0D-8D | | | | 输入支持的dtype是否全覆盖 | | | | 输入是否支持隐式类型转换? | | | | 输入是否支持广播 | | | | 输入之间的约束是否有验证 | | | | 正向的精度验证是否通过 | | | | 反向是否支持 | | | | 反向是否是单算子实现 | | | | 异常用例是否校验具体报错信息 | | | | 是否提供报错白名单 | | | | 是否提供functional用例 | | | | 动态shape/rank/属性是否都支持 | | | | 是否关闭退避功能验证 export MS_DISABLE_KERNEL_BACKOFF=1 | | | | 测试仓接口相关用例是否全部可以PASS,接口没有遗留问题单 | | | | 是否支持bf16 | | | | 多输入算子的bprop函数是否有考虑反向按需求导 | | | | 算子输出shape是否依赖于算子的计算结果 | | | | 是否支持非连续输入支持情况 | | | | 是否与PTA计算结果0偏差(请将MD5对比截图贴于右侧) | | | | 是否会使得运算符或存量ops接口调用到新增的原语 | | | | 是否已支持amp(混合精度)特性 | | | | 若多Tensor输入,是否支持各Tensor数据类型不一致 | | | - 性能验证(占位,后续补充): | 自测内容 | 自测结果 | 备注 | | --- | --- | --- | | 性能是否验证广播场景 | | | | 是否考虑到反向显存优化,未用到的输入是否添加到了SetUnusedInputs中(串讲时需展示代码) | | | | 性能测试是否覆盖不同规格的数据(3种以上),且算子性能不低于友商,允许波动范围10%(CPU和GPU3090平台及以上) | | | | 显存是否持平PTA | | | 显存自验结果: 910b tensor-scalar - 算子自测: | 自测内容 | 自测结果 | 备注 | | --- | --- | --- | | 生成算子自测用例的json文件是否提供 | | | | 是否自测无遗留问题 | | | - 安全编码检视: | 自测内容 | 自测结果 | 备注 | | --- | --- | --- | | 指针是否未判空 | | | | 指针是否先使用后校验 | | | | 数组、指针是否访存越界 | | | | 是否存在除零 | | | | 是否存在内存泄露(new/mallco内存未释放 | | | | 是否存在异常、错误处理分支未对内存、文件句柄等资源释放 | | | | 是否存在'使用new创建对象未声明为nothrow | | | | 是否未使用安全函数库进行内存操作 | | | | 是否存在数据类型转换导致数值移除(上溢或下溢) | | | | 是否存在冗余代码(冗余校验、不可达代码等) | | | | 是否存在暴露敏感信息 | | | | 是否使用了弱随机数生成器 | | | | 是否将敏感信息硬编码 | | | ### 其他信息
### 🚀 背景描述 当前 MindSpore 的 Equal 算子(==)在处理 Tensor 与标量(Scalar)的比较时,可能存在类型提升或回退问题。为了提升易用性并与 PyTorch 行为对齐,需要支持 Python 的 `==` 符号重载,并确保在各个后端高效执行,特别是利用 `eq_scalar` 算子来处理 Tensor 与标量的比较,避免不必要的 Tensor 转换以及同步拷贝带来的性能损失。 ### 设计思路 ### 标杆与接口(Benchmark & API) - 标杆接口:`torch.eq(input, other, *, out=None)` - 功能:逐元素比较 `input` 和 `other`。支持广播。 - MindSpore 接口: - `mindspore.ops.equal(input, other)` - `Tensor.__eq__(other)` ### 任务清单(Tasks) | 序号 | 任务项 | 任务子项 | 状态(新增/修改/无变更/不涉及) | 备注 | | ---- | ------------------ | ----------------- | ------------------------------ | ---- | | 1 | 接口基本功能 | Primitive | 修改 | 引用 eq_scalar_op 和 equal_op | | | | functional | 新增 | api_def/equal.yaml, __eq__.yaml | | | | multitype_ops | 修改 | equal_impl.py 拆分 Tensor/Scalar 逻辑 | | 2 | 后端及数据类型支持 | Ascend | 不涉及 | 已有 Kernel 支持 | | | | GPU | 不涉及 | — | | | | CPU | 不涉及 | — | | 3 | 支持 vmap | | 不涉及 | — | | 4 | 支持动态 Shape | | 不涉及 | — | | 5 | 支持反向 | | 不涉及 | Equal 不可导 | | 6 | 补齐资料 | API 映射 | 新增 | __eq__ 映射到 equal | ### YAML 定义(Definition 摘要) - `mindspore/ops/api_def/equal.yaml`: ```yaml equal: - op_yaml: eq_scalar_op.yaml py_method: tensor_equal Ascend: pyboost CPU: pyboost GPU: pyboost interface: tensor, function - op_yaml: equal_op.yaml py_method: tensor_equal Ascend: pyboost CPU: pyboost GPU: pyboost interface: tensor, function ``` - `mindspore/ops/api_def/__eq__.yaml`: ```yaml __eq__: alias: equal ``` ### 实现方案(Implementation) - **api_def 定义**:通过 `equal.yaml` 定义多态分发,优先匹配 `eq_scalar_op`(Tensor, Number),其次匹配 `equal_op`(Tensor, Tensor)。 - **MultitypeFuncGraph 适配**:修改 `mindspore/python/mindspore/ops/composite/multitype_ops/equal_impl.py`,将 `_tensor_equal_tensor` 拆分为: - `_tensor_equal_scalar`:调用 `ops.eq_scalar(tensor, number)`。 - `_scalar_equal_tensor`:调用 `ops.eq_scalar(tensor, number)`(交换参数)。 - `_tensor_equal_tensor`:调用 `ops.equal(tensor, tensor)`。 - **eq_scalar_op 配置**:修改 `eq_scalar_op.yaml`,启用 functional 接口生成(`function: disable: False`)。 ### 测试方案(Test Plan) - 执行模式与平台: - 模式:`Pynative`、`Graph O0`(KBK) - 平台:`Ascend`、`GPU`、`CPU`(Equal/eq_scalar 为通用基础算子,各后端均需覆盖) - 功能与一致性(Functional & Consistency): 1) 接口覆盖: - `mindspore.ops.equal(input, other)`(functional) - `Tensor.__eq__(other)`(运算符 `==`) - `ops.equal(tensor, scalar)` / `ops.equal(scalar, tensor)`:确认分发走 `eq_scalar` 路径(Tensor-Scalar / Scalar-Tensor) - `ops.equal(tensor, tensor)`:确认分发走 `equal` 路径(Tensor-Tensor) 2) 广播(broadcast)覆盖: - `input.shape=(N, D)`,`other.shape=(N, 1)` / `(1, D)` / `()`(scalar)等可广播组合 - 对齐 PyTorch `torch.eq` 的广播规则与结果 3) 标量类型覆盖(Scalar coverage): - `other` 取 `int/float/bool`(Python Number/Bool) - 特殊值:`0/-1/1`、极大/极小数、`inf/-inf/nan`(浮点标量) 4) dtype 覆盖(Tensor dtype coverage): - `bool` - `int8/int16/int32/int64` - `uint8`(若后端支持) - `float16/float32/float64`(按平台能力可选 `bfloat16`) - 期望:输出 dtype 为 `bool`,逐元素比较结果与 torch 对齐 5) shape/rank 覆盖: - 0D(标量 Tensor)、1D、2D、3D(至少覆盖常见 N 维场景) - 空 Tensor:`shape` 含 0 维(如 `(0,)`、`(0, D)`),期望可执行且输出形状符合广播规则 6) 非连续输入(non-contiguous): - 例如 `x = ops.transpose(t, (1, 0))` 等视图场景,比较结果正确、无崩溃 - 动态形态(Dynamic Shape/Rank): - 使用 TEST_OP/动态形态框架覆盖: - 动态 rank:输入 rank 动态变化(如 1D/2D/3D),比较可达 - 动态 shape:`N/D` 动态(含广播维度变化) - 动态标量:`other` 为 scalar 与 0D Tensor 两条路径均覆盖 - 异常与边界(Errors & Edge Cases): - `input` 与 `other` 不可广播:期望抛出 `ValueError`(或框架定义的形状不匹配错误) - `other` 类型不支持(如 list/dict/自定义对象):期望抛出 `TypeError` - `input` 为不支持的 dtype(如 complex64/complex128):期望抛出 `TypeError`(若当前不支持) - `Tensor.__eq__`:验证 `tensor == other` 与 `ops.equal(tensor, other)` 结果一致;并覆盖 `other == tensor` 的反向比较(例如 scalar 侧 `__eq__` 返回 `NotImplemented` 时回落到 Tensor 比较) - 标杆对比(Benchmarking): - 首选与 `torch.eq`(CPU)结果对齐;对浮点场景包含 `nan` 时按 IEEE 规则(`nan != nan`)校验 - 对比项:输出 shape、逐元素布尔结果、广播行为 ### 验收报告 - 基本信息: 1. 要验收的算子:`mindspore.ops.equal` / `Tensor.__eq__` 2. 需求 ISSUE 单:— 3. mindspore 仓 PR:— 4. MindSporeTest 测试仓 PR:— 5. 对比标杆版本:`torch 2.x` 6. 标杆算子:`torch.eq` 7. 是否为副作用算子:否 8. 报错规范:MindSpore 日志与错误信息规范 – 资料验证: | 自测内容 | 自测结果 | 备注 | | --- | --- | --- | | 新增接口列表 | 不涉及 | 对外接口均为存量接口 | | 提供典型场景的ut/st用例 | 是 | | | 接口是否提供中文rst并与英文注释对应 | 是 | 存量接口 | | 接口描述是否详细准确 | 是 | | | 与pytorch的接口是否一致 | 是 | | | summary部分是否有提供公式 | 不涉及 | | | 属性描述是否完整正确 | 不涉及 | | | input描述是否完整正确 | 不涉及 | | | 输出描述是否完整正确 | 不涉及 | | | 输出尺寸和input是否一样 | 否 | 满足广播原则 | | Raises项描述完整正确 | 是 | | | 支持的平台都有填写 | 是 | | | 资料格式(包括样例格式)检查是否ok |是 | | | 样例是否有提供 | 是 | | | 样例是否有打印结果 | 是 | | | 样例执行情况是否ok | 是 | | | 算子与API能力沙盘是否补齐 | 是 | | – 功能验证: | 自测内容 | 自测结果 | 备注 | | --- | --- | --- | | 默认参数场景是否验证 | | | | 空Tensor输入的正反向是否验证 | | | | inf和nan是否验证 | | | | 算子支持数据类型是否与标杆对齐(pytorch npu/gpu/cpu) | | | | 输入取值范围是否有验证 | | | | 输入维度是否有覆盖0D-8D | | | | 输入支持的dtype是否全覆盖 | | | | 输入是否支持隐式类型转换? | | | | 输入是否支持广播 | | | | 输入之间的约束是否有验证 | | | | 正向的精度验证是否通过 | | | | 反向是否支持 | | | | 反向是否是单算子实现 | | | | 异常用例是否校验具体报错信息 | | | | 是否提供报错白名单 | | | | 是否提供functional用例 | | | | 动态shape/rank/属性是否都支持 | | | | 是否关闭退避功能验证 export MS_DISABLE_KERNEL_BACKOFF=1 | | | | 测试仓接口相关用例是否全部可以PASS,接口没有遗留问题单 | | | | 是否支持bf16 | | | | 多输入算子的bprop函数是否有考虑反向按需求导 | | | | 算子输出shape是否依赖于算子的计算结果 | | | | 是否支持非连续输入支持情况 | | | | 是否与PTA计算结果0偏差(请将MD5对比截图贴于右侧) | | | | 是否会使得运算符或存量ops接口调用到新增的原语 | | | | 是否已支持amp(混合精度)特性 | | | | 若多Tensor输入,是否支持各Tensor数据类型不一致 | | | - 性能验证(占位,后续补充): | 自测内容 | 自测结果 | 备注 | | --- | --- | --- | | 性能是否验证广播场景 | | | | 是否考虑到反向显存优化,未用到的输入是否添加到了SetUnusedInputs中(串讲时需展示代码) | | | | 性能测试是否覆盖不同规格的数据(3种以上),且算子性能不低于友商,允许波动范围10%(CPU和GPU3090平台及以上) | | | | 显存是否持平PTA | | | 显存自验结果: 910b tensor-scalar - 算子自测: | 自测内容 | 自测结果 | 备注 | | --- | --- | --- | | 生成算子自测用例的json文件是否提供 | | | | 是否自测无遗留问题 | | | - 安全编码检视: | 自测内容 | 自测结果 | 备注 | | --- | --- | --- | | 指针是否未判空 | | | | 指针是否先使用后校验 | | | | 数组、指针是否访存越界 | | | | 是否存在除零 | | | | 是否存在内存泄露(new/mallco内存未释放 | | | | 是否存在异常、错误处理分支未对内存、文件句柄等资源释放 | | | | 是否存在'使用new创建对象未声明为nothrow | | | | 是否未使用安全函数库进行内存操作 | | | | 是否存在数据类型转换导致数值移除(上溢或下溢) | | | | 是否存在冗余代码(冗余校验、不可达代码等) | | | | 是否存在暴露敏感信息 | | | | 是否使用了弱随机数生成器 | | | | 是否将敏感信息硬编码 | | | ### 其他信息
评论 (
2
)
登录
后才可以发表评论
状态
TODO
TODO
ACCEPTED
WIP
VALIDATION
DONE
CLOSED
REJECTED
负责人
未设置
标签
feature
未设置
项目
未立项任务
未立项任务
里程碑
未关联里程碑
未关联里程碑
Pull Requests
未关联
未关联
关联的 Pull Requests 被合并后可能会关闭此 issue
分支
未关联
分支 (
-
)
标签 (
-
)
开始日期   -   截止日期
-
置顶选项
不置顶
置顶等级:高
置顶等级:中
置顶等级:低
优先级
不指定
严重
主要
次要
不重要
预计工期
(小时)
参与者(2)
Python
1
https://gitee.com/mindspore/mindspore.git
git@gitee.com:mindspore/mindspore.git
mindspore
mindspore
mindspore
点此查找更多帮助
搜索帮助
Git 命令在线学习
如何在 Gitee 导入 GitHub 仓库
Git 仓库基础操作
企业版和社区版功能对比
SSH 公钥设置
如何处理代码冲突
仓库体积过大,如何减小?
如何找回被删除的仓库数据
Gitee 产品配额说明
GitHub仓库快速导入Gitee及同步更新
什么是 Release(发行版)
将 PHP 项目自动发布到 packagist.org
评论
仓库举报
回到顶部
登录提示
该操作需登录 Gitee 帐号,请先登录后再操作。
立即登录
没有帐号,去注册