37.卷积神经网络(LeNet)的代码实现(在colab上)

news2024/11/24 19:15:04

ps:在教材上直接打开colab,运行原来的代码!pip install git+https://github.com/d2l-ai/d2l-zh@release # installing d2l是会报错的,改成这句代码,可以正确运行:!pip install d2l==0.14.,直接制定了d2l的版本

1. LeNet

总体来看,LeNet(LeNet-5)由两个部分组成:

  • 卷积编码器:由两个卷积层组成;

  • 全连接层密集块:由三个全连接层组成。

实例化一个Sequential块并将需要的层连接在一起。

import torch
from torch import nn
from d2l import torch as d2l

class Reshape(torch.nn.Module): # 这个函数是给输入用的
    def forward(self,x):
        # view函数是为了改变tensor形状为(batchsize,channels,x,y)
        return x.view(-1,1,28,28) # 28*28是输入图片大小
    
net = torch.nn.Sequential(
    # torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True)
    Reshape(),nn.Conv2d(1,6,kernel_size=5,padding=2),nn.Sigmoid(), # 为了得到非线性,要在卷积后面加上sigmoid激活函数
    nn.AvgPool2d(kernel_size=2,stride=2), # 也可以写成nn.AvgPool2d(2,stride=2),第一个参数不用写参数名
    nn.Conv2d(6,16,kernel_size=5),nn.Sigmoid(),
    nn.AvgPool2d(kernel_size=2,stride=2),nn.Flatten(),# Flatten()将第一维(批量)保持住,其他展平为一个维度,输入到多层感知机
    nn.Linear(16 * 5 * 5 , 120),nn.Sigmoid(), # 线性层,输入时400,输出是120,用sigmoid激活一下
    nn.Linear(120, 84),nn.Sigmoid(), # 再把120 降到 84
    nn.Linear(84,10)) # 输入为 84, 输出为10

可以看到,最后是一个3层的,有两个隐藏层的多层感知机,前面是两个卷积层,每个卷积层后面有一个激活层和一个池化层。

2. 检查模型

检查模型,以确保其操作与我们期望的 图一致。

在这里插入图片描述

X = torch.rand(size=(1,1,28,28), dtype = torch.float32)
for layer in net:
    X = layer(X)
    print(layer.__class__.__name__,'output shape:\t' , X.shape)

运行结果如下:
在这里插入图片描述

在这里,卷积的作用就是把每一层输出变小,通道变多。每一个通道信息可以认为是一个空间的pattern(模式),不断地把空间信息压缩变小,通道数变多,可以把抽出来的压缩的信息放在不同的通道里面,最后MLP就把所有的模式拿出来然后训练到最后的输出。

3. 模型训练

现在我们已经实现了LeNet,让我们看看LeNet在Fashion-MNIST数据集上的表现。

batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size=batch_size)

虽然卷积神经网络的参数较少,但与深度的多层感知机相比,它们的计算成本仍然很高,因为每个参数都参与更多的乘法。 通过使用GPU,可以用它加快训练。

# 如果模型已经在gpu上了,计算精度会在gpu上做
def evaluate_accuracy_gpu(net, data_iter, device=None): 
    """使用GPU计算模型在数据集上的精度"""
    # isinstance(object, classinfo):object – 实例对象,classinfo – 可以是直接或间接类名、基本类型或者由它们组成的元组
    # 返回值:如果对象的类型与参数二的类型(classinfo)相同则返回 True,否则返回 False。
    if isinstance(net, nn.Module):
      # 如果net是 troch.nn实现的版本,与之相对的是手写的版本
        net.eval()  # 设置为评估模式,不用计算和更新梯度,eval()模式与之相对的是train()
        if not device: # 如果没有设备
            device = next(iter(net.parameters())).device # 设置为网络层所在的设备
    # 正确预测的数量,总预测的数量
    metric = d2l.Accumulator(2)
    with torch.no_grad():
        for X, y in data_iter:
            if isinstance(X, list): # 如果X是一个list
                # BERT微调所需的(之后将介绍)
                X = [x.to(device) for x in X] # 把每一个x都挪到device中
            else:
                X = X.to(device) # 如果不是list,挪一次就够了
            y = y.to(device) # 把y也挪到device上
            metric.add(d2l.accuracy(net(X), y), y.numel())
    return metric[0] / metric[1] # 分类正确的个数/总个数

