d2l动手学深度学习】 Lesson 13 Dropout层 老板随机丢掉一些做项目的程序员‍,项目的效果会更好!(bushi)

news2024/11/16 15:25:59

文章目录

  • 1. 什么是Dropout
    • 老板随机丢掉一些做项目的程序员🧑‍💻,项目的效果会更好!
  • 2. 代码实现(不用torch)
  • 3. 代码实现(使用torch)
  • 3. 调节实验
    • 3.1 老师上课所设置的dropout1, dropout2 = 0.2, 0.5
      • 动手实现版
      • 简介torch版
    • 3.2 dropout1, dropout2 = 0, 0
    • 3.3 dropout1, dropout2 = 1, 1(全部扔掉?🤔)
    • 3.4 dropout1, dropout2 = 0.9, 0.9(几乎全部扔掉?)
    • 3.5 dropout1, dropout2 = 0.6, 0.8
    • 3.5 dropout1, dropout2 = 0.8, 0.6
  • 4. 整理一些有趣的Q&A 🤔
    • dropout 随机为0?
    • 理解dropout
    • 可重复性问题
  • 写在最后


1. 什么是Dropout

老板随机丢掉一些做项目的程序员🧑‍💻,项目的效果会更好!

dropout图片来自讲课PPT

Dropout,顾名思义,就是丢弃,是在多层感知机(MLP)中经常用到的一种用于防止过拟合的一种训练技巧,如上图所示,就是在中间层将一些神经元变为0,然后输出

需要注意的是:在实作中并不会像上面这张图片这样直接删除神经元,而是通过生成一个含有0的Mask去和原来输入的结果作点积(维持输入形状不改变,被去掉的神经元对应的位置乘以0)

李沐老师还提到,Dropout也可以看成是另一种形式的正则化方法(Regulation),也可以用来防止模型过拟合


2. 代码实现(不用torch)

def dropout_layer(X, dropout):
    # input, dropout rate
    assert 0 <= dropout <= 1
    if dropout == 1:
        return torch.zeros_like(X) # 等于1变全0了 全丢了
    if dropout == 0:
        return X
    # 比较得到布尔矩阵
    mask = (torch.randn(X.shape) > dropout).float()
    # 做矩阵乘法比使用数组索引index的运算速度快 X[mask] = 0
    return mask * X /(1.0 - dropout)

3. 代码实现(使用torch)

net = nn.Sequential(nn.Flatten(),
                    nn.Linear(784, 256),
                    nn.ReLU(),
                    nn.Dropout(dropout1),
                    nn.Linear(256, 256),
                    nn.ReLU(),
                    nn.Dropout(dropout2),
                    nn.Linear(256, 10))
def init_weights(m):
    if type(m) == nn.Linear:
        nn.init.normal_(m.weight, std=0.01)
        
net.apply(init_weights)

模型结构

Sequential(
  (0): Flatten(start_dim=1, end_dim=-1)
  (1): Linear(in_features=784, out_features=256, bias=True)
  (2): ReLU()
  (3): Dropout(p=0.0, inplace=False)
  (4): Linear(in_features=256, out_features=256, bias=True)
  (5): ReLU()
  (6): Dropout(p=0.0, inplace=False)
  (7): Linear(in_features=256, out_features=10, bias=True)
)

3. 调节实验

老师的代码主要用到了两个Dropout层,因此在模型中对应两个dropout rate,👇下面我们主要对者两个参数进行调节并观察对应的实验结果。dropout rate 后面将简写为(DR)

⚠️注意!!:在运行代码时候,记得修改loss = nn.CrossEntropyLoss(reduction='none'),里面reduction,不然显示不出来loss!

3.1 老师上课所设置的dropout1, dropout2 = 0.2, 0.5

动手实现版

动手实现版

简介torch版

在这里插入图片描述

两种版本的实现的训练效果都差不多(这里假设没有其他优化计算的因素影响模型最后的训练效果),接下来我们就用简洁Torch版本来讨论。

3.2 dropout1, dropout2 = 0, 0

在这里插入图片描述

