【动手学习深度学习--逐行代码解析合集】10Dropout暂退法

news2024/9/28 9:32:36

【动手学习深度学习】逐行代码解析合集

10Dropout暂退法


视频链接:动手学习深度学习–Dropout暂退法
课程主页:https://courses.d2l.ai/zh-v2/
教材:https://zh-v2.d2l.ai/

1、暂退法原理

在这里插入图片描述
在这里插入图片描述

2、从零开始实现暂退法

import torch
from torch import nn
from d2l import torch as d2l

import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"

# 该函数以dropout的概率丢弃张量输入X中的元素
def dropout_layer(X, dropout):
    assert 0 <= dropout <= 1
    # 在本情况中,所有元素都被丢弃
    if dropout == 1:
        return torch.zeros_like(X)
    # 在本情况中,所有元素都被保留
    if dropout == 0:
        return X
    # torch.rand(X.shape)生成0-1之间的均匀随机分布,大于dropout的返回1,小于的返回0
    mask = (torch.rand(X.shape) > dropout).float()
    # mask随机生成0或1
    return mask * X / (1.0 - dropout)
# 测试dropout_layer函数,暂退概率分别为0、0.5和1。
X=  torch.arange(16, dtype = torch.float32).reshape((2, 8))
print(X)
print(dropout_layer(X, 0.))
print(dropout_layer(X, 0.5))
print(dropout_layer(X, 1.))

运行结果
在这里插入图片描述

2.1 定义模型参数

# 定义具有两个隐藏层的多层感知机,每个隐藏层包含256个单元。
num_inputs, num_outputs, num_hiddens1, num_hiddens2 = 784, 10, 256, 256

2.2 定义模型

我们可以将暂退法应用于每个隐藏层的输出(在激活函数之后), 并且可以为每一层分别设置暂退概率: 常见的技巧是在靠近输入层的地方设置较低的暂退概率。 下面的模型将第一个和第二个隐藏层的暂退概率分别设置为0.2和0.5, 并且暂退法只在训练期间有效。

# 定义具有两个隐藏层的多层感知机,每个隐藏层包含256个单元。
num_inputs, num_outputs, num_hiddens1, num_hiddens2 = 784, 10, 256, 256
# 模型将第一个和第二个隐藏层的暂退概率分别设置为0.2和0.5
dropout1, dropout2 = 0.2, 0.5

class Net(nn.Module):
    # is_training = True:给程序标注是在训练
    def __init__(self, num_inputs, num_outputs, num_hiddens1, num_hiddens2,
                 is_training = True):
        super(Net, self).__init__()
        self.num_inputs = num_inputs
        self.training = is_training
        self.lin1 = nn.Linear(num_inputs, num_hiddens1)  # 第一个隐藏层
        self.lin2 = nn.Linear(num_hiddens1, num_hiddens2)  # 第二个隐藏层
        self.lin3 = nn.Linear(num_hiddens2, num_outputs)  # 输出层
        self.relu = nn.ReLU()  # 激活函数

    def forward(self, X):
        # 对第一个隐藏层作非线性激活后,再使用dropout
        H1 = self.relu(self.lin1(X.reshape((-1, self.num_inputs))))
        # 只有在训练模型时才使用dropout
        if self.training == True:
            # 在第一个全连接层之后添加一个dropout层
            H1 = dropout_layer(H1, dropout1)
        # 对第二个隐藏层作非线性激活
        H2 = self.relu(self.lin2(H1))
        if self.training == True:
            # 在第二个全连接层之后添加一个dropout层
            H2 = dropout_layer(H2, dropout2)
        # 输出层不作用dropout
        out = self.lin3(H2)
        return out

net = Net(num_inputs, num_outputs, num_hiddens1, num_hiddens2)

2.3 训练和测试

# 训练和测试
num_epochs, lr, batch_size = 10, 0.5, 256
loss = nn.CrossEntropyLoss(reduction='none')
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
trainer = torch.optim.SGD(net.parameters(), lr=lr)
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)
d2l.plt.show()

在这里插入图片描述

若不使用dropout对比结果(此处将dropout1, dropout2 = 0.0, 0.0)
在这里插入图片描述

3、暂退法的简洁实现

