动手学深度学习—卷积神经网络LeNet(代码详解)

news2025/1/23 15:06:05

1. LeNet

LeNet由两个部分组成:

  • 卷积编码器:由两个卷积层组成;
  • 全连接层密集块:由三个全连接层组成。

在这里插入图片描述

  1. 每个卷积块中的基本单元是一个卷积层、一个sigmoid激活函数和平均汇聚层;
  2. 每个卷积层使用5×5卷积核和一个sigmoid激活函数;
  3. 这些层将输入映射到多个二维特征输出,通常同时增加通道的数量;
  4. 每个4×4池操作(步幅2)通过空间下采样将维数减少4倍。
import torch
from torch import nn
from d2l import torch as d2l

# 定义模型net
net = nn.Sequential(
    nn.Conv2d(1, 6, kernel_size=5, padding=2), nn.Sigmoid(),
    nn.AvgPool2d(kernel_size=2, stride=2),
    nn.Conv2d(6, 16, kernel_size=5), nn.Sigmoid(),
    nn.AvgPool2d(kernel_size=2, stride=2),
    nn.Flatten(),
    nn.Linear(16 * 5 * 5, 120), nn.Sigmoid(),
    nn.Linear(120, 84), nn.Sigmoid(),
    nn.Linear(84, 10))

该模型去掉了最后一层的高斯激活,下面将一个大小为28×28的单通道(黑白)图像通过LeNet,打印每一层输出的形状。

# 观察各层的输入输出通道数,宽度和高度
X = torch.rand(size=(1, 1, 28, 28), dtype=torch.float32)
for layer in net:
    X = layer(X)
    print(layer.__class__.__name__,'output shape:\t', X.shape)

在这里插入图片描述

  1. 第一个卷积层使用2个像素的填充,来补偿5×5卷积核导致的特征减少;
  2. 第二个卷积层没有填充,因此高度和宽度都减少了4个像素;
  3. 随着层叠的上升,通道的数量从输入时的1个,增加到第一个卷积层之后的6个,再到第二个卷积层之后的16个;
  4. 每个汇聚层的高度和宽度都减半;
  5. 每个全连接层减少维数,最终输出一个维数与结果分类数相匹配的输出。

2. 模型训练

batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size=batch_size)
"""
    定义精度评估函数:
    1、将数据集复制到显存中
    2、通过调用accuracy计算数据集的精度
"""
def evaluate_accuracy_gpu(net, data_iter, device=None): #@save
    # 判断net是否属于torch.nn.Module类
    if isinstance(net, nn.Module):
        net.eval()
        
        # 如果不在参数选定的设备,将其传输到设备中
        if not device:
            device = next(iter(net.parameters())).device
    
    # Accumulator是累加器,定义两个变量:正确预测的数量,总预测的数量。
    metric = d2l.Accumulator(2)
    with torch.no_grad():
        for X, y in data_iter:
            # 将X, y复制到设备中
            if isinstance(X, list):
                # BERT微调所需的(之后将介绍)
                X = [x.to(device) for x in X]
            else:
                X = X.to(device)
            y = y.to(device)
            
            # 计算正确预测的数量,总预测的数量,并存储到metric中
            metric.add(d2l.accuracy(net(X), y), y.numel())
    return metric[0] / metric[1]
"""
    定义GPU训练函数:
    1、为了使用gpu,首先需要将每一小批量数据移动到指定的设备(例如GPU)上;
    2、使用Xavier随机初始化模型参数;
    3、使用交叉熵损失函数和小批量随机梯度下降。
"""
#@save
def train_ch6(net, train_iter, test_iter, num_epochs, lr, device):
    """用GPU训练模型(在第六章定义)"""
    # 定义初始化参数,对线性层和卷积层生效
    def init_weights(m):
        if type(m) == nn.Linear or type(m) == nn.Conv2d:
            nn.init.xavier_uniform_(m.weight)
    net.apply(init_weights)
    
    # 在设备device上进行训练
    print('training on', device)
    net.to(device)
    
    # 优化器:随机梯度下降
    optimizer = torch.optim.SGD(net.parameters(), lr=lr)
    
    # 损失函数:交叉熵损失函数
    loss = nn.CrossEntropyLoss()
    
    # Animator为绘图函数
    animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs],
                            legend=['train loss', 'train acc', 'test acc'])
    
    # 调用Timer函数统计时间
    timer, num_batches = d2l.Timer(), len(train_iter)
    
    for epoch in range(num_epochs):
        
        # Accumulator(3)定义3个变量:损失值,正确预测的数量,总预测的数量
        metric = d2l.Accumulator(3)
        net.train()
        
        # enumerate() 函数用于将一个可遍历的数据对象
        for i, (X, y) in enumerate(train_iter):
            timer.start() # 进行计时
            optimizer.zero_grad() # 梯度清零
            X, y = X.to(device), y.to(device) # 将特征和标签转移到device
            y_hat = net(X)
            l = loss(y_hat, y) # 交叉熵损失
            l.backward() # 进行梯度传递返回
            optimizer.step()
            with torch.no_grad():
                # 统计损失、预测正确数和样本数
                metric.add(l * X.shape[0], d2l.accuracy(y_hat, y), X.shape[0])
            timer.stop() # 计时结束
            train_l = metric[0] / metric[2] # 计算损失
            train_acc = metric[1] / metric[2] # 计算精度
            
            # 进行绘图
            if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:
                animator.add(epoch + (i + 1) / num_batches,
                             (train_l, train_acc, None))
                
        # 测试精度
        test_acc = evaluate_accuracy_gpu(net, test_iter) 
        animator.add(epoch + 1, (None, None, test_acc))
        
    # 输出损失值、训练精度、测试精度
    print(f'loss {train_l:.3f}, train acc {train_acc:.3f},'
          f'test acc {test_acc:.3f}')
    
    # 设备的计算能力
    print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec'
          f'on {str(device)}')

