pytorch学习3(pytorch手写数字识别练习)

news2025/1/12 1:58:42

网络模型

设置三层网络,一般最后一层激活函数不选择relu
在这里插入图片描述

任务步骤

手写数字识别任务共有四个步骤:
1、数据加载--Load Data
2、构建网络--Build Model
3、训练--Train
4、测试--Test

实战

1、导入各种需要的包

import torch
from torch import nn
from torch.nn import functional as F
from torch import optim

import torchvision

from matplotlib import pyplot as plt

from minist_utils import plot_image, plot_curve, one_hot ##自写文件

minist_utils:
在这里插入图片描述
在这里插入图片描述在这里插入图片描述

2、加载数据

batch_size = 512

train_loader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST('mnist_data', train=True, download=True,
                               transform=torchvision.transforms.Compose([
                                   torchvision.transforms.ToTensor(),
                                   torchvision.transforms.Normalize(
                                       (0.1307,), (0.3081, ))
                               ])),
    batch_size=batch_size, shuffle=True)

test_loader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST('mnist_data', train=False, download=True,
                               transform=torchvision.transforms.Compose([
                                   torchvision.transforms.ToTensor(),
                                   torchvision.transforms.Normalize(
                                       (0.1307,), (0.3081, ))
                               ])),
    batch_size=batch_size, shuffle=False

取一些样本看数据的shape以及图片内容

x, y = next(iter(train_loader))
print(x.shape, y.shape, x.min(), x.max())
plot_image(x, y, 'image sample')

在这里插入图片描述在这里插入图片描述

注:经过load加载处理后的数据集包含x(图像信息)和y(标签信息)
next(iter())的用法是取一组样本,重复运行可以依次顺序取样,直到样本被取完
可在csdn自行搜索学习了解

3、网络构建

按之前设想的三层线性模型嵌套的思想搭建模型,为了模型简单,第三层不加激活函数。

class Net(nn.Module):

    def __init__(self):
        super(Net, self).__init__()

        # xw+b
        self.fc1 = nn.Linear(28*28, 256) #输入特征数,输出特征数
        self.fc2 = nn.Linear(256, 64)  #256,64是根据经验判断
        self.fc3 = nn.Linear(64, 10)  #最开始的28*28和输出的10是一定的

    def forward(self, x):
        # x: [b, 1, 28, 28]
        # h1 = relu(xw1 + b1)
        x = F.relu(self.fc1(x)) #输入x后第一次线性模型得到H1作第二层输入
        # h2 = relu(h1w2 + b2)
        x = F.relu(self.fc2(x)) #输入H1得到H2作第三层输入
        # h3 = h2w3 + b3
        x = self.fc3(x)	#输入H3得到最终结果,维度为10

        return x

4、模型训练

net = Net()

# [w1, b1, w2, b2, w3, b3]
optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)

train_loss = []

for epoch in range(3):

    for batch_idx, (x, y) in enumerate(train_loader):

        # x: [b, 1, 28, 28], y: [512]
        # [b, 1, 28, 28] => [b, feature] 全连接层只能接受这样的数据
        x = x.view(x.size(0), 28*28)
        # => [b, 10]
        out = net(x)
        # [b, 10]
        y_onehot = one_hot(y)
        # loss = mse(out, y_onehot)
        loss = F.mse_loss(out, y_onehot)

        optimizer.zero_grad()
        loss.backward() # 梯度计算过程
        # w` = w - lr * grad
        optimizer.step() # 优化更新w,b

        train_loss.append(loss.item())

        if batch_idx % 10 == 0:
            print(epoch, batch_idx, loss.item())

plot_curve(train_loss)

在这里插入图片描述

5、测试

1、计算准确率acc

total_correct = 0
for x, y in test_loader:
    x = x.view(x.size(0), 28*28)
    out = net(x)
    # out: [b, 10] => pred: [b]
    pred = out.argmax(dim=1)
    correct = pred.eq(y).sum().float().item()
    total_correct += correct

total_num = len(test_loader.dataset)
acc = total_correct / total_num
print(("acc:", acc))

在这里插入图片描述
2、展示部分测试样本原图以及预测标签结果

x, y =next(iter(test_loader))
out = net(x.view(x.size(0), 28*28))
pred = out.argmax(dim=1)
plot_image(x, pred, 'test')

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1030389.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Matlab图像处理-区域特征

凹凸性 设P是图像子集S中的点,若通过的每条直线只与S相交一次,则称S为发自P的星形,也就是站在P点能看到S的所有点。 满足下列条件之一,称此为凸状的: 1.从S中每点看,S都是星形的; 2.对S中任…

软件设计师笔记系列(四)

😀前言 随着技术的快速发展,软件已经成为我们日常生活中不可或缺的一部分。从智能手机应用到大型企业系统,软件都在为我们提供便利、增强效率和创造价值。然而,随之而来的是对软件质量的日益增长的关注。软件的质量不仅关乎其功能…

C语言中的虚拟地址

虚拟地址 虚拟地址空间 对于操作系统而言,每个进程所得到的虚拟地址都在一个独立的固定的范围内,不会超过这个范围,我们把这个范围称为虚拟地址空间。所谓的虚拟地址空间本质就是一个地址范围,表示程序的寻址能力。对于32位系统…

Python 在 JMeter 中如何使用?

要在JMeter中使用Python,需要使用JSR223 Sampler元素来执行Python脚本。使用JSR223 Sampler执行Python脚本时,需要确保已在JMeter中配置了Python解释器,并设置了正确的环境路径。 1、确保JMeter已安装Python解释器,并将解释器的路…

时序预测 | MATLAB实现POA-CNN-BiGRU鹈鹕算法优化卷积双向门控循环单元时间序列预测

