pytorch深度学习实战lesson14

news2024/10/2 12:24:13

第十四课 丢弃法(Dropout)

目录

理论部分

实践部分

从零开始实现:

简洁实现:


理论部分

这节课很重要,因为沐神说这个丢弃法比上节课的权重衰退效果更好!

为什么期望没变?

如上图所示,使用dropout后,是将隐藏层的的某几个神经元变成零,然后没有变成0的神经元会相应增大以保证总的期望不变。当然保留和置零的神经元不是一定的。这个是训练才会使用,测试的时候不用。

实践部分

从零开始实现:

代码:

#我们实现 dropout_layer 函数,该函数以dropout的概率丢弃张量输入X中的元素
import torch
from torch import nn
from d2l import torch as d2l
import matplotlib.pyplot as plt
def dropout_layer(X, dropout):#dropout相当于丢弃法公式中的p也就是概率。
    assert 0 <= dropout <= 1
    if dropout == 1:#当丢弃的概率p为1时,也就是100%丢弃,那么就返回全0的阵。
        return torch.zeros_like(X)
    if dropout == 0:#当丢弃的概率p为0时,也就是0%丢弃,那么就保持X。
        return X
    #生成0-1之间的均匀分布,如果其中的值大于dropout,mask就置一,否则置零。
    # 这里的mask就相当于掩膜矩阵,有置零和置一的作用。
    mask = (torch.randn(X.shape) > dropout).float()
    return mask * X / (1.0 - dropout)#做乘法比“X[mask]=0”这样的选元素要快,省资源。
#测试dropout_layer函数
X = torch.arange(16, dtype=torch.float32).reshape((2, 8))
print(X)
print(dropout_layer(X, 0.))#都不变
print(dropout_layer(X, 0.5))#百分之五十概率变成0,每次都是随机变的,这样才有效果
print(dropout_layer(X, 1.))#全变成0
#定义具有两个隐藏层的多层感知机,每个隐藏层包含256个单元
num_inputs, num_outputs, num_hiddens1, num_hiddens2 = 784, 10, 256, 256
dropout1, dropout2 = 0.2, 0.5
class Net(nn.Module):
    def __init__(self, num_inputs, num_outputs, num_hiddens1, num_hiddens2,
                 is_training=True):
        super(Net, self).__init__()
        self.num_inputs = num_inputs
        self.training = is_training
        self.lin1 = nn.Linear(num_inputs, num_hiddens1)#输入层
        self.lin2 = nn.Linear(num_hiddens1, num_hiddens2)#第一个隐藏层
        self.lin3 = nn.Linear(num_hiddens2, num_outputs)#第二个隐藏层
        self.relu = nn.ReLU()#输出层,以relu激活函数输出
    def forward(self, X):
        H1 = self.relu(self.lin1(X.reshape((-1, self.num_inputs))))#第一个隐藏层
        if self.training == True:           #如果是在训练的话
            H1 = dropout_layer(H1, dropout1)#就dropout
        H2 = self.relu(self.lin2(H1))       #第二个隐藏层
        if self.training == True:           #如果在训练的话
            H2 = dropout_layer(H2, dropout2)#就dropout
        out = self.lin3(H2)                 #输出层不dropout
        return out
net = Net(num_inputs, num_outputs, num_hiddens1, num_hiddens2)
#训练和测试
num_epochs, lr, batch_size = 10, 0.5, 256
loss = nn.CrossEntropyLoss()
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
trainer = torch.optim.SGD(net.parameters(), lr=lr)
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)
plt.show()

简洁实现:

代码:

#简洁实现
import torch
from torch import nn
from d2l import torch as d2l
import matplotlib.pyplot as plt
num_inputs, num_outputs, num_hiddens1, num_hiddens2 = 784, 10, 256, 256
dropout1, dropout2= 0.2, 0.5
num_epochs, lr, batch_size = 10, 0.5, 256
net = nn.Sequential(nn.Flatten(), nn.Linear(784, 256), nn.ReLU(),
                    nn.Dropout(dropout1), nn.Linear(256, 256), nn.ReLU(),
                    #nn.Dropout(dropout2), nn.Linear(256, 256), nn.ReLU(),
                    nn.Dropout(dropout2), nn.Linear(256, 10))
loss = nn.CrossEntropyLoss()
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
def init_weights(m):
    if type(m) == nn.Linear:
        nn.init.normal_(m.weight, std=0.01)
net.apply(init_weights)
#对模型进行训练和测试
trainer = torch.optim.SGD(net.parameters(), lr=lr)
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)
plt.show()


拓展:

在简洁实现的基础上,调整参数看有什么变化。

1、把两个dropout的概率都设为0(代码第7行)

和简洁实现比,没什么大区别。

2、把两个dropout的概率分别设为0.7和0.9(代码第7行)

Traceback (most recent call last):

  File "D:\Python\pythonProject\11.14.2丢弃法简洁实现.py", line 20, in <module>

    d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)

  File "D:\anaconda9.20\lib\site-packages\d2l\torch.py", line 343, in train_ch3

    assert train_loss < 0.5, train_loss

