Pytorch 暂退法(Dropout)

news2024/11/25 16:52:14

在2014年,斯里瓦斯塔瓦等人 (Srivastava et al., 2014) 就如何将毕晓普的想法应用于网络的内部层提出了一个想法: 在训练过程中,他们建议在计算后续层之前向网络的每一层注入噪声。 因为当训练一个有多层的深层网络时,注入噪声只会在输入-输出映射上增强平滑性。这个想法被称为暂退法(dropout)。

目录

1、暂退法

2、使用Fashion-MNIST数据集动手实践暂退法

3、简单实践暂退法


1、暂退法

 

在2012年,Hinton在其论文《Improving neural networks by preventing co-adaptation of feature detectors》中提出Dropout。当一个复杂的前馈神经网络被训练在小的数据集时,容易造成过拟合。为了防止过拟合,可以通过阻止特征检测器的共同作用来提高神经网络的性能。

在2012年,Alex、Hinton在其论文《ImageNet Classification with Deep Convolutional Neural Networks》中用到了Dropout算法,用于防止过拟合。并且,这篇论文提到的AlexNet网络模型引爆了神经网络应用热潮,并赢得了2012年图像识别大赛冠军,使得CNN成为图像分类上的核心算法模型。

随后,又有一些关于Dropout的文章《Dropout:A Simple Way to Prevent Neural Networks from Overfitting》、《Improving Neural Networks with Dropout》、《Dropout as data augmentation》。

Dropout可以作为训练深度神经网络的一种trick供选择。在每个训练批次中,通过忽略一半的特征检测器(让一半的隐层节点值为0),可以明显地减少过拟合现象。这种方式可以减少特征检测器(隐层节点)间的相互作用,检测器相互作用是指某些检测器依赖其他检测器才能发挥作用。

Dropout说的简单一点就是:我们在前向传播的时候,让某个神经元的激活值以一定的概率p停止工作,这样可以使模型泛化性更强,因为它不会太依赖某些局部的特征。

2、使用Fashion-MNIST数据集动手实践暂退法

①载入数据集

%matplotlib inline
import torch
import torchvision
from torch.utils import data
from torch import nn
from torchvision import transforms
from d2l import torch as d2l
import warnings
warnings.filterwarnings("ignore")
# 用SVG清晰度高
d2l.use_svg_display()
from IPython import display
def get_dataloader_workers():  #@save
    """使用4个进程来读取数据"""
    return 4
def load_data_fashion_mnist(batch_size, resize=None):  #@save
    """下载Fashion-MNIST数据集,然后将其加载到内存中"""
    trans = [transforms.ToTensor()]
    if resize:
        trans.insert(0, transforms.Resize(resize))
    trans = transforms.Compose(trans)
    mnist_train = torchvision.datasets.FashionMNIST(
        root="data/", train=True, transform=trans, download=True)
    mnist_test = torchvision.datasets.FashionMNIST(
        root="data/", train=False, transform=trans, download=True)
    return (data.DataLoader(mnist_train, batch_size, shuffle=True,
                            num_workers=get_dataloader_workers()),
            data.DataLoader(mnist_test, batch_size, shuffle=False,
                            num_workers=get_dataloader_workers()))
train_iter, test_iter = load_data_fashion_mnist(256)

②、定义dropout

我们实现 dropout_layer 函数, 该函数以dropout的概率丢弃张量输入X中的元素, 如上所述重新缩放剩余部分:将剩余部分除以1.0-dropout

def dropout_layer(X, dropout):
    assert 0 <= dropout <= 1
    # 在本情况中,所有元素都被丢弃
    if dropout == 1:
        return torch.zeros_like(X)
    # 在本情况中,所有元素都被保留
    if dropout == 0:
        return X
    mask = (torch.rand(X.shape) > dropout).float()
    return mask * X / (1.0 - dropout)

几个例子来测试dropout_layer函数。 我们将输入X通过暂退法操作,暂退概率分别为0、0.5和1。

X= torch.arange(16, dtype = torch.float32).reshape((2, 8))
print(X)
print(dropout_layer(X, 0.))
print(dropout_layer(X, 0.5))
print(dropout_layer(X, 1.))

③、定义模型参数

num_inputs, num_outputs, num_hiddens1, num_hiddens2 = 784, 10, 256, 256

 ④、定义模型

dropout1, dropout2 = 0.2, 0.5

