经典模型LeNet跑Fashion-MNIST 代码解析

news2024/9/23 11:25:18

测试6.6. 卷积神经网络(LeNet) — 动手学深度学习 2.0.0 documentation

 

import torch
from torch import nn
from d2l import torch as d2l

net = nn.Sequential(
    #输入通道1表示黑白 输出通道6表示6组取不同特征的卷积核 因为卷积核是5*5,原始图片单通道黑白28*28,看图 目的是输出6组28*28的特征图,故需要四条边都padding补上2条边 左右补4 上下补4,故卷积后可以达成28*28,之后用Sigmoid使结果非线性防止最终模型可以简化成一个线性模型   
    #用Sigmoid是因为relu当时还没有发明出来
    nn.Conv2d(1, 6, kernel_size=5, padding=2), nn.Sigmoid(),
    #池化平均.使28*28按kernel_size=2 -> 2*2的特征做平均为14*14大小的特征图,步长stride=2使得汇聚窗口不会互相重叠
    nn.AvgPool2d(kernel_size=2, stride=2),
    #以输入通道为6 输出通道为16模式,卷积核有16组的5*5做卷积运算,之后激活函数Sigmoid使之非线性
    nn.Conv2d(6, 16, kernel_size=5), nn.Sigmoid(),
    #继续/2平均后
    nn.AvgPool2d(kernel_size=2, stride=2),
    #此时是16组5*5的特征图,  Flatten拉成一条的向量16*5*5
    nn.Flatten(),
    #做以16*5*5做输入,120输出做一次线性 变成隐藏层1再nn.Sigmoid()非线性
    nn.Linear(16 * 5 * 5, 120), nn.Sigmoid(),
    #做以120做输入,84输出再做一次线性 变成隐藏层2再nn.Sigmoid()非线性
    nn.Linear(120, 84), nn.Sigmoid(),
    #最后以84输入,10输出做一次线性得到结果Y, 此时Y是10组不同分类中的特征值
    nn.Linear(84, 10))

总体来说就是
1.一张单通道28*28的图片,通过第一次卷积增大特征量->变成6组不同特征量的28*28的C1特征图(6@28*28),之后浓缩特征,进行池化平均->变成6组14*14的S2特征图(6@14*14)->以6组不同特征的图当做6个输入通道,以16组输出通道和5*5的特征和输出成16组10*10的C3特征图(16@10*10)->池化平均16@5*5的S4特征图
2.16@5*5的S4特征图开始做拉成一组大向量做线性的梯度下降,先拉成长度120做一次(120,84)的线性->再做(84,10)的线性->最后输出10种类别的特征值

开始测试 网络结构

X = torch.rand(size=(1, 1, 28, 28), dtype=torch.float32)
for layer in net:
    X = layer(X)
    print(layer.__class__.__name__,'output shape: \t',X.shape)
#测试每一层是否有错误

#
Conv2d output shape:         torch.Size([1, 6, 28, 28])
Sigmoid output shape:        torch.Size([1, 6, 28, 28])
AvgPool2d output shape:      torch.Size([1, 6, 14, 14])
Conv2d output shape:         torch.Size([1, 16, 10, 10])
Sigmoid output shape:        torch.Size([1, 16, 10, 10])
AvgPool2d output shape:      torch.Size([1, 16, 5, 5])
Flatten output shape:        torch.Size([1, 400])
Linear output shape:         torch.Size([1, 120])
Sigmoid output shape:        torch.Size([1, 120])
Linear output shape:         torch.Size([1, 84])
Sigmoid output shape:        torch.Size([1, 84])
Linear output shape:         torch.Size([1, 10])

读数据

batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size=batch_size)

评估方法test 

