pytorch,用lenet5识别cifar10数据集(训练+测试+单张图片识别)

news2024/9/23 1:22:06

目录

LeNet-5

LeNet-5 结构

CIFAR-10

pytorch实现

lenet模型

训练模型

1.导入数据

2.训练模型

3.测试模型

测试单张图片

代码

运行结果


LeNet-5

LeNet-5 是由 Yann LeCun 等人在 1998 年提出的一种经典卷积神经网络(CNN)模型,主要用于手写数字识别任务。它在 MNIST 数据集上表现出色,并且是深度学习历史上的一个重要里程碑。

LeNet-5 结构

LeNet-5 的结构包括以下几个层次:

  1. 输入层: 32x32 的灰度图像。
  2. 卷积层 C1: 包含 6 个 5x5 的滤波器,输出尺寸为 28x28x6。
  3. 池化层 S2: 平均池化层,输出尺寸为 14x14x6。
  4. 卷积层 C3: 包含 16 个 5x5 的滤波器,输出尺寸为 10x10x16。
  5. 池化层 S4: 平均池化层,输出尺寸为 5x5x16。
  6. 卷积层 C5: 包含 120 个 5x5 的滤波器,输出尺寸为 1x1x120。
  7. 全连接层 F6: 包含 84 个神经元。
  8. 输出层: 包含 10 个神经元,对应于 10 个类别。

CIFAR-10

CIFAR-10 是一个常用的图像分类数据集,包含 10 个类别的 60,000 张 32x32 彩色图像。每个类别有 6,000 张图像,其中 50,000 张用于训练,10,000 张用于测试。

1. 标注数据量训练集:50000张图像测试集:10000张图像

2. 标注类别数据集共有10个类别。具体分类见图1。

3. 可视化

pytorch实现

lenet模型

  • 平均池化(Average Pooling):对池化窗口内所有像素的值取平均,适合保留图像的背景信息。
  • 最大池化(Max Pooling):对池化窗口内的最大值进行选择,适合提取显著特征并具有降噪效果。

在实际应用中,最大池化更常用,因为它通常能更好地保留重要特征并提高模型的性能。

import torch.nn as nn
import torch.nn.functional as func


class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, kernel_size=5)
        self.conv2 = nn.Conv2d(6, 16, kernel_size=5)
        self.fc1 = nn.Linear(16*5*5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = func.relu(self.conv1(x))
        x = func.max_pool2d(x, 2)
        x = func.relu(self.conv2(x))
        x = func.max_pool2d(x, 2)
        x = x.view(x.size(0), -1)
        x = func.relu(self.fc1(x))
        x = func.relu(self.fc2(x))
        x = self.fc3(x)
        return x

训练模型

1.导入数据

导入训练数据和测试数据

    def load_data(self):
        #transforms.RandomHorizontalFlip() 是 pytorch 中用来进行随机水平翻转的函数。它将以一定概率(默认为0.5)对输入的图像进行水平翻转,并返回翻转后的图像。这可以用于数据增强,使模型能够更好地泛化。

        train_transform = transforms.Compose([transforms.RandomHorizontalFlip(), transforms.ToTensor()])
        test_transform = transforms.Compose([transforms.ToTensor()])

        train_set = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform)
        self.train_loader = torch.utils.data.DataLoader(dataset=train_set, batch_size=self.train_batch_size, shuffle=True)
        
        # shuffle=True 表示在每次迭代时,数据集都会被重新打乱。这可以防止模型在训练过程中过度拟合训练数据,并提高模型的泛化能力。
        test_set = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=test_transform)
        self.test_loader = torch.utils.data.DataLoader(dataset=test_set, batch_size=self.test_batch_size, shuffle=False)

2.训练模型

    def train(self):
        print("train:")
        self.model.train()
        train_loss = 0
        train_correct = 0
        total = 0

        for batch_num, (data, target) in enumerate(self.train_loader):
            data, target = data.to(self.device), target.to(self.device)
            self.optimizer.zero_grad()
            output = self.model(data)
            loss = self.criterion(output, target)
            loss.backward()
            self.optimizer.step()
            train_loss += loss.item()
            prediction = torch.max(output, 1)  # second param "1" represents the dimension to be reduced
            total += target.size(0)

            # train_correct incremented by one if predicted right
            train_correct += np.sum(prediction[1].cpu().numpy() == target.cpu().numpy())

            progress_bar(batch_num, len(self.train_loader), 'Loss: %.4f | Acc: %.3f%% (%d/%d)'
                         % (train_loss / (batch_num + 1), 100. * train_correct / total, train_correct, total))

        return train_loss, train_correct / total

3.测试模型

    def test(self):
        print("test:")
        self.model.eval()
        test_loss = 0
        test_correct = 0
        total = 0

        with torch.no_grad():
            for batch_num, (data, target) in enumerate(self.test_loader):
                data, target = data.to(self.device), target.to(self.device)
                output = self.model(data)
                loss = self.criterion(output, target)
                test_loss += loss.item()
                prediction = torch.max(output, 1)
                total += target.size(0)
                test_correct += np.sum(prediction[1].cpu().numpy() == target.cpu().numpy())

                progress_bar(batch_num, len(self.test_loader), 'Loss: %.4f | Acc: %.3f%% (%d/%d)'
                             % (test_loss / (batch_num + 1), 100. * test_correct / total, test_correct, total))

        return test_loss, test_correct / total

