代码拉取完成,页面将自动刷新
#include <iostream>
#include <string>
#include <vector>
#include <torch/torch.h>
#include <torch/script.h>
#include <torch_npu/torch_npu.h>
int main(int argc, const char* argv[]) {
if (argc != 2) {
TORCH_CHECK(false, "Please input the model name!");
}
if (!argv[1]) {
TORCH_CHECK(false, "Got invalid model name!");
}
std::cout << "module: " << argv[1] << std::endl;
// init npu
torch_npu::init_npu("npu:0");
auto device = at::Device("npu:0");
// load model
torch::jit::script::Module module = torch::jit::load(argv[1]);
module.to(device);
// run model
torch::jit::setGraphExecutorOptimize(false);
std::vector<torch::jit::IValue> input_tensor;
input_tensor.push_back(torch::randn({1, 3, 244, 244}).to(device));
at::Tensor output = module.forward(input_tensor).toTensor();
std::cout << output.slice(1, 0, 5) << std::endl;
std::cout << "resnet_model run success!" << std::endl;
torch_npu::finalize_npu();
return 0;
}
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。