动手学深度学习(pytorch)学习记录26-卷积神经网路(LeNet)[学习记录]

news2024/11/10 22:47:54

目录

  • LeNet
  • 模型训练

LeNet

总体来看,LeNet(LeNet-5)由两个部分组成:
卷积编码器:由两个卷积层组成;
全连接层密集块:由三个全连接层组成。
图片来源https://baike.baidu.com/item/LeNet-5/61427772?fr=ge_ala每个卷积块中的基本单元是一个卷积层、一个sigmoid激活函数和平均汇聚层。请注意,虽然ReLU和最大汇聚层更有效,但它们在20世纪90年代还没有出现。每个卷积层使用5×5卷积核和一个sigmoid激活函数。这些层将输入映射到多个二维特征输出,通常同时增加通道的数量。第一卷积层有6个输出通道,而第二个卷积层有16个输出通道。每个2×2池操作(步幅2)通过空间下采样将维数减少4倍。卷积的输出形状由批量大小、通道数、高度、宽度决定。

为了将卷积块的输出传递给稠密块,必须在小批量中展平每个样本。换言之,我们将这个四维输入转换成全连接层所期望的二维输入。这里的二维表示的第一个维度索引小批量中的样本,第二个维度给出每个样本的平面向量表示。LeNet的稠密块有三个全连接层,分别有120、84和10个输出。因为我们在执行分类任务,所以输出层的10维对应于最后输出结果的数量。

import torch
from torch import nn
from d2l import torch as d2l
net = nn.Sequential(
    nn.Conv2d(1, 6, kernel_size=5, padding=2), nn.Sigmoid(),
    nn.AvgPool2d(kernel_size=2, stride=2),
    nn.Conv2d(6, 16, kernel_size=5), nn.Sigmoid(),
    nn.AvgPool2d(kernel_size=2, stride=2),
    nn.Flatten(),
    nn.Linear(16 * 5 * 5, 120), nn.Sigmoid(),
    nn.Linear(120, 84), nn.Sigmoid(),
    nn.Linear(84, 10))

对原始模型做了一点小改动,去掉了最后一层的高斯激活。除此之外,这个网络与最初的LeNet-5一致
将一个大小为28×28的单通道(黑白)图像通过LeNet。通过在每一层打印输出的形状,可以检查模型。

X = torch.rand(size=(1, 1, 28, 28), dtype=torch.float32)
for layer in net:
    X = layer(X)
    print(layer.__class__.__name__,'output shape: \t',X.shape)
Conv2d output shape: 	 torch.Size([1, 6, 28, 28])
Sigmoid output shape: 	 torch.Size([1, 6, 28, 28])
AvgPool2d output shape: 	 torch.Size([1, 6, 14, 14])
Conv2d output shape: 	 torch.Size([1, 16, 10, 10])
Sigmoid output shape: 	 torch.Size([1, 16, 10, 10])
AvgPool2d output shape: 	 torch.Size([1, 16, 5, 5])
Flatten output shape: 	 torch.Size([1, 400])
Linear output shape: 	 torch.Size([1, 120])
Sigmoid output shape: 	 torch.Size([1, 120])
Linear output shape: 	 torch.Size([1, 84])
Sigmoid output shape: 	 torch.Size([1, 84])
Linear output shape: 	 torch.Size([1, 10])

模型训练

看看LeNet在Fashion-MNIST数据集上的表现

batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size=batch_size)

虽然卷积神经网络的参数较少,但与深度的多层感知机相比,它们的计算成本仍然很高,因为每个参数都参与更多的乘法。 通过使用GPU,可以用它加快训练。

def evaluate_accuracy_gpu(net, data_iter, device=None): #@save
    """使用GPU计算模型在数据集上的精度"""
    if isinstance(net, nn.Module):
        net.eval()  # 设置为评估模式
        if not device:
            device = next(iter(net.parameters())).device
    # 正确预测的数量,总预测的数量
    metric = d2l.Accumulator(2)
    with torch.no_grad():
        for X, y in data_iter:
            if isinstance(X, list):
                # BERT微调所需的(之后将介绍)
                X = [x.to(device) for x in X]
            else:
                X = X.to(device)
            y = y.to(device)
            metric.add(d2l.accuracy(net(X), y), y.numel())
    return metric[0] / metric[1]

为了使用GPU,还需要一点小改动。 与学习记录10节中定义的train_epoch_ch3不同,在进行正向和反向传播之前,需要将每一小批量数据移动到我们指定的设备(例如GPU)上。

如下所示,训练函数train_ch6也类似于 学习记录10中定义的train_ch3。 主要使用高级API实现多层神经网络。 以下训练函数假定从高级API创建的模型作为输入,并进行相应的优化。 我们使用Xavier随机初始化模型参数。 与全连接层一样,使用交叉熵损失函数和小批量随机梯度下降。

