Pytorch 实践手写数字识别深度学习网络 LeNet-5

news2024/10/6 12:30:23

Pytorch 实践手写数字识别深度学习网络 LeNet-5

文章目录

  • Pytorch 实践手写数字识别深度学习网络 LeNet-5
    • 认识 LeNet-5
    • 认识数据集
    • 处理数据集
      • 下载数据集
      • 读取数据
      • 定义Dataset的继承类
      • 把数据进行载入
      • 载入`dataloader`
    • 编写网络
    • 编写训练与测试代码
    • 实践结果展示
    • 完整代码

训练手写体识别任务是一个非常简单的学习任务,理论简单是简单,但是我相信很多人都和我一样,想体验一下一个学习任务的全流程,有原始数据,处理数据,编写网络,训练模型,测试模型,使用模型这个过程。今天我们就由我来带大家体验一下。

认识 LeNet-5

LeNet-5出自论文《Gradient-Based Learning Applied to Document Recognition》, 原本是一种用于手写体字符识别的非常高效的卷积神经网络,包含了深度学习的基本模块:卷 积层,池化层,全连接层。

在这里插入图片描述

  • INPUT(输入层) :输入28∗28的图片。
  • C1(卷积层):选取6个5∗5卷积核(不包含偏置),得到6个特征图,每个特征
  • 图的一个边为28−5+1=24。
  • S2(池化层):池化层是一个下采样层,输出12∗12∗6的特征图。
  • C3(卷积层):选取16个大小为5∗5卷积核,得到特征图大小为8∗8∗16。
  • S4(池化层):窗口大小为2∗2,输出4∗4∗16的特征图。
  • F5(全连接层):120个神经元。
  • F6(全连接层):84个神经元。
  • OUTPUT(输出层):10个神经元,10分类问题。

认识数据集

MNIST数据集来自美国国家标准与技术研究所,National Institute of Standards and Technology(NIST),数据集由来自250个不同人手写的数字构 成,其中50%是高中学生,50%来自人口普查局(the Census Bureau)的工 作人员。

训练集:60000,测试集:10000

MNIST数据集可在 http://yann.lecun.com/exdb/mnist/ 获取

大家如果想要的话可以联系我邮箱2837468248@qq.com,也可以直接发给你。

在这里插入图片描述

处理数据集

一般我们进行这个实践的时候,因为这个太经典了,很多的框架就直接集成了这个数据集,不用我们自己处理了,直接拿来用就好了。如下:

train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)

train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=64, shuffle=False)

其中的 torchvision 的包的 MNIST 数据集已经替我们解析好了一切。下载、解压、加载都做好了,效果如下:

在这里插入图片描述

虽然这样我们也能用,但是我们今天是来体验全流程的,我们要自己处理。

下载数据集

可以通过上面的数据集下载网址进行下载

在这里插入图片描述

下载好后是 gz 格式,你可以选择用程序对其进行解压,或者直接用解压软件解压。

解压后我是这样存放数据的

在这里插入图片描述

读取数据

一开始看到这个 ubyte 形式的数据我直接震惊,这啥数据呀,哪里有图片呢,真不清楚,后面了解到这是把数据进行压缩了。 具体解释请看文章

读取图片数据:

# 读取图像文件
def load_mnist_images(file_path):
    with open(file_path, 'rb') as f:
        magic, num_images, rows, cols = struct.unpack('>IIII', f.read(16))
        if magic != 2051:
            raise ValueError(f'Invalid magic number {magic} in file: {file_path}')
        images = np.frombuffer(f.read(), dtype=np.uint8).reshape(num_images, 1, rows, cols)
    return images

读取标签数据

# 读取标签文件
def load_mnist_labels(file_path):
    with open(file_path, 'rb') as f:
        magic, num_labels = struct.unpack('>II', f.read(8))
        if magic != 2049:
            raise ValueError(f'Invalid magic number {magic} in file: {file_path}')
        labels = np.frombuffer(f.read(), dtype=np.uint8)
    return labels

定义Dataset的继承类

