【深度学习】用Pytorch完成MNIST手写数字数据集的训练和测试

news2025/1/18 16:46:56

模型训练相关

思路:

  1. 导入数据集(对数据集转换为张量)
  2. 加载数据集(使数据集成为可以进行迭代)
  3. 搭建卷积模型
  4. 进行模型训练(每训练一轮查看一次在测试集上的准确率)
  5. 使用tensorboard进行可视化
  6. 保存训练后的模型
  7. 加载训练好的模型进行测试.

选择的模型结构
imagepng

训练模型和评估模型

在conda命令行输入
tensorboard --logdir=“tensorboard --logdir=D:\student\ai-study\02框架学习\logs\mnist”
打开可视化面板
imagepng
imagepng
在测试集上的准确率不断上升

训练和评估完整代码

"""
@author:Lunau
@file:022_mnist.py
@time:2024/08/07
@任务:使用pytorch对mnist数据集进行训练和测试
"""
import torch
import torchvision
import time
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

"""
@root:存放数据集的目录
@train:为True表示是作为训练集
@transforms:导入数据集的同时进行预处理
@download:为True表示从网络下载模型
"""
# 导入数据集 这里没有归一化
train_data = torchvision.datasets.MNIST("./dataset/MNIST", train=True, transform=
                                        torchvision.transforms.ToTensor(), download=True)
test_data = torchvision.datasets.MNIST("./dataset/MNIST", train=False, transform=
                                       torchvision.transforms.ToTensor(), download=True)


# # 加载数据集,方便进行迭代
train_dataloader = DataLoader(dataset=train_data, batch_size=64)
test_dataloader = DataLoader(dataset=test_data, batch_size=64)
# img, target = test_data[0]
# print(img)
# print(target)

# 构建卷积层
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=10, kernel_size=5, padding=0, stride=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
            nn.Conv2d(in_channels=10, out_channels=20, kernel_size=5, padding=0, stride=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
            nn.Flatten(),
            nn.Linear(in_features=320, out_features=10)
        )

    def forward(self, x):
        x = self.model(x)
        return x

# 创建卷积模型
model = Model()

# 损失函数 交叉熵
loss_fn = nn.CrossEntropyLoss()

# 优化器
learning_rate = 0.01
optimizer = torch.optim.SGD(params=model.parameters(), lr=learning_rate)


# 可视化
writer = SummaryWriter("./logs/mnist")

# 训练网络的参数
total_train_step = 0  # 训练次数
total_test_step = 0  # 测试次数

# 训练
def train():
    # 训练步骤开始
    model.train()
    global total_train_step
    for data in train_dataloader:
        imgs, targets = data
        outputs = model(imgs)
        loss = loss_fn(outputs, targets)  # 计算当前损失
        # 优化器进行优化
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_train_step += 1
        if total_train_step % 100 ==0:
            writer.add_scalar("train_loss", loss, total_train_step) # 可视化每轮的损失
            print(f"训练次数:{total_train_step}, Loss:{loss}")

def test():
    global total_test_step

    model.eval()
    total_test_loss = 0
    total_accuracy = 0
    test_data_len = len(test_data)

    with torch.no_grad():
        for data in test_dataloader:
            imgs, targets = data
            output = model(imgs)
            loss = loss_fn(output, targets)
            total_test_loss +=loss
            accuracy = (output.argmax(1) == targets).sum().item() # 计算出正确的次数
            total_accuracy+=accuracy
        total_accuracy = total_accuracy / test_data_len
    # 在整体测试集上的损失
    print(f"整体测试损失Loss:{total_test_loss}")
    # 整体测试的正确率
    print(f"整体测试的正确率acc:{total_accuracy}")

    writer.add_scalar("test_loss", total_test_loss, total_test_step)
    writer.add_scalar("test_acc", total_accuracy, total_test_step)
    total_test_step +=1
    return total_accuracy



if __name__ == '__main__':

    epoch = 10  # 训练的轮数
    for i in range(epoch):
        print(f"第{i + 1}轮训练开始")
        train()
        test()
    torch.save(model, "mnist1.pth")

测试模型

