diff --git a/torch_npu/csrc/aten/common/ResizeNpu.h b/torch_npu/csrc/aten/common/ResizeNpu.h index e71914c9adaca20590f4d61b95fef47febe0d1dc..78dbfac58bea2304202cd8653194ab2935426604 100644 --- a/torch_npu/csrc/aten/common/ResizeNpu.h +++ b/torch_npu/csrc/aten/common/ResizeNpu.h @@ -20,17 +20,12 @@ static void storage_resize_npu( c10::IntArrayRef new_size) { if (!storage.resizable()) { - TORCH_CHECK(false, "Trying to resize storage that is not resizable", OPS_ERROR(ErrCode::NOT_SUPPORT)); - return; - } - - auto &storage_desc = torch_npu::NPUBridge::GetNpuStorageImpl(&storage)->npu_desc_; - if (!FormatHelper::IsBaseFormatType(storage_desc.npu_format_)) { - TORCH_CHECK(false, "Cannot resize storage without base format", OPS_ERROR(ErrCode::NOT_SUPPORT)); + AT_ERROR("Trying to resize storage that is not resizable"); return; } at::DataPtr new_data = storage.allocator()->allocate(size); + auto storage_desc = torch_npu::NPUBridge::GetNpuStorageImpl(&storage)->npu_desc_; size_t itemsize = storage_desc.data_type_.itemsize(); at::DataPtr old_data = storage.set_data_ptr(std::move(new_data)); ptrdiff_t old_size = static_cast(storage.nbytes()); @@ -50,25 +45,8 @@ static void storage_resize_npu( }; // It is necessary to properly refresh the storage according to sizes and strides, // not just new sizes. - int64_t new_data_numel = c10::multiply_integers(resize_shape); - int64_t new_shape_numel = c10::multiply_integers(new_size); - const c10::IntArrayRef &refresh_size = new_data_numel > new_shape_numel ? resize_shape : new_size; - - // 计算连续场景下size对应的stride值 - int64_t dim_ = static_cast(refresh_size.size()); - c10::SmallVector new_stride(dim_); - if (dim_ > 0) { - int64_t last_idx = dim_ - 1; - new_stride[last_idx] = 1; - for (auto i = last_idx - 1; i >= 0; --i) { - new_stride[i] = new_stride[i + 1] * std::max(refresh_size[i + 1], 1); - } - } - - storage_desc.base_sizes_ = refresh_size; - storage_desc.base_strides_ = new_stride; - storage_desc.npu_format_ = ACL_FORMAT_ND; - storage_desc.storage_sizes_ = storage_desc.base_sizes_; + StorageDescHelper::UpdateDesc( + torch_npu::NPUBridge::GetNpuStorageImpl(&storage)->npu_desc_, resize_shape, new_size); if (old_data != nullptr) { ptrdiff_t copy_size = old_size;