原来直接用 torch vision 的数据的话,他会直接包装好,但是我们这里是要全部体验全流程,所以我们自己包装。

为什么要有 Dataset 类呢?

因为在进行深度学习训练的时候都是一批一批进行训练的,需要把数据载入到 Pytorch 提供的 dataloader 中去,方便 pytorch 后面对我们的数据方便进行操作。

我们自己定义一个 Dataset 类的话,一定要实现三个函数

__init__

__len__

__getitem__

这里我们的实现如下:

# 自定义 Dataset 类
class MNISTDataset(Dataset):
    def __init__(self, images, labels, transform=None):
        self.images = images
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx]

        if self.transform:
            image = self.transform(image)

        return image, label

把数据进行载入

这里载入的时候我们有个小地方需要注意一下,对数据进行一个简单的处理,就是把数据的维度进行放小,原来图片的维度如下:

在这里插入图片描述

我们需要的数据维度是 (60000,28,28)的数据,所以要进行降维

train_images = train_images.squeeze(axis=1) # 把第一维度进行减掉

在这里插入图片描述

后面测试集的数据的话就是一样的处理。

载入dataloader

transform = transforms.Compose([
    transforms.ToTensor(), #把numpy数据转化为tensor
    transforms.Normalize((0.5,), (0.5,)) #对数据进行归一化处理
])
from torch.utils.data import DataLoader
train_dataset = MNISTDataset(train_images, train_labels,transform=transform)
test_dataset = MNISTDataset(test_images,test_labels,transform=transform)

train_loader = DataLoader(dataset=train_dataset, batch_size=64,shuffle=True)
test_loader = DataLoader(dataset=test_dataset,batch_size=64,shuffle=False)

编写网络

这个网络是非常简单的,直接定义一个类继承 torch.nn.Module进行实现就好了,代码如下:

# 网络定义
import torch
import torch.nn as nn
from torch.utils.data import Dataset
import torch.optim as optim
import torchvision.transforms as transforms


class LeNet5(nn.Module):
    def __init__(self):
        super(LeNet5, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, kernel_size=5, stride=1, padding=2)
        self.conv2 = nn.Conv2d(6, 16, kernel_size=5, stride=1, padding=0)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = torch.relu(self.conv1(x)) #卷积以后进行激活 
        x = torch.max_pool2d(x, kernel_size=2, stride=2) #最大池化,提取特征
        x = torch.relu(self.conv2(x))
        x = torch.max_pool2d(x, kernel_size=2, stride=2)
        x = x.view(-1, 16 * 5 * 5) #把数据进行展平方便全连接层的输入
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

编写训练与测试代码

训练的话就是一般深度学习的流程,直接调用 pytorch 的API进行解决了。

