PyTorch深度学习框架学习记录(2)--MNIST手写数字识别(续)

news2025/1/20 3:51:17

文章目录

    • 前言
    • MNIST手写数字识别
      • 数据的准备工作
      • 数据的处理
      • 主干网络的定义
      • 损失函数的使用(修改)
      • 训练和预测
      • 运行

前言

这个是我在学习中的其中一种方式实现MNIST手写的识别,思路我觉得比较清晰,后面会把另外的方法代码整理发布。

MNIST手写数字识别

数据的准备工作

非常重要,但是只使用MNIST学习过程,所以并不需要深究,不同的数据集的处理都不一样

因为MNIST数据集很简单,由28×28的灰度图像组成,所以每张图片都是784个灰度数字。

有脚本可以直接下载:

"""
download_mnist.py
下载数据
"""
from pathlib import Path
import requests

DATA_PATH = Path("data")
PATH = DATA_PATH / "mnist"

PATH.mkdir(parents=True, exist_ok=True)

URL = "http://deeplearning.net/data/mnist/"
FILENAME = "mnist.pkl.gz"

if not (PATH / FILENAME).exists():
        content = requests.get(URL + FILENAME).content
        (PATH / FILENAME).open("wb").write(content)

但是经常有网络问题,也可以去找MNIST的pkl格式的数据。

mnist.pkl.gz百度网盘链接

链接:https://pan.baidu.com/s/1nx2k5IPAnP1u6CkRR8NXfw?pwd=zbqy
提取码:zbqy

"""
path_setting.py
设置数据所在路径
"""
import pickle
import gzip
from pathlib import Path

"""保存路径data/mnist/mnist.pkl.gz"""
DATA_PATH = Path("data")
PATH = DATA_PATH / "mnist"
FILENAME = "mnist.pkl.gz"

如果想看下数据集的内容,可以使用下面的方式:

"""
show.py
查看数据内容
"""
import matplotlib.pyplot as plt
import pylab
from path_setting import *

"""读取图像"""
with gzip.open((PATH / FILENAME).as_posix(), "rb") as f:
    ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding="latin-1")  # 读取数据
    """
    x_train: 训练数据
    y_train: 训练标签
    x_valid: 测试数据
    y_valid: 测试标签
    """


print(x_train.shape)  # 查看x_train的形状: (50000, 784)
print(x_valid.shape)  # 查看x_valid的形状: (10000, 784)
print(y_train.shape)  # 查看y_train的形状: (50000, )
print(x_train[0])  # 0号数据(第一个数据)的784个灰度值
print(y_train[0])  # 0号数据(第一个数据)的标签: 5
print(x_train[0].shape)  # 看一下x_train中一个图像的形状: (784,) 784个灰度值
plt.imshow(x_train[0].reshape(28, 28), cmap="gray")  # 更改图像的形状为28 × 28
pylab.show()  # 展示图像

运行结果:

在这里插入图片描述

数据的处理

获取数据之后,数据此时并不是张量,而是数组,需要将数据转换成张量

torch也提供了方法

  • TensorDataset
  • Dataloader
"""
data_process.py
数据处理
"""
import torch
from torch.utils.data import DataLoader
from torch.utils.data import TensorDataset
from path_setting import *
import pickle
import gzip

bs = 64  # batch_size
with gzip.open((PATH / FILENAME).as_posix(), "rb") as f:
    ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding="latin-1")

x_train, y_train, x_valid, y_valid = map(
    torch.tensor, (x_train, y_train, x_valid, y_valid)  # 将numpy数组转换为tensor
)

train_ds = TensorDataset(x_train, y_train)
train_dl = DataLoader(train_ds, batch_size=bs, shuffle=True)  # 在所有的训练数据中以bs个样本为单位随机取数据

valid_ds = TensorDataset(x_valid, y_valid)
valid_dl = DataLoader(valid_ds, batch_size=bs * 2)  


主干网络的定义

"""
backbone.py
主干神经网络
"""

import torch.nn as nn
import torch.nn.functional as F


class FCnet(nn.Module):
    """input -> hidden1 -> hidden2 -> out"""
    def __init__(self):
        super().__init__()
        self.hidden1 = nn.Linear(784, 128)
        self.hidden2 = nn.Linear(128, 256)
        self.out = nn.Linear(256, 10)
        self.dropout = nn.Dropout(0.5)  # dropout的概率

    def forward(self, x):
        x = F.relu(self.hidden1(x))
        x = self.dropout(x)
        x = F.relu(self.hidden2(x))
        x = self.dropout(x)
        x = self.out(x)
        return x

使用输入层->全连接层->输出这样的全连接神经网络。

基本的代码格式就按上面的模板来

