CNN从搭建到部署实战(pytorch+libtorch)

news2025/1/11 12:52:57

模型搭建

下面的代码搭建了CNN的开山之作LeNet的网络结构。

import torch


class LeNet(torch.nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv = torch.nn.Sequential(
            torch.nn.Conv2d(1, 6, 5), # in_channels, out_channels, kernel_size
            torch.nn.Sigmoid(),
            torch.nn.MaxPool2d(2, 2), # kernel_size, stride
            torch.nn.Conv2d(6, 16, 5),
            torch.nn.Sigmoid(),
            torch.nn.MaxPool2d(2, 2)
        )
        self.fc = torch.nn.Sequential(
            torch.nn.Linear(16*4*4, 120),
            torch.nn.Sigmoid(),
            torch.nn.Linear(120, 84),
            torch.nn.Sigmoid(),
            torch.nn.Linear(84, 10)
        )

    def forward(self, img):
        feature = self.conv(img)
        flat = feature.view(img.shape[0], -1)
        output = self.fc(flat)
        return output
 
 
net = LeNet()   
print(net)
print('parameters:', sum(param.numel() for param in net.parameters()))

运行代码,输出结果:

LeNet(
  (conv): Sequential(
    (0): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
    (1): Sigmoid()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
    (4): Sigmoid()
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (fc): Sequential(
    (0): Linear(in_features=256, out_features=120, bias=True)
    (1): Sigmoid()
    (2): Linear(in_features=120, out_features=84, bias=True)
    (3): Sigmoid()
    (4): Linear(in_features=84, out_features=10, bias=True)
  )
)
parameters: 44426

模型训练

编写训练代码如下:16-18行解析参数;20-21行加载网络;23-24行定义损失函数和优化器;26-36行定义数据集路径和数据变换,加载训练集和测试集(实际上应该是验证集);37-57行for循环中开始训练num_epochs轮次,计算训练集和测试集(验证集)上的精度,并保存权重。

import torch
import torchvision
import time
import argparse
from models.lenet import net


def parse_args():
    parser = argparse.ArgumentParser('training')
    parser.add_argument('--batch_size', default=128, type=int, help='batch size in training')
    parser.add_argument('--num_epochs', default=5, type=int, help='number of epoch in training')
    return parser.parse_args()


if __name__ == '__main__':
    args = parse_args()
    batch_size = args.batch_size
    num_epochs = args.num_epochs
        
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    net = net.to(device)

    loss = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters(), lr=0.001)
          
    train_path = r'./Datasets/mnist_png/training'
    test_path = r'./Datasets/mnist_png/testing'
    transform_list = [torchvision.transforms.Grayscale(num_output_channels=1), torchvision.transforms.ToTensor()]
    transform = torchvision.transforms.Compose(transform_list)

    train_dataset = torchvision.datasets.ImageFolder(train_path, transform=transform)
    test_dataset = torchvision.datasets.ImageFolder(test_path, transform=transform)

    train_iter = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
    test_iter = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=0)

    for epoch in range(num_epochs):
            train_l, train_acc, test_acc, m, n, batch_count, start = 0.0, 0.0, 0.0, 0, 0, 0, time.time()
            for X, y in train_iter:
                X, y = X.to(device), y.to(device)
                y_hat = net(X)
                l = loss(y_hat, y)
                optimizer.zero_grad()
                l.backward()
                optimizer.step()
                train_l += l.cpu().item()
                train_acc += (y_hat.argmax(dim=1) == y).sum().cpu().item()
                m += y.shape[0]
                batch_count += 1
            with torch.no_grad():
                for X, y in test_iter:
                    net.eval() # 评估模式
                    test_acc += (net(X.to(device)).argmax(dim=1) == y.to(device)).float().sum().cpu().item()
                    net.train() # 改回训练模式
                    n += y.shape[0]
            print('epoch %d, loss %.6f, train acc %.3f, test acc %.3f, time %.1fs'% (epoch, train_l / batch_count, train_acc / m, test_acc / n, time.time() - start))
            torch.save(net, "checkpoint.pth")

