diff --git a/torch_npu/csrc/aten/common/ResizeNpu.h b/torch_npu/csrc/aten/common/ResizeNpu.h index e94fb80ce1b6bfad60c60315e3ca954f2fec2910..8153c56f9e7fbb844df8604d8fa9d0e0668f993c 100644 --- a/torch_npu/csrc/aten/common/ResizeNpu.h +++ b/torch_npu/csrc/aten/common/ResizeNpu.h @@ -20,7 +20,6 @@ #include "torch_npu/csrc/core/npu/NPUStream.h" #include "torch_npu/csrc/core/npu/interface/AsyncTaskQueueInterface.h" #include "torch_npu/csrc/framework/utils/NpuUtils.h" -#include "torch_npu/csrc/framework/FormatHelper.h" #include "torch_npu/csrc/framework/StorageDescHelper.h" #include "torch_npu/csrc/core/NPUBridge.h" #include "torch_npu/csrc/core/NPUStorageImpl.h" @@ -32,80 +31,57 @@ namespace native { static void storage_resize_npu( torch_npu::NPUStorageImpl& storage, ptrdiff_t size, - c10::IntArrayRef new_size) -{ - if (!storage.resizable()) { - TORCH_CHECK(false, "Trying to resize storage that is not resizable", OPS_ERROR(ErrCode::NOT_SUPPORT)); - return; - } - - auto &storage_desc = torch_npu::NPUBridge::GetNpuStorageImpl(&storage)->npu_desc_; - if (!FormatHelper::IsBaseFormatType(storage_desc.npu_format_)) { - TORCH_CHECK(false, "Cannot resize storage without base format", OPS_ERROR(ErrCode::NOT_SUPPORT)); - return; - } + c10::IntArrayRef new_size) { + if (!storage.resizable()) { + AT_ERROR("Trying to resize storage that is not resizable"); + return; + } - at::DataPtr new_data; - size_t itemsize = storage_desc.data_type_.itemsize(); - if (size != 0) { - new_data = storage.allocator()->allocate(size); - } - at::DataPtr old_data = storage.set_data_ptr(std::move(new_data)); - ptrdiff_t old_size = static_cast(storage.nbytes()); - storage.set_nbytes(size); + at::DataPtr new_data; + auto storage_desc = torch_npu::NPUBridge::GetNpuStorageImpl(&storage)->npu_desc_; + size_t itemsize = storage_desc.data_type_.itemsize(); + if (size != 0) { + new_data = storage.allocator()->allocate(size); + } + at::DataPtr old_data = storage.set_data_ptr(std::move(new_data)); + ptrdiff_t old_size = static_cast(storage.nbytes()); + storage.set_nbytes(size); - if (itemsize == 0) { - AT_ERROR("When resizing, item size of storage cannot be zero."); - return; + if (itemsize == 0) { + AT_ERROR("When resizing, item size of storage cannot be zero."); + return; + } + if ((size % static_cast(itemsize)) != 0) { + AT_ERROR("The specified storage nbytes cannot be divided by item size.", + "Please check the input parameter size."); + return; + } + std::vector resize_shape = { + size/static_cast(itemsize) + }; + // It is necessary to properly refresh the storage according to sizes and strides, + // not just new sizes. + StorageDescHelper::UpdateDesc( + torch_npu::NPUBridge::GetNpuStorageImpl(&storage)->npu_desc_, resize_shape, new_size); + + if (old_data != nullptr) { + ptrdiff_t copy_size = old_size; + if (static_cast(storage.nbytes()) < copy_size) { + copy_size = static_cast(storage.nbytes()); } - if ((size % static_cast(itemsize)) != 0) { - AT_ERROR("The specified storage nbytes cannot be divided by item size.", - "Please check the input parameter size."); + if (copy_size > 0) { + aclError error = CalcuOpUtil::LaunchAsyncCopyTaskWithModeSwitch( + storage, + copy_size, + old_data.get(), + copy_size, + ACL_MEMCPY_DEVICE_TO_DEVICE); + if (error != ACL_ERROR_NONE) { + AT_ERROR("ACL_Memcpy device to device error."); return; + } } - std::vector resize_shape = { - size/static_cast(itemsize) - }; - // It is necessary to properly refresh the storage according to sizes and strides, - // not just new sizes. - int64_t new_data_numel = c10::multiply_integers(resize_shape); - int64_t new_shape_numel = c10::multiply_integers(new_size); - const c10::IntArrayRef &refresh_size = new_data_numel > new_shape_numel ? resize_shape : new_size; - - // 计算连续场景下size对应的stride值 - int64_t dim_ = static_cast(refresh_size.size()); - c10::SmallVector new_stride(dim_); - if (dim_ > 0) { - int64_t last_idx = dim_ - 1; - new_stride[last_idx] = 1; - for (auto i = last_idx - 1; i >= 0; --i) { - new_stride[i] = new_stride[i + 1] * std::max(refresh_size[i + 1], 1); - } - } - - storage_desc.base_sizes_ = refresh_size; - storage_desc.base_strides_ = new_stride; - storage_desc.npu_format_ = ACL_FORMAT_ND; - storage_desc.storage_sizes_ = storage_desc.base_sizes_; - - if (old_data != nullptr) { - ptrdiff_t copy_size = old_size; - if (static_cast(storage.nbytes()) < copy_size) { - copy_size = static_cast(storage.nbytes()); - } - if (copy_size > 0) { - aclError error = CalcuOpUtil::LaunchAsyncCopyTaskWithModeSwitch( - storage, - copy_size, - old_data.get(), - copy_size, - ACL_MEMCPY_DEVICE_TO_DEVICE); - if (error != ACL_ERROR_NONE) { - AT_ERROR("ACL_Memcpy device to device error."); - return; - } - } - } + } } static inline void maybe_resize_storage_npu(