Pytorch-以数字识别更好地入门深度学习

news2025/1/11 12:01:11

目录

一、数据介绍

二、下载数据 

三、可视化数据

四、模型构建

五、模型训练

六、模型预测


一、数据介绍

MNIST数据集是深度学习入门的经典案例,因为它具有以下优点:

1. 数据量小,计算速度快。MNIST数据集包含60000个训练样本和10000个测试样本,每张图像的大小为28x28像素,这样的数据量非常适合在GPU上进行并行计算。

2. 标签简单,易于理解。MNIST数据集的标签只有0-9这10个数字,相比其他图像分类数据集如CIFAR-10等更加简单易懂。

3. 数据集已标准化。MNIST数据集中的图像已经被归一化到0-1之间,这使得模型可以更快地收敛并提高准确率。

4. 适合初学者练习。MNIST数据集是深度学习入门的最佳选择之一,因为它既不需要复杂的数据预处理,也不需要大量的计算资源,可以帮助初学者快速上手深度学习。

综上所述,MNIST数据集是深度学习入门的经典案例,它具有数据量小、计算速度快、标签简单、数据集已标准化、适合初学者练习等优点,因此被广泛应用于深度学习的教学和实践中。

手写数字识别技术的应用非常广泛,例如在金融、保险、医疗、教育等领域中,都有很多应用场景。手写数字识别技术可以帮助人们更方便地进行数字化处理,提高工作效率和准确性。此外,手写数字识别技术还可以用于机器人控制、智能家居等方面  。

使用torch.datasets.MNIST下载到指定目录下:./data,当download=True时,如果已经下载了不会再重复下载,同train选择下载训练数据还是测试数据

官方提供的类:

class MNIST(
    root: str,
    train: bool = True,
    transform: ((...) -> Any) | None = None,
    target_transform: ((...) -> Any) | None = None,
    download: bool = False
)
Args:
    root (string): Root directory of dataset where MNIST/raw/train-images-idx3-ubyte
        and MNIST/raw/t10k-images-idx3-ubyte exist.
    train (bool, optional): If True, creates dataset from train-images-idx3-ubyte,
        otherwise from t10k-images-idx3-ubyte.
    download (bool, optional): If True, downloads the dataset from the internet and
        puts it in root directory. If dataset is already downloaded, it is not downloaded again.
    transform (callable, optional): A function/transform that takes in an PIL image
        and returns a transformed version. E.g, transforms.RandomCrop
    target_transform (callable, optional): A function/transform that takes in the
        target and transforms it.

二、下载数据 

# 导入数据集
# 训练集
import torch
from torchvision import datasets,transforms
from torch.utils.data import Dataset
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST(root="./data",
                   train=True,
                   download=True,
                   transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,),(0.3081,))])),
                    batch_size=64,
                    shuffle=True)

# 测试集
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST("./data",train=False,transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,),(0.3081,))])),
    batch_size=64,shuffle=True
)

pytorch也提供了自定义数据的方法,根据自己数据进行处理

使用PyTorch提供的Dataset和DataLoader类来定制自己的数据集。如果想个性化自己的数据集或者数据传递方式,也可以自己重写子类。

以下是一个简单的例子,展示如何创建一个自定义的数据集并将其传递给模型进行训练:

import torch
from torch.utils.data import Dataset, DataLoader

class MyDataset(Dataset):
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        x = self.data[index]
        y = self.labels[index]
        return x, y

data = torch.randn(100, 3, 32, 32)
labels = torch.randint(0, 10, (100,))

my_dataset = MyDataset(data, labels)
my_dataloader = DataLoader(my_dataset, batch_size=4, shuffle=True)

详细完整流程可参考: Pytorch快速搭建并训练CNN模型?

三、可视化数据

mport matplotlib.pyplot as plt
import numpy as np
import torchvision
def imshow(img):
    img = img / 2 + 0.5 # 逆归一化
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg,(1,2,0)))
    plt.title("Label")
    plt.show()

# 得到batch中的数据
dataiter = iter(train_loader)
images,labels = next(dataiter)
# 展示图片
imshow(torchvision.utils.make_grid(images))

四、模型构建

定义模型类并继承nn.Module基类

