diff --git a/tf_adapter/kernels/geop_npu.cc b/tf_adapter/kernels/geop_npu.cc index 87dd62d71512d8e34789944d324956d73f306c5f..d866b67531732191cccfd31950a0b5f592aca46e 100644 --- a/tf_adapter/kernels/geop_npu.cc +++ b/tf_adapter/kernels/geop_npu.cc @@ -1734,9 +1734,19 @@ Status GeOp::BuildInputTensorInfo(OpKernelContext *const ctx, std::vectornum_inputs(); std::string cur_input_shapes; + input_shapes_vec_.resize(num_inputs, TensorShape()); + + if (init_options_["ge.jit_compile"] == "0") { + FuzzAllShape(ctx); + } else { + MaybeUpdateShape(ctx); + } + // populate inputs for (int i = 0; i < num_inputs; i++) { Tensor tensor(ctx->input(i)); + ADP_LOG(INFO) << "Refresh input tensor " << i << " tensor1 to " << tensor.DebugString(); + bool is_equal = false; if (GraphCheckInputEqualConstOp(tensor, i, is_equal) != Status::OK()) { return errors::Internal("Const op value not equal with tensor :", i); @@ -1786,16 +1796,9 @@ Status GeOp::BuildInputTensorInfo(OpKernelContext *const ctx, std::vectorop_kernel().name(), - " is dynamic, please ensure that npu option[dynamic_input] is set" - " correctly, for more details please refer to the migration guide."); - } + Tensor tensor1 = Tensor(data_type, input_shapes_vec_[i]); + ADP_LOG(INFO) << "Refresh input tensor1 " << i << " tensor1 to " << tensor1.DebugString(); + input_vec.push_back(tensor1); } return Status::OK(); } diff --git a/tf_adapter/kernels/geop_npu.h b/tf_adapter/kernels/geop_npu.h index 50c38544a9c4c0c7a1ceaef7b02d51c9cfd6161d..385b43848427ba52a4655a1b85b5e048e1a2faa3 100644 --- a/tf_adapter/kernels/geop_npu.h +++ b/tf_adapter/kernels/geop_npu.h @@ -20,6 +20,7 @@ #include #include +#include "tf_adapter/common/adapter_logger.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/platform/mutex.h" @@ -128,6 +129,65 @@ private: bool IsDynamicConfig(); + void FuzzAllShape(OpKernelContext *const ctx) { + for (size_t i = 0UL; i < static_cast(ctx->num_inputs()); i++) { + auto &value_shape = ctx->input(static_cast(i)).shape(); + TensorShape shape; + std::vector dims; + ADP_LOG(INFO) << "value_shape dims " << value_shape.dims(); + for (int j = 0; j < value_shape.dims(); j++) { + dims.push_back(-1); + } + ADP_LOG(INFO) << "dims size is " << dims.size(); + for (auto dim : dims) { + ADP_LOG(INFO) << "dim is : " << dim; + } + auto status = TensorShapeUtils::MakeShape(dims.data(), static_cast(dims.size()), &shape); + ADP_LOG(INFO) << "FuzzAllShape Init input " << i << " shape " << shape.DebugString() << "status is" + << status.ok(); + input_shapes_vec_[i] = shape; + } + } + + void MaybeUpdateShape(OpKernelContext *const ctx) { + for (size_t i = 0UL; i < static_cast(ctx->num_inputs()); i++) { + auto &shape = input_shapes_vec_[i]; + auto &value_shape = ctx->input(static_cast(i)).shape(); + if (shape.dims() == 0) { + build_flag_ = false; + shape = value_shape; + ADP_LOG(INFO) << "Init input " << i << " shape " << shape.DebugString(); + } else { + PartialTensorShape sub_shape = shape; + if (sub_shape.IsCompatibleWith(value_shape)) { + continue; + } else { + build_flag_ = false; + ADP_LOG(INFO) << "Compat input " << i << " shape " << shape.DebugString() << " vs. " + << value_shape.DebugString(); + shape = MakeCompatShape(shape, value_shape); + ADP_LOG(INFO) << "Refresh input " << i << " shape to " << shape.DebugString(); + } + } + } + } + + TensorShape MakeCompatShape(const TensorShape &a, const TensorShape &b) { + const static auto kUnknownRankShape = TensorShape(); + if (a.dims() != b.dims()) { + return kUnknownRankShape; + } + TensorShape shape; + static constexpr int64 kUnknownDim = -1; + std::vector dims; + for (int i = 0; i < a.dims(); i++) { + dims.push_back((a.dim_size(i) != b.dim_size(i)) ? kUnknownDim : a.dim_size(i)); + } + auto status = TensorShapeUtils::MakeShape(dims.data(), static_cast(dims.size()), &shape); + //NPU_LOG_IF_ERROR(status); + return status.ok() ? shape : kUnknownRankShape; + } + static const std::string INPUT_DESC; static const std::string OUTPUT_DESC; static const std::string SERIALIZE_FORMAT; @@ -180,6 +240,7 @@ private: std::string recompute_mode_; std::string enable_graph_parallel_; std::string graph_parallel_option_path_; + std::vector input_shapes_vec_; SessionId session_id_; AoeInitializeFunc aoe_initialize_;