From 991074179276ef7c12cbe31e646c6a77d56b867c Mon Sep 17 00:00:00 2001 From: PengC Date: Sat, 23 Nov 2024 16:45:19 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E6=94=B9tf2.x=E6=97=B6=E7=9A=84add=5F?= =?UTF-8?q?framework?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../AscendCustomToTensorFlowBuildIn/README.md | 11 ++++++++++- .../run_add_custom_tf2.py | 3 +++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/operator/AddCustomSample/FrameworkLaunch/TensorflowInvocation/AscendCustomToTensorFlowBuildIn/README.md b/operator/AddCustomSample/FrameworkLaunch/TensorflowInvocation/AscendCustomToTensorFlowBuildIn/README.md index ecf71c8c6..42bc3c733 100644 --- a/operator/AddCustomSample/FrameworkLaunch/TensorflowInvocation/AscendCustomToTensorFlowBuildIn/README.md +++ b/operator/AddCustomSample/FrameworkLaunch/TensorflowInvocation/AscendCustomToTensorFlowBuildIn/README.md @@ -4,6 +4,14 @@ ## 运行样例算子 ### 1.编译算子工程 运行此样例前,请参考[编译算子工程](../../README.md#operatorcompile)完成前期准备。 +注意:若tensorflow版本是2.x,需注意插件代码适配,路径为: samples/operator/AddCustomSample/FrameworkLaunch/AddCustom/framework/tf_plugin/tensorflow_add_custom_plugin.cc +需修改插件代码中的TensorFlow调用算子名称OriginOpType为"AddV2",如下所示: +```c++ +REGISTER_CUSTOM_OP("AddCustom") + .FrameworkType(TENSORFLOW) // type: TENSORFLOW + .OriginOpType("AddV2") // name in tf module + .ParseParamsByOperatorFn(AutoMappingByOpFn); +``` ### 2.tensorflow调用的方式调用样例运行 - 进入到样例目录 @@ -33,4 +41,5 @@ ## 更新说明 | 时间 | 更新事项 | | ---------- | ------------ | -| 2024/05/22 | 新增本readme | \ No newline at end of file +| 2024/05/22 | 新增本readme | +| 2024/11/25 | 修改tensorflow2.x时编译算子工程的说明 | \ No newline at end of file diff --git a/operator/AddCustomSample/FrameworkLaunch/TensorflowInvocation/AscendCustomToTensorFlowBuildIn/run_add_custom_tf2.py b/operator/AddCustomSample/FrameworkLaunch/TensorflowInvocation/AscendCustomToTensorFlowBuildIn/run_add_custom_tf2.py index b60f23b10..686710fcb 100755 --- a/operator/AddCustomSample/FrameworkLaunch/TensorflowInvocation/AscendCustomToTensorFlowBuildIn/run_add_custom_tf2.py +++ b/operator/AddCustomSample/FrameworkLaunch/TensorflowInvocation/AscendCustomToTensorFlowBuildIn/run_add_custom_tf2.py @@ -16,6 +16,9 @@ from __future__ import print_function import logging import tensorflow as tf import numpy as np +import npu_device +from npu_device.compat.v1.npu_init import * +npu_device.compat.enable_v1() tf.compat.v1.disable_v2_behavior() tf.compat.v1.flags.DEFINE_string("local_log_dir", "output/train_logs.txt", "Log file path") -- Gitee