初始化函数是设置各个层次的结构

forward函数为前向传播,需要手动设置

损失函数的使用(修改)

PyTorch中有两种快速使用一些常用损失函数的方法:

  • torch.nn.functional
  • nn.Module

损失函数传入的参数一般为两个,伪代码如下:

import torch.nn.functional as F

loss_func = F.cross_entropy
loss_func(预测值, 实际值)
"""
loss_batch.py
计算loss和梯度更新
"""
from torch import optim


def loss_batch(module, loss_func, data, label, opt=None):

    loss = loss_func(module(data), label)  # 计算当前损失

    if opt is not None:
        loss.backward()  # 每一层的权重参数运算,得出梯度
        opt.step()  # 根据执行参数更新
        opt.zero_grad()  # 重新将梯度置为0

    return loss.item(), len(data)

训练和预测

"""
fit.py
训练和验证过程
"""
import numpy as np
import torch
from loss_batch import loss_batch


def fit(epochs, module, loss_func, opt, train_dl, valid_dl):
    for epoch in range(epochs):
        module.train()
        for train_data, train_label in train_dl:
            loss_batch(module, loss_func, train_data, train_label, opt)

        module.eval()
        with torch.no_grad():
            losses, nums = zip(
                *[loss_batch(module, loss_func, valid_data, valid_label) for valid_data, valid_label in valid_dl]
            )

        val_loss = np.sum(np.multiply(losses, nums)) / np.sum(nums)
        print(f"step: {str(epoch)} , loss in valid : {str(val_loss)}")
    correct = 0
    total = 0
    for vd, vl in valid_dl:
        outputs = module(vd)  # (128, 10)一个batch64,两个128,每一个的结果是隶属于十个分类的概率
        _, predicted = torch.max(outputs.data, 1)  # 1代表行,每个样本中取最大的
        total += vl.size(0)
        correct += (predicted == vl).sum().item()
    print(f"accuracy: {correct / total}")

运行

"""
run.py
"""
from torch import optim
from backbone import FCnet
from data_process import *
from fit import fit
from data_process import *
import torch.nn.functional as F

net = FCnet()
loss_func = F.cross_entropy
opt = optim.SGD(net.parameters(), lr=0.001)
fit(50, net, loss_func, opt, train_dl, valid_dl)

运行结果:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/93046.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

是否有 API 可供云对接?

涂鸦提供了两种维度供开发者拉取:应用维度、产品维度。 应用维度:设备绑定的用户属于开发者在涂鸦云应用中的用户,开发者间接拥有权限操作属于自己应用用户的设备; 产品维度:设备所属产品属于开发者在涂鸦云产品中的设…

2022年总结:道阻且长,行则将至

前言 今年是第四个年头写总结了,直到这个时候,我仍未想出今年的标题是什么。 2019年总结,平凡的我仍在平凡的生活 2020年总结,所有努力只为一份期待 2021年总结:前路有光,初心莫忘 如果非得用一句话来…

(Java)SpringMVC学习笔记(二)

