深度学习中的 Dropout:原理、公式与实现解析

news2024/11/7 14:36:53

8. dropout

深度学习中的 Dropout:原理、公式与实现解析

在神经网络训练中,模型往往倾向于“记住”训练数据的细节甚至噪声,导致模型在新数据上的表现不佳,即过拟合。为了解决这一问题,Dropout 应运而生。通过在训练过程中随机丢弃一部分神经元,Dropout 能减少模型对特定神经元的依赖,从而提升泛化能力,今天我们将深入讲解 Dropout 的原理,并用代码实现它!


为什么需要 Dropout?

在没有正则化的情况下,神经网络可能会过于依赖于某些特定的神经元,这种现象容易导致过拟合。Dropout 通过随机丢弃神经元,避免模型过度依赖某些特征,使得模型在新数据上表现更好。


Dropout 的工作原理

1. Dropout 的训练过程

假设我们有一个输入向量 x = [ x 1 , x 2 , … , x n ] x = [x_1, x_2, \dots, x_n] x=[x1,x2,,xn]Dropout 在训练时会遵循以下步骤:

  1. 设置丢弃概率 p p p :通常在 0.1 到 0.5 之间,表示每个神经元被丢弃的概率。
  2. 生成随机掩码 m m m
    • 对每个神经元生成一个随机值。
    • 如果随机值小于 p p p ,该神经元输出置为 0(即丢弃)。
    • 如果随机值大于等于 p p p ,该神经元输出保持不变。
  3. 应用掩码:将掩码与输入相乘,丢弃部分神经元输出。

在测试时,我们不再随机丢弃神经元,而是将每个神经元的输出缩小 1 − p 1 - p 1p 倍,以保持与训练时相同的输出期望值。


Dropout 的数学公式

在训练时,Dropout 可以用以下公式表示:

output = x ⋅ m \text{output} = x \cdot m output=xm

其中 m m m 是随机掩码,0 表示丢弃,1 表示保留。训练时,为了保持输出一致性,我们会将结果除以 1 − p 1 - p 1p

output = x ⋅ m 1 − p \text{output} = \frac{x \cdot m}{1 - p} output=1pxm

在测试时,我们不再随机丢弃,而是将每个神经元的输出乘以 1 − p 1 - p 1p

output = x ⋅ ( 1 − p ) \text{output} = x \cdot (1 - p) output=x(1p)

这样可以确保训练和测试时的输出分布一致。


自己实现一个 Dropout 类

为了帮助大家理解 Dropout 的实现原理,我们可以用 Python 和 PyTorch 实现一个简单的 Dropout 类。

import torch
import torch.nn as nn

class CustomDropout(nn.Module):
    def __init__(self, p=0.5):
        super(CustomDropout, self).__init__()
        self.p = p  # 丢弃概率

    def forward(self, x):
        if self.training:
            # 生成与 x 形状相同的随机掩码
            mask = (torch.rand_like(x) > self.p).float()
            return x * mask / (1 - self.p)
        else:
            # 推理时,直接缩放输出
            return x * (1 - self.p)


代码解析

  • 初始化:我们定义了 p 表示丢弃的概率。p 越大,丢弃的神经元越多。
  • 前向传播
    • 在训练模式下:生成一个与输入张量形状相同的随机掩码,对每个神经元随机保留或丢弃。
    • 在测试模式下:不再随机丢弃,而是将输出乘以 1 − p 1 - p 1p ,确保输出分布一致。

测试代码

我们可以使用以下代码测试自定义 Dropout 的效果。

# 输入张量 x
x = torch.ones(5, 5)  # 一个简单的 5x5 全 1 张量

# 实例化自定义 Dropout
dropout = CustomDropout(p=0.5)

# 训练模式
dropout.train()
output_train = dropout(x)
print("训练模式下的输出:\\n", output_train)

# 推理模式
dropout.eval()
output_eval = dropout(x)
print("推理模式下的输出:\\n", output_eval)

解释测试结果

  • 训练模式:输出中会有一部分元素被随机置为 0,其余的值会放大(除以 1 − p 1 - p 1p )。
  • 推理模式:所有元素值会被缩小到 1 − p 1 - p 1p 倍,以确保训练和推理阶段输出分布一致。

