diff --git a/src/tensor/tensor.h b/src/tensor/tensor.h index 13356550b59131d6c5faa2414fa174ed3b0023c4..0f2f878dbbca86941887595c1e4926e3561c9855 100644 --- a/src/tensor/tensor.h +++ b/src/tensor/tensor.h @@ -40,14 +40,17 @@ struct Tensor { Tensor(void *data_ptr, std::vector sizes, std::vector strides, int64_t dim, ::kutacc::DType dtype) : data(SimpleTensor(data_ptr, sizes, strides, dim, dtype)) { - KUTACC_CHECK(data_ptr != nullptr, "data_ptr is a nullptr"); KUTACC_CHECK(sizes.size() != 0, "invalid tensor size, size is empty"); KUTACC_CHECK(strides.size() != 0, "invalid tensor stride, stride is empty"); - KUTACC_CHECK(dim > 0, "invalid tensor dim, dim is:", dim); + KUTACC_CHECK(dim >= 0, "invalid tensor dim, dim is:", dim); for (uint32_t i = 0; i < sizes.size(); ++i) { - KUTACC_CHECK(sizes[i] > 0, "sizes[%u] is less than 0", i); + KUTACC_CHECK(sizes[i] >= 0, "sizes[", i, "] is less than 0"); } KUTACC_CHECK(dtype != kInt64 || dtype != kBF16, "invalid dtype which is not BF16 or INT64"); + for (uint32_t i = 0; i < strides.size(); ++i) { + KUTACC_CHECK(strides[i] >= 0, "strides[", i, "] is less than 0"); + } + KUTACC_CHECK(strides.size() == sizes.size(), "strides size dismatches tensor dim"); }; void* data_ptr() const; std::vector sizes() const;