def evaluate_accuracy_gpu(net, data_iter, device=None): #@save 评估方法test
    """使用GPU计算模型在数据集上的精度"""
    if isinstance(net, nn.Module):
        net.eval()  # 设置为评估模式
        if not device:#调用gpu跑
            device = next(iter(net.parameters())).device
    # 正确预测的数量,总预测的数量
    metric = d2l.Accumulator(2)#累加器用来放置数据 累加
    with torch.no_grad():
        for X, y in data_iter:
            if isinstance(X, list):
                # BERT微调所需的(之后将介绍)
                X = [x.to(device) for x in X]#X放gpu
            else:
                X = X.to(device)
            y = y.to(device)#y放gpu
            metric.add(d2l.accuracy(net(X), y), y.numel())#[0]放每次跑的预测准确的数,[1]放每次跑的总数
    return metric[0] / metric[1] #返回准确的数/总数 即准确率

训练函数 train

#@save
def train_ch6(net, train_iter, test_iter, num_epochs, lr, device):
    """用GPU训练模型(在第六章定义)"""
    def init_weights(m):#定义init_weights函数用于初始化数据
        if type(m) == nn.Linear or type(m) == nn.Conv2d:
            nn.init.xavier_uniform_(m.weight)#可以初始化我们的权重 使之符合均值0,方差1,否则会出现梯度爆炸或者消失等严重问题 跑不出结果
    net.apply(init_weights)#选定自定义的参数初始化函数 
    print('training on', device)
    net.to(device)
    optimizer = torch.optim.SGD(net.parameters(), lr=lr)#优化函数用sgd,对net.parameters()进行梯度下降
    loss = nn.CrossEntropyLoss()#分类问题 采用softmax的交叉熵函数
    animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs],
                            legend=['train loss', 'train acc', 'test acc'])#画图内容
    timer, num_batches = d2l.Timer(), len(train_iter)
    for epoch in range(num_epochs):
        # 训练损失之和,训练准确率之和,样本数
        metric = d2l.Accumulator(3)#3位置的累加器
        net.train()#将模型设置为训练模式:默认参数是Train。model.train()会启动drop 和 BN,但是model.train(False)不会
        for i, (X, y) in enumerate(train_iter): #enumerate() 函数用于将一个可遍历的数据对象(如列表、元组或字符串)组合为一个索引序列,同时列出数据和数据下标,一般用在 for 循环当中。
            timer.start()
            optimizer.zero_grad()#对不同的train_iter中出来的X,y 每次清除梯度
            X, y = X.to(device), y.to(device)
            y_hat = net(X)
            l = loss(y_hat, y)
            l.backward()#反向传播计算各参数当前梯度比如y对w求梯度
            optimizer.step()#step()之后才会更新参数值 比如x=x−lr∗x.grad
            with torch.no_grad():#累加器加入数据的时候不算入梯度计算范围
                #加入的分别是 损失*批量大小, *这个看下面解析* ,批量大小
                metric.add(l * X.shape[0], d2l.accuracy(y_hat, y), X.shape[0])

            timer.stop()
            train_l = metric[0] / metric[2]#损失率
            train_acc = metric[1] / metric[2]#准确率
            if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:#画图
                animator.add(epoch + (i + 1) / num_batches,
                             (train_l, train_acc, None))
        test_acc = evaluate_accuracy_gpu(net, test_iter)#训练完每一个train_iter的数据后再拿完整的训练好的参数的net验证准确率
        animator.add(epoch + 1, (None, None, test_acc))
    print(f'loss {train_l:.3f}, train acc {train_acc:.3f}, '
          f'test acc {test_acc:.3f}')
    print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec '
          f'on {str(device)}')

设置学习率开跑 

lr, num_epochs = 0.9, 10
train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())

附录

 交叉熵有点忘了看:

交叉熵损失函数(Cross Entropy Loss)_SongGu1996的博客-CSDN博客

1.解析该accuracy函数:

看不懂的时候 把函数从源文件复制出来调试 方法d2l.那里删了

def accuracy(y_hat, y):
    """Compute the number of correct predictions.

    Defined in :numref:`sec_softmax_scratch`"""
    if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:
        y_hat = d2l.argmax(y_hat, axis=1)
    cmp = d2l.astype(y_hat, y.dtype) == y #强制换成某类型
    return float(d2l.reduce_sum(d2l.astype(cmp, y.dtype))) #d2l.astype(cmp, y.dtype)全bool转0,1 转为数字
    #reduce_sum 全部数加起来再转float

