diff --git a/src/poly/tiling/hermes/axis.h b/src/poly/tiling/hermes/axis.h index 8ea718684b5025b6add7818aaba4663a6d754563..27df78f5fd4e2374aeee4a023aa1d4a6f4564c78 100644 --- a/src/poly/tiling/hermes/axis.h +++ b/src/poly/tiling/hermes/axis.h @@ -36,7 +36,7 @@ class Axis { kVectorization }; - Axis(); + Axis() = default; std::string name_; std::string gemm_axis_; diff --git a/src/poly/tiling/hermes/hardware.cc b/src/poly/tiling/hermes/hardware.cc index bd248edc34284eea7615c8a5848c4f8200bc56f7..d09cd05df805ec3782d00a6212089d6d7b96d15b 100644 --- a/src/poly/tiling/hermes/hardware.cc +++ b/src/poly/tiling/hermes/hardware.cc @@ -30,11 +30,11 @@ Hardware::Hardware(int num_core, int mem_VC_size, int mem_C1_size, int mem_C0_si mem_C1_align_{mem_C1_align}, vblocknum_{vblocknum}, vblocksize_{vblocksize} { - // we divide UB by 2 for each UB alloc error + // we divide VC by 2 for each VC alloc error this->mem_VC_size_ = mem_VC_size / (1 << Hardware::mem_VC_alloc_failed_); } -bool Hardware::HasUBFail(const std::string allocation_error_buf) { +bool Hardware::HasVCFail(const std::string allocation_error_buf) { if (allocation_error_buf == "local.UB") { return true; } diff --git a/src/poly/tiling/hermes/hardware.h b/src/poly/tiling/hermes/hardware.h index 9dade828751fcac90d2d4ab36e379db0475bc79c..b217834cc31e17e3699b881e84bd185288d23025 100644 --- a/src/poly/tiling/hermes/hardware.h +++ b/src/poly/tiling/hermes/hardware.h @@ -25,9 +25,9 @@ class Hardware { public: Hardware(int, int, int, int, int, int, int, int); - static bool HasUBFail(const std::string); - static void AddUBFailCounter() { Hardware::mem_VC_alloc_failed_++; } - static void ResetUBFailCounter() { Hardware::mem_VC_alloc_failed_ = 0; } + static bool HasVCFail(const std::string); + static void AddVCFailCounter() { Hardware::mem_VC_alloc_failed_++; } + static void ResetVCFailCounter() { Hardware::mem_VC_alloc_failed_ = 0; } int num_core_; int mem_VC_size_; diff --git a/src/poly/tiling/hermes/init_graph.cc b/src/poly/tiling/hermes/init_graph.cc new file mode 100644 index 0000000000000000000000000000000000000000..64c068355565f877dbaf336eebfc7c7c43c07798 --- /dev/null +++ b/src/poly/tiling/hermes/init_graph.cc @@ -0,0 +1,349 @@ +/** + * Copyright 2021-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +#include "poly/tiling/hermes/init_graph.h" +#include "poly/tiling/hermes/tensor.h" +#include "poly/tiling/hermes/utils.h" + +namespace akg { +namespace ir { +namespace poly { +InitGraph::InitGraph(const std::string &name, std::vector> &nodes, + const std::vector> &inputs, + const std::vector> &outputs) + : name_{name}, nodes_{nodes}, inputs_{inputs}, outputs_{outputs} {} + +InitGraph::InitGraph(const std::vector> &check_visitor_nodes) + : nodes_{check_visitor_nodes} { + SetInputNodes(); + SetOutputNodes(); + SetConstantNodes(this->nodes_, inputs_); + if (!outputs_.empty()) { + this->name_ = outputs_.back()->name_; + } +} + +void InitGraph::SetConstantNodes(const std::vector> &nodes, + const std::vector> &inputs) { + for (auto const &node : nodes) { + if (node->op_.IsInput() && (std::find(inputs.begin(), inputs.end(), node) == inputs.end())) { + node->op_.op_type_ = Op::OpType::Constant; + } + } +} + +void InitGraph::SetInputNodes() { + for (auto const &node : nodes_) { + if (node->op_.IsInput()) { + inputs_.push_back(node); + } + } +} + +void InitGraph::SetOutputNodes() { + std::vector predecessor_names; + for (auto const &node : nodes_) { + for (auto const &pred : node->pred_) { + predecessor_names.push_back(pred->name_); + } + } + for (auto const &node : nodes_) { + for (auto const &succ : node->succ_) { + if (std::find(outputs_.begin(), outputs_.end(), succ) != outputs_.end()) { + continue; + } + if (std::find(predecessor_names.begin(), predecessor_names.end(), succ->name_) != predecessor_names.end()) { + continue; + } + outputs_.push_back(succ); + } + } + if (outputs_.empty()) { + outputs_.push_back(nodes_.back()); + } +} + +void InitGraph::RemoveNameless() { + std::set to_remove; + for (size_t i = 0; i < nodes_.size(); ++i) { + if (!(nodes_[i]->HasName())) { + if (nodes_[i]->op_.RemoveUselessInput()) { // InplaceAssign + to_remove = UselessInput(nodes_[i]->pred_, nodes_, to_remove); + } else { + FixGraph(nodes_, i); + } + to_remove.insert(static_cast(i)); + } + } + + for (auto nid = to_remove.rbegin(); nid != to_remove.rend(); ++nid) { + nodes_.erase(nodes_.begin() + (*nid)); + } +} + +std::set InitGraph::UselessInput(const std::vector> &inputs, + const std::vector> &nodes, std::set to_remove) { + for (auto const &node : inputs) { + if ((node->succ_.size() == 1) && node->pred_.empty()) { + to_remove.insert(IdOfNodeName(node->name_, nodes)); + } + } + return to_remove; +} + +void InitGraph::FixGraph(std::vector> nodes, size_t zombie_id) { + for (auto const &input : nodes[zombie_id]->pred_) { + for (auto const &output : nodes[zombie_id]->succ_) { + input->succ_.push_back(output); + output->pred_.push_back(input); + for (auto const &out_tensor : nodes[zombie_id]->output_tensors_) { + auto ipt_tsr = std::find(output->input_tensors_.begin(), output->input_tensors_.end(), out_tensor); + if (ipt_tsr != output->input_tensors_.end()) { + output->input_tensors_.erase(ipt_tsr); + } + } + output->input_tensors_.insert(output->input_tensors_.end(), input->output_tensors_.begin(), + input->output_tensors_.end()); + } + } +} + +// Buffers Naming +int InitGraph::IdOfNodeName(const std::string &name, const std::vector> &nodes) { + for (size_t i = 0; i < nodes.size(); i++) { + if (nodes[i]->name_ == name) { + return static_cast(i); + } + } + LOG(FATAL) << "No node with the name" + name; + return -1; +} + +// True if all inputs of a node were assigned a name +bool AreAllInputsAssigned(std::set> assigned, std::vector> inputs) { + bool is_assigned = true; + for_each(std::begin(inputs), std::end(inputs), [&is_assigned, &assigned](const std::shared_ptr &node) { + is_assigned &= (assigned.find(node) != assigned.end()) || (node->op_.IsConstant()); + }); + return is_assigned; +} + +void InitGraph::AddNodesName(std::vector names) { + std::set> nexts = std::set>(); // yet to assign + std::set> assigned; + std::string found; + bool is_buffer_stitch = false; + + if (names.empty()) { + LOG(FATAL) << "No buffer names given"; + } + + LOG(INFO) << "AddNodesName: " << nodes_.size() << " nodes; " << outputs_.size() << " outputs; " << inputs_.size() + << " inputs"; + + std::set intermed_output_names = GetIntermediateOutputsNames(names); + if (!intermed_output_names.empty()) { + is_buffer_stitch = true; + inputs_ = GetInputs(intermed_output_names, nodes_); + } + // give tensor name as name for input node + for (auto const &input_node : inputs_) { + input_node->name_ = input_node->output_tensors_[0]->name_; + assigned.insert(input_node); + + for (auto const &node : input_node->succ_) { + if (std::find(inputs_.begin(), inputs_.end(), node) == inputs_.end()) { + nexts.insert(node); + } + } + + FilterNames(names, input_node->output_tensors_[0]->name_); + } + + for (auto const &node : nodes_) { + if (node->op_.IsConstant()) { + nexts.insert(node->succ_.begin(), node->succ_.end()); + } + } + + std::shared_ptr node; + + // Breadth-First Search + while (!nexts.empty()) { + auto it1 = std::find_if(nexts.begin(), nexts.end(), [&assigned](const std::shared_ptr &node) { + return (AreAllInputsAssigned(assigned, node->pred_)); + }); + if (it1 == nexts.end()) { + auto it2 = std::find_if(nodes_.begin(), nodes_.end(), [&assigned](const std::shared_ptr &node) { + return ((assigned.find(node) == assigned.end()) && (!node->op_.IsInput()) && + (AreAllInputsAssigned(assigned, node->pred_))); + }); + if (it2 == nodes_.end()) { // raise an error + LOG(WARNING) << "some nodes are left unassigned\n"; + break; + } + node = *it2; + } else { + node = *it1; + } + + nexts.erase(node); + found = FindName(names, node); + + node->name_ = found; + assigned.insert(node); + + // add to next only if not assigned already + for_each(node->succ_.begin(), node->succ_.end(), [&assigned, &nexts](const std::shared_ptr &node) { + if (assigned.find(node) == assigned.end()) { + nexts.insert(node); + } + }); + + FilterNames(names, found); + } + + if (is_buffer_stitch) { + for (auto const &input : inputs_) { + input->name_ = ""; + } + } +} + +std::set InitGraph::GetIntermediateOutputsNames(const std::vector &names) { + std::set intermed_output_names; + size_t pos = 0; + size_t end = 0; + std::string output; + for (std::string const &name : names) { + pos = name.find("output"); + while (pos != std::string::npos) { + end = name.find('_', pos + 1); + end = name.find('_', end + 1); + end = name.find('_', end + 1); + if (end == std::string::npos) { + output = name.substr(pos); + } else { + output = name.substr(pos, end - pos); + } + intermed_output_names.insert(output); + pos = name.find("output", end); + } + } + return intermed_output_names; +} + +std::vector> InitGraph::GetInputs(std::set intermed_output_names, + const std::vector> &nodes) { + std::vector> n_inputs; + for (auto const &node : nodes) { + auto name = intermed_output_names.find(node->output_tensors_[0]->name_); + if (name != intermed_output_names.end()) { + n_inputs.push_back(node); + node->op_.op_type_ = Op::OpType::Input; + intermed_output_names.erase(name); + } + } + return n_inputs; +} + +bool InputsAreInside(std::vector> inputs, const std::string &name) { + size_t pos = 0; + + for (size_t i = 0; i < inputs.size(); i++) { + std::string input = StripRename(inputs[i]->name_); + if (i > 1) { + return true; + } + + if ((name.find(input, pos) == std::string::npos) && + (name.find(inputs[i]->output_tensors_[0]->name_, pos) == std::string::npos)) { + return false; + } + pos++; + } + return true; +} + +// find the name of a node +std::string InitGraph::FindName(std::vector names, std::shared_ptr node) { + std::vector possibles; + std::vector> inputs; + std::string result; + if (!node->op_.IsLonely()) { // ie Exp does't have its input written + inputs = node->pred_; + } + bool cst_input = HasConstantInput(node); + std::copy_if(names.begin(), names.end(), std::back_inserter(possibles), + [&inputs, &node, &cst_input](const std::string &name) { + bool is_inside = InputsAreInside(inputs, name); + bool reduceName = !(name.find("red", name.size() - 4) == std::string::npos); + reduceName = (node->op_.IsReduce()) ? reduceName : !reduceName; + return (is_inside && reduceName && node->op_.FitBufferName(name, cst_input)); + }); + + std::sort(possibles.begin(), possibles.end(), // avoid intermediate buffer + [](const std::string &str_a, const std::string &str_b) { return (str_a.size() < str_b.size()); }); + + if (possibles.empty()) { + LOG(INFO) << "Careful: No suitable name left for node " << node->op_.ToString(); + return ""; // Accepted because nodes may be simplified (e.g. RealDiv(x, 1)) + } + return possibles[0]; +} + +bool InitGraph::HasConstantInput(const std::shared_ptr &node) { + int size = node->pred_.size(); + bool all_consts = true; + for (int i = 0; i < size; ++i) { + all_consts &= node->pred_[i]->op_.IsConstant(); + } + return all_consts; +} + +// remove name assigned & similar ones (with local_UB.*) +void InitGraph::FilterNames(std::vector names, const std::string &out) { + std::vector to_remove; + for (int i = names.size() - 1; i >= 0; --i) { + if (names[i] == out) { + to_remove.push_back(i); + } + } + + for_each(std::begin(to_remove), std::end(to_remove), [&](int i) { names.erase(names.begin() + i); }); +} + +std::string InitGraph::ToString() { + std::stringstream buf; + buf << "{ name = " << this->name_; + int size = this->nodes_.size(); + for (int i = 0; i < size; i++) { + buf << "{node:" << nodes_[i] << "}" << std::endl; + } + return buf.str(); +} + +Op::OpCategory InitGraph::OperatorCategory() { + Op::OpCategory result = Op::OpCategory::Input; + for (auto const &node : nodes_) { + result = Op::DominantCategory(result, node->op_.Category()); + } + return result; +} +} // namespace poly +} // namespace ir +} // namespace akg diff --git a/src/poly/tiling/hermes/init_graph.h b/src/poly/tiling/hermes/init_graph.h new file mode 100644 index 0000000000000000000000000000000000000000..eca9dee06db891a6da65128a6cc186c1c659e8ec --- /dev/null +++ b/src/poly/tiling/hermes/init_graph.h @@ -0,0 +1,68 @@ +/** + * Copyright 2021-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef POLY_TILING_HERMES_INIT_GRAPH_H_ +#define POLY_TILING_HERMES_INIT_GRAPH_H_ + +#include +#include +#include +#include + +#include "poly/tiling/hermes/node.h" + +namespace akg { +namespace ir { +namespace poly { +class InitGraph { + public: + InitGraph(const std::string &, std::vector> &, const std::vector> &, + const std::vector> &); + explicit InitGraph(const std::vector> &); + InitGraph() = default; + + Op::OpCategory OperatorCategory(); + void RemoveNameless(); + void AddNodesName(std::vector); + std::string ToString(); + + std::string name_; + std::vector> nodes_; + std::vector> inputs_; + std::vector> outputs_; + + private: + void SetInputNodes(); + void SetOutputNodes(); + static void SetConstantNodes(const std::vector> &nodes, + const std::vector> &inputs); + static std::set UselessInput(const std::vector> &inputs, + const std::vector> &nodes, std::set to_remove); + static void FixGraph(std::vector> nodes, size_t zombie_id); + static int IdOfNodeName(const std::string &name, const std::vector> &nodes); + static std::set GetIntermediateOutputsNames(const std::vector &names); + static std::vector> GetInputs(std::set intermed_output_names, + const std::vector> &nodes); + static std::string FindName(std::vector names, std::shared_ptr node); + static bool HasConstantInput(const std::shared_ptr &node); + static void FilterNames(std::vector names, const std::string &out); + + template + std::vector> RefVecFromIdxVec(std::vector indexes, std::vector> refs); +}; +} // namespace poly +} // namespace ir +} // namespace akg +#endif // POLY_TILING_HERMES_INIT_GRAPH_H_ diff --git a/src/poly/tiling/hermes/model_graph.cc b/src/poly/tiling/hermes/model_graph.cc new file mode 100644 index 0000000000000000000000000000000000000000..1162b00a9201e22cfa3bf55f053bd550e82069c1 --- /dev/null +++ b/src/poly/tiling/hermes/model_graph.cc @@ -0,0 +1,210 @@ +/** + * Copyright 2021-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "poly/tiling/hermes/model_graph.h" +#include "poly/tiling/hermes/op.h" + +namespace akg { +namespace ir { +namespace poly { +std::vector ModelGraph::global_axis_vec_; +std::set> ModelGraph::name_dim_set_; + +ModelGraph::ModelGraph(InitGraph &init_graph) { + // Add additional nodes generated by Reduce Op to InitGraph. + CompleteNodesGeneratedByReduce(init_graph); + std::vector> critical_nodes = GetCriticalNodes(init_graph); + this->name_ = init_graph.name_; + this->nodes_ = init_graph.nodes_; + this->inputs_ = init_graph.inputs_; + this->outputs_ = init_graph.outputs_; + this->critical_nodes_ = critical_nodes; +} + +ModelGraph::ModelGraph(InitGraph &init_graph, const std::vector> &critical_nodes) + : InitGraph{init_graph.name_, init_graph.nodes_, init_graph.inputs_, init_graph.outputs_}, + critical_nodes_{critical_nodes} {} + +void ModelGraph::CompleteNodesGeneratedByReduce(InitGraph &init_graph) { + int index_reduce = -1; + for (size_t index_node = init_graph.nodes_.size(); index_node > 0; --index_node) { + if (init_graph.nodes_[index_node - 1]->op_.IsReduce()) { + index_reduce = static_cast(index_node) - 1; + break; + } + } + if (index_reduce < 0) { + return; + } + + auto &reduce_node = init_graph.nodes_[index_reduce]; + + ReduceDirection reduce_type = GetReduceDirection(reduce_node); + + int ax0 = 0; + int ax1 = 0; + for (auto const &axis : ModelGraph::global_axis_vec_) { + if (axis.dim_axis_ == 0) { + ax0 = static_cast(axis.range_); + } else if (axis.dim_axis_ == 1) { + ax1 = static_cast(axis.range_); + } + } + + int dst_shape_size = 0; + int src_shape_size = 0; + + if (reduce_type == ReduceDirection::ALL) { + reduce_node->op_.op_type_ = Op::OpType::AllReduce; + reduce_node->output_tensors_[0]->shape_[0] = kExtraMemoryCoeffRequiredByReduceDst; + int shape_val = kExtraMemoryCoeffRequiredByReduceDst; + reduce_node->transformed_output_shape_[0].shape_.push_back(shape_val); + + dst_shape_size = kExtraMemoryCoeffRequiredByReduceDst; + src_shape_size = kExtraMemoryCoeffRequiredByAllReduce; + } else { + if (reduce_type == ReduceDirection::Y) { + reduce_node->op_.op_type_ = Op::OpType::ReduceY; + } else { + reduce_node->op_.op_type_ = Op::OpType::ReduceX; + } + dst_shape_size = ax0 * kExtraMemoryCoeffRequiredByReduceDst; + src_shape_size = ax0 * ax1 / kExtraMemoryCoeffRequiredByReduceSrc; + } + + std::shared_ptr dst_node = + SetReduceSrcDstNodes(reduce_node, kDstTmpSuffix, Op::OpType::ReduceDST, dst_shape_size); + std::shared_ptr src_node = + SetReduceSrcDstNodes(reduce_node, kSrcTmpSuffix, Op::OpType::ReduceSRC, src_shape_size); + + dst_node->axis_of_node_ = reduce_node->axis_of_node_; + dst_node->axis_to_tensor_to_shape_id_map_ = reduce_node->axis_to_tensor_to_shape_id_map_; + + src_node->axis_of_node_ = reduce_node->pred_[0]->axis_of_node_; + if (reduce_type == ReduceDirection::ALL) { + src_node->axis_to_tensor_to_shape_id_map_ = reduce_node->axis_to_tensor_to_shape_id_map_; + } else { + src_node->axis_to_tensor_to_shape_id_map_ = reduce_node->pred_[0]->axis_to_tensor_to_shape_id_map_; + } + + init_graph.nodes_.push_back(dst_node); + init_graph.nodes_.push_back(src_node); +} + +std::shared_ptr ModelGraph::SetReduceSrcDstNodes(const std::shared_ptr &reduce_node, + const std::string &suffix, Op::OpType op_type, int shape_size) { + std::shared_ptr node = std::make_shared(); + + node->name_ = reduce_node->name_ + suffix; + node->op_.op_type_ = op_type; + + std::shared_ptr output_tensor = std::make_shared(); + output_tensor->shape_.push_back(shape_size); + output_tensor->datatype_ = reduce_node->output_tensors_[0]->datatype_; + output_tensor->format_ = reduce_node->output_tensors_[0]->format_; + node->output_tensors_.push_back(output_tensor); + + std::vector tensor_shape; + tensor_shape.push_back(shape_size); + node->transformed_output_shape_.emplace_back( + Tensor(tensor_shape, reduce_node->output_tensors_[0]->datatype_, reduce_node->output_tensors_[0]->format_)); + + node->input_tensors_ = reduce_node->input_tensors_; + node->succ_ = reduce_node->succ_; + node->pred_ = reduce_node->pred_; + + return node; +} + +ReduceDirection ModelGraph::GetReduceDirection(const std::shared_ptr &reduce_node) { + if (reduce_node->output_tensors_.size() == 1 && reduce_node->output_tensors_[0]->shape_.size() == 1 && + reduce_node->output_tensors_[0]->shape_[0] == 1) { + return ReduceDirection::ALL; + } + ReduceDirection reduce_type = ReduceDirection::UNKNOWN; + for (auto const &axis : ModelGraph::global_axis_vec_) { + if (!axis.is_inner_ && axis.is_reduce_axis_) { + if (axis.is_reduce_src_last_) { + reduce_type = ReduceDirection::X; + } else { + reduce_type = ReduceDirection::Y; + } + } + } + if (reduce_type == ReduceDirection::UNKNOWN) { + LOG(FATAL) << "unknown reduce type"; + } + return reduce_type; +} + +std::tuple ModelGraph::GetMinShapeAndDataCoef(const Axis &axis) const { + int min_shape = INT32_MAX; + int data_coef = INT32_MAX; + for (auto const &node : this->nodes_) { + for (auto const &node_axis : node->axis_of_node_) { + if (node_axis.dim_axis_ == axis.dim_axis_ && static_cast(node_axis.range_) < min_shape) { + min_shape = static_cast(node_axis.range_); + data_coef = node->output_tensors_[0]->GetDataTypeCoef(); + break; + } + } + } + return std::make_tuple(min_shape, data_coef); +} + +bool ModelGraph::IsInVector(const std::string &name, const std::vector> &node_vec) { + return std::any_of(node_vec.begin(), node_vec.end(), + [&name](const std::shared_ptr &node) { return name == node->name_; }); +} + +std::vector> ModelGraph::GetCriticalNodes(const InitGraph &init_graph) { + std::vector> critical_nodes; + if (!init_graph.nodes_.empty()) { + critical_nodes.push_back(init_graph.nodes_[0]); + } + for (size_t i = 1; i < init_graph.nodes_.size(); i++) { + if (init_graph.nodes_[i]->name_.find(HInputOp::input) == 0 || + critical_nodes.back()->name_.find(HInputOp::input) == 0) { + critical_nodes.push_back(init_graph.nodes_[i]); + } else if (init_graph.nodes_[i]->op_.op_type_ == Op::OpType::AllReduce || + init_graph.nodes_[i]->op_.op_type_ == Op::OpType::ReduceX || + init_graph.nodes_[i]->op_.op_type_ == Op::OpType::ReduceY || + init_graph.nodes_[i]->op_.op_type_ == Op::OpType::ReduceDST || + init_graph.nodes_[i]->op_.op_type_ == Op::OpType::ReduceSRC) { + critical_nodes.push_back(init_graph.nodes_[i]); + } else if (IsInVector(init_graph.nodes_[i]->name_, init_graph.outputs_)) { + critical_nodes.push_back(init_graph.nodes_[i]); + } else if (init_graph.nodes_[i]->succ_.size() > 1) { + critical_nodes.push_back(init_graph.nodes_[i]); + } else { + int curr_node_prod_out_shape = kMinShapeSize; + int last_critc_node_prod_out_shape = kMinShapeSize; + for (auto const &out_t : init_graph.nodes_[i]->output_tensors_) { + curr_node_prod_out_shape *= out_t->GetShapeProduct() * out_t->GetDataTypeCoef(); + } + for (auto const &out_t : critical_nodes.back()->output_tensors_) { + last_critc_node_prod_out_shape *= out_t->GetShapeProduct() * out_t->GetDataTypeCoef(); + } + if (curr_node_prod_out_shape >= last_critc_node_prod_out_shape) { + critical_nodes.pop_back(); + critical_nodes.push_back(init_graph.nodes_[i]); + } + } + } + return critical_nodes; +} +} // namespace poly +} // namespace ir +} // namespace akg diff --git a/src/poly/tiling/hermes/model_graph.h b/src/poly/tiling/hermes/model_graph.h new file mode 100644 index 0000000000000000000000000000000000000000..961604af630ff36f659df4bf84f37062a55fcca0 --- /dev/null +++ b/src/poly/tiling/hermes/model_graph.h @@ -0,0 +1,64 @@ +/** + * Copyright 2021-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef POLY_TILING_HERMES_MODEL_GRAPH_H_ +#define POLY_TILING_HERMES_MODEL_GRAPH_H_ + +#include +#include +#include +#include + +#include "poly/poly_util.h" +#include "poly/tiling/hermes/axis.h" +#include "poly/tiling/hermes/init_graph.h" +#include "poly/tiling/hermes/node.h" + +namespace akg { +namespace ir { +namespace poly { +class ModelGraph : public InitGraph { + public: + ModelGraph(InitGraph &, const std::vector> &); + explicit ModelGraph(InitGraph &init_graph); + ModelGraph() = default; + + std::tuple GetMinShapeAndDataCoef(const Axis &axis) const; + + std::vector> critical_nodes_; + bool is_activated_double_buffer_{false}; + + static std::vector global_axis_vec_; + static std::set> name_dim_set_; + + private: + static void CompleteNodesGeneratedByReduce(InitGraph &init_graph); + static bool IsInVector(const std::string &name, const std::vector> &node_vec); + static ReduceDirection GetReduceDirection(const std::shared_ptr &reduce_node); + static std::shared_ptr SetReduceSrcDstNodes(const std::shared_ptr &reduce_node, const std::string &suffix, + Op::OpType op_type, int shape_size); + static std::vector> GetCriticalNodes(const InitGraph &init_graph); + + static const int kExtraMemoryCoeffRequiredByAllReduce = 16; + static const int kExtraMemoryCoeffRequiredByReduceDst = 8; + static const int kExtraMemoryCoeffRequiredByReduceSrc = 64; + static const int kMinShapeSize = 1; + inline static const std::string kSrcTmpSuffix = "_src_tmp"; + inline static const std::string kDstTmpSuffix = "_dst_tmp"; +}; +} // namespace poly +} // namespace ir +} // namespace akg +#endif // POLY_TILING_HERMES_MODEL_GRAPH_H_ diff --git a/src/poly/tiling/hermes/op.h b/src/poly/tiling/hermes/op.h index 126c1528180765e064067199dae07a7d62bcabca..d8fa465234509aa7e5c0ebff30c03c07e79fd113 100644 --- a/src/poly/tiling/hermes/op.h +++ b/src/poly/tiling/hermes/op.h @@ -145,11 +145,7 @@ class Op { ReduceY = 8, MatMul = 9 }; - enum Source { - Enum = 0, - Info = 1, - IR = 2 - }; + enum Source { Enum = 0, Info = 1, IR = 2 }; static int Priority(Op::OpCategory cat); }; diff --git a/src/poly/tiling/hermes/utils.cc b/src/poly/tiling/hermes/utils.cc new file mode 100644 index 0000000000000000000000000000000000000000..5bb001a7da09e9ec00eb27065f146dce01fd9829 --- /dev/null +++ b/src/poly/tiling/hermes/utils.cc @@ -0,0 +1,82 @@ +/** + * Copyright 2021-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "poly/tiling/hermes/utils.h" + +namespace akg { +namespace ir { +namespace poly { +std::string ParseString(air::Expr expression) { + if (const auto *const strimm = expression.as()) { + return strimm->value; + } else { + LOG(FATAL) << "String cannot be parsed"; + return ""; + } +} + +int ParseInt(air::Integer integer) { + if (const auto *const intimm = integer.as()) { + return intimm->value; + } else { + LOG(FATAL) << "Int cannot be parsed"; + return -1; + } +} + +std::vector ParseIntArray(air::Array arr) { + std::vector vec; + for (air::Integer i : arr) { + vec.push_back(ParseInt(i)); + } + return vec; +} + +std::vector ParseStringArray(air::Array arr) { + std::vector vec; + for (air::Expr s : arr) { + vec.push_back(ParseString(s)); + } + return vec; +} + +std::string StripRename(std::string name) { + size_t pos = name.rfind("_rename"); + if (pos == std::string::npos) { + return name; + } + return name.substr(0, pos); +} + +int Get2PowerBelow(int n) { + int result = n; + constexpr int twice = 2; + for (int i = 1; i < n; i *= twice) { + result = i; + } + return result; +} + +int SearchDownDivisibleNumber(int begin, size_t divided_size) { + for (int i = begin; i > 0; i--) { + if (divided_size % i == 0) { + return i; + } + } + return 1; +} +} // namespace poly +} // namespace ir +} // namespace akg diff --git a/src/poly/tiling/hermes/utils.h b/src/poly/tiling/hermes/utils.h new file mode 100644 index 0000000000000000000000000000000000000000..59b0036b24f188ed8467a01ba2df1be380c2618f --- /dev/null +++ b/src/poly/tiling/hermes/utils.h @@ -0,0 +1,38 @@ +/** + * Copyright 2021-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef POLY_TILING_HERMES_UTILS_H_ +#define POLY_TILING_HERMES_UTILS_H_ + +#include + +#include +#include + +namespace akg { +namespace ir { +namespace poly { +std::string ParseString(air::Expr); +int ParseInt(air::Integer); +std::vector ParseIntArray(air::Array); +std::vector ParseStringArray(air::Array arr); +std::string StripRename(std::string name); +int Get2PowerBelow(int); +int SearchDownDivisibleNumber(int begin, size_t divided_size); +} // namespace poly +} // namespace ir +} // namespace akg + +#endif // POLY_TILING_HERMES_UTILS_H_