diff --git a/tf_adapter/kernels/aicpu/npu_cpu_ops.cc b/tf_adapter/kernels/aicpu/npu_cpu_ops.cc index 1c00ce3b615dd74c408d7714a2a1b51a9dd6abe8..14093961e26f16d5c68c70437d3c47939b331b04 100644 --- a/tf_adapter/kernels/aicpu/npu_cpu_ops.cc +++ b/tf_adapter/kernels/aicpu/npu_cpu_ops.cc @@ -242,6 +242,41 @@ public: void Compute(OpKernelContext *context) override { ADP_LOG(INFO) << "UninitEmbeddingHashmapOp Compute"; } }; +class TableToResourceOp : public OpKernel { +public: + explicit TableToResourceOp(OpKernelConstruction *context) : OpKernel(context) {} + ~TableToResourceOp() override {} + void Compute(OpKernelContext *context) override { ADP_LOG(INFO) << "TableToResourceOp Compute"; } +}; + +class EmbeddingTableFindAndInitOp : public OpKernel { +public: + explicit EmbeddingTableFindAndInitOp(OpKernelConstruction *context) : OpKernel(context) {} + ~EmbeddingTableFindAndInitOp() override {} + void Compute(OpKernelContext *context) override { ADP_LOG(INFO) << "EmbeddingTableFindAndInitOp Compute"; } +}; + +class EmbeddingApplyAdamOp : public OpKernel { +public: + explicit EmbeddingApplyAdamOp(OpKernelConstruction *context) : OpKernel(context) {} + ~EmbeddingApplyAdamOp() override {} + void Compute(OpKernelContext *context) override { ADP_LOG(INFO) << "EmbeddingApplyAdamOp Compute"; } +}; + +class EmbeddingApplyAdaGradOp : public OpKernel { +public: + explicit EmbeddingApplyAdaGradOp(OpKernelConstruction *context) : OpKernel(context) {} + ~EmbeddingApplyAdaGradOp() override {} + void Compute(OpKernelContext *context) override { ADP_LOG(INFO) << "EmbeddingApplyAdaGradOp Compute"; } +}; + +class EmbeddingTableExportOp : public OpKernel { +public: + explicit EmbeddingTableExportOp(OpKernelConstruction *context) : OpKernel(context) {} + ~EmbeddingTableExportOp() override {} + void Compute(OpKernelContext *context) override { ADP_LOG(INFO) << "EmbeddingTableExportOp Compute"; } +}; + REGISTER_KERNEL_BUILDER(Name("ScatterElementsV2").Device(DEVICE_CPU), ScatterElementsV2Op); REGISTER_KERNEL_BUILDER(Name("EmbeddingRankId").Device(DEVICE_CPU), EmbeddingRankIdOpKernel); REGISTER_KERNEL_BUILDER(Name("EmbeddingLocalIndex").Device(DEVICE_CPU), EmbeddingLocalIndexOpKernel); @@ -266,6 +301,11 @@ REGISTER_KERNEL_BUILDER(Name("EmbeddingTableFind").Device(DEVICE_CPU), Embedding REGISTER_KERNEL_BUILDER(Name("EmbeddingTableImport").Device(DEVICE_CPU), EmbeddingTableImportOp); REGISTER_KERNEL_BUILDER(Name("UninitPartitionMap").Device(DEVICE_CPU), UninitPartitionMapOp); REGISTER_KERNEL_BUILDER(Name("UninitEmbeddingHashmap").Device(DEVICE_CPU), UninitEmbeddingHashmapOp); +REGISTER_KERNEL_BUILDER(Name("TableToResource").Device(DEVICE_CPU), TableToResourceOp); +REGISTER_KERNEL_BUILDER(Name("EmbeddingTableFindAndInit").Device(DEVICE_CPU), EmbeddingTableFindAndInitOp); +REGISTER_KERNEL_BUILDER(Name("EmbeddingApplyAdam").Device(DEVICE_CPU), EmbeddingApplyAdamOp); +REGISTER_KERNEL_BUILDER(Name("EmbeddingApplyAdaGrad").Device(DEVICE_CPU), EmbeddingApplyAdaGradOp); +REGISTER_KERNEL_BUILDER(Name("EmbeddingTableExport").Device(DEVICE_CPU), EmbeddingTableExportOp); class DecodeImageV3Op : public OpKernel { public: diff --git a/tf_adapter/ops/aicpu/npu_cpu_ops.cc b/tf_adapter/ops/aicpu/npu_cpu_ops.cc index 4d58758179795f4e5b4738ad98f9aac1d0880300..23fb074e91a6401b58db5aad8581d08de59def74 100644 --- a/tf_adapter/ops/aicpu/npu_cpu_ops.cc +++ b/tf_adapter/ops/aicpu/npu_cpu_ops.cc @@ -266,6 +266,77 @@ REGISTER_OP("UninitEmbeddingHashmap") .Input("table_id: int32") .SetShapeFn(shape_inference::NoOutputs); +REGISTER_OP("TableToResource") + .Input("table_id: int32") + .Output("table_handle: resource") + .SetShapeFn([](shape_inference::InferenceContext *c) { + auto data_shape = c->input(0); + c->set_output(0, data_shape); + return Status::OK(); + }); + +REGISTER_OP("EmbeddingTableFindAndInit") + .Input("table_id: int32") + .Input("keys: int64") + .Output("values: float32") + .Attr("embedding_dim: int = 0") + .Attr("random_alg: string = 'random_uniform' ") + .Attr("seed: int = 0") + .Attr("seed2: int = 0") + .SetShapeFn([](shape_inference::InferenceContext *c) { + auto data_shape = c->input(1); + c->set_output(0, data_shape); + return Status::OK(); + }); + +REGISTER_OP("EmbeddingApplyAdam") + .Input("var_handle: Ref(resource)") + .Input("beta1_power: T") + .Input("beta2_power: T") + .Input("lr: T") + .Input("beta1: T") + .Input("beta2: T") + .Input("epsilon: T") + .Input("grad: T") + .Input("keys: int64") + .Input("global_step: Tstep") + .Output("var_handle1: Ref(resource)") + .Attr("embedding_dim: int = 0") + .Attr("T: {float32, float16}") + .Attr("Tstep: {int32, int64}") + .SetShapeFn([](shape_inference::InferenceContext *c) { + auto data_shape = c->input(0); + c->set_output(0, data_shape); + return Status::OK(); + }); + +REGISTER_OP("EmbeddingApplyAdaGrad") + .Input("var_handle: Ref(resource)") + .Input("lr: T") + .Input("grad: T") + .Input("keys: int64") + .Input("global_step: Tstep") + .Output("var_handle1: Ref(resource)") + .Attr("embedding_dim: int = 0") + .Attr("T: {float32, float16}") + .Attr("Tstep: {int32, int64}") + .SetShapeFn([](shape_inference::InferenceContext *c) { + auto data_shape = c->input(0); + c->set_output(0, data_shape); + return Status::OK(); + }); + +REGISTER_OP("EmbeddingTableExport") + .Input("file_path: string") + .Input("file_name: string") + .Input("ps_id: int32") + .Input("table_id: int32") + .Attr("embedding_dim: int = 0") + .Attr("value_total_len: int = 0") + .Attr("only_var_flag: bool = false") + .Attr("file_type: string = 'bin' ") + .SetShapeFn(shape_inference::NoOutputs); + // regist dense image warp op REGISTER_OP("DenseImageWarp") .Input("image: T") diff --git a/tf_adapter/tests/st/kernels/testcase/npu_cpu_ops_test.cc b/tf_adapter/tests/st/kernels/testcase/npu_cpu_ops_test.cc index 827b714dc51d620e6f2aa143f82c47f90884e1be..dc845f8ccdfa2bf6c9c19cb1c92d278a73f17544 100644 --- a/tf_adapter/tests/st/kernels/testcase/npu_cpu_ops_test.cc +++ b/tf_adapter/tests/st/kernels/testcase/npu_cpu_ops_test.cc @@ -146,4 +146,84 @@ TEST(EmbeddingOpsTest, TestUninitEmbeddingHashmap) { delete op_def; delete context; } + +TEST(EmbeddingOpsTest, TestTableToResource) { + DataTypeSlice input_types({DT_INT32}); + MemoryTypeSlice input_memory_types; + DataTypeSlice output_types({DT_RESOURCE}); + MemoryTypeSlice output_memory_types; + DeviceBase *device = new DeviceBase(Env::Default()); + NodeDef *node_def = new NodeDef(); + OpDef *op_def = new OpDef(); + OpKernelConstruction *context = new OpKernelConstruction(DEVICE_CPU, device, nullptr, node_def, op_def, nullptr, + input_types, input_memory_types, output_types, output_memory_types, + 1, nullptr); + TableToResourceOp cache(context); + OpKernelContext *ctx = nullptr; + cache.Compute(ctx); + delete device; + delete node_def; + delete op_def; + delete context; +} + +TEST(EmbeddingOpsTest, TestEmbeddingTableFindAndInit) { + DataTypeSlice input_types({DT_INT32}); + MemoryTypeSlice input_memory_types; + DataTypeSlice output_types({DT_INT32}); + MemoryTypeSlice output_memory_types; + DeviceBase *device = new DeviceBase(Env::Default()); + NodeDef *node_def = new NodeDef(); + OpDef *op_def = new OpDef(); + OpKernelConstruction *context = new OpKernelConstruction(DEVICE_CPU, device, nullptr, node_def, op_def, nullptr, + input_types, input_memory_types, output_types, output_memory_types, + 1, nullptr); + EmbeddingTableFindAndInitOp cache(context); + OpKernelContext *ctx = nullptr; + cache.Compute(ctx); + delete device; + delete node_def; + delete op_def; + delete context; +} + +TEST(EmbeddingOpsTest, TestEmbeddingTableExport) { + DataTypeSlice input_types({DT_STRING}); + MemoryTypeSlice input_memory_types; + DataTypeSlice output_types({DT_STRING}); + MemoryTypeSlice output_memory_types; + DeviceBase *device = new DeviceBase(Env::Default()); + NodeDef *node_def = new NodeDef(); + OpDef *op_def = new OpDef(); + OpKernelConstruction *context = new OpKernelConstruction(DEVICE_CPU, device, nullptr, node_def, op_def, nullptr, + input_types, input_memory_types, output_types, output_memory_types, + 1, nullptr); + EmbeddingTableExportOp cache(context); + OpKernelContext *ctx = nullptr; + cache.Compute(ctx); + delete device; + delete node_def; + delete op_def; + delete context; +} + +TEST(EmbeddingOpsTest, TestEmbeddingApplyAdam) { + DataTypeSlice input_types({DT_RESOURCE}); + MemoryTypeSlice input_memory_types; + DataTypeSlice output_types({DT_RESOURCE}); + MemoryTypeSlice output_memory_types; + DeviceBase *device = new DeviceBase(Env::Default()); + NodeDef *node_def = new NodeDef(); + OpDef *op_def = new OpDef(); + OpKernelConstruction *context = new OpKernelConstruction(DEVICE_CPU, device, nullptr, node_def, op_def, nullptr, + input_types, input_memory_types, output_types, output_memory_types, + 1, nullptr); + EmbeddingApplyAdamOp cache(context); + OpKernelContext *ctx = nullptr; + cache.Compute(ctx); + delete device; + delete node_def; + delete op_def; + delete context; +} } \ No newline at end of file diff --git a/tf_adapter/tests/ut/kernels/testcase/npu_cpu_ops_test.cc b/tf_adapter/tests/ut/kernels/testcase/npu_cpu_ops_test.cc index 5e9f047f34c6512ea8df79d30c5f51553c45223b..d5b31002e6e10e506cf3ea129d1b2a4a409ebd22 100644 --- a/tf_adapter/tests/ut/kernels/testcase/npu_cpu_ops_test.cc +++ b/tf_adapter/tests/ut/kernels/testcase/npu_cpu_ops_test.cc @@ -166,4 +166,83 @@ TEST(EmbeddingOpsTest, TestUninitEmbeddingHashmap) { delete op_def; delete context; } +TEST(EmbeddingOpsTest, TestTableToResource) { + DataTypeSlice input_types({DT_INT32}); + MemoryTypeSlice input_memory_types; + DataTypeSlice output_types({DT_RESOURCE}); + MemoryTypeSlice output_memory_types; + DeviceBase *device = new DeviceBase(Env::Default()); + NodeDef *node_def = new NodeDef(); + OpDef *op_def = new OpDef(); + OpKernelConstruction *context = new OpKernelConstruction(DEVICE_CPU, device, nullptr, node_def, op_def, nullptr, + input_types, input_memory_types, output_types, output_memory_types, + 1, nullptr); + TableToResourceOp cache(context); + OpKernelContext *ctx = nullptr; + cache.Compute(ctx); + delete device; + delete node_def; + delete op_def; + delete context; +} + +TEST(EmbeddingOpsTest, TestEmbeddingTableFindAndInit) { + DataTypeSlice input_types({DT_INT32}); + MemoryTypeSlice input_memory_types; + DataTypeSlice output_types({DT_INT32}); + MemoryTypeSlice output_memory_types; + DeviceBase *device = new DeviceBase(Env::Default()); + NodeDef *node_def = new NodeDef(); + OpDef *op_def = new OpDef(); + OpKernelConstruction *context = new OpKernelConstruction(DEVICE_CPU, device, nullptr, node_def, op_def, nullptr, + input_types, input_memory_types, output_types, output_memory_types, + 1, nullptr); + EmbeddingTableFindAndInitOp cache(context); + OpKernelContext *ctx = nullptr; + cache.Compute(ctx); + delete device; + delete node_def; + delete op_def; + delete context; +} + +TEST(EmbeddingOpsTest, TestEmbeddingTableExport) { + DataTypeSlice input_types({DT_STRING}); + MemoryTypeSlice input_memory_types; + DataTypeSlice output_types({DT_STRING}); + MemoryTypeSlice output_memory_types; + DeviceBase *device = new DeviceBase(Env::Default()); + NodeDef *node_def = new NodeDef(); + OpDef *op_def = new OpDef(); + OpKernelConstruction *context = new OpKernelConstruction(DEVICE_CPU, device, nullptr, node_def, op_def, nullptr, + input_types, input_memory_types, output_types, output_memory_types, + 1, nullptr); + EmbeddingTableExportOp cache(context); + OpKernelContext *ctx = nullptr; + cache.Compute(ctx); + delete device; + delete node_def; + delete op_def; + delete context; +} + +TEST(EmbeddingOpsTest, TestEmbeddingApplyAdam) { + DataTypeSlice input_types({DT_RESOURCE}); + MemoryTypeSlice input_memory_types; + DataTypeSlice output_types({DT_RESOURCE}); + MemoryTypeSlice output_memory_types; + DeviceBase *device = new DeviceBase(Env::Default()); + NodeDef *node_def = new NodeDef(); + OpDef *op_def = new OpDef(); + OpKernelConstruction *context = new OpKernelConstruction(DEVICE_CPU, device, nullptr, node_def, op_def, nullptr, + input_types, input_memory_types, output_types, output_memory_types, + 1, nullptr); + EmbeddingApplyAdamOp cache(context); + OpKernelContext *ctx = nullptr; + cache.Compute(ctx); + delete device; + delete node_def; + delete op_def; + delete context; +} } \ No newline at end of file