神经网络-MNIST数据集训练

news2024/12/26 21:10:58

文章目录

  • 一、MNIST数据集
    • 1.数据集概述
    • 2.数据集组成
    • 3.文件结构
    • 4.数据特点
  • 二、代码实现
    • 1.数据加载与预处理
    • 2. 模型定义
    • 3. 训练和测试函数
    • 4.训练和测试结果
  • 三、总结

一、MNIST数据集

MNIST数据集是深度学习和计算机视觉领域非常经典且基础的数据集,它包含了大量的手写数字图片,通常用于训练各种图像处理系统,也被广泛用于机器学习领域的训练和测试。

1.数据集概述

  • 来源:MNIST数据集由Yann LeCun等人于1994年创建,它是NIST(美国国家标准与技术研究所)数据集的一个子集。
  • 内容:数据集主要包含手写数字(0~9)的图片及其对应的标签。
  • 用途:作为深度学习和计算机视觉领域的入门级数据集,它适合初学者练习建立模型、训练和预测。

2.数据集组成

MNIST数据集总共包含两个子数据集:训练数据集和测试数据集。

训练数据集:

  • 包含了60,000张28x28像素的灰度图像。
  • 对应的标签文件包含了60,000个标签,每个标签对应一张图像中的手写数字。

测试数据集:

  • 包含了10,000张28x28像素的灰度图像。
  • 对应的标签文件包含了10,000个标签。

3.文件结构

MNIST数据集包含四个文件,分别是训练集图像、训练集标签、测试集图像和测试集标签。这些文件以gzip格式压缩,并且不是标准的图像格式,需要通过专门的编程方式读取。

  • 训练集图像:train-images-idx3-ubyte.gz
  • 训练集标签:train-labels-idx1-ubyte.gz)
  • 测试集图像:t10k-images-idx3-ubyte.gz
  • 测试集标签:t10k-labels-idx1-ubyte.gz

4.数据特点

  • 图像大小:每张图像的大小为28x28像素,是一个灰度图像,位深度为8(灰度值范围为0~255)。
  • 数据来源:手写数字来自250个不同的人。
  • 数据格式:图像数据以字节的形式存储在二进制文件中,标签文件则存储了每张图像对应的数字标签。

二、代码实现

1.数据加载与预处理

import torch
from torch import nn  # 导入神经网络模块
from torch.utils.data import DataLoader  # 数据包管理工具,打包数据
from torchvision import datasets  # 封装了很多与图像相关的模型,数据集
from torchvision.transforms import ToTensor

"""下载训练集数据(包含训练图片和标签)"""
training_data = datasets.MNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),  # 张量,图片是不能直接传入神经网络模型
)

"""下载测试集数据(包括训练图片和标签)"""
test_data = datasets.MNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
)
train_dataloader = DataLoader(training_data, batch_size=64)  # 64张图片为一个包
test_dataloader = DataLoader(test_data, batch_size=64)
  • 下载数据集:使用torchvision.datasets.MNIST下载并加载MNIST数据集。数据集分为训练集和测试集,train=True为训练集数据,train=False为测试集数据。
  • 数据转换:数据通过transform=ToTensor()进行预处理,将图片转换为PyTorch张量(Tensor),并自动将像素值归一化到[0,1]区间。
  • 数据封装:使用DataLoader将数据集封装成批次(batch)形式,便于后续的训练和测试过程。

2. 模型定义

class NeuralNetwork(nn.Module):  # 通过调用类的形式来使用神经网络,神经网络的模型,nn.module
    def __init__(self):  # python基础关于类,self类自己本身
        super().__init__()  # 继承的父类初始化
        self.flatten = nn.Flatten()  # 展开,创建一个展开对象flatten
        self.hidden1 = nn.Linear(28 * 28, 128)  # 第1个参数:有多少个神经元传入进来,第2个参数:有多少个数据传出去前一层神经元的个数,当前本层神经元个数
        self.hidden2 = nn.Linear(128, 256)
        self.hidden3 = nn.Linear(256, 128)
        self.out = nn.Linear(128, 10)

    def forward(self, x):  # 前向传播,告诉它,数据的流向。
        x = self.flatten(x)  # 图像进行展开
        x = self.hidden1(x)
        x = torch.sigmoid(x) 
        x = self.hidden2(x)
        x = torch.sigmoid(x)
        x = self.hidden3(x)
        x = torch.sigmoid(x)
        x = self.out(x)
        return x


