抑制过拟合——Dropout原理

news2025/1/15 19:34:28

抑制过拟合——Dropout原理

  • Dropout的工作原理
  • 实验观察

  在机器学习领域,尤其是当我们处理复杂的模型和有限的训练样本时,一个常见的问题是过拟合。简而言之,过拟合发生在模型对训练数据学得太好,以至于它捕捉到了数据中的噪声和误差,而不仅仅是底层模式。具体来说,这在神经网络训练中尤为常见,表现为在训练数据上表现优异(例如损失函数值很小,预测准确率高)而在未见过的数据(测试集)上表现不佳。

  过拟合不仅是机器学习新手容易遇到的问题,即使是经验丰富的从业者也会面临这一挑战。一个典型的解决方案是采用模型集成技术,这涉及训练多个模型并将它们的预测结合起来。但这种方法的缺点是显而易见的:它既耗时又昂贵,不仅在训练阶段,而且在模型评估和部署时也是如此。

  在这种背景下,Dropout 作为一种有效的正则化技术,可以显著减轻过拟合问题。它的基本原理是在每次训练迭代中随机“丢弃”(即暂时移除)网络中的一部分神经元。这种方法不仅简单,而且被证明在许多情况下都非常有效。

Dropout的工作原理

  在 PyTorch 中,Dropout 层的使用相当直观。通常,它被添加到神经网络的各个层之间,如下所示:

torch.nn.Dropout(p=0.5, inplace=False)

  p:这是一个关键参数,代表着每个神经元被丢弃的概率。

  在实践中,这意味着对于网络中的每个神经元,它在每次训练迭代中都有 1 − p 1-p 1p 的概率被保留, p p p 的概率被丢弃。值得注意的是,这种随机性确保了每个mini-batch都在对不完全相同的网络进行训练,从而减少过拟合的风险。

  在训练期间,对于每个训练样本,网络中的每个神经元都有概率 1 − p 1-p 1p 被保留,概率 p p p 被丢弃。如果神经元被保留,则其输出乘以 1 1 − p \frac{1}{1-p} 1p1​(这样做是为了保持该层输出的总期望值不变)。设 r j r_j rj​ 为一个随机变量,它对应于第 j j j 个神经元,且服从伯努利分布(即 r j = 1 r_j = 1 rj=1 的概率为 1 − p 1-p 1p r j = 0 r_j = 0 rj=0 的概率为 p p p)。那么在训练时,神经元的输出 y j y_j yj变为 r j × y j / ( 1 − p ) r_j \times y_j / (1-p) rj×yj/(1p)

为什么需要保持期望不变? 举个简单的例子,假设某层有两个神经元,它们的输出在没有dropout时都是1。在应用了50%的dropout后,期望只有一个神经元被激活,输出为1,另一个被丢弃,输出为0。这样,这层的平均输出变成了0.5。为了保持输出的总期望值不变,激活的神经元的输出应该乘以2,即 1 1 − p \frac{1}{1-p} 1p1​,这样平均输出才能保持为1,与没有应用dropout时相同。这样的处理有助于保持整个网络的稳定性和一致性。

  在模型预测(或测试)阶段,所有的神经元都保持激活(即不进行dropout)。因为在训练阶段,神经元的输出已经被放大了 1 1 − p \frac{1}{1-p} 1p1 倍,所以在预测时不需要进行任何调整,直接使用网络进行前向传播即可。

在这里插入图片描述

实验观察

  为了更深入地理解 Dropout 的影响,我们可以通过一个实验来观察不同的 Dropout 设置对训练过程的影响。比如,可以比较 Dropout = 0.1Dropout = 0 在训练过程中的表现差异,相关代码实现如下:

import torch
from tensorboardX import SummaryWriter
from torch import optim, nn
import time


class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.linears = nn.Sequential(
            nn.Linear(2, 20),

            nn.Linear(20, 20),
            nn.Dropout(0.1),

            nn.Linear(20, 20),

            nn.Linear(20, 20),

            nn.Linear(20, 1),
        )

    def forward(self, x):
        _ = self.linears(x)
        return _

lr = 0.01
iteration = 1000


x1 = torch.arange(-10, 10).float()
x2 = torch.arange(0, 20).float()
x = torch.cat((x1.unsqueeze(1), x2.unsqueeze(1)), dim=1)
y = 2*x1 - x2**2 + 1

