【Pytorch】16.使用ImageFolder加载自定义MNIST数据集训练手写数字识别网络(包含数据集下载)

news2024/7/30 7:29:40

数据集下载

MINST_PNG_Training在github的项目目录中的datasets中有MNIST的png格式数据集的压缩包

用于训练的神经网络模型

在这里插入图片描述

自定义数据集训练

在前文【Pytorch】13.搭建完整的CIFAR10模型我们已经知道了基本搭建神经网络的框架了,但是其中的数据集使用的torchvision中的CIFAR10官方数据集进行训练的

train_dataset = torchvision.datasets.CIFAR10('../datasets', train=True, download=True,
                                             transform=torchvision.transforms.ToTensor())
test_dataset = torchvision.datasets.CIFAR10('../datasets', train=False, download=True,
                                            transform=torchvision.transforms.ToTensor())

在这里插入图片描述

本文将用图片格式的数据集进行训练
在这里插入图片描述
我们通过

# Dataset CIFAR10
#     Number of datapoints: 60000
#     Root location: ../datasets
#     Split: Train
#     StandardTransform
# Transform: ToTensor()
print(train_dataset)

可以看到我们下载的数据集是这种格式的,所以我们的主要问题就是如何将自定义的数据集获取,并且转化为这种形式,剩下的步骤就和上文相同了

数据类型进行转化

我们的首要目的是,根据数据集的地址,分别将数据转化为train_datasettest_dataset
我们需要调用ImageFolder方法来进行操作

from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
from torchvision.datasets import ImageFolder
from model import *

# 训练集地址
train_root = "../datasets/mnist_png/training"
# 测试集地址
test_root = '../datasets/mnist_png/testing'

# 进行数据的处理,定义数据转换
data_transform = transforms.Compose([transforms.Resize((28, 28)),
                                     transforms.Grayscale(),
                                     transforms.ToTensor()])


# 加载数据集
train_dataset = ImageFolder(train_root, transform=data_transform)
test_dataset = ImageFolder(test_root, transform=data_transform)

首先我们需要将数据进行处理,通过transforms.Compose获取对象data_transform
其中进行了三步操作

  • 将图片大小变为28*28像素便于输入网络模型
  • 将图片转化为灰度格式,因为手写数字识别不需要三通道的图片,只需要灰度图像就可以识别,而png格式的图片是四通道
  • 将图片转化为tensor数据类型

然后通过ImageFolder给出图片的地址与转化类型,就可以实现与我们在官方下载数据集相同的格式

# Dataset ImageFolder
#     Number of datapoints: 60000
#     Root location: ../datasets/mnist_png/training
#     StandardTransform
# Transform: Compose(
#                Resize(size=(28, 28), interpolation=bilinear, max_size=None, antialias=True)
#                ToTensor()
#            )
print(train_dataset)

其他与前文【Pytorch】13.搭建完整的CIFAR10模型基本相同

完整代码

网络模型

import torch
from torch import nn


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(2, stride=2)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(2, stride=2)
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(3136, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.pool1(x)
        x = self.conv2(x)
        x = self.relu2(x)
        x = self.pool2(x)
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.fc2(x)
        return x


if __name__ == "__main__":
    model = Net()
    input = torch.ones((1, 1, 28, 28))
    output = model(input)
    print(output.shape)

训练过程

from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
from torchvision.datasets import ImageFolder
from model import *

# 训练集地址
train_root = "../datasets/mnist_png/training"
# 测试集地址
test_root = '../datasets/mnist_png/testing'

# 进行数据的处理,定义数据转换
data_transform = transforms.Compose([transforms.Resize((28, 28)),
                                     transforms.Grayscale(),
                                     transforms.ToTensor()])


# 加载数据集
train_dataset = ImageFolder(train_root, transform=data_transform)
test_dataset = ImageFolder(test_root, transform=data_transform)

# Dataset ImageFolder
#     Number of datapoints: 60000
#     Root location: ../datasets/mnist_png/training
#     StandardTransform
# Transform: Compose(
#                Resize(size=(28, 28), interpolation=bilinear, max_size=None, antialias=True)
#                ToTensor()
#            )
# print(train_dataset)

# print(train_dataset[0])


train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=True)


device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

model = Net().to(device)
loss_fn = nn.CrossEntropyLoss().to(device)
learning_rate = 0.01
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

epoch = 10

writer = SummaryWriter('../logs')
total_step = 0

