pytorch学习——正则化技术——丢弃法(dropout)

news2025/1/12 15:56:23

一、概念介绍

        在多层感知机(MLP)中,丢弃法(Dropout)是一种常用的正则化技术,旨在防止过拟合。(效果一般比前面的权重衰退好)

        在丢弃法中,随机选择一部分神经元并将其输出清零,被清零的神经元在该轮训练中不会被激活。这样,其他神经元就需要学习代替这些神经元的功能,从而促进了神经元之间的独立性和鲁棒性。

1.1思想原理

        丢弃法的基本思想是,在每一次训练中,随机选择一些神经元不参与训练,从而减少神经元之间的相互依赖关系,使得模型对于训练数据的过拟合程度降低。这样在测试时,所有神经元都参与,可以取得更好的泛化性能。

        丢弃法可以被应用到多层感知机的任意层中,包括输入层和输出层。在实际应用中,通常会在每一层都添加丢弃法,以充分发挥其正则化作用。

丢弃法特性:在层之间加入噪声,而不是在数据输入时加入。

 对Xi中的元素,以p概率变成0,1-p概率变大,最后期望值不变。

1.2应用场景

        通常将丢弃法作用在隐藏全连接层的输出上。

 如图所示,丢弃法可将一些中间结点丢弃,对剩余节点进行一定的增强。

 注:dropout是正则项,仅在训练中使用,不用于预测。

 二、示例演示

2.1实现dropout_layer 函数        

        该函数以dropout的概率丢弃张量输入X中的元素, 如上所述重新缩放剩余部分:将剩余部分除以1.0-dropout

import torch
from torch import nn
from d2l import torch as d2l


def dropout_layer(X, dropout):
    assert 0 <= dropout <= 1
    # 在本情况中,所有元素都被丢弃
    if dropout == 1:
        return torch.zeros_like(X)
    # 在本情况中,所有元素都被保留
    if dropout == 0:
        return X
    mask = (torch.rand(X.shape) > dropout).float()
    return mask * X / (1.0 - dropout)

2.2测试dropout_layer函数

X=torch.arange(16,dtype=torch.float32).reshape((2,8))
print(X)
#暂退概率是0,0.5,1
print(dropout_layer(X,0.))
print(dropout_layer(X,0.5))
print(dropout_layer(X,1.))
#结果
tensor([[ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.],
        [ 8.,  9., 10., 11., 12., 13., 14., 15.]])
tensor([[ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.],
        [ 8.,  9., 10., 11., 12., 13., 14., 15.]])
tensor([[ 0.,  0.,  4.,  6.,  0.,  0.,  0., 14.],
        [16.,  0., 20., 22., 24.,  0., 28.,  0.]])
tensor([[0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.]])

2.3定义模型 

        引入Fashion-MNIST数据集。 我们定义具有两个隐藏层的多层感知机,每个隐藏层包含256个单元。将暂退法应用于每个隐藏层的输出(在激活函数之后), 并且可以为每一层分别设置暂退概率: 常见的技巧是在靠近输入层的地方设置较低的暂退概率。 下面的模型将第一个和第二个隐藏层的暂退概率分别设置为0.2和0.5, 并且暂退法只在训练期间有效。

num_inputs,num_outputs,num_hiddens1,num_hiddens2=784,10,256,256
#定义两个隐藏层,每个隐藏层有256个单元
dropout1, dropout2 = 0.2, 0.5  # 为每个隐藏层设置一个 dropout 概率

class Net(nn.Module):
    def __init__(self, num_inputs, num_outputs, num_hiddens1, num_hiddens2, is_training=True):
        super(Net, self).__init__()
        self.num_inputs = num_inputs
        self.training = is_training
        self.lin1 = nn.Linear(num_inputs, num_hiddens1)
        self.lin2 = nn.Linear(num_hiddens1, num_hiddens2)
        self.lin3 = nn.Linear(num_hiddens2, num_outputs)
        self.relu = nn.ReLU()

    def forward(self, X):
        # 应用第一个全连接层和 ReLU 激活函数
        H1 = self.relu(self.lin1(X.reshape((-1, self.num_inputs))))
        # 如果处于训练模式,对第一个隐藏层应用 dropout 操作
        if self.training == True:
            H1 = dropout_layer(H1, dropout1)
        # 应用第二个全连接层和 ReLU 激活函数
        H2 = self.relu(self.lin2(H1))
        # 如果处于训练模式,对第二个隐藏层应用 dropout 操作
        if self.training == True:
            H2 = dropout_layer(H2, dropout2)
        # 应用第三个全连接层,得到输出张量
        out = self.lin3(H2)
        return out

# 创建一个神经网络模型实例
net = Net(num_inputs, num_outputs, num_hiddens1, num_hiddens2)

2.4训练和测试

# 设置训练的轮数、学习率和批次大小
num_epochs, lr, batch_size = 10, 0.5, 256

