From fb228fc9c0d5261000e639bca7aec8c9ab611daa Mon Sep 17 00:00:00 2001 From: carol233 Date: Fri, 20 Jan 2023 12:50:21 +0800 Subject: [PATCH 1/5] Modified tf_adapter/kernels/aicpu/npu_cpu_ops.cc Modified tf_adapter/ops/aicpu/npu_cpu_ops.cc --- tf_adapter/kernels/aicpu/npu_cpu_ops.cc | 40 ++++++++++++++ tf_adapter/ops/aicpu/npu_cpu_ops.cc | 71 +++++++++++++++++++++++++ 2 files changed, 111 insertions(+) diff --git a/tf_adapter/kernels/aicpu/npu_cpu_ops.cc b/tf_adapter/kernels/aicpu/npu_cpu_ops.cc index 1c00ce3b6..14093961e 100644 --- a/tf_adapter/kernels/aicpu/npu_cpu_ops.cc +++ b/tf_adapter/kernels/aicpu/npu_cpu_ops.cc @@ -242,6 +242,41 @@ public: void Compute(OpKernelContext *context) override { ADP_LOG(INFO) << "UninitEmbeddingHashmapOp Compute"; } }; +class TableToResourceOp : public OpKernel { +public: + explicit TableToResourceOp(OpKernelConstruction *context) : OpKernel(context) {} + ~TableToResourceOp() override {} + void Compute(OpKernelContext *context) override { ADP_LOG(INFO) << "TableToResourceOp Compute"; } +}; + +class EmbeddingTableFindAndInitOp : public OpKernel { +public: + explicit EmbeddingTableFindAndInitOp(OpKernelConstruction *context) : OpKernel(context) {} + ~EmbeddingTableFindAndInitOp() override {} + void Compute(OpKernelContext *context) override { ADP_LOG(INFO) << "EmbeddingTableFindAndInitOp Compute"; } +}; + +class EmbeddingApplyAdamOp : public OpKernel { +public: + explicit EmbeddingApplyAdamOp(OpKernelConstruction *context) : OpKernel(context) {} + ~EmbeddingApplyAdamOp() override {} + void Compute(OpKernelContext *context) override { ADP_LOG(INFO) << "EmbeddingApplyAdamOp Compute"; } +}; + +class EmbeddingApplyAdaGradOp : public OpKernel { +public: + explicit EmbeddingApplyAdaGradOp(OpKernelConstruction *context) : OpKernel(context) {} + ~EmbeddingApplyAdaGradOp() override {} + void Compute(OpKernelContext *context) override { ADP_LOG(INFO) << "EmbeddingApplyAdaGradOp Compute"; } +}; + +class EmbeddingTableExportOp : public OpKernel { +public: + explicit EmbeddingTableExportOp(OpKernelConstruction *context) : OpKernel(context) {} + ~EmbeddingTableExportOp() override {} + void Compute(OpKernelContext *context) override { ADP_LOG(INFO) << "EmbeddingTableExportOp Compute"; } +}; + REGISTER_KERNEL_BUILDER(Name("ScatterElementsV2").Device(DEVICE_CPU), ScatterElementsV2Op); REGISTER_KERNEL_BUILDER(Name("EmbeddingRankId").Device(DEVICE_CPU), EmbeddingRankIdOpKernel); REGISTER_KERNEL_BUILDER(Name("EmbeddingLocalIndex").Device(DEVICE_CPU), EmbeddingLocalIndexOpKernel); @@ -266,6 +301,11 @@ REGISTER_KERNEL_BUILDER(Name("EmbeddingTableFind").Device(DEVICE_CPU), Embedding REGISTER_KERNEL_BUILDER(Name("EmbeddingTableImport").Device(DEVICE_CPU), EmbeddingTableImportOp); REGISTER_KERNEL_BUILDER(Name("UninitPartitionMap").Device(DEVICE_CPU), UninitPartitionMapOp); REGISTER_KERNEL_BUILDER(Name("UninitEmbeddingHashmap").Device(DEVICE_CPU), UninitEmbeddingHashmapOp); +REGISTER_KERNEL_BUILDER(Name("TableToResource").Device(DEVICE_CPU), TableToResourceOp); +REGISTER_KERNEL_BUILDER(Name("EmbeddingTableFindAndInit").Device(DEVICE_CPU), EmbeddingTableFindAndInitOp); +REGISTER_KERNEL_BUILDER(Name("EmbeddingApplyAdam").Device(DEVICE_CPU), EmbeddingApplyAdamOp); +REGISTER_KERNEL_BUILDER(Name("EmbeddingApplyAdaGrad").Device(DEVICE_CPU), EmbeddingApplyAdaGradOp); +REGISTER_KERNEL_BUILDER(Name("EmbeddingTableExport").Device(DEVICE_CPU), EmbeddingTableExportOp); class DecodeImageV3Op : public OpKernel { public: diff --git a/tf_adapter/ops/aicpu/npu_cpu_ops.cc b/tf_adapter/ops/aicpu/npu_cpu_ops.cc index 4d5875817..b5a7b30c2 100644 --- a/tf_adapter/ops/aicpu/npu_cpu_ops.cc +++ b/tf_adapter/ops/aicpu/npu_cpu_ops.cc @@ -266,6 +266,77 @@ REGISTER_OP("UninitEmbeddingHashmap") .Input("table_id: int32") .SetShapeFn(shape_inference::NoOutputs); +REGISTER_OP("TableToResource") + .Input("table_id: int32") + .Output("table_handle: resource") + .SetShapeFn([](shape_inference::InferenceContext *c) { + auto data_shape = c->input(0); + c->set_output(0, data_shape); + return Status::OK(); + }); + +REGISTER_OP("EmbeddingTableFindAndInit") + .Input("table_id: int32") + .Input("keys: int64") + .Output("values: float32") + .Attr("embedding_dim: int = 0") + .Attr("random_alg: string = 'random_uniform' ") + .Attr("seed: int = 0") + .Attr("seed2: int = 0") + .SetShapeFn([](shape_inference::InferenceContext *c) { + auto data_shape = c->input(1); + c->set_output(0, data_shape); + return Status::OK(); + }); + +REGISTER_OP("EmbeddingApplyAdam") + .Input("var_handle: resource") + .Input("beta1_power: T") + .Input("beta2_power: T") + .Input("lr: T") + .Input("beta1: T") + .Input("beta2: T") + .Input("epsilon: T") + .Input("grad: T") + .Input("keys: int64") + .Input("global_step: Tstep") + .Output("var_handle: resource") + .Attr("embedding_dim: int = 0") + .Attr("T: {float32, float16}") + .Attr("Tstep: {int32, int64}") + .SetShapeFn([](shape_inference::InferenceContext *c) { + auto data_shape = c->input(0); + c->set_output(0, data_shape); + return Status::OK(); + }); + +REGISTER_OP("EmbeddingApplyAdaGrad") + .Input("var_handle: resource") + .Input("lr: T") + .Input("grad: T") + .Input("keys: int64") + .Input("global_step: Tstep") + .Output("var_handle: resource") + .Attr("embedding_dim: int = 0") + .Attr("T: {float32, float16}") + .Attr("Tstep: {int32, int64}") + .SetShapeFn([](shape_inference::InferenceContext *c) { + auto data_shape = c->input(0); + c->set_output(0, data_shape); + return Status::OK(); + }); + +REGISTER_OP("EmbeddingTableExport") + .Input("file_path: string") + .Input("file_name: string") + .Input("ps_id: int32") + .Input("table_id: int32") + .Attr("embedding_dim: int = 0") + .Attr("value_total_len: int = 0") + .Attr("only_var_flag: bool = false") + .Attr("file_type: string = 'bin' ") + .SetShapeFn(shape_inference::NoOutputs); + // regist dense image warp op REGISTER_OP("DenseImageWarp") .Input("image: T") -- Gitee From 2b9e4718ddba9a1acc1105dfc99914e8c7e3d581 Mon Sep 17 00:00:00 2001 From: carol233 Date: Fri, 20 Jan 2023 13:09:52 +0800 Subject: [PATCH 2/5] Modified tf_adapter/kernels/aicpu/npu_cpu_ops.cc Modified tf_adapter/ops/aicpu/npu_cpu_ops.cc --- tf_adapter/ops/aicpu/npu_cpu_ops.cc | 8 ++-- .../st/kernels/testcase/npu_cpu_ops_test.cc | 40 +++++++++++++++++++ 2 files changed, 44 insertions(+), 4 deletions(-) diff --git a/tf_adapter/ops/aicpu/npu_cpu_ops.cc b/tf_adapter/ops/aicpu/npu_cpu_ops.cc index b5a7b30c2..23fb074e9 100644 --- a/tf_adapter/ops/aicpu/npu_cpu_ops.cc +++ b/tf_adapter/ops/aicpu/npu_cpu_ops.cc @@ -290,7 +290,7 @@ REGISTER_OP("EmbeddingTableFindAndInit") }); REGISTER_OP("EmbeddingApplyAdam") - .Input("var_handle: resource") + .Input("var_handle: Ref(resource)") .Input("beta1_power: T") .Input("beta2_power: T") .Input("lr: T") @@ -300,7 +300,7 @@ REGISTER_OP("EmbeddingApplyAdam") .Input("grad: T") .Input("keys: int64") .Input("global_step: Tstep") - .Output("var_handle: resource") + .Output("var_handle1: Ref(resource)") .Attr("embedding_dim: int = 0") .Attr("T: {float32, float16}") .Attr("Tstep: {int32, int64}") @@ -311,12 +311,12 @@ REGISTER_OP("EmbeddingApplyAdam") }); REGISTER_OP("EmbeddingApplyAdaGrad") - .Input("var_handle: resource") + .Input("var_handle: Ref(resource)") .Input("lr: T") .Input("grad: T") .Input("keys: int64") .Input("global_step: Tstep") - .Output("var_handle: resource") + .Output("var_handle1: Ref(resource)") .Attr("embedding_dim: int = 0") .Attr("T: {float32, float16}") .Attr("Tstep: {int32, int64}") diff --git a/tf_adapter/tests/st/kernels/testcase/npu_cpu_ops_test.cc b/tf_adapter/tests/st/kernels/testcase/npu_cpu_ops_test.cc index 827b714dc..e93122c68 100644 --- a/tf_adapter/tests/st/kernels/testcase/npu_cpu_ops_test.cc +++ b/tf_adapter/tests/st/kernels/testcase/npu_cpu_ops_test.cc @@ -146,4 +146,44 @@ TEST(EmbeddingOpsTest, TestUninitEmbeddingHashmap) { delete op_def; delete context; } + +TEST(EmbeddingOpsTest, TestTableToResource) { + DataTypeSlice input_types({DT_INT32}); + MemoryTypeSlice input_memory_types; + DataTypeSlice output_types({DT_RESOURCE}); + MemoryTypeSlice output_memory_types; + DeviceBase *device = new DeviceBase(Env::Default()); + NodeDef *node_def = new NodeDef(); + OpDef *op_def = new OpDef(); + OpKernelConstruction *context = new OpKernelConstruction(DEVICE_CPU, device, nullptr, node_def, op_def, nullptr, + input_types, input_memory_types, output_types, output_memory_types, + 1, nullptr); + UninitEmbeddingHashmapOp cache(context); + OpKernelContext *ctx = nullptr; + cache.Compute(ctx); + delete device; + delete node_def; + delete op_def; + delete context; +} + +TEST(EmbeddingOpsTest, TestEmbeddingTableFindAndInit) { + DataTypeSlice input_types({DT_INT32}); + MemoryTypeSlice input_memory_types; + DataTypeSlice output_types({DT_INT32}); + MemoryTypeSlice output_memory_types; + DeviceBase *device = new DeviceBase(Env::Default()); + NodeDef *node_def = new NodeDef(); + OpDef *op_def = new OpDef(); + OpKernelConstruction *context = new OpKernelConstruction(DEVICE_CPU, device, nullptr, node_def, op_def, nullptr, + input_types, input_memory_types, output_types, output_memory_types, + 1, nullptr); + UninitEmbeddingHashmapOp cache(context); + OpKernelContext *ctx = nullptr; + cache.Compute(ctx); + delete device; + delete node_def; + delete op_def; + delete context; +} } \ No newline at end of file -- Gitee From aefa7a193961844d7ccb054bbeeefb18174c747d Mon Sep 17 00:00:00 2001 From: carol233 Date: Fri, 20 Jan 2023 13:29:17 +0800 Subject: [PATCH 3/5] Modified tf_adapter/kernels/aicpu/npu_cpu_ops.cc Modified tf_adapter/ops/aicpu/npu_cpu_ops.cc --- .../st/kernels/testcase/npu_cpu_ops_test.cc | 44 ++++++++++++++++++- 1 file changed, 42 insertions(+), 2 deletions(-) diff --git a/tf_adapter/tests/st/kernels/testcase/npu_cpu_ops_test.cc b/tf_adapter/tests/st/kernels/testcase/npu_cpu_ops_test.cc index e93122c68..c01019243 100644 --- a/tf_adapter/tests/st/kernels/testcase/npu_cpu_ops_test.cc +++ b/tf_adapter/tests/st/kernels/testcase/npu_cpu_ops_test.cc @@ -158,7 +158,7 @@ TEST(EmbeddingOpsTest, TestTableToResource) { OpKernelConstruction *context = new OpKernelConstruction(DEVICE_CPU, device, nullptr, node_def, op_def, nullptr, input_types, input_memory_types, output_types, output_memory_types, 1, nullptr); - UninitEmbeddingHashmapOp cache(context); + TableToResourceOp cache(context); OpKernelContext *ctx = nullptr; cache.Compute(ctx); delete device; @@ -178,7 +178,47 @@ TEST(EmbeddingOpsTest, TestEmbeddingTableFindAndInit) { OpKernelConstruction *context = new OpKernelConstruction(DEVICE_CPU, device, nullptr, node_def, op_def, nullptr, input_types, input_memory_types, output_types, output_memory_types, 1, nullptr); - UninitEmbeddingHashmapOp cache(context); + EmbeddingTableFindAndInitOp cache(context); + OpKernelContext *ctx = nullptr; + cache.Compute(ctx); + delete device; + delete node_def; + delete op_def; + delete context; +} + +TEST(EmbeddingOpsTest, TestEmbeddingTableExport) { + DataTypeSlice input_types({DT_STRING}); + MemoryTypeSlice input_memory_types; + DataTypeSlice output_types({DT_STRING}); + MemoryTypeSlice output_memory_types; + DeviceBase *device = new DeviceBase(Env::Default()); + NodeDef *node_def = new NodeDef(); + OpDef *op_def = new OpDef(); + OpKernelConstruction *context = new OpKernelConstruction(DEVICE_CPU, device, nullptr, node_def, op_def, nullptr, + input_types, input_memory_types, output_types, output_memory_types, + 1, nullptr); + EmbeddingTableExportOp cache(context); + OpKernelContext *ctx = nullptr; + cache.Compute(ctx); + delete device; + delete node_def; + delete op_def; + delete context; +} + +TEST(EmbeddingOpsTest, TestEmbeddingTableExport) { + DataTypeSlice input_types({DT_RESOURCE}); + MemoryTypeSlice input_memory_types; + DataTypeSlice output_types({DT_RESOURCE}); + MemoryTypeSlice output_memory_types; + DeviceBase *device = new DeviceBase(Env::Default()); + NodeDef *node_def = new NodeDef(); + OpDef *op_def = new OpDef(); + OpKernelConstruction *context = new OpKernelConstruction(DEVICE_CPU, device, nullptr, node_def, op_def, nullptr, + input_types, input_memory_types, output_types, output_memory_types, + 1, nullptr); + EmbeddingApplyAdamOp cache(context); OpKernelContext *ctx = nullptr; cache.Compute(ctx); delete device; -- Gitee From 7761ddd8a9a8b29d18e283fa28bd0e65f3b2eb13 Mon Sep 17 00:00:00 2001 From: carol233 Date: Fri, 20 Jan 2023 13:32:57 +0800 Subject: [PATCH 4/5] Modified tf_adapter/kernels/aicpu/npu_cpu_ops.cc Modified tf_adapter/ops/aicpu/npu_cpu_ops.cc --- tf_adapter/tests/st/kernels/testcase/npu_cpu_ops_test.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tf_adapter/tests/st/kernels/testcase/npu_cpu_ops_test.cc b/tf_adapter/tests/st/kernels/testcase/npu_cpu_ops_test.cc index c01019243..dc845f8cc 100644 --- a/tf_adapter/tests/st/kernels/testcase/npu_cpu_ops_test.cc +++ b/tf_adapter/tests/st/kernels/testcase/npu_cpu_ops_test.cc @@ -207,7 +207,7 @@ TEST(EmbeddingOpsTest, TestEmbeddingTableExport) { delete context; } -TEST(EmbeddingOpsTest, TestEmbeddingTableExport) { +TEST(EmbeddingOpsTest, TestEmbeddingApplyAdam) { DataTypeSlice input_types({DT_RESOURCE}); MemoryTypeSlice input_memory_types; DataTypeSlice output_types({DT_RESOURCE}); -- Gitee From 3847f94031658284c688938406b32ca529f5f02a Mon Sep 17 00:00:00 2001 From: carol233 Date: Fri, 20 Jan 2023 13:38:21 +0800 Subject: [PATCH 5/5] Modified tf_adapter/kernels/aicpu/npu_cpu_ops.cc Modified tf_adapter/ops/aicpu/npu_cpu_ops.cc --- .../ut/kernels/testcase/npu_cpu_ops_test.cc | 79 +++++++++++++++++++ 1 file changed, 79 insertions(+) diff --git a/tf_adapter/tests/ut/kernels/testcase/npu_cpu_ops_test.cc b/tf_adapter/tests/ut/kernels/testcase/npu_cpu_ops_test.cc index 5e9f047f3..d5b31002e 100644 --- a/tf_adapter/tests/ut/kernels/testcase/npu_cpu_ops_test.cc +++ b/tf_adapter/tests/ut/kernels/testcase/npu_cpu_ops_test.cc @@ -166,4 +166,83 @@ TEST(EmbeddingOpsTest, TestUninitEmbeddingHashmap) { delete op_def; delete context; } +TEST(EmbeddingOpsTest, TestTableToResource) { + DataTypeSlice input_types({DT_INT32}); + MemoryTypeSlice input_memory_types; + DataTypeSlice output_types({DT_RESOURCE}); + MemoryTypeSlice output_memory_types; + DeviceBase *device = new DeviceBase(Env::Default()); + NodeDef *node_def = new NodeDef(); + OpDef *op_def = new OpDef(); + OpKernelConstruction *context = new OpKernelConstruction(DEVICE_CPU, device, nullptr, node_def, op_def, nullptr, + input_types, input_memory_types, output_types, output_memory_types, + 1, nullptr); + TableToResourceOp cache(context); + OpKernelContext *ctx = nullptr; + cache.Compute(ctx); + delete device; + delete node_def; + delete op_def; + delete context; +} + +TEST(EmbeddingOpsTest, TestEmbeddingTableFindAndInit) { + DataTypeSlice input_types({DT_INT32}); + MemoryTypeSlice input_memory_types; + DataTypeSlice output_types({DT_INT32}); + MemoryTypeSlice output_memory_types; + DeviceBase *device = new DeviceBase(Env::Default()); + NodeDef *node_def = new NodeDef(); + OpDef *op_def = new OpDef(); + OpKernelConstruction *context = new OpKernelConstruction(DEVICE_CPU, device, nullptr, node_def, op_def, nullptr, + input_types, input_memory_types, output_types, output_memory_types, + 1, nullptr); + EmbeddingTableFindAndInitOp cache(context); + OpKernelContext *ctx = nullptr; + cache.Compute(ctx); + delete device; + delete node_def; + delete op_def; + delete context; +} + +TEST(EmbeddingOpsTest, TestEmbeddingTableExport) { + DataTypeSlice input_types({DT_STRING}); + MemoryTypeSlice input_memory_types; + DataTypeSlice output_types({DT_STRING}); + MemoryTypeSlice output_memory_types; + DeviceBase *device = new DeviceBase(Env::Default()); + NodeDef *node_def = new NodeDef(); + OpDef *op_def = new OpDef(); + OpKernelConstruction *context = new OpKernelConstruction(DEVICE_CPU, device, nullptr, node_def, op_def, nullptr, + input_types, input_memory_types, output_types, output_memory_types, + 1, nullptr); + EmbeddingTableExportOp cache(context); + OpKernelContext *ctx = nullptr; + cache.Compute(ctx); + delete device; + delete node_def; + delete op_def; + delete context; +} + +TEST(EmbeddingOpsTest, TestEmbeddingApplyAdam) { + DataTypeSlice input_types({DT_RESOURCE}); + MemoryTypeSlice input_memory_types; + DataTypeSlice output_types({DT_RESOURCE}); + MemoryTypeSlice output_memory_types; + DeviceBase *device = new DeviceBase(Env::Default()); + NodeDef *node_def = new NodeDef(); + OpDef *op_def = new OpDef(); + OpKernelConstruction *context = new OpKernelConstruction(DEVICE_CPU, device, nullptr, node_def, op_def, nullptr, + input_types, input_memory_types, output_types, output_memory_types, + 1, nullptr); + EmbeddingApplyAdamOp cache(context); + OpKernelContext *ctx = nullptr; + cache.Compute(ctx); + delete device; + delete node_def; + delete op_def; + delete context; +} } \ No newline at end of file -- Gitee