MNIST手写字符分类

news2025/1/13 13:13:37

MNIST手写字符分类

文章目录

  • MNIST手写字符分类
    • 1 数据集
    • 2 模型构建
    • 3 训练
    • 4 模型保存
    • 5 推理
    • 6 模型导出
    • 7 导出模型测试

1 数据集

  MNIST手写字符集包括60000张用于训练的训练集图片和10000张用于测试的测试集图片,所有图片均归一化为28*28的灰度图像。其中字符区域为白色,非字符区域为黑色。
  该数据集可以直接通过pytorch dataset进行下载。

2 模型构建

  本次试验采用图像分类的方案进行手写字符分类。分类网络采用线性层->非线性函数堆叠的方式实现。本次试验设计了两个网络,分别是3层网络模型以及4层网络模型。

  模型定义如下:

import torch
from torch import nn
from torch.utils.data import DataLoader

class ZKNNNet_3Layer(nn.Module):
    def __init__(self):
        super(ZKNNNet_3Layer, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10)
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

class ZKNNNet_5Layer(nn.Module):
    def __init__(self):
        super(ZKNNNet_5Layer, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10)
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

3 训练

  模型训练过程主要包括确定数据集、构建dataloader、确定训练设备(device)、生成模型对象、定义优化器、定义目标函数、设定模型为训练模式、获取输入与标签、进行模型推理、计算损失函数、优化器梯度清零、损失反向传播、优化器步进等。

  训练代码如下:

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda, Compose
from ZKNNNet import ZKNNNet_3Layer, ZKNNNet_5Layer
import os

# Download training data from open datasets.
training_data = datasets.MNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),
)

# Download test data from open datasets.
test_data = datasets.MNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor(),
)

# Create data loaders.
train_dataloader = DataLoader(training_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)

# Get cpu or gpu device for training.
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using {} device".format(device))

model = ZKNNNet_3Layer()
if os.path.exists("model/model_3layer.pth"):
    model.load_state_dict(torch.load("model/model_3layer.pth"))
model = model.to(device)
print(model)

# Optimizer
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

# Loss function
loss_fn = nn.CrossEntropyLoss()

# Train
def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

# Test
def test(dataloader, model):
    size = len(dataloader.dataset)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= size
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
    return correct

epochs = 200
maxAcc = 0
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(train_dataloader, model, loss_fn, optimizer)
    currentAcc = test(test_dataloader, model)
    if maxAcc < currentAcc:
        maxAcc = currentAcc
        torch.save(model.state_dict(), "model/model_3layer.pth")
print("Done!")

  以上脚本训练的是3层网络,经过200个epoch的训练,模型精度可以达到96%左右。

  以下为训练5层网络模型的训练代码。

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda, Compose
from ZKNNNet import ZKNNNet_3Layer, ZKNNNet_5Layer
import os
# Download training data from open datasets.
training_data = datasets.MNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),
)

# Download test data from open datasets.
test_data = datasets.MNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor(),
)

# Create data loaders.
train_dataloader = DataLoader(training_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)

# Get cpu or gpu device for training.
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using {} device".format(device))

model = ZKNNNet_5Layer()
if os.path.exists("model/model_5layer.pth"):
    model.load_state_dict(torch.load("model/model_5layer.pth"))
model = model.to(device)
print(model)

# Optimizer
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

# Loss function
loss_fn = nn.CrossEntropyLoss()

# Train
def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

# Test
def test(dataloader, model):
    size = len(dataloader.dataset)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= size
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
    return correct

epochs = 200
maxAcc = 0
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(train_dataloader, model, loss_fn, optimizer)
    currentAcc = test(test_dataloader, model)
    if maxAcc < currentAcc:
        maxAcc = currentAcc
        torch.save(model.state_dict(), "model/model_5layer.pth")
print("Done!")

4 模型保存

  在pytorch中,模型保存通过torch.save即可完成模型的保存。通常使用的方式是torch.save(model.state_dice(),'save_model_path.pth')
  这里需要注意的时候,通常我们只需要保存在测试集上准确率最高的模型,此时,可以结合训练过程中的测试过程,在计算了在测试集上的准确率之后,根据准确率进行保存。

# Test
def test(dataloader, model):
    size = len(dataloader.dataset)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= size
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
    return correct