# 简洁实现
net = nn.Sequential(nn.Flatten(),
        nn.Linear(784, 256),  # 第一个隐藏层
        nn.ReLU(),  # Dropout放在ReLU前后均可
        # 在第一个全连接层之后添加一个dropout层
        nn.Dropout(dropout1),
        nn.Linear(256, 256),  # 第二个隐藏层
        nn.ReLU(),
        # 在第二个全连接层之后添加一个dropout层
        nn.Dropout(dropout2),
        nn.Linear(256, 10))   # 输出层

# 初始化权重,此处不懂可看05softmax回归的简洁实现
def init_weights(m):
    if type(m) == nn.Linear:
        # m.weight默认为0,以均值为0方差为0.01来随机初始化权重
        nn.init.normal_(m.weight, std=0.01)
# net.apply(init_weights)会递归地将函数init_weights应用到父模块的每个子模块submodule,也包括model这个父模块自身。
net.apply(init_weights);

# 参数更新
trainer = torch.optim.SGD(net.parameters(), lr=lr)
# 训练画图
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)
d2l.plt.show()

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/730867.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

微服务网关技术选型:Zuul2、Gateway、OpenResty、Kong

1、简介 当使用单体应用程序架构时&#xff0c;客户端&#xff08;Web 或移动端&#xff09;通过向后端应用程序发起一次 REST 调用来获取数据。负载均衡器将请求路由给 N 个相同的应用程序实例中的一个。然后应用程序会查询各种数据库表&#xff0c;并将响应返回给客户端。微…

missing-semester————1

文章目录 shell概述echoshell如何知道去哪寻找date或echo呢&#xff1f;$PATHlsman流根用户 shell概述 root1test:~$ $表示身份不是root用户 ~表示当前所在位置是"home" root1test:~$ date Sat Jul 8 02:57:44 UTC 2023输入命令&#xff0c;会被shell解析 上述执行…

静态路由配置——Cisco Packet Tracer

这里放一个用Packet Tracer 8.0实现的配置好的静态路由文件&#xff0c;配置如下 下载链接如下&#xff1a; https://wwix.lanzoue.com/ifp5T11ksnla

内嵌tomcat报错

严重: Unable to process Jar entry [module-info.class] from Jar [jar:file:/D:/javaTools/apache-maven-bin/apache-maven-3.6.1/maven-repo/com/fasterxml/jackson/core/jackson-databind/2.10.5/jackson-databind-2.10.5.jar!/] for annotations org.apache.tomcat.util.b…

NI采集卡USB-6361多通道模拟输入采集报错解决方案

文章目录 前言一、现有例程1、前面板2、程序框图 二、采集测试1、单通道采集2、多通道采集①、错误的做法②、正确的做法1&#xff09;前面板2&#xff09;程序框图3&#xff09;运行测试 总结 前言 折腾一块 USB-6361 采集卡很久了&#xff0c;之前都是单通道采集模拟信号&am…

云原生(第六篇)k8s-kubeadmin部署

master&#xff08;2C/4G&#xff0c;cpu核心数要求大于2&#xff09; 192.168.169.10 docker、kubeadm、kubelet、kubectl、flannel node01&#xff08;2C/2G&#xff09; 192.168.169.30 docker、kubeadm、kubelet、kubect…

汇总:FlatLaf-intellij-themes皮肤效果一览

关于主题包&#xff1a; FlatLaf 是一个跨平台的 Java Swing 外观库&#xff0c;提供现代化的平面化用户界面。 导包 <dependency><groupId>com.formdev</groupId><artifactId>flatlaf</artifactId><version>3.1.1</version><sco…

机器学习28:《推荐系统-I》概述

在互联网领域&#xff0c;推荐系统&#xff08;Recommendation Systems&#xff09;的应用非常广泛。在音视频方面&#xff0c;如抖音、快手、哔哩等&#xff1b;在电商平台方面&#xff0c;如京东、淘宝、拼多多等。推荐有助于帮助用户快速发现潜在感兴趣的内容&#xff08;音…

RS485或RS232转ETHERCAT连接安川ethercat总线伺服

最近&#xff0c;生产管理设备中经常会遇到两种协议不相同的情况&#xff0c;这严重阻碍了设备之间的通讯&#xff0c;串口设备的数据不能直接传输给ETHERCAT。这可怎么办呢&#xff1f; 别担心&#xff0c;远创智控YC-ECT-RS485/232来了&#xff01;这是一款自主研发的ETHER…

使用vue ui创建vue项目失败原因