# 训练和测试函数
def train(model, device, train_loader, optimizer, criterion, epoch):
    model.train()
    best_model = model
    min_loss = 1
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        if min_loss > loss.item():
            best_model, best_loss = model, loss.item()
            print("update")
        if batch_idx % 100 == 0:
            print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} ({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}')
    print("模型训练结束")
    print("保存最好 loss 模型,loss:",min_loss)
    torch.save(best_model.state_dict(),'best-lenet5.pth')

def test(model, device, test_loader, criterion):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += criterion(output, target).item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    print(f'\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} ({100. * correct / len(test_loader.dataset):.0f}%)\n')
# 训练和测试模型
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = LeNet5().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

num_epochs = 10
for epoch in range(1, num_epochs + 1):
    train(model, device, train_loader, optimizer, criterion, epoch)
    test(model, device, test_loader, criterion)

实践结果展示

载入保存的模型 pth 文件,然后手写一个 28 * 28 的数字图片进行处理识别效果如下

在这里插入图片描述

代码如下:

model = LeNet5()
model.load_state_dict(torch.load('best-lenet5.pth',map_location=torch.device('cpu')))
# 假设图像路径
image_path = '5.png'  # 替换为你的图像路径
from PIL import Image
# 使用 PIL 库打开图像
image = Image.open(image_path)

# 使用 torchvision.transforms 进行数据转换和归一化
transform = transforms.Compose([
    transforms.Resize((28, 28)),  # 调整大小到 28x28
    transforms.Grayscale(),
    transforms.ToTensor(),        # 转为 Tensor
    transforms.Normalize((0.5,), (0.5,))  # 归一化
])

# 应用转换
image_tensor = transform(image).unsqueeze(0)
predict_output = model(image_tensor)
pred_num = predict_output.argmax(dim=1,keepdim=True)
print(pred_num) # 数据要写满28*28的格子才能预测)

如果需要 pth 文件也可以联系我,不过这个训练很快,可以自己训练玩!

完整代码

# 网络定义
import torch
import torch.nn as nn
from torch.utils.data import Dataset
import torch.optim as optim
import torchvision.transforms as transforms


class LeNet5(nn.Module):
    def __init__(self):
        super(LeNet5, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, kernel_size=5, stride=1, padding=2)
        self.conv2 = nn.Conv2d(6, 16, kernel_size=5, stride=1, padding=0)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.max_pool2d(x, kernel_size=2, stride=2)
        x = torch.relu(self.conv2(x))
        x = torch.max_pool2d(x, kernel_size=2, stride=2)
        x = x.view(-1, 16 * 5 * 5)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x
# 自定义 Dataset 类
class MNISTDataset(Dataset):
    def __init__(self, images, labels, transform=None):
        self.images = images
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx]

        if self.transform:
            image = self.transform(image)

        return image, label
import numpy as np
import struct


# 读取标签文件
def load_mnist_labels(file_path):
    with open(file_path, 'rb') as f:
        magic, num_labels = struct.unpack('>II', f.read(8))
        if magic != 2049:
            raise ValueError(f'Invalid magic number {magic} in file: {file_path}')
        labels = np.frombuffer(f.read(), dtype=np.uint8)
    return labels


# 读取图像文件
def load_mnist_images(file_path):
    with open(file_path, 'rb') as f:
        magic, num_images, rows, cols = struct.unpack('>IIII', f.read(16))
        if magic != 2051:
            raise ValueError(f'Invalid magic number {magic} in file: {file_path}')
        images = np.frombuffer(f.read(), dtype=np.uint8).reshape(num_images, 1, rows, cols)
    return images

# 获取到标签,图像数据
train_images = load_mnist_images('./data/train/train-images-idx3-ubyte')
train_labels = load_mnist_labels('./data/train/train-labels-idx1-ubyte')
print('train_images.shape', train_images.shape)
print('label.shape', train_labels.shape)
train_images = train_images.squeeze(axis=1)
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])
test_images = load_mnist_images('./data/test/t10k-images-idx3-ubyte')
test_labels = load_mnist_labels('./data/test/t10k-labels-idx1-ubyte')
test_images = test_images.squeeze(axis = 1)
print(test_images.shape)
print(test_labels.shape)
from torch.utils.data import DataLoader
train_dataset = MNISTDataset(train_images, train_labels,transform=transform)
test_dataset = MNISTDataset(test_images,test_labels,transform=transform)

