diff --git a/operator/AddCustomSample/FrameworkLaunch/TensorflowInvocation/AscendCustomToTensorFlowBuildIn/README.md b/operator/AddCustomSample/FrameworkLaunch/TensorflowInvocation/AscendCustomToTensorFlowBuildIn/README.md index ecf71c8c6cb968fda3ee8827611468e584baa940..42bc3c733380e4721a9c2268e66e62168e9c4443 100644 --- a/operator/AddCustomSample/FrameworkLaunch/TensorflowInvocation/AscendCustomToTensorFlowBuildIn/README.md +++ b/operator/AddCustomSample/FrameworkLaunch/TensorflowInvocation/AscendCustomToTensorFlowBuildIn/README.md @@ -4,6 +4,14 @@ ## 运行样例算子 ### 1.编译算子工程 运行此样例前,请参考[编译算子工程](../../README.md#operatorcompile)完成前期准备。 +注意:若tensorflow版本是2.x,需注意插件代码适配,路径为: samples/operator/AddCustomSample/FrameworkLaunch/AddCustom/framework/tf_plugin/tensorflow_add_custom_plugin.cc +需修改插件代码中的TensorFlow调用算子名称OriginOpType为"AddV2",如下所示: +```c++ +REGISTER_CUSTOM_OP("AddCustom") + .FrameworkType(TENSORFLOW) // type: TENSORFLOW + .OriginOpType("AddV2") // name in tf module + .ParseParamsByOperatorFn(AutoMappingByOpFn); +``` ### 2.tensorflow调用的方式调用样例运行 - 进入到样例目录 @@ -33,4 +41,5 @@ ## 更新说明 | 时间 | 更新事项 | | ---------- | ------------ | -| 2024/05/22 | 新增本readme | \ No newline at end of file +| 2024/05/22 | 新增本readme | +| 2024/11/25 | 修改tensorflow2.x时编译算子工程的说明 | \ No newline at end of file diff --git a/operator/AddCustomSample/FrameworkLaunch/TensorflowInvocation/AscendCustomToTensorFlowBuildIn/run_add_custom_tf2.py b/operator/AddCustomSample/FrameworkLaunch/TensorflowInvocation/AscendCustomToTensorFlowBuildIn/run_add_custom_tf2.py index b60f23b100c9beb8ea7f50a26a347498e188490d..686710fcbb3afd6ed7501ae7a1640b2d55859435 100755 --- a/operator/AddCustomSample/FrameworkLaunch/TensorflowInvocation/AscendCustomToTensorFlowBuildIn/run_add_custom_tf2.py +++ b/operator/AddCustomSample/FrameworkLaunch/TensorflowInvocation/AscendCustomToTensorFlowBuildIn/run_add_custom_tf2.py @@ -16,6 +16,9 @@ from __future__ import print_function import logging import tensorflow as tf import numpy as np +import npu_device +from npu_device.compat.v1.npu_init import * +npu_device.compat.enable_v1() tf.compat.v1.disable_v2_behavior() tf.compat.v1.flags.DEFINE_string("local_log_dir", "output/train_logs.txt", "Log file path")