感觉有图能看懂了

 

 

2.打点调试  查看卷积核的值学习过程  #net[0].weight   net[3].weight

初始值:初始创建net的时候自动生成全部数据

初始化的时候init再打点

 跳入初始化

初始化跳回去之后发现

 成功初始化为均值0 方差1的卷积核了

再接着跑epoch的时候调试爆了 不知道是因为cpu还是内存不够爆的 重新取消了其他断点,只点了一个

重新的一个数据  与上面的不一样

  可以看出梯度更新卷积核的参数第一行第一个一看好像没变,看看第二行1198->1197 是变了

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/392503.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

面向对象设计模式:行为型模式之模板方法模式

一、模板方法引入:泡茶与冲咖啡 泡茶 烧水泡茶倒入杯子加入柠檬 冲咖啡 烧水冲咖啡倒入杯子加入牛奶和糖 二、模板方法,TemplateMethod 2.1 Intent 意图 Define the skeleton of an algorithm in an operation, deferring some steps to lets subclas…

【深度学习】BERT变体—BERT-wwm

1.BERT-wwm 1-1 Whole Word Masking Whole Word Masking (wwm)是谷歌在2019年5月31日发布的一项BERT的升级版本,主要更改了原预训练阶段的训练样本生成策略。 原有基于WordPiece的分词方式会把一个完整的词切分成若干个子词,在生成训练样本时&#xff…

路由传参含对象数据刷新页面数据丢失

目录 一、问题描述 二、 解决办法 一、问题描述 【1】众所周知,在veu项目开发过程中,我们常常会用到通过路由的方式在页面中传递数据。但是用到this.$route.query.ObjectData的页面,刷新后会导致this.$route.query.ObjectData数据丢失。 …

(小甲鱼python)函数笔记合集七 函数(IX)总结 python实现汉诺塔详解