时序预测 | MATLAB实现POA-CNN-BiGRU鹈鹕算法优化卷积双向门控循环单元时间序列预测 目录 时序预测 | MATLAB实现POA-CNN-BiGRU鹈鹕算法优化卷积双向门控循环单元时间序列预测预测效果基本介绍程序设计参考资料 预测效果 基本介绍 MATLAB实现POA-CNN-BiGRU鹈鹕算法优化卷积双向…

Jenkins学习笔记6

开发者开发代码一般会使用IDE集成开发工具(比如pycharm这种),那么使用pycharm开发的代码能否直接利用自动发布系统发布到业务服务器上呢? 答案是肯定的。 然后进行下测试: 那说明SSH免密是成功的。 将Pycharm修改为原来的界面,然…

抖音短视频矩阵系统搭建

企业在进行短视频矩阵运营时,搭建一个矩阵号是非常必要的。矩阵号可以绑定多个不同平台的账号,批量制作和定时发布短视频,提高企业的曝光量和粉丝互动。但是,如何搭建一个有效的短视频矩阵号呢?以下是几个关键步骤。 一…

STM32 NVIC中断优先级管理通过结构图快速理解

STM32 NVIC中断优先级管理通过结构图快速理解 📑抢占优先级和响应优先级基本常识 🌿抢占优先级的级别高于响应优先级。🌿抢占优先级数值编号越小,所代表的优先级就越高;同理,响应优先级也是如此。&#x1…

为什么要选择Spring cloud Sentinel

为什么要选择Spring cloud Sentinel 🍎对比Hystrix🍂雪崩问题及解决方案🍂雪崩问题🍂.超时处理🍂仓壁模式🍂断路器🍂限流🍂总结 🍎对比Hystrix 在SpringCloud当中支持多…

使用 FHE 实现加密大语言模型

近来,大语言模型 (LLM) 已被证明是提高编程、内容生成、文本分析、网络搜索及远程学习等诸多领域生产力的可靠工具。 大语言模型对用户隐私的影响 尽管 LLM 很有吸引力,但如何保护好 输入给这些模型的用户查询中的隐私 这一问题仍然存在。一方面&#xf…

《从菜鸟到大师之路 Redis 篇》

《从菜鸟到大师之路 Redis 篇》 (一):Redis 基础理论与安装配置 Nosql 数据库介绍 是一种 非关系型 数据库服务,它能 解决常规数据库的并发能力 ,比如 传统的数据库的IO与性能的瓶颈 ,同样它是关系型数据…

Android 11.0 禁止二次展开QuickQSPanel设置下拉QSPanel高度

1.前言 在11.0的系统定制化需求中,在进行systemui的ui定制开发中,有些产品中有需求对原生systemui下拉状态栏中的二次展开QSPanel修改成 一次展开禁止二次展开,所以就需要修改QuickQSpanel的高度,然后在QuickQsPanel做定制,然后禁止二次展开就可以了 如图: 2.禁止二次展开…

32.3D文本旋转动画效果

特效 源码 index.html <!DOCTYPE html> <html> <head> <title>CSS 3D Text Rotation</title> <link rel="stylesheet" type="text/css" href="style.css"> </head> <body><div class=&quo…

C++实现观察者模式(包含源码)

文章目录 观察者模式一、基本概念二、实现方式三、角色四、过程五、结构图六、构建思路七、完整代码 观察者模式 一、基本概念 观察者模式&#xff08;又被称为模型&#xff08;Model&#xff09;-视图&#xff08;View&#xff09;模式&#xff09;是软件设计模式的一种。在…

5G通信与蜂窝模组之间的关系

5G通信是第五代移动通信技术的简称&#xff0c;它代表了一种新一代的无线通信技术标准。5G通信的主要目标是提供更高的数据传输速度、更低的延迟、更大的网络容量以及更可靠的连接&#xff0c;以支持各种新兴应用和服务&#xff0c;包括高清视频流、虚拟现实、物联网&#xff0…

【软考中级】网络工程师:7.下一代互联网

IPv4问题与改进 IPv4存在以下著名的问题&#xff1a; 网络地址短缺&#xff08;32位&#xff09;以二进制数串表示&#xff0c;v4仅有43亿个地址&#xff0c;而IPv6有128位&#xff0c;且以十六进制数串表示。&#xff08;现在还能用v4得益于NAT地址转换&#xff09;地址分配…

pwn学习(3)BUUCTF-rip

下载文件&#xff0c;查看文件信息 IDA64打开&#xff0c;发现危险函数gets(),可以判断存在栈溢出漏洞 接着查看fun()函数&#xff0c;发现是system函数&#xff0c;system是C语言下的一个可以执行shell命令的函数 接下来思路就清晰了&#xff0c;需要用gets函数获取一个长字符…

电力安全智慧云平台:引领更安全的用电新时

电力能源是人类社会不可或缺的重要资源&#xff0c;其安全稳定供应关系到各行各业的正常运转和千家万户的生活质量。然而&#xff0c;随着电力使用的普及&#xff0c;电力安全问题也日益凸显&#xff0c;一旦发生电力事故&#xff0c;不仅会造成巨大的经济损失&#xff0c;还会…

python随手小练

题目&#xff1a; 使用python做一个简单的英雄联盟商城登录界面 具体操作&#xff1a; print("英雄联盟商城登录界面") print("~ * "*15 "~") #找其规律 a "1、用户登录" b "2、新用户注册" c "3、退出系统&quo…

rv1126-rv1109-test

测试指令 播放音频:aplay aigei.wav 测试时间: 查看系统时间:date 设置时间:date -s "2023-09-21 16:00:00" 设置芯片时间:hwclock -w 查看芯片时间:hwclock 测试背光: echo 0 > sys/class/backlight/backlight/brightness echo 50 > sys/class/backlig…