diff --git a/tf_adapter/kernels/prod_virial_se_a_ops.cc b/tf_adapter/kernels/prod_virial_se_a_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..049d01782b5ea5c3f891713c4b75e2000c258aa6 --- /dev/null +++ b/tf_adapter/kernels/prod_virial_se_a_ops.cc @@ -0,0 +1,40 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" + +namespace tensorflow { +template +class ProdVirialSeAOp : public OpKernel { +public: + explicit ProdVirialSeAOp(OpKernelConstruction *ctx) : OpKernel(ctx) { + LOG(INFO) << "new ProdVirialSeAOp"; + } + ~ProdVirialSeAOp() { + LOG(INFO) << "del ProdVirialSeAOp"; + } + void Compute(OpKernelContext *ctx) override { + LOG(INFO) << "ProdVirialSeAOp Compute"; + } + bool IsExpensive() override { + LOG(INFO) << "ProdVirialSeAOp IsExpensive"; + return false; + } +}; + +REGISTER_KERNEL_BUILDER(Name("ProdVirialSeA").Device(DEVICE_CPU), ProdVirialSeAOp); +} // namespace tensorflow diff --git a/tf_adapter/ops/aicore/npu_aicore_ops.cc b/tf_adapter/ops/aicore/npu_aicore_ops.cc index 8bf4a4bce9292b6d2410643081640b9eda927173..d5b963277b89f6f28d94973ffa1fc99d0eedc5d3 100644 --- a/tf_adapter/ops/aicore/npu_aicore_ops.cc +++ b/tf_adapter/ops/aicore/npu_aicore_ops.cc @@ -477,5 +477,24 @@ REGISTER_OP("FusedLayerNormGrad") c->set_output(2, c->input(4)); return Status::OK(); }); + +REGISTER_OP("ProdVirialSeA") + .Input("net_deriv:T") + .Input("in_deriv:T") + .Input("rij:T") + .Input("nlist:int32") + .Input("natoms:int32") + .Output("virial:T") + .Output("atom_virial:T") + .Attr("n_a_sel:int = 0") + .Attr("n_r_sel:int = 0") + .Attr("T: {float32, float64}") + .SetIsStateful() + .SetShapeFn([](InferenceContext* c) { + int64_t nframes = c->Value(c->Dim(c->input(0), 0)); + c->set_output(0, c->MakeShape({nframes, 9})); + c->set_output(1, c->MakeShape({nframes, 254952})); + return Status::OK(); + }); } // namespace } // namespace tensorflow diff --git a/tf_adapter/python/npu_bridge/estimator/npu_aicore_ops.py b/tf_adapter/python/npu_bridge/estimator/npu_aicore_ops.py index 622e92a2ee96639665397bbbcfaa49d1aa73208e..f3603b0fbb8fe26f8d0a5f5b2f53386907d85034 100644 --- a/tf_adapter/python/npu_bridge/estimator/npu_aicore_ops.py +++ b/tf_adapter/python/npu_bridge/estimator/npu_aicore_ops.py @@ -181,3 +181,17 @@ def _layer_norm_grad(op, *grad): op.inputs[1]) return [pd_x, pd_gamma, pd_beta] + + +def prodvirialsea(net_deriv, in_deriv, rij, nlist, natoms, n_a_sel=0, n_r_sel=0, name=None): + """ + ProdVirialSeA op + """ + net_deriv = ops.convert_to_tensor(net_deriv, name="net_deriv") + in_deriv = ops.convert_to_tensor(in_deriv, name="in_deriv") + rij = ops.convert_to_tensor(rij, name="rij") + nlist = ops.convert_to_tensor(nlist, name="nlist") + natoms = ops.convert_to_tensor(natoms, name="natoms") + result = npu_aicore_ops.prod_virial_se_a(net_deriv, in_deriv, rij, nlist, natoms, n_a_sel, n_r_sel, name=name) + return result + diff --git a/tf_adapter/tests/st/kernels/testcase/prod_virial_se_a_ops_test.cc b/tf_adapter/tests/st/kernels/testcase/prod_virial_se_a_ops_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..bafb3c0aca151bf89ae067aa32fb5edc6318f668 --- /dev/null +++ b/tf_adapter/tests/st/kernels/testcase/prod_virial_se_a_ops_test.cc @@ -0,0 +1,68 @@ +#include "tensorflow/core/framework/fake_input.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/platform/test.h" +#include "tf_adapter/kernels/prod_virial_se_a_ops.cc" +#include "gtest/gtest.h" +#include +namespace tensorflow { +namespace { + +PartialTensorShape TShape(std::initializer_list dims) { + return PartialTensorShape(dims); +} + +FakeInputFunctor FakeInputStub(DataType dt) { + return [dt](const OpDef &op_def, int in_index, const NodeDef &node_def, + NodeDefBuilder *builder) { + char c = 'a' + (in_index % 26); + string in_node = string(&c, 1); + builder->Input(in_node, 0, dt); + return Status::OK(); + }; +} + +TEST(ProdVirialSeAOpTest, TestProdVirialSeA) { + DataTypeSlice input_types({DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_INT32, DT_INT32}); + MemoryTypeSlice input_memory_types; + DataTypeSlice output_types({DT_FLOAT, DT_FLOAT}); + MemoryTypeSlice output_memory_types; + DeviceBase *device = new DeviceBase(Env::Default()); + NodeDef *node_def = new NodeDef(); + OpDef *op_def = new OpDef(); + OpKernelConstruction *context = new OpKernelConstruction( + DEVICE_CPU, device, nullptr, node_def, op_def, nullptr, input_types, + input_memory_types, output_types, output_memory_types, 1, nullptr); + ProdVirialSeAOp prod_virial_se_a(context); + OpKernelContext *ctx = nullptr; + prod_virial_se_a.Compute(ctx); + prod_virial_se_a.IsExpensive(); + delete device; + delete node_def; + delete op_def; + delete context; +} + +TEST(ProdVirialSeAOpTest, TestProdVirialSeAShapeInference) { + const OpRegistrationData *reg; + TF_CHECK_OK(OpRegistry::Global()->LookUp("ProdVirialSeA", ®)); + OpDef op_def = reg->op_def; + NodeDef def; + TF_CHECK_OK(NodeDefBuilder("dummy", &op_def) + .Attr("T", DT_FLOAT) + .Attr("n_a_sel", DT_INT32) + .Attr("n_r_sel", DT_INT32) + .Input(FakeInputStub(DT_FLOAT)) + .Input(FakeInputStub(DT_FLOAT)) + .Input(FakeInputStub(DT_FLOAT)) + .Input(FakeInputStub(DT_INT32)) + .Input(FakeInputStub(DT_INT32)) + .Finalize(&def)); + shape_inference::InferenceContext c(0, &def, op_def, {TShape({1, 6782976}), TShape({1, 20348928}), + TShape({1, 5087232}), TShape({1, 1695744}), TShape({4})}, {}, {}, {}); + std::vector input_shapes; + TF_CHECK_OK(reg->shape_inference_fn(&c)); + ASSERT_EQ("[1,9]", c.DebugString(c.output(0))); + ASSERT_EQ("[1,254952]", c.DebugString(c.output(1))); +} +} // namespace +} // namespace tensorflow \ No newline at end of file diff --git a/tf_adapter/tests/ut/kernels/testcase/prod_virial_se_a_ops_test.cc b/tf_adapter/tests/ut/kernels/testcase/prod_virial_se_a_ops_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..bafb3c0aca151bf89ae067aa32fb5edc6318f668 --- /dev/null +++ b/tf_adapter/tests/ut/kernels/testcase/prod_virial_se_a_ops_test.cc @@ -0,0 +1,68 @@ +#include "tensorflow/core/framework/fake_input.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/platform/test.h" +#include "tf_adapter/kernels/prod_virial_se_a_ops.cc" +#include "gtest/gtest.h" +#include +namespace tensorflow { +namespace { + +PartialTensorShape TShape(std::initializer_list dims) { + return PartialTensorShape(dims); +} + +FakeInputFunctor FakeInputStub(DataType dt) { + return [dt](const OpDef &op_def, int in_index, const NodeDef &node_def, + NodeDefBuilder *builder) { + char c = 'a' + (in_index % 26); + string in_node = string(&c, 1); + builder->Input(in_node, 0, dt); + return Status::OK(); + }; +} + +TEST(ProdVirialSeAOpTest, TestProdVirialSeA) { + DataTypeSlice input_types({DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_INT32, DT_INT32}); + MemoryTypeSlice input_memory_types; + DataTypeSlice output_types({DT_FLOAT, DT_FLOAT}); + MemoryTypeSlice output_memory_types; + DeviceBase *device = new DeviceBase(Env::Default()); + NodeDef *node_def = new NodeDef(); + OpDef *op_def = new OpDef(); + OpKernelConstruction *context = new OpKernelConstruction( + DEVICE_CPU, device, nullptr, node_def, op_def, nullptr, input_types, + input_memory_types, output_types, output_memory_types, 1, nullptr); + ProdVirialSeAOp prod_virial_se_a(context); + OpKernelContext *ctx = nullptr; + prod_virial_se_a.Compute(ctx); + prod_virial_se_a.IsExpensive(); + delete device; + delete node_def; + delete op_def; + delete context; +} + +TEST(ProdVirialSeAOpTest, TestProdVirialSeAShapeInference) { + const OpRegistrationData *reg; + TF_CHECK_OK(OpRegistry::Global()->LookUp("ProdVirialSeA", ®)); + OpDef op_def = reg->op_def; + NodeDef def; + TF_CHECK_OK(NodeDefBuilder("dummy", &op_def) + .Attr("T", DT_FLOAT) + .Attr("n_a_sel", DT_INT32) + .Attr("n_r_sel", DT_INT32) + .Input(FakeInputStub(DT_FLOAT)) + .Input(FakeInputStub(DT_FLOAT)) + .Input(FakeInputStub(DT_FLOAT)) + .Input(FakeInputStub(DT_INT32)) + .Input(FakeInputStub(DT_INT32)) + .Finalize(&def)); + shape_inference::InferenceContext c(0, &def, op_def, {TShape({1, 6782976}), TShape({1, 20348928}), + TShape({1, 5087232}), TShape({1, 1695744}), TShape({4})}, {}, {}, {}); + std::vector input_shapes; + TF_CHECK_OK(reg->shape_inference_fn(&c)); + ASSERT_EQ("[1,9]", c.DebugString(c.output(0))); + ASSERT_EQ("[1,254952]", c.DebugString(c.output(1))); +} +} // namespace +} // namespace tensorflow \ No newline at end of file