pytorch搭建LeNet网络实现图像分类器

news2025/1/15 6:57:08

pytorch搭建LeNet网络实现图像分类器

  • 一、定义LeNet网络模型
    • 1,卷积 Conv2d
    • 2,池化 MaxPool2d
    • 3,Tensor的展平:view()
    • 4,全连接 Linear
    • 5,代码:定义 LeNet 网络模型
  • 二、训练并保存网络参数
    • 1,数据预处理
    • 2,数据集
    • 3,代码
  • 三、图像分类测试

一、定义LeNet网络模型

pytorch 中的卷积、池化、输入输出层中参数的含义与位置,可参考下图:
在这里插入图片描述

1,卷积 Conv2d

常用的卷积(Conv2d)在pytorch中对应的函数是

# in_channels:输入特征矩阵的深度。如输入一张RGB彩色图像,那in_channels=3
# out_channels:输入特征矩阵的深度。也等于卷积核的个数,使用n个卷积核输出的特征矩阵深度就是n
# kernel_size:卷积核的尺寸。可以是int类型,如3 代表卷积核的height=width=3,也可以是tuple类型如(3, 5)代表卷积核的height=3,width=5
# stride:卷积核的步长。默认为1,和kernel_size一样输入可以是int型,也可以是tuple类型
# padding:补零操作,默认为0。可以为int型如1即补一圈0,如果输入为tuple型如(2, 1) 代表在上下补2行,左右补1列。
torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros')

经卷积后的输出层尺寸计算公式为:
在这里插入图片描述
例如:当定义 Conv2d(3, 16, 5) 和 input(3, 32, 32),步长 S 为 1,P 为0时,此时卷积核尺度 F 为5,W 为32,计算得到 output(16, 28, 28)

2,池化 MaxPool2d

最大池化(MaxPool2d)在 pytorch 中对应的函数是:

MaxPool2d(kernel_size, stride)

3,Tensor的展平:view()

注意到,在经过第二个池化层后,数据还是一个三维的Tensor (32, 5, 5),需要先经过展平后(3255)再传到全连接层:

  x = self.pool2(x)            # 第二个池化层 output(32, 5, 5)
  x = x.view(-1, 32*5*5)       # 展平 output(32*5*5)
  x = F.relu(self.fc1(x))      # 传到全连接层 output(120)

4,全连接 Linear

全连接(Linear)在 pytorch 中对应的函数是:

Linear(in_features, out_features, bias=True)

5,代码:定义 LeNet 网络模型

model.py

# 定义LeNet网络模型
import torch.nn as nn
import torch.nn.functional as F

class LeNet(nn.Module):                  # 继承于nn.Module这个父类
    def __init__(self):                  # 初始化网络结构
        super(LeNet, self).__init__()    # 多继承需用到super函数
        self.conv1 = nn.Conv2d(3, 16, 5)
        self.pool1 = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(16, 32, 5)
        self.pool2 = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(32*5*5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):            # 正向传播过程
        x = F.relu(self.conv1(x))    # input(3, 32, 32) output(16, 28, 28)
        x = self.pool1(x)            # output(16, 14, 14)
        x = F.relu(self.conv2(x))    # output(32, 10, 10)
        x = self.pool2(x)            # output(32, 5, 5)
        x = x.view(-1, 32*5*5)       # output(32*5*5)
        x = F.relu(self.fc1(x))      # output(120)
        x = F.relu(self.fc2(x))      # output(84)
        x = self.fc3(x)              # output(10)
        return x

二、训练并保存网络参数

1,数据预处理

ToTensor:把输入的图像数据为 shape (H x W x C) in the range [0, 255] 转化为 shape (C x H x W) in the range [0.0, 1.0],同时将 image 和 numpy 输入格式转化为 tensor
Normalize:标准化

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

2,数据集

用的是CIFAR10数据集,是 pytorch 自带的一个很经典的图像分类数据集,一共包含 10 个类别的 RGB 彩色图片。
在这里插入图片描述
注意:第一次运行程序,需要下载数据集到本地,所以第一次运行训练集下载时download=True为True,下载完成后改为False。测试集的加载则不用变化。

3,代码

