Pytorch学习笔记day4——训练mnist数据集和初步研读

news2024/9/20 22:49:24

该来的还是来了hhhhhhhhhh,基本上机器学习的初学者都躲不开这个例子。开源,数据质量高,数据尺寸整齐,问题简单,实在太适合初学者食用了。

今天把代码跑通,趁着周末好好的琢磨一下里面的各种细节。

代码实现

首先鸣谢百度AI,真的直接生成的代码就能跑,不要太爽。差不多九年前大二的时候,这一点点代码,是要看完一个几小时的英文视频才能获取的。看着网络非常非常浅,就已经达到了比较好的预测效果。

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
 
# 定义模型
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5) #输入为1,输出为10,卷积核大小5
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.fc = nn.Linear(20 * 4 * 4, 10)
 
    def forward(self, x):
        batch_size = x.size(0)   #第一个维度是batch维度,图片为1*28*28时,输入为64*1*28*28
        x = torch.relu(self.conv1(x))  # 输入64*1*28*28, 输出64*10*24*24
        x = torch.max_pool2d(x, 2, 2)  # 输入64*10*24*24, 输出64*10*12*12,池化层
        x = torch.relu(self.conv2(x))  # 输入64*10*12*12, 输出64*20*8*8
        x = torch.max_pool2d(x, 2, 2)  # 输入64*20*8*8, 输出64*20*4*4
        x = x.view(batch_size, -1)     # 输入64*20*4*4, 输出64*320
        x = self.fc(x)                 # 输入64*320, 输出64*10
        return x

if __name__=="__main__":
    # 定义超参数
    batch_size = 64
    epochs = 10
    learning_rate = 0.01
     
    # 数据预处理
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    # 加载训练/测试数据  batch_size:每次训练的规模  shuffle: 是否每次训练完对数据进行洗牌
    train_dataset = datasets.MNIST('data', train=True, download=True, transform=transform)
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_dataset = datasets.MNIST('data', train=False, transform=transform)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
     
    # 实例化模型、损失函数和优化器
    model = Net()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    criterion = nn.CrossEntropyLoss()
     
    # 训练模型
    for epoch in range(epochs):
        for batch_idx, (data, target) in enumerate(train_loader): #自动打batch
            optimizer.zero_grad()   #典型的训练步骤
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            if batch_idx % 100 == 0:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, batch_idx * len(data), len(train_loader.dataset),
    0. * batch_idx / len(train_loader), loss.item()))
     
    # 测试模型
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            output = model(data)
            test_loss += criterion(output, target).item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
    test_loss /= len(test_loader.dataset)
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset), 100. * correct / len(test_loader.dataset)))

运行结果如下:

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to data\MNIST\raw\train-images-idx3-ubyte.gz
100%|██████████| 9912422/9912422 [02:41<00:00, 61401.03it/s]
Extracting data\MNIST\raw\train-images-idx3-ubyte.gz to data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to data\MNIST\raw\train-labels-idx1-ubyte.gz
100%|██████████| 28881/28881 [00:00<00:00, 97971.03it/s]
Extracting data\MNIST\raw\train-labels-idx1-ubyte.gz to data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to data\MNIST\raw\t10k-images-idx3-ubyte.gz
100%|██████████| 1648877/1648877 [00:29<00:00, 56423.58it/s]
Extracting data\MNIST\raw\t10k-images-idx3-ubyte.gz to data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to data\MNIST\raw\t10k-labels-idx1-ubyte.gz
100%|██████████| 4542/4542 [00:00<00:00, 4339528.19it/s]
Extracting data\MNIST\raw\t10k-labels-idx1-ubyte.gz to data\MNIST\raw

