从官网下载/处理 MNIST 数据集,并构造CNN网络训练

news2024/12/23 8:49:55

这里写自定义目录标题

  • MNIST 网络 测试用
    • 1. 导入所需要的模块
    • 2. 下载 MNIST 数据集
    • 3. 读取 MNIST 数据集

MNIST 网络 测试用

1. 导入所需要的模块

import sys
sys.path.append('../../')
from zfdplearn import fdutils, fdtorch_net, fddata
import os
import os.path as path
import gzip

from typing import Dict, List, Tuple, AnyStr

import torch
from torch.utils.data import DataLoader, Dataset
from torch import nn
from torchvision import transforms

import numpy as np
import matplotlib.pyplot as plt


from tqdm import tqdm

2. 下载 MNIST 数据集

2.1 下载地址: http://yann.lecun.com/exdb/mnist/
2.1.1 下载的文件有 4 个,分别是:

train-images-idx3-ubyte.gz ==> 训练集的图片
train-label-idx1-ubyte.gz ==> 训练集的标签
t10k-images-idx3-ubyte.gz ==> 测试集的图片
t10k-label-idx1-ubyte.gz ==> 测试集的标签

下载的数据集格式为 .gz,因此需要使用到 python 的 gzip 包

# 下载地址: http://yann.lecun.com/exdb/mnist/
dataset_folder = '../datasets/mnist'
files_name = {
    'train_img': 'train-images-idx3-ubyte.gz',
    'train_label': 'train-labels-idx1-ubyte.gz',
    'vali_img': 't10k-images-idx3-ubyte.gz',
    'vali_label': 't10k-labels-idx1-ubyte.gz'
}

3. 读取 MNIST 数据集

3.1 下载的数据集格式为 .gz,因此需要使用 gzip 中的 open 函数打开。
3.2 打开模式设置为 mode=‘rb’,以字节流的方式打开。因为下载的数据集的格式为字节方式封装
3.3 由于使用字节流打开,因此需要使用 torch.frombuffer() 或者 np.frombuffer() 函数打开。
3.3 根据 MNIST 数据集官网可知,读取数据集需要 offset,因为,在数据头部的数据存储了数据集的一些信息
3.4.1 training set label file: 前 4 个字节为 魔术数,第 4-7 字节为数据的条数(number of items),因此需要 offset 8
trainSetLable
3.4.2 training set images file: 前 4 个字节为 魔术数,第 4-7 字节为数据的条数(number of items),第 8-11 是每张图片的行数,第 12-15 是每张图片的列数, 因此需要 offset 16
trainSetImg
3.4.2 test set label file: 前 4 个字节为 魔术数,第 4-7 字节为数据的条数(number of items),因此需要 offset 8
testSetLable
3.4.3 test set images file: 前 4 个字节为 魔术数,第 4-7 字节为数据的条数(number of items),第 8-11 是每张图片的行数,第 12-15 是每张图片的列数,因此需要 offset 16
testSetImg

PS: torch/np. frombuffer()

# 加载训练集 图片
def load_mnist_data(files_name) -> Tuple:
    with gzip.open(path.join(dataset_folder, files_name['train_img']), mode='rb') as data:
        train_img = torch.frombuffer(data.read(), dtype=torch.uint8, offset=16).reshape(-1, 1, 28, 28)
    # 加载训练集 标签
    with gzip.open(path.join(dataset_folder, files_name['train_label']), mode='rb') as label:
        train_label = torch.frombuffer(label.read(), dtype=torch.uint8, offset=8)
    # 加载验证集 图片
    with gzip.open(path.join(dataset_folder, files_name['vali_img']), mode='rb') as data:
        vali_img = torch.frombuffer(data.read(), dtype=torch.uint8, offset=16).reshape(-1, 1, 28, 28)
    # 加载验证集 label
    with gzip.open(path.join(dataset_folder, files_name['vali_label']), mode='rb') as label:
        vali_label = torch.frombuffer(label.read(), dtype=torch.uint8, offset=8)
    return (train_img, train_label),(vali_img, vali_label)


class MNIST_dataset(Dataset):
    def __init__(self, data: List, label: List):
        self.__data = data
        self.__label = label

    def __getitem__(self, item):
        if not item < self.__len__():
            return f'Error, index {item} is out of range'
        return self.__data[item], self.__label[item]

    def __len__(self):
        return len(self.__data)