class Net(nn.Module):
    def __init__(self, num_inputs, num_outputs, num_hiddens1, num_hiddens2,
                 is_training = True):
        super(Net, self).__init__()
        self.num_inputs = num_inputs
        self.training = is_training
        self.lin1 = nn.Linear(num_inputs, num_hiddens1)
        self.lin2 = nn.Linear(num_hiddens1, num_hiddens2)
        self.lin3 = nn.Linear(num_hiddens2, num_outputs)
        self.relu = nn.ReLU()

    def forward(self, X):
        H1 = self.relu(self.lin1(X.reshape((-1, self.num_inputs))))
        # 只有在训练模型时才使用dropout
        if self.training == True:
            # 在第一个全连接层之后添加一个dropout层
            H1 = dropout_layer(H1, dropout1)
        H2 = self.relu(self.lin2(H1))
        if self.training == True:
            # 在第二个全连接层之后添加一个dropout层
            H2 = dropout_layer(H2, dropout2)
        out = self.lin3(H2)
        return out


net = Net(num_inputs, num_outputs, num_hiddens1, num_hiddens2)

⑤、进行模型训练

num_epochs, lr, batch_size = 10, 0.5, 256
loss = nn.CrossEntropyLoss()
trainer = torch.optim.SGD(net.parameters(), lr=lr)
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)

3、简单实践暂退法

net = nn.Sequential(nn.Flatten(),
        nn.Linear(784, 256),
        nn.ReLU(),
        # 在第一个全连接层之后添加一个dropout层
        nn.Dropout(dropout1),
        nn.Linear(256, 256),
        nn.ReLU(),
        # 在第二个全连接层之后添加一个dropout层
        nn.Dropout(dropout2),
        nn.Linear(256, 10))

def init_weights(m):
    if type(m) == nn.Linear:
        nn.init.normal_(m.weight, std=0.01)

net.apply(init_weights);

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/154159.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

八、Gtk4-GtkBuilder and UI file

1 New, Open and Save button 在上一节中&#xff0c;我们制作了一个非常简单的编辑器。它在程序开始时读取文件&#xff0c;在程序结束时将文件写出来。它可以工作&#xff0c;但不是很好。如果我们有“新建”、“打开”、“保存”和“关闭”按钮就更好了。本节介绍如何在窗口…

【年度总结】2022不忘初心,砥砺前行 2023纵有疾风起,人生不言弃。

2022 工作 砥砺前行 跨境电商跑路 2022年3月&#xff0c;从一家小跨境电商被动跑路&#xff0c;果断处于迷茫期&#xff0c;被动跑路的最好优势就是有赔偿&#xff0c;哈哈&#xff0c;怎么还有一丢丢高兴呢。简单总结了下在该公司的经历&#xff0c;是之前的老大带我去的&am…

如何通过产品帮助中心减轻客服压力,提高内部人员工作效率

客户的问题一直在重复&#xff0c;客服人员压力山大 客户不愿接听客服电话&#xff0c;产品问题难以解决 下班时间休息日&#xff0c;产品问题找谁问&#xff1f; 这些关于客户服务的老问题们困扰着许多产品方多年 要解决以上问题&#xff0c;更好地满足顾客需求。 搭建帮…

基于javaweb(springboot+mybatis)生活美食分享平台管理系统设计和实现以及文档报告

基于javaweb(springbootmybatis)生活美食分享平台管理系统设计和实现以及文档报告 博主介绍&#xff1a;5年java开发经验&#xff0c;专注Java开发、定制、远程、文档编写指导等,csdn特邀作者、专注于Java技术领域 作者主页 超级帅帅吴 Java毕设项目精品实战案例《500套》 欢迎…

详解函数指针(●‘◡‘●)☞

本文紧接于http://t.csdn.cn/78wbF 这篇一.函数指针数组\ ( >O< ) /1.书写形式&#xff1a;由函数指针内部*变量名>*变量名[n]&#xff1b; 2.使用&#xff1a;函数指针数组的用途&#xff1a;转移表 例如&#xff1a;模拟计算器&#xff1a;#include<stdio.h> …

用了这么久 IDEA,你还没用过 Live Templates 吗?

大家好&#xff0c;我是风筝&#xff0c;公众号「古时的风筝」&#xff0c;专注于 Java技术 及周边生态。 Live Templates 是什么&#xff0c;听上去感觉挺玄乎的。有的同学用过之后觉得简直太好用了&#xff0c;不能说大大提高了开发效率吧&#xff0c;至少也是小小的提高一下…

Qt创建项目:手把手创建第一个Qt项目

上一节介绍了QtCreator编辑器的页面长什么样子&#xff0c;以及都有哪些功能区&#xff0c;每个功能区都是用来做什么的。这一节我就手把手带大家创建一个Qt项目。 创建项目 点击新建按钮 创建项目有两个入口&#xff0c;一个是在欢迎页面的projects中点击New(新建)按钮&…

未来,勒索软件会呈现何种发展态势?

尽管过去一年里&#xff0c;全世界大约花费了1500亿美元在网络安全领域上&#xff0c;却无法真正阻止黑客攻击。在过去一年里&#xff0c;针对医院、学校、政府的勒索软件越来越多&#xff1b;加密货币领域也有无休止的黑客盗窃事件&#xff1b;还有针对微软、英伟达、Rockstar…