model = NeuralNetwork().to(device)  # 把刚刚创建的模型传入到gpu
print(model)

定义类:定义了一个名为NeuralNetwork的类,该类继承自nn.Module,用于构建神经网络模型。
模型结构:模型包含输入层,输出层,隐藏层,其中隐藏层使用了Sigmoid激活函数,最后输出10个类别的得分(对应0-9的数字)
打印模型结构:打印了模型的结构,有助于理解模型的架构。
在这里插入图片描述

3. 训练和测试函数

def train(dataloader, model, loss_fn, optimizer):
    model.train()
    batch_size_num = 1
    for X, y in dataloader:  # 其中batch为每一个数据的编号
        X, y = X.to(device), y.to(device)  # 把训练数据集和标签传入cpu或GPU
        pred = model.forward(X)  # .forward可以被省略,父类中已经对次功能进行了设置。自动初始化w权值
        loss = loss_fn(pred, y)  # 通过交叉熵损失函数计算损失值loss
        # Backpropaqation 进来-个bqtch的数据,计算一次梯度,更新一次网络
        optimizer.zero_grad()  # 梯度值清零
        loss.backward()  # 反向传播计算得到每个参数的梯度值w
        optimizer.step()  # 根据梯度更新网络w参数

        loss_value = loss.item()  # 从tensor数据中提取数据出来,tensor获取损失值
        if batch_size_num % 100 == 0:
            print(f"loss:{loss_value:>7f}  [number:{batch_size_num}]")
        batch_size_num += 1


def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()  # 测试,w就不能再更新。
    test_loss, correct = 0, 0
    with torch.no_grad():  # 一个上下文管理器,关闭梯度计算。
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model.forward(X)
            test_loss += loss_fn(pred, y).item()  # test loss是会自动累加每一个批次的损失值
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
            a = (pred.argmax(1) == y)  # dim=1表示每一行中的最大值对应的索引号,dim=0表示每一列中的最大值对应的索引号
            b = (pred.argmax(1) == y).type(torch.float)  # 把预测值Ture、False 转换为01
    test_loss /= num_batches  # 评判模型的好坏
    correct /= size  # 平均的准确率
    print(f"Test result:\n Accuracy:{(100 * correct)}%,Avg loss:{test_loss}")
  • train函数负责训练模型。它遍历训练数据集的每个批次,计算模型的预测、损失,并执行反向传播和参数更新。
  • test函数用于评估模型在测试集上的性能。它遍历测试数据集的每个批次,计算模型的预测和损失,但不进行反向传播或参数更新。
  • 在训练和测试过程中,都使用了torch.no_grad()上下文管理器来关闭梯度计算,这可以节省内存和计算资源。

4.训练和测试结果

loss_fn = nn.CrossEntropyLoss()

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)  # 创建一个优化器,S6D为随机梯度下降算法

epochs = 10
for t in range(epochs):
    print(f"Epoch {t + 1}\n-------------------------")
    train(train_dataloader, model, loss_fn, optimizer)
print("Done!")
test(test_dataloader, model, loss_fn)

  • 使用torch.optim.Adam优化器来优化模型的参数,这里的学习率设置为0.01。
  • 定义了训练轮次(epochs),并在每个epoch中调用train函数来训练模型。
  • 最后,使用test函数来评估模型在测试集上的性能,并打印出准确率和平均损失。
    在这里插入图片描述

三、总结

本文为大家介绍了MNIST数据集的组成、文件结构与数据集特点,然后为大家提供了MNIST数据集训练的相关代码,通过对数据集进行处理,训练来得出准确率与损失率,为大家更好的展示。总之,MNIST数据集是深度学习和计算机视觉领域不可或缺的基础数据集之一,对于初学者来说是一个非常好的练手项目,同时也为相关领域的研究和实验提供了宝贵的数据资源。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2143689.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