epochs = 200
maxAcc = 0
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(train_dataloader, model, loss_fn, optimizer)
    currentAcc = test(test_dataloader, model)
    if maxAcc < currentAcc:
        maxAcc = currentAcc
        torch.save(model.state_dict(), "model/model_5layer.pth")

这部分代码显示了在test()测试过程中计算模型在测试集上的准确率并返回,在保存模型时,判断当前测试集准确率是否优于历史最优准确率,高于的话就保存模型。

5 推理

  在进行完模型训练并保存了最优模型之后,我们需要对模型进行推理测试。通常用于在模型训练完成后,使用少量数据进行模型推理结果可视化,方便排查模型性能。

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
from torchvision import datasets
from ZKNNNet import ZKNNNet_3Layer

import matplotlib.pyplot as plt

# Get cpu or gpu device for inference.
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using {} device for inference".format(device))

# Load the trained model
model = ZKNNNet_3Layer()
model.load_state_dict(torch.load("model/model_3layer.pth"))
model.to(device)
model.eval()

# Download test data from open datasets.
test_data = datasets.MNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor(),
)

# Create data loader.
test_dataloader = DataLoader(test_data, batch_size=64)

# Perform inference
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_dataloader:
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

        # Visualize the image and its predicted result
        for i in range(len(images)):
            image = images[i].cpu()
            label = labels[i].cpu()
            prediction = predicted[i].cpu()

            plt.imshow(image.squeeze(), cmap='gray')
            plt.title(f"Label: {label}, Predicted: {prediction}")
            plt.show()

    accuracy = 100 * correct / total
    print("Accuracy on test set: {:.2f}%".format(accuracy))

在这里插入图片描述

6 模型导出

  在模型训练完成,经过推理可视化,证明模型精度可用的时候,可以对模型进行导出。因为pytorch模型结构为pth结构,该模型文件不利于部署。通常可以将pth模型文件导出为onnx模型文件。onnx模型文件通常是各种推理框架支持的中间模型文件格式。

import torch
import torch.utils
import os
from ZKNNNet import ZKNNNet_3Layer,ZKNNNet_5Layer

device = "cpu"
print("Using {} device".format(device))
model_3Layer = ZKNNNet_3Layer()
if os.path.exists('./model/model_3layer.pth'):
    model_3Layer.load_state_dict(torch.load('./model/model_3layer.pth'))
model_3Layer = model_3Layer.to(device)

model_3Layer.eval()

# export pytorch model to onnx
torch.onnx.export(model_3Layer, torch.randn(1, 1, 28, 28), './model/model_3layer.onnx', verbose=True)

model_5Layer = ZKNNNet_5Layer()
if os.path.exists('./model/model_5layer.pth'):
    model_5Layer.load_state_dict(torch.load('./model/model_5layer.pth'))
model_5Layer = model_5Layer.to(device)
model_5Layer.eval()
torch.onnx.export(model_5Layer,torch.randn(1,1,28,28),'./model/model_5layer.onnx',verbose=True)

通过上述脚本,可以将之前训练好的3层模型和5层模型文件进行导出,生成对应的onnx模型文件。

7 导出模型测试

  在onnx模型文件生成之后,需要对onnx模型文件进行测试,判断整个模型文件导出过程、推理过程是否正确。

import onnxruntime as rt
import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
from torchvision import datasets

import matplotlib.pyplot as plt

from PIL import Image

sess = rt.InferenceSession("model/model_3layer.onnx")
input_name = sess.get_inputs()[0].name
print(input_name)

image = Image.open('./data/test/2.png')
image_data = np.array(image)
image_data = image_data.astype(np.float32)/255.0
image_data = image_data[None,None,:,:]
print(image_data.shape)

outputs = sess.run(None,{input_name:image_data})
outputs = np.array(outputs).flatten()

prediction = np.argmax(outputs)
plt.imshow(image, cmap='gray')
plt.title(f"Predicted: {prediction}")
plt.show()

# Download test data from open datasets.
test_data = datasets.MNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor(),
)

# Create data loader.
test_dataloader = DataLoader(test_data, batch_size=1)

with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_dataloader:
        images = images.numpy()
        labels = labels.numpy()
        outputs = sess.run(None,{input_name:images})[0]
        outputs = np.array(outputs).flatten()
        prediction = np.argmax(outputs)

        # Visualize the image and its predicted result
        for i in range(len(images)):
            image = images[i]
            label = labels[i]

            plt.imshow(image.squeeze(), cmap='gray')
            plt.title(f"Label: {label}, Predicted: {prediction}")
            plt.show()

