Dropout层的个人理解和具体使用

news2024/11/26 19:29:20

Dropout层的作用

        dropout 能够避免过拟合,我们往往会在全连接层这类参数比较多的层中使用dropout;在训练包含dropout层的神经网络中,每个批次的训练数据都是随机选择,实质是训练了多个子神经网络,因为在不同的子网络中随机忽略的权重的位置不同,最后在测试的过程中,将这些小的子网络组合起来,类似一种投票的机制来作预测,有点类似于集成学习的感觉。

  关于dropout,有nn.Dropoutnn.functional.dropout两种。推荐使用nn.Dropout,因为一般情况下只有训练train时才用dropout,在eval不需要dropout。使用nn.Dropout,在调用model.eval()后,模型的dropout层和批归一化(batchnorm)都关闭,但用nn.functional.dropout,在没有设置training模式下调用model.eval()后不会关闭dropout。
  这里关闭dropout等的目的是为了测试我们训练好的网络。在eval模式下,dropout层会让所有的激活单元都通过,而batchnorm层会停止计算和更新mean和var,直接使用在train训练阶段已经学出的mean和var值。同时我们在用模型做预测的时候也应该声明model.eval()。

注⚠️:为了进一步加速模型的测试,我们可以设置with torch.no_grad(),主要是用于停止autograd模块的工作,以起到加速和节省显存的作用,具体行为就是停止梯度gradient计算和储存,从而节省了GPU算力和显存,但是并不会影响dropout和batchnorm层的行为,这样我们可以使用更大的batch进行测试。

model.eval()下不启用 Batch Normalization 和 Dropout。
如果模型中有BN层(Batch Normalization)和Dropout,在测试时添加model.eval()。model.eval()是保证BN层能够用全部训练数据的均值和方差,即测试过程中要保证BN层的均值和方差不变。对于Dropout,model.eval()是利用到了所有网络连接,即不进行随机舍弃神经元。

下面一段话是关于BN层的解释:

在训练过程中,Dropout的实现是让神经元以超参数p的概率停止工作或者激活被置为0,未被置为0的进行缩放,缩放比例为1/(1-p)

Dropout正则化

  • 概念

    • 每次迭代过程中按照层,随机选择某些节点删除前向和后向连接

    • 注意:对隐藏层的某一个节点给删掉 (dropout的时候,神经网络的输入和输出节点个数没有发生变化)

  • 看成机器学习中的集成方法(ensemble technique)

    • 删除一些神经元, 减少模型复杂度, 也减少了过拟合

  • 被失活的神经元输出为0 其他神经元被放大 1/(1-rate)倍

  • 只在网络训练时有效, 模型预测时无效

  • 神经网络中独有的方法

  • 重中之重:随机置零的是神经元,也就是forward里的self.linear(x)里的x

注意:Dropout只针对隐藏层来说的!!!

代码实现:

单纯调用Dropout:

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim


def dropout_test():
    torch.random.manual_seed(0)
    inputs = torch.randn(2, 5)
    dropout = nn.Dropout(p=0.4)
    print("dropout之前=>>>inputs", inputs)
    out = dropout(inputs)
    print("dropout之后=>>>inputs", out)
    print("各个元素乘以 1 / (1-p)=>>>", inputs * 1 / (1 - 0.4))


if __name__ == '__main__':
    dropout_test()

运行结果:

上图展示了“在训练过程中,Dropout的实现是让神经元以超参数p的概率停止工作或者激活被置为0,未被置为0的进行缩放,缩放比例为1/(1-p)”这句话的体现!!!

nn.Module类型的例子:

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim


class LinearModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear1 = nn.Linear(in_features=3, out_features=5)
        self.linear2 = nn.Linear(in_features=5, out_features=20)
        self.out = nn.Linear(in_features=20, out_features=2)
        self.dropout = nn.Dropout(p=0.2)

    def forward(self, x):
        x = self.linear1(x)
        print("Dropout前的权重=>>>", x)
        x = self.dropout(x)
        print("Dropout后的权重=>>>", x)
        x = torch.relu(x)
        x = self.linear2(x)
        x = self.dropout(x)
        x = torch.relu(x)
        out = self.out(x)
        return F.log_softmax(out, dim=1)