# 读取数据
train_data, vali_data = load_mnist_data(files_name)
# 将数据封装为 MNIST 类
train_dataset = MNIST_dataset(*train_data)
vali_dataset = MNIST_dataset(*vali_data)
len(train_dataset), len(vali_dataset)
(60000, 10000)
class YLMnistNet(nn.Module):
    def __init__(self):
        super(YLMnistNet, self).__init__()
        self.conv0 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=(5, 5))
        self.conv1 = nn.Conv2d(6, 16, kernel_size=(5, 5))
        self.pool0 = nn.AvgPool2d(kernel_size=(2, 2))
        self.pool1 = nn.AvgPool2d(kernel_size=(2, 2))
        self.linear0 = nn.Linear(16*4*4, 120)
        self.linear1 = nn.Linear(120, 84)
        self.linear2 = nn.Linear(84, 10)
        self.relu = nn.ReLU()
        self.flatten = nn.Flatten()
        self.layers = [self.conv0, self.pool0, self.conv1, self.pool1, self.flatten, self.linear0, self.relu, self.linear1, self.relu, self.linear2, self.relu]

    def forward(self, x):
        output = self.conv0(x)
        output = self.pool0(output)
        output = self.conv1(output)
        output = self.pool1(output)
        output = self.flatten(output)
        output = self.linear0(output)
        output = self.relu(output)
        output = self.linear1(output)
        output = self.relu(output)
        output = self.linear2(output)
        output = self.relu(output)
        return output

    # get depth of MNIST Net
    def __len__(self):
        return len(self.layers)

    # get specified layer
    def __getitem__(self, item):
        return self.layers[item]

    def __name__(self):
        return 'YNMNISTNET'

net = YLMnistNet()

def train(net, loss, train_iter, vali_iter, optimizer, epochs, device) -> fdutils.Accumulator:
    net = net.to(device)
    one_hot_f = nn.functional.one_hot
    accumulator = fdutils.Accumulator(['train_loss', 'vali_loss', 'train_acc', 'vali_acc'])
    epoch_loss = []
    for epoch in range(epochs):
        len_train =  0
        len_vali = 0

        net.train()
        epoch_loss.clear()
        correct_num = 0
        for img, label in train_iter:
            img, label = img.to(device, dtype=torch.float), label.to(device)
            oh_label = one_hot_f(label.long(), num_classes=10)
            optimizer.zero_grad()
            y_hat = net(img)
            l = loss(y_hat, oh_label.to(dtype=float))
            l.backward()
            optimizer.step()
            epoch_loss.append(l.item())
            correct_num += (y_hat.argmax(dim=1, keepdim=True) == label.reshape(-1, 1)).sum().item()
            len_train += len(label)
        accumulator['train_loss'].append(sum(epoch_loss)/len(epoch_loss))
        accumulator['train_acc'].append(correct_num/len_train)
        print(f'-----------epoch: {epoch+1} start --------------')
        print(f'epoch: {epoch+1} train loss: {accumulator["train_loss"][-1]}')
        print(f'epoch: {epoch+1} train acc: {accumulator["train_acc"][-1]}')

        # validation
        epoch_loss.clear()
        correct_num = 0
        with torch.no_grad():
            net.eval()
            for img, label in vali_iter:
                img, label = img.to(device, dtype=torch.float), label.to(device)
                # print(img.dtype)
                oh_label = one_hot_f(label.long(), num_classes=10)
                vali_y_hat = net(img)
                l = loss(vali_y_hat, oh_label.to(dtype=float))
                epoch_loss.append(l.item())
                correct_num += (vali_y_hat.argmax(dim=1, keepdim=True) == label.reshape(-1, 1)).sum().item()
                len_vali += len(label)
            accumulator['vali_loss'].append(sum(epoch_loss)/len(epoch_loss))
            accumulator['vali_acc'].append(correct_num / len_vali)
            print(f'epoch: {epoch+1} vali loss: {accumulator["vali_loss"][-1]}')
            print(f'epoch: {epoch+1} vali acc: {accumulator["vali_acc"][-1]}')
            print(f'-----------epoch: {epoch+1} end --------------')
    return accumulator

