From 8ffb38381188e6524864a0be9f1acb066328cc18 Mon Sep 17 00:00:00 2001 From: XeonYZhang Date: Thu, 25 Sep 2025 09:45:28 +0800 Subject: [PATCH] discard kutacc tensor constructor args check on dataptr --- src/tensor/tensor.h | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/tensor/tensor.h b/src/tensor/tensor.h index 1335655..0f2f878 100644 --- a/src/tensor/tensor.h +++ b/src/tensor/tensor.h @@ -40,14 +40,17 @@ struct Tensor { Tensor(void *data_ptr, std::vector sizes, std::vector strides, int64_t dim, ::kutacc::DType dtype) : data(SimpleTensor(data_ptr, sizes, strides, dim, dtype)) { - KUTACC_CHECK(data_ptr != nullptr, "data_ptr is a nullptr"); KUTACC_CHECK(sizes.size() != 0, "invalid tensor size, size is empty"); KUTACC_CHECK(strides.size() != 0, "invalid tensor stride, stride is empty"); - KUTACC_CHECK(dim > 0, "invalid tensor dim, dim is:", dim); + KUTACC_CHECK(dim >= 0, "invalid tensor dim, dim is:", dim); for (uint32_t i = 0; i < sizes.size(); ++i) { - KUTACC_CHECK(sizes[i] > 0, "sizes[%u] is less than 0", i); + KUTACC_CHECK(sizes[i] >= 0, "sizes[", i, "] is less than 0"); } KUTACC_CHECK(dtype != kInt64 || dtype != kBF16, "invalid dtype which is not BF16 or INT64"); + for (uint32_t i = 0; i < strides.size(); ++i) { + KUTACC_CHECK(strides[i] >= 0, "strides[", i, "] is less than 0"); + } + KUTACC_CHECK(strides.size() == sizes.size(), "strides size dismatches tensor dim"); }; void* data_ptr() const; std::vector sizes() const; -- Gitee