diff --git a/inferrt/python/mrt/torch/fx_backend.py b/inferrt/python/mrt/torch/fx_backend.py index 961533d038fd9e297bcda9aa66512f59dc7acb24..3bde93e7cb249aa11e32d287825eb91ee958bfd2 100644 --- a/inferrt/python/mrt/torch/fx_backend.py +++ b/inferrt/python/mrt/torch/fx_backend.py @@ -154,6 +154,7 @@ _OP_MAP = { torch.sigmoid: Op.sigmoid, torch.empty: Op.empty, torch.zeros: Op.zeros, + torch.contiguous: Op.contiguous, torch.ops._c10d_functional.all_gather_into_tensor: Op.all_gather, torch.ops._c10d_functional.all_reduce: Op.all_reduce, torch.ops._c10d_functional.reduce_scatter_tensor: Op.reduce_scatter, @@ -217,6 +218,7 @@ _OP_MAP = { "view": Op.reshape, # view is often used like reshape "copy_": Op.copy, "long": Op.cast, + "contiguous": Op.contiguous, } diff --git a/inferrt/src/ops/ascend/aclnn/aclnn_contiguous.cc b/inferrt/src/ops/ascend/aclnn/aclnn_contiguous.cc new file mode 100644 index 0000000000000000000000000000000000000000..15c3123403d5678c459db033234810a6eb02ed17 --- /dev/null +++ b/inferrt/src/ops/ascend/aclnn/aclnn_contiguous.cc @@ -0,0 +1,39 @@ +/** + * Copyright 2025 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include "ops/ascend/aclnn/aclnn_contiguous.h" +#include "ops/ascend/aclnn/utils/opapi_utils.h" +#include "ops/op_register.h" + +namespace mrt { +namespace ops { +OpsErrorCode AclnnContiguous::CalcWorkspace(const std::vector &input, const ir::Value *output, + size_t *workspaceSize) { + executor_->GetWorkspaceSize(static_cast(workspaceSize), output->ToTensor(), input[kIndex0]->ToTensor()); + return SUCCESS; +} + +OpsErrorCode AclnnContiguous::Launch(const std::vector &input, void *workspace, size_t workspaceSize, + ir::Value *output, void *stream) { + executor_->Launch(workspace, workspaceSize, stream, output->ToTensor(), input[kIndex0]->ToTensor()); + return SUCCESS; +} + +MRT_REG_OP(contiguous, AclnnContiguous, Ascend); +} // namespace ops +} // namespace mrt diff --git a/inferrt/src/ops/ascend/aclnn/aclnn_contiguous.h b/inferrt/src/ops/ascend/aclnn/aclnn_contiguous.h new file mode 100644 index 0000000000000000000000000000000000000000..627ad2700df9db5a2b2946daabc2ebe612d0e9a8 --- /dev/null +++ b/inferrt/src/ops/ascend/aclnn/aclnn_contiguous.h @@ -0,0 +1,41 @@ +/** + * Copyright 2025 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __OPS_ASCEND_ACLNN_ACLNN_CONTIGUOUS_H__ +#define __OPS_ASCEND_ACLNN_ACLNN_CONTIGUOUS_H__ + +#include "ops/operator.h" +#include "ops/ascend/aclnn/utils/aclnn_executor.h" + +namespace mrt { +namespace ops { +class AclnnContiguous : public Operator { + public: + AclnnContiguous() { executor_ = std::make_unique("aclnnInplaceCopy"); } + ~AclnnContiguous() override = default; + + OpsErrorCode CalcWorkspace(const std::vector &input, const ir::Value *output, + size_t *workspaceSize) override; + OpsErrorCode Launch(const std::vector &input, void *workspace, size_t workspaceSize, + ir::Value *output, void *stream) override; + + private: + std::unique_ptr executor_{nullptr}; +}; + +} // namespace ops +} // namespace mrt +#endif // __OPS_ASCEND_ACLNN_ACLNN_CONTIGUOUS_H__ diff --git a/inferrt/src/ops/op_def/ops.list b/inferrt/src/ops/op_def/ops.list index 65929d31841728b422269d35f022db9856da8e45..a42db2f3328c6923a656d9ff400a572568b5cae4 100644 --- a/inferrt/src/ops/op_def/ops.list +++ b/inferrt/src/ops/op_def/ops.list @@ -27,6 +27,7 @@ OP(reshape) OP(reshape_ext) OP(rsqrt) OP(cumsum_ext) +OP(contiguous) OP(sigmoid) OP(square) OP(add_rms_norm) diff --git a/mopt/include/mopt/Dialect/Mrt/MrtOps.td b/mopt/include/mopt/Dialect/Mrt/MrtOps.td index b6cc783c87126b483478a15dc4cf9a6570c71453..745a1e417bd8a18cfa8088feba293076d4980c75 100644 --- a/mopt/include/mopt/Dialect/Mrt/MrtOps.td +++ b/mopt/include/mopt/Dialect/Mrt/MrtOps.td @@ -244,6 +244,24 @@ def Mrt_ConvOp : Mrt_Op<"conv", [Pure]> { }]; } +def Mrt_ContiguousOp : Mrt_Op<"contiguous", [Pure]> { + let summary = "tensor contiguous operation"; + let description = [{ + Performs contiguous operation on input tensors. + }]; + + let arguments = (ins + MrtAnyTensor:$input + ); + + let results = (outs MrtAnyTensor:$result); + + let assemblyFormat = [{ + $input + attr-dict `:` functional-type(operands, results) + }]; +} + def Mrt_DivOp : Mrt_Op<"div", [Pure]> { let summary = "aclnnDiv"; let description = [{ diff --git a/tests/st/inferrt/ops/test_aclnn_contiguous.py b/tests/st/inferrt/ops/test_aclnn_contiguous.py new file mode 100644 index 0000000000000000000000000000000000000000..928bf2c00bfd5a65258d11ef531700af46e88c97 --- /dev/null +++ b/tests/st/inferrt/ops/test_aclnn_contiguous.py @@ -0,0 +1,38 @@ +import pytest +import torch + +from tests.mark_utils import arg_mark +from tests.ops_utils import AssertRtolEqual +from mrt.torch import backend + + +def op_func(input_tensor): + return input_tensor.contiguous() + + +def get_op_func_compiled(): + def custom_op_func(input_tensor): + return input_tensor.contiguous() + return torch.compile(custom_op_func, backend=backend) + + +@arg_mark(plat_marks=["platform_ascend"], level_mark="level0", card_mark="onecard", essential_mark="essential") +@pytest.mark.parametrize("pipeline", (True, False)) +@pytest.mark.parametrize("shape", [[10, 40], [20, 30, 35]]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32]) +def test_contiguous(pipeline, monkeypatch, shape, dtype): + """ + Feature: Test tensor.contiguous + Description: Test contiguous with dtype inputs + Expectation: The result is correct + """ + if pipeline: + monkeypatch.setenv("MRT_ENABLE_PIPELINE", "on") + + input_tensor = torch.rand(shape, dtype=dtype) + input_tensor = input_tensor.transpose(1, 0) + input_tensor_npu = input_tensor.npu() + cpu_output0 = op_func(input_tensor) + op_func_compiled = get_op_func_compiled() + npu_output = op_func_compiled(input_tensor_npu) + AssertRtolEqual(cpu_output0, npu_output.detach().cpu())