【PyTorch】 暂退法(dropout)

news2024/11/19 23:23:12

文章目录

  • 1. 理论介绍
  • 2. 实例解析
    • 2.1. 实例描述
    • 2.2. 代码实现
      • 2.2.1. 主要代码
      • 2.2.2. 完整代码
      • 2.2.3. 输出结果

1. 理论介绍

  • 线性模型泛化的可靠性是有代价的,因为线性模型没有考虑到特征之间的交互作用,由此模型灵活性受限。
  • 泛化性和灵活性之间的基本权衡被描述为偏差-方差权衡
    • 线性模型有很高的偏差,因此它们只能表示一小类函数,但其方差很低,因此它们在不同的随机数据样本上可以得出相似的结果。
    • 神经网络不局限于单独查看每个特征,而是学习特征之间的交互,但即使我们有比特征多得多的样本,深度神经网络也有可能过拟合。
  • 经典泛化理论认为,为了缩小训练和测试性能之间的差距,应该以简单的模型为目标。 简单性以较小维度的形式展现,简单性的另一个角度是平滑性,即函数不应该对其输入的微小变化敏感
  • 暂退法在前向传播过程中,计算每一内部层的同时注入噪声。 因为当训练一个有多层的深层网络时,注入噪声只会在输入-输出映射上增强平滑性。从表面上看是在训练过程中丢弃(drop out)一些神经元。 在整个训练过程的每一次迭代中,标准暂退法包括在计算下一层之前将当前层中的一些节点置零。
  • 神经网络过拟合与每一层都依赖于前一层激活值相关,这种情况称为共适应性,而暂退法会破坏共适应性。
  • 可以以一种无偏向的方式注入噪声,在固定住其他层时,每一层的期望值等于没有噪音时的值。
  • 标准暂退正则化通过按未丢弃的节点的分数进行规范化来消除每一层的偏差,换言之,每个中间活性值 h h h以暂退概率 p p p由随机变量 h ′ h' h替换,即:
    h ′ = { 0  概率为  p h 1 − p  其他情况 \begin{aligned} h' = \begin{cases} 0 & \text{ 概率为 } p \\ \frac{h}{1-p} & \text{ 其他情况} \end{cases} \end{aligned} h={01ph 概率为 p 其他情况
  • 通常,我们在测试时不用暂退法。 给定一个训练好的模型和一个新的样本,我们不会丢弃任何节点,因此不需要标准化。 然而也有一些例外:一些研究人员在测试时使用暂退法, 用于估计神经网络预测的不确定性: 如果通过许多不同的暂退法遮盖后得到的预测结果都是一致的,那么我们可以说网络发挥更稳定。
  • 我们可以将暂退法应用于每个隐藏层的输出(在激活函数之后), 并且可以为每一层分别设置暂退概率,常见的技巧是在靠近输入层的地方设置较低的暂退概率。

2. 实例解析

2.1. 实例描述

使用具有两个隐藏层的多层感知机和暂退法,拟合Fashion-MNIST数据集。

2.2. 代码实现

2.2.1. 主要代码

net = nn.Sequential(
        nn.Flatten(),
        nn.Linear(784, 256),
        nn.ReLU(),
        nn.Dropout(0.2),	# 暂退概率为0.2
        nn.Linear(256, 256),
        nn.ReLU(),
        nn.Dropout(0.5),	# 暂退概率为0.5
        nn.Linear(256, 10)
).to(device)

2.2.2. 完整代码

import os
from tensorboardX import SummaryWriter
from rich.progress import track
from torchvision.transforms import Compose, ToTensor
from torchvision.datasets import FashionMNIST
import torch
from torch.utils.data import DataLoader
from torch import nn, optim

def load_dataset():
    """加载数据集"""
    root = "./dataset"
    transform = Compose([ToTensor()])
    mnist_train = FashionMNIST(root, True, transform, download=True)
    mnist_test = FashionMNIST(root, False, transform, download=True)

    dataloader_train = DataLoader(mnist_train, batch_size, shuffle=True, 
        num_workers=num_workers,
    )
    dataloader_test = DataLoader(mnist_test, batch_size, shuffle=False,
        num_workers=num_workers,
    )
    return dataloader_train, dataloader_test


if __name__ == '__main__':
    # 全局参数设置
    num_epochs = 10
    batch_size = 256
    num_workers = 3
    device = torch.device('cuda:0')
    lr = 0.5

    # 创建记录器
    def log_dir():
        root = "runs"
        if not os.path.exists(root):
            os.mkdir(root)
        order = len(os.listdir(root)) + 1
        return f'{root}/exp{order}'
    writer = SummaryWriter(log_dir=log_dir())

    # 数据集配置
    dataloader_train, dataloader_test = load_dataset()

    # 定义模型
    net = nn.Sequential(
        nn.Flatten(),
        nn.Linear(784, 256),
        nn.ReLU(),
        nn.Dropout(0.2),
        nn.Linear(256, 256),
        nn.ReLU(),
        nn.Dropout(0.5),
        nn.Linear(256, 10)
    ).to(device)
    def init_weights(m):
        if type(m) == nn.Linear:
            nn.init.normal_(m.weight, std=0.01)
    net.apply(init_weights)
    criterion = nn.CrossEntropyLoss(reduction='none')
    optimizer = optim.SGD(net.parameters(), lr=lr)

    # 训练循环
    for epoch in track(range(num_epochs), description='dropout'):
        for X, y in dataloader_train:
            X, y = X.to(device), y.to(device)
            loss = criterion(net(X), y)
            optimizer.zero_grad()
            loss.mean().backward()
            optimizer.step()

        with torch.no_grad():
            train_loss, train_acc, num_samples = 0.0, 0.0, 0
            for X, y in dataloader_train:
                X, y = X.to(device), y.to(device)
                y_hat = net(X)
                loss = criterion(y_hat, y)
                train_loss += loss.sum()
                train_acc += (y_hat.argmax(dim=1) == y).sum()
                num_samples += y.numel()
            train_loss /= num_samples
            train_acc /= num_samples

            test_acc, num_samples = 0.0, 0
            for X, y in dataloader_test:
                X, y = X.to(device), y.to(device)
                y_hat = net(X)
                test_acc += (y_hat.argmax(dim=1) == y).sum()
                num_samples += y.numel()
            test_acc /= num_samples

            writer.add_scalars('metrics', {
                'train_loss': train_loss,
                'train_acc': train_acc,
                'test_acc': test_acc
            }, epoch)

    writer.close()

2.2.3. 输出结果

dropout

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1294392.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

JavaSE基础50题:20. 创建一个int类型的数组,元素为100,并把每个元素依次设置为1 - 100

概述 数组练习题 创建一个int类型的数组,元素为100,并把每个元素依次设置为1 - 100. 代码 public static void main(String[] args) {int[] array new int[100]; //没有赋值的情况下&#xff0c;里面存的是100个0//获取数组的长度,依次赋值for (int i 0; i < array.len…

P1 Qt的认识及环境配置

目录 前言 01 下载Qt Creator windows下载安装包拷贝到Linux Linux直接下载 02 Linux 安装Qt 前言 &#x1f3ac; 个人主页&#xff1a;ChenPi &#x1f43b;推荐专栏1: 《C_ChenPi的博客-CSDN博客》✨✨✨ &#x1f525; 推荐专栏2: 《Linux C应用编程&#xff08;概念类…

HarmonyOS4.0开发应用(三)【ArkUI组件使用】

ArkUI组件使用 这里会详细演示以下组件使用: ImageTextTextInputButtonSliderColumn&&RowList自定义组件以及相关函数使用 Image 可以是网络图片、可以是本地图片、也可以是像素图 Image("https://ts1.cn.mm.bing.net/th?idOIP-C.cYA-_PINA-ND9OeBaolDTwHaHa&…

查看电脑cuda版本

1.找到NVODIA控制面板 输入NVIDIA搜索即可 出现NVIDIA控制面板 点击系统信息 2.WINR 输入nvidia-smi 检查了一下&#xff0c;电脑没用过GPU&#xff0c;连驱动都没有 所以&#xff0c;装驱动…… 选版本&#xff0c;下载 下载后双击打开安装 重新输入nvidia-smi 显示如下…

【Lidar】Python实现点云CSF布料滤波算法提取地面点

这两天会持续更新一下Python处理点云数据的教程&#xff0c;大家可以点个关注。今天给大家分享一下点云的经典算法&#xff1a;CSF布料模拟算法。 1 CSF算法简介 CSF算法&#xff0c;全称为Cloth Simulation Filtering&#xff0c;是一种基于欧几里得空间中最小生成树思想的聚类…

iPaaS架构深入探讨

在数字化时代全面来临之际&#xff0c;企业正面临着前所未有的挑战与机遇。技术的迅猛发展与数字化转型正在彻底颠覆各行各业的格局&#xff0c;不断推动着企业迈向新的前程。然而&#xff0c;这一数字化时代亦衍生出一系列复杂而深奥的难题&#xff1a;各异系统之间数据孤岛、…

生成式人工智能在采购场景落地 隆道AI产品亮相甲子引力年终盛典

11月30日&#xff0c;以“致追风赶月的你”为主题的“2023甲子引力年终盛典”在北京举办。北京隆道网络科技有限公司总裁吴树贵受邀出席本次会议&#xff0c;并在企业数字化专场会议中做主题发言&#xff0c;探讨了人工智能在采购业务场景中的应用&#xff0c;重点向参会人员介…

老电脑重置后能连上WIFI但是打开360网页老是提示该网址不是私密连接

看了一下可以忽略这次提示&#xff0c;能够上网&#xff0c;但是每次打开新网页都会有“该网址不是私密连接”提示&#xff0c;这个提示非常大&#xff0c;严重影响上网。 强行下载了谷歌浏览器并打开后&#xff0c;提示“您的时钟慢了”&#xff0c;然后看了一下电脑右下角日期…

UDP实现群聊

代码&#xff1a; import java.awt.*; import java.awt.event.*; import javax.swing.*; import java.net.*; import java.io.IOException; import java.lang.String;public class liaotian extends JFrame{private static final int DEFAULT_PORT8899;private JLabel stateLB…

JAVA实现敏感词高亮或打码过滤:sensitive-word

练手项目中实现发表文章时检测文章是否带有敏感词&#xff0c;以及对所有敏感词的一键过滤功能 文章目录 效果预览实现步骤 效果预览 随便复制一篇内容到输入框 机器审核文章存在敏感词&#xff0c;弹消息提示并进入人工审核阶段&#xff08;若机器审核通过&#xff0c;则无需审…

helm或者k8s部署pod时遇到pod一直处于pending状态

问题 我在使用helm部署一个mysql主从复制的数据库时遇到了一些问题&#xff0c;具体如下&#xff1a; 这是我使用的配置文件 # ruoyi-mysql.yaml auth:rootPassword: "123456"database: ry-vueinitdbScriptsConfigMap: ruoyi-init-sqlprimary:persistence:size: 2G…

CleanMyMac X2024最新版本软件实用性测评

信大多数MAC用户都较为了解&#xff0c;Mac虽然有着许多亮点的性能&#xff0c;但是让用户叫苦不迭的还其硬盘空间小的特色&#xff0c;至于很多人因为文件堆积以及软件缓存等&#xff0c;造成系统空间内存不够使用的情况。于是清理工具就成为了大多数MAC用户使用频率较高的实用…

使用Caliper对Fabric地basic链码进行性能测试

如果你需要对fabric网络中地合约进行吞吐量、延迟等性能进行评估&#xff0c;可以使用Caliper来实现&#xff0c;会返回给你一份网页版的直观测试报告。下面是对test-network网络地basic链码地测试过程。 目录 1. 建立caliper-workspace文件夹2. 安装npm等3. calipe安装4. 创建…

[wp]“古剑山”第一届全国大学生网络攻防大赛 Web部分wp

“古剑山”第一届全国大学生网络攻防大赛 群友说是原题杯 哈哈哈哈 我也不懂 我比赛打的少 Web Web | unse 源码&#xff1a; <?phpinclude("./test.php");if(isset($_GET[fun])){if(justafun($_GET[fun])){include($_GET[fun]);}}else{unserialize($_GET[…

Rust的eBFP框架Aya(一) - Linux内核网络基础

前言 在我的Rust入门及实战系列文章中已经说明&#xff0c; Rust是一门内存安全的高性能编程语言&#xff0c;从它的这些优秀特性来看&#xff0c;就是一门专为系统开发而诞生的语言。至于很多使用Rust来进行web开发的行为&#xff0c;不能说它们不好&#xff0c;只能说是杀鸡…

CPU密集型和IO密集型与CPU内核之间的关系

CPU密集型和IO密集型与CPU内核之间的关系 一、CPU密集型 介绍 CPU密集型&#xff0c;也叫计算密集型&#xff0c;是指需要大量CPU计算资源&#xff0c;例如大量的数学运算、图像处理、加密解密等。这种类型的任务主要依赖于CPU的计算能力&#xff0c;会占用大量的CPU时间片&am…

ALNS的MDP模型| 还没整理完12-08

有好几篇论文已经这样做了&#xff0c;先摆出一篇&#xff0c;然后再慢慢更新 第一篇 该篇论文提出了一种称为深增强ALNS&#xff08;DR-ALNS&#xff09;的方法&#xff0c;它利用DRL选择最有效的破坏和修复运营商&#xff0c;配置破坏严重性参数施加在破坏算子上&#xff0c…

备忘录模式 rust和java的实现

文章目录 备忘录模式介绍实现javarustrust仓库 备忘录模式 备忘录&#xff08;Memento&#xff09;模式的定义&#xff1a;在不破坏封装性的前提下&#xff0c;捕获一个对象的内部状态&#xff0c;并在该对象之外保存这个状态&#xff0c;以便以后当需要时能将该对象恢复到原先…

解决Flutter运行报错Could not run build/ios/iphoneos/Runner.app

错误场景 更新了IOS的系统版本为最新的17.0, 运行报以下错误 Launching lib/main.dart on iPhone in debug mode... Automatically signing iOS for device deployment using specified development team in Xcode project: GN3DCAF71C Running Xcode build... Xcode build d…

C++_命名空间(namespace)

目录 1、namespace的重要性 2、 namespace的定义及作用 2.1 作用域限定符 3、命名空间域与全局域的关系 4、命名空间的嵌套 5、展开命名空间的方法 5.1 特定展开 5.1 部分展开 5.2 全部展开 结语&#xff1a; 前言&#xff1a; C作为c语言的“升级版”&#xff0c;其在…