CORS跨域通信

在上一集的坐牢文章中&#xff0c;我们介绍了非官方的很多中方案&#xff0c;其中不乏一些江湖秘术。今天的这个&#xff0c;绝对的正统&#xff0c;纯正的官方打造。我们赶紧来看看。 1.什么是CORS&#xff1f; CORS 是一个 W3C 标准&#xff0c;全称是“跨域资源共享”&…

Prometheus的使用

Prometheus 是一个开放性的监控解决方案&#xff0c;用户可以非常方便的安装和使用 Prometheus 并且能够非常方便的对其进行扩展。 在Prometheus的架构设计中&#xff0c;Prometheus Server 并不直接服务监控特定的目标&#xff0c;其主要任务负责数据的收集&#xff0c;存储并…

ArcGIS Engine基础(31)之使用仿射变换对矢量数据进行空间校正

在生产数据过程中&#xff0c;因每个工程项目都可能有自己的施工坐标系&#xff0c;难免会产生数据提供方与数据使用方采用的坐标系不一致&#xff0c;造成数据在不同坐标系下存在一定偏移、旋转、缩放等&#xff0c;为了让数据能够在新坐标系准确定位&#xff0c;需要进行空间…

kitti数据集理解及可视化

kitti数据集简介 kitti数据集是比较早出来的3D检测方面的数据集&#xff0c;相对来说数据结构简单&#xff0c;适合做单目检测的工作&#xff0c;目前也是业界和学术界常用的公开数据集。 自己最近也在做单目3D检测的工作&#xff0c;所以也分享一些理解&#xff0c;希望能给到…

微服务自动化管理【Docker跨主机集群之Flannel】

环境说明 CentOS7 etcd-v3.4.3-linux-amd64.tar.gz flannel-v0.11.0-linux-amd64.tar.gz 官方文档&#xff1a;https://github.com/coreos/flannel 下载地址&#xff1a;https://github.com/coreos/flannel/releases/download/v0.11.0/flannel-v0.11.0-linux-amd64.tar.gz 1.…

verilog学习笔记- 9)流水灯实验

目录 简介&#xff1a; 实验任务&#xff1a; 硬件设计&#xff1a; 程序设计&#xff1a; 下载验证&#xff1a; 简介&#xff1a; LED&#xff0c;又名发光二极管。LED 灯工作电流很小&#xff08;有的仅零点几毫安即可发光&#xff09;&#xff0c;抗冲击和抗震性能好&…

工业互联网安全漏洞分析

声明 本文是学习github5.com 网站的报告而整理的学习笔记,分享出来希望更多人受益,如果存在侵权请及时联系我们 研究背景 在政策与技术的双轮驱动下&#xff0c;工业控制系统正在越来越多地与企业内网和互联网相连接&#xff0c;并与新型服务模式相结合&#xff0c;逐步形成…

day01--Python学习笔记之安装及测试

目录 官网下载 1、Python自带简单的开发环境 2、Python的交互式命令行程序 3、官方技术文档 4、安装模块文档 1、Python自带简单的开发环境 在当前中编辑代码 在新文件中编辑代码 2、P…

springboot 使用tomcat详解

1.使用内嵌tomcat启动 创建tomcat对象设置端口设置Context设置servlet 和 路径 2.spring中单独注册servlet和地址的映射关系 Beanpublic ServletRegistrationBean getServletRegistrationBean() {ServletRegistrationBean bean new ServletRegistrationBean(apiServlet);bean…

免费GPU攻略

白嫖kaggle kaggle每周会送38个小时&#xff0c;16GB显存。 验证手机号 根据 kaggle问答可知&#xff0c;需要手机号短信验证&#xff0c;账号才能用GPU。国内的手机号都是可用的。操作如下&#xff1a; 点击右上角头像&#xff0c;点击Account。 到了个人详情页后&#xf…

基于JAVA springboot+mybatis智慧生活分享平台设计和实现

基于JAVA springbootmybatis智慧生活分享平台设计和实现 博主介绍&#xff1a;5年java开发经验&#xff0c;专注Java开发、定制、远程、文档编写指导等,csdn特邀作者、专注于Java技术领域 作者主页 超级帅帅吴 Java毕设项目精品实战案例《500套》 欢迎点赞 收藏 ⭐留言 文末获取…

【MySQL进阶】执行一条 sql 语句,期间会发生什么

【MySQL进阶】执行一条 sql 语句&#xff0c;期间会发生什么? 文章目录【MySQL进阶】执行一条 sql 语句&#xff0c;期间会发生什么?MySQL 执行流程是怎样的&#xff1f;第一步&#xff1a;连接器第二步&#xff1a;查询缓存第三步&#xff1a;解析 SQL解析器第四步&#xff…