# 构建模型
import torch.nn as nn
import torch
import torch.nn.functional as F
class MyNet(nn.Module):
    def __init__(self):
        super(MyNet,self).__init__()
        # 输入图像为单通道,输出为六通道,卷积核大小为5×5
        self.conv1 = nn.Conv2d(1,6,5)
        self.conv2 = nn.Conv2d(6,16,5)
        # 将16×4×4的Tensor转换为一个120维的Tensor,因为后面需要通过全连接层
        self.fc1 = nn.Linear(16*4*4,120)
        self.fc2 = nn.Linear(120,84)
        self.fc3 = nn.Linear(84,10)
    
    def forward(self,x):
        # 在(2,2)的窗口上进行池化
        x = F.max_pool2d(F.relu(self.conv1(x)),2)
        x = F.max_pool2d(F.relu(self.conv2(x)),2)
        # 将维度转换成以batch为第一维,剩余维数相乘为第二维
        x = x.view(-1,self.num_flat_features(x))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
    
    def num_flat_features(self,x):
        size = x.size()[1:]
        num_features = 1
        for s in size:
            num_features *= s
        return num_features
    
net = MyNet()
print(net)

输出: 

MyNet(
  (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=256, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
)

简单的前向传播

# 前向传播
print(len(images))
image = images[:2]
label = labels[:2]
print(image.shape)
print(image.size())
print(label)
out = net(image)
print(out)

输出: 

16
torch.Size([2, 1, 28, 28])
torch.Size([2, 1, 28, 28])
tensor([6, 0])
tensor([[ 1.5441e+00, -1.2524e+00,  5.7165e-01, -3.6299e+00,  3.4144e+00,
          2.7756e+00,  1.1974e+01, -6.6951e+00, -1.2850e+00, -3.5383e+00],
        [ 6.7947e+00, -7.1824e+00,  8.8787e-01, -5.2218e-01, -4.1045e+00,
          4.6080e-01, -1.9258e+00,  1.8958e-01, -7.7214e-01, -6.3265e-03]],
       grad_fn=<AddmmBackward0>)

计算损失:

# 计算损失
# 因为是多分类,所有采用CrossEntropyLoss函数,二分类用BCELoss
image = images[:2]
label = labels[:2]
out = net(image)
criterion = nn.CrossEntropyLoss()
loss = criterion(out,label)
print(loss)

输出:

tensor(2.2938, grad_fn=<NllLossBackward0>)

五、模型训练

# 开始训练
model = MyNet()
# device = torch.device("cuda:0")
# model = model.to(device)
import torch.optim as optim
optimizer = optim.SGD(model.parameters(),lr=0.01) # lr表示学习率
criterion = nn.CrossEntropyLoss()
def train(epoch):
    # 设置为训练模式:某些层的行为会发生变化(dropout和batchnorm:会根据当前批次的数据计算均值和方差,加速模型的泛化能力)
    model.train()
    running_loss = 0.0
    for i,data in enumerate(train_loader):
        # 得到输入和标签
        inputs,labels = data
        # 消除梯度
        optimizer.zero_grad()
        # 前向传播、计算损失、反向传播、更新参数
        outputs = model(inputs)
        loss = criterion(outputs,labels)
        loss.backward()
        optimizer.step()
        # 打印日志
        running_loss += loss.item()
        if i % 100 == 0:
            print("[%d,%5d] loss: %.3f"%(epoch+1,i+1,running_loss/100))
            running_loss = 0

train(10)

输出:

[11,    1] loss: 0.023
[11,  101] loss: 2.302
[11,  201] loss: 2.294
[11,  301] loss: 2.278
[11,  401] loss: 2.231
[11,  501] loss: 1.931
[11,  601] loss: 0.947
[11,  701] loss: 0.601
[11,  801] loss: 0.466
[11,  901] loss: 0.399

六、模型预测

# 模型预测结果
correct = 0
total = 0
with torch.no_grad():
    for data in test_loader:
        images,labels = data
        outputs = model(images)
        # 最大的数值及最大值对应的索引
        value,predicted = torch.max(outputs.data,1)
        total += labels.size(0)
        # 对bool型的张量进行求和操作,得到所有预测正确的样本数,采用item将整数类型的张量转换为python中的整型对象
        correct += (predicted == labels).sum().item()
    print("predicted:{}".format(predicted[:10].tolist()))
    print("label:{}".format(labels[:10].tolist()))
    print("Accuracy of the network on the 10 test images: %d %%"% (100*correct/total))

imshow(torchvision.utils.make_grid(images[:10],nrow=len(images[:10])))

输出:

predicted:[1, 0, 7, 6, 5, 2, 4, 3, 2, 6]
label:[1, 0, 7, 6, 5, 2, 4, 8, 2, 6]
Accuracy of the network on the 10 test images: 91 %

对应类别的准确率:

class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))
classes = [i for i in range(10)]

with torch.no_grad():# model.eval()
    for data in test_loader:
        images,labels = data
        outputs = model(images)
        value,predicted = torch.max(outputs,1)
        c = (predicted == labels).squeeze()
        # 对所有labels逐个进行判断
        for i in range(len(labels)):
            label = labels[i]
            class_correct[label] += c[i].item()
            class_total[label] += 1
    print("class_correct:{}".format(class_correct))
    print("class_total:{}".format(class_total))