# from torch.utils.data import DataLoader
net = YLMnistNet()
batch_size = 32
train_iter = DataLoader(train_dataset, batch_size=batch_size)
vali_iter = DataLoader(vali_dataset, batch_size=batch_size)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
num_epoch = 1
loss = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr=0.01)
accumulator = train(net, loss, train_iter, vali_iter, optimizer, num_epoch, device)

epoch: 1 train loss: nan
epoch: 1 train acc: 0.09871666666666666
epoch: 1 vali loss: nan
epoch: 1 vali acc: 0.098

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/404316.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

# 数据完整性算法在shell及python中的实践

数据完整性算法在shell及python中的实践 文章目录数据完整性算法在shell及python中的实践1 预备知识1.1 摘要算法1.2 报文&#xff08;数据&#xff09;完整性校验1.3 python byte类型字符串与普通字符串区别2 传统方法&#xff08;散列函数&#xff09;2.1 在shell中实践2.2 在…

python调试模块ipdb

1. 调试python ipdb是用来python中用以交互式debug的模块&#xff0c;可以直接利用pip安装; 其功能类似于pycharm中 python控制台&#xff0c; 而使用ipdb 的优点&#xff0c;便是直接在代码中调试&#xff0c; 避免了在python控制台&#xff0c;或者重新设置一些简单变量。…

Web前端开发--自用

第一章 1.1 时间&#xff1a;1980 人物&#xff1a;Tim Berners-Lee 地点&#xff1a;欧洲核子研究组织中最大的欧洲核子物理实验室 事件&#xff1a;与Robert Cailliau建立ENQUIRE系统 1984年&#xff0c;世界上第一个客户端浏览器&#xff08;World Wide Web&#xff09;和第…

软考高项——配置管理

配置管理配置管理配置管理6个主要活动配置项配置基线配置项的状态配置库配置库权限管理配置审计配置管理 配置管理的总线索包括&#xff1a; 1&#xff09;配置管理6个主要活动 2&#xff09;配置项 3&#xff09;配置基线 4&#xff09;配置项的状态 5&#xff09;配置库 6&a…

SAP SQVI快速报表的使用

SQVI快速报表 一、说明 对数据表进行查询通常使用SE16&#xff0c;但只限于单张表&#xff0c;对于多表联动的查询&#xff0c;则需要通过创建Query的方式&#xff0c;方法有多种&#xff0c;而SQVI是一种简洁快速的工具。SQVI全称是Quick Viewer&#xff0c;可以快速生成多表…

动态规划回文子串

647. 回文子串方法&#xff1a;双指针回文子串有长度为奇数和偶数两种&#xff0c;extend(s, i, i, n); extend(s, i, i 1, n);就分别对应长度为奇数和偶数的情况class Solution { private:int extend(const string& s, int i, int j, int n) {int res 0;while (i > 0…

前端——8.超链接标签

这篇文章&#xff0c;我们来讲一下超链接标签 目录 1.超链接标签介绍 1.1链接的分类 2.具体案例讲解 2.1外部链接 2.2 内部标签 2.3 空链接 2.4下载连接 2.5网页元素链接 2.6锚点标签 3.小结 1.超链接标签介绍 超链接标签是HTML中一个十分重要的标签&#xff0c;下…

案例18-面向对象之开门小例子

目录 一&#xff1a;背景介绍 二&#xff1a;思路&方案 1.面向过程 2.面向对象 3.面向对象(反射) 三&#xff1a;过程 1.面向过程&#xff1a;原本何老师的作用交给我了米老师来完成。 2.面向对象&#xff1a;把开门的方法完全交个何老师&#xff0c;米老师不需要有…

k8s 部署 skywalking 并持久化到es

1、k8s中安装部署 skywalking skywalking集群情况下需要保证用同一数据源&#xff0c;这里我们存储方式改为es 1.1 部署elasticsearch docker run -it -d -p 9200:9200 -p 9300:9300 -e ES_JAVA_OPTS"-Xms256m -Xmx256m" -e "discovery.typesingle-node"…

DSRC技术