一、基础复习 函数的基本用法 创建和调用函数 函数的形参与实参等等函数的几种参数 位置参数、关键字参数、默认参数等函数的收集参数*args **args 解包参数详解函数中参数的作用域 局部作用域 全局作用域 global语句 嵌套函数 nonlocal语句等详解函数的闭包(工厂函…

【LeetCode每日一题】——1323.6 和 9 组成的最大数字

文章目录一【题目类别】二【题目难度】三【题目编号】四【题目描述】五【题目示例】六【解题思路】七【题目提示】八【时间频度】九【代码实现】十【提交结果】一【题目类别】 贪心算法 二【题目难度】 简单 三【题目编号】 1323.6 和 9 组成的最大数字 四【题目描述】 …

【mediasoup】RtpStreamRecv 对rtp 序号的验证

mediasoup 接收到rtp包D:\XTRANS\soup\mediasoup_offical\worker\src\RTC\RtpStreamRecv.cpp代码竟然跟 https://tools.ietf.org/html/rfc3550#appendix-A.1 stuff. 一样的。RtpStreamRecv的 ReceivePacket(RTC::RtpPacket* packet) 处理收到的rtp包 可能会丢弃 判断丢帧 回卷后…

项目团队沟通管理 5大沟通原则

1、沟通内外有别 沟通需要区分团队内和团队外,在团队对外进行沟通时,团队作为一个整体,对外意见需要一致,一个团队需用一种声音说话。 沟通管理5大原则:沟通内外有别​ 2、重视非正式沟通 非正式的沟通有助于关…

FUNIT

无监督图像到图像转换方法学习将给定类中的图像映射到不同类中的类似图像,使用非结构化(非注册)图像数据集。虽然非常成功,但目前的方法需要在训练时访问源类和目标类中的许多图像。我们认为这极大地限制了它们的使用。从人类从少量示例中提取新对象的本…

用报废耳机自制助听器

平常戴着跑步的外挂耳机被洗衣机洗了,不是进水马上捞起来的那种洗,就是全机包括充电盒都在洗衣机里走了一遍洗衣的流程,在晾晒衣服时才发现衣兜里的这付耳机,再进行啥挽救处理都已为时过晚,好在这付耳机并不贵&#xf…

HBuilderX无线连接真机

说明 安装的是HBuilderX,不是HBuilder,adb.exe所在目录是 x:\HBuilderX\plugins\launcher\tools\adbs\ 里面可能有其他版本,用哪个都,建议使用最新的 配置 首先,将真机使用USB连接到电脑上。 在adb目录中启动命令…

iOS设备管理器有人推荐iTunes,有人推荐iMazing,到底如何选择

一说到iTunes软件,想必苹果用户都不会感觉陌生,它为我们在iPhone、iPad等iOS设备和电脑之间进行文件传输提供了便利,但它并没有那么好用,有时甚至让人抓狂。那我们今天就来分享一款可以取代iTunes的良心好软——iMazing&#xff0…

u盘扫描并修复后文件消失了怎么办?2种方法帮助找回

演示机型:技嘉 H310M HD22.0系统版本:Windows 10 专业版软件版本:云骑士数据恢复软件3.21.0.17案例分享:“我的u盘每次插电脑都会弹出要不要扫描并修复的提示窗口,不懂,然后不小心选择了“扫描并修复”&…

并发编程——CAS

如果有兴趣了解更多相关内容的话,可以来我的个人网站看看:耶瞳空间 一:前言 首先看一个案例:我们开发一个网站,需要对访问量进行统计,用户每发送一次请求,访问量1,如何实现&#x…

前端都在聊什么 - 第 4 期

Hello 小伙伴们早上、中午、下午、晚上、深夜好,我是爱折腾的 jsliang~「前端都在聊什么」是 jsliang 日常写文章/做视频/玩直播过程中,小伙伴们的提问以及我的解疑整理。本文章视频同步:TODO:本期对应 2023.01.28 当天直播间的粉丝互动。主要…

关于Scipy的概念和使用方法及实战

关于scipy的概念和使用方法 什么是Scipy Scipy是一个基于Python的科学计算库,它提供了许多用于数学、科学、工程和技术计算的工具和函数。Scipy的名称是“Scientific Python”的缩写。 Scipy包含了许多子模块,其中一些主要的子模块包括: …

eBPF(内核态)和WebAssembly

1 什么是eBPF 无需修改内核,也不用加载内核模块,程序员就可以在内核中执行执行自定义的字节码。 eBPF,它的全称是“Extended Berkeley Packet Filter”, 网络数据包过滤模块。我们很熟悉的 tcpdump 工具,它就是利用了…

Bellman-ford和SPFA算法

目录 一、前言 二、Bellman-ford算法 1、算法思想 2、算法复杂度 3、判断负圈 4、出差(2022第十三届国赛,lanqiaoOJ题号2194) 三、SPFA算法:改进的Bellman-Ford 1、随机数据下的最短路问题(lanqiaoOJ题号1366&…

xss靶场绕过

目录 第一关 原理 payload 第二关 原理 payload 第三关 原理 payload 第四关 原理 payload 第五关 原理 payload 第六关 原理 payload 第七关 原理 payload 第八关 原理 payload 第九关 原理 payload 第十关 原理 payload 第十一关 原理 payl…

Ubuntu 虚拟机 安装nvidia驱动失败,进不了系统

VMware 安装的 Ubuntu 1804 安装 英伟达显卡失败后,启动出现:在上面那个页面,直接使用组合键:Ctrl Alt F3 便可以进入命令行模式。如果可以成功进入,则说明ubantu系统确实起来了,只是界面相关的模块没有成…

Win32api学习之常见编码格式(一)

ASCII编码 ASCII编码是一种最早出现的字符编码方案,它是由美国标准化协会(ASA)于1963年制定的标准,用于在计算机系统中表示英语文本字符集。ASCII编码仅使用7位二进制数(共128个),用于表示英文…