Train Epoch: 0 [0/60000 (0%)]	Loss: 2.275243
Train Epoch: 0 [6400/60000 (0%)]	Loss: 0.200208
Train Epoch: 0 [12800/60000 (0%)]	Loss: 0.064670
Train Epoch: 0 [19200/60000 (0%)]	Loss: 0.066074
Train Epoch: 0 [25600/60000 (0%)]	Loss: 0.115960
Train Epoch: 0 [32000/60000 (0%)]	Loss: 0.171170
Train Epoch: 0 [38400/60000 (0%)]	Loss: 0.041663
Train Epoch: 0 [44800/60000 (0%)]	Loss: 0.179172
Train Epoch: 0 [51200/60000 (0%)]	Loss: 0.014898
Train Epoch: 0 [57600/60000 (0%)]	Loss: 0.035095
Train Epoch: 1 [0/60000 (0%)]	Loss: 0.016566
Train Epoch: 1 [6400/60000 (0%)]	Loss: 0.008371
Train Epoch: 1 [12800/60000 (0%)]	Loss: 0.006069
Train Epoch: 1 [19200/60000 (0%)]	Loss: 0.009995
Train Epoch: 1 [25600/60000 (0%)]	Loss: 0.020422
Train Epoch: 1 [32000/60000 (0%)]	Loss: 0.155348
Train Epoch: 1 [38400/60000 (0%)]	Loss: 0.059595
Train Epoch: 1 [44800/60000 (0%)]	Loss: 0.038654
Train Epoch: 1 [51200/60000 (0%)]	Loss: 0.084179
Train Epoch: 1 [57600/60000 (0%)]	Loss: 0.147250
Train Epoch: 2 [0/60000 (0%)]	Loss: 0.040161
Train Epoch: 2 [6400/60000 (0%)]	Loss: 0.147080
Train Epoch: 2 [12800/60000 (0%)]	Loss: 0.037228
Train Epoch: 2 [19200/60000 (0%)]	Loss: 0.257872
Train Epoch: 2 [25600/60000 (0%)]	Loss: 0.052811
Train Epoch: 2 [32000/60000 (0%)]	Loss: 0.005805
Train Epoch: 2 [38400/60000 (0%)]	Loss: 0.092318
Train Epoch: 2 [44800/60000 (0%)]	Loss: 0.084066
Train Epoch: 2 [51200/60000 (0%)]	Loss: 0.000331
Train Epoch: 2 [57600/60000 (0%)]	Loss: 0.011482
Train Epoch: 3 [0/60000 (0%)]	Loss: 0.042851
Train Epoch: 3 [6400/60000 (0%)]	Loss: 0.004001
Train Epoch: 3 [12800/60000 (0%)]	Loss: 0.008942
Train Epoch: 3 [19200/60000 (0%)]	Loss: 0.045065
Train Epoch: 3 [25600/60000 (0%)]	Loss: 0.099309
Train Epoch: 3 [32000/60000 (0%)]	Loss: 0.054098
Train Epoch: 3 [38400/60000 (0%)]	Loss: 0.059155
Train Epoch: 3 [44800/60000 (0%)]	Loss: 0.016098
Train Epoch: 3 [51200/60000 (0%)]	Loss: 0.114458
Train Epoch: 3 [57600/60000 (0%)]	Loss: 0.231477
Train Epoch: 4 [0/60000 (0%)]	Loss: 0.003781
Train Epoch: 4 [6400/60000 (0%)]	Loss: 0.068822
Train Epoch: 4 [12800/60000 (0%)]	Loss: 0.103501
Train Epoch: 4 [19200/60000 (0%)]	Loss: 0.002396
Train Epoch: 4 [25600/60000 (0%)]	Loss: 0.174503
Train Epoch: 4 [32000/60000 (0%)]	Loss: 0.027796
Train Epoch: 4 [38400/60000 (0%)]	Loss: 0.013167
Train Epoch: 4 [44800/60000 (0%)]	Loss: 0.011576
Train Epoch: 4 [51200/60000 (0%)]	Loss: 0.000726
Train Epoch: 4 [57600/60000 (0%)]	Loss: 0.069251
Train Epoch: 5 [0/60000 (0%)]	Loss: 0.006919
Train Epoch: 5 [6400/60000 (0%)]	Loss: 0.015165
Train Epoch: 5 [12800/60000 (0%)]	Loss: 0.117820
Train Epoch: 5 [19200/60000 (0%)]	Loss: 0.031030
Train Epoch: 5 [25600/60000 (0%)]	Loss: 0.031566
Train Epoch: 5 [32000/60000 (0%)]	Loss: 0.046268
Train Epoch: 5 [38400/60000 (0%)]	Loss: 0.055709
Train Epoch: 5 [44800/60000 (0%)]	Loss: 0.021299
Train Epoch: 5 [51200/60000 (0%)]	Loss: 0.004246
Train Epoch: 5 [57600/60000 (0%)]	Loss: 0.014340
Train Epoch: 6 [0/60000 (0%)]	Loss: 0.056358
Train Epoch: 6 [6400/60000 (0%)]	Loss: 0.104084
Train Epoch: 6 [12800/60000 (0%)]	Loss: 0.097005
Train Epoch: 6 [19200/60000 (0%)]	Loss: 0.009379
Train Epoch: 6 [25600/60000 (0%)]	Loss: 0.078417
Train Epoch: 6 [32000/60000 (0%)]	Loss: 0.217889
Train Epoch: 6 [38400/60000 (0%)]	Loss: 0.079795
Train Epoch: 6 [44800/60000 (0%)]	Loss: 0.052873
Train Epoch: 6 [51200/60000 (0%)]	Loss: 0.127716
Train Epoch: 6 [57600/60000 (0%)]	Loss: 0.087016
Train Epoch: 7 [0/60000 (0%)]	Loss: 0.045884
Train Epoch: 7 [6400/60000 (0%)]	Loss: 0.087923
Train Epoch: 7 [12800/60000 (0%)]	Loss: 0.164549
Train Epoch: 7 [19200/60000 (0%)]	Loss: 0.111163
Train Epoch: 7 [25600/60000 (0%)]	Loss: 0.300172
Train Epoch: 7 [32000/60000 (0%)]	Loss: 0.045357
Train Epoch: 7 [38400/60000 (0%)]	Loss: 0.087294
Train Epoch: 7 [44800/60000 (0%)]	Loss: 0.110581
Train Epoch: 7 [51200/60000 (0%)]	Loss: 0.001932
Train Epoch: 7 [57600/60000 (0%)]	Loss: 0.066714
Train Epoch: 8 [0/60000 (0%)]	Loss: 0.047415
Train Epoch: 8 [6400/60000 (0%)]	Loss: 0.106327
Train Epoch: 8 [12800/60000 (0%)]	Loss: 0.016832
Train Epoch: 8 [19200/60000 (0%)]	Loss: 0.013452
Train Epoch: 8 [25600/60000 (0%)]	Loss: 0.035256
Train Epoch: 8 [32000/60000 (0%)]	Loss: 0.026502
Train Epoch: 8 [38400/60000 (0%)]	Loss: 0.011809
Train Epoch: 8 [44800/60000 (0%)]	Loss: 0.171943
Train Epoch: 8 [51200/60000 (0%)]	Loss: 0.209570
Train Epoch: 8 [57600/60000 (0%)]	Loss: 0.047113
Train Epoch: 9 [0/60000 (0%)]	Loss: 0.126423
Train Epoch: 9 [6400/60000 (0%)]	Loss: 0.016720
Train Epoch: 9 [12800/60000 (0%)]	Loss: 0.210951
Train Epoch: 9 [19200/60000 (0%)]	Loss: 0.072410
Train Epoch: 9 [25600/60000 (0%)]	Loss: 0.042366
Train Epoch: 9 [32000/60000 (0%)]	Loss: 0.002912
Train Epoch: 9 [38400/60000 (0%)]	Loss: 0.074261
Train Epoch: 9 [44800/60000 (0%)]	Loss: 0.004673
Train Epoch: 9 [51200/60000 (0%)]	Loss: 0.074964
Train Epoch: 9 [57600/60000 (0%)]	Loss: 0.040360

