From fc7644bd0e6f1095061c6af1d626ae6d59602751 Mon Sep 17 00:00:00 2001 From: medivh-x Date: Tue, 28 Sep 2021 17:05:03 +0800 Subject: [PATCH] bugfix-for-remove-redundant-ctrl-edges --- tf_adapter_2.x/npu_device/core/npu_device.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tf_adapter_2.x/npu_device/core/npu_device.cpp b/tf_adapter_2.x/npu_device/core/npu_device.cpp index acb6eeeca..74e48d7ad 100644 --- a/tf_adapter_2.x/npu_device/core/npu_device.cpp +++ b/tf_adapter_2.x/npu_device/core/npu_device.cpp @@ -57,6 +57,7 @@ const static std::string kDropOutGenMaskV3 = "DropOutGenMaskV3"; const static std::string kDropOutDoMaskV3 = "DropOutDoMaskV3"; const static std::string kNpuLossScaleAttr = "_npu_loss_scale"; const static std::string kNpuAllocFloatStatusOp = "NpuAllocFloatStatus"; +const static std::string kNpuGetFloatStatusOp = "NpuGetFloatStatus"; const static std::string kEnable = "1"; const static std::string kWeightUpdateGroupingAttr = "_weight_update_grouping"; const static std::string kReadVariableOp = "ReadVariableOp"; @@ -87,7 +88,7 @@ size_t RemoveRedundantControlEdges(tensorflow::Graph *graph) { std::vector edges_to_remove; for (auto edge : graph->edges()) { if (edge->IsControlEdge()) { - if (edge->dst()->type_string() == kHcomAllReduce || + if ((edge->dst()->type_string() == kHcomAllReduce && edge->src()->type_string() != kNpuGetFloatStatusOp) || (edge->src()->type_string() == kHcomAllReduce && edge->src()->attrs().Find(kNpuLossScaleAttr) == nullptr)) { edges_to_remove.push_back(edge); } else if (edge->src()->type_string() == kDropOutDoMaskV3 && edge->dst()->type_string() == kDropOutGenMaskV3) { -- Gitee