78 Star 600 Fork 1.2K

Ascend/pytorch

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
libtorch_resnet.cpp 973 Bytes
一键复制 编辑 原始数据 按行查看 历史
#include <iostream>
#include <string>
#include <vector>
#include <torch/torch.h>
#include <torch/script.h>
#include <torch_npu/torch_npu.h>
int main(int argc, const char* argv[]) {
if (argc != 2) {
TORCH_CHECK(false, "Please input the model name!");
}
if (!argv[1]) {
TORCH_CHECK(false, "Got invalid model name!");
}
std::cout << "module: " << argv[1] << std::endl;
// init npu
torch_npu::init_npu("npu:0");
auto device = at::Device("npu:0");
// load model
torch::jit::script::Module module = torch::jit::load(argv[1]);
module.to(device);
// run model
torch::jit::setGraphExecutorOptimize(false);
std::vector<torch::jit::IValue> input_tensor;
input_tensor.push_back(torch::randn({1, 3, 244, 244}).to(device));
at::Tensor output = module.forward(input_tensor).toTensor();
std::cout << output.slice(1, 0, 5) << std::endl;
std::cout << "resnet_model run success!" << std::endl;
torch_npu::finalize_npu();
return 0;
}
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/ascend/pytorch.git
git@gitee.com:ascend/pytorch.git
ascend
pytorch
pytorch
master

搜索帮助