# 每个类别的指标
for i in range(10):
    print('Accuracy of -> class %d : %2d %%'%(classes[i],100*class_correct[i]/class_total[i]))

输出:

class_correct:[958.0, 1119.0, 948.0, 938.0, 901.0, 682.0, 913.0, 918.0, 748.0, 902.0]
class_total:[980.0, 1135.0, 1032.0, 1010.0, 982.0, 892.0, 958.0, 1028.0, 974.0, 1009.0]
Accuracy of -> class 0 : 97 %
Accuracy of -> class 1 : 98 %
Accuracy of -> class 2 : 91 %
Accuracy of -> class 3 : 92 %
Accuracy of -> class 4 : 91 %
Accuracy of -> class 5 : 76 %
Accuracy of -> class 6 : 95 %
Accuracy of -> class 7 : 89 %
Accuracy of -> class 8 : 76 %
Accuracy of -> class 9 : 89 %

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/945629.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Java实现根据商品ID获取1688商品详情跨境属性数据,1688商品重量数据接口,1688API接口封装方法

要通过1688的API获取商品详情跨境属性数据&#xff0c;您可以使用1688开放平台提供的接口来实现。以下是一种使用Java编程语言实现的示例&#xff0c;展示如何通过1688开放平台API获取商品详情属性数据接口&#xff1a; 首先&#xff0c;确保您已注册成为1688开放平台的开发者…

网工内推 | IT网工,华为、华三认证优先,15k*13薪

01 广东善能科技发展股份有限公司 招聘岗位&#xff1a;IT网络工程师 职责描述&#xff1a; 1、负责公司项目售后技术支持工作&#xff1b; 2、负责项目交付实施&#xff0c;配置调试、运维等&#xff1b; 3、参加合作厂商产品技术知识培训&#xff1b; 4、参加合作厂商工程师…

运维Shell脚本小试牛刀(二)

运维Shell脚本小试牛刀(一) 运维Shell脚本小试牛刀(二) 一: if---else.....fi 条件判断演示 [rootwww shelldic]# cat checkpass.sh #!/bin/bash - # # # # FILE: checkpass.sh # USAGE: ./checkpass.sh # DESCRI…

贝叶斯人工智能大脑与 ChatGPT

文章目录 一、前言二、主要内容 &#x1f349; CSDN 叶庭云&#xff1a;https://yetingyun.blog.csdn.net/ 一、前言 论文地址&#xff1a;https://arxiv.org/abs/2308.14732 这篇论文旨在研究 Chat Generative Pre-trained Transformer&#xff08;ChatGPT&#xff09;在贝叶斯…

【AI】数学基础——高数(积分部分)

高数&#xff08;函数&微分部分&#xff09; 文章目录 1.4 微积分1.4.1 基本思想1.4.2 定积分定义定义计算定积分定积分性质定理N-L公式泰勒公式麦克劳林公式 1.5 求极值1.5.1 无条件极值1.5.2 条件极值1.5.3 多条件极值1.5.4 凹函数与凸函数 1.4 微积分 用于求解速度、面积…

开始MySQL之路——MySQL 事务(详解分析)

MySQL 事务概述 MySQL 事务主要用于处理操作量大&#xff0c;复杂度高的数据。比如说&#xff0c;在人员管理系统中&#xff0c;你删除一个人员&#xff0c;你即需要删除人员的基本资料&#xff0c;也要删除和该人员相关的信息&#xff0c;如信箱&#xff0c;文章等等&#xf…

JS小球绕着椭圆形的轨迹旋转并且近大远小

在ivx中案例如下&#xff1a; VxEditor 效果如下&#xff0c;近大远小 主要代码如下&#xff1a; const centerX 360 / 2; // 椭圆中心的X坐标 const centerY 120 / 2; // 椭圆中心的Y坐标 const a 100; // 长半轴 const b 60; // 短半轴const elementsWithClassName d…

uniapp项目实战系列(1):导入数据库,启动后端服务,开启代码托管

目录 前言前期准备1.数据库的导入2.运行后端服务2.1数据库的后端配置2.2后端服务下载依赖&#xff0c;第三方库2.3启动后端服务 3.开启gitcode代码托管 ✨ 原创不易&#xff0c;还希望各位大佬支持一下&#xff01; &#x1f44d; 点赞&#xff0c;你的认可是我创作的动力&…

虚拟化技术:云计算发展的核心驱动力

文章目录 虚拟化技术的概念和作用虚拟化技术的优势虚拟化技术对未来发展的影响结论 &#x1f389;欢迎来到AIGC人工智能专栏~虚拟化技术&#xff1a;云计算发展的核心驱动力 ☆* o(≧▽≦)o *☆嗨~我是IT陈寒&#x1f379;✨博客主页&#xff1a;IT陈寒的博客&#x1f388;该系…