model = Model()
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=0.01)
loss_function = torch.nn.MSELoss()

start_time = time.time()
writer = SummaryWriter(comment='_随机失活')

for iter in range(iteration):
    y_pred = model(x)
    loss = loss_function(y, y_pred.squeeze())
    loss.backward()

    for name, layer in model.named_parameters():
        writer.add_histogram(name + '_grad', layer.grad, iter)
        writer.add_histogram(name + '_data', layer, iter)
    writer.add_scalar('loss', loss, iter)

    optimizer.step()
    optimizer.zero_grad()

    if iter % 50 == 0:
        print("iter: ", iter)

print("Time: ", time.time() - start_time)

这里我们使用 TensorBoardX 进行结果的可视化展示。

  通过观察模型训练1000轮后的线性层梯度分布,可以发现,应用 Dropout 后的模型梯度通常会更加分散和多样化。这种梯度的多样性有助于防止模型过于依赖训练数据中的特定模式,从而减轻过拟合。

在这里插入图片描述

  同样值得注意的是,模型的损失曲线也会受到影响。加入 Dropout 通常会使损失曲线出现更多的波动(例如,图中的蓝色曲线),这反映了模型在学习过程中的不稳定性。然而,这种不稳定性通常是可接受的,因为它反映了模型正在学习更多的泛化模式而不是简单地记住训练数据。

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1272236.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

热烈欢迎省工信厅毛郑建处长莅临公司调研指导工作

2023年11月28日,河南省工信厅信息化和软件服务业处毛郑建处长莅临郑州埃文计算机科技有限公司(以下简称“埃文科技”)调研考察工作。河南省工业信息安全产业发展联盟理事长任传军陪同调研。 首先,埃文科技董事长王永向毛处长介绍埃…

开源运维监控系统-Nightingale(夜莺)应用实践(未完)

一、前言 某业务系统因OS改造,原先的Zabbix监控系统推倒后未重建,本来计划用外部企业内其他监控系统接入,后又通知需要自建才能对接,考虑之前zabbix的一些不便,本次计划采用一个类Prometheus的监控系统,镜调研后发现Nightingale兼容Prometheus,又有一些其他功能增强,又…

JDK 动态代理从入门到掌握

快速入门 本文介绍 JDK 实现的动态代理及其原理,通过 ProxyGenerator 生成的动态代理类字节码文件 环境要求 要求原因JDK 8 及以下在 JDK 9 之后无法使用直接调用 ProxyGenerator 中的方法,不便于将动态代理类对应的字节码文件输出lombok为了使用 Sne…

孩子都能学会的FPGA:第十七课——用FPGA实现定点数的乘法

