diff --git a/tf_adapter/kernels/aicore/npu_aicore_ops.cc b/tf_adapter/kernels/aicore/npu_aicore_ops.cc index cfa22fb75a44cf834c59ef5ad766a7a60841d6cf..34de0097e2e095143ee53154eb0ccbfde13b99f7 100644 --- a/tf_adapter/kernels/aicore/npu_aicore_ops.cc +++ b/tf_adapter/kernels/aicore/npu_aicore_ops.cc @@ -43,6 +43,36 @@ class FastGeluOp : public tensorflow::OpKernel { } }; +class EmbeddingHashTableImportOp : public tensorflow::OpKernel { +public: + explicit EmbeddingHashTableImportOp(tensorflow::OpKernelConstruction *context) : OpKernel(context) {} + ~EmbeddingHashTableImportOp() override {} + void Compute(tensorflow::OpKernelContext *context) override {} +}; + +REGISTER_KERNEL_BUILDER(Name("EmbeddingHashTableImport") +.Device(tensorflow::DEVICE_CPU), EmbeddingHashTableImportOp); + +class EmbeddingHashTableExportOp : public tensorflow::OpKernel { +public: + explicit EmbeddingHashTableExportOp(tensorflow::OpKernelConstruction *context) : OpKernel(context) {} + ~EmbeddingHashTableExportOp() override {} + void Compute(tensorflow::OpKernelContext *context) override {} +}; + +REGISTER_KERNEL_BUILDER(Name("EmbeddingHashTableExport") +.Device(tensorflow::DEVICE_CPU), EmbeddingHashTableExportOp); + +class EmbeddingHashTableLookupOrInsertOp : public tensorflow::OpKernel { +public: + explicit EmbeddingHashTableLookupOrInsertOp(tensorflow::OpKernelConstruction *context) : OpKernel(context) {} + ~EmbeddingHashTableLookupOrInsertOp() override {} + void Compute(tensorflow::OpKernelContext *context) override {} +}; + +REGISTER_KERNEL_BUILDER(Name("EmbeddingHashTableLookupOrInsert") +.Device(tensorflow::DEVICE_CPU), EmbeddingHashTableLookupOrInsertOp); + REGISTER_KERNEL_BUILDER( Name("FastGelu") . @@ -122,6 +152,26 @@ Device(tensorflow::DEVICE_CPU) .TypeConstraint("T"), FastGeluGradOp); +class InitEmbeddingHashTableOp : public tensorflow::OpKernel { +public: + explicit InitEmbeddingHashTableOp(tensorflow::OpKernelConstruction *context) : OpKernel(context) {} + ~InitEmbeddingHashTableOp() override {} + void Compute(tensorflow::OpKernelContext *context) override {} +}; + +REGISTER_KERNEL_BUILDER(Name("InitEmbeddingHashTable").Device(tensorflow::DEVICE_CPU), InitEmbeddingHashTableOp); + +class EmbeddingHashTableApplyAdamWOp : public tensorflow::OpKernel { +public: + explicit EmbeddingHashTableApplyAdamWOp(tensorflow::OpKernelConstruction *context) + : OpKernel(context) {} + ~EmbeddingHashTableApplyAdamWOp() override {} + void Compute(tensorflow::OpKernelContext *context) override {} +}; + +REGISTER_KERNEL_BUILDER(Name("EmbeddingHashTableApplyAdamW").Device(tensorflow::DEVICE_CPU), + EmbeddingHashTableApplyAdamWOp); + class EmbeddingHashTableEvictOp : public tensorflow::OpKernel { public: explicit EmbeddingHashTableEvictOp(tensorflow::OpKernelConstruction *context) : OpKernel(context) {} diff --git a/tf_adapter/ops/aicore/npu_aicore_ops.cc b/tf_adapter/ops/aicore/npu_aicore_ops.cc index acfad59f3734e738bb0fb3ba1eb5b79343146e15..bb5e99e7f50ebb25804b2a9c598c1fcfa554043f 100644 --- a/tf_adapter/ops/aicore/npu_aicore_ops.cc +++ b/tf_adapter/ops/aicore/npu_aicore_ops.cc @@ -44,6 +44,63 @@ REGISTER_OP("FastGeluGrad") .Attr("T: realnumbertype") .SetShapeFn(tensorflow::shape_inference::MergeBothInputsShapeFn); +REGISTER_OP("EmbeddingHashTableImport") + .Input("table_handles: int64") + .Input("embedding_dims: int64") + .Input("bucket_sizes: int64") + .Input("keys: num * int64") + .Input("counters: num * uint64") + .Input("filter_flags: num * uint8") + .Input("values: num * float32") + .Attr("num: int >= 1") + .SetShapeFn(tensorflow::shape_inference::NoOutputs); + +REGISTER_OP("EmbeddingHashTableExport") + .Input("table_handles: int64") + .Input("table_sizes: int64") + .Input("embedding_dims: int64") + .Input("bucket_sizes: int64") + .Output("keys: num * int64") + .Output("counters: num * uint64") + .Output("filter_flags: num * uint8") + .Output("values: num * float") + .Attr("export_mode: string = 'all'") + .Attr("filter_export_flag: bool = false") + .Attr("num: int >= 1") + .SetIsStateful() + .SetShapeFn([](shape_inference::InferenceContext *c) { + int64 num = 0; + c->GetAttr("num", &num); + for (int64_t i = 0; i < num; ++i) { + c->set_output(i, c->Vector(c->UnknownDim())); + c->set_output(i + num, c->Vector(c->UnknownDim())); + c->set_output(i + 2 * num, c->Vector(c->UnknownDim())); + c->set_output(i + 3 * num, c->Vector(c->UnknownDim())); + } + return Status::OK(); + }); + +REGISTER_OP("EmbeddingHashTableApplyAdamW") + .Input("table_handle: int64") + .Input("keys: int64") + .Input("m: Ref(T)") + .Input("v: Ref(T)") + .Input("beta1_power: Ref(T)") + .Input("beta2_power: Ref(T)") + .Input("lr: T") + .Input("weight_decay: T") + .Input("beta1: T") + .Input("beta2: T") + .Input("epsilon: T") + .Input("grad: T") + .Input("max_grad_norm: Ref(T)") + .Attr("embedding_dim: int") + .Attr("bucket_size: int") + .Attr("amsgrad: bool = false") + .Attr("maximize: bool = false") + .Attr("T: {float16, float32}") + .SetShapeFn(tensorflow::shape_inference::NoOutputs); + REGISTER_OP("DynamicGruV2") .Input("x: T") .Input("weight_input: T") @@ -459,6 +516,27 @@ REGISTER_OP("DynamicRnnGrad") return Status::OK(); }); +REGISTER_OP("EmbeddingHashTableLookupOrInsert") + .Input("table_handle: int64") + .Input("keys:int64") + .Output("values: float") + .Attr("bucket_size:int") + .Attr("embedding_dim:int") + .Attr("filter_mode:string='no_filter'") + .Attr("filter_freq:int=0") + .Attr("default_key_or_value:bool = false") + .Attr("default_key: int = 0") + .Attr("default_value: float = 0.0") + .SetIsStateful() + .SetShapeFn([](InferenceContext* c) { + int64 num = 0; + c->GetAttr("embedding_dim", &num); + auto key_num = c->input(1); + int64_t nsample = InferenceContext::Value(c->Dim(key_num, 0)); + c->set_output(0, c->MakeShape({c->MakeDim(nsample), c->MakeDim(num)})); + return Status::OK(); + }); + REGISTER_OP("LRUCacheV2") .Input("index_list: T") .Input("data: Ref(dtype)") @@ -793,6 +871,15 @@ REGISTER_OP("TabulateFusionGrad") return Status::OK(); }); +REGISTER_OP("InitEmbeddingHashTable") + .Input("table_handle: int64") + .Input("sampled_values: float") + .Attr("bucket_size : int") + .Attr("embedding_dim : int") + .Attr("initializer_mode : string='random'") + .Attr("constant_value : float=0.0") + .SetShapeFn(shape_inference::NoOutputs); + REGISTER_OP("EmbeddingHashTableEvict") .Input("table_handle: int64") .Input("keys: int64") diff --git a/tf_adapter/python/npu_bridge/npu_cpu/npu_cpu_ops.py b/tf_adapter/python/npu_bridge/npu_cpu/npu_cpu_ops.py index de958655170721fa6e7ef9a49b8c4a8c276b1485..e7c5fa61fc2d1410d1f819c7b0217d7bc56542a3 100644 --- a/tf_adapter/python/npu_bridge/npu_cpu/npu_cpu_ops.py +++ b/tf_adapter/python/npu_bridge/npu_cpu/npu_cpu_ops.py @@ -25,6 +25,27 @@ from npu_bridge.helper import helper gen_npu_cpu_ops = helper.get_gen_ops() +## 提供device侧FeatureMapping LookupOrInsert功能 +# @param table_handle int64 类型 +# @param keys int64 类型 +# @param bucket_size int 类型 +# @param embedding_dim int 类型 +# @param filter_mode string 类型 +# @param filter_freq int 类型 +# @param default_key_or_value bool 类型 +# @param default_key int 类型 +# @param default_value float 类型 +# @return values float 类型 +def embedding_hashtable_lookup_or_insert(table_handle, keys, bucket_size, embedding_dim, filter_mode, filter_freq, + default_key_or_value, default_key, default_value): + """ device embedding feature mapping lookup or insert. """ + result = gen_npu_cpu_ops.EmbeddingHashTableLookupOrInsert( + table_handle=table_handle, keys=keys, bucket_size=bucket_size, embedding_dim=embedding_dim, + filter_mode=filter_mode, filter_freq=filter_freq, default_key_or_value=default_key_or_value, + default_key=default_key, default_value=default_value) + return result + + ## 提供embeddingrankid功能 # @param addr_tensor tensorflow的tensor类型,embeddingrankid操作的输入; # @param index tensorflow的tensor类型,embeddingrankid操作的输入; @@ -579,6 +600,84 @@ def embedding_hashmap_import_v2(file_path, table_ids, table_sizes, table_names, return result +## EmbeddingHashTable Init功能 +# @param table_handle int64 类型 +# @param sampled_values float 类型 +# @param bucket_size int 类型 +# @param embedding_dim int 类型 +# @param initializer_mode string 类型 +# @param constant_value int 类型 +def init_embedding_hashtable(table_handle, sampled_values, bucket_size, embedding_dim, initializer_mode, + constant_value): + """ device init embedding hashtable. """ + result = gen_npu_cpu_ops.InitEmbeddingHashTable( + table_handle=table_handle, sampled_values=sampled_values, bucket_size=bucket_size, embedding_dim=embedding_dim, + initializer_mode=initializer_mode, constant_value=constant_value) + return result + + +## 提供host侧hashTable导入功能 +# @param table_handles int64 类型 +# @param embedding_dims int64 类型 +# @param bucket_sizes int64 类型 +# @param keys int64 类型 +# @param counters uint64 类型 +# @param filter_flags uint8 类型 +# @param values float 类型 +def embedding_hash_table_import(table_handles, embedding_dims, bucket_sizes, keys, counters, filter_flags, values): + """ host embedding feature hash table import. """ + result = gen_npu_cpu_ops.EmbeddingHashTableImport( + table_handles=table_handles, embedding_dims=embedding_dims, bucket_sizes=bucket_sizes, + keys=keys, counters=counters, filter_flags=filter_flags, values=values) + return result + + +## 提供host侧hashTable导出功能 +# @param table_handles int64 类型 +# @param table_sizes int64 类型 +# @param embedding_dims int64 类型 +# @param bucket_sizes int64 类型 +# @param export_mode string 类型 +# @param filtered_export_flag bool 类型 +def embedding_hash_table_export(table_handles, table_sizes, embedding_dims, bucket_sizes, export_mode='all', + filter_export_flag=False): + """ host embedding feature hash table export. """ + result = gen_npu_cpu_ops.EmbeddingHashTableExport( + table_handles=table_handles, table_sizes=table_sizes, embedding_dims=embedding_dims, bucket_sizes=bucket_sizes, + export_mode=export_mode, filter_export_flag=filter_export_flag) + return result + + +## EmbeddingHashTableApplyAdamW AdamW 更新功能 +# @param table_handle int64 类型 +# @param keys int64 类型 +# @param m float16, float32 类型 +# @param v float16, float32 类型 +# @param beta1_power float16, float32 类型 +# @param beta2_power float16, float32 类型 +# @param lr float16, float32 类型 +# @param weight_decay float16, float32 类型 +# @param beta1 float16, float32 类型 +# @param beta2 float16, float32 类型 +# @param epsilon float16, float32 类型 +# @param grad float16, float32 类型 +# @param max_grad_norm float16, float32 类型 +# @param embedding_dim int 类型 +# @param bucket_size int 类型 +# @param amsgrad bool 类型 +# @param maximize bool 类型 +def embedding_hashtable_apply_adam_w(table_handle, keys, m, v, beta1_power, beta2_power, lr, weight_decay, + beta1, beta2, epsilon, grad, max_grad_norm, embedding_dim, + bucket_size, amsgrad, maximize): + """ device update embedding hashtable using AdamW. """ + result = gen_npu_cpu_ops.EmbeddingHashTableApplyAdamW( + table_handle=table_handle, keys=keys, m=m, v=v, beta1_power=beta1_power, beta2_power=beta2_power, + lr=lr, weight_decay=weight_decay, beta1=beta1, beta2=beta2, epsilon=epsilon, grad=grad, + max_grad_norm=max_grad_norm, embedding_dim=embedding_dim, bucket_size=bucket_size, + amsgrad=amsgrad, maximize=maximize) + return result + + ## 提供device侧FeatureMapping Evict功能 # @param table_handle int64 类型 # @param keys int64 类型