文章目录
- 环境说明
- Pytorch 序列化
- Libtorch 下载
- VS配置
- 主程序
- 可能遇到的问题
- 参考
环境说明
win10
VS2019
OPENCV4.7.0
Litorch1.13
Pytorch 1.12.1
Pytorch 序列化
import torch
from torchvision.models import resnet50
net = resnet50(pretrained=True)
net = net.cuda()
net.eval()
# trace
x = torch.ones(1, 3, 224, 224)
x = x.cuda()
traced_module = torch.jit.trace(net, x)
traced_module.save("resnet50_trace.pt")
Libtorch 下载
[libtorch-win-shared-with-deps-1.13.0+cu117.zip ]
通过网盘分享的文件:libtorch-win-shared-with-deps-1.13.0+cu117.zip
链接: https://pan.baidu.com/s/1NY25r9HnHXhy4kzIn-jlnQ 提取码: v5sq
解压后添加到环境变量里
D:\envPath\libtorch_11.3\libtorch\lib
VS配置
项目右键,选择属性
配置选择Release, 平台选择x64
VC++目录:
D:\envPath\libtorch_11.3\libtorch\include
D:\envPath\libtorch_11.3\libtorch\include\torch\csrc\api\include
C:\opencv\build\include\opencv2
C:\opencv\build\include\
库目录
C:\opencv\build\x64\vc16\lib
D:\envPath\libtorch_11.3\libtorch\lib
链接器-输入-附加依赖项
asmjit.lib
c10.lib
c10_cuda.lib
caffe2_detectron_ops_gpu.lib
caffe2_module_test_dynamic.lib
caffe2_nvrtc.lib
clog.lib
cpuinfo.lib
dnnl.lib
fbgemm.lib
libprotobuf.lib
libprotobuf-lite.lib
libprotoc.lib
mkldnn.lib
torch.lib
torch_cpu.lib
torch_cuda.lib
opencv_world470.lib
链接器-输入-命令行
其他选项添加
/INCLUDE:?ignore_this_library_placeholder@@YAHXZ
如图:
主程序
int main(int argc, char** argv)
{
//Pro_info();
if (argc < 1) return -1;
std::string image_path = argv[1];
//std::string model_path = argv[2];
// 定义设备类型
torch::DeviceType* deviceType = new torch::DeviceType();
if (torch::cuda::is_available())
{
*deviceType = torch::kCUDA;
std::cout << "The cuda is available" << std::endl;
}
else
{
*deviceType = torch::kCPU;
std::cout << "The cuda isn't available" << std::endl;
}
torch::Device device(*deviceType);
std::cout << *deviceType << std::endl;
//加载模型
using torch::jit::script::Module;
Module module;
try
{
module = torch::jit::load("D:/traced_resnet_model.pt");
printf("The model load success!\n");
}
catch (std::exception& e)
{
std::cout << e.what() << std::endl;
std::cerr << "error loading the model\n";
}
module.to(device);
double total_time = 0;
for (int i = 0; i < 100; i++)
{
auto start = std::chrono::system_clock::now();
cv::Mat img = cv::imread(image_path);
if (img.empty())
{
std::cerr << "The image can't open!\n";
return -1;
}
//图像预处理
cv::Mat input;
cv::resize(img, img, cv::Size(input_shape, input_shape));//图片resize成512*512*3
cv::cvtColor(img, input, cv::COLOR_BGR2RGB);
//from_blob Mat转Tensor {batchsize,w,h,channles}
torch::Tensor tensor_image = torch::from_blob(input.data, { 1,input.rows, input.cols,3 }, torch::kByte);
//shape->(batchsize,channles,w,h)
tensor_image = tensor_image.permute({ 0,3,1,2 });
tensor_image = tensor_image.toType(torch::kFloat);
//image/255.0图像的归一化处理
tensor_image = tensor_image.div(255);
tensor_image[0][0] = tensor_image[0][0].sub(0.485).div(0.229);
tensor_image[0][1] = tensor_image[0][1].sub(0.456).div(0.224);
tensor_image[0][2] = tensor_image[0][2].sub(0.406).div(0.225);
auto img_var = torch::autograd::make_variable(tensor_image, false);
std::vector<torch::jit::IValue> inputs;
inputs.push_back(img_var.to(device));
// 网络前向计算
torch::NoGradGuard no_grad;
auto output = module.forward({ inputs }).toTensor();
output = torch::squeeze(torch::argmax(torch::softmax(output, 1), 1), 0);
//std::cout << output.sizes() << std::endl;
output = output.to(torch::kU8).to(torch::kCPU);
//将tensor转为cv::Mat格式,进行展示
cv::Mat Img(output.sizes()[0], output.sizes()[1], CV_8U, output.data_ptr());
Img = Img * 255;
cv::imshow("result", Img);
cv::waitKey(0);
return 0;
auto end = std::chrono::system_clock::now();
double spend_time = std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count();
std::cout << spend_time << "ms" << std::endl;
total_time = total_time + spend_time;
}
可能遇到的问题
- pytorch序列化过程要保证不出现警告,否则会导致后续libtorch调用模型出错
- 环境方面要求libtorch版本要高于pytorch,否则会出现c10 error
- 不同图像分割方法的预处理和后处理可能不同,需要重新使用C++实现。
- 运行时如果无法定位程序输入点于动态链接库,可以将libtorch下的dll文件复制到exe对应目录下
- 显卡没有调用的话,在链接器-命令行添加/INCLUDE:?ignore_this_library_placeholder@@YAHXZ,但不同版本添加的语句不一样,具体参考Terod的总结
- 'torch/torch.h’头文件提示不存在的话,环境配置中确认“\libtorch\include\torch\csrc\api\include”是否添加
- 出现其他问题,可以添加以下代码,打印出错误原因
try
{
// your code goes here...
}
catch (std::exception &e)
{
std::cout << e.what() << std::endl;
}
参考
-
经典图像分割网络:Unet 支持libtorch部署推理【附代码】
-
Windows下将PyTorch模型转成Libtorch并使用C++进行读取
-
Windows下PyTorch(LibTorch)配置cuda加速
如有其他问题,欢迎交流!