测试的照片
mnist3jpgmnist5jpgmnist9jpg

测试代码

"""
@author:Lunau
@file:023_mnist_test.py
@time:2024/08/07
"""
import cv2
import torch
import torchvision
import torchvision.transforms as transforms
from PIL import Image
from torch import nn

# 测试模型

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=10, kernel_size=5, padding=0, stride=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
            nn.Conv2d(in_channels=10, out_channels=20, kernel_size=5, padding=0, stride=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
            nn.Flatten(),
            nn.Linear(in_features=320, out_features=10)
        )

    def forward(self, x):
        x = self.model(x)
        return x
# 测试单张照片
def test_one_image():
    image_path = "./images/mnist_3.jpg"
    image = Image.open(image_path)
    print(image)
    image = image.convert('L')
    image.show()

    # 图片进行转换
    transform = torchvision.transforms.Compose([
        torchvision.transforms.Resize((28, 28)),
        torchvision.transforms.ToTensor()
    ])

    image = transform(image)
    print(image.shape)

    # 加载模型 若模型是在gpu训练出来,需要在cpu上运行需要进行一个映射
    model = torch.load("./mnist1.pth")
    image = torch.reshape(image, (1, 1, 28, 28))  # 转换一下尺寸,为输入要求的尺寸
    # 测试
    model.eval()
    with torch.no_grad():
        output = model(image)
    print(output)
    print(f"手写数字是:{output.argmax(1).item()}")

test_one_image()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1989968.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

MySQL3 DQL数据查询语言

DQL SQL-DQL重要地位简单查询selectjia简单查询数据准备别名(AS)消除重复行(DISTINCT去重)算数运算符0.优先级1.算数运算符2.比较运算符3.逻辑运算符4.位运算符 空值空值参与运算 条件查询普通条件查询特殊比较运算符BETWEEN...AND...INLIKEIS NULLleast,greatest运…

Unity补完计划 之 SpriteEditer SingleMode