该代码支持cpu和gpu训练,损失函数是CrossEntropyLoss,优化器是Adam,数据集用的是手写数字mnist数据集。训练的部分打印日志如下:

epoch 0, loss 1.486503, train acc 0.506, test acc 0.884, time 25.8s
epoch 1, loss 0.312726, train acc 0.914, test acc 0.938, time 33.3s
epoch 2, loss 0.185561, train acc 0.946, test acc 0.960, time 27.4s
epoch 3, loss 0.135757, train acc 0.960, test acc 0.968, time 24.9s
epoch 4, loss 0.108427, train acc 0.968, test acc 0.972, time 19.0s

模型测试

测试的代码非常简单,流程是加载网络和权重,然后读入数据进行变换再前向推理即可。

import cv2
import torch
from pathlib import Path
from models.lenet import net
import torchvision.transforms.functional


if __name__ == '__main__':
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    net = net.to(device)
    net = torch.load('checkpoint.pth')
    net.eval()

    with torch.no_grad():
        imgs_path = Path(r"./Datasets/mnist_png/testing/0/").glob("*")
        for img_path in imgs_path:
            img = cv2.imread(str(img_path), 0)
            img_tensor = torchvision.transforms.functional.to_tensor(img)
            img_tensor = torch.unsqueeze(img_tensor, 0)
            print(net(img_tensor.to(device)).argmax(dim=1).item())

输出部分结果如下:

0
0
0
0
0
0
0
0

模型转换

下面的脚本提供了pytorch模型转换torchscript和onnx的功能。

import torch
from models.lenet import net


if __name__ == '__main__':
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    net = net.to(device)
    net = torch.load('checkpoint.pth')

    x = torch.rand(1, 1, 28, 28)
    x = x.to(device)

    traced_script_module = torch.jit.trace(net, x)
    traced_script_module.save("checkpoint.pt")

    torch.onnx.export(net,
                    x,
                    "checkpoint.onnx",
                    opset_version = 11
                    )

onnx模型netron可视化结果如下:
在这里插入图片描述

模型部署

10.png图片
在这里插入图片描述

libtorch部署

#include <iostream>
#include <opencv2/opencv.hpp>
#include <torch/torch.h>
#include <torch/script.h> 


int main(int argc, char* argv[])
{
	std::string model = "checkpoint.pt";
	torch::jit::script::Module module = torch::jit::load(model);
	module.to(torch::kCUDA);

	cv::Mat image = cv::imread("10.png", 0);
	image.convertTo(image, CV_32F, 1.0 / 255);
	at::Tensor img_tensor = torch::from_blob(image.data, { 1, 1, image.rows, image.cols }, torch::kFloat32).to(torch::kCUDA);
	
	torch::Tensor result = module.forward({ img_tensor }).toTensor();
	std::cout << result << std::endl;
	std::cout << result.argmax(1) << std::endl;

	return 0;
}

输出结果:

 8.0872 -6.3622  0.0291 -1.7327 -4.0367  0.8192  0.8159 -3.2559 -1.8254 -2.2787
[ CUDAFloatType{1,10} ]
 0
[ CUDALongType{1} ]

opencv dnn部署

只用OpenCV也可以部署模型,其中的dnn模块可以解析onnx格式模型。

#include <iostream>
#include <opencv2/opencv.hpp>


int main(int argc, char* argv[])
{
	std::string model = "checkpoint.onnx";
	cv::dnn::Net net = cv::dnn::readNet(model);

	cv::Mat image = cv::imread("10.png", 0), blob;
	cv::dnn::blobFromImage(image, blob, 1. / 255., cv::Size(28, 28), cv::Scalar(), true, false);

	net.setInput(blob);
	std::vector<cv::Mat> output;
	net.forward(output, net.getUnconnectedOutLayersNames());

	std::vector<float> values;
	for (size_t i = 0; i < output[0].cols; i++)
	{
		values.push_back(output[0].at<float>(0, i));
	}
	std::cout << std::distance(values.begin(), std::max_element(values.begin(), values.end())) << std::endl;

	return 0;
}

