diff --git a/torch_npu/csrc/distributed/ProcessGroupHCCL.cpp b/torch_npu/csrc/distributed/ProcessGroupHCCL.cpp index 9e63eb10327ae2f4ac8fb7e99d9201540fc21b1d..370bc57d9f738f2cc800311adaafe3e0ef3ae4fa 100644 --- a/torch_npu/csrc/distributed/ProcessGroupHCCL.cpp +++ b/torch_npu/csrc/distributed/ProcessGroupHCCL.cpp @@ -67,6 +67,7 @@ std::map kScalarTypeToHcclDataType = { {at::kFloat, HCCL_DATA_TYPE_FP32}, {at::kDouble, HCCL_DATA_TYPE_FP64}, {at::kBool, HCCL_DATA_TYPE_UINT8}, + {at::kBFloat16, HCCL_DATA_TYPE_BFP16}, }; std::map kHcclDataTypeToStringMap = { @@ -78,6 +79,7 @@ std::map kHcclDataTypeToStringMap = { {HCCL_DATA_TYPE_FP16, "at::kHalf"}, {HCCL_DATA_TYPE_FP32, "at::kFloat"}, {HCCL_DATA_TYPE_FP64, "at::kDouble"}, + {HCCL_DATA_TYPE_BFP16, "at::kBFloat16"}, }; int64_t physical_numel(at::Tensor& self){ @@ -136,7 +138,7 @@ HcclReduceOp getHcclReduceOp(const c10d::ReduceOp reduceOp, at::Tensor& input) { void checkSupportedDataTypeOfAllReduce(HcclDataType type) { static std::set allReduceSupportedDataTypes = {HCCL_DATA_TYPE_INT8, HCCL_DATA_TYPE_INT16, HCCL_DATA_TYPE_INT32, HCCL_DATA_TYPE_FP16, - HCCL_DATA_TYPE_FP32}; + HCCL_DATA_TYPE_FP32, HCCL_DATA_TYPE_BFP16}; TORCH_CHECK(allReduceSupportedDataTypes.count(type) != 0, "HCCL AllReduce & Reduce: Unsupported data type ", getHcclDataTypeSerialString(type));