Test set: Average loss: 0.0011, Accuracy: 9795/10000 (98%)

部分解读

下面这个语法是定义了一个二维卷积层,

nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros')

可以参考一下这篇博客 https://blog.csdn.net/qq_60245590/article/details/135856418
百度AI也给出了解释
在这里插入图片描述
训练数据是python实时从网上下载的,打开看看,里面还挺东西,应该最主要的就是训练数据和测试数据。可是这样的话,为啥要分布下载个train_dataset和test_dataset呢?我略有些迷茫。
在这里插入图片描述
batch居然不用我们自己打,咦?这个功能mindspore有吗?我自己捏的数据能自动打batch吗?能的话就很方便了。
在这里插入图片描述
好!今天崩铁前瞻~打游戏去咯~

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1935945.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

C++内存管理(区别C语言)深度对比

欢迎来到我的Blog&#xff0c;点击关注哦&#x1f495; 前言 前面已经介绍了类和对象&#xff0c;对C面向对象编程已经有了全面认识&#xff0c;接下来要学习对语言学习比较重要的是对内存的管理。 一、内存的分区 代码区&#xff1a;存放程序的机器指令&#xff0c;通常是可…

js实现数组的下标为n的对象后面新增一条对象

前言&#xff1a; js实现数组的下标为n的对象后面新增一条对象 实现方法&#xff1a; arr.splice(1, 0, obj); splice 参数1: 数组里面的第几个元素&#xff0c;你希望在第几个对象后面新增参数2: 0 表示不删除任何元素参数3: 插入的新对象 let arr [{},{},{},{}] let obj…