链表的合并,结点逆置,顺序表的高效划分(数据结构作业02)

目录 链表的合并 链表的结点逆置 顺序表的高效划分 链表的合并 已知两个递增有序的单链表A和B,分别表示两个集合。试设计一个算法,用于求出A与B的交集,并存储在C链表中。例如 : La {2,4,6,8};…

闯关leetcode——28. Find the Index of the First Occurrence in a String

大纲 题目地址内容 解题代码地址 题目 地址 https://leetcode.com/problems/find-the-index-of-the-first-occurrence-in-a-string/description/ 内容 Given two strings needle and haystack, return the index of the first occurrence of needle in haystack, or -1 if …

冒泡排序的C++语言实现(不用std::sort)

自己写一个冒泡排序的代码。 void vSort(std::vector<int> & vec, bool bDescending) {//冒泡排序int iTail vec.size()-1;while(iTail > 0){for(int k 0; k < iTail; k){int f1 vec.at(k);int f2 vec.at(k1);if(f1 < f2){//默认是降序int iTmp vec.a…

为什么大公司不用pandas取代excel?

如果你熟练使用Excel的话&#xff0c;你就会发现有些pandas的功能&#xff0c;在Excel中也可以实现&#xff0c;而且对比下来&#xff0c;Excel操作更简单&#xff0c;从效率上跟pandas更无二致&#xff0c;这样Excel的优势就比较突出了&#xff0c;比如下面使用pandas和Excel分…

【实战篇】幻读是什么,幻读有什么问题?

背景 我们先使用一个小一点儿的表。建表和初始化语句如下&#xff1a; CREATE TABLE t (id int(11) NOT NULL,c int(11) DEFAULT NULL,d int(11) DEFAULT NULL,PRIMARY KEY (id),KEY c (c) ) ENGINEInnoDB; insert into t values(0,0,0),(5,5,5), (10,10,10),(15,15,15),(20,…

2010-2022 CSP-J/普及组T1-T4考点统计

T1考点统计 T2考点统计 T3考点统计 T4考点统计 总结

MOE论文汇总2

TASK-CUSTOMIZED MASKED AUTOENCODER VIA MIXTURE OF CLUSTER-CONDITIONAL Experts 这篇论文提出了一种新颖的自监督学习方法&#xff0c;名为“Mixture of Cluster-conditional Experts (MoCE)”&#xff0c;旨在解决传统Masked Autoencoder (MAE)在不同下游任务中可能遇到的负…

蓝桥杯-STM32G431RBT6(UART解析字符串sscanf和解决串口BUG)

一、C语言常识 printf和sprintf的主要区别在于它们的功能和用途&#xff1a; printf&#xff1a;主要用于将格式化的数据输出到标准输出&#xff08;如屏幕&#xff09;。sprintf&#xff1a;则是将格式化的数据存储到一个指定的字符串缓冲区中&#xff0c;而不是直接输出。 pr…

Docker实操:安装MySQL5.7详解(保姆级教程)

介绍 Docker 中文网址: https://www.dockerdocs.cn Docker Hub官方网址&#xff1a;https://hub.docker.com Docker Hub中MySQL介绍&#xff1a;https://hub.docker.com/_/mysql ​ 切换到“Tags”页面&#xff0c;复制指定的MySQL版本拉取命令&#xff0c;例如 &#xff1a…

LabVIEW提高开发效率技巧----使用LabVIEW工具

LabVIEW为开发者提供了多种工具和功能&#xff0c;不仅提高工作效率&#xff0c;还能确保项目的质量和可维护性。以下详细介绍几种关键工具&#xff0c;并结合实际案例说明它们的应用。 1. VI Analyzer&#xff1a;自动检查代码质量 VI Analyzer 是LabVIEW提供的一款强大的工…

架构师,被严重低估的角色!

在企业数字化转型与变革的壮阔浪潮中&#xff0c;企业架构&#xff08;Enterprise Architecture&#xff0c;EA&#xff09;作为一门高度复杂且跨学科的知识体系&#xff0c;无可争议地成为了驱动组织战略深化与技术创新的核心引擎。尽管市场上充斥着丰富的指导理论与参考资料&…

202409012在飞凌的OK3588-C的核心板上使用Rockchip原厂的Buildroot点MIPI屏【背光篇】

202409012在飞凌的OK3588-C的核心板上使用Rockchip原厂的Buildroot点MIPI屏【背光篇】 2024/9/12 10:44 缘起&#xff0c;拿到一块MIPI屏&#xff0c;需要使用飞凌的OK3588-C的核心板在Android12下点亮。 在飞凌的Linux R4下修改部分屏参之后即可直接点亮。 但是在飞凌的Andro…

Java笔记-MinIO Java SDK的使用

此博文内容为&#xff1a; 使用SDK创建bucket&#xff1b; 使用SDK上传文件&#xff1b; 使用SDK下载文件。 maven添加&#xff1a; <dependency><groupId>io.minio</groupId><artifactId>minio</artifactId><version>8.5.2</versi…

Linux使用Clash,clash-for-linux

文件下载 clash-for-linuxhttps://link.zhihu.com/?targethttps%3A//zywang.lanzn.com/ijE2a1m7h6mb&#xff08;百度和阿里云盘都不支持这个文件分享&#xff09;。 使用须知 - 此项目不提供任何订阅信息&#xff0c;请自行准备Clash订阅地址。 - 运行前请手动更改.env文件…

嵌入式开发—CAN通信协议详解与应用(中)

书接上回&#xff1a;嵌入式开发—CAN通信协议详解与应用&#xff08;上&#xff09; 文章目录 CAN通讯中的位时间和位同步位时间的构成采样点 位时间的计算公式时间量子&#xff08;Time Quantum, TQ&#xff09;位时间的阶段示意图位同步机制 CAN通信中的仲裁规则仲裁规则的…

03-Mac系统PyCharm主题设置

目录 1. 打开PyCharm窗口 2. Mac左上角点击PyCharm&#xff0c;点击Settings 3. 点击第一项Appearance& Behavior 4. 点击Appearance 5. 找到Theme进行设置 1. 打开PyCharm窗口 2. Mac左上角点击PyCharm&#xff0c;点击Settings 3. 点击第一项Appearance& Behavi…

【例题】lanqiao4425 咖啡馆订单系统

样例输入 3 2 2 1 3 1 2样例输出 3 2样例说明 输入的数组为&#xff1a;【3&#xff0c;1&#xff0c;2】 增量序列为&#xff1a;【2&#xff0c;1】 当增量 h2&#xff1a;对于每一个索引 i&#xff0c;我们会将数组元素 arr[i] 与 arr[i−h] 进行比较&#xff0c;并进行可…

Stable Diffusion绘画 | ControlNet应用-IP-Adapter:堪比 Midjourney 垫图

IP-Adapter 是腾讯AI实验室研发的控制器&#xff0c;属于 ControlNet 最强控制器前三之一。 如果想参照图片的风格&#xff0c;生成各种各样类似效果的图片&#xff0c;就可以用到 IP-Adapter。 在 ControlNet 单元中上传一张图片&#xff1a; 不输入任何提示词&#xff0c;出图…

MySQL数据库:掌握备份与恢复的艺术,确保数据安全无忧

作者简介&#xff1a;我是团团儿&#xff0c;是一名专注于云计算领域的专业创作者&#xff0c;感谢大家的关注 座右铭&#xff1a; 云端筑梦&#xff0c;数据为翼&#xff0c;探索无限可能&#xff0c;引领云计算新纪元 个人主页&#xff1a;团儿.-CSDN博客 目录 前言&#…

波导阵列天线学习笔记 馈电网络1 使用X型全公共波导馈网的毫米波大规模天线阵列的带宽提升

摘要&#xff1a; 全公共波导馈网的一次反射等效模型被研究用于提出一种毫米波大规模天线阵列带宽提升的新方法。理论分析显示由馈电网络拓扑造成的指定频率的多级小反射的同相叠加现象是影响大规模阵列的可实现带宽的重要因素&#xff0c;除了包含阵列的独立功分器和反射器的带…