AssertionError: 2.3032064454396566

3、把两个dropout的概率都设为1(代码第7行)

相当于把隐藏层神经元都给换成0了。

报错了

Traceback (most recent call last):

  File "D:\Python\pythonProject\11.14.2丢弃法简洁实现.py", line 20, in <module>

    d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)

  File "D:\anaconda9.20\lib\site-packages\d2l\torch.py", line 343, in train_ch3

    assert train_loss < 0.5, train_loss

AssertionError: 2.3032064454396566

4、把两个dropout的概率分别设为1和0(代码第7行)

Traceback (most recent call last):

  File "D:\Python\pythonProject\11.14.2丢弃法简洁实现.py", line 20, in <module>

    d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)

  File "D:\anaconda9.20\lib\site-packages\d2l\torch.py", line 343, in train_ch3

    assert train_loss < 0.5, train_loss

AssertionError: 2.3032064454396566

5、把两个dropout的概率分别设为0和1(代码第7行)

Traceback (most recent call last):

  File "D:\Python\pythonProject\11.14.2丢弃法简洁实现.py", line 20, in <module>

    d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)

  File "D:\anaconda9.20\lib\site-packages\d2l\torch.py", line 343, in train_ch3

    assert train_loss < 0.5, train_loss

AssertionError: 2.3032064454396566

上面这三种情况都是dropout1和dropout2里面有一个设的比较大的时候才会报错,报错的原因是说torch.py里面有一句assert语句出问题了。这块我还没想到怎么解决。我也不是很明白torch.py里面的三组assert语句有什么用。。如果有大佬明白咋回事的话可以指点一下。谢谢。

6、加个隐藏层。

感觉效果不是很好啊。

后来把隐藏层的输入输出维度改了一下就好点了:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/7495.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

java main方法控制日志级别

背景&#xff1a; 今天想用main方法去调用http请求&#xff0c;结果已经没什么问题了&#xff0c;但是打印了一大堆Http业务内部的日志信息&#xff0c;特别挡路&#xff0c;导致想看到的业务输出看不到&#xff0c;所以经过多方求证&#xff0c;进行了日志等级处理。 默认情…

【Pytorch with fastai】第 5 章 :图像分类

&#x1f50e;大家好&#xff0c;我是Sonhhxg_柒&#xff0c;希望你看完之后&#xff0c;能对你有所帮助&#xff0c;不足请指正&#xff01;共同学习交流&#x1f50e; &#x1f4dd;个人主页&#xff0d;Sonhhxg_柒的博客_CSDN博客 &#x1f4c3; &#x1f381;欢迎各位→点赞…

电商项目缓存问题的解决方案(初步)

内容分类 容量规化 架构设计 数据库设计 缓存设计 框架选型 数据迁移方案 性能压测 监控报警 领域模型 回滚方案 高并发 分库分表 优化策略 负载均衡 软件负载 nginx&#xff1a;它自身的高可用是用lvs去保证。 下单需要登录 > 需要Session > 分布式Ses…

好书赠送丨海伦·尼森鲍姆著:《场景中的隐私——技术、政治和社会生活中的和谐》,王苑等译

开放隐私计算 收录于合集#书籍分享1个 开放隐私计算 开放隐私计算OpenMPC是国内第一个且影响力最大的隐私计算开放社区。社区秉承开放共享的精神&#xff0c;专注于隐私计算行业的研究与布道。社区致力于隐私计算技术的传播&#xff0c;愿成为中国 “隐私计算最后一公里的服…

[附源码]java毕业设计基于javaweb电影购票系统

项目运行 环境配置&#xff1a; Jdk1.8 Tomcat7.0 Mysql HBuilderX&#xff08;Webstorm也行&#xff09; Eclispe&#xff08;IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持&#xff09;。 项目技术&#xff1a; SSM mybatis Maven Vue 等等组成&#xff0c;B/S模式 M…

公众号运营建议与反思分享,建议收藏

正所谓有总结才会有成长&#xff0c;公众号运营也是如此。 公众号运营不是一朝一夕的事情&#xff0c;经过岁月的洗礼和千锤百炼&#xff0c;也总归是有了自己的一套经验和技巧。 对于公众号运营有什么建议&#xff1f;值得大家反思什么&#xff1f;今天伯乐网络传媒就来给大…

Boost升压电路调试

背景&#xff1a; 项目用到了一款升压电路&#xff0c;将12V升压到32V&#xff0c;电流要求有12A&#xff0c;最大18A。 设计的方案是使用Boost Controller 外置MOS来实现。 选定的Controller芯片为Maxim的MAX25203。 问题&#xff1a; 回板后进行调试&#xff0c;在不使能…

活动预告|“构建新安全格局”专家研讨会即将开幕

