diff --git a/tf_adapter/ops/aicpu/npu_cpu_ops.cc b/tf_adapter/ops/aicpu/npu_cpu_ops.cc index 7e02ef5e75658e9849d461e9bf2422852ad82718..912dfdb45d92713fdd6854ac01943a8578a09735 100644 --- a/tf_adapter/ops/aicpu/npu_cpu_ops.cc +++ b/tf_adapter/ops/aicpu/npu_cpu_ops.cc @@ -241,6 +241,8 @@ REGISTER_OP("InitEmbeddingHashmap") .Attr("seed: int = 0") .Attr("seed2: int = 0") .Attr("filter_mode: string = 'no_filter' ") + .Attr("optimizer_mode: string = '' ") + .Attr("optimizer_params: list(float)") .SetShapeFn(shape_inference::NoOutputs); REGISTER_OP("EmbeddingTableImport") @@ -304,6 +306,8 @@ REGISTER_OP("EmbeddingTableFindAndInit") .Attr("default_key_or_value: bool = false") .Attr("default_key: int = 0") .Attr("default_value: float = 0") + .Attr("optimizer_mode: string = '' ") + .Attr("optimizer_params: list(float)") .SetShapeFn([](shape_inference::InferenceContext *c) { ShapeHandle keys_shape; TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &keys_shape));