计算机视觉-05-目标检测:LeNet的PyTorch复现(MNIST手写数据集篇)(包含数据和代码)

news2024/12/17 17:33:43

文章目录

    • 0. 数据下载
    • 1. 背景描述
    • 2. 预测目的
    • 3. 数据总览
    • 4. 数据预处理
      • 4.1 下载并加载数据,并做出一定的预先处理
      • 4.2 搭建 LeNet-5 神经网络结构,并定义前向传播的过程
      • 4.3 将定义好的网络结构搭载到 GPU/CPU,并定义优化器
      • 4.4 定义训练过程
      • 4.5 定义测试过程
      • 4.6 运行
      • 4.7 保存模型
      • 4.8 手写图片的测试

0. 数据下载

关注公众号:『AI学习星球
回复:LeNet的PyTorch复现 即可获取数据下载。
算法学习4对1辅导论文辅导核心期刊可以通过公众号CSDN滴滴我
在这里插入图片描述


1. 背景描述

MINIST是一个非常有名的手写体数字识别数据集,训练样本:共60000个,其中55000个用于训练,另外5000个用于检测;测试样本:共10000个。MINIST数据集每张图片是单通道的,大小为28 * 28。

2. 预测目的

使用LeNet-5网络结构创建MNIST手写数字识别分类器,

3. 数据总览

4. 数据预处理

4.1 下载并加载数据,并做出一定的预先处理

由于MNIST数据集的图片尺寸是28 * 28单通道的,而LeNet-5网络输入Input图片尺寸是32 * 32,因此使用transform.Resize将输入图片尺寸调整为32 * 32。

首先导入Pytorch的相关算法库

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import time
from matplotlib import pyplot as plt
pipline_train = transforms.Compose([
    #随机旋转图片
    transforms.RandomHorizontalFlip(),
    #将图片尺寸resize到32x32
    transforms.Resize((32,32)),
    #将图片转化为Tensor格式
    transforms.ToTensor(),
    #正则化(当模型出现过拟合的情况时,用来降低模型的复杂度)
    transforms.Normalize((0.1307,),(0.3081,))    
])
pipline_test = transforms.Compose([
    #将图片尺寸resize到32x32
    transforms.Resize((32,32)),
    transforms.ToTensor(),
    transforms.Normalize((0.1307,),(0.3081,))
])
#下载数据集
train_set = datasets.MNIST(root="./data", train=True, download=True, transform=pipline_train)
test_set = datasets.MNIST(root="./data", train=False, download=True, transform=pipline_test)
#加载数据集
trainloader = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True)
testloader = torch.utils.data.DataLoader(test_set, batch_size=32, shuffle=False)

这里要解释一下 Pytorch MNIST 数据集标准化为什么是 transforms.Normalize((0.1307,), (0.3081,))?

标准化(Normalization)是神经网络对数据的一种经常性操作。标准化处理指的是:样本减去它的均值,再除以它的标准差,最终样本将呈现均值为 0 方差为 1 的数据分布。

神经网络模型偏爱标准化数据,原因是均值为0方差为1的数据在 sigmoid、tanh 经过激活函数后求导得到的导数很大,反之原始数据不仅分布不均(噪声大)而且数值通常都很大(本例中数值范围是 0~255),激活函数后求导得到的导数则接近与 0,这也被称为梯度消失。所以说,数据的标准化有利于加快神经网络的训练。

除此之外,还需要保持 train_set、val_set 和 test_set 标准化系数的一致性。标准化系数就是计算要用到的均值和标准差,在本例中是((0.1307,), (0.3081,)),均值是 0.1307,标准差是 0.3081,这些系数都是数据集提供方计算好的数据。不同数据集就有不同的标准化系数,例如([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])就是 ImageNet dataset 的标准化系数(RGB三个通道对应三组系数),当需要将 Imagenet 预训练的参数迁移到另一神经网络时,被迁移的神经网络就需要使用 Imagenet的系数,否则预训练不仅无法起到应有的作用甚至还会帮倒忙。