应急管理承担着防范化解重大风险、及时应对处置各类突发事件的重要职责&#xff0c;担负保护人民群众生命财产安全和维护社会稳定的重要使命。过去一年是我国应急管理体系和能力建设经受严峻考验的一年&#xff0c;也是实现大发展的一年。 11月17日&#xff0c;由中央党校科研部…

Python简单实现人脸识别检测, 对某平台美女主播照片进行评分排名

前言 嗨喽~大家好呀&#xff0c;这里是魔王呐 ❤ ~! 开发环境: Python 3.8 Pycharm 2021.2 模块使用: 第三方模块 requests >>> pip install requests tqdm >>> pip install tqdm 简单实现进度条效果 自带模块 os base64 采集代码 导入模块 # 数…

vue封装的echarts组件被同一个页面多次引用无法正常显示问题(已解决)

问题&#xff1a;第二张图显示空白&#xff0c;折线图并没有展示出来 当我们在封装了echarts组件之后&#xff0c;需要在同一个页面中引入多次时&#xff0c;会出现数据覆盖等一系列问题 当时我是修改了id也无济于事&#xff0c;达不到我需要的效果 解决方案 将我们封装的组件…

HTML5简明教程系列之HTML5 表格与表单(二)

HTML的第二弹也来了&#xff0c;最近高产似母猪&#xff0c;状态也不错&#xff0c;代码来源为实验课。本期主要内容为&#xff1a;HTML表格与DIV应用、HTML表单。上期基础部分的传送门&#xff1a; HTML5简明教程系列之HTML5基础&#xff08;一&#xff09;_Thomas_Lbw的博客-…

【进程复制】

目录地址偏移量fork函数fork练习地址偏移量 PCB结构体&#xff1a; struct task_struct { PID ststus ; … } 页面的内存大小是固定的&#xff0c;不足一页会给一页&#xff0c;大于一页会给一个整页数 比如一页大小为4K&#xff0c;地址除4K商是页号&#xff0c;余数是在该页…

Vue(六)——使用脚手架(3)

目录 webStorage localStorage sessionStorage todolist案例中使用 组件自定义事件 绑定 解绑 总结 全局事件总线 消息发布与订阅 nextTick 过渡与动画 webStorage 这不是vue团队开发的&#xff0c;不需要写在xx.vue当中&#xff0c;只需写在xx.html当中即可。 什…

Linux下C++开发笔记--g++命令

目录 1--前言 2--开发环境搭建 3--g重要编译参数 4--实例 1--前言 最近学习在linux环境下进行C开发的基础知识&#xff0c;参考的教程是基于VSCode和CMake实现C/C开发 | Linux篇&#xff0c;非常适合小白入门学习。 2--开发环境搭建 ①安装gcc、g和gdb&#xff1a; sud…

深度学习入门(三十七)计算性能——硬件(TBC)

深度学习入门&#xff08;三十七&#xff09;计算性能——硬件&#xff08;CPU、GPU&#xff09;前言计算性能——硬件&#xff08;CPU、GPU&#xff09;课件电脑提升CPU利用率①提升CPU利用率②CPU VS GPU提升GPU利用率CPU/GPU带宽更多的CPU和GPUCPU/GPU高性能计算编程总结教材…

SpringBoot整合dubbo(一)

第一次整合&#xff0c;使用无注册中心方式 一、首先&#xff0c;项目分为三个模块&#xff0c;如下图&#xff0c;dubbo-interface&#xff08;要发布的接口&#xff09;、dubbo-provider&#xff08;接口的具体实现&#xff0c;服务提供者&#xff09;、dubbo-consumer&#…

【LeetCode-中等】63. 不同路径 II(详解)

题目 一个机器人位于一个 m x n 网格的左上角 &#xff08;起始点在下图中标记为 “Start” &#xff09;。 机器人每次只能向下或者向右移动一步。机器人试图达到网格的右下角&#xff08;在下图中标记为 “Finish”&#xff09;。 现在考虑网格中有障碍物。那么从左上角到…

VScode

VScode 下载 VScode&#xff1a;https://code.visualstudio.com/安装 汉化 Chinese (Simplified) 设置 背景色 Atom One Light Theme Color Theme 护眼色 "workbench.colorCustomizations": { // 设置背景颜色// "foreground": "#75a478",&…

List详解

一、List&#xff08;列表&#xff09; 基本的数据类型&#xff0c;列表 在redis中&#xff0c;通过相应操作可以让list变成栈、队列、阻塞队列&#xff01; 在redis中所有的list命令都是以 l 开头的 添加值 将一个值或多个值&#xff0c;插入到列表尾部&#xff08;右&…

深度学习之语义分割算法(入门学习)

>>>深度学习Tricks&#xff0c;第一时间送达<<< 目录 &#x1f4a1; 写在前面 一、前言 二、深度学习的图像分割分类 1.语义分割 2.实例分割 3.全景分割 三、语义分割的基本原理 四、语义分割的常用运算及评价指标 关于算法改进及论文投稿可关注并留…