测试单张图片

网上随便下载一个图片

然后使用图片编辑工具,把图片设置为32x32大小

通过导入模型,然后测试一下

代码

import torch
import cv2
import torch.nn.functional as F
#from model import Net  ##重要,虽然显示灰色(即在次代码中没用到),但若没有引入这个模型代码,加载模型时会找不到模型
from torch.autograd import Variable
from torchvision import datasets, transforms
import numpy as np
 
classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
if __name__ == '__main__':
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = torch.load('lenet.pth')  # 加载模型
    model = model.to(device)
    model.eval()  # 把模型转为test模式
 
    img = cv2.imread("bird1.png")  # 读取要预测的图片
    trans = transforms.Compose(
        [
            transforms.ToTensor()
        ])
 
    img = trans(img)
    img = img.to(device)
    img = img.unsqueeze(0)  # 图片扩展多一维,因为输入到保存的模型中是4维的[batch_size,通道,长,宽],而普通图片只有三维,[通道,长,宽]
    # 扩展后,为[1,1,28,28]
    output = model(img)
    prob = F.softmax(output,dim=1) #prob是10个分类的概率
    print(prob)
    value, predicted = torch.max(output.data, 1)
    print(predicted.item())
    print(value)
    pred_class = classes[predicted.item()]
    print(pred_class)

运行结果

tensor([[1.8428e-01, 1.3935e-06, 7.8295e-01, 8.5042e-04, 3.0219e-06, 1.6916e-04,
         5.8798e-06, 3.1647e-02, 1.7037e-08, 8.9128e-05]], device='cuda:0',
       grad_fn=<SoftmaxBackward0>)
2
tensor([4.0915], device='cuda:0')
bird

从结果看,效果还不错。记录一下

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2034241.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

dos命令获取java进程的pid并停止 bat脚本 第二篇

最近要做一个java程序一键重启的功能,主要思路是用批处理命令先将java程序停止,然后重新启动,研究了一把dos命令, taskkill /?取得帮助, taskkill /FI是筛选器: 然后很happy的写好停止脚本如下: taskkill /f /fi "IMAGENAME eq javaw.exe"是不是这样就行了…

spring中使用到的设计模式有哪些

Spring 框架是一个高度模块化和灵活的框架&#xff0c;广泛使用了各种设计模式来实现其核心功能和架构。这些设计模式帮助 Spring 提供了高可配置性、可扩展性和可维护性。以下是 Spring 框架中使用到的一些关键设计模式&#xff1a;

linux 安装jdk步骤

建议用免安装版的&#xff0c;安装方法如下&#xff1a; 一、软件下载 查看系统多少位 getconf LONG_BIT 下载JDK&#xff08;下面分别是32位系统和64位系统下的版本&#xff09; # 32位 http://download.oracle.com/otn-pub/java/jdk/7u9-b05/jdk-7u9-linux-i586.tar.gz?…

来电、消息提醒延时很久,该如何解决

使用华为穿戴设备且同时使用三方安卓手机的朋友们&#xff0c;是否发现自己的华为手表经常接不到电话&#xff0c;接到消息提醒也是延时很久&#xff1f;不是手表有问题&#xff0c;而是因为三方安卓手机系统管控华为运动健康App&#xff0c;导致推动来电和消息有延迟。 若您使…

《实现 DevOps 平台(2) · GitLab CI/CD 交互》

&#x1f4e2; 大家好&#xff0c;我是 【战神刘玉栋】&#xff0c;有10多年的研发经验&#xff0c;致力于前后端技术栈的知识沉淀和传播。 &#x1f497; &#x1f33b; CSDN入驻不久&#xff0c;希望大家多多支持&#xff0c;后续会继续提升文章质量&#xff0c;绝不滥竽充数…

猫头虎 分享:Python库 Flask 的简介、安装、用法详解入门教程

&#x1f42f; 猫头虎 分享&#xff1a;Python库 Flask 的简介、安装、用法详解入门教程 这是猫头虎带给大家的一篇关于Flask框架的入门教程&#xff01;&#x1f389; 今天猫头虎要跟大家聊聊Python中的一个非常重要且流行的库——Flask。如果你正在寻找一个轻量级、易上手、…

基于CANopen的LabVIEW同步与PDO通信示例

该程序展示了在LabVIEW中使用CANopen协议实现同步消息&#xff08;SYNC&#xff09;和PDO&#xff08;过程数据对象&#xff09;通信的流程。 以下是程序各部分的详细解释&#xff1a; 接口创建 (Interface Create)&#xff1a; 创建一个CANopen接口&#xff0c;并设定通信的波…

git常见命令和常见问题解决

文章目录 常见命令问题问题1&#xff08;git push相关&#xff09;问题2&#xff08;git push相关&#xff09;问题3&#xff08;git push相关&#xff09;删除github的仓库github新创建本地仓库的操作…or create a new repository on the command line…or push an existing …