DSRC(Dedicated Short Range Communication)专用短程通信 定位 是V2X领域存在的两大通信技术之一&#xff08;另一个为LTE-V2X&#xff09;。 所属技术路线 与这两大技术相对应&#xff0c;是V2X无线通信技术的两大技术路线&#xff1a; IEEE 802.11p 本是04年指定的一个通…

一文入门HTML+CSS+JS(样例后续更新)

一文入门HTMLCSSJS&#xff08;样例后续更新&#xff09;前言HTML&#xff0c;CSS和JS的关系HTMLhead元素titlelinkmetabody元素设置网页正文颜色与背景颜色添加网页背景图片设置网页链接文字颜色设置网页边框文字与段落标记普通文字的输入对文字字体的设置 font使用文字的修饰…

代码随想录刷题-数组总结篇

文章目录数组二分查找原理习题题目1思路和代码题目-2移除元素习题我的想法暴力解法双指针有序数组的平方习题暴力排序双指针长度最小的子数组习题暴力解法滑动窗口螺旋矩阵 II习题我的解法别人的解法总结数组 二分查找 本节对应代码随想录中&#xff1a;代码随想录-二分查找 …

java教程--函数式接口--lambda表达式--方法引用

函数式接口 介绍 jdk8新特性&#xff0c;只有一个抽象方法的接口我们称之为函数接口。 FunctionalInterface ​ JDK的函数式接口都加上了FunctionalInterface 注解进行标识。但是无论是否加上该注解只要接口中只有一个抽象方法&#xff0c;都是函数式接口。 如在Comparato…

Makefile的概述

什么是makemake 是个命令&#xff0c;是个可执行程序&#xff0c;用来解析Makefile文件的命令&#xff0c;这个命令存放在 /usr/binmake概述1.GUN make是一种代码维护工具2.make 工具会根据makefile文件定义的规则和步骤&#xff0c;完成整个软件项目的代码维护工作。3.一般用来…

解决Win10图片/文件右键单击自动退出并刷新桌面问题

问题描述 这两天开始不知道怎么回事儿&#xff0c;右键选择图片时候&#xff0c;电脑黑屏且资源管理器自动重启。然后我就开始找很多方法去解决。 我试了很多种复杂的简单的方法&#xff0c;但是只有一种解决了我的问题。 解决方案【解决我的问题】 这个方法如下&#xff1…

VirtualBox的克隆与复制

快照太多&#xff0c;想整合成1个文件怎么办&#xff1f; 最近&#xff0c;我就遇到一个问题。快照太多了。比较占用空间怎么办&#xff1f; 错误做法 一开始&#xff0c;我是这么操作的&#xff0c;选中某个快照&#xff0c;然后选择删除…然后我登录虚拟机后&#xff0c;发…

为什么程序员喜欢这些键盘?

文章目录程序员的爱介绍个人体验程序员的爱 程序员是长时间使用计算机的群体&#xff0c;他们需要一款高品质的键盘来保证舒适的打字体验和提高工作效率。在键盘市场上&#xff0c;有很多不同类型的键盘&#xff0c;但是对于程序员来说&#xff0c;机械键盘是他们最钟爱的选择…

原来CSS 也可以节流啊

Ⅰ、前言 「节流」 是为了减少请求的触发频率&#xff0c;不让用户点的太快&#xff0c;达到节省资源的目的 &#xff1b;通常 我们采用 JS 的 定时器 setTimeout &#xff0c;来控制点击多少秒才能在触发&#xff1b;其实 通过 CSS 也能达到 「节流」 的目的&#xff0c;下面…

LeetCode598. 范围求和 II(python)

题目 给你一个 m x n 的矩阵 M &#xff0c;初始化时所有的 0 和一个操作数组 op &#xff0c;其中 ops[i] [ai, bi] 意味着当所有的 0 < x < ai 和 0 < y < bi 时&#xff0c; M[x][y] 应该加 1。 提示: 1 < m, n < 4 * 104 0 < ops.length < 104 o…

硅谷银行倒闭的几点启示

摘要&#xff1a;本文从公开资料分析一下硅谷银行对信息科技行业的我们有一些什么启示。硅谷银行“拔网线”了&#xff0c;想创业的您&#xff0c;该注意了。1.硅谷银行是谁我们从其官网的说明来看看。The financial partner of the innovation economy.&#xff08;翻译成中文…