diff --git a/torch_npu/csrc/aten/common/ResizeNpu.h b/torch_npu/csrc/aten/common/ResizeNpu.h index 7d9c3a2b8c2a28093cdc75ddf2ffcc543a144541..715e2d4b9c0dc16a6a33ce4b19c38b3aec7de7d4 100644 --- a/torch_npu/csrc/aten/common/ResizeNpu.h +++ b/torch_npu/csrc/aten/common/ResizeNpu.h @@ -20,12 +20,17 @@ static void storage_resize_npu( c10::IntArrayRef new_size) { if (!storage.resizable()) { - AT_ERROR("Trying to resize storage that is not resizable"); + TORCH_CHECK(false, "Trying to resize storage that is not resizable", OPS_ERROR(ErrCode::NOT_SUPPORT)); + return; + } + + auto &storage_desc = torch_npu::NPUBridge::GetNpuStorageImpl(&storage)->npu_desc_; + if (!FormatHelper::IsBaseFormatType(storage_desc.npu_format_)) { + TORCH_CHECK(false, "Cannot resize storage without base format", OPS_ERROR(ErrCode::NOT_SUPPORT)); return; } at::DataPtr new_data = storage.allocator()->allocate(size); - auto storage_desc = torch_npu::NPUBridge::GetNpuStorageImpl(&storage)->npu_desc_; size_t itemsize = storage_desc.data_type_.itemsize(); at::DataPtr old_data = storage.set_data_ptr(std::move(new_data)); ptrdiff_t old_size = static_cast(storage.nbytes()); @@ -45,8 +50,25 @@ static void storage_resize_npu( }; // It is necessary to properly refresh the storage according to sizes and strides, // not just new sizes. - StorageDescHelper::UpdateDesc( - torch_npu::NPUBridge::GetNpuStorageImpl(&storage)->npu_desc_, resize_shape, new_size); + int64_t new_data_numel = c10::multiply_integers(resize_shape); + int64_t new_shape_numel = c10::multiply_integers(new_size); + const c10::IntArrayRef &refresh_size = new_data_numel > new_shape_numel ? resize_shape : new_size; + + // 计算连续场景下size对应的stride值 + int64_t dim_ = static_cast(refresh_size.size()); + c10::SmallVector new_stride(dim_); + if (dim_ > 0) { + int64_t last_idx = dim_ - 1; + new_stride[last_idx] = 1; + for (auto i = last_idx - 1; i >= 0; --i) { + new_stride[i] = new_stride[i + 1] * std::max(refresh_size[i + 1], 1); + } + } + + storage_desc.base_sizes_ = refresh_size; + storage_desc.base_strides_ = new_stride; + storage_desc.npu_format_ = ACL_FORMAT_ND; + storage_desc.storage_sizes_ = storage_desc.base_sizes_; if (old_data != nullptr) { ptrdiff_t copy_size = old_size;