train_loader = DataLoader(dataset=train_dataset, batch_size=64,shuffle=True)
test_loader = DataLoader(dataset=test_dataset,batch_size=64,shuffle=False)
# 训练和测试函数
def train(model, device, train_loader, optimizer, criterion, epoch):
    model.train()
    best_model = model
    min_loss = 100000.0
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        if min_loss > loss.item():
            best_model, best_loss = model, loss.item()
            print("update")
        if batch_idx % 100 == 0:
            print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} ({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}')
    print("模型训练结束")
    print("保存最好 loss 模型,loss:",min_loss)
    torch.save(best_model.state_dict(),'best-lenet5.pth')

def test(model, device, test_loader, criterion):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += criterion(output, target).item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    print(f'\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} ({100. * correct / len(test_loader.dataset):.0f}%)\n')

# 训练和测试模型
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = LeNet5().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

num_epochs = 10
for epoch in range(1, num_epochs + 1):
    train(model, device, train_loader, optimizer, criterion, epoch)
    test(model, device, test_loader, criterion)


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1903448.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

什么是反射?

什么是反射? 1、反射的基本概念2、 获取Class对象3、获取类的成员变量、方法和构造方法3.1 获取成员变量3.2 获取方法3.3 获取构造方法3.4 动态调用方法 4、反射的优缺点 💖The Begin💖点点关注,收藏不迷路💖 反射&…

Unity3D 转换微信小游戏指引 02

Unity3D 转换微信小游戏指引系列(第二期) 云开发 当小游戏打包后的首包占用内存比较大(大约是 14M 左右),首包资源加载方式就不能选择小游戏包内了。 这时就需要购买服务器,把首包放到服务器上&#xff…

Drools开源业务规则引擎(二)- Drools规则语言(DRL)

文章目录 1.DRL文件的组成:2.package3.import4.function5.query6.declare7.global8.rule8.1.规则属性8.2.LHS8.2.1.语法格式8.2.2.运算符优先级8.2.3.特殊的运算符1.matches, not matches2.contains, not contains3.memberOf, not memberOf4.in, notin5.soundslike6…

尚品汇-(十三)

&#xff08;1&#xff09;查询sku列表 在ManageService 中添加 /*** SKU分页列表* param pageParam* return*/ IPage<SkuInfo> getPage(Page<SkuInfo> pageParam);接口实现类 Override public IPage<SkuInfo> getPage(Page<SkuInfo> pageParam) {Qu…

STM32-01 推挽输出-点亮LED

本文以STM32中点亮LED为例&#xff0c;解读推挽输出的原理 推挽输出介绍 所谓的推挽输出&#xff0c;就是通过控制输出控制模块&#xff0c;打开或者关闭P-MOS或者N-MOS。 ─ 推挽模式下&#xff1a;输出寄存器上的’0’激活N-MOS&#xff0c;而输出寄存器上的’1’将激活P-M…

IDEA与通义灵码的智能编程之旅

1 概述 本文主要介绍在IDEA中如何安装和使用通义灵码来助力软件编程,从而提高编程效率,创造更大的个人同企业价值。 2 安装通义灵码 2.1 打开IDEA插件市场 点击IDEA的设置按钮,下拉选择Plugins,如下: 2.2 搜索通义灵码 在搜索框中输入“通义灵码”,如下: 2.3 安…

74HC165芯片验证

目录 0x01 74HC165芯片介绍0x02 编程实现 0x01 74HC165芯片介绍 74HC165的引脚定义如下&#xff0c;长这个样子 ABCDEFGH是它的八个输入引脚&#xff0c;例如你可以将它连接按键&#xff0c;让它来读取8个按键值。也可以将他级联其它的74165&#xff0c;无需增加单片机GPIO引…

Apache Seata Mac下的Seata Demo环境搭建

本文来自 Apache Seata官方文档&#xff0c;欢迎访问官网&#xff0c;查看更多深度文章。 本文来自 Apache Seata官方文档&#xff0c;欢迎访问官网&#xff0c;查看更多深度文章。 Mac下的Seata Demo环境搭建&#xff08;AT模式&#xff09; 前言 最近因为工作需要&#xf…

强化学习的数学原理:时序差分算法

概述 之前第五次课时学习的 蒙特卡洛 的方法是全课程当中第一次介绍的第一种 model-free 的方法&#xff0c;而本次课的 Temporal-Difference Learning 简称 TD learning &#xff08;时序差分算法&#xff09;就是第二种 model-free 的方法。而对于 蒙特卡洛方法其是一种 non…

DropNotch for Mac v1.0.1 在 Mac 刘海快速使用 AirDrop

应用介绍 DropNotch 是一款专为Mac设计的应用程序&#xff0c;可以将MacBook的凹口区域&#xff08;刘海&#xff09;转换为文件放置区。 功能特点 文件共享: 用户可以将文件拖放到MacBook的凹口区域&#xff0c;并通过AirDrop、邮件、消息等方式轻松共享。多显示器支持: 即…

【经验篇】Spring Data JPA开启批量更新时乐观锁失效问题

乐观锁机制 什么是乐观锁&#xff1f; 乐观锁的基本思想是&#xff0c;认为在大多数情况下&#xff0c;数据访问不会导致冲突。因此&#xff0c;乐观锁允许多个事务同时读取和修改相同的数据&#xff0c;而不进行显式的锁定。在提交事务之前&#xff0c;会检查是否有其他事务…

3.js - 裁剪场景(多个scence)

不给newScence添加background、environment时 给newScence添加background、environment时 源码 // ts-nocheck// 引入three.js import * as THREE from three// 导入轨道控制器 import { OrbitControls } from three/examples/jsm/controls/OrbitControls// 导入lil.gui impor…

leetcode每日一题-3033. 修改矩阵

题目描述&#xff1a; 解题思路&#xff1a;简单题目&#xff0c;思路非常直接。对列进行遍历&#xff0c;记录下最大值&#xff0c;然后再遍历一遍&#xff0c;把-1替换为最大值。需要注意的是进行列遍历和行遍历是不同的。 官方题解&#xff1a; class Solution { public:v…

工控人最爱的PLC触摸屏一体机,有多香

PLC触摸屏一体机是什么 PLC触摸屏一体机&#xff0c;听起来可能有点技术化&#xff0c;但简单来说&#xff0c;它就是一个集成了可编程逻辑控制器&#xff08;PLC&#xff09;和触摸屏的智能设备。这种设备不仅能够执行自动化控制任务&#xff0c;还能实时显示和操作设备状态&a…

作业训练二编程题3. 数的距离差

【问题描述】 给定一组正整数&#xff0c;其中最大值和最小值分别为Max和Min, 其中一个数x到Max和Min的距离差定义为&#xff1a; abs(abs(x-Max)-(x-Min)) 其中abs()为求一个数的绝对值 【输入形式】 包括两行&#xff0c;第一行一个数n&#xff0c;表示第二行有n个正整数…

如何在PD虚拟机中开启系统的嵌套虚拟化功能?pd虚拟机怎么用 Parallels Desktop 19 for Mac

PD虚拟机是一款可以在Mac电脑中运行Windows系统的应用软件。使用 Parallels Desktop for Mac 体验 macOS 和 Windows 的最优性能&#xff0c;解锁强大性能和无缝交互。 在ParallelsDesktop&#xff08;PD虚拟机&#xff09;中如何开启系统的嵌套虚拟化功能&#xff1f;下面我们…

新手教学系列——前后端分离API优化版

在之前的文章《Vue 前后端分离开发:懒人必备的API SDK》中,我介绍了通过Object对象自动生成API的方法。然而,之前的代码存在一些冗余之处。今天,我将分享一个改进版本,帮助你更高效地管理API。 改进版API SDK 首先,让我们来看一下改进后的代码: import request from …

华为OD机试 - 来自异国的客人(Java 2024 D卷 100分)

华为OD机试 2024D卷题库疯狂收录中&#xff0c;刷题点这里 专栏导读 本专栏收录于《华为OD机试&#xff08;JAVA&#xff09;真题&#xff08;D卷C卷A卷B卷&#xff09;》。 刷的越多&#xff0c;抽中的概率越大&#xff0c;每一题都有详细的答题思路、详细的代码注释、样例测…

自动控制:反馈控制

自动控制&#xff1a;反馈控制 反馈控制&#xff08;Feedback Control&#xff09;是一种在控制系统中通过测量输出信号&#xff0c;并将其与期望信号进行比较&#xff0c;产生误差信号&#xff0c;再根据误差信号调整输入来达到控制目标的方法。反馈控制是自动控制系统中最常…

C#——使用ini-parser第三方操作ini文件

使用ini-parser第三方操作ini文件 IniParser - 一个轻量级的.NET类库&#xff0c;用于读写INI文件。 安装 在NuGet程序包中下载IniParser第三方 使用IniParser第三方操作Ini文件 读取 // 初始化解析器var parser new FileIniDataParser();// 读取INI文件string iniFilePat…