名词定义
epoch对训练集的全部数据进行一次完整的训练,称为 一次 epoch
batch由于硬件算力有限,实际训练时将训练集分成多个批次训练,每批数据的大小为 batch_size
iteration 或 step对一个batch的数据训练的过程称为 一个 iteration 或 step
# 加载数据集并训练,训练集计算loss,测试集计算accuracy,保存训练好的网络参数
import torch
import torchvision
import torch.nn as nn
from model import LeNet 
import torch.optim as optim
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import time

# 数据预处理
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

# 导入、加载训练集
# 导入50000张训练图片
train_set = torchvision.datasets.CIFAR10(root='./data',      # 数据集存放目录
                                        train=True,          # 表示是数据集中的训练集
                                        download=False,       # 第一次运行时为True,下载数据集,下载完成后改为False
                                        transform=transform) # 预处理过程
# 加载训练集,实际过程需要分批次(batch)训练                                        
train_loader = torch.utils.data.DataLoader(train_set,      # 导入的训练集
                                           batch_size=50,  # 每批训练的样本数
                                           shuffle=False,  # 是否打乱训练集
                                           num_workers=0)  # 使用线程数,在windows下设置为0

# 导入测试集
# 导入10000张测试图片
test_set = torchvision.datasets.CIFAR10(root='./data', 
                                        train=False,    # 表示是数据集中的测试集
                                        download=False,transform=transform)
# 加载测试集
test_loader = torch.utils.data.DataLoader(test_set, 
                                          batch_size=10000, # 每批用于验证的样本数
                                          shuffle=False, num_workers=0)
# 获取测试集中的图像和标签,用于accuracy计算
test_data_iter = iter(test_loader)
test_image, test_label = test_data_iter.next()

#训练过程
net = LeNet()                                       # 定义训练的网络模型
loss_function = nn.CrossEntropyLoss()               # 定义损失函数为交叉熵损失函数 
optimizer = optim.Adam(net.parameters(), lr=0.001)  # 定义优化器(训练参数,学习率)

for epoch in range(5):  # 一个epoch即对整个训练集进行一次训练
    running_loss = 0.0	# 累加过程中的损失
    time_start = time.perf_counter()
    
    for step, data in enumerate(train_loader, start=0):   # enumerate遍历训练集,可以同时返回 data 和 step对应下标,start表示从0下标开始遍历
        inputs, labels = data 	# 获取训练集的图像和标签
        optimizer.zero_grad()   # 清除历史损失梯度
        
        # forward + backward + optimize
        outputs = net(inputs)  				  # 正向传播
        loss = loss_function(outputs, labels) # 计算损失
        loss.backward() 					  # 反向传播
        optimizer.step() 					  # 优化器更新参数

        # 打印耗时、损失、准确率等数据
        running_loss += loss.item()
        if step % 1000 == 999:    # print every 1000 mini-batches,每1000步打印一次
            with torch.no_grad(): # 在以下步骤中(验证过程中)不用计算每个节点的损失梯度,防止内存占用
                outputs = net(test_image) 				 # 测试集传入网络(test_batch_size=10000),output维度为[10000,10]
                predict_y = torch.max(outputs, dim=1)[1] # 以output中值最大位置对应的索引(标签)作为预测输出
                accuracy = (predict_y == test_label).sum().item() / test_label.size(0)
                
                print('[%d, %5d] train_loss: %.3f  test_accuracy: %.3f' %  # 打印epoch,step,loss,accuracy
                      (epoch + 1, step + 1, running_loss / 500, accuracy))
                
                print('%f s' % (time.perf_counter() - time_start))        # 打印耗时
                running_loss = 0.0

print('Finished Training')

# 保存训练得到的参数
save_path = './Lenet.pth'
torch.save(net.state_dict(), save_path)

三、图像分类测试

使用训练并保存好的网络参数,从数据集外找一张图像进行分类测试

# 导入包
import torch
import torchvision.transforms as transforms
from PIL import Image
from model import LeNet

# 数据预处理
transform = transforms.Compose(
    [transforms.Resize((32, 32)), # 首先需resize成跟训练集图像一样的大小
     transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])    # 数据标准化

