【机器学习】037_暂退法

news2025/1/11 9:58:49

一、实现原理

具有输入噪音的训练,等价于Tikhonov正则化

核心方法:在前向传播的过程中,计算每一内部层的同时注入噪声

· 从作用上来看,表面上来说是在训练过程中丢弃一些神经元

· 假设x是某一层神经网络层的输出,是下一层的输入,我们希望对x加入一些噪音,使得:

E[x^`]=x

  ※x`的期望为x,也就是说平均上来说输出值还是x

· 暂退法对每个元素进行了如下扰动:

        有p的概率下取值:x^`_i=0

        其它情况(1-p概率):x^`_i = \frac{x_i}{1-p}

实践中使用暂退法:

· 通常将暂退法作用在全连接隐藏层的输出上

如图所示,在第一个隐藏层的输出上,有些神经元有p的概率使输出值置零。

非置零的输出值,即有1-p的概率被施加了一个较小的扰动值使其略微增大。

※暂退法只在训练中使用,dropout是正则项,在推理过程中不会使用,这样也会保证输出值确定

※每次执行暂退法的时候,实际上是每次随机采样了一些子神经网络

总结:

①暂退法将一些输出项随机置零来控制模型的复杂度

②暂退法的作用效果和正则化等价

③常应用在多层感知机的隐藏层输出上

④丢弃概率p是控制模型复杂度的超参数

二、代码实现

从零实现代码:

import torch
from torch import nn
from d2l import torch as d2l

def dropout_layer(X, dropout):
    # assert用于选择dropout符合范围的情况,不符合则报错
    assert 0 <= dropout <= 1, "不符合范围!"
    # 在本情况中,所有元素都被丢弃
    if dropout == 1:
        return torch.zeros_like(X)
    # 在本情况中,所有元素都被保留
    if dropout == 0:
        return X
    # 在这一步操作中,首先定义一个和X张量形状相同但元素值均为随机数的张量
    # 将这个张量里每个元素与dropout比较,如果大于就置为True,小于等于就置为False
    # 再调用float将True和False转化为1和0
    # 这样,mask就是一个仅含1与0的张量了
    # 最后将mask里的每个元素与X里的每个元素做数乘
    mask = (torch.rand(X.shape) > dropout).float()
    return mask * X / (1.0 - dropout)

# 生成X来测试暂退法
X= torch.arange(16, dtype = torch.float32).reshape((2, 8))
print(X)
print(dropout_layer(X, 0.))
print(dropout_layer(X, 0.5))
print(dropout_layer(X, 1.))

# 定义模型参数
num_inputs, num_outputs, num_hiddens1, num_hiddens2 = 784, 10, 256, 256
# 定义模型
dropout1, dropout2 = 0.2, 0.5
# is_training用来表示当前是在测试还是在训练
class Net(nn.Module):
    def __init__(self, num_inputs, num_outputs, num_hiddens1, num_hiddens2,
                 is_training = True):
        super(Net, self).__init__()
        self.num_inputs = num_inputs
        self.training = is_training
        self.lin1 = nn.Linear(num_inputs, num_hiddens1)
        self.lin2 = nn.Linear(num_hiddens1, num_hiddens2)
        self.lin3 = nn.Linear(num_hiddens2, num_outputs)
        self.relu = nn.ReLU()

    def forward(self, X):
        H1 = self.relu(self.lin1(X.reshape((-1, self.num_inputs))))
        # 只有在训练模型时才使用dropout
        if self.training == True:
            # 在第一个全连接层之后添加一个dropout层
            H1 = dropout_layer(H1, dropout1)
        H2 = self.relu(self.lin2(H1))
        if self.training == True:
            # 在第二个全连接层之后添加一个dropout层
            H2 = dropout_layer(H2, dropout2)
            # 输出不需要dropout作用
        out = self.lin3(H2)
        return out

net = Net(num_inputs, num_outputs, num_hiddens1, num_hiddens2)

# 训练、测试模型
num_epochs, lr, batch_size = 10, 0.5, 256
loss = nn.CrossEntropyLoss(reduction='none')
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
trainer = torch.optim.SGD(net.parameters(), lr=lr)
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)

简洁实现代码:

import torch
from torch import nn
from d2l import torch as d2l

# 定义概率参数
dropout1, dropout2 = 0.2, 0.5