vue使用echarts开发大屏可视化(附echarts案例资源)

近年来&#xff0c;可视化在前端领域是越来越多。最近投入的一个项目就是关于大屏可视化&#xff0c;基本就是用到了echarts&#xff0c;所以项目结束后&#xff0c;我也来总结一下如何在Vue中去引入echarts并使用。 文章目录 一、echarts案例网站可视化社区(https://www.makea…

Zoho Mail企业邮箱好用吗?

企业在选择企业邮箱时需要考虑三大因素&#xff0c;一是安全隐私&#xff0c;二是功能易用&#xff0c;三是产品价格。作为国际排行前五的企业邮箱&#xff0c;Zoho邮箱好用吗&#xff1f;本文将为您详细介绍Zoho邮箱的功能、安全性和产品价格。 一、安全隐私 1、数据加密与安…

MySQL----初始数据类型

前言 一、tinyint 范围&#xff1a;-128-----127 在MySQL中&#xff0c;整型可以指定是有符号的和无符号的&#xff0c;默认是有符号的。可以通过UNSIGNED来说明某个字段是无符号的。如果我们向mysqlt特定的类型中插入不合法的数据&#xff0c;Mysq一般会直接拦截&#xff0c…

【HarmonyOS学习】定位相关知识(Locationkit)

简介 LocationKit提供了定位服务、地理围栏、地理编码、逆地理编码和国家码等功能。 可以实现点击获取用户位置信息、持续获取位置信息和区域进出监控等多项功能。 需要注意&#xff0c;需要确定用户已经开启定位信息&#xff0c;一下的代码没有做这一步的操作&#xff0c;默…

p17面试题

品茗面试题 1.交换两个int变量的值&#xff0c;不能使用第三个变量&#xff0c;即a3,b5,交换后&#xff0c;a5,b3&#xff1b; #include<stdio.h> //int main(){ // //打印函数&#xff0c;引用头文件.stdio.h // printf("hello world\n");//打印函数 …

C++STL详解(二)——string类的模拟实现

首先&#xff0c;我们为了防止命名冲突&#xff0c;我们需要在自己的命名空间内实现string类。 一.string类基本结构 string类的基本结构和顺序表是相似的&#xff0c;结构如下&#xff1a; //.h namespace kuzi {class string{private:char* _str;//字符串size_t _size;//长…

算法基础之回溯法

本文将详细介绍回溯法的基本原理和适用条件&#xff0c;并通过经典例题辅助读者理解回溯法的思想、掌握回溯法的使用。本文给出的例题包括&#xff1a;N皇后问题、子集和问题。 算法原理 在问题的解空间树中&#xff0c;回溯法按照深度优先的搜索策略&#xff0c;从根结点出发…

LDR6020:重塑iPad一体式有线键盘体验的创新力量

在移动办公与娱乐日益融合的时代&#xff0c;iPad凭借其强大的性能和便携性&#xff0c;成为了众多用户不可或缺的生产力工具。然而&#xff0c;为了进一步提升iPad的使用体验&#xff0c;一款高效、便捷的键盘成为了不可或缺的配件。今天&#xff0c;我们要介绍的&#xff0c;…

TYPE-C接口PD取电快充协议芯片ECP5701:支持PD 2.0和PD 3.0(5V,9V,12V,15V,20V)

