From 7b49e1c986a65c7ac21545eb2407253f6913dd10 Mon Sep 17 00:00:00 2001 From: w00520461 Date: Mon, 10 Apr 2023 14:32:40 +0800 Subject: [PATCH] embeeding_init_op --- tf_adapter/ops/aicpu/npu_cpu_ops.cc | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/tf_adapter/ops/aicpu/npu_cpu_ops.cc b/tf_adapter/ops/aicpu/npu_cpu_ops.cc index 93c822cca..44471265d 100644 --- a/tf_adapter/ops/aicpu/npu_cpu_ops.cc +++ b/tf_adapter/ops/aicpu/npu_cpu_ops.cc @@ -232,7 +232,12 @@ REGISTER_OP("InitEmbeddingHashmap") .Attr("value_total_len: int = 0") .Attr("dtype: {uint8, uint16, float32} = DT_FLOAT") .Attr("embedding_dim: int = 0") - .Attr("random_alg: string = '' ") + .Attr("initializer_mode: string = '' ") + .Attr("constant_value: float = 0") + .Attr("min: float = -2") + .Attr("max: float = 2") + .Attr("mu: float = 0") + .Attr("sigma: float = 1") .Attr("seed: int = 0") .Attr("seed2: int = 0") .SetShapeFn(shape_inference::NoOutputs); @@ -286,7 +291,12 @@ REGISTER_OP("EmbeddingTableFindAndInit") .Output("values: float32") .Attr("embedding_dim: int = 0") .Attr("value_total_len: int = 0") - .Attr("random_alg: string = 'random_uniform'") + .Attr("initializer_mode: string = 'random_uniform'") + .Attr("constant_value: float = 0") + .Attr("min: float = -2") + .Attr("max: float = 2") + .Attr("mu: float = 0") + .Attr("sigma: float = 1") .Attr("seed: int = 0") .Attr("seed2: int = 0") .SetShapeFn([](shape_inference::InferenceContext *c) { @@ -344,6 +354,7 @@ REGISTER_OP("EmbeddingTableExport") .Input("table_id: int32") .Attr("embedding_dim: int = 0") .Attr("value_total_len: int = 0") + .Attr("export_mode: {'all', 'old', 'new', 'specifiednew'} = 'all'") .Attr("only_var_flag: bool = false") .Attr("file_type: string = 'bin' ") .SetShapeFn(shape_inference::NoOutputs); -- Gitee