以上推理onnx模型的过程中,分别选择了直接图片输入和来自dataloader输入两种方式。这里需要注意的是,在使用图片进行输入时,需要注意数据范围需要与训练时的dataloader输入方式时一致。在使用dataloader时,图像数据是归一化到0-1的范围内的,而使用PIL读取图片之后是uint8的范围0-255的数据,因此在推理之前需要先进行数据范围转化,将输入数据转换成范围0-1的float32类型的数据。否则可能导致推理结果错误。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1817217.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

LabVIEW水箱液位控制系统

介绍了如何使用LabVIEW软件和硬件工具开发水箱液位控制系统。系统集成了数据采集、实时控制和模拟仿真技术&#xff0c;展示了高精度和高可靠性的特点&#xff0c;适用于需要精细水位调节的工业应用。 项目背景 在制造和化工行业&#xff0c;液位控制是保证生产安全与效率的关…

第3章 Unity 3D着色器系统

3.1 从一个外观着色器程序谈起 新建名为basic_diffuse.shader的文件&#xff0c;被一个名为basic_diffuse.mat的材质文件所引用&#xff0c;而basic_diffuse.mat文件则被场景中名为Sphere的game object的MeshRenderer组件所使用。 basic_diffuse.shader代码文件的内容如下所示…

51.Python-web框架-Django开始第一个应用的增删改查

目录 1.概述 2.创建应用 创建app01 在settings.py里引用app01 3.定义模型 在app01\models.py里创建模型 数据库迁移 4.创建视图 引用头 部门列表视图 部门添加视图 部门编辑视图 部门删除视图 5.创建Template 在app01下创建目录templates 部门列表模板depart.ht…

java+vue3+el-tree实现树形结构操作

基于springboot vue3 elementPlus实现树形结构数据的添加、删除和页面展示 效果如下 代码如下&#xff0c;业务部分可以自行修改 java后台代码 import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper; import com.daztk.mes.common.annotation.LogOperation…

高通Android 12 右边导航栏改成底部显示

最近同事说需要修改右边导航栏到底部&#xff0c;问怎么搞&#xff1f;然后看下源码尝试下。 1、Android 12修改代码路径 frameworks/base/services/core/java/com/android/server/wm/DisplayPolicy.java a/frameworks/base/services/core/java/com/android/server/wm/Display…

树莓派4B_OpenCv学习笔记6:OpenCv识别已知颜色_运用掩膜

今日继续学习树莓派4B 4G&#xff1a;&#xff08;Raspberry Pi&#xff0c;简称RPi或RasPi&#xff09; 本人所用树莓派4B 装载的系统与版本如下: 版本可用命令 (lsb_release -a) 查询: Opencv 版本是4.5.1&#xff1a; 学了这些OpenCv的理论性知识&#xff0c;不进行实践实在…

Roboflow 图片分类打标

今天准备找个图片标注工具&#xff0c;在网上搜了一下&#xff0c;看 Yolo 的视频中都是用 Roboflow 工具去尝试了一下&#xff0c;标注确实挺好用的&#xff0c;可以先用一些图片训练一个模型&#xff0c;随后用模型进行智能标注。我主要是做标注然后到处到本地进行模型的训练…

springboot的WebFlux 和Servlet

Spring Boot 中的 Servlet 定义&#xff1a; 在 Spring Boot 中&#xff0c;Servlet 应用程序通常基于 Spring MVC&#xff0c;它是一个基于 Servlet API 的 Web 框架。Spring MVC 提供了模型-视图-控制器&#xff08;MVC&#xff09;架构&#xff0c;用于构建 Web 应用程序。…

【Mac】增加 safari 体验的插件笔记

Safari 本身的功能不全面&#xff0c;探索积累了一点插件笔记&#xff0c;提升使用体验&#xff1b;但后面因为插件或会影响运行速度&#xff0c;就全部都禁止了。做个笔记记录一下。 Cascadea 相当于 stylus&#xff0c;可以自定义页面。测试过几个&#xff0c;只有几个可行。…

Java:爬虫htmlunit抓取a标签

如果对htmlunit还不了解的话可以参考Java&#xff1a;爬虫htmlunit-CSDN博客 了解了htmlunit之后&#xff0c;我们再来学习如何在页面中抓取我们想要的数据&#xff0c;我们在学习初期可以找一些结构比较清晰的网站来做测试爬取&#xff0c;首先我们随意找个网站如下&#xff…