输出结果:

0

本文的完整工程可见:https://github.com/taifyang/deep-learning-pytorch-demo

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/753781.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

mysql 第三章

目录 1.索引 2.事务 3.总结 1.索引 2.事务 3.总结 事务是一种机制&#xff0c;一个操作序列。包含了一组数据库操作命令。

静态数码管——FPGA

文章目录 前言一、数码管1、数码管简介2、共阴极数码管or共阳极数码管3、共阴极与共阳极的真值表 二、系统设计1、模块框图2、RTL视图 三、源码1、seg_led_static模块2、time_count模块3、top_seg_led_static(顶层文件) 四、效果五、总结六、参考资料 前言 环境&#xff1a; 1、…

大学生用一周时间给麦当劳做了个App(微信小程序版)

背景 有个大学生粉丝最近私信联系我&#xff0c;说基于我之前开源的多语言项目做了个仿麦当劳的项目&#xff0c;虽然只是个样子货&#xff0c;但是收获颇多&#xff0c;希望把自己写的代码开源出来供大家一起学习进度。这个小伙伴确实是非常积极上进&#xff0c;很多大学生&a…

MySQL数据库(三)

前言 聚合查询、分组查询、联合查询是数据库知识中最重要的一部分&#xff0c;是将表的行与行之间进行运算。 目录 前言 一、聚合查询 &#xff08;一&#xff09;聚合函数 1、count 2、sum 3、avg 4、max 5、min 二、分组查询 &#xff08;一&#xff09;group by …

Docker架构

目录 Docker总架构图Docker ClientDocker DaemonDocker ServerDocker EngineJob Docker RegistryGraphDriverGraphDriverNetworkDriverExecDriver LibcontainerDocker Container Docker可以帮助用户在容器内部快速自动化部署应用&#xff0c;并利用Linux内核特性命名空间&#…

微软将推出更多Edge特有功能,与Chrome展开竞争

微软在 2018 年宣布将推出基于 Chromium 构建的 Edge 浏览器&#xff0c;并于 2020 年 1 月推出了新版 Edge。如今时隔三年&#xff0c;根据统计 Edge 全平台的市场占有率仅为 4.23%&#xff0c;如果只考虑桌面端的话&#xff0c;Edge 的市场占有率则是 10.98%&#xff0c;这两…

系统设计蓝图 / 备忘单

开发一个强大、可扩展和高效的系统可能会令人望而却步。然而&#xff0c;了解关键概念和组件可以使这个过程更可管理。在本博客文章中&#xff0c;我们将探讨系统设计的关键概念和组件&#xff0c;如DNS、负载均衡、API网关等&#xff0c;以及一个简明的备忘单&#xff0c;可以…

inux运维面试题(二)之系统管理类面试题

Linux运维面试题&#xff08;二&#xff09;之系统管理类面试题 1.权限优化1.1 简述Linux权限划分原则文件基本权限默认权限特殊权限sudo授权文件系统属性权限 解答 2.备份策略2.1需要备份的内容备份策略备份频率备份存储位置 2.2网站服务器每天产生的日志数量较大&#xff0c;…

LLM - 读取 Lora 模型进行文本生成

目录 一.引言 二.Lora 模型文本生成 1.模型读取 1.1 AutoModelForCausalLM.from_pretrained 1.2 PeftModel.from_pretrained 2.文本生成 2.1 Tokenizer 2.2 model.generate 3.输出实践 三.总结 一.引言 前面介绍了使用 Baichuan7B 从样本生成到 Lora 模型微调和存储…

磁盘擦写次数计算

1.让机器能有外网 2,安装工具 sudo apt-get install smartmontools 3,输入查询命令 sudo smartctl -x /dev/sda |egrep Device Model|User Capacity|Sector Size|173|Logical Sectors Written|Percentage Used Endurance Indicator 4,计算擦写次数 计算方法&#xff1a;25…