def train():
    learning_rate = 0.1
    inputs = torch.randn(5, 3, dtype=torch.float32)
    label = torch.tensor([1, 0, 1, 0, 0], dtype=torch.int64)

    model = LinearModule()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    criterion = nn.NLLLoss()

    y_pred = model(inputs)
    loss = criterion(y_pred, label)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()


if __name__ == '__main__':
    train()

运行结果:

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/587249.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Linux常用命令——gzip命令

在线Linux命令查询工具 gzip 用来压缩文件 补充说明 gzip命令用来压缩文件。gzip是个使用广泛的压缩程序,文件经它压缩过后,其名称后面会多处“.gz”扩展名。 gzip是在Linux系统中经常使用的一个对文件进行压缩和解压缩的命令,既方便又好…

Python过滤信息,如省位中包含广东、安徽、浙江这3个省份的话,就pass,怎么破?...

点击上方“Python爬虫与数据挖掘”,进行关注 回复“书籍”即可获赠Python从入门到进阶共10本电子书 今 日 鸡 汤 但令心似金钿坚,天上人间会相见。 大家好,我是皮皮。 一、前言 前几天遇到了一个小问题,在做资料的时候&#xff0c…

7位专家齐聚openGauss Developer Day 2023云和恩墨专题论坛,共论数据库自主创新改造与技术发展...

5月26日,云和恩墨在「openGauss Developer Day 2023」主论坛上大放异彩(←点此回顾主论坛精彩时刻),更是通过举办一场数据库技术创新与应用实践分论坛,力邀7位重量级嘉宾就数据库创新能力构建、行业应用实践和迁移替代…

SpringBoot配置文件的注入和读取

目录 1. 配置文件的作用 2. 两种配置文件的格式: 2.1 properties 基本语法: 2.1.1 写入 2.1.2 读取 执行原理 2.1.3 缺点分析 2.2 yml 基本语法: 2.2.1 写入(非对象) 2.2.3 配置对象 2.2.4 配置集合 多个配…

【Linux初阶】基础IO - 文件管理(深入理解文件描述符) | 重定向

文章目录 一、文件管理引入二、理解文件描述符三、文件描述符表四、文件描述符的分配规则五、重定向六、使用 dup2 系统调用实现重定向1.模拟实现 >&#xff08;输出&#xff09;2.模拟实现 >>&#xff08;追加&#xff09;3.模拟实现 <&#xff08;输入&#xff0…

【观察】浪潮信息:自研液环式真空CDU技术,将被动应对变为主动防御

毫无疑问&#xff0c;在“双碳”战略的大环境下&#xff0c;数据中心走向绿色低碳和可持续发展已成为“不可逆”的大趋势&#xff0c;特别是随着全国一体化大数据中心、新型数据中心等政策文件的出台、“东数西算”工程的正式启动&#xff0c;数据中心的建设规模和数量呈现出快…

老胡周刊QA微信机器人(基于ChatGPT)

背景 先做个介绍吧&#xff0c;老胡的信息周刊是我从2021-08-16创立的周刊&#xff0c;截止到目前(2023-05-29)将近两年时间&#xff0c;目前已经有92期周刊&#xff0c;中间基本没有断更过&#xff0c;一共发布资源统计如下&#xff1a; &#x1f3af; 项目 288&#x1f916; …

Ae:稳定运动

使用跟踪器 Tracker面板的稳定运动 Stabilize Motion功能&#xff0c;可通过手动添加和设置跟踪点来跟踪对象的运动&#xff0c;将获得的跟踪数据对图层本身进行反向变换&#xff0c;从而达到稳定画面的目的。 Ae菜单&#xff1a;窗口/跟踪器 Tracker 点击跟踪器面板上的“稳定…

长文教你如何正确使用ChatGPT提高学习效率!

最近 Chat GPT 很&#x1f525;&#xff0c;被大家评为无所不能的最强AI。据说&#xff0c;有百分之八十的留学生已经在用ChatGPT 来写作业了&#xff0c;因为ChatGPT真的是有问必答&#xff0c;光速回复&#xff0c;复制粘贴都没有它回答的快。 目录 Part.1 什么是ChatGPT&a…

驱动开发:内核读写内存浮点数

如前所述&#xff0c;在前几章内容中笔者简单介绍了内存读写的基本实现方式&#xff0c;这其中包括了CR3切换读写&#xff0c;MDL映射读写&#xff0c;内存拷贝读写&#xff0c;本章将在如前所述的读写函数进一步封装&#xff0c;并以此来实现驱动读写内存浮点数的目的。内存浮…