net = nn.Sequential(nn.Flatten(),
        nn.Linear(784, 256),
        nn.ReLU(),
        # 在第一个全连接层之后添加一个dropout层
        nn.Dropout(dropout1),
        nn.Linear(256, 256),
        nn.ReLU(),
        # 在第二个全连接层之后添加一个dropout层
        nn.Dropout(dropout2),
        nn.Linear(256, 10))

def init_weights(m):
    if type(m) == nn.Linear:
        nn.init.normal_(m.weight, std=0.01)

net.apply(init_weights);

# 训练、测试模型
num_epochs, lr, batch_size = 10, 0.5, 256
loss = nn.CrossEntropyLoss(reduction='none')
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)

trainer = torch.optim.SGD(net.parameters(), lr=lr)
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1232283.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

​EMNLP 2023 findings | 生成式框架下解决输入扰动槽填充任务

©PaperWeekly 原创 作者 | 回亭风 单位 | 北京邮电大学 研究方向 | 自然语言理解 论文标题&#xff1a; DemoNSF: A Multi-task Demonstration-based Generative Framework for Noisy Slot Filling Task 论文链接&#xff1a; https://arxiv.org/abs/2310.10169 代码链接…

【Linux】套接字编程

目录 套接字 IP PORT TCP和UDP的介绍 TCP UDP 网络字节序 转换接口 UDP服务器的编写 服务器的初始化 socket bind sockaddr 结构 服务器的运行 数据的收发 业务处理 客户端的编写 运行效果 拓展 套接字 &#x1f338;首先&#xff0c;我们先思考一个问题…

11.4MyBatis(基础)

一.搭环境 1.创建完SSM项目,添加MySQL和MyBatis后,项目启动一定会报错,这是正常情况. 2.配置文件 properties: server.port9090 spring.datasource.urljdbc:mysql://127.0.0.1:3306/test1?characterEncodingutf8&useSSLfalse spring.datasource.usernameroot spring.d…

Linux内核的安装

1.通过tftp 加载内核和根文件系统 即sd内存卡启动&#xff1a; SD卡的存储以扇区为单位,每个扇区的大小为512Byte, 其中零扇区存储分区表&#xff08;即分区信息&#xff09;,后续的扇区可自行分区和格式化&#xff1b; 若选择SD卡启动&#xff0c;处理器上电后从第一个扇区开…

开发仿抖音APP遇到的问题和解决方案

uni-app如何引入阿里矢量库图标/uniapp 中引入 iconfont 文件报错文件查找失败 uni-app如何引入阿里矢量库图标 - 知乎 uniapp 中引入 iconfont 文件报错文件查找失败&#xff1a;‘./iconfont.woff?t1673007495384‘ at App.vue:6_宝马金鞍901的博客-CSDN博客 将课件中的cs…

企业微信将应用安装到工作台

在上篇中介绍了配置小程序应用及指令、数据回调获取第三方凭证&#xff1b; 本篇将介绍如何将应用安装到企业工作台。 添加测试企业 通过【应用管理】->【测试企业配置】添加测试企业。 通过企业微信扫描二维码添加测试企业。 注意&#xff1a;需要扫描的账号为管理员权限…

让别人访问电脑本地

查看本地IP地址&#xff1a; 使用ipconfig&#xff08;Windows&#xff09;或ifconfig&#xff08;Linux/macOS&#xff09;命令来查看你的计算机本地网络的IP地址。确保*****是你的本地IP地址。 防火墙设置&#xff1a; 确保你的防火墙允许从外部访问*****。你可能需要在防火…

万字解析设计模式之代理模式

一、代理模式 1.1概述 代理模式是一种结构型设计模式&#xff0c;它允许通过创建代理对象来控制对其他对象的访问。这种模式可以增加一些额外的逻辑来控制对原始对象的访问&#xff0c;同时还可以提供更加灵活的访问方式。 代理模式分为静态代理和动态代理两种。静态代理是在编…

【机器学习】032_多种神经网络层类型

一、密集层 每一层神经元都是上一层神经元的函数&#xff0c;每层每个神经元都从前一层获得所有激活的输入。 整个神经网络前一层与后一层连接在一起&#xff0c;构造的网络密集。 二、卷积层 假设有一张大小为axb像素的图片&#xff0c;上面标着一些手写数字&#xff0c…

Apache Airflow (十二) :PythonOperator

