2.3K Star 8.1K Fork 4.3K

GVPMindSpore / mindspore

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
model_parser_registry_test.cc 2.56 KB
一键复制 编辑 原始数据 按行查看 历史
luoyuan 提交于 2022-03-23 15:49 . add core ops api and adapter new mindapi
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <vector>
#include "common/common_test.h"
#include "ut/tools/converter/registry/parser/model_parser_test.h"
#include "tools/optimizer/common/gllo_utils.h"
#include "mindspore/core/ir/anf.h"
#include "mindapi/ir/func_graph.h"
using mindspore::converter::ConverterParameters;
using mindspore::converter::kFmkTypeCaffe;
namespace mindspore {
namespace {
FuncGraphPtr ConvertGraph(api::FuncGraphPtr func_graph) {
auto impl = func_graph->impl();
return std::dynamic_pointer_cast<FuncGraph>(impl);
}
} // namespace
class ModelParserRegistryTest : public mindspore::CommonTest {
public:
ModelParserRegistryTest() = default;
};
TEST_F(ModelParserRegistryTest, TestRegistry) {
auto node_parser_reg = NodeParserTestRegistry::GetInstance();
auto add_parser = node_parser_reg->GetNodeParser("add");
ASSERT_NE(add_parser, nullptr);
auto proposal_parser = node_parser_reg->GetNodeParser("proposal");
ASSERT_NE(proposal_parser, nullptr);
REG_MODEL_PARSER(kFmkTypeCaffe,
TestModelParserCreator); // register test model parser creator, which will overwrite existing.
auto model_parser = registry::ModelParserRegistry::GetModelParser(kFmkTypeCaffe);
ASSERT_NE(model_parser, nullptr);
ConverterParameters converter_parameters;
auto func_graph = model_parser->Parse(converter_parameters);
ASSERT_NE(func_graph, nullptr);
auto graph = ConvertGraph(func_graph);
auto node_list = graph->TopoSort(graph->get_return());
std::vector<AnfNodePtr> cnode_list;
for (auto &node : node_list) {
if (node->isa<CNode>()) {
cnode_list.push_back(node);
}
}
ASSERT_EQ(cnode_list.size(), 3);
auto iter = cnode_list.begin();
bool is_add = opt::CheckPrimitiveType(*iter, prim::kPrimAddFusion);
ASSERT_EQ(is_add, true);
++iter;
is_add = opt::CheckPrimitiveType(*iter, prim::kPrimAddFusion);
ASSERT_EQ(is_add, true);
++iter;
bool is_return = opt::CheckPrimitiveType(*iter, prim::kPrimReturn);
ASSERT_EQ(is_return, true);
}
} // namespace mindspore
Python
1
https://gitee.com/mindspore/mindspore.git
git@gitee.com:mindspore/mindspore.git
mindspore
mindspore
mindspore
r1.7

搜索帮助