centos安装KVM

文章目录 一、centos安装KVM步骤 1. 检查硬件支持 2. 安装 KVM 相关软件包 3. 启动 libvirtd 服务 4. 设置 libvirtd 服务自启动 5. 验证 KVM 安装 二、出现问题的解决方法 1. 检查网络连接 2. 检查 DNS 解析 3. 检查软件源设置 4. 禁用 IPv6 前言 本篇主要介绍cen…

教育最大的失败,是普通家庭富养孩子

作者| Mr.K 编辑| Emma 来源| 技术领导力(ID&#xff1a;jishulingdaoli) 著名教育家马卡连柯曾说&#xff1a;“一切都给孩子&#xff0c;牺牲一切&#xff0c;甚至牺牲自己的幸福&#xff0c;这是父母给孩子最可怕的礼物。”前些天刷到一个挺扎心的视频&#xff0c;不知道算…

商业智能 (BI) 对企业中每个员工的 5 大好处

本文由葡萄城技术团队于博客园原创并首发。转载请注明出处&#xff1a;葡萄城官网&#xff0c;葡萄城为开发者提供专业的开发工具、解决方案和服务&#xff0c;赋能开发者。 众所周知&#xff0c;商业智能 (BI) 是探索企业数据价值的强大工具&#xff0c;能够帮助企业做出明智…

全网最全2W字-基于Java+SpringBoot+Vue+Element实现小区生活保障系统(建议收藏)

博主介绍&#xff1a;✌全网粉丝30W,CSDN特邀作者、博客专家、新星计划导师、java领域优质创作者,博客之星、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ &#x1f345;文末获取源码联系&#x1f345; &#x1f447;&#x1f3fb; 精彩专栏推…

破局35岁危机:技术人如何做好职业规划?

见字如面&#xff0c;我是军哥。 最近有一位读者工作 8 年了&#xff0c;后端做了 3 年&#xff0c;算法做了 5 年&#xff0c;换了 6/7 家公司&#xff0c;基本上每一家公司只干 1 年左右&#xff0c;换了 N 个行业&#xff0c;现在工作出现瓶颈&#xff0c;也不知道未来的路怎…

十、Git代码仓库

一、Git概述 Git是一个开源的分布式版本控制系统&#xff0c;可以有效、高速地处理从很小到非常大的项目版本管理。 也是为了帮助管理Linux内核开发而开发的一个开放源码的版本控制软件。 二、Git常用命令 查看git配置 git config -l设置用户名和邮箱 git config --global u…

带电更换柱上变压器(综合不停电作业法)

一、现场复勘 1.核对工作线路双重名称、杆号及设备双重名称 2.检查杆身质量 3.检查线路装置是否符合带电作业要求 4.检查待更换变压器容量 满足旁路作业要求 5.检查气象条件 作业前进行湿度和风速的测量&#xff0c;风力大于5级或湿度大于80%时&#xff0c;不宜带电作业&…

开源“模仿”ChatGPT,居然效果行?UC伯克利论文,劝退,还是前进?

原创&#xff1a;谭婧ChatGPT 从“古”至今&#xff0c;AI的世界&#xff0c;是一个开源引领发展的世界。 虽然Stable Diffusion作为开源的图像生成模型&#xff0c;将图像生成提到了全新境界&#xff0c;但是ChatGPT的出现&#xff0c;似乎动摇了一些人的信念。 因为ChatGPT是…

16. Vue-element-template记住密码

Vue-element-template 记住密码 1. 在登录页面添加记住密码按钮 新增参数 rememberMe # resources/src/views/login/index.vueloginForm: {username: admin,password: 123456,rememberMe: false},添加复选框 # resources/src/views/login/index.vue<div style"margin-…

一、STM32开发环境的搭建(Keil+STM32CubeMX)

1、STM32开发环境所需的东西 (1)KeilMDK安装包。 (2)STM32CubeMX。 (3)Keil软件对应的单片机pack包。 (4)STM32Cube MCU包。 2、Keil简介及安装 略 3、CubeMX简介及安装 3.1、CubeMX简介 (1)STM32CubeMX是一种图形工具&#xff0c;通过分步过程可以非常轻松地配置STM3…