(原创声明:该文是作者的原创,面向对象是FPGA入门者,后续会有进阶的高级教程。宗旨是让每个想做FPGA的人轻松入门,作者不光让大家知其然,还要让大家知其所以然!每个工程作者都搭建了全自动化的仿…

SpringBoot-监听Nacos动态修改日志级别

目录 一、pom文件 二、项目配置文件 三、日志配置文件 四、日志监听类 五、日志动态修改服务类 线上系统的日志级别一般都是 INFO 级别,有时候需要查看 WARN 级别的日志,所以需要动态修改日志级别。微服务项目中使用 Nacos 作为注册中心&#xff0c…

什么是网络攻击?阿里云服务器可以避免被攻击吗?

网络攻击是指:损害网络系统安全属性的任何类型的进攻动作。进攻行为导致网络系统的机密性、完整性、可控性、真实性、抗抵赖性等受到不同程度的破坏。 网络攻击有很多种,网络上常见的攻击有DDOS攻击、CC攻击、SYN攻击、ARP攻击以及木马、病毒等等,所以再…

CTO对生活和工作一点感悟

陌生人,你好啊。 感谢CSDN平台让我们有了隔空认识,交流的机会。 我是谁? 我呢,毕业快11年,在网易做了几年云计算,后来追风赶上了大数据的浪潮,再到后来混迹在AI、智能推荐等领域。 因为有一颗…

SS8847T 双通道 H 桥驱动芯片 替代DRV8847

SS8847E是一款双桥电机驱动器,具有两个H桥驱动器,可以驱动两个直流有刷电机,一个双极步进电机,螺线管或其他感性负载。该器件的工作电压范围为 2.7V 至 15V,每通道可提供高达 1.0A 的负载电流。每个H桥的输出驱动器模块…

2023年安全员-A证证模拟考试题库及安全员-A证理论考试试题

题库来源:安全生产模拟考试一点通公众号小程序 2023年安全员-A证证模拟考试题库及安全员-A证理论考试试题是由安全生产模拟考试一点通提供,安全员-A证证模拟考试题库是根据安全员-A证最新版教材,安全员-A证大纲整理而成(含2023年…

navigator.clipboard is undefined in JavaScript issue [Fixed]

navigator.clipboard 在不安全的网站是无法访问的。 在本地开发使用localhost或127.0.0.1没有这个问题。因为它不是不安全网站。 在现实开发中,可能遇到测试环境为不安全网站。 遇到这个问题,就需要将不安全网站标记为非不安全网站即可。 外网提供了3…

python动态加载内容抓取问题的解决实例

问题背景 在网页抓取过程中,动态加载的内容通常无法通过传统的爬虫工具直接获取,这给爬虫程序的编写带来了一定的技术挑战。腾讯新闻(https://news.qq.com/)作为一个典型的动态网页,展现了这一挑战。 问题分析 动态…

【开源视频联动物联网平台】视频接入网关的用法

视频接入网关是一种功能强大的视频网关设备,能够解决各种视频接入、视频输出、视频转码和视频融合等问题。它可以在应急指挥、智慧融合等项目中发挥重要作用,与各种系统进行对接,解决视频能力跨系统集成的难题。 很多视频接入网关在接入协议…

Go 语言输出文本函数详解

Go语言拥有三个用于输出文本的函数: Print()Println()Printf() Print() 函数以其默认格式打印其参数。 示例 打印 i 和 j 的值: package mainimport "fmt"func main() {var i, j string "Hello", "World"fmt.Print(…

【力扣:526】优美的排列

状态压缩动态规划 原理如下: 遍历位图可以得到所有组合序列,将这些序列的每一位看作一个数,取序列中1总量的值作为每轮遍历的位,此时对每个这样的位都能和所有数进行匹配,因为一开始就取的是全排列,并且我们…

MySQL表的查询、更新、删除

查询 全列查询 指定列查询 查询字段并添加自定义表达式 自定义表达式重命名 查询指定列并去重 select distinct 列名 from 表名 where条件 查询列数据为null的 null与 (空串)是不同的! 附:一般null不参与查询。 查询列数据不为null的 查询某列数据指定…

陈嘉庚慈善践行与卓顺发的大爱传承

陈嘉庚慈善践行,了解陈嘉庚后人与卓顺发的大爱传承。 2023年11月25日,卓顺发太平绅士以及陈家后人在分享他们对慈善领域见解的过程中,特别强调了慈善在促进社会和谐以及推动社会进步方面的关键作用。同时,他们深入探讨了如何在当今社会中继续传扬和实践家国情怀以及…

C++ CryptoPP使用AES加解密

Crypto (CryptoPP) 是一个用于密码学和加密的 C 库。它是一个开源项目,提供了大量的密码学算法和功能,包括对称加密、非对称加密、哈希函数、消息认证码 (MAC)、数字签名等。Crypto 的目标是提供高性能和可靠的密码学工具,以满足软件开发中对…

什么是木马

木马 1. 定义2. 木马的特征3. 木马攻击流程4. 常见木马类型5. 如何防御木马 1. 定义 木马一名来源于古希腊特洛伊战争中著名的“木马计”,指可以非法控制计算机,或在他人计算机中从事秘密活动的恶意软件。 木马通过伪装成正常软件被下载到用户主机&…

strstr 的使用和模拟实现

就位了吗?如果坐好了的话,那么我就要开始这一期的表演了哦! strstr 的使用和模拟实现: char * strstr ( const char * str1, const char * str2); Returns a pointer to the first occurrence of str2 in str1, or a null pointer if str2 i…

030 - STM32学习笔记 - ADC(四) 独立模式多通道DMA采集

030 - STM32学习笔记 - ADC(四) 独立模式多通道DMA采集 中断模式和DMA模式进行单通道模拟量采集,这节继续学习独立模式多通道DMA采集,使用到的引脚有之前使用的PC3(电位器),PA4(光敏…