for i in range(epoch):
    model.train()
    pre_step = 0
    pre_loss = 0
    for data in train_loader:
        images, labels = data
        images = images.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = loss_fn(outputs, labels)
        loss.backward()
        optimizer.step()
        pre_loss = pre_loss + loss.item()
        pre_step += 1
        total_step += 1
        if pre_step % 100 == 0:
            print(f"Epoch: {i+1} ,pre_loss = {pre_loss/pre_step}")
            writer.add_scalar('train_loss', pre_loss / pre_step, total_step)

    model.eval()
    pre_accuracy = 0
    with torch.no_grad():
        for data in test_loader:
            images, labels = data
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            pre_accuracy += outputs.argmax(1).eq(labels).sum().item()
    print(f"Test_accuracy: {pre_accuracy/len(test_dataset)}")
    writer.add_scalar('test_accuracy', pre_accuracy / len(test_dataset), i)
    torch.save(model, f'../models/model{i}.pth')

writer.close()

参考文章

【CNN】搭建AlexNet网络——并处理自定义的数据集(猫狗分类)
How to download MNIST images as PNGs

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1689528.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

vue实现加入购物车动效

实现 实现逻辑: 点击添加购物车按钮时,获取当前点击位置event的clientX 、clientY;动态创建移动的小球,动态计算小球需要移动到的位置(通过ref 的getBoundingClientRect获取统计元素按钮位置)&#xff1b…

JS 网页密码框验证信息

<!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"><title>Document</title><style>/* 当没有密码…

【错误解决】使用HuggingFaceInstructEmbeddings时的一个错误

起因&#xff1a;使用huggingface构建一个问答程序时出现的问题。 错误内容&#xff1a; 分析&#xff1a; 查看代码发现&#xff0c;HuggingFaceInstructEmbeddings和sentence-transformers模块版本不兼容导致。 可以明显看到方法参数不同。 解决&#xff1a; 安装sentenc…

MySQL主从复制(docker搭建)

文章目录 1.MySQL主从复制配置1.主服务器配置1.拉取mysql5.7的镜像2.启动一个主mysql&#xff0c;进行端口映射和目录挂载3.进入/mysql5.7/mysql-master/conf中创建my.cnf并写入主mysql配置1.进入目录2.执行命令写入配置 4.重启mysql容器&#xff0c;使配置生效5.进入主mysql&a…

c#自动生成缺陷图像-添加新功能(可从xml直接提取目标数据,然后进行数据离线增强)--20240524

在进行深度学习时,数据集十分重要,尤其是负样本数据。 故设计该软件进行深度学习数据预处理,最大可能性获取较多的模拟工业现场负样本数据集。 该软件基于VS2015、.NETFrameWork4.7.2、OpenCvSharp1.0.0.0、netstandard2.0.0.0、SunnyUI3.2.9.0、SunnyUI.Common3.2.9.0及Ope…

Android studio的Gradle出问题