RabbitMQ---Spring AMQP

Spring AMQP 1. 简介 Spring有很多不同的项目&#xff0c;其中就有对AMQP的支持&#xff1a; Spring AMQP的页面&#xff1a;http://spring.io/projects/spring-amqp 注意这里一段描述&#xff1a; Spring-amqp是对AMQP协议的抽象实现&#xff0c;而spring-rabbit 是对协…

如何在有或没有WiF适配器的情况下把台式机接入WiFi

Wi-Fi在台式电脑中越来越普遍,但并不是所有的台式电脑都有。添加Wi-Fi,你就可以无线连接到互联网,并为其他设备托管Wi-Fi热点。 这是一个简单、廉价的过程。买一个合适的小适配器,你甚至可以随身携带,通过将一个小设备插入USB端口,可以快速将Wi-Fi添加到你遇到的任何桌面…

Particle Life粒子生命演化的MATLAB模拟

Particle Life粒子生命演化的MATLAB模拟 0 前言1 基本原理1.1 力影响-吸引排斥行为1.2 距离rmax影响 2 多种粒子相互作用2.1 双种粒子作用2.1 多种粒子作用 3 代码 惯例声明&#xff1a;本人没有相关的工程应用经验&#xff0c;只是纯粹对相关算法感兴趣才写此博客。所以如果有…

Android屏幕显示 android:screenOrientation configChanges 处理配置变更

显示相关 屏幕朝向 https://developer.android.com/reference/android/content/res/Configuration.html#orientation 具体区别如下&#xff1a; activity.getResources().getConfiguration().orientation获取的是当前设备的实际屏幕方向值&#xff0c;可以动态地根据设备的旋…

Windows10 系统安装教程

多虚不如少实。 一、 下载安装包 下载前景&#xff1a;网上下载的 windows10 系统一般都有捆绑软件&#xff0c;用户体验不爽&#xff0c;所以建议到 正规渠道下载 windows10 系统的不同版本。另外网上也有一些 windows10 系统的镜像文件 可以直接一键安装&#xff0c;…

OpenShift 4 - 用 Prometheus 和 Grafana 监视用户应用定制的观测指标(视频)

《OpenShift / RHEL / DevSecOps 汇总目录》 说明&#xff1a;本文已经在 OpenShift 4.13 的环境中验证 文章目录 OpenShift 的监控功能构成部署被监控应用用 OpenShift 内置功能监控应用用 Grafana 监控应用安装 Grafana 运行环境配置 Grafana 数据源定制监控 Dashboard 演示视…

Python之动态规划

序言 最近在学习python语言&#xff0c;语言有通用性&#xff0c;此文记录复习动态规划并练习python语言。 动态规划&#xff08;Dynamic Programming&#xff09; 动态规划是运筹学的一个分支&#xff0c;是求解决策过程最优化的过程。20世纪50年代初&#xff0c;美国数学家…

上滑动导航栏手势桌面最近任务可见解密-千里马手把手带你搞定framework车载车机系统开发

建议先看另一篇blog&#xff1a; https://blog.csdn.net/learnframework/article/details/123032419 系统如何让桌面执行对应的onStart方法呢&#xff1f; 具体的堆栈显示如下&#xff1a; makeActiveIfNeeded:5788, ActivityRecord (com.android.server.wm) makeVisibleIfNe…

MOS管开关电路栅极为什么要串接电阻

在MOS管开关电路或者驱动电路中&#xff0c;常常会在MOS管的栅极串接一个电阻。 这个电阻阻值一般是几十欧姆&#xff0c;那么这个电阻有什么作用呢&#xff1f; 第一个作用就是可以限制驱动电流 &#xff0c;防止瞬间驱动电流过大导致驱动芯片驱动能力不足或者损坏。 MOS管的…

CANOCO5.0实现冗余分析(RDA)最详细步骤

在地理及生态领域会常使用RDA分析&#xff0c;RDA的实现路径也有很多&#xff0c;今天介绍一下CANOCO软件的实现方法。 1.软件安装 时间调整到2010年 2.数据处理 得有不同的物种或者样点数值&#xff0c;再加上环境因子数据。 3.软件运行 4.结果解读 结果解读主要把握这几点…

1、[春秋云镜]CVE-2022-32991

文章目录 一、相关信息二、解题思路&#xff08;手注&#xff09;三、通关思路&#xff08;sqlmap&#xff09; 一、相关信息 靶场提示&#xff1a;该CMS的welcome.php中存在SQL注入攻击。 NVD关于漏洞的描述&#xff1a; 注入点不仅在eid处&#xff01;&#xff01;&#xff…