diff --git a/ops/c_api/scatter_nd_update/scatter_nd_update.cc b/ops/c_api/scatter_nd_update/scatter_nd_update.cc new file mode 100644 index 0000000000000000000000000000000000000000..d17dce47fca889e0ce7682c98709a656baa64605 --- /dev/null +++ b/ops/c_api/scatter_nd_update/scatter_nd_update.cc @@ -0,0 +1,169 @@ +/** + * Copyright 2025 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// ============================================================================= +// GRAPH MODE IMPLEMENTATION +// ============================================================================= + +#include +#include +#include +#include "ops/framework/aclnn/graphmode/aclnn_kernel_mod.h" +#include "ops/framework/utils.h" + +namespace ms_custom_ops { +enum ScatterNdUpdateInputIndex : size_t { + kScatterNdUpdateInputIndex = 0, + kScatterNdUpdateIndicesIndex, + kScatterNdUpdateUpdatesIndex, + kScatterNdUpdateInputsNum, +}; + +enum ScatterNdUpdateOutputIndex : size_t { + kScatterNdUpdateOutputIndex = 0, + kScatterNdUpdateOutputNums, +}; + +class OPS_API ScatterNdUpdateCustomOpFuncImpl : public OpFuncImpl { + public: + ShapeArray InferShape(const PrimitivePtr &primitive, const InferInfoPtrList &input_infos) const override { + MS_EXCEPTION_IF_NULL(primitive); + + if (input_infos[kScatterNdUpdateInputIndex]->IsDynamicRank()) { + return {input_infos[kScatterNdUpdateInputIndex]->GetShape()}; + } + auto indices_shape = input_infos[kScatterNdUpdateIndicesIndex]->GetShape(); + if (indices_shape.size() < 2) { + MS_LOG(EXCEPTION) << "scatter_nd_update indices dims should not less than 2, but got " << indices_shape.size(); + } + auto input_shape = input_infos[kScatterNdUpdateInputIndex]->GetShape(); + auto updates_shape = input_infos[kScatterNdUpdateUpdatesIndex]->GetShape(); + auto last_indices_dim = indices_shape.back(); + + // otherwise its dims should be 1~8. + if (input_shape.size() < 1 || input_shape.size() > 8) { + MS_LOG(EXCEPTION) << "'scatter_nd_update' only support the input tensor 'input' in idx 0, which shape dims " + "should in [1, 8], but got" + << input_shape.size(); + } + + // updates shape = (indices_shape - last_dim) + + auto expected_update_shapes_dim = indices_shape.size() - 1 + input_shape.size() - indices_shape.back(); + ShapeVector expected_update_shape(expected_update_shapes_dim, 0); + size_t idx = 0; + for (; idx < indices_shape.size() - 1; idx++) { + expected_update_shape[idx] = indices_shape[idx]; + } + + for (size_t i = 0; idx < expected_update_shapes_dim; idx++, i++) { + expected_update_shape[idx] = input_shape[i + indices_shape.back()]; + } + + if (expected_update_shape != updates_shape) { + MS_LOG(EXCEPTION) << "scatter_nd_update: let the last dimension of indices have size a. Then the shape of " + "updates is the shape of indices without the last dimension, followed by the shape of input " + "tensor without the first a dimensions.but indices shape:" + << indices_shape << ", input shape:" << input_shape << ", updates shape:" << updates_shape; + } + + return {input_infos[kScatterNdUpdateInputIndex]->GetShape()}; + } + + std::vector InferType(const PrimitivePtr &primitive, const InferInfoPtrList &input_infos) const override { + MS_EXCEPTION_IF_NULL(primitive); + auto op_name = primitive->name(); + auto input_dtype = input_infos[kScatterNdUpdateInputIndex]->GetType(); + auto updates_dtype = input_infos[kScatterNdUpdateUpdatesIndex]->GetType(); + if (input_dtype != updates_dtype) { + MS_LOG(EXCEPTION) << "input dtype should be same with updates dtype, but input_dtype:" << input_dtype + << ", updates dtype:" << updates_dtype; + } + auto ms_context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(ms_context); + const auto &soc_version = ms_context->ascend_soc_version(); + + if (soc_version == kAscendVersion910_93 || soc_version == kAscendVersion910b) { + const std::set valid_types = { + kNumberTypeFloat16, kNumberTypeBFloat16, kNumberTypeFloat32, kNumberTypeInt64, kNumberTypeBool, kNumberTypeInt8, + }; + CheckAndConvertUtils::CheckTypeIdValid("input", input_dtype, valid_types, op_name); + CheckAndConvertUtils::CheckTypeIdValid("updates", updates_dtype, valid_types, op_name); + } else if (soc_version == kAscendVersion310p) { + const std::set valid_types = {kNumberTypeFloat16, kNumberTypeFloat32, kNumberTypeBool}; + CheckAndConvertUtils::CheckTypeIdValid("input", input_dtype, valid_types, op_name); + CheckAndConvertUtils::CheckTypeIdValid("updates", updates_dtype, valid_types, op_name); + } else { + MS_LOG(EXCEPTION) << "'scatter_nd_update' only support [" << kAscendVersion910b << ", " << kAscendVersion910_93 + << ", " << kAscendVersion310p << "], but got " << soc_version; + } + + return {input_dtype}; + } + + bool GeneralInferRegistered() const override { return true; } +}; + +class ScatterNdUpdateCustomAscend : public AclnnCustomKernelMod { + public: + ScatterNdUpdateCustomAscend() : AclnnCustomKernelMod("aclnnScatterNdUpdate") {} + ~ScatterNdUpdateCustomAscend() = default; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + MS_EXCEPTION_IF_NULL(stream_ptr); + RunOp(stream_ptr, workspace, inputs[kScatterNdUpdateInputIndex], inputs[kScatterNdUpdateIndicesIndex], + inputs[kScatterNdUpdateUpdatesIndex]); + return true; + } + + void GetWorkSpaceInfo(const std::vector &inputs, + const std::vector &outputs) override { + GetWorkspaceForResize(inputs[kScatterNdUpdateInputIndex], inputs[kScatterNdUpdateIndicesIndex], + inputs[kScatterNdUpdateUpdatesIndex]); + return; + } + + private: + DEFINE_GET_WORKSPACE_FOR_RESIZE(); +}; +} // namespace ms_custom_ops + +REG_GRAPH_MODE_OP(scatter_nd_update, ms_custom_ops::ScatterNdUpdateCustomOpFuncImpl, + ms_custom_ops::ScatterNdUpdateCustomAscend); + +// ============================================================================= +// PYBOOST MODE IMPLEMENTATION +// ============================================================================= + +namespace ms_custom_ops { +using namespace mindspore; +using namespace mindspore::device::ascend; + +std::vector scatter_nd_update_custom(const ms::Tensor &input, const ms::Tensor &indices, + const ms::Tensor &updates) { + std::vector outputs = {}; + auto runner = std::make_shared("aclnnScatterNdUpdate"); + runner->SetLaunchFunc(LAUNCH_ACLNN_FUNC(aclnnScatterNdUpdate, input, indices, updates)); + // only set tensor. + runner->Run({input, indices, updates}, outputs); + return {input}; +} +} // namespace ms_custom_ops + +MS_CUSTOM_OPS_EXTENSION_MODULE(m) { + m.def("scatter_nd_update", + PYBOOST_CALLER(ms_custom_ops::kScatterNdUpdateOutputNums, ms_custom_ops::scatter_nd_update_custom)); +} diff --git a/ops/c_api/scatter_nd_update/scatter_nd_update.md b/ops/c_api/scatter_nd_update/scatter_nd_update.md new file mode 100644 index 0000000000000000000000000000000000000000..5a39017a5ebedc7b305bfe0b7c02f4e968dd66d2 --- /dev/null +++ b/ops/c_api/scatter_nd_update/scatter_nd_update.md @@ -0,0 +1,65 @@ +# scatter_nd_update算子 + +## 描述 +scatter_nd_update算子用于计算旋转编码操作。 +该算子底层调用的是aclnnScatterNdUpdate算子。 + +## 输入参数 + +| Name | DType | Shape | Optional | Inplace | Format | Description | +|---------------------|-----------------|----------------------------------------|----------|---------|--------|--------------------------------------------------------| +| varRef | Tensor | 1~8维 | No | No | ND | 待更新的原始数据,数据类型需与updates一致 | +| indices | Tensor | 至少2维 | No | No | ND | 索引张量,数据类型为int32或者int64,索引数据不支持越界 | +| updates | Tensor | | No | No | ND | 更新值张量,数据类型需要与VarRef一致 | + +Note: +Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件、Atlas A3 训练系列产品/Atlas A3 推理系列产品​: ++ varRef和updates支持:FLOAT16、BFLOAT16、FLOAT32、INT64、BOOL、INT8 ++ indices支持:INT32、INT64 +Atlas 训练系列产品、Atlas 推理系列产品​: ++ varRef和updates支持:FLOAT16、FLOAT32、BOOL ++ indices支持:INT32、INT64 + +约束说明: +1.indices至少是2维,其最后1维的大小(记为a)不能超过varRef的维度大小 +2.updates的形状必须等于indices除最后1维外的形状加上varRef除前a维外的形状 +例如:varRef的shape是(4, 5, 6),indices的shape是(3, 2),则updates的shape必须是(3, 6) + +## 输出参数 + +| Name | DType | Shape | Description | +|--------|------------|------------|-----------------------| +| out | Tensor | 同varRef | 更新后的输出张量 | + + +更多详细信息请参考:[aclnnScatterNdUpdate](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/83RC1alpha002/API/aolapi/context/aclnnScatterNdUpdate.md) + + +## 特殊说明 + +## 使用示例 + +### 基本使用示例 + +```python +import mindspore as ms +import numpy as np +import ms_custom_ops +from mindspore import context, Tensor + +ms.set_device("Ascend") + +@ms.jit +def scatter_nd_update_func(var, indices, updates): + return ms_custom_ops.scatter_nd_update(var, indices, updates) + +data_var = np.random.uniform(0, 1, [24, 128]).astype(np.float16) +data_indices = np.random.uniform(0, 12, [12, 1]).astype(np.int32) +data_updates = np.random.uniform(1, 2, [12, 128]).astype(np.float16) + + +var = Tensor(data_var, dtype=ms.float16) +indices = Tensor(data_indices, dtype=ms.int32) +var = Tensor(data_updates, dtype=ms.float16) +out = scatter_nd_update_func(var, indices, updates) +``` diff --git a/ops/c_api/scatter_nd_update/scatter_nd_update_op.yaml b/ops/c_api/scatter_nd_update/scatter_nd_update_op.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a9acaf7358d513ab6e785805280ed470fed7283c --- /dev/null +++ b/ops/c_api/scatter_nd_update/scatter_nd_update_op.yaml @@ -0,0 +1,17 @@ +#operator scatter_nd_update +scatter_nd_update: + args: + input: + dtype: tensor + indices: + dtype: tensor + updates: + dtype: tensor + args_signature: + rw_write: input + labels: + side_effect_mem: True + returns: + out: + dtype: tensor + inplace: input \ No newline at end of file diff --git a/tests/st/test_custom_scatter_nd_update.py b/tests/st/test_custom_scatter_nd_update.py new file mode 100644 index 0000000000000000000000000000000000000000..22e0fbe72887fb0a2bfe52e553ab341665590d8a --- /dev/null +++ b/tests/st/test_custom_scatter_nd_update.py @@ -0,0 +1,148 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import os +import numpy as np +import pytest +from functools import wraps + +import mindspore.ops as ops +import mindspore.nn as nn +import mindspore as ms +from mindspore import context, Tensor +from mindspore.common.np_dtype import bfloat16 +from mindspore._c_expression import MSContext +import ms_custom_ops + + +def get_ms_dtype(query_dtype): + if query_dtype == np.float32: + ms_dtype = ms.float32 + elif query_dtype == np.float16: + ms_dtype = ms.float16 + elif query_dtype == bfloat16: + ms_dtype = ms.bfloat16 + return ms_dtype + + +def scatter_nd_update(varRef, indices, updates): + """ + NumPy实现scatter_nd_update功能 + 参数: + varRef: 原始数据张量 (n维) + indices: 索引张量 (至少2维,最后一维大小k <= varRef.ndim) + updates: 更新值张量 (形状需符合约束条件) + 返回: + 更新后的张量副本 + """ + # 复制原始数据避免修改输入 + result = np.copy(varRef) + + # 获取关键维度信息 + k = indices.shape[-1] # 索引维度 + idx_shape = indices.shape[:-1] # 索引前缀形状 + var_shape = varRef.shape + + # 验证updates形状: [idx_shape] + [var_shape[k:]] + expected_updates_shape = idx_shape + var_shape[k:] + if updates.shape != expected_updates_shape: + raise ValueError( + f"Updates shape mismatch. Expected {expected_updates_shape}, got {updates.shape}" + ) + + # 重塑索引为二维数组 [N, k] + flat_indices = indices.reshape(-1, k) + # 重塑更新值为 [N, ...] 形状 + flat_updates = updates.reshape(-1, *var_shape[k:]) + + # 遍历每个索引 + for i in range(flat_indices.shape[0]): + idx = tuple(flat_indices[i]) + # 验证索引边界 + if any((idx[j] < 0 or idx[j] >= var_shape[j]) for j in range(k)): + raise IndexError( + f"Index {idx} out of bounds for tensor shape {var_shape[:k]}" + ) + + # 构建完整切片索引 + full_idx = idx + # 添加剩余维度的全切片 + if k < len(var_shape): + full_idx += (slice(None),) * (len(var_shape) - k) + + # 执行更新 + result[full_idx] = flat_updates[i] + + return result + + +class ScatterNdUpdateNet(ms.nn.Cell): + def _init__(self): + super().__init__() + + def construct(self, var, indices, updates): + return ms_custom_ops.scatter_nd_update(var, indices, updates) + + +def run( + net, + var_shape, + indices_shape, + updates_shape, + dtype, + indices_dtype, +): + var = np.random.uniform(0, 1, var_shape).astype(dtype) + indices = np.random.uniform(0, 12, indices_shape).astype(indices_dtype) + updates = np.random.uniform(1, 2, updates_shape).astype(dtype) + + var_tensor = Tensor(var, dtype=get_ms_dtype(dtype)) + indices_tensor = Tensor(indices, dtype=ms.int64) + updates_tensor = Tensor(updates, dtype=get_ms_dtype(dtype)) + out = net(var_tensor, indices_tensor, updates_tensor) + + golden = scatter_nd_update(var, indices, updates) + + np.testing.assert_allclose(out, golden, rtol=1e-2, atol=1e-2) + + +@pytest.mark.level0 +@pytest.mark.env_onecard +@pytest.mark.platform_ascend910b +@pytest.mark.platform_ascend310p +@pytest.mark.parametrize("mode", [context.GRAPH_MODE, context.PYNATIVE_MODE]) +@pytest.mark.parametrize("dtype", [np.float16, np.float32, np.int64, np.int8]) +@pytest.mark.parametrize("indice_dtype", [np.int64, np.int32]) +@pytest.mark.parametrize("var_shape", [[24, 128]]) +@pytest.mark.parametrize("indices_shape", [[12, 1]]) +@pytest.mark.parametrize("updates_shape", [[12, 128]]) +def test_scatter_nd_update_float16( + mode, + dtype, + indice_dtype, + var_shape, + indices_shape, + updates_shape, +): + ms.set_context(device_target="Ascend", mode=mode) + ms.set_context(jit_config={"jit_level": "O0", "infer_boost": "on"}) + net = ScatterNdUpdateNet() + run( + net, + var_shape, + indices_shape, + updates_shape, + dtype, + indice_dtype, + )