From 644cfb6d45adbfa33c258b8b93af0bbb619619d9 Mon Sep 17 00:00:00 2001 From: zhao-lupeng Date: Fri, 24 Feb 2023 15:11:19 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=8A=A0precision=5Fmode=E9=80=89?= =?UTF-8?q?=E9=A1=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tf_adapter/util/npu_attrs.cc | 12 +++++++++--- .../python/npu_device/configs/npu_config.py | 3 ++- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/tf_adapter/util/npu_attrs.cc b/tf_adapter/util/npu_attrs.cc index 3c88960c4..e7a1af5c4 100644 --- a/tf_adapter/util/npu_attrs.cc +++ b/tf_adapter/util/npu_attrs.cc @@ -1734,9 +1734,15 @@ Status NpuAttrs::SetNpuOptimizerAttr(const GraphOptimizationPassOptions &options } if (params.count("precision_mode") > 0) { precision_mode = params.at("precision_mode").s(); - const static std::vector kPrecisionModeList = {"force_fp32", "allow_fp32_to_fp16", - "force_fp16", "must_keep_origin_dtype", - "allow_mix_precision", "cube_fp16in_fp32out"}; + const static std::vector kPrecisionModeList = {"force_fp32", + "allow_fp32_to_fp16", + "force_fp16", + "must_keep_origin_dtype", + "allow_mix_precision", + "cube_fp16in_fp32out", + "allow_mix_precision_fp16", + "allow_mix_precision_bf16", + "allow_fp32_to_bp16"}; NPU_REQUIRES_OK(CheckValueAllowed(precision_mode, kPrecisionModeList)); } else { if (static_cast(graph_run_mode)) { diff --git a/tf_adapter_2.x/python/npu_device/configs/npu_config.py b/tf_adapter_2.x/python/npu_device/configs/npu_config.py index 30791404b..1f74782af 100644 --- a/tf_adapter_2.x/python/npu_device/configs/npu_config.py +++ b/tf_adapter_2.x/python/npu_device/configs/npu_config.py @@ -37,7 +37,8 @@ class NpuConfig(NpuBaseConfig): self.fusion_switch_file = OptionValue(None, None) self.precision_mode = OptionValue('allow_fp32_to_fp16', ['force_fp32', 'allow_fp32_to_fp16', 'force_fp16', 'must_keep_origin_dtype', - 'allow_mix_precision', 'cube_fp16in_fp32out']) + 'allow_mix_precision', 'cube_fp16in_fp32out', 'allow_mix_precision_fp16', + 'allow_mix_precision_bf16', 'allow_fp32_to_bp16']) self.op_select_implmode = DeprecatedValue(['high_performance', 'high_precision'], replacement='op_precision_mode') self.optypelist_for_implmode = DeprecatedValue(None, replacement='op_precision_mode') -- Gitee