# 定义损失函数为交叉熵损失,并设置reduction='none'以便获得单个样本的损失值
loss = nn.CrossEntropyLoss(reduction='none')

# 加载Fashion-MNIST数据集,并设置批次大小
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)

# 定义优化器为随机梯度下降(SGD),并设置学习率
trainer = torch.optim.SGD(net.parameters(), lr=lr)

# 使用d2l.train_ch3函数进行模型训练,其中包括训练数据迭代器、测试数据迭代器、损失函数、训练轮数和优化器等参数
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)

2.5结果

 三 、简洁实现

from torch import nn
import torch
from d2l import  torch as d2l
dropout1, dropout2 = 0.2, 0.5  # 为每个隐藏层设置一个 dropout 概率
num_epochs, lr, batch_size = 10, 0.5, 256
net = nn.Sequential(nn.Flatten(),
        nn.Linear(784, 256),
        nn.ReLU(),
        # 在第一个全连接层之后添加一个dropout层
        nn.Dropout(dropout1),
        nn.Linear(256, 256),
        nn.ReLU(),
        # 在第二个全连接层之后添加一个dropout层
        nn.Dropout(dropout2),
        nn.Linear(256, 10))

def init_weights(m):
    if type(m) == nn.Linear:
        nn.init.normal_(m.weight, std=0.01)

net.apply(init_weights)
loss = nn.CrossEntropyLoss(reduction='none')
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
trainer = torch.optim.SGD(net.parameters(), lr=lr)
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/819176.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

HCIP中期实验

1、该拓扑为公司网络&#xff0c;其中包括公司总部、公司分部以及公司骨干网&#xff0c;不包含运营商公网部分。 2、设备名称均使用拓扑上名称改名&#xff0c;并且区分大小写。 3、整张拓扑均使用私网地址进行配置。 4、整张网络中&#xff0c;运行OSPF协议或者BGP协议的设备…

Hadoop 之 Hive 4.0.0-alpha-2 搭建(八)

Hadoop 之 Hive 搭建与使用 一.Hive 简介二.Hive 搭建1.下载2.安装1.解压并配置 HIVE2.修改 hive-site.xml3.修改 hadoop 的 core-site.xml4.启动 三.Hive 测试 一.Hive 简介 Hive 是基于 Hadoop 的数据仓库工具&#xff0c;可以提供类 SQL 查询能力 二.Hive 搭建 1.下载 H…

linux 安装FTP

检查是否已经安装 $] rpm -qa |grep vsftpd vsftpd-3.0.2-29.el7_9.x86_64出现 vsftpd 信息表示已经安装&#xff0c;无需再次安装 yum安装 $] yum -y install vsftpd此命令需要root执行或有sudo权限的账号执行 /etc/vsftpd 目录 ftpusers # 禁用账号列表 user_list # 账号列…

【Ajax】笔记-设置CORS响应头实现跨域

CORS CORS CORS是什么&#xff1f; CORS(Cross-Origin Resource Sharing),跨域资源共享。CORS是官方的跨域解决方案&#xff0c;它的特点是不需要在客户端做任何特殊的操作&#xff0c;完全在服务器中进行处理&#xff0c;支持get和post请求。跨域资源共享标准新增了一组HTTP首…

VSCode格式化shell脚本

安装格式化插件&#xff1a;shell-format 用VSCode打开shell脚本之后&#xff0c;按格式化快捷键CtrlAltF&#xff0c;会提示没有格式化shell的工具&#xff0c;然后安装插件&#xff0c;我装的是这个插件&#xff1a;shell-format。 介绍&#xff1a;https://marketplace.vis…

2.C语言数据类型

常量与变量 1.**常量&#xff1a;**程序运行中&#xff0c;值不改变的量 变量&#xff1a;int num5&#xff1b;值可以变的量 2 C语言三种简单数据类型&#xff1a;整型&#xff0c;实型&#xff0c;字符型 %c %d %ld %s %f整型-进制的转换 **1.十进制&#xff1a;**默认的进…

SpringBoot使用JKS或PKCS12证书实现https

SpringBoot使用JKS或PKCS12证书实现https 生成JKS类型的证书 可以利用jdk自带的keytool工具来生成证书文件&#xff0c; 默认生成的是JKS证书 cmd命令如下: 执行如下命令&#xff0c;并按提示填写证书内容&#xff0c;最后会生成server.keystore文件 keytool -genkey tomcat…

ChatGPT在商业世界中的创新应用:颠覆传统营销与客户关系管理

&#x1f337;&#x1f341; 博主 libin9iOak带您 Go to New World.✨&#x1f341; &#x1f984; 个人主页——libin9iOak的博客&#x1f390; &#x1f433; 《面试题大全》 文章图文并茂&#x1f995;生动形象&#x1f996;简单易学&#xff01;欢迎大家来踩踩~&#x1f33…

14、容器初始化器和配置环境后处理器