hadoop启动无法启动datanode或者namenode

首先进入hadoop安装目录下例如&#xff1a; 再进入dfs目录下&#xff0c;没有出现在hdfs-site.xml配置的data或者name cd data/current vim VERSION 打开VERSION文件复制&#xff1a;clusterID下的内容 将内容复制到name中的VERSION。再进行重启Hadoop

常见的前端安全以及常规安全策略

1、CSRF&#xff1a;跨站请求伪造&#xff08;Cross-site request forgery&#xff09;&#xff1b; 原理&#xff1a; &#xff08;1&#xff09; 用户C打开浏览器&#xff0c;访问受信任网站A&#xff0c;输入用户名和密码请求登录网站A&#xff1b; &#xff08;2&#xff…

mapping文件目录生成修改

参考文章&#xff1a; gradle编译完成Copy mapping文件 - 简书 (jianshu.com) 第一步&#xff1a;在app的build.gradle中做如下配置&#xff1a; android {android.applicationVariants.all {variant ->def buildType variant.buildType.nametasks.all {def mappingDir …

米尔基于STM32MP135核心板,助力充电桩发展

随着电动车的普及和人们环保意识的增强&#xff0c;充电桩作为电动车充电设备的重要一环&#xff0c;充电桩行业正迅速发展&#xff0c;消费市场的大量应用也造就市场的需求量不断增长。因此&#xff0c;产品的功能、可靠性、安全性等要求也变得尤为重要&#xff0c;而采用传统…

Leetcode---353周赛

周赛题目 2769. 找出最大的可达成数字 2770. 达到末尾下标所需的最大跳跃次数 2771. 构造最长非递减子数组 2772. 使数组中的所有元素都等于零 一、找出最大的可达成数字 这题就是简单的不能在简单的简单题&#xff0c; 题目意思是&#xff1a;给你一个数num和操作数t&…

智能分析网关V2有抓拍告警但无法推送到EasyCVR,是什么原因?

我们在此前的文章中也介绍了关于智能分析网关V2接入EasyCVR平台的操作步骤&#xff0c;感兴趣的用户可以查看这篇文章&#xff1a;在EasyCVR新版本v3.3中&#xff0c;如何正确接入智能分析网关V2&#xff1f; 智能分析网关V2是基于边缘AI计算技术&#xff0c;可对前端摄像头采…

机器学习 day27(反向传播)

导数 函数在某点的导数为该点处的斜率&#xff0c;用height / width表示&#xff0c;可以看作当w增加ε&#xff0c;J(w,b)增加k倍的ε&#xff0c;则k为该点的导数 反向传播 tensorflow使用反向传播来自动计算神经网络模型中的导数

中国区域地面气象要素驱动数据集(1979-2018)

中国区域地面气象要素驱动数据集&#xff08;1979-2018&#xff09; 摘要 中国区域地面气象要素驱动数据集&#xff0c;包括近地面气温、近地面气压、近地面空气比湿、近地面全风速、地面向下短波辐射、地面向下长波辐射、地面降水率共7个要素。数据为NETCDF格式&#xff0c;时…

Redis_非关系型数据库

一、 Redis介绍 1.NoSQL 也叫Not Only SQL(不仅仅是SQL, 不用 sql语言操作的数据库), 一般指非关系型数据库 关系型数据库: 以数据库表为单位存储,表与表之间存在某种关系 非关系型数据库: 数据与数据之间没有关系, 数据就是以键值对的形式存储, 通过键获取到值 在互联网发展中…

cjson的内存泄漏案例

1、当我们使用下面这些创建json对象时&#xff0c;需要用cJSON_Delete();释放&#xff0c;&#xff08;当然&#xff0c;释放父JSON对象后&#xff0c;子JSON对象也会被释放&#xff09; 2、多次释放同一内存空间 在recv_write_property函数中的data&#xff0c;在Equipment_re…