&#x1f3e1; 个人主页&#xff1a;IT贫道_大数据OLAP体系技术栈,Apache Doris,Clickhouse 技术-CSDN博客 &#x1f6a9; 私聊博主&#xff1a;加入大数据技术讨论群聊&#xff0c;获取更多大数据资料。 &#x1f514; 博主个人B栈地址&#xff1a;豹哥教你大数据的个人空间-豹…

Java-equals方法

1.package com.msb.test02; 2. 3./** 4. * Auther: msb-zhaoss 5. */ 6.public class Phone {//手机类&#xff1a; 7. //属性&#xff1a; 8. private String brand;//品牌型号 9. private double price;//价格 10. private int year ;//出产年份 11. //方法&a…

Java基础-----正则表达式

文章目录 1.简介2.目的3.学习网站4.常用匹配字符5.String类中用到正则表达式的方法 1.简介 又叫做规则表达式。是一种文本模式&#xff0c;包括普通字符和特殊字符&#xff08;元字符&#xff09;。正则使用单个字符来描述、匹配一系列某个句法规则的字符串&#xff0c;通常用…

投资黄金:如何选择正确的黄金品种增加收益?

黄金一直以来都是备受投资者青睐的避险资产&#xff0c;然而&#xff0c;在庞大的黄金市场中&#xff0c;选择适合自己的黄金品种成为影响收益的关键因素。黄金投资并不只有一种方式&#xff0c;而是有很多种不同的黄金品种可以选择。每种黄金品种都有其独特的特点和风险&#…

Linux本地WBO创作白板部署与远程访问

文章目录 前言1. 部署WBO白板2. 本地访问WBO白板3. Linux 安装cpolar4. 配置WBO公网访问地址5. 公网远程访问WBO白板6. 固定WBO白板公网地址 前言 WBO在线协作白板是一个自由和开源的在线协作白板&#xff0c;允许多个用户同时在一个虚拟的大型白板上画图。该白板对所有线上用…

启动dubbo消费端过程提示No provider available for the service的问题定位与解决

文/朱季谦 某次在启动dubbo消费端时&#xff0c;发现无法从zookeeper注册中心获取到所依赖的消费者API&#xff0c;启动日志一直出现这样的异常提示 Failed to check the status of the service com.fte.zhu.api.testService. No provider available for the service com.fte…

使用Python的turtle模块绘制玫瑰花图案(含详细Python代码与注释)

1.1引言 turtle模块是Python的标准库之一&#xff0c;它提供了一个绘图板&#xff0c;让我们可以在屏幕上绘制各种图形。通过使用turtle&#xff0c;我们可以创建花朵、叶子、复杂的图案等等。本博客将介绍如何使用turtle模块实现绘制图形的过程&#xff0c;并展示最终结果。 …

初始环境配置

目录 一、JDK1、简介2、配置步骤 二、Redis1、简介2、配置步骤 三、MySQL1、简介2、配置步骤 四、Git1、简介2、配置步骤 五、NodeJS1、简介2、配置步骤 六、Maven1、简介2、配置步骤 七、Tomcat1、简介2、配置步骤 一、JDK 1、简介 JDK 是 Oracle 提供的 Java 开发工具包&…

Java基础-----StringBuffer和StringBuilder

文章目录 1.StringBuffer1.1 构造方法1.2 常用方法 2.StringBuilder3.String、StringBuffer、StringBuilder的区别 1.StringBuffer 内容可变的字符串类&#xff0c;适应StringBuffer来对字符串的内容进行动态操作&#xff0c;不会产生额外的对象。StringBuffer在初始时&#x…

机器学习笔记 - Ocr识别中的CTC算法原理概述

一、文字识别 在文本检测步骤中,分割出了文本区域。现在需要识别这些片段中存在哪些文本。 机器学习笔记 - Ocr识别中的文本检测EAST网络概述-CSDN博客文章浏览阅读300次。在 EAST 网络的这个分支中,它合并了 VGG16 网络不同层的特征输出。现在,该层之后的特征大小将等于 p…

【计算机网络笔记】路由算法之链路状态路由算法

系列文章目录 什么是计算机网络&#xff1f; 什么是网络协议&#xff1f; 计算机网络的结构 数据交换之电路交换 数据交换之报文交换和分组交换 分组交换 vs 电路交换 计算机网络性能&#xff08;1&#xff09;——速率、带宽、延迟 计算机网络性能&#xff08;2&#xff09;…