本文仅作笔记学习和分享,不用做任何商业用途 本文包括但不限于unity官方手册,unity唐老狮等教程知识,如有不足还请斧正 因为unity不只是3d需要,还有2d游戏需要大量编辑处理图片素材,所以需要了解Sprite(精灵…

ASC格式的协议数据解析

函数来自RTT的AT组件 - at_client.c 例如,数据是 CGREG: 0,1,通过at_resp_parse_line_args_by_kw把1赋予link_stat。 简化从AT响应中提取信息的过程,使得编写与硬件通信的代码更加简洁和易于维护。 这么提数据也太方便了 at_resp_parse_l…

结构体练习作业

作业一:结构体数组存储学生信息(姓名,年龄,分数),完成输入学生信息,输出学生信息,求学生成绩之和,求最低学生成绩。 .h文件 main.c .c文件 输入信息 输出信息 平均值 最低值 作业二:在堆区,申…

STC-ISP升级MCU

STC-ISP升级mcu步骤: 1、RS232线连接电脑,芯片型号选择STC8H8K64U 2、波特率选择115200 3、IRC频率选择24MHz 4、设置EEPROM大小为64K 如下图设置: 插上RS232选择相应的COM口: 我这里的COM口是COM5. 打开程序文件&#xff1…

揭秘Redis的“隐藏武器”:跳跃表的原理与应用

1. 引言 1.1 Redis的快速崛起 Redis,全名为Remote Dictionary Server,是一个开源的高性能键值对存储系统,它提供了多种类型的数据结构,如字符串、列表、集合、有序集合等。由于其高性能、持久化选项以及丰富的特性,Re…

【已解决】如何获取到DF数据里最新的调薪时间,就是薪资最高且时间最早?

问题说明: 前几天在Python最强王者交流群【群除我佬】问了一个Pandas处理的问题,这里拿出来给大家分享下。 看上去不太好理解,其实说白了,就是在工资最高里,再找时间最早的。 换句话说就是,这三个人&…

益九未来CEO曾宪军:创新引领,打造智能售货机行业新标杆

在智能零售行业迅速发展的今天,益九未来(天津)科技发展有限公司正以其创新精神和前瞻性的战略布局,引领着智能售货机市场的潮流。而这一切的背后,离不开总经理(CEO)曾宪军先生的卓越领导和远见卓…

人类预期寿命数据-1960至2022年(世界各国与中国各省)

数据简介:人类预期寿命是指在特定年龄出生的人群,按照当前的死亡率水平,预期平均能够存活的年数。预期寿命衡量一个国家和地区卫生健康状况、社会经济发展水平和生活条件的重要参数,这次数据包含世界各国(1960-2022年&…

代理IP类型详细解析:那么多种协议的代理如何选?

代理IP已经成为跨境业务的得力工具,但是仍有许多新手小白在初次接触到代理IP服务商时,不知道具体如何选择代理IP类型,面对五花八门的代理类型名称,往往需要付出一定的试错成本才知道哪个适合自己的业务。今天就来给大家科普科普&a…

深度学习中的规范化-层规范化

文章目录 层规范化层规范化参数与公式normalized_shape传入一个整数接口函数LayerNorm计算手动计算 normalized_shape传入一个列表接口函数LayerNorm计算手动计算 层规范化 在批量规范化这篇文章里详细介绍了批量规范化在卷积神经网络里的使用,本篇文章将继续介绍另…

LVS中NAT模式和DR模式实战讲解

1DR模式 DR:Direct Routing,直接路由,LVS默认模式,应用最广泛,通过为请求报文重新封装一个MAC首部进行 转发,源MAC是DIP所在的接口的MAC,目标MAC是某挑选出的RS的RIP所在接口的MAC地址;源 IP/PORT&#xf…

C++:auto关键字、内联函数、引用、带默认形参值的函数、函数重载

一、auto关键字 在C中,auto关键字是一个类型说明符,用于自动类型推导。 使用 auto 关键字时,变量的类型 是在编译时由编译器 根据 初始化表达式 自动推导出来的。这意味着你 不能在 声明 auto 变量时 不进行初始化 声明 auto 变量时&#x…

“八股文”:是助力还是阻力?

在程序员面试中,“八股文”是一个绕不开的话题。所谓“八股文”,指的是那些在面试中经常出现的标准问题及其答案,例如“解释一下死锁的概念”、“CAP理论是什么”等。这些内容通常被求职者反复练习,以至于变成了某种固定的模式或套…

分享6类10种政务AI大模型应用场景

大模型是指具有大规模参数和复杂计算结构的机器学习模型。这些模型通常由深度神经网络构建而成,拥有数十亿甚至数千亿个参数。大模型的设计目的是提高模型的表达能力和预测性能,能够处理更加复杂的任务和数据。大模型在各种领域都有广泛的应用&#xff0…

adword — Recho | pwn题目记录

涉及到以前没接触过的点,记录下。 checksec: IDA: 很明显的一个栈溢出,但是一直有一个while循环,就算劫持控制流后也出不了这个循环。这里学到了一个新方法: pwntools的shutdown(send) def shutdown(…

C++现代教程四

float转string不带多余0 float a 1.2; std::tostring(a); // 1.200000 std::ostringstream strStream; strStream << a; // 1.2 if (!strStream.view().empty()) // 判定流有数据// 边框融合 float measureText(std::u8string text, FontTypes::Rectangle &recta…

Marin说PCB之1000-BASE-T1上的共模电感的选型知多少---02

今天刚刚好是立秋的第一天&#xff0c;天气还是有点炎热的。不知道诸位老铁们有没有买今年秋天的第一杯奶茶&#xff0c;反正小编我是下班到家吃饭的时候买了一杯伯牙绝弦&#xff0c;喝起来味道还是不错的&#xff0c;而且奶茶店里今天几乎爆满&#xff0c;我足足等了30分钟才…

计算机网络面试-核心概念-问题理解

目录 1.计算机网络OSI协议七层结构功能分别是什么&#xff1f;如何理解这些功能 2.物理层、数据链路层、网络层、传输层和应用层&#xff0c;这五个层之间功能的关系&#xff0c;或者说是否存在协调关系 3. 数据链路层功能理解 4.MAC地址和以太网协议 5.以太网协议中的CSMA…