前言 继续学习SpringMVC视频教程,争取今明后三天把设定的目标完成 SpringMVC 框架搭建 这一步花了我一上午时间,报了个404错误,没办法,只能从头开始创建项目(心态差点整爆炸,第一次是自认不懂&#xff0…

格力博通过创业板注册:上半年营收32亿 拟募资34.56亿

雷递网 雷建平 12月15日格力博(江苏)股份有限公司(简称:“格力博”)日前通过注册,准备在深交所创业板上市。格力博计划募资34.56亿元,其中,11.69亿元用于年产500万件新能源园林机械智…

【从零开始学习深度学习】25.卷积神经网络之LeNet模型介绍及其Pytorch实现【含完整代码】

目录1. LeNet模型介绍与实现2. 输入为Fashion-MNIST时各层输出形状3. 获取Fashion-MNIST数据和并使用LeNet模型进行训练4.完整代码之前我们对Fashion-MNIST数据集中的图像进行分类时,是将28*28图像中的像素逐行展开,得到长度为784的向量,并输…

Spring Cloud基于JWT创建统一的认证服务

认证服务肯定要有用户信息,不然怎么认证是否为合法用户?因为是内部的调用认证,可以简单一点,用数据库管理就是一种方式。或者可以配置用户信息,然后集成分布式配置管理就完美了。 表结构 本教程中的案例把查数据库这…

2022-年终总结

2022年已经到了尾声,后半年度过的太漫长了,也是自己这两年来成长速度最快的一次了(后文揭晓) 今年的年中总结链接 上半年我沉浸在读各类技术书籍中,但是后半年的我几乎放弃了读书,转而投身到另外一个学习渠…

Linux Phy 驱动解析

文章目录1. 简介2. phy_device2.1 mdio bus2.2 mdio device2.3 mdio driver2.4 poll task2.4.1 自协商配置2.4.2 link 状态读取2.4.3 link 状态通知3. phylink3.1 phylink_create()3.2 phylink_connect_phy()3.3 phylink_start()3.3 poll task参考资料1. 简介 在调试网口驱动的…

从另外一个角度解释AUC

AUC到底代表什么呢,我们从另外一个角度解释AUC,我们先看看一个auc曲线 蓝色曲线下的面积(我的模型的AUC)比红线下的面积(理论随机模型的AUC)大得多,所以我的模型一定更好。 我的模型比随机模型好多少呢?理论随机模型只是对角线,…

加密与认证技术

加密与认证技术密码技术概述密码算法与密码体制的基本概念加密算法与解密算法秘钥的作用什么是密码密钥长度对称密码体系对称加密的基本概念典型的对称加密算法DES加密算法3DES加密算法非对称密码体系非对称加密基本概念密码技术概述 密码技术是保证网络安全的核心技术之一&am…

【windows Server 2019系列】 构建IIS服务器

个人名片: 对人间的热爱与歌颂,可抵岁月冗长🌞 Github👨🏻‍💻:念舒_C.ying CSDN主页✏️:念舒_C.ying 个人博客🌏 :念舒_C.ying Web服务器也称为WWW(World W…

电子厂测试题——难倒众多主播——大司马也才90分

一、选择题 1、1-2 ( ) A.1 B.3 C.-1 D.-3 2、|1-2|( ) A.1 B.3 C. -1 D.-3 3、1x2x3( ) A.5 B.6 C.7 D.8 4、3643( ) A.29 B.16 C.8 D.3 5、55x5( ) A.15 B.30 C.50 D.125 二、填空题(请填写阿拉伯数字) 6、110100 1000_______ 7、一个三角形砍去1个角&#…

Feign的两种最佳实践方式介绍

何谓最佳实践呢?就是企业中各种踩坑,最后总结出来的相对比较好的使用方式; 下面给大家介绍两种比较好的实践方案: 方式一(继承):给消费者的FeignClient和提供着的Controller定义一个统一的父接…

在逆变器中驱动和保护IGBT

在逆变器中驱动和保护IGBT 介绍 ACPL-339J是一款先进的1.0 A双输出,易于使用,智能的手机IGBT门驱动光耦合器接口。专为支持而设计MOSFET制造商的各种电流评级,ACPL-339J使它更容易为系统工程师支持不同的系统额定功率使用一个硬件平台通过…

全面解析若依框架(springboot-vue前后分离--后端部分)

1、 若依框架分解 - 启动配置 前端启动 # 进入项目目录 cd ruoyi-ui# 安装依赖 npm install# 强烈建议不要用直接使用 cnpm 安装,会有各种诡异的 bug,可以通过重新指定 registry 来解决 npm 安装速度慢的问题。 npm install --registryhttps://regist…

算法刷题打卡第47天:排序数组---归并排序

排序数组 难度:中等 给你一个整数数组 nums,请你将该数组升序排列。 示例 1: 输入:nums [5,2,3,1] 输出:[1,2,3,5]示例 2: 输入:nums [5,1,1,2,0,0] 输出:[0,0,1,1,2,5]归并排…

用CSS给健身的侣朋友做一个喝水记录本

前言 事情是这样的,由于七八月份的晚上时不时就坐在地摊上开始了喝酒撸串的一系列放肆的长肉肉项目。 这不,前段时间女朋友痛下决心(心血来潮)地就去报了一个健身的私教班,按照教练给的饮食计划中,其中有一…

卵巢早衰与微生物群,营养治疗新进展

卵巢早衰 卵巢早衰(premature ovarian insufficiency,简称POI)在生殖系统疾病中位居首位,这些疾病可能会损害多个功能系统,降低生活质量,最终剥夺女性患者的生育能力。 目前的激素替代疗法不能改善受孕或降…

NR PDSCH(七) DL SPS

非动态调度,除了PUSCH configured grant type 1和2的传输,还有PDSCH SPS 传输,两者的流程基本类似,也有些小区别。在实网并没有见过配置DL SPS PDSCH传输的log,但还是按顺序理一遍相关内容。 RRC/MAC 先看下MAC 38.32…

文件上传,还存储在应用服务器?

一般项目开发中都会有文件、图片、视频等文件上传并能够访问的场景。要实现这样的场景,要么把文件存储在应用服务器上,要么搭建文件服务来存储。但是这两种方式也有不少的缺点,增加运维的成本。 因此,追求用户体验的项目可能会考…