4.2 搭建 LeNet-5 神经网络结构,并定义前向传播的过程

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 5) 
        self.relu = nn.ReLU()
        self.maxpool1 = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.maxpool2 = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(16*5*5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.maxpool1(x)
        x = self.conv2(x)
        x = self.maxpool2(x)
        x = x.view(-1, 16*5*5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        output = F.log_softmax(x, dim=1)
        return output

4.3 将定义好的网络结构搭载到 GPU/CPU,并定义优化器

#创建模型,部署gpu
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = LeNet().to(device)
#定义优化器
optimizer = optim.Adam(model.parameters(), lr=0.001)

4.4 定义训练过程

def train_runner(model, device, trainloader, optimizer, epoch):
    #训练模型, 启用 BatchNormalization 和 Dropout, 将BatchNormalization和Dropout置为True
    model.train()
    total = 0
    correct =0.0

    #enumerate迭代已加载的数据集,同时获取数据和数据下标
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        #把模型部署到device上
        inputs, labels = inputs.to(device), labels.to(device)
        #初始化梯度
        optimizer.zero_grad()
        #保存训练结果
        outputs = model(inputs)
        #计算损失和
        #多分类情况通常使用cross_entropy(交叉熵损失函数), 而对于二分类问题, 通常使用sigmod
        loss = F.cross_entropy(outputs, labels)
        #获取最大概率的预测结果
        #dim=1表示返回每一行的最大值对应的列下标
        predict = outputs.argmax(dim=1)
        total += labels.size(0)
        correct += (predict == labels).sum().item()
        #反向传播
        loss.backward()
        #更新参数
        optimizer.step()
        if i % 1000 == 0:
            #loss.item()表示当前loss的数值
            print("Train Epoch{} \t Loss: {:.6f}, accuracy: {:.6f}%".format(epoch, loss.item(), 100*(correct/total)))
            Loss.append(loss.item())
            Accuracy.append(correct/total)
    return loss.item(), correct/total

4.5 定义测试过程

def test_runner(model, device, testloader):
    #模型验证, 必须要写, 否则只要有输入数据, 即使不训练, 它也会改变权值
    #因为调用eval()将不启用 BatchNormalization 和 Dropout, BatchNormalization和Dropout置为False
    model.eval()
    #统计模型正确率, 设置初始值
    correct = 0.0
    test_loss = 0.0
    total = 0
    #torch.no_grad将不会计算梯度, 也不会进行反向传播
    with torch.no_grad():
        for data, label in testloader:
            data, label = data.to(device), label.to(device)
            output = model(data)
            test_loss += F.cross_entropy(output, label).item()
            predict = output.argmax(dim=1)
            #计算正确数量
            total += label.size(0)
            correct += (predict == label).sum().item()
        #计算损失值
        print("test_avarage_loss: {:.6f}, accuracy: {:.6f}%".format(test_loss/total, 100*(correct/total)))

4.6 运行

LeNet-5 网络模型定义好,训练函数、验证函数也定义好了,就可以直接使用 MNIST 数据集进行训练了。

# 调用
epoch = 5
Loss = []
Accuracy = []
for epoch in range(1, epoch+1):
    print("start_time",time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time())))
    loss, acc = train_runner(model, device, trainloader, optimizer, epoch)
    Loss.append(loss)
    Accuracy.append(acc)
    test_runner(model, device, testloader)
    print("end_time: ",time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time())),'\n')

print('Finished Training')
plt.subplot(2,1,1)
plt.plot(Loss)
plt.title('Loss')
plt.show()
plt.subplot(2,1,2)
plt.plot(Accuracy)
plt.title('Accuracy')
plt.show()

经历 5 次 epoch 的 loss 和 accuracy 后,最终在 10000 张测试样本上,average_loss降到了 0.00228,accuracy 达到了 97.72%。可以说 LeNet-5 的效果非常好!

4.7 保存模型

print(model)
torch.save(model, './models/model-mnist.pth') #保存模型

LeNet-5 的模型会 print 出来,并将模型模型命令为 model-mnist.pth 保存在固定目录下。