【burp + ddddocr 加载验证码识别插件对登录页面进行爆破】

1 插件下载 项目地址&#xff1a; https://github.com/c0ny1/captcha-killer https://github.com/f0ng/captcha-killer-modified 安装burp插件&#xff1a; 下载已编译好的jar文件 https://github.com/f0ng/captcha-killer-modified/releases 2 验证码识别平台使用 https://g…

三防平板满足多样化定制为工业领域打造硬件解决方案

在当今工业领域&#xff0c;数字化、智能化的发展趋势日益显著&#xff0c;对于高效、可靠且适应各种复杂环境的硬件设备需求不断增长。三防平板作为一种具有坚固耐用、防水防尘防摔特性的工业级设备&#xff0c;正以其出色的性能和多样化的定制能力&#xff0c;为不同行业的应…

从OFD文件提取数字证书过程详解

OFD 文件是“Open Fixed Document”的缩写&#xff0c;它是一种用于电子文档的开放标准格式。OFD 文件格式由中国国家标准化管理委员会&#xff08;SAC&#xff09;制定&#xff0c;目的是提供一种开放、稳定且兼容性强的电子文档格式。下面是 OFD 文件的一些主要特点&#xff…

火语言RPA--火语言自动化插件安装方法

使用自动化控制浏览器插件实现自动化操作&#xff1a; 自动安装插件步骤 ① 点击下图中星标位置图标按钮 ② 点击上一步骤图标按钮后&#xff0c;弹出如下图所示插件安装对话框 ③ 点击下图中星标位置图标安装按钮 ④ 如果Chrome浏览器正在运行&#xff0c;则会弹出如下图所示…

Python大数据分析——DBSCAN聚类模型(密度聚类)

Python大数据分析——DBSCAN聚类模型 介绍数学基础模型步骤函数密度聚类对比Kmeans聚类球形簇聚类情况非球形簇的情况 示例 介绍 Kmeans聚类存在两个致命缺点&#xff0c;一是聚类效果容易受到异常样本点的影响&#xff08;因为求的是均值&#xff0c;而异常值对于均值聚类非常…

移动UI:阅读类应用如何从设计上吸引读者?

阅读类应用的用户界面设计对于吸引读者和提升用户体验至关重要。 以下是一些设计上的建议&#xff0c;可以帮助阅读类应用吸引读者&#xff1a; 1. 清晰的内容布局&#xff1a; 确保内容排版清晰&#xff0c;字体大小适中&#xff0c;行间距和段落间距合适&#xff0c;让用户…

纯技巧,伦敦金投资入门阶段怎么学习K线?

投资者刚进入伦敦金市场&#xff0c;都需要学习一定的交易知识&#xff0c;否则没办法在市场中立足。而K线很可能是我们做伦敦金投资入门时第一个碰到的需要学习的理论。下面我们就来讨论一下&#xff0c;在投资入门阶段&#xff0c;我们怎么学习K线。其实作为伦敦金投资的工具…

vega ai创作平台官网基础教程-文生图功能使用

我们都知道vega ai创作平台是右脑科技公司发布的一款革新性的在线AI艺术创作工具&#xff0c;它凭借先进的人工智能技术&#xff0c;为艺术家们打开了一扇通往无限创作可能的大门。无论是将文字灵感转化为视觉艺术&#xff0c;还是通过融合多张图片来训练出独特的艺术风格&…

2024.8.12

2024.8.12 【梦最让我费解的地方在于&#xff0c;明明你看不清梦里人们的脸&#xff0c;却清晰地知道他们是谁。】 Monday 七月初九 序理论 最小链覆盖&最长反链长度 我们设定一个二元关系符R和一个集合A 我们设定<A,R>这样一个类群&#xff0c;那么对于任意 a i…

酒店行业如何利用XML进行营销短信

随着信息社会的到来&#xff0c;消费者获得会所的服务也从单纯的电话方式&#xff0c;逐渐转变为电话、互联网、传真&#xff0c;群发短信等多种媒体并行的方式。今天着重介绍下酒店行业如何利用短信平台进行营销。 群发短信业务对酒店起到的效率&#xff1a;根据新产品或服务向…

20240813 每日AI必读资讯

Flux生成网红博主因太逼真爆火&#xff01;有人用Claude写代码识破“AI美女” - Flux生成的情侣合照逼真程度达到恐怖级别&#xff0c;挑战人类视觉辨识能力。 - 网友发现Flux生成的照片几乎完美&#xff0c;但仍有细微瑕疵可供识别。 - 有人利用Flux等工具制作逼真的YouTub…

分班查询一键发布,老师们都在用

新学期的钟声即将敲响&#xff0c;校园里又将迎来一批充满好奇和期待的新生。对于老师们来说&#xff0c;这不仅仅是一个新起点&#xff0c;更是一项挑战——如何高效而准确地将新生的分班信息传达给每一位家长。传统的方法是通过私信逐一发送&#xff0c;这不仅耗时耗力&#…