def train_ch6(net, train_iter, test_iter, num_epochs, lr, device):
    """用GPU训练模型(在第六章定义)"""
    def init_weights(m):
        if type(m) == nn.Linear or type(m) == nn.Conv2d:
            nn.init.xavier_uniform_(m.weight)
    net.apply(init_weights)
    print('training on', device)
    net.to(device)
    optimizer = torch.optim.SGD(net.parameters(), lr=lr)
    loss = nn.CrossEntropyLoss()
    animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs],
                            legend=['train loss', 'train acc', 'test acc'])
    timer, num_batches = d2l.Timer(), len(train_iter)
    for epoch in range(num_epochs):
        # 训练损失之和,训练准确率之和,样本数
        metric = d2l.Accumulator(3)
        net.train()
        for i, (X, y) in enumerate(train_iter):
            timer.start()
            optimizer.zero_grad()
            X, y = X.to(device), y.to(device)
            y_hat = net(X)
            l = loss(y_hat, y)
            l.backward()
            optimizer.step()
            with torch.no_grad():
                metric.add(l * X.shape[0], d2l.accuracy(y_hat, y), X.shape[0])
            timer.stop()
            train_l = metric[0] / metric[2]
            train_acc = metric[1] / metric[2]
            if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:
                animator.add(epoch + (i + 1) / num_batches,
                             (train_l, train_acc, None))
        test_acc = evaluate_accuracy_gpu(net, test_iter)
        animator.add(epoch + 1, (None, None, test_acc))
    print(f'loss {train_l:.3f}, train acc {train_acc:.3f}, '
          f'test acc {test_acc:.3f}')
    print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec '
          f'on {str(device)}')

训练和评估LeNet-5模型。

lr, num_epochs = 0.9, 10
train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())

在这里插入图片描述

封面图片来源
欢迎点击我的主页查看更多文章。
本人学习地址https://zh-v2.d2l.ai/
恳请大佬批评指正。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2118976.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Docker 清理和查看镜像与容器占用情况

查看容器占用磁盘大小 docker system df 查看单个image、container大小: docker system df -v 清理所有废弃镜像与Build Cache docker system prune -a

【解决内存泄漏的问题】 Qt 框架中的父子对象关系会自动管理内存,父对象会在其销毁时自动销毁所有子对象。

修改前的代码 这段代码可能会出现内存泄漏问题,主要原因是构造函数中创建的 LoginDialog 和 RegisterDialog 对象未在合适的地方被正确释放。具体分析如下: 1. 构造函数中的问题 _login_dlg new LoginDialog(); setCentralWidget(_login_dlg); _login…

【北京迅为】《STM32MP157开发板使用手册》- 第十二章 编译Linux内核

iTOP-STM32MP157开发板采用ST推出的双核cortex-A7单核cortex-M4异构处理器,既可用Linux、又可以用于STM32单片机开发。开发板采用核心板底板结构,主频650M、1G内存、8G存储,核心板采用工业级板对板连接器,高可靠,牢固耐…

运算放大器中的反馈

运算放大器中的反馈:原理、类型与应用 运算放大器(Operational Amplifier, 简称Op-Amp)是现代电子电路中的重要组成部分,被广泛应用于信号处理、放大、滤波等场合。而反馈技术则是运算放大器电路的核心之一,直接影响其…

代码随想录算法训练营第二十二天| 491. 递增子序列、46. 全排列、47. 全排列Ⅱ

今日内容 Leetcode. 491 递增子序列Leetcode. 46 全排列Leetcode. 47 全排列Ⅱ Leetcode. 491 递增子序列 文章链接:代码随想录 (programmercarl.com) 题目链接:491. 非递减子序列 - 力扣(LeetCode) 本题也是一个子集问题&#…

【AI绘画】Midjourney后置指令--seed、--tile、--q、--chaos、--w、--no详解