容器初始化器和配置环境后处理器 容器初始化器 与传统Spring的容器后处理器&#xff08;也可对容器做一些定制&#xff09;对比&#xff0c;此处的容器初始化器的执行时机要更早一些。 容器初始化器负责可对Spring容器执行初始化定制。 就是在启动项目的时候&#xff0c;在容…

eclipse版本与jdk版本对应关系

官网&#xff1a;Eclipse/Installation - Eclipsepedia eclipse历史版本&#xff08;2007-&#xff09;&#xff1a;Older Versions Of Eclipse - Eclipsepedia Eclipse Packaging Project (EPP) Releases | Eclipse Packages

springboot疾病查询网站【纯干货分享,免费领源码01548】

spring boot疾病查询网站 摘 要 随着互联网时代的到来&#xff0c;同时计算机网络技术高速发展&#xff0c;网络管理运用也变得越来越广泛。因此&#xff0c;建立一个B/S结构的疾病查询网站&#xff0c;会使疾病查询工作系统化、规范化&#xff0c;也会提高医院形象&#xff0c…

红外雨量计(光学雨量传感器)检测降雨量,预防内涝

红外雨量计&#xff08;光学雨量传感器&#xff09;检测降雨量&#xff0c;预防内涝 随着城市化进程的加快&#xff0c;城市内涝成为一个愈发严峻的问题。短时间内大量的降雨&#xff0c;不仅会给城市交通带来困难&#xff0c;也会对城市的基础设施和居民的生活造成很大的影响…

ESP32cam系列教程003:ESP32cam实现远程 HTTP_OTA 自动升级

文章目录 1.什么是 OTA2. ESP32cam HTTP_OTA 本地准备2.1 HTTP OTA 升级原理2.2 开发板本地基准程序&#xff08;程序版本&#xff1a;1_0_0&#xff09;2.3 开发板升级程序&#xff08;程序版本&#xff1a;1_0_1&#xff09;2.4 本地 HTTP_OTA 升级测试2.4.1 本地运行一个 HT…

Spring系列二:基于注解配置bean

文章目录 &#x1f497;通过注解配置bean&#x1f35d;基本介绍&#x1f35d;快速入门&#x1f35d;注意事项和细节 &#x1f497;自己实现Spring注解配置Bean机制&#x1f35d;需求说明&#x1f35d;思路分析&#x1f35d;注意事项和细节 &#x1f497;自动装配 Autowired&…

谷粒商城第八天-商品服务之品牌管理的整体实现(直接使用逆向生成的代码;含oss文件上传)

目录 一、总述 二、前端部分 2.1 创建好品牌管理菜单 2.2 复制组件 ​编辑2.3 复制api ​​​编辑 2.4 查看效果 ​编辑2.5 需要优化的地方 2.6 具体优化实现 2.6.1 优化一&#xff1a;将表格的状态列&#xff08;这里是是否显示列&#xff09;修改为开关&#xff…

Day15-作业(Maven高级)

作业1&#xff1a;完成苍穹外卖项目目录结构创建 说明&#xff1a; 1.sky-take-out 是父工程 2.sky-common/sky-pojo/sky-server 是子工程 3.sky-server模块引用了sky-common/sky-pojo模块 作业2&#xff1a;完成汇客CRM项目目录结构创建 说明&#xff1a; huike-paren…

判断变量是否为标量(没有维度,单个数值)numpy.isscalar()

【小白从小学Python、C、Java】 【计算机等级考试500强双证书】 【Python-数据分析】 判断变量是否为标量 &#xff08;没有维度,单个数值&#xff09; numpy.isscalar() [太阳]选择题 请问关于以下代码表述错误的是&#xff1f; import numpy as np a 0 b np.array([0]) pri…

六、初始化和清理(3)

本章概要 成员初始化构造器初始化 初始化的顺序静态数据的初始化显式的静态初始化非静态实例初始化 成员初始化 Java 尽量保证所有变量在使用前都能得到恰当的初始化。对于方法的局部变量&#xff0c;这种保证会以编译时错误的方式呈现&#xff0c;所以如果写成&#xff1a…

初识Java - 概念与准备

本笔记参考自&#xff1a; 《On Java 中文版》 目录 写在第一行 Java的迭代与发展 Java的迭代 Java的参考文档 对象的概念 抽象 接口 访问权限 复用实现 继承 基类和子类 A是B和A像B 多态 单根层次结构 集合 参数化类型 对象的创建和生命周期 写在第一行 作为一…

Charles抓包工具使用(一)(macOS)

Fiddler抓包 | 竟然有这些骚操作&#xff0c;太神奇了&#xff1f; Fiddler响应拦截数据篡改&#xff0c;实现特殊场景深度测试&#xff08;一&#xff09; 利用Fiddler抓包调试工具&#xff0c;实现mock数据特殊场景深度测试&#xff08;二&#xff09; 利用Fiddler抓包调试工…