为什么训练和测试阶段需要缩放?

在训练时,Dropout 随机丢弃一部分神经元,使得实际参与计算的神经元变少。这样训练时的输出总量会降低,因此我们需要对保留下来的神经元进行缩放(除以 1 − p 1 - p 1p )。在测试时,我们则对输出进行整体缩放(乘以 1 − p 1 - p 1p ),以确保训练和测试阶段的输出期望值一致,从而保证模型在不同阶段表现一致。


总结

  • Dropout 是一种防止过拟合的正则化方法,通过随机丢弃神经元来提升模型的泛化能力。
  • 在训练时,随机丢弃神经元并缩放剩余神经元的输出。
  • 在推理时,直接缩放整个输出,以保持训练和推理的分布一致。

希望这篇文章能帮助你理解 Dropout 的工作原理和实现过程。如果有任何疑问,欢迎留言讨论!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2235078.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

stable diffusion 大模型

本节内容,给大家带来的是stable diffusion的基础模型课程。基础模型,我们有时候也称之为大模型。在之前的课程中,我们已经多次探讨过大模型,并且也见识过一些大模型绘制图片的独特风格,相信大家对stable diffusion大模…

ChatPaper.ai:研究生文献阅读的AI助手利器

为什么选择 ChatPaper.ai? 作为研究生,我们每天都面临着大量文献阅读的挑战。一篇关键论文往往需要反复阅读数小时,还要做笔记、提取要点、理解创新点,这个过程既耗时又费力。ChatPaper.ai(ChatPaper.ai - Chat with …

python-docx -- 读取word图片

文章目录 概念介绍形状对象读取图片自定义图形 概念介绍 从概念上来讲,word文档分为两层,一个文本层,一个绘画层; 文本层,从上到下,从左到右,流式排版,本页填满则开启新页面&#…

Python邮差:如何用代码精确投递商品快递费用的密信

目录 一、准备工作 二、编写API请求脚本 三、解析与处理快递费用数据 四、案例应用:模拟电商平台的快递费用计算 五、自动化邮件通知 六、总结 在电子商务的广阔天地里,精确计算并快速传递商品快递费用是一项至关重要的任务。作为Python邮差&#…

swoole扩展安装--入门篇

对于php来说,swoole是个强大的补充扩展。这是我第3次写swoole扩展安装,这次基于opencloudos8系统,php使用8.2。 安装swoole扩展首先想到的是用宝塔来安装,毕竟安装方便,还能统一管理。虽然获得swoole版本不是最新的&am…

Linux信号_信号的保存

我们知道向进程发送信号,进程并不是立即处理,而是等合适的时机进行处理。那么就需要保存信号。在信号的产生中说过信号保存在进程PCB里面的信号位图里,那信号位图到底是什么? 一.信号保存 我们先补充一些概念 1.阻塞 忽略概念 实…

如何使用示波器测量信号强度

示波器是一种用于观察和分析电信号的电子测试仪器。它可以显示信号的波形、幅度、频率和其他特性,是工程师和技术人员进行电路设计、调试和故障排除的重要工具。本文将详细介绍如何使用示波器测量信号强度。 一、认识示波器的基本组成部分 显示屏:用于显…

Axure设计之三级联动选择器教程(中继器)

使用Axure设计三级联动选择器(如省市区选择器)时,可以利用中继器的数据存储和动态交互功能来实现。下面介绍中继器三级联动选择器设计的教程: 一、效果展示: 1、在三级联动选择器中,首先选择省份&#xff…

K8S篇(基本介绍)

目录 一、什么是Kubernetes? 二、Kubernetes管理员认证(CKA) 1. 简介 2. 考试难易程度 3. 考试时长 4. 多少分及格 5. 考试费用 三、Kubernetes整体架构 Master Nodes 四、Kubernetes架构及和核心组件 五、Kubernetes各个组件及功…

卖模版还能赚到钱吗?

说到赚钱,我想大部分人都会感兴趣。但如果告诉大家现阶段卖模板也能赚钱,可能还是有人不信。我要说说我的观察了。 本文可在公众号「德育处主任」免费阅读 我是一只临期程序猿,我最早接触到“模板能卖钱”这个概念是在模板王里。模板王平台上…

超萌!HTMLCSS:打造趣味动画卡通 dog

这段HTML与CSS代码实现了一个超萌的动画卡通dog。 HTML <div class"dog"><div class"dog-body"><div class"dog-tail"><div class"dog-tail"><div class"dog-tail"><div class"do…

Elasticsearch Interval 查询:为什么它们是真正的位置查询,以及如何从 Span 转换

作者&#xff1a;来自 Elastic Mayya Sharipova 解释 span 查询如何成为真正的位置查询以及如何从 span 查询过渡到它们。 长期以来&#xff0c;Span 查询一直是有序和邻近搜索的工具。这些查询对于特定领域&#xff08;例如法律或专利搜索&#xff09;尤其有用。但相对较新的 …

【YOLOv11[基础]】实例分割Seg | 导出ONNX模型 | ONN模型推理以及检测结果可视化 | python

本文将导出YOLO-Seg.pt模型对应的ONNX模型,并且使用ONNX模型推理以及结果的可视化。话不多说,先看看效果图吧!!! 目录 一 导出ONNX模型 二 推理及检测结果可视化 1 代码 2 效果图

手搓AI大模型应用获25万用户,果断辞职创业,结果收入不如摆摊

我开发的 AI 应用有 25 万用户&#xff0c;我感觉要起飞了&#xff0c;于是辞掉工作&#xff0c;准备大干一番。 结果没想到开局即巅峰&#xff0c;突然就完蛋了。 这几天&#xff0c;一个悲催的程序员创业故事在社交网络上流传&#xff0c;引发了人们的深思。 故事的主人公&…

品质生活新选择:看三星AI神黑钻衣物护理机,如何为用户打造精致日常

屠格涅夫曾说&#xff0c;一个人应当好好地安排生活&#xff0c;要使每一刻的时光都有意义。这不仅是对个人生活的深刻洞察&#xff0c;也是对生活品质的不懈追求。实际上&#xff0c;在追求品质生活的道路上&#xff0c;无关乎年龄和阶层&#xff0c;其核心精髓往往潜藏于那些…

ios打包文件上传App Store windows工具

在苹果开发者中心上架IOS APP的时候&#xff0c;在苹果开发者中心不能直接上传打包文件&#xff0c;需要下载mac的xcode这些工具进行上传&#xff0c;但这些工具无法安装在windows或linux电脑上。 这里&#xff0c;我们可以不用xcode这些工具来上传&#xff0c;可以用国内的香…

Nginx(编译)+Lua脚本+Redis 实现自动封禁访问频率过高IP

1.安装lua 1.1安装LuaJIT yum install readline-devel mkdir -p lua-file cd lua-file/ wget http://luajit.org/download/LuaJIT-2.0.5.tar.gz tar -zxvf LuaJIT-2.0.5.tar.gz cd LuaJIT-2.0.5 make && make install PREFIX/usr/local/luajit 1.2配置LuaJIT环境变量…

OA项目 python + vue3

准备工作 创建django项目 在setting.py进行数据库的配置&#xff1a; DATABASES {default: {ENGINE: django.db.backends.mysql,NAME: , #数据库名字USER: , #连接的数据库的用户名PASSWORD: ,HOST: 127.0.0.1,PORT: 3306,} }安装app&#xff1a; rest_framwork: 关闭csrf…

内网渗透-信息收集篇

通过webshell或其他方式拿下一台机器&#xff0c;并且存在内网环境&#xff0c;这个时候就在准备进行内网渗透&#xff0c;而在内网渗透之前需要对本地机器进行信息收集&#xff0c;才能够更好的进行内网渗透。 目录 Windows本地基础信息收集 权限查看 判断域存在 查看防火…

斯坦福团队研发:手机运行的超GPT-4大模型一夜爆红,下载量突破2000次

在大模型落地应用的过程中&#xff0c;端侧 AI 是非常重要的一个方向。 近日&#xff0c;斯坦福大学研究人员推出的 Octopus v2 火了&#xff0c;受到了开发者社区的极大关注&#xff0c;模型一夜下载量超 2k。 20 亿参数的 Octopus v2 可以在智能手机、汽车、个人电脑等端侧…