不用dropout的模型准确率反而上升了?
弹幕里面说是因为有可能模型是过拟合的,因为这里的Loss变的非常小
李沐老师说,现在256是一个很大的模型(对于我们这个小的MNIST数据集来说

3.3 dropout1, dropout2 = 1, 1(全部扔掉?🤔)

报错🙅,这个是运行不了的
在这里插入图片描述

3.4 dropout1, dropout2 = 0.9, 0.9(几乎全部扔掉?)

在这里插入图片描述
这个也会出现很问题

3.5 dropout1, dropout2 = 0.6, 0.8

在这里插入图片描述

3.5 dropout1, dropout2 = 0.8, 0.6

在这里插入图片描述

这里推测之前运行不了的原因也有可能是第一层的神经网络扔得太多了


综上所述,不管怎么调节dropout rate,还是比不过不用drop的策略,有可能模型还是不够大,应该出现overfitting的情况再使用dropout策略会好一点?


4. 整理一些有趣的Q&A 🤔

dropout 随机为0?

  • 在求梯度时,设置随机为0,在BackProp的时候对应的梯度也是0,所以为啥Hinton说Dropout更像是在训练的过程中,将一些小网络逐一拿出来(不同的子网络),将各个子网络完成训练以后再融和在一起

理解dropout

  • DR太小了,和太大了都不合适,太小没有作用,相反,太大就变成限制模型参数拟合的性能发挥 没有正确性🙆可言,一般就只有正确率
  • 在作模型推理的时候,不需要使用Drop,因为不会再改变模型参数,如果用也可以,就会引入一些随机性,需要多算几遍,预测会丢掉东西,第一次是猫🐱,第二次可能就是狗🐶
  • Drop在MLP全连接层用的比较多,但是weight decay则全部都在用,包括CNN,RNN这些

可重复性问题

  • 神经网络训练的可重复性确实是一个问题,不过可以通过把random seed设定
  • 但是李沐老师提到一个问题🙋,就是使用加速⏩CUDA中的Cudnn的话可能会导致计算结果不能🔁重复,这是因为并行计算的加法问题,100个数相加的先后顺序不同的话,得到的结果也会不同(精度不够),想重复的话需要固定住CuDNN
  • 随机性会使得整个神经网络的收敛域变的平滑
  • 每个batch丢进去之后,都要丢弃一次
  • 老师,dropout每次随机选几个子网络,最后做平均的做法是不是类似于随机森林多决策树做投票的这种思想?(是的)
  • 深度学习:我需要模型够强,但是我需要通过正则化来保证不要学偏

写在最后

各位看官,都看到这里了,麻烦动动手指头给博主来个点赞8,您的支持作者最大的创作动力哟!
才疏学浅,若有纰漏,恳请斧正
本文章仅用于各位作为学习交流之用,不作任何商业用途,若涉及版权问题请速与作者联系,望悉知

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1377708.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

探索 hasOwnProperty:处理对象属性的关键(下)

&#x1f90d; 前端开发工程师&#xff08;主业&#xff09;、技术博主&#xff08;副业&#xff09;、已过CET6 &#x1f368; 阿珊和她的猫_CSDN个人主页 &#x1f560; 牛客高级专题作者、在牛客打造高质量专栏《前端面试必备》 &#x1f35a; 蓝桥云课签约作者、已在蓝桥云…

问界又“翻车”了? 新能源电池“怕冷”成短板

文 | AUTO芯球 作者 | 李欣 2023年12月17日&#xff0c;蔚来创始人李斌亲自下场&#xff01;驾驶ET7从上海出发&#xff0c;经过超14小时的行驶后&#xff0c;达成一块电池行驶超过1000公里的成绩&#xff0c;这一直播引起外界的广泛关注。 这不禁让人与”懂车帝冬测“联想到…

Vue3 的基本开发+新特性

Vue3 1.Vue3 1. Vue2 选项式 API vs Vue3 组合式API <script> export default {data(){return {count:0}},methods:{addCount(){this.count}} } </script> <script setup> import { ref } from vue const count ref(0) const addCount ()> count.val…

文件操作(你真的会读写文件吗?)

文章目录 一、为什么使用文件&#xff1f;二、什么是文件&#xff1f;2.1 程序文件2.2 数据文件2.3 文件名 三、二进制文件和文本文件3.1 二进制文件3.2 文本文件 四、文件的打开和关闭4.1 流和标准流4.1.1 流4.1.2 标准流 4.2 文件指针4.3 fopen和fclose 五、文件的顺序读写5.…

代码随想录刷题第四十八天| 198.打家劫舍 ● 213.打家劫舍II ● 337.打家劫舍III

代码随想录刷题第四十八天 今天是打家劫舍三部曲&#xff0c;最后一题树形dp有点难&#xff0c;其他还好 打家劫舍 (LC 198) 题目思路&#xff1a; 代码实现&#xff1a; class Solution:def rob(self, nums: List[int]) -> int:dp [0 for _ in range(len(nums)1)]dp[1…

Open3D 截取感兴趣的点云部分

import time import open3d as o3d; import numpy as np; import matplotlib.pyplot as plt from scipy.signal import find_peaks#坐标 mesh_coord_frame o3d.geometry.TriangleMesh.create_coordinate_frame(size355, origin[0, 0, 0]) #mesh_coord_frame mesh_coord_frame…

机器学习_实战框架

文章目录 介绍机器学习的实战框架1.定义问题2.收集数据和预处理(1).收集数据(2).数据可视化(3).数据清洗(4).特征工程(5).构建特征集和标签集(6).拆分训练集、验证集和测试集。 3.选择算法并建立模型4.训练模型5.模型的评估和优化 介绍机器学习的实战框架 一个机器学习项目从开…

UVa1308/LA2572 Viva Confetti

题目链接 本题是2002年ICPC亚洲区域赛金沢(日本)赛区的H题 题意 我已经把n个圆盘依次放到了桌面上。现按照放置顺序依次给出各个圆盘的圆心位置和半径&#xff0c;问最后有多少圆盘可见&#xff1f;如下图所示。 分析 《训练指南》的题解&#xff1a; 题目说“保证在对输入数据…

87.乐理基础-记号篇-反复记号(一)反复、跳房子

内容参考于&#xff1a;三分钟音乐社 上一个内容&#xff1a;86.乐理基础-记号篇-速度记号-CSDN博客 首先是反复记号表总结图&#xff1a; 当前是写前两个记号&#xff0c;其余记号后面写&#xff1a;这些反复记号最主要的目的很简单&#xff0c;还是为了节约纸张&#xff0c…

使用Linux安装Mysql Community Server 8.0.35

一、下载Mysql 官网&#xff1a;https://www.mysql.com/ 第一步&#xff1a;进入Linux官网&#xff0c;点击下载 第二步&#xff1a;点击MySQL Community (GPL) Downloads 第三步&#xff1a;进入页面&#xff0c;选择 MySQL Community Server 第四步&#xff1a;根据自己服务…

SpringBoot集成RabbitMq,RabbitMq消费与生产,消费失败重发机制,发送签收确认机制

RabbitMq消费与生产&#xff0c;消费失败重发机制&#xff0c;发送确认机制&#xff0c;消息发送结果回执 1. RabbitMq集成spring bootRabbitMq集成依赖RabbitMq配置RabbitMq生产者&#xff0c;队列&#xff0c;交换通道配置&#xff0c;消费者示例 2. RabbitMq消息确认机制消息…

LangChain 72 reference改变结果 字符串评估器String Evaluation

LangChain系列文章 LangChain 60 深入理解LangChain 表达式语言23 multiple chains链透传参数 LangChain Expression Language (LCEL)LangChain 61 深入理解LangChain 表达式语言24 multiple chains链透传参数 LangChain Expression Language (LCEL)LangChain 62 深入理解Lang…

Nginx——基础配置

和大多数软件一样&#xff0c;Nginx也有自己的配置文件&#xff0c;但它又有很多与众不同的地方&#xff0c;本帖就来揭开Nginx基础配置的面纱。 1、Nginx指令和指令块 了解指令和指令块有助于大家了解配置的上下文&#xff0c;下面是一个配置模板示例&#xff1a; 在这个配…

Transformer详解【学习笔记】

文章目录 1、Transformer绪论2、Encoders和Decoder2.1 Encoders2.1.1 输入部分2.1.2 多头注意力机制2.1.3 残差2.1.4 LayNorm&#xff08;Layer Normalization&#xff09;2.1.5 前馈神经网路 2.2 Decoder2.2.1 多头注意力机制2.2.2 交互层 1、Transformer绪论 Transformer在做…

第11章 GUI Page495~496 步骤三十一:另存为别的文件

当前的TrySaveFile(bool hint_on_dirty true)有两个特征无法满足“另存”的需求&#xff1a; 一&#xff0c;TrySaveFile仅在数据为“新”的时候才提问用户输入文件名。而“另存”总是要求用户输入一个文件名&#xff0c;多以它总应该弹出一个文件选择对话框&#xff0c;这也…

从零到一的方法:学习视频剪辑与嵌套合并技巧

随着社交媒体和数字技术的快速发展&#xff0c;视频制作已是常见的工作。那么如何批量嵌套合并视频呢&#xff1f;下面一起来看云炫AI智剪如何批量合并的方法。 嵌套合并后的视频截图查看。 批量嵌套合并的操作&#xff1a; 操作1、在云炫AI智剪上选择“嵌套合并”功能&#…

PHP版学校教务管理系统源码带文字安装教程

PHP版学校教务管理系统源码带文字安装教程 运行环境 服务器宝塔面板 PHP 7.0 Mysql 5.5及以上版本 Linux Centos7以上 系统介绍&#xff1a; 后台权限控制&#xff1a;支持多个管理员&#xff0c;学生管理&#xff0c;学生成绩&#xff0c;教师管理&#xff0c;文章管理&#x…

CSS3简单运用过渡元素(transition)

CSS3过渡 概念&#xff1a;在CSS3中&#xff0c;我们可以使用transition属性将元素的某一个属性从“一个属性值”在指定的时间内平滑地过渡到“另一个属性值”&#xff0c;从而实现动画效果。 CSS3变形&#xff08;transform)呈现的仅仅是一个结果&#xff0c;而CSS过渡&…

方波 离散傅里叶级数 MATLAB

%方波 离散时间傅里叶变换 L 5; N 10; k [-N/2:1:N/2]; %占空比 基本周期 离散时间的参数 xn [ones(1,L),zeros(1,N-L)]; %生成方波序列 XK dfs(xn,N); magXK abs([XK(N/21:N),XK(1:N/21)]); subplot(2,2,3); stem(k,magXK); axis([-N/2,N/2,-0.5,5.5]); xlabel(k); y…

麒麟OS + DM8数据库(Graalvm for JDK17) 测试

1、添加依赖 implementation com.dameng:DmJdbcDriver18:8.1.3.62 implementation com.baomidou:mybatis-plus-boot-starter:3.5.4 2、application.yml 数据源配置 spring: datasource: driver-class-name: dm.jdbc.driver.DmDriver #com.mysql.cj.jdbc.Driver url: jdbc:d…