diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/gather_nd_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/gather_nd_infer.c index 27512661bcbb2d914db848a0d19900bc6e480537..4c287cdc8d39345ce0a955b0bf25d080b4554721 100644 --- a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/gather_nd_infer.c +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/gather_nd_infer.c @@ -37,9 +37,7 @@ int GatherNdInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC } int in_rank = (int)(input->shape_size_); int indices_rank = (int)(indices->shape_size_); - for (int i = 0; i < indices_rank; i++) { - NNACL_CHECK_FALSE(indices->shape_[i] == 0, NNACL_ERR); - } + if (indices->shape_[indices_rank - 1] > in_rank) { return NNACL_OK; } diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/reduce.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/reduce.c index 4752357a3639307afd74def38ccce5286f3f85c9..79ef6324dbf0992fec3a10c704c46dde823398cd 100644 --- a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/reduce.c +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/reduce.c @@ -368,16 +368,54 @@ int ReduceResize(struct KernelBase *self) { return NNACL_OK; } +int HandleShapeContain0(struct KernelBase *self, bool *contain) { + int *in_shape = self->in_[FIRST_INPUT]->shape_; + size_t shape_size = self->in_[FIRST_INPUT]->shape_size_; + size_t out_shape_size = self->out_[OUTPUT_INDEX]->shape_size_; + size_t out_len = 1; + for (size_t i = 0; i < out_shape_size; ++i) { + out_len *= self->out_[OUTPUT_INDEX]->shape_[i]; + } + for (size_t i = 0; i < shape_size; ++i) { + if (in_shape[i] != Num0) { + continue; + } + *contain = true; + switch (self->out_[OUTPUT_INDEX]->data_type_) { + case kNumberTypeInt32: + for (size_t j = 0; j < out_len; ++j) { + *((int32_t *)(self->out_[OUTPUT_INDEX]->data_) + j) = 0; + } + break; + case kNumberTypeFloat32: + for (size_t j = 0; j < out_len; ++j) { + *((float *)(self->out_[OUTPUT_INDEX]->data_) + j) = 0.0f; + } + break; + default: + return NNACL_ERR; + } + return NNACL_OK; + } + *contain = false; + return NNACL_OK; +} + int ReduceCompute(struct KernelBase *self) { NNACL_CHECK_NULL_RETURN_ERR(self); ReduceStruct *reduce = (ReduceStruct *)self; NNACL_CHECK_FALSE(self->in_[FIRST_INPUT]->data_type_ != reduce->data_type_, NNACL_ERR); + bool in_shape_contain_0 = false; + int ret = HandleShapeContain0(self, &in_shape_contain_0); + if (in_shape_contain_0) { + return ret; + } if (reduce->only_copy_) { return CopyReduceyInputToOutput(reduce); } - int ret = MallocReduceTmpBuffer(reduce); + ret = MallocReduceTmpBuffer(reduce); if (ret != NNACL_OK) { FreeReduceTmpBuffer(reduce); return ret;