From 7610045ef563cf2a2ac6302373114b2cbb3cfa40 Mon Sep 17 00:00:00 2001 From: qingfenxiaochong Date: Tue, 25 Jul 2023 10:51:23 +0800 Subject: [PATCH] support bf16 for v1.11.0-5.0-rc2 --- torch_npu/csrc/distributed/ProcessGroupHCCL.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torch_npu/csrc/distributed/ProcessGroupHCCL.cpp b/torch_npu/csrc/distributed/ProcessGroupHCCL.cpp index d353bdc57a..68110266f7 100644 --- a/torch_npu/csrc/distributed/ProcessGroupHCCL.cpp +++ b/torch_npu/csrc/distributed/ProcessGroupHCCL.cpp @@ -67,6 +67,7 @@ std::map kScalarTypeToHcclDataType = { {at::kFloat, HCCL_DATA_TYPE_FP32}, {at::kDouble, HCCL_DATA_TYPE_FP64}, {at::kBool, HCCL_DATA_TYPE_UINT8}, + {at::kBFloat16, HCCL_DATA_TYPE_BFP16}, }; std::map kHcclDataTypeToStringMap = { @@ -78,6 +79,7 @@ std::map kHcclDataTypeToStringMap = { {HCCL_DATA_TYPE_FP16, "at::kHalf"}, {HCCL_DATA_TYPE_FP32, "at::kFloat"}, {HCCL_DATA_TYPE_FP64, "at::kDouble"}, + {HCCL_DATA_TYPE_BFP16, "at::kBFloat16"}, }; int64_t physical_numel(at::Tensor& self){ @@ -136,7 +138,7 @@ HcclReduceOp getHcclReduceOp(const c10d::ReduceOp reduceOp, at::Tensor& input) { void checkSupportedDataTypeOfAllReduce(HcclDataType type) { static std::set allReduceSupportedDataTypes = {HCCL_DATA_TYPE_INT8, HCCL_DATA_TYPE_INT16, HCCL_DATA_TYPE_INT32, HCCL_DATA_TYPE_FP16, - HCCL_DATA_TYPE_FP32}; + HCCL_DATA_TYPE_FP32, HCCL_DATA_TYPE_BFP16}; TORCH_CHECK(allReduceSupportedDataTypes.count(type) != 0, "HCCL AllReduce & Reduce: Unsupported data type ", getHcclDataTypeSerialString(type)); -- Gitee