在这里插入图片描述

lr, num_epochs = 0.9, 10
train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())

在这里插入图片描述

3. 小结

  1. 卷积神经网络(CNN)是一类使用卷积层的网络;
  2. 卷积神经网络中,可以组合使用卷积层、非线性激活函数和汇聚层;
  3. 为了构造高性能的卷积神经网络,通常对卷积层进行排列,逐渐降低其表示的空间分辨率,同时增加通道数;
  4. 在传统的卷积神经网络中,卷积块编码得到的表征在输出之前需由一个或多个全连接层进行处理。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/883257.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

基于libevent的tcp服务器

libevent使用教程_evutil_make_socket_nonblocking_易方达蓝筹的博客-CSDN博客 一、准备 centos7下安装libevent库 yum install libevent yum install -y libevent-devel 二、代码 server.cpp /** You need libevent2 to compile this piece of code Please see: http://li…

分类预测 | MATLAB实现MTBO-CNN多输入分类预测

分类预测 | MATLAB实现MTBO-CNN多输入分类预测 目录 分类预测 | MATLAB实现MTBO-CNN多输入分类预测预测效果基本介绍程序设计参考资料 预测效果 基本介绍 1.MATLAB实现MTBO-CNN多输入分类预测 2.代码说明:基于登山队优化算法(MTBO)、卷积神经…

android 12系统加上TTS引擎

系统层修改&#xff1a; 1.frameworks/base/packages/SettingsProvider/res/values/defaults.xml <string name"def_tts"></string> 2.frameworks/base/packages/SettingsProvider/src/com/android/providers/settings/DatabaseHelper.java loadString…

​五金件机器视觉定位​并获取外观轮廓软硬件视觉方案

【检测目的】 五金件机器视觉定位&#xff0c;视觉检测五金件轮廓并矫正五金件位置进行涂油 【客户要求】 FOV:540*400mm 【拍摄与处理效图一】 【拍摄与处理效图二】 【实验原理及说明】 【方案评估】 根据目前的图像和处理结果来看&#xff0c;可以检测出产品轮廓并进行位置…

Unity小项目__小球吃零食

// Player脚本文件源代码 public class Player : MonoBehaviour {public Rigidbody rd; // 定义了一个刚体组件public int score 0; // 定义了一个计分器public Text scoreText; // 定义了一个文本组件public GameObject winText; // 定义了一个游戏物体用于检验游戏结束// S…

征战2023跨境旺季,独立站如何实现新增长?

2023年出海赛道万象更新&#xff0c;行业重回正轨&#xff0c;跨境行业在经历过多轮洗牌过后&#xff0c;中国企业在全球化产业链中表现出了更强大的增长韧性。而随着跨境模式与消费需求的多样化与精细化。单一渠道的出海布局已经不能满足企业实现品牌出海的转型需求。 DTC独立…

前端工具的选择

目录 前端常见开发者工具 浏览器 开发者工具 VScode开发者工具快捷键 前端常见开发者工具 浏览器 浏览器是我们最重要的合作伙伴 关于浏览器的选择&#xff0c;我目前主要用主要是谷歌浏览器&#xff0c;我个人觉得谷歌浏览器使用起来比较方便、简洁&#xff0c;没有太多…

【TypeScript】tsc -v 报错 —— 在此系统上禁止运行脚本

在 VS Code 终端中执行 tsc -v &#xff0c;报错 —— 在此系统上禁止运行脚本 然后 windows x &#xff0c;打开终端管理员&#xff0c;出现同样的问题 解决方法&#xff1a; 终端&#xff08;管理员&#xff09;执行以下命令&#xff1a; 出现 RemoteSigned 则代表更改成功…

Flask-SQLAlchemy

认识Flask-SQLAlchemy Flask-SQLAlchemy 是一个为 Flask 应用增加 SQLAlchemy 支持的扩展。它致力于简化在 Flask 中 SQLAlchemy 的使用。SQLAlchemy 是目前python中最强大的 ORM框架, 功能全面, 使用简单。 ORM优缺点 优点 有语法提示, 省去自己拼写SQL&#xff0c;保证SQL…

誉天HCIP-Datacom课程简介

HCIP-Datacom课程介绍&#xff1a;HCIP-Datacom分为一个核心技术方向&#xff1a;HCIP-Datacom-Core Technology H12-821 &#xff08;核心技术&#xff09;六个可选子方向&#xff1a;HCIP-Datacom-Advanced Routing & Switching Technology H12-831 &#xff08;高级路…

ubuntu设置共享文件夹成功后却不显示找不到(已解决)

1.首先输下面命令查看是否真的设置成功共享文件夹 vmware-hgfsclient如果确实已经设置过共享文件夹将输出window下共享文件夹名字 2.确认自己已设置共享文件夹后输入下面的命令 //如果之前没有命令包则先执行sudo apt-get install open-vm-tools sudo vmhgfs-fuse .host:/ /mn…

Ubuntu18.04.4裸机配置

下载虚拟机Ubuntu18.04.4 链接&#xff1a;https://pan.baidu.com/s/1jyucyUSXa9-Fw9ctuU87hA 提取码&#xff1a;o42a –来自百度网盘超级会员V5的分享 VMware选择镜像安装 设置你的用户名&#xff0c;就像windows上登录用户一样简单 下一步……下一步……如此简单 下载…

《Effects of Graph Convolutions in Multi-layer Networks》阅读笔记

一.文章概述 本文研究了在XOR-CSBM数据模型的多层网络的第一层以上时&#xff0c;图卷积能力的基本极限&#xff0c;并为它们在数据中信号的不同状态下的性能提供了理论保证。在合成数据和真实世界数据上的实验表明a.卷积的数量是决定网络性能的一个更重要的因素&#xff0c;而…

TiDB数据库从入门到精通系列之一:TiDB数据库的软硬件环境要求和系统配置检查

TiDB数据库从入门到精通系列之一&#xff1a;TiDB数据库的软硬件环境要求和系统配置检查 一、软件和硬件配置要求1.操作系统及平台要求2.服务器建议配置3.网络要求4.磁盘空间要求 二、TiDB 环境与系统配置检查1.在 TiKV 部署目标机器上添加数据盘 EXT4 文件系统挂载参数2.设置 …

ECOLOGY9实现正文文档按发布范围授权查看

需求&#xff1a;E9流程需要流程归档后&#xff0c;正文文档按发布范围授权查看&#xff0c;不需要是流程参与者。 解决&#xff1a;表单中定义发布范围是人力资源条件 在流程基础设置-功能设置中设置启用按人力资源 条件字段赋权。 实现的效果。

司徒理财:8.15黄金美盘多空最新操作建议

黄金一直没能跌破1902的支撑&#xff0c;司徒理财依旧维持低多看涨的思路不变&#xff0c;早盘1905多单继续持有中&#xff0c;静待美盘拉升&#xff01;黄金现在的下跌力度已经衰竭&#xff0c;并且日线上已经跌至200日均线的支撑位置&#xff0c;大周期的均线支撑&#xff0c…

WSL2 ubuntu子系统OpenCV调用本机摄像头的RTSP视频流做开发测试

文章目录 前言一、Ubuntu安装opencv库二、启动 Windows 本机的 RTSP 视频流下载解压 EasyDarwin查看本机摄像头设备开始推流 三、在ubuntu 终端编写代码创建目录及文件创建CMakeLists.txt文件启动 cmake 配置并构建 四、结果展示启动图形界面在图形界面打开终端找到 rtsp_demo运…

阿里云与中国中医科学院合作,推动中医药行业数字化和智能化发展

据相关媒体消息&#xff0c;阿里云与中国中医科学院的合作旨在推动中医药行业的数字化和智能化发展。随着互联网的进步和相关政策的支持&#xff0c;中医药产业受到了国家的高度关注。这次合作将以“互联网 中医药”为载体&#xff0c;致力于推进中医药文化的传承和创新发展。…

PDB Database - 高质量 RCSB PDB 蛋白质结构筛选与过滤

欢迎关注我的CSDN&#xff1a;https://spike.blog.csdn.net/ 本文地址&#xff1a;https://spike.blog.csdn.net/article/details/132307119 Protein Data Bank (PDB) 是一个收集和存储三维结构数据的公共数据库&#xff0c;主要包括蛋白质和核酸分子。PDB 由美国、欧洲和日本三…

机器学习深度学习——机器翻译(序列生成策略)

&#x1f468;‍&#x1f393;作者简介&#xff1a;一位即将上大四&#xff0c;正专攻机器学习的保研er &#x1f30c;上期文章&#xff1a;机器学习&&深度学习——seq2seq实现机器翻译&#xff08;详细实现与原理推导&#xff09; &#x1f4da;订阅专栏&#xff1a;机…