每个人的失败原因都不相同&#xff0c;因为下载NodeJS文件时&#xff0c;默认下载到c盘中&#xff0c;我改变盘符到了D盘&#xff0c;因此要删除c盘中隐藏的文件&#xff0c;注意是c盘中的.npmrc文件。具体位置如下&#xff1a; 点击查看显示隐藏文件才能看到该文件 最后创建项…

磁性材料在使用时需要注意什么

为了不引起人身损伤及磁体性能不良&#xff0c;请遵循以下注意事项&#xff1a; 1、 磁体在使用过程中应确保工作场所干净&#xff0c;否则容易吸附铁屑等磁性小颗粒影响使用。 2、 磁体在充磁时&#xff0c;磁体必须固定&#xff0c;且充磁场必须大于磁体材料矫顽力的2.5倍&…

CEC2023动态多目标优化算法:基于自适应启动策略的混合交叉动态约束多目标优化算法(MC-DCMOEA)求解CEC2023(提供MATLAB代码)

一、动态多目标优化问题 1.1问题定义 1.2 动态支配关系定义 二、 基于自适应启动策略的混合交叉动态多目标优化算法 基于自适应启动策略的混合交叉动态多目标优化算法&#xff08;Mixture Crossover Dynamic Constrained Multi-objective Evolutionary Algorithm Based on Se…

【ElasticSearch】JavaRestClient实现文档查询、排序、分页、高亮

文章目录 1、入门案例2、全文检索3、精确查询4、复合查询-boolean query5、排序和分页6、高亮 1、入门案例 先初始化JavaRestClient对象&#xff1a; SpringBootTest public class HotelSearchTest {private RestHighLevelClient client;Testvoid testInit() {System.out.pri…

uniapp如何给空包进行签名操作

这里给大家分享我在网上总结出来的一些知识&#xff0c;希望对大家有所帮助 首先安装sdk https://www.oracle.com/java/technologies/downloads/ 正常下一步即可~安装完毕后&#xff0c;进入在sdk根目录执行cmd C:\Program Files\Java\jdk-18.0.1.1\bin生成keystore 例&#xf…

数据结构--线索二叉树找前驱后继

数据结构–线索二叉树找前驱后继 中序线索二叉树找中序后继 在中序线索二叉树中找到指定结点*p的 中序后继 \color{red}中序后继 中序后继next ①若p->rtag 1&#xff0c;则next p->rchild ②若p->rtag 0 中序遍历――左根右 左根(左根右) 左根((左根右)根右) next …

PVT、OCV、工艺偏差、CPPRCRPR、ld漏级电流计算

文章目录 PVT&OCV(local variation)Sources of variation1) Etching2) Oxide Thickness propagation delay、ld、drain currentCPPR&CRPRsetup checkHold check 芯片的delay由两部分影响因素构成 cell delay&#xff1a;library set pvt_cornernet delay: rc tech fil…

电风扇自动温控调速器电路设计

这是一个电风扇自动温控调速器&#xff0c;可根据温度变化情况自动调节电风扇的转速&#xff0c;电路加以调整&#xff0c;也可用于其它电气设备的控制。 一、电路工作原理 电路原理如图 37 所示。 图中 IC 是 555 时基电路&#xff0c;与R2、R3 和 C2 等元件构成多谐振荡器…

前端Vue自定义暂无数据组件nodata 用于页面请求无数据时展示

随着技术的发展&#xff0c;开发的复杂度也越来越高&#xff0c;传统开发方式将一个系统做成了整块应用&#xff0c;经常出现的情况就是一个小小的改动或者一个小功能的增加可能会引起整体逻辑的修改&#xff0c;造成牵一发而动全身。通过组件化开发&#xff0c;可以有效实现单…

【Cesium 安装+Cesium 加载b3dm】

Cesium 安装 一、安装的方式大致有三种&#xff1a; 1、引入ceisum源码包使用&#xff1b; 2、安装cesium插件&#xff1b; 3、安装Vue-cesium插件 我这里只尝试了第一种和第二种。 引入ceisum源码包使用 可以使用直接下载官方压缩包来引入也可以npm i cesium包&#xff0c;把…

Socket error Event: 32 Error: 10053.

Socket error Event: 32 Error: 10053. 一、报错 &#xff1a;二、问题&#xff1a;三、原因&#xff1a;四、解决方案&#xff1a; 一、报错 &#xff1a; Socket error Event: 32 Error: 10053. 二、问题&#xff1a; xshell连接虚拟机断连 三、原因&#xff1a; 虚拟机…