为了使用GPU,我们还需要一点小改动,在进行正向和反向传播之前,我们需要将每一小批量数据移动到我们指定的设备(例如GPU)上

如下所示,训练函数train_ch6也类似于之前定义的train_ch3。 由于我们将实现多层神经网络,因此我们将主要使用高级API。 以下训练函数假定从高级API创建的模型作为输入,并进行相应的优化。 我们使用Xavier随机初始化模型参数。 与全连接层一样,我们使用交叉熵损失函数小批量随机梯度下降

def train_ch6(net, train_iter, test_iter, num_epochs, lr, device):
    """用GPU训练模型(在第六章定义)"""
    def init_weights(m): # 初始化权重
        if type(m) == nn.Linear or type(m) == nn.Conv2d: # 如果是全连接层或者卷积层
            nn.init.xavier_uniform_(m.weight) # 使用Xavier初始化:nn.init.xavier_uniform_
    net.apply(init_weights) # 使net的每一层都应用一下初始化权重函数
    print('training on', device) # 打印一下在哪个device上训练
    net.to(device) # 把net挪到device上
    optimizer = torch.optim.SGD(net.parameters(), lr=lr) # 使用了 SGD优化器
    loss = nn.CrossEntropyLoss() # 使用交叉熵损失函数
    animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs],
                            legend=['train loss', 'train acc', 'test acc'])# animator动画效果
    timer, num_batches = d2l.Timer(), len(train_iter) # d2l.Timer()是一个计时器
    for epoch in range(num_epochs): # 对每一次数据做迭代
        # 训练损失之和,训练准确率之和,样本数
        metric = d2l.Accumulator(3)
        net.train() # 训练模式
        # enumerate函数用来遍历一个集合对象,它在遍历的同时还可以得到当前元素的索引位置
        for i, (X, y) in enumerate(train_iter): # 每一次迭代都拿到索引i,以及 一个batch的X和y
            timer.start() # 开始训练
            optimizer.zero_grad() # 优化器梯度清零
            X, y = X.to(device), y.to(device) #把输入X和输出y挪到gpu上
            y_hat = net(X) # 得到X的预测值y_hat
            l = loss(y_hat, y) # 计算损失:预测值和真实值之间的差距
            l.backward() # 反向传播计算得到每个参数的梯度值
            optimizer.step() # 通过梯度下降执行一步参数更新
            with torch.no_grad(): # l * X.shape[0] 样本数乘以平均损失=总的样本损失
                metric.add(l * X.shape[0], d2l.accuracy(y_hat, y), X.shape[0])
            timer.stop() # 一个批量的训练结束
            train_l = metric[0] / metric[2] # 训练的损失
            train_acc = metric[1] / metric[2] # 训练的精度
            if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:
                animator.add(epoch + (i + 1) / num_batches,
                             (train_l, train_acc, None)) # 绘制每5次迭代的训练损失和训练精度
        test_acc = evaluate_accuracy_gpu(net, test_iter) # 测试精度
        animator.add(epoch + 1, (None, None, test_acc)) # 绘制每一次迭代的测试精度
    print(f'loss {train_l:.3f}, train acc {train_acc:.3f}, '
          f'test acc {test_acc:.3f}')
    print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec ' 
          f'on {str(device)}') # metric[2] * num_epochs:样本数,timer.sum():训练完成需要的时间

现在,我们[训练和评估LeNet-5模型]

lr, num_epochs = 0.9, 10
train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())

在这里插入图片描述

4. 小结

  • 卷积神经网络(CNN)是一类使用卷积层的网络。
  • 在卷积神经网络中,我们组合使用卷积层、非线性激活函数和汇聚层。
  • 为了构造高性能的卷积神经网络,我们通常对卷积层进行排列,逐渐降低其表示的空间分辨率,同时增加通道数。
  • 在传统的卷积神经网络中,卷积块编码得到的表征在输出之前需由一个或多个全连接层进行处理。
  • LeNet是最早发布的卷积神经网络之一。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/115410.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

