diff --git a/mindspore-lite/tools/lite_exporter/fetch_content.cc b/mindspore-lite/tools/lite_exporter/fetch_content.cc index 17a0cadf34e810d402ba1e8108005ccabea27f15..9822d6dcc8a19711ac9a622e98d9e0e7f710b0ff 100644 --- a/mindspore-lite/tools/lite_exporter/fetch_content.cc +++ b/mindspore-lite/tools/lite_exporter/fetch_content.cc @@ -48,6 +48,16 @@ constexpr int kNumTransposePermSize = 4; constexpr size_t kTensorListMinSize = 3 * sizeof(int32_t); static const std::unordered_map TypeToTypeMap = { {kNumberTypeInt, kNumberTypeInt32}, {kNumberTypeUInt, kNumberTypeUInt32}, {kNumberTypeFloat, kNumberTypeFloat32}}; + +size_t GetTensorDataNBytes(const tensor::TensorPtr &tensor) { + MS_EXCEPTION_IF_NULL(tensor->device_address()); + if (tensor->device_address()->data() != nullptr) { + return static_cast(tensor->device_address()->data()->nbytes()); + } else { + return tensor->DataNBytes(); + } +} + STATUS GetShapeVectorFromStringTensor(const tensor::TensorPtr &tensor_info, ShapeVector *shape_vector, size_t *offset) { MS_ASSERT(tensor_info != nullptr && shape_vector != nullptr && offset != nullptr); auto data_type = tensor_info->data_type(); @@ -62,9 +72,7 @@ STATUS GetShapeVectorFromStringTensor(const tensor::TensorPtr &tensor_info, Shap std::string shape_size_str; *offset = 0; size_t cnt = 0; - MS_EXCEPTION_IF_NULL(tensor_info->device_address()); - MS_EXCEPTION_IF_NULL(tensor_info->device_address()->data()); - auto tensor_info_nbytes = static_cast(tensor_info->device_address()->data()->nbytes()); + auto tensor_info_nbytes = GetTensorDataNBytes(tensor_info); for (; *offset < tensor_info_nbytes; (*offset)++) { if (tensor_data[*offset] == ',') { (*offset)++; @@ -162,9 +170,7 @@ int FetchFromTensorValue(const ValueNodePtr &value_node, converter::FmkType fmk_ // process weight tensor if (copy_data) { - MS_EXCEPTION_IF_NULL(data->device_address()); - MS_EXCEPTION_IF_NULL(data->device_address()->data()); - auto data_nbytes = static_cast(data->device_address()->data()->nbytes()); + auto data_nbytes = GetTensorDataNBytes(data); data_info->data_.resize(data_nbytes); if (data_nbytes > 0 && memcpy_s(data_info->data_.data(), data_nbytes, data->data_c(), data_nbytes) != EOK) { MS_LOG(ERROR) << "memcpy_s error."; @@ -266,9 +272,7 @@ int SetTensorData(const tensor::TensorPtr &tensor_info, DataInfo *data_info, Typ bool copy_data) { MS_CHECK_TRUE_RET(data_info != nullptr, RET_NULL_PTR); MS_CHECK_TRUE_RET(tensor_info != nullptr, RET_NULL_PTR); - MS_EXCEPTION_IF_NULL(tensor_info->device_address()); - MS_EXCEPTION_IF_NULL(tensor_info->device_address()->data()); - auto tensor_info_nbytes = static_cast(tensor_info->device_address()->data()->nbytes()); + auto tensor_info_nbytes = GetTensorDataNBytes(tensor_info); if (data_type == kObjectTypeTensorType && tensor_info_nbytes >= kTensorListMinSize) { data_info->data_.resize(tensor_info_nbytes - offset); if (EOK != common::huge_memcpy(data_info->data_.data(), data_info->data_.size(), @@ -319,9 +323,7 @@ int FetchFromDefaultParam(const ParameterPtr ¶m_node, const converter::FmkTy std::vector dims(shape_vector.begin(), shape_vector.end()); data_info->shape_ = dims; if (tensor_info != nullptr) { - MS_EXCEPTION_IF_NULL(tensor_info->device_address()); - MS_EXCEPTION_IF_NULL(tensor_info->device_address()->data()); - auto tensor_info_nbytes = static_cast(tensor_info->device_address()->data()->nbytes()); + auto tensor_info_nbytes = GetTensorDataNBytes(tensor_info); if (tensor_info_nbytes != 0) { // tensor_list tensor status = SetTensorData(tensor_info, data_info, data_type, offset, copy_data); @@ -330,8 +332,7 @@ int FetchFromDefaultParam(const ParameterPtr ¶m_node, const converter::FmkTy return RET_ERROR; } } - } - if (tensor_info != nullptr) { + data_info->compress_type_ = tensor_info->compression_type(); data_info->quant_params_ = tensor_info->quant_params(); } @@ -458,9 +459,7 @@ int FetchDataFromCNode(const CNodePtr &cnode, size_t index, DataInfo *data_info) } auto tensor_value = tensor_info->cast(); MS_CHECK_TRUE_MSG(tensor_value != nullptr, RET_ERROR, "cast ptr failed"); - MS_EXCEPTION_IF_NULL(tensor_value->device_address()); - MS_EXCEPTION_IF_NULL(tensor_value->device_address()->data()); - auto tensor_value_nbytes = static_cast(tensor_value->device_address()->data()->nbytes()); + auto tensor_value_nbytes = GetTensorDataNBytes(tensor_value); if (tensor_value_nbytes >= kTensorListMinSize) { data_info->data_.resize(tensor_value_nbytes); if (memcpy_s(data_info->data_.data(), tensor_value_nbytes, tensor_value->data_c(), tensor_value_nbytes) != EOK) { @@ -525,9 +524,7 @@ int FetchDataFromAbstract(const AbstractBasePtr &abstract, DataInfo *data_info) } auto tensor_value = tensor_info->cast(); MS_CHECK_TRUE_MSG(tensor_value != nullptr, RET_ERROR, "cast ptr failed"); - MS_EXCEPTION_IF_NULL(tensor_value->device_address()); - MS_EXCEPTION_IF_NULL(tensor_value->device_address()->data()); - auto tensor_value_nbytes = static_cast(tensor_value->device_address()->data()->nbytes()); + auto tensor_value_nbytes = GetTensorDataNBytes(tensor_value); if (tensor_value_nbytes >= kTensorListMinSize) { data_info->data_.resize(tensor_value_nbytes); if (memcpy_s(data_info->data_.data(), tensor_value_nbytes, tensor_value->data_c(), tensor_value_nbytes) != EOK) {