随着智能设备的普及&#xff0c;快充技术成为了越来越多用户的刚需。而TYPE-C接口作为新一代的USB接口&#xff0c;具有正反插、传输速度快、充电体验好等优点&#xff0c;已经成为了快充技术的主要接口形式。而TYPE-C接口的PD&#xff08;Power Delivery&#xff09;取电快充协…

【数据结构】线性结构——数组、链表、栈和队列

目录 前言 一、数组&#xff08;Array&#xff09; 1.1优点 1.2缺点 1.3适用场景 二、链表&#xff08;Linked List&#xff09; 2.1优点 2.2缺点 2.3适用场景 三、栈&#xff08;Stack&#xff09; 3.1优点 3.2缺点 3.3适用场景 四、队列&#xff08;Queue&#xff09; 4.1优点…

【python】Python高阶函数--reduce函数的高阶用法解析与应用实战

✨✨ 欢迎大家来到景天科技苑✨✨ &#x1f388;&#x1f388; 养成好习惯&#xff0c;先赞后看哦~&#x1f388;&#x1f388; &#x1f3c6; 作者简介&#xff1a;景天科技苑 &#x1f3c6;《头衔》&#xff1a;大厂架构师&#xff0c;华为云开发者社区专家博主&#xff0c;…

Redis常用的5大数据类型

Reids字符串&#xff08;String&#xff09; 设置相同的key&#xff0c;之前内容会覆盖掉 Redis列表&#xff08;List&#xff09; 常用命令 从左往右放值 数据结构 Redis集合&#xff08;set&#xff09; sadd<key><value1><value2>...... 数据结构 Set数据…

前端组件化开发:以Vue自定义底部操作栏组件为例

摘要 随着前端技术的不断演进&#xff0c;组件化开发逐渐成为提升前端开发效率和代码可维护性的关键手段。本文将通过介绍一款Vue自定义的底部操作栏组件&#xff0c;探讨前端组件化开发的重要性、实践过程及其带来的优势。 一、引言 随着Web应用的日益复杂&#xff0c;传统的…

「豆包Marscode体验官」 | 云端 IDE 启动 Rust 体验

theme: cyanosis 我正在参加「豆包MarsCode初体验」征文活动 MarsCode 可以看作一个运行在服务端的远程 VSCode开发环境。 对于我这种想要学习体验某些语言&#xff0c;但不想在电脑里装环境的人来说非常友好。本文就来介绍一下在 MarsCode里&#xff0c;我的体验 rust 开发体验…

Games101学习笔记 Lecture22 Animation(cont.)

Lecture22 Animation(cont. 一、单个粒子模拟Ordinary Differential Equation ODE 常微分方程ODE求解方法——欧拉方法解决不稳定中点法改进欧拉方法自适应步长隐式欧拉方法 二、流体模拟基于位置的方法物质点方法 一、单个粒子模拟 想模拟粒子在场中的运动 Ordinary Differe…

Token Labeling(NeurIPS 2021, ByteDance)论文解读

paper&#xff1a;All Tokens Matter: Token Labeling for Training Better Vision Transformers official implementation&#xff1a;https://github.com/zihangJiang/TokenLabeling 出发点 ViTs的局限性&#xff1a;尽管ViTs在捕捉长距离依赖方面表现出色&#xff0c; 但…

代码随想录算法训练营第五十八天|108.冗余连接、109.冗余连接II

108.冗余连接 题目链接&#xff1a;108.冗余连接 文档讲解&#xff1a;代码随想录 状态&#xff1a;还行 思路&#xff1a; 并查集可以解决什么问题&#xff1a;两个节点是否在一个集合&#xff0c;也可以将两个节点添加到一个集合中。 题解&#xff1a; public class Main {p…

套用BI方案做数据可视化是种什么体验?

在数字化转型的浪潮中&#xff0c;数据可视化作为连接数据与决策的桥梁&#xff0c;其重要性日益凸显。近期&#xff0c;我有幸体验了奥威BI方案进行数据可视化的全过程&#xff0c;这不仅是一次技术上的探索&#xff0c;更是一次对高效、智能数据分析的深刻感受。 初识奥威&a…