# 导入要测试的图像
im = Image.open('./car.jpg').convert('RGB')    # 若图像为4通道,则用 convert('RGB') 转化为3通道,否则 transform 会报错
im = transform(im)  # [C, H, W]
im = torch.unsqueeze(im, dim=0)  # 对数据增加一个新维度,因为tensor的参数是[batch, channel, height, width] 

# 实例化网络,加载训练好的模型参数
net = LeNet()
net.load_state_dict(torch.load('Lenet.pth'))

# 预测
classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
with torch.no_grad():
    outputs = net(im)
    predict = torch.max(outputs, dim=1)[1].data.numpy()    # 找出最大概率的下标
	predicts = torch.softmax(outputs , dim=1)    # 所有分类的预测概率
print(classes[int(predict)])
print(predicts)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/659057.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

GAMES101笔记 Lecture03 Transformation

目录 Transoformation(变换)Why stuty transformation(为什么要学习变换呢?)2D transformations(2D变换)Scale transformation(缩放变换)Reflection Matrix(反射矩阵)Shear Matrix(切变矩阵) Rotate transformation(旋转变换)Linear Transforms Matrices(线性变换 矩阵) Hom…

Java泛型详解,史上最全图文详解

泛型在java中有很重要的地位,无论是开源框架还是JDK源码都能看到它。 毫不夸张的说,泛型是通用设计上必不可少的元素,所以真正理解与正确使用泛型,是一门必修课。 一:泛型本质 Java 泛型(generics&#…

C#程序设计——Windows应用程序开发,1、初步掌握Windows应用程序的设计方法。2、掌握常用窗体控件的使用方法。

Windows应用程序开发 一、实验目的 初步掌握Windows应用程序的设计方法。掌握常用窗体控件的使用方法。 二、实验内容 1、设计一个Windows应用程序,创建一个用于添加学生个人基本信息的窗体,窗体下方法同时滚动信息“天行健&#xf…

前端学习-html基础

html学习与总结 一、基础认知 1.1.1 认识网页(了解) ➢ 问题1:网页由哪些部分组成? ✓ 文字、图片、音频、视频、超链接 ➢ 问题2:我们看到的网页背后本质是什么? ✓ 前端程序员写的代码 ➢ 问题3&a…

设计模式的几大原则

设计模式原则 前言一.单一职责原则1.1 定义1.2 例子1.3 总结 二.里氏替换原则2.1 定义1.2 例子1.3 总结 三.依赖倒置原则3.1 定义3.2例子3.3总结 四.接口隔离原则4.1 定义4.2 例子4.3 总结五.迪米特法则5.1 定义5.2 例子5.3 总结 六.开闭原则6.1 定义6.2 例子6.3 结论 前言 设…

MongoDB复制(副本)集实战及其原理分析-04

MongoDB复制集 复制集架构 在生产环境中,不建议使用单机版的MongoDB服务器。 原因如下: 单机版的MongoDB无法保证可靠性,一旦进程发生故障或是服务器宕机,业务 将直接不可用。 一旦服务器上的磁盘损坏,数据会直接丢…

UDS系列-31服务(Routine Control)

诊断协议那些事儿 诊断协议那些事儿专栏系列文章,本文介绍例程控制服务RoutineControl,该服务的目的是Client端使用Routine Control服务来执行定义的步骤序列并获取特定序列的相关结果。这个服务经常在EOL、Bootloader中使用,比如,检查刷写条件是否满足、擦除内存、覆盖正…

post接口请求测试,通俗易懂

目录 前言: GET方法和POST方法传递数据的异同 POST方法如何传递数据 接口测试软件简介 POST请求接口的测试 测试方法 3.保存接口测试用例,生成自动化测试套件 总结 前言: Post请求是HTTP中请求方法之一,用于向服务器提交…

AI 绘画(2):Ai模型训练,Embedding模型,实现“人物模型“自由

文章目录 文章回顾感谢人员题外话Ai绘画公约Ai模型训练硬件要求显存设置查看显存大小显存过小解决方法 视频教程前期准备SD配置设置SD设置配置SD训练配置pt生成训练集收集训练集要求截图软件推荐训练集版权声明一键重命名图片训练图片来源批量修改图片尺寸 开始训练导入训练集&…

MQTTX的使用

1.MQTT介绍 MQTT是一种常用的物联网协议。MQTT(Message Queuing Telemetry Transport)是一种轻量级的发布/订阅通信协议,用于在物联网(IoT)和机器对机器(M2M)通信中传输消息。 MQTT协议被设计用…

013.【排序算法】合并排序法

1. 合并排序法 合并排序法是针对已经排序好的两个或两个以上的数列,通过合并的方式,将其组合成一个大的且排序好的数列。首先是将无序的数列分成若干小份,分若干份的规则就是不断把每段长度除以2(对半分),…

Jmeter断言详细使用教程

目录 前言: 断言介绍与使用 响应断言 断言持续时间 XML断言 1、响应断言 2、JSON Assertion 3、Size Assertion(见图知意) 4、JSR223 Assertion JSR223 Assertion实例: 5、XPath Assertion 6、Compare Assertion 7、断言持续时间…

如何获得忠诚的铁粉

目录 1.选择热门主题 2.提供独特观点(原创精神) 3.写作风格(目录定位分点总结) 4.提供有价值的内容 5.总结: 📢导语:赢得铁粉(粉丝)的支持对于一个作者来说至关重要。…

前端Vue加载中页面动画弹跳动画loading

前端Vue加载中页面动画弹跳动画loading&#xff0c; 下载完整代码请访问uni-app插件市场址:https://ext.dcloud.net.cn/plugin?id13091 效果图如下&#xff1a; #### 使用方法 使用方法 <!-- ref:唯一ref top&#xff1a;距离中间顶部距离 --> <cc-loading ref&…

Postman大势已去,Apifox的时代已到来

目录 前言&#xff1a; 前情简介&#xff1a;亲身经历节选 Code: 403 “将我踢飞” 浓眉大眼的 Swagger 把我欺骗 工作提效的版本答案 为什么是Apifox 贴心为你 写在最后 前言&#xff1a; Apifox是一款基于web的API设计工具&#xff0c;提供了简洁明了的界面和丰富的…

Debezium系列之:Outbox Event Router

Debezium系列之&#xff1a;Outbox Event Router 一、认识Outbox Event Router二、使用发件箱模式进行可靠的微服务数据交换三、双写问题四、发件箱模式五、基于变更数据捕获的实现六、发件箱表七、发送事件到发件箱八、注册 Debezium 连接器九、主题路由十、Apache Kafka 中的…

交叉编译libcurl libosip libeXosip(包含openssl)

交叉编译libcurl ./configure --with-ssl/home/zx/zxapp/openssl-1.1.0l/output --without-zlib --enable-shared --enable-static --hostarm-linux-gnueabihf CCarm-linux-gnueabihf-gcc --prefix$PWD/build 交叉编译openssl ./config no-asm shared -fPIC --prefix/home/…

ColorUI 全网最全使用文档(建议收藏)

Color UI 我想大家都知晓吧&#xff0c;我就不过多阐述了&#xff0c;是 文晓港 大佬开发的一款适应于H5、微信小程序、安卓、ios、支付宝的高颜值&#xff0c;高度自定义的 Css 组件库.&#xff0c;属于出道即巅峰的史诗级大作&#xff0c;众所周知&#xff0c;万物皆可 Color…

【CEEMDAN-CNN-LSTM】完备集合经验模态分解-卷积神经长短时记忆神经网络研究(Python代码实现)

&#x1f4a5;&#x1f4a5;&#x1f49e;&#x1f49e;欢迎来到本博客❤️❤️&#x1f4a5;&#x1f4a5; &#x1f3c6;博主优势&#xff1a;&#x1f31e;&#x1f31e;&#x1f31e;博客内容尽量做到思维缜密&#xff0c;逻辑清晰&#xff0c;为了方便读者。 ⛳️座右铭&a…

【雕爷学编程】Arduino动手做(115)---HB100多普勒雷达模块

37款传感器与执行器的提法&#xff0c;在网络上广泛流传&#xff0c;其实Arduino能够兼容的传感器模块肯定是不止这37种的。鉴于本人手头积累了一些传感器和执行器模块&#xff0c;依照实践出真知&#xff08;一定要动手做&#xff09;的理念&#xff0c;以学习和交流为目的&am…