From 618d38d90ac90840f0dcd31ac2efa5b7dcd017e9 Mon Sep 17 00:00:00 2001 From: weidandan 00687068 Date: Mon, 20 Feb 2023 17:26:00 +0800 Subject: [PATCH] =?UTF-8?q?=E6=95=B0=E6=8D=AE=E9=A2=84=E5=A4=84=E7=90=86?= =?UTF-8?q?=E5=8A=A8=E6=80=81=E6=8E=A7=E5=88=B6device=E5=86=85=E5=AD=98?= =?UTF-8?q?=E5=8D=A0=E7=94=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../kernels/aicpu/host_queue_dataset_op.cc | 49 ++++++++++++++++++- .../depends/ascendcl/src/ascendcl_stub.cc | 4 ++ 2 files changed, 51 insertions(+), 2 deletions(-) diff --git a/tf_adapter/kernels/aicpu/host_queue_dataset_op.cc b/tf_adapter/kernels/aicpu/host_queue_dataset_op.cc index f032e347f..6f5637918 100644 --- a/tf_adapter/kernels/aicpu/host_queue_dataset_op.cc +++ b/tf_adapter/kernels/aicpu/host_queue_dataset_op.cc @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#include #include #include #include @@ -280,6 +281,14 @@ class HostQueueDatasetOp : public DatasetOpKernel { dataset()->device_id_, dataset()->local_device_list_, dataset()->channel_name_)) { + size_t output_shape_size = dataset()->output_types_.size(); + for (size_t i = 0UL; i < output_shape_size; i++) { + DataType tensor_data_type = dataset()->output_types_.at(i); + if (tensor_data_type == DT_STRING) { + is_hold_data_trans = true; + break; + } + } } ~Iterator() override { @@ -581,7 +590,10 @@ class HostQueueDatasetOp : public DatasetOpKernel { Status SendDataByAclQueue(const vector &args, const acltdtTensorType &data_type) { Status status; + aclError acl_status; + size_t size; bool is_need_resend = false; + bool is_need_recompute_mbuf = false; do { { mutex_lock lck(mu_); @@ -589,8 +601,29 @@ class HostQueueDatasetOp : public DatasetOpKernel { break; } } - status = SendTensorsByAcl(acl_handle_, data_type, args, is_need_resend); - } while (status.ok() && is_need_resend); + if (!is_hold_data_trans || is_need_resend) { + status = SendTensorsByAcl(acl_handle_, data_type, args, is_need_resend); + } else { + acl_status = acltdtQueryChannelSize(acl_handle_, &size); + if (acl_status != ACL_SUCCESS) { + return errors::InvalidArgument("Failed to get the mbuf size!"); + } + if (size <= 1 || GetMbufTotalBytes(size) <= kMaxBytes) { + status = SendTensorsByAcl(acl_handle_, data_type, args, is_need_resend); + is_need_recompute_mbuf = false; + } else { + is_need_recompute_mbuf = true; + sched_yield(); + } + } + } while ((status.ok() && is_need_resend) || is_need_recompute_mbuf); + + uint64_t bytes_sum = 0ULL; + for (auto &tensor : args) { + bytes_sum += tensor.TotalBytes(); + } + args_bytes[args_bytes_rear] = bytes_sum; + args_bytes_rear = (args_bytes_rear + 1) % kStringTypeDepth; return status; } @@ -643,6 +676,15 @@ class HostQueueDatasetOp : public DatasetOpKernel { return true; } + uint64_t GetMbufTotalBytes(size_t size) { + uint64_t sum = 0; + for (size_t i = 1; i <= size ; i++) { + size_t index = (args_bytes_rear - i + kStringTypeDepth) % kStringTypeDepth; + sum += args_bytes[index]; + } + return sum; + } + void SendDataByQueueThread(const std::shared_ptr &ctx) { ADP_LOG(INFO) << "Begin to send data to the NPU. "; rtError_t rt = rtSetDevice(dataset()->device_id_); @@ -976,6 +1018,9 @@ class HostQueueDatasetOp : public DatasetOpKernel { acltdtChannelHandle *acl_handle_; uint32_t queue_id_; int active_thread_num = 0; + uint64_t args_bytes[kStringTypeDepth]; + int args_bytes_rear = 0; + bool is_hold_data_trans = false; }; const std::vector inputs_; std::string channel_name_; diff --git a/tf_adapter/tests/depends/ascendcl/src/ascendcl_stub.cc b/tf_adapter/tests/depends/ascendcl/src/ascendcl_stub.cc index 52109405e..87862c1d0 100644 --- a/tf_adapter/tests/depends/ascendcl/src/ascendcl_stub.cc +++ b/tf_adapter/tests/depends/ascendcl/src/ascendcl_stub.cc @@ -103,6 +103,10 @@ aclError aclrtResetDevice(int32_t deviceId) { return ACL_SUCCESS; } +aclError acltdtQueryChannelSize(const acltdtChannelHandle *handle, size_t *size) { + return ACL_SUCCESS; +} + acltdtChannelHandle *acltdtCreateChannelWithCapacity(uint32_t deviceId, const char *name, size_t capacity) { -- Gitee