【StableDiffusion】Embedding 底层原理,Prompt Embedding,嵌入向量

Embedding 是什么&#xff1f; Embedding 是将自然语言词汇&#xff0c;映射为 固定长度 的词向量 的技术 说到这里&#xff0c;需要介绍一下 One-Hot 编码 是什么。 One-Hot 编码 使用了众多 5000 长度的1维矩阵&#xff0c;每个矩阵代表一个词语。 这有坏处&#xff0c…

美国空军发布类ChatGPT产品—NIPRGPT

6月11日&#xff0c;美国空军研究实验室&#xff08;AFRL&#xff09;官网消息&#xff0c;空军部已经发布了一款生成式AI产品NIPRGPT。 据悉&#xff0c;NIPRGPT是一款类ChatGPT产品&#xff0c;可生成文本、代码、摘要等内容&#xff0c;主要为为飞行员、文职人员和承包商提…

Python 中浅拷贝(copy)和深拷贝(deepcopy)

1. 浅拷贝&#xff1a; 它创建一个新的对象&#xff0c;但对于原始对象内的子对象&#xff08;如列表中的嵌套列表&#xff09;&#xff0c;只是复制了引用。例如&#xff1a; import copy original_list [1, 2, 3] shallow_copied_list copy.copy(original_list) original_…

【PIXEL】2024年 Pixel 解除 4G限制

首先在谷歌商店下载 Shizuku 和 pixel IMS 两个app 然后打开shizuku &#xff0c;按照它的方法启动 推荐用adb 启动&#xff08; 电脑连手机 &#xff0c;使用Qtscrcpy最简洁&#xff09; 一条指令解决 shell sh /storage/emulated/0/Android/data/moe.shizuku.privileged.ap…

Chrome/Edge浏览器视频画中画可拉动进度条插件

目录 前言 一、Separate Window 忽略插件安装&#xff0c;直接使用 注意事项 插件缺点 1 .无置顶功能 2.保留原网页&#xff0c;但会刷新原网页 3.窗口不够美观 二、弹幕画中画播放器 三、失败的尝试 三、Potplayer播放器 总结 前言 平时看一些视频的时候&#xff…

ListView的使用

&#x1f4d6;ListView的使用 ✅1. 创建ListView✅2. 创建适配器Adapter✅3. 开始渲染数据 主要3步骤&#xff1a; 创建ListView 创建适配器Adapter&#xff0c;和Adapter对应的视图 开始渲染数据 效果图&#xff1a; ✅1. 创建ListView 例如现有DemoActivity页面&#xf…

C# WinForm —— 33 ContextMenuStrip介绍

1. 简介 右键某个控件/窗体时&#xff0c;弹出来的菜单&#xff0c;比如VS中右键窗体&#xff0c;弹出来的这个菜单&#xff1a; 和MenuStrip类似&#xff0c;ContextMenuStrip主菜单下面可以有子菜单&#xff0c;子菜单下面可以有下一级子菜单 2. 属性 和MenuStrip一样 …

国内服务器安装 Docker 服务和拉取 dockerhub 镜像

前提: 有一台海外的VPS,目的是安装dockerhub镜像.适用debian系统 1: 安装 docker-ce (国内服务器) sudo apt-get update sudo apt-get install ca-certificates curl sudo install -m 0755 -d /etc/apt/keyrings sudo curl -fsSL https://mirrors.aliyun.com/docker-ce/linux/…

如何免费用 Qwen2 辅助你翻译与数据分析?

对于学生用户来说&#xff0c;这可是个好消息。 开源 从前人们有一种刻板印象——大语言模型里好用的&#xff0c;基本上都是闭源模型。而前些日子&#xff0c;Meta推出了Llama3后&#xff0c;你可能已经从中感受到现在开源模型日益增长的威力。当时我也写了几篇文章来介绍这个…

【DevOps】Ubuntu基本使用教程

目录 引言 Ubuntu简介 安装Ubuntu 准备工作 创建启动盘 安装过程 桌面环境 基本操作 定制桌面 文件管理 文件操作 文件权限 软件管理 安装软件 更新软件 系统设置 用户账户 网络设置 电源管理 命令行操作 常用命令 管理权限 安全与维护 系统更新 备份…