利用Bat打开exe程序并传入值

目录 一、分清楚exe接收值的方式 1、打开exe时提示输入1、2、3... 2、知道exe形参(程序主函数中定义的argv[]) 二、call和start的区别 一、分清楚exe接收值的方式 1、打开exe时提示输入1、2、3... 如图: 这种是程序运行时接收用户输入…

SuperMap GIS 三维硬件设置优化

目录一、简介二、查看硬件显卡三、显卡设置1、NVIDA显卡设置2、AMD显卡设置一、简介 我们都知道为了体验更好的大型3D游戏,那么好的显卡是必不可少的。但是有了好的显卡当配置不对时,此时体验感也会大打折扣。同样的道理,在SuperMap中也需要…

Redis原理篇—通信协议

Redis原理篇—通信协议 笔记整理自 b站_黑马程序员Redis入门到实战教程 RESP协议 Redis 是一个 CS 架构的软件,通信一般分两步(不包括 pipeline 和 PubSub): 客户端(client)向服务端(server&a…

VMC证书是什么?获取认证标志证书步骤是怎样的?

VMC证书是什么? VMC证书(全称:Verified Mark Certificate),也称认证标志证书,是由权威CA机构颁发,用于验证商标所有权的数字证书。 VMC 通过提供验证机制与 BIMI 协同工作,BIMI标准可以在电子邮件中的“发…

OSCS开源安全周报第23期:Foxit PDF Reader/Editor 任意代码执行漏洞

本周安全态势综述 OSCS 社区共收录安全漏洞10个&#xff0c;其中公开漏洞值得关注的是 Apache Airflow Hive Provider <5.0.0 存在操作系统命令注入漏洞&#xff08;CVE-2022-46421&#xff09;vm2 < 3.9.10 存在任意代码执行漏洞&#xff08;CVE-2022-25893&#xff0…

湖南软件工程自考本科总结

本人情况 在湖南长沙考试&#xff0c;从2021年初开始备考&#xff0c;社会考生&#xff0c;自己复习&#xff0c;从2021-4月到2022-10月&#xff0c;理论每次都考了4门课程&#xff0c;前3次每次挂了1门课程&#xff0c;刚刚好在4次考试完成了所有的理论考试。 经验分享 复习重…

2022 re:Invent 凌云驭势 重塑未来

2022年11月29日&#xff0c;一年一度的亚马逊 re:Invent全球大会在拉斯维加斯再度上演&#xff0c;这是亚马逊云科技第11年举办re:Invent&#xff0c;来自全球的5万多客户和合作伙伴参加了此次线下盛会&#xff0c;还有超过30万人线上参会。在此次大会上&#xff0c;亚马逊云科…

喜报 | 云畅科技再次入榜湖南省互联网企业50强

12月21日&#xff0c;湖南省互联网协会在国家网络安全产业园区&#xff08;长沙&#xff09;发布了2022年湖南省互联网企业综合实力30强榜单、互联网成长型企业10强榜单、互联网创新型企业10强榜单和《2022年湖南省互联网企业50强发展报告》。 湖南云畅网络科技有限公司&#x…

burpsuite靶场——CSRF

文章目录什么是CSRF&#xff1f;CSRF 攻击的影响是什么&#xff1f;CSRF 是如何工作的&#xff1f;没有防御的 CSRF 漏洞常见的 CSRF 漏洞Token验证取决于请求方法Token的验证取决于Token是否存在CSRF Token未绑定到用户会话Token未与会话 cookie绑定什么是CSRF&#xff1f; 跨…

【HTML】动画合集--跟着pink老师学习

1.奔跑小熊 奔跑小熊<!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta http-equiv"X-UA-Compatible" content"IEedge"><meta name"viewport" content"widthdevice…

【免杀前置课——shellcode】二十、初识shellcode——配合栈溢出漏洞利用shellcode在代码中返回MessageBox函数

初识shellcode栈溢出漏洞反弹shellcodeshellcode取出shellcode栈溢出漏洞反弹shellcode shellcode shellcode&#xff1a; shellcode是一段用于利用软件漏洞而执行的代码&#xff0c;shellcode为16进制的机器码&#xff0c;因为经常让攻击者获得shell而得名。.shellcode常常使…

自动控制原理笔记-线性系统的稳定性分析

目录 稳定的概念及定义&#xff1a; 系统稳定的充要条件——闭环极点全部落在虚轴左边&#xff1a; 系统的稳定性判据&#xff1a; 劳斯判据(充要性)判据&#xff1a; 劳斯表特殊情况例(出现计算过程分母为0)&#xff1a; 劳斯表特殊情况例(出现全0行)&#xff1a; 稳定的…

【Java】花费数十小时,带你体验Java文档搜索引擎的实现过程

Java文档搜索引擎项目运行效果一、简述搜索引擎概念二、搜索引擎实现思路2.1倒排索引介绍2.2项目目标2.3获取java文档2.4模块划分2.5创建项目2.6认识分词2.7分词的原理2.8使用第三方分词库三、实现索引模块-parser类3.1 实现索引模块-递归枚举文件3.2 排除非HTML文件3.3 实现索…

旁瓣消隐技术在雷达中应用

电子对抗在现代战争中的作用日趋重要&#xff0c;没有雷达抗干扰技术的雷达完全失去其发现测定敌人目标的功能。从降低天线旁瓣干扰方面考虑&#xff0c;雷达抗干扰技术主要包括旁瓣对消技术和旁瓣消隐技术&#xff0c;旁瓣对消器在有一个辅助天线的情况下抑制一个干扰源的效果…

正式入职开发工程师工作近半年有感

一、前言 博主是毕业于集美大学的一枚软件工程本科生&#xff0c;不知不觉已经毕业近半年了&#xff0c;由于工作繁忙 个人的懒惰&#xff0c;对CSDN的博客记录频率已经大不如之前。说起这里也是惭愧&#xff0c;之后博主会尽量抽出时间&#xff0c;继续保持各种学习&#xf…

代码随想录算法训练营第43天 | 1049. 最后一块石头的重量 II 494. 目标和 474.一和零

一、Leetcode 1049. 最后一块石头的重量 II 这几个题都很不好给转成01问题。本题一开始我以为怎么撞都行&#xff0c;其实不是&#xff0c;相当于给每项前面加1&#xff0c; 就是说有时候不能浪费小石头&#xff0c;得跟大石头碰。 那么问题就很明显了&#xff0c;类似于分割…

AC自动机

AC自动机 AC自动机是干嘛的&#xff1f; 我有一个敏感词数组&#xff0c;里面装的是所有的敏感词&#xff0c;还有一篇大文章&#xff0c;我要求出大文章里面所有的敏感词。 敏感词数组本身的组织是一颗前缀树。 AC自动机就是在前缀树的基础上做升级。 流程 我们在前缀树的…

已来到 “后云原生时代” 的我们,如何规模化运维?

文&#xff5c;李大元 &#xff08;花名&#xff1a;达远&#xff09; Kusion 项目负责人 来自蚂蚁集团 PaaS 核心团队&#xff0c;PaaS IaC 基础平台负责人。 本文 4331 字 阅读 11 分钟 PART. 1 后云原生时代 距离 Kubernetes 第一个 commit 已经过去八年多了&#xff0c…

chrome extensions mv3与mv2比较 执行eval

文章目录背景1、mv3版本与mv2版本之间的一些区别2、解决mv3版本DOM交互 & JS执行问题2.1、关于引入eval52.2、关于在background.js执行script脚本3、background执行fetch调用URL参考背景 老的扩展项目使用的是mv2版本的API&#xff0c;计划升级mv3版本的时候遇到了下面的问…

MySQL索引为什么使用B+树,而不用二叉树、红黑树、哈希表、B树?

索引是帮助MySQL高效获取数据的排好序的数据结构。 索引数据结构&#xff1a; 1.二叉树 2.红黑树 3.Hash表 4.B-Tree 1. 二叉查找树&#xff08;Binary Search Trees&#xff09; 左节点比父节点要小&#xff0c;右节点比父节点要大。它的高度决定的查找效率。 如果某一列数…