Gradle sync failed: Plugin [id: com.android.application, version: 7.1.1, apply: false] was not found in any of the following sources: 在src里面的build.gradle中 plugins { id ‘com.android.application’ } 的上面加上 buildscript {repositories {jcenter()}depen…

数字驱动,教育先行——低代码揭秘教育机构管理数字化转型

数字化时代为教育带来了许多变革和挑战&#xff0c;同时也为教育创新提供了无限可能。数字化转型可以帮助教育机构应对这些变革和挑战&#xff0c;提高教育效率和质量&#xff0c;满足学生个性化需求&#xff0c;优化教育管理和服务&#xff0c;并提高教育机构的竞争力。 并且…

【译】MySQL复制入门: 探索不同类型的MySQL复制解决方案

原文地址&#xff1a;An Introduction to MySQL Replication: Exploring Different Types of MySQL Replication Solutions 在这篇博文中&#xff0c;我将深入介绍 MySQL 复制&#xff0c;回答它是什么、如何工作、它的优势和挑战&#xff0c;并回顾作为 MySQL 环境&#xff0…

131. 面试中关于架构设计都需要了解哪些内容?

文章目录 一、社区系统架构组件概览1. 系统拆分2. CDN、Nginx静态缓存、JVM本地缓存3. Redis缓存4. MQ5. 分库分表6. 读写分离7. ElasticSearch 二、商城系统-亿级商品如何存储三、对账系统-分布式事务一致性四、统计系统-海量计数六、系统设计 - 微软1、需求收集2、顶层设计3、…

开源大模型与闭源大模型:谁主沉浮?

目录 &#x1f349;引言 &#x1f349;数据隐私 &#x1f348;开源大模型的优势与挑战 &#x1f34d;优势&#xff1a; &#x1f34d;挑战&#xff1a; &#x1f348;闭源大模型的优势与挑战 &#x1f34d;优势&#xff1a; &#x1f34d;挑战&#xff1a; &#x1f34…

【设计模式深度剖析】【2】【创建型】【工厂方法模式】

&#x1f448;️上一篇:单例模式 | 下一篇:抽象工厂模式&#x1f449;️ 目录 工厂方法模式概览工厂方法模式的定义英文原话直译 工厂方法模式的4个角色抽象工厂&#xff08;Creator&#xff09;角色具体工厂&#xff08;Concrete Creator&#xff09;角色抽象产品&#x…

2001-2022年全国31省份互联网发展47个指标合集各省电信业务信息化软件信息技术服务业

全国31省份互联网发展47个指标合集各省电信业务信息化软件信息技术服务业&#xff08;2001-2022年&#xff09;插值填补无缺失 整理了各省电信业务、从业人员、电信通信、互联网发展、企业信息化、软件和信息技术服务业等47个互联网主要发展指标&#xff0c;内含原始数据、线性…

Web前端一套全部清晰 ⑨ day5 CSS.4 标准流、浮动、Flex布局

我走我的路&#xff0c;有人拦也走&#xff0c;没人陪也走 —— 24.5.24 一、标准流 标准流也叫文档流&#xff0c;指的是标签在页面中默认的排布规则&#xff0c;例如:块元素独占一行&#xff0c;行内元素可以一行显示多个。 二、浮动 作用: 让块级元素水平排列。 属性名:floa…

LeetCode1161最大内层元素和

题目描述 给你一个二叉树的根节点 root。设根节点位于二叉树的第 1 层&#xff0c;而根节点的子节点位于第 2 层&#xff0c;依此类推。请返回层内元素之和 最大 的那几层&#xff08;可能只有一层&#xff09;的层号&#xff0c;并返回其中 最小 的那个。 解析 在上一题&…

微信小程序报错:notifyBLECharacteristicValueChange:fail:nodescriptor的解决办法

文章目录 一、发现问题二、分析问题二、解决问题 一、发现问题 微信小程序报错&#xff1a;notifyBLECharacteristicValueChange:fail:nodescriptor 二、分析问题 这个提示有点问题&#xff0c;应该是该Characteristic的Descriptor有问题&#xff0c;而不能说nodescriptor。 …

【传知代码】Modnet 人像抠图-论文复现

文章目录 概述原理介绍核心逻辑ModNet 的结构 环境配置WebUI 小结 论文地址 论文GitHub 本文涉及的源码可从Modnet 人像抠图该文章下方附件获取 概述 人像抠图技术在多个领域有着广泛的应用场景&#xff0c;包括但不限于&#xff1a; 展馆互动拍照&#xff1a;展馆中使用的抠…

二叉树的递归实现及例题

目录 遍历方式 示例 原理 前序遍历示例 二叉树的节点个数 原理 层序遍历 原理 这样做的目的是 判断完全二叉树 例题 ​编辑 思路 代码 遍历方式 二叉树的遍历方式可分为&#xff1a; 前序遍历&#xff1a;先访问根&#xff0c;访问左子树&#xff0c;在访问右子…

2024.05.24 学习记录

1、面经复习&#xff1a; js基础、知识深度、js垃圾回收 2、代码随想录刷题&#xff1a;动态规划 完全背包 all 3、rosebush 完成 Tabs、Icon、Transition组件

C++中获取int最大与最小值

不知道大家有没有遇到过这种要求&#xff1a;“返回值必须是int&#xff0c;如果整数数超过 32 位有符号整数范围 [−2^31, 2^31 − 1] &#xff0c;需要截断这个整数&#xff0c;使其保持在这个范围内。例如&#xff0c;小于 −2^31 的整数应该被固定为 −2^31 &#xff0c;大…

【Ubuntu查看硬盘和网络配置信息】解决方案

1.查看硬盘序列号 搜索系统中有应用&#xff0c;打开“磁盘”&#xff0c;因为我本人是只设置了一个分区&#xff0c;所以打开如下&#xff0c;从上面可以得到硬盘序列号 2.查看ip地址和MAC地址 ctrl alt t 打开命令行 输入ifconfig, 如下图所示&#xff0c;MAC地址即为&…