From ead0a32b72c366a19a7942775a8c51f28c632cb2 Mon Sep 17 00:00:00 2001 From: qingfenxiaochong Date: Tue, 25 Jul 2023 11:31:30 +0800 Subject: [PATCH] support bf16 for v2.0.1-5.0-rc2 --- torch_npu/csrc/distributed/ProcessGroupHCCL.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torch_npu/csrc/distributed/ProcessGroupHCCL.cpp b/torch_npu/csrc/distributed/ProcessGroupHCCL.cpp index 9e63eb1032..370bc57d9f 100644 --- a/torch_npu/csrc/distributed/ProcessGroupHCCL.cpp +++ b/torch_npu/csrc/distributed/ProcessGroupHCCL.cpp @@ -67,6 +67,7 @@ std::map kScalarTypeToHcclDataType = { {at::kFloat, HCCL_DATA_TYPE_FP32}, {at::kDouble, HCCL_DATA_TYPE_FP64}, {at::kBool, HCCL_DATA_TYPE_UINT8}, + {at::kBFloat16, HCCL_DATA_TYPE_BFP16}, }; std::map kHcclDataTypeToStringMap = { @@ -78,6 +79,7 @@ std::map kHcclDataTypeToStringMap = { {HCCL_DATA_TYPE_FP16, "at::kHalf"}, {HCCL_DATA_TYPE_FP32, "at::kFloat"}, {HCCL_DATA_TYPE_FP64, "at::kDouble"}, + {HCCL_DATA_TYPE_BFP16, "at::kBFloat16"}, }; int64_t physical_numel(at::Tensor& self){ @@ -136,7 +138,7 @@ HcclReduceOp getHcclReduceOp(const c10d::ReduceOp reduceOp, at::Tensor& input) { void checkSupportedDataTypeOfAllReduce(HcclDataType type) { static std::set allReduceSupportedDataTypes = {HCCL_DATA_TYPE_INT8, HCCL_DATA_TYPE_INT16, HCCL_DATA_TYPE_INT32, HCCL_DATA_TYPE_FP16, - HCCL_DATA_TYPE_FP32}; + HCCL_DATA_TYPE_FP32, HCCL_DATA_TYPE_BFP16}; TORCH_CHECK(allReduceSupportedDataTypes.count(type) != 0, "HCCL AllReduce & Reduce: Unsupported data type ", getHcclDataTypeSerialString(type)); -- Gitee