LeNet(
  (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
  (relu): ReLU()
  (maxpool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (maxpool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (fc1): Linear(in_features=400, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
)

4.8 手写图片的测试

下面,我们将利用刚刚训练的 LeNet-5 模型进行手写数字图片的测试。

import cv2

if __name__ == '__main__':
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = torch.load('./models/model-mnist.pth') #加载模型
    model = model.to(device)
    model.eval()    #把模型转为test模式

    #读取要预测的图片
    img = cv2.imread("./images/test_mnist.jpg")
    img=cv2.resize(img,dsize=(32,32),interpolation=cv2.INTER_NEAREST)
    plt.imshow(img,cmap="gray") # 显示图片
    plt.axis('off') # 不显示坐标轴
    plt.show()

    # 导入图片,图片扩展后为[1,1,32,32]
    trans = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ])
    img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)#图片转为灰度图,因为mnist数据集都是灰度图
    img = trans(img)
    img = img.to(device)
    img = img.unsqueeze(0)  #图片扩展多一维,因为输入到保存的模型中是4维的[batch_size,通道,长,宽],而普通图片只有三维,[通道,长,宽]

    # 预测 
    output = model(img)
    prob = F.softmax(output,dim=1) #prob是10个分类的概率
    print("概率:",prob)
    value, predicted = torch.max(output.data, 1)
    predict = output.argmax(dim=1)
    print("预测类别:",predict.item())

在这里插入图片描述

输出:
概率:tensor([[2.0888e-07, 1.1599e-07, 6.1852e-05, 1.5797e-04, 1.4975e-09, 9.9977e-01,1.9271e-06, 3.1589e-06, 1.2186e-07, 4.3405e-07]],grad_fn=)
预测类别:5模型预测结果正确!
以上就是 PyTorch 构建 LeNet-5 卷积神经网络并用它来识别 MNIST 数据集的例子。


关注公众号:『AI学习星球
回复:LeNet的PyTorch复现 即可获取数据下载。
算法学习4对1辅导论文辅导核心期刊可以通过公众号CSDN滴滴我
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1299102.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

机器学习算法(9)——集成技术(Bagging——随机森林分类器和回归)

一、说明 在这篇文章,我将向您解释集成技术和著名的集成技术之一,它属于装袋技术,称为随机森林分类器和回归。 集成技术是机器学习技术,它结合多个基本模块和模型来创建最佳预测模型。为了更好地理解这个定义,我们需要…

C语言进阶之路之结构体、枚举关卡篇

目录 一、学习目标 二、组合数据类型-结构体 结构体基本概念 结构体的声明: 小怪实战 结构体初始化 指定成员初始化的好处: 结构体成员引用 结构体指针与数组 关卡BOOS 三、结构体的尺寸 CPU字长 地址对齐 结构体的M值 可移植性 四、联合体…

ssm的健身房预约系统(有报告)。Javaee项目。ssm项目。

演示视频: ssm的健身房预约系统(有报告)。Javaee项目。ssm项目。 项目介绍: 采用M(model)V(view)C(controller)三层体系结构,通过Spring Spring…

极智一周 | AI 算力国产化、通义开源、Gemini、鸿蒙、蔚来 And so on

欢迎关注我的公众号 [极智视界],获取我的更多技术分享 大家好,我是极智视界,带来本周的 [极智一周],关键词:AI 算力国产化、通义开源、Gemini、鸿蒙、蔚来 And so on。 邀您加入我的知识星球「极智视界」,…

c-语言->数据在内存的存储

系列文章目录 文章目录 系列文章目录前言 前言 目的:学习整数在内存的储存,什么是大小端,浮点数的储存。 1. 整数在内存中的存储 在讲解操作符的时候,我们就讲过了下⾯的内容: 整数的2进制表⽰⽅法有三种&#xff0…

带有 RaspiCam 的 Raspberry Pi 监控和延时摄影摄像机

一、说明 一段时间以来,我一直想构建一个运动激活且具有延时功能的树莓派相机,但从未真正找到我喜欢的案例。我在thingiverse上找到了这个适合树莓派和相机的好案例。它是为特定的鱼眼相机设计的,但从模型来看,我拥有的廉价中国鱼…

STM32F1之CAN报文传输

目录 报文传输 1. 帧类型 1.1 数据帧 1.1.1 帧起始 1.1.2 仲裁场 1.1.3 控制场 1.1.4 数据场 1.1.5 CRC 场 1.1.6 应答场 1.1.7 帧结尾 1.2 远程帧 1.3 错误帧 1.4 过载帧 1.5 帧间空间(INTERFRAME SPACING) 2. 发送器/接收器的…

【动态规划】【广度优先】LeetCode2258:逃离火灾

作者推荐 本文涉及的基础知识点 二分查找算法合集 动态规划 二分查找 题目 给你一个下标从 0 开始大小为 m x n 的二维整数数组 grid ,它表示一个网格图。每个格子为下面 3 个值之一: 0 表示草地。 1 表示着火的格子。 2 表示一座墙,你跟…

【C语言】内联函数

一、内联函数 在C语言中,内联函数(Inline function)是一种代码优化技术,它的目的是减少函数调用的开销。内联函数通知编译器在每个函数调用的位置插入函数的实际代码,而不是进行传统的函数调用。这避免了调用函数时的…

球上进攻^^

欢迎来到程序小院 球上进攻 玩法&#xff1a;点击鼠标走动躲避滚动的球球&#xff0c;球球碰到即为游戏结束&#xff0c;看看你能坚持多久&#xff0c;快去玩吧^^。开始游戏https://www.ormcc.com/play/gameStart/214 html <div id"game" class"game" …

【基于Flask、MySQL和Echarts的热门游戏数据可视化平台设计与实现】

基于Flask、MySQL和Echarts的热门游戏数据可视化平台设计与实现 前言数据获取与清洗数据集数据获取数据清洗 数据分析与可视化数据分析功能可视化功能 创新点结语 前言 随着游戏产业的蓬勃发展&#xff0c;了解游戏销售数据对于游戏从业者和游戏爱好者都至关重要。为了更好地分…

(六)五种最新算法(SWO、COA、LSO、GRO、LO)求解无人机路径规划MATLAB

一、五种算法&#xff08;SWO、COA、LSO、GRO、LO&#xff09;简介 1、蜘蛛蜂优化算法SWO 蜘蛛蜂优化算法&#xff08;Spider wasp optimizer&#xff0c;SWO&#xff09;由Mohamed Abdel-Basset等人于2023年提出&#xff0c;该算法模型雌性蜘蛛蜂的狩猎、筑巢和交配行为&…

DataFrame的使用

查看数据类型及属性 # 查看df类型 type(df) # 查看df的shape属性&#xff0c;可以获取DataFrame的行数&#xff0c;列数 df.shape # 查看df的columns属性&#xff0c;获取DataFrame中的列名 df.columns # 查看df的dtypes属性&#xff0c;获取每一列的数据类型 df.dtypes df.i…

模型能力赋能搜索——零样本分类(Zero-Shot Classification)在搜索意图识别上的探索

什么是Zero-Shot Classification https://huggingface.co/tasks/zero-shot-classification hugging face上的零样本分类模型 facebook/bart-large-mnli https://huggingface.co/facebook/bart-large-mnli 当然这是一个英文模型&#xff0c;我们要去用一些多语言的模型。 可以在…

Android 样式小结

关于作者&#xff1a;CSDN内容合伙人、技术专家&#xff0c; 从零开始做日活千万级APP。 专注于分享各领域原创系列文章 &#xff0c;擅长java后端、移动开发、商业变现、人工智能等&#xff0c;希望大家多多支持。 目录 一、导读二、概览三、使用3.1 创建并应用样式3.2 创建并…

Azure Machine Learning - 使用 Azure OpenAI 服务生成图像

在浏览器/Python中使用 Azure OpenAI 生成图像&#xff0c;图像生成 API 根据文本提示创建图像。 关注TechLead&#xff0c;分享AI全维度知识。作者拥有10年互联网服务架构、AI产品研发经验、团队管理经验&#xff0c;同济本复旦硕&#xff0c;复旦机器人智能实验室成员&#x…

点击el-tree小三角后去除点击后的高亮背景样式,el-tree样式修改

<div class"videoTree" v-loading"loadingTree" element-loading-text"加载中..." element-loading-spinner"el-icon-loading" element-loading-background"rgba(0, 0, 0, 0.8)" > <el-tree :default-expand-all&q…

可视化监控云平台/智能监控平台EasyCVR国标设备开启音频没有声音是什么原因?

视频云存储/安防监控EasyCVR视频汇聚平台基于云边端智能协同&#xff0c;支持海量视频的轻量化接入与汇聚、转码与处理、全网智能分发、视频集中存储等。GB28181视频平台EasyCVR拓展性强&#xff0c;视频能力丰富&#xff0c;具体可实现视频监控直播、视频轮播、视频录像、云存…

Nacos源码解读09——配置中心配置信息创建修改怎么处理的

存储配置 从整体上Nacos服务端的配置存储分为三层&#xff1a; 内存&#xff1a;Nacos每个节点都在内存里缓存了配置&#xff0c;但是只包含配置的md5&#xff08;缓存配置文件太多了&#xff09;&#xff0c;所以内存级别的配置只能用于比较配置是否发生了变更&#xff0c;只用…

基于SSM实现的公文管理系统

一、技术架构 前端&#xff1a;jsp | jquery | bootstrap 后端&#xff1a;spring | springmvc | mybatis 环境&#xff1a;jdk1.8 | mysql | maven 二、代码及数据库 三、功能介绍 01. 登录页 02. 首页 03. 系统管理-角色管理 04. 系统管理-功能管理 05. 系统管理-用…