博客主页: [小ᶻZ࿆] 本文专栏: AI绘画 | Midjourney 文章目录 💯前言💯Midjourney后置指令--seed测试1测试2如何获取未指定种子图片的随机种子注意点 💯Midjourney后置指令--tile测试 💯Midjourney后置指令--q(or-…

Retrieval-based-Voice-Conversion-WebUI模型构建指南

一、模型介绍 Retrieval-based-Voice-Conversion-WebUI(简称 RVC)模型是一个基于 VITS(Variational Inference with adversarial learning for end-to-end Text-to-Speech)的简单易用的语音转换框架。 具有以下特点 简单易用&a…

chrome浏览器如何设置自动播放音视频

使用场景: 有些场景需要打开页面后,自动播放视频或者视频流,这时候发现无法播放,打开浏览器控制台发现有下面的错误提示:NotAllowedError: play() failed because the user didnt interact with the document first 。…

顶级出图效果!免费在线使用FLux.1 模型,5s出图无限制!

最近发现一个可以在线免费使用 FLux.1 模型 生成图片的AI工具。 先看效果图: 工具不需要登录即可使用,目前还是完全免费的,国内可以直接使用。 在提示词输入框直接输入提示词即可,选择图片比例之后,直接生图。 出图的…

安全运营之浅谈SIEM告警疲劳

闲谈: 刚开始学习SIEM、态势感知这类产品的时,翻阅老外们的文章总是谈什么真阳性,假阳性告警、告警疲劳,当时在国内资料中没找到很合理的解释,慢慢就淡忘这件事了。随着慢慢深入工作,感觉大概理解了这些概念…

‌技术人必看!如何科学规划,从需求出发打造完美技术方案

引言 在互联网架构师的角色中,我们面临的挑战不仅仅是编写代码,更重要的是深入理解需求、设计系统,并确保我们的解决方案能够稳定、高效地运行。本文将详细介绍从新需求提出到技术方案发布的全过程。 1. 理解现有需求和场景 在开始一个新的…

信息学奥赛初赛天天练-87-NOIP2014普及组-完善程序-矩阵、子矩阵、最大子矩阵和、前缀和、打擂台求最大值

1 完善程序 最大子矩阵和 给出 m行 n列的整数矩阵,求最大的子矩阵和(子矩阵不能为空)。 输入第一行包含两个整数 m和 n,即矩阵的行数和列数。之后 m行,每行 n个整数,描述整个矩阵。程序最终输出最大的子矩阵和。 (最…

SAP中mmpv自动过账—附带源码

想省事儿的直接拖到后面查看代码 思路分析 实现逻辑:初版 前台测试:选择屏幕确认公司代码。必要情况手动开账勾选前台执行按钮 1.1去marv表找公司代码的当前账期,简单运算获取下一个账期。1.2执行bdc,模拟前台手动开账期1.3执行的必要信息存日志表。例:修改人(开账期的人…

FastAPI 进阶:使用 BackgroundTasks 处理长时间运行的任务

在 FastAPI 中,BackgroundTasks 是一个功能,它允许你在发送响应给客户端之后执行后台任务。这些任务对于不需要客户端等待的操作非常有用,比如发送电子邮件通知或处理数据。然而,当服务器重启时,由于 BackgroundTasks …

C++: set与map容器的介绍与使用

本文索引 前言1. 二叉搜索树1.1 概念1.2 二叉搜索树操作1.2.1 查找与插入1.2.2 删除1.2.3 二叉搜索树实现代码 2. 树形结构的关联式容器2.1 set的介绍与使用2.1.1 set的构造函数2.1.2 set的迭代器2.1.3 set的容量2.1.4 set的修改操作 2.2 map的介绍与使用2.2.1 map的构造函数2.…

基于python的mediapipe姿态识别 动作识别 人体关健点 实现跳绳状态判别 计数功能

基于Python的MediaPipe姿态识别实现跳绳状态判别与计数功能 项目概述 本项目旨在利用Google的MediaPipe库,结合姿态识别技术,实现对跳绳动作的实时检测与计数功能。通过识别人体关键点,系统能够准确判断跳绳动作的状态,并实时统…

Java入门:07.Java中的面向对象03

11 this关键字 this关键字有两个作用 第一个作用,用来调用重载的构造方法 public class Test3{public static void main(String[] args){new User();new User("ls");new User("ls","女");} } ​ class User{String name ;String sex…

Autosar工程师必读:ETAS工具链自动化实战指南<三>

----自动化不仅是一种技术,更是一种思维方式,它将帮助我们在快节奏的工作环境中保持领先! 目录 往期推荐 自动化命令--generate 命令语法 参数说明 命令使用前提 场景1:BSW代码生成 场景2:RTE代码生成 场景3&a…

对非洲33国免关税!非洲市场不容错过

2024年9月5日中非合作论坛峰会在北京隆重召开,会议后宣布对非洲33个国家实行0关税的优惠政策,并且在未来三年,推动中国企业对非投资不少于700亿元人民币。 自然而然的,中非友好关系必然会带动中国对非洲市场的出口,近…

云计算实训44——K8S及pod相关介绍

一、K8S基本概念 1、k8s是什么 K8S是Kubernetes的 缩写,由于k 和 s 之间有⼋个字符,所以因此得名。 Kubernetes 是⼀个可移植的、可扩展的开源平台,⽤于管理容器化 的⼯作负载和服务,可促进声 明式配置和⾃动化。 2、k8s的功能…