From 8bc1a6cfe3134396c91399ae8e5a4cfc6dccb475 Mon Sep 17 00:00:00 2001 From: yiyanzhi_akane Date: Sat, 15 May 2021 11:05:00 +0800 Subject: [PATCH] add kahan algorithm for high precise reduction --- src/akg_reduce/reduce.cuh | 18 ++++ src/poly/gpu_isl_emitter.cc | 99 +++++++++++++++++-- src/poly/gpu_isl_emitter.h | 13 ++- .../incubator-tvm/src/codegen/codegen_cuda.cc | 23 +++++ .../incubator-tvm/src/codegen/codegen_cuda.h | 7 ++ 5 files changed, 148 insertions(+), 12 deletions(-) diff --git a/src/akg_reduce/reduce.cuh b/src/akg_reduce/reduce.cuh index 0c8d5af8..0be409e7 100644 --- a/src/akg_reduce/reduce.cuh +++ b/src/akg_reduce/reduce.cuh @@ -60,6 +60,24 @@ __inline__ __device__ void AkgReduce(const ReduceOp op, // The operator } } +/** + * @brief Accumulation with kahan algorithm, only for sum operator + * @tparam T Dtype: half, float, double, int; + */ + template + __device__ __forceinline__ void AkgKahanAccumulation(T *y, + T *t, + T *c, + T *acc, + const T input + ) { + y[0] = input - c[0]; + t[0] = acc[0] + y[0]; + c[0] = (t[0] - acc[0]) - y[0]; + acc[0] = t[0]; + } + + /** * @brief Atomic return function, from shared memory to global memory * @tparam T Dtype: half, float, double, int, signed char, bool; diff --git a/src/poly/gpu_isl_emitter.cc b/src/poly/gpu_isl_emitter.cc index 413017d5..f405884d 100644 --- a/src/poly/gpu_isl_emitter.cc +++ b/src/poly/gpu_isl_emitter.cc @@ -291,16 +291,15 @@ Stmt GpuIslEmitter::EmitSync() { return Evaluate::make(Call::make(Int(32), STORAGE_SYNC, {StringImm::make(SYNC_SCOP_SHARED)}, Call::Intrinsic)); } -void GpuIslEmitter::SetScalarTensorBind() { +void GpuIslEmitter::SetScalarTensorBind(std::string scalar_tensor_name) { Array shapes; shapes.push_back(Expr(1)); Type type = reduce_info_.reduce_data_type_info_; - std::string scalar_tensor_name = reduce_info_.scalar_tensor_name_; reduce_info_.added_tensors_.insert(scalar_tensor_name); Tensor tensor = placeholder(shapes, type, scalar_tensor_name); const Buffer buffer = decl_buffer(shapes, type, scalar_tensor_name); - reduce_info_.scalar_tensor_ = tensor; + reduce_info_.scalar_tensor_[scalar_tensor_name] = tensor; info_.user_config_.SetBind(tensor, buffer); } @@ -332,6 +331,11 @@ Stmt GpuIslEmitter::EmitReduceInit(const isl::ast_node_user &node) { auto stmt_id = usr_expr.get_arg(0).as().get_id(); CHECK(!reduce_info_.scalar_tensor_name_.empty()) << "scalar tensor info should not be empty!"; + if (reduce_info_.reduce_op_.find("SumOp") != std::string::npos) { + CHECK(!reduce_info_.scalar_kht_name_.empty()) << "scalar tensor kht info should not be empty!"; + CHECK(!reduce_info_.scalar_khy_name_.empty()) << "scalar tensor khy info should not be empty!"; + CHECK(!reduce_info_.scalar_khc_name_.empty()) << "scalar tensor khc info should not be empty!"; + } std::vector strs = common::Split(stmt_id.name(), "_"); CHECK_EQ(strs.size(), REDUCE_FLAG_SIZE) << "red init format is not right!."; @@ -345,20 +349,32 @@ Stmt GpuIslEmitter::EmitReduceInit(const isl::ast_node_user &node) { } } - Array args; - args.push_back(Expr(0)); - Stmt scalar_stmt = Provide::make(reduce_info_.scalar_tensor_->op, 0, init_value, args); - CHECK(reduce_info_.reduce_area_stmt_.defined()); reduce_info_.stmts_.insert(reduce_info_.stmts_.begin(), reduce_info_.reduce_area_stmt_); + Array args; + args.push_back(Expr(0)); + Stmt scalar_stmt = + Provide::make(reduce_info_.scalar_tensor_[reduce_info_.scalar_tensor_name_]->op, 0, init_value, args); CHECK(scalar_stmt.defined()); reduce_info_.stmts_.insert(reduce_info_.stmts_.begin(), scalar_stmt); + if (reduce_info_.reduce_op_.find("SumOp") != std::string::npos) { + Stmt scalar_khc = + Provide::make(reduce_info_.scalar_tensor_[reduce_info_.scalar_khc_name_]->op, 0, init_value, args); + CHECK(scalar_khc.defined()); + reduce_info_.stmts_.insert(reduce_info_.stmts_.begin(), scalar_khc); + } + MakeReduceStmt(); Stmt stmt = Block::make(reduce_info_.stmts_); stmt = InsertRealizeWithMemType(stmt, isl::id(stmt_id.ctx(), reduce_info_.scalar_tensor_name_), MEM_TYPE_LOCAL); + if (reduce_info_.reduce_op_.find("SumOp") != std::string::npos) { + stmt = InsertRealizeWithMemType(stmt, isl::id(stmt_id.ctx(), reduce_info_.scalar_kht_name_), MEM_TYPE_LOCAL); + stmt = InsertRealizeWithMemType(stmt, isl::id(stmt_id.ctx(), reduce_info_.scalar_khy_name_), MEM_TYPE_LOCAL); + stmt = InsertRealizeWithMemType(stmt, isl::id(stmt_id.ctx(), reduce_info_.scalar_khc_name_), MEM_TYPE_LOCAL); + } stmt = InsertRealizeWithMemType(stmt, isl::id(stmt_id.ctx(), reduce_info_.shared_compute_name_), MEM_TYPE_SHARED); ResetStatus(); @@ -402,6 +418,7 @@ void GpuIslEmitter::ResetStatus() { reduce_info_.origin_reduce_stmt_ = Stmt(); reduce_info_.gm_write_stmt_ = Stmt(); reduce_info_.atomic_rhs_ = Expr(); + reduce_info_.input_tensor_expr_ = Expr(); is_out_most_stmt_ = true; } @@ -427,6 +444,16 @@ Stmt GpuIslEmitter::EmitReduceUpdate(const isl::ast_node_user &node) { reduce_info_.reduce_op_ += strs[REDUCE_FLAG_TYPE_POS]; } CHECK(!reduce_info_.reduce_op_.empty()) << "reduce op should not be empty!"; + + if (reduce_info_.reduce_op_.find("SumOp") != std::string::npos) { + reduce_info_.scalar_kht_name_ = SCALAR_KHT_PREFIX; + reduce_info_.scalar_kht_name_ += reduce_info_.reduce_stmt_index_; + reduce_info_.scalar_khy_name_ = SCALAR_KHY_PREFIX; + reduce_info_.scalar_khy_name_ += reduce_info_.reduce_stmt_index_; + reduce_info_.scalar_khc_name_ = SCALAR_KHC_PREFIX; + reduce_info_.scalar_khc_name_ += reduce_info_.reduce_stmt_index_; + } + std::string stmt_name = strs[REDUCE_FLAG_STMT_PREFIX_POS] + "_" + strs[REDUCE_FLAG_STMT_NUM_POS]; std::string origin_tensor_name = ""; for (auto it : info_.analysis_result_.GetReduceTensorInfoMap()) { @@ -448,12 +475,49 @@ Stmt GpuIslEmitter::EmitReduceUpdate(const isl::ast_node_user &node) { } MakeAkgReduceFuncName(); - SetScalarTensorBind(); + SetScalarTensorBind(reduce_info_.scalar_tensor_name_); + if (reduce_info_.reduce_op_.find("SumOp") != std::string::npos) { + SetScalarTensorBind(reduce_info_.scalar_kht_name_); + SetScalarTensorBind(reduce_info_.scalar_khy_name_); + SetScalarTensorBind(reduce_info_.scalar_khc_name_); + } SetSharedTensorBind(); return Stmt(); } +Stmt GpuIslEmitter::TransferToKaHanInterface() { + std::string func_name = AKG_REDUCE_LIB_SPACE; + func_name += "::"; + func_name += AKG_KAHAN_LIB_NAME; + Expr template_arg0 = make_const(reduce_info_.reduce_data_type_info_, 1); + + Array args; + args.push_back(Expr(0)); + + Tensor tt = reduce_info_.scalar_tensor_[reduce_info_.scalar_khy_name_]; + Expr a1 = Call::make(tt->dtype, tt->op->func_name(), args, Call::Halide, tt->op, 0); + a1 = Call::make(a1.type(), "&", {a1}, Call::Extern); + + tt = reduce_info_.scalar_tensor_[reduce_info_.scalar_kht_name_]; + Expr a2 = Call::make(tt->dtype, tt->op->func_name(), args, Call::Halide, tt->op, 0); + a2 = Call::make(a2.type(), "&", {a2}, Call::Extern); + + tt = reduce_info_.scalar_tensor_[reduce_info_.scalar_khc_name_]; + Expr a3 = Call::make(tt->dtype, tt->op->func_name(), args, Call::Halide, tt->op, 0); + a3 = Call::make(a3.type(), "&", {a3}, Call::Extern); + + tt = reduce_info_.scalar_tensor_[reduce_info_.scalar_tensor_name_]; + Expr a4 = Call::make(tt->dtype, tt->op->func_name(), args, Call::Halide, tt->op, 0); + a4 = Call::make(a4.type(), "&", {a4}, Call::Extern); + + CHECK(reduce_info_.input_tensor_expr_.defined()); + Stmt stmt = Evaluate::make( + Call::make(Int(32), func_name, {template_arg0, a1, a2, a3, a4, reduce_info_.input_tensor_expr_}, Call::Extern)); + + return stmt; +} + void GpuIslEmitter::MakeReduceStmt() { std::string func_name = reduce_info_.akg_reduce_api_; std::string op_info = reduce_info_.reduce_op_ + "()"; @@ -482,7 +546,7 @@ void GpuIslEmitter::MakeReduceStmt() { CHECK(buffer.defined()); - Tensor tt = reduce_info_.scalar_tensor_; + Tensor tt = reduce_info_.scalar_tensor_[reduce_info_.scalar_tensor_name_]; Array args; args.push_back(Expr(0)); Expr a4 = Call::make(tt->dtype, tt->op->func_name(), args, Call::Halide, tt->op, 0); @@ -563,8 +627,23 @@ Stmt GpuIslEmitter::EmitReduceArea(const isl::ast_node_user &node) { Array args_scalar; args_scalar.push_back(Expr(0)); - stmt = AkgReduceStmtChange(reduce_info_.scalar_tensor_, args_scalar, reduce_info_.promoted_tensor_name_for_reduce_) + stmt = AkgReduceStmtChange(reduce_info_.scalar_tensor_[reduce_info_.scalar_tensor_name_], args_scalar, + reduce_info_.promoted_tensor_name_for_reduce_) .Mutate(stmt); + + if (reduce_info_.reduce_op_.find("SumOp") != std::string::npos) { + auto pro = stmt.as(); + CHECK(pro); + auto value = pro->value; + auto add = value.as(); + CHECK(add); + auto add_a = add->a; + auto add_b = add->b; + reduce_info_.input_tensor_expr_ = + (add->a.as() && add->a.as()->name == reduce_info_.scalar_tensor_name_) ? add_b : add_a; + stmt = TransferToKaHanInterface(); + } + if (add_to_reduce_area) { reduce_info_.reduce_area_stmt_ = stmt; return Stmt(); diff --git a/src/poly/gpu_isl_emitter.h b/src/poly/gpu_isl_emitter.h index 0f2a2fd6..b6bae30f 100644 --- a/src/poly/gpu_isl_emitter.h +++ b/src/poly/gpu_isl_emitter.h @@ -52,6 +52,9 @@ constexpr auto DEFAULT_TENSOR_INDEX = "[0]"; constexpr auto USELESS_INDEX = "0"; constexpr auto USELESS_SHAPE_SIZE = "1"; constexpr auto SCALAR_TENSOR_PREFIX = "acc_"; +constexpr auto SCALAR_KHT_PREFIX = "kahan_t"; +constexpr auto SCALAR_KHY_PREFIX = "kahan_y"; +constexpr auto SCALAR_KHC_PREFIX = "kahan_c"; constexpr auto SHARED_MEMORY_PREFIX = "__shared__"; constexpr auto SHARED_TENSOR_PREFIX = "red_buf"; @@ -59,6 +62,7 @@ constexpr auto REDUCE_LIB_TYPE_ORIGIN = "origin"; constexpr auto REDUCE_LIB_TYPE_PARIS = "paris"; constexpr auto AKG_REDUCE_LIB_SPACE = "akg_reduce"; constexpr auto AKG_REDUCE_LIB_NAME = "AkgReduce"; +constexpr auto AKG_KAHAN_LIB_NAME = "AkgKahanAccumulation"; constexpr auto PARIS_REDUCE_LIB_SPACE = "paris_reduce"; constexpr auto PARIS_REDUCE_LIB_NAME = "ParisReduce"; constexpr auto AKG_REDUCE_RETURN_NAME = "AkgAtomicReturn"; @@ -109,6 +113,10 @@ class ReduceEmitInfo { std::string shared_compute_name_; std::string scalar_tensor_name_; + std::string scalar_kht_name_; + std::string scalar_khy_name_; + std::string scalar_khc_name_; + Expr input_tensor_expr_; std::string reduce_op_; std::string reduce_stmt_index_; @@ -119,7 +127,7 @@ class ReduceEmitInfo { std::set added_tensors_; Stmt reduce_area_stmt_; Stmt origin_reduce_stmt_; - Tensor scalar_tensor_; + std::map scalar_tensor_; Tensor shared_tensor_; std::vector stmts_; Expr atomic_rhs_; @@ -234,9 +242,10 @@ class GpuIslEmitter : public IslEmitter { void MakeAkgReduceFuncName(); void ConstructAtomicReturnFuncName(); void MakeReduceStmt(); + Stmt TransferToKaHanInterface(); Stmt MakeAtomicStmt(); - void SetScalarTensorBind(); + void SetScalarTensorBind(std::string scalar_tensor_name); void SetSharedTensorBind(); void ResetStatus(); diff --git a/third_party/incubator-tvm/src/codegen/codegen_cuda.cc b/third_party/incubator-tvm/src/codegen/codegen_cuda.cc index 74230df5..3519a187 100644 --- a/third_party/incubator-tvm/src/codegen/codegen_cuda.cc +++ b/third_party/incubator-tvm/src/codegen/codegen_cuda.cc @@ -54,6 +54,14 @@ * Print offset shared memory when use total shared_memory of VisitStmt_(const Allocate* op) */ +/* + * 2021.5.17 + * Modify the functions: + * add KaHan interface processing logic in VisitExpr_(const Call *op, std::ostream& os) + * for the reduce sum operator + */ + + #include #include #include @@ -553,6 +561,21 @@ void CodeGenCUDA::VisitExpr_(const Call *op, std::ostream& os) { os << ")"; return; } + if (op->name == AKG_KAHAN) { + CHECK_GE(op->args.size(), 1); + os << op->name << "<"; + Expr template_arg0 = op->args[0]; + this->PrintType(template_arg0.type(), os); + os << ">("; + for (size_t i = 1; i < op->args.size(); i++) { + this->PrintExpr(op->args[i], os); + if (i < op->args.size() - 1) { + os << ", "; + } + } + os << ")"; + return; + } CodeGenC::VisitExpr_(op, os); } else if (op->is_intrinsic(Call::reinterpret_cast_op)) { os << "*(reinterpret_cast<"; diff --git a/third_party/incubator-tvm/src/codegen/codegen_cuda.h b/third_party/incubator-tvm/src/codegen/codegen_cuda.h index 64531c04..05d1a094 100644 --- a/third_party/incubator-tvm/src/codegen/codegen_cuda.h +++ b/third_party/incubator-tvm/src/codegen/codegen_cuda.h @@ -42,6 +42,12 @@ * wmma_scope. */ + +/* + * 2021.05.17 + * Add const akg_reduce::AkgKahanAccumulation for reduce + */ + #ifndef TVM_CODEGEN_CODEGEN_CUDA_H_ #define TVM_CODEGEN_CODEGEN_CUDA_H_ @@ -57,6 +63,7 @@ namespace codegen { constexpr auto REDUCE_LIB_TYPE = "reduceLibType"; constexpr auto AKG_REDUCE = "akg_reduce::AkgReduce"; constexpr auto AKG_ATOMIC_RETURN = "akg_reduce::AkgAtomicReturn"; +constexpr auto AKG_KAHAN = "akg_reduce::AkgKahanAccumulation"; constexpr auto PARIS_REDUCE = "paris_reduce::ParisReduce"; constexpr auto PARIS_ATOMIC_RETURN = "paris_reduce::ParisReturn"; constexpr auto ORIGIN_REDUCE_LIB = "origin"; -- Gitee