droppath

news2024/12/5 3:33:51

DropPath 是一种用于正则化深度学习模型的技术,它在训练过程中随机丢弃路径(或者说随机让某些部分的输出变为零),从而增强模型的鲁棒性和泛化能力。

代码解释:

import torch
import torch.nn as nn

# 定义 DropPath 类
class DropPath(nn.Module):
    def __init__(self, drop_prob=None):
        super().__init__()
        self.drop_prob = drop_prob

    def forward(self, x):
        return drop_path(x, self.drop_prob, self.training)

def drop_path(x, drop_prob: float = 0., training: bool = False):
#drop_path(输入,将drop_prob初始化为0., 判断是否为训练模式)
    if drop_prob == 0. or not training:
        return x
#如果drop_prob等于0或者不是训练模式直接将输入输出
    keep_prob = 1 - drop_prob
#保留的概率
    shape = (x.shape[0],) + (1,) * (x.ndim - 1) 
# 形状:(batch_size, 1, 1, ...)
# x.shape[0]获取xshape的第一维也就是batch_size
# (1,) * (x.ndim - 1) 将shape用1填充和x的形状一样
    random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
# torch.rand(shape, dtype=x.dtype, device=x.device)生成随机数(生成均值为0,标准差为1的正态# 分布随机数)形状和shape一致的也就是和x一致,数据类型,设备都和x一致
# 将随机数和keep_prob相加得到随机数(范围[keep_prob,1+keep_prob])
    random_tensor.floor_() 
# 二值化,生成 0 或 1 的 mask
# 也就是将随机数向下取整
    output = x.div(keep_prob) * random_tensor
#x.div(keep_prob)将输入张量x的所有值除以keep_prob,目的是 放大保留下来的部分

#* random_tensor根据0 或 1 的 mask决定哪些路径会被保留(1)或丢弃(0)
    return output

为什么要放大保留下来的部分:

  • 丢弃路径会导致部分值被置为零,模型整体输出的总期望值会下降。
  • 为了补偿这种下降,需要对保留下来的部分放大,使得丢弃路径后的总期望值和丢弃前一致。

因为只是补偿所以并不一定等与原期望

数学解释:

假设输入张量是 x=\begin{bmatrix} x_{1,}&x_{2,} & ... &, x_{n} \end{bmatrix},其中每个元素 xi表示特征。

期望:E=\frac{1}{n}\sum_{1}^{n}x_{i}

丢弃之后:E=\frac{1}{n}\sum_{1}^{n}{keepprob}\cdot x_{i}

放大之后:E=\frac{1}{n}\sum_{1}^{n}\frac{​{keepprob}\cdot x_{i}}{keepprob}=\frac{1}{n}\sum_{1}^{n}x_{i}

实例:

import torch
import torch.nn as nn

# 定义 DropPath 类
class DropPath(nn.Module):
    def __init__(self, drop_prob=None):
        super().__init__()
        self.drop_prob = drop_prob

    def forward(self, x):
        return drop_path(x, self.drop_prob, self.training)

def drop_path(x, drop_prob: float = 0., training: bool = False):
    if drop_prob == 0. or not training:
        return x
    keep_prob = 1 - drop_prob
    shape = (x.shape[0],) + (1,) * (x.ndim - 1)  # 形状:(batch_size, 1, 1, ...)
    random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
    random_tensor.floor_()  # 二值化,生成 0 或 1 的 mask
    print(f'mask: {random_tensor}')
    output = x.div(keep_prob) * random_tensor
    return output

# 定义简单模型
class SimpleModel(nn.Module):
    def __init__(self, drop_prob):
        super().__init__()
        self.linear = nn.Linear(4, 4)  # 简单的线性层
        self.drop_path = DropPath(drop_prob)  # 使用 DropPath
        self.activation = nn.ReLU()  # ReLU 激活

    def forward(self, x):
        print("输入数据:")
        print(x)

        x = self.linear(x)  # 线性层
        print("线性层输出:")
        print(x)

        x = self.activation(x)  # ReLU 激活
        print("激活后输出:")
        print(x)

        x = self.drop_path(x)  # DropPath
        print("DropPath 后输出:")
        print(x)

        return x

# 创建模型
model = SimpleModel(drop_prob=0.5)
model.train()  # 设置为训练模式以启用 DropPath

# 输入数据
input_data = torch.tensor([[1.0, 2.0, 3.0, 4.0],
                           [5.0, 6.0, 7.0, 8.0]], dtype=torch.float32)

# 运行模型
output = model(input_data)

输出: 简单理解就是根据mask的1,0值对每个样本进行保留或置零

输入数据:
tensor([[1., 2., 3., 4.],
        [5., 6., 7., 8.]])
线性层输出:
tensor([[ 1.2836, -1.4602,  2.2660, -1.7250],
        [ 1.3035, -4.1391,  4.5453, -2.5738]], grad_fn=<AddmmBackward0>)
激活后输出:
tensor([[1.2836, 0.0000, 2.2660, 0.0000],
        [1.3035, 0.0000, 4.5453, 0.0000]], grad_fn=<ReluBackward0>)
mask: tensor([[1.],
        [0.]])
DropPath 后输出:
tensor([[2.5672, 0.0000, 4.5321, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000]], grad_fn=<MulBackward0>)

 

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2252751.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

机器学习算法(六)---逻辑回归

常见的十大机器学习算法&#xff1a; 机器学习算法&#xff08;一&#xff09;—决策树 机器学习算法&#xff08;二&#xff09;—支持向量机SVM 机器学习算法&#xff08;三&#xff09;—K近邻 机器学习算法&#xff08;四&#xff09;—集成算法 机器学习算法&#xff08;五…

Ubuntu24.04初始化教程(包含基础优化、ros2)

将会不断更新。但是所有都是基础且必要的操作。 为重装系统之后的环境配置提供便捷信息来源。记录一些错误的解决方案。 目录 构建系统建立系统备份**Timeshift: 系统快照和备份工具****安装 Timeshift****使用 Timeshift 创建快照****还原快照****自动创建快照** 最基本配置换…

【Maven】Nexus私服

6. Maven的私服 6.1 什么是私服 Maven 私服是一种特殊的远程仓库&#xff0c;它是架设在局域网内的仓库服务&#xff0c;用来代理位于外部的远程仓库&#xff08;中央仓库、其他远程公共仓库&#xff09;。一些无法从外部仓库下载到的构件&#xff0c;如项目组其他人员开发的…

Gradle vs. Maven: 到底哪个更适合java 项目?

ApiHug ApiHug - API Design & Develop New Paradigm.ApiHug - API Design & Develop New Paradigm.https://apihug.com/ 首先 ApiHug 整个工具链是基于 gradle 构建,包括项目模版&#xff0c; 插件&#xff1b; 说到 Java 项目管理&#xff0c;有两个巨头脱颖而出&a…

Dubbo的集群容错策略有哪些?它们的工作原理是什么?

大家好&#xff0c;我是锋哥。今天分享关于【Dubbo的集群容错策略有哪些&#xff1f;它们的工作原理是什么&#xff1f;】面试题。希望对大家有帮助&#xff1b; Dubbo的集群容错策略有哪些&#xff1f;它们的工作原理是什么&#xff1f; 1000道 互联网大厂Java工程师 精选面试…

分治的思想(力扣965、力扣144、牛客KY11)

引言 分治思想是将问题分解为更小子问题&#xff0c;分别解决后再合并结果。二叉树中常用此思想&#xff0c;因其结构递归&#xff0c;易分解为左右子树问题&#xff0c;递归解决后合并结果。 这篇文章会讲解用分治的思想去解决二叉树的一些题目&#xff0c;顺便会强调在做二…

中国电信张宝玉:城市数据基础设施建设运营探索与实践

11月28日&#xff0c;2024新型智慧城市发展创新大会在山东青岛召开&#xff0c;中国电信数字政府研究院院长张宝玉在大会发表主旨演讲《城市数据基础设施运营探索与实践》。报告内容包括城市数据基础设施的概述、各地典型做法及发展趋势建议三个方面展开。 篇幅限制&#xff0…

【论文阅读】Federated learning backdoor attack detection with persistence diagram

目的&#xff1a;检测联邦学习环境下&#xff0c;上传上来的模型是不是恶意的。 1、将一个模型转换为|L|个PD,&#xff08;其中|L|为层数&#xff09; 如何将每一层转换成一个PD&#xff1f; 为了评估第&#x1d457;层的激活值&#xff0c;我们需要&#x1d450;个输入来获…

深度学习案例:ResNet50模型+SE-Net

本文为为&#x1f517;365天深度学习训练营内部文章 原作者&#xff1a;K同学啊 一 回顾ResNet模型 ResNet&#xff0c;即残差网络&#xff0c;是由微软研究院的Kaiming He及其合作者于2015年提出的一种深度卷积神经网络架构。该网络架构的核心创新在于引入了“残差连接”&…

js高级-ajax封装和跨域

ajax简介及相关知识 原生ajax AJAX 简介 AJAX 全称为 Asynchronous JavaScript And XML&#xff0c;就是异步的 JS 和 XML。 通过 AJAX 可以在浏览器中向服务器发送异步请求&#xff0c;最大的优势&#xff1a;无刷新获取数据。 按需请求&#xff0c;可以提高网站的性能 AJ…

【AI】Sklearn

长期更新&#xff0c;建议关注、收藏、点赞。 友情链接&#xff1a; AI中的数学_线代微积分概率论最优化 Python numpy_pandas_matplotlib_spicy 建议路线&#xff1a;机器学习->深度学习->强化学习 目录 预处理模型选择分类实例&#xff1a; 二分类比赛 网格搜索实例&…

如何让控件始终处于父容器的居中位置(父容器可任意改变大小)

前言&#xff1a; 大家好&#xff0c;我是上位机马工&#xff0c;硕士毕业4年年入40万&#xff0c;目前在一家自动化公司担任软件经理&#xff0c;从事C#上位机软件开发8年以上&#xff01;我们在C#开发winform程序的时候&#xff0c;有时候需要将一个控件居中显示&#xff0c…

Python 调用 Umi-OCR API 批量识别图片/PDF文档数据

目录 一、需求分析 二、方案设计&#xff08;概要/详细&#xff09; 三、技术选型 四、OCR 测试 Demo 五、批量文件识别完整代码实现 六、总结 一、需求分析 市场部同事进行采购或给客户报价时&#xff0c;往往基于过往采购合同数据&#xff0c;给出现在采购或报价的金额…

【QT】背景,安装和介绍

TOC 目录 背景 GUI技术 QT的安装 使用流程 QT程序介绍 main.cpp​编辑 Wiget.h Widget.cpp form file .pro文件 临时文件 C作为一门比较古老的语言&#xff0c;在人们的认知里始终是以底层&#xff0c;复杂和高性能著称&#xff0c;所以在很多高性能需求的场景之下…

Linux内核编译流程(Ubuntu24.04+Linux Kernel 6.8.12)

万恶的拯救者&#xff0c;使用Ubuntu没有声音&#xff0c;必须要自己修改一下Linux内核中的相关驱动逻辑才可以&#xff0c;所以被迫学习怎么修改内核&编译内核&#xff0c;记录如下 准备工作 下载Linux源码&#xff1a;在Linux发布页下载并使用gpg签名验证 即&#xff1a…

【阅读笔记】Android广播的处理流程

关于Android的解析&#xff0c;有很多优质内容&#xff0c;看了后记录一下阅读笔记&#xff0c;也是一种有意义的事情&#xff0c; 今天就看看“那个写代码的”这位大佬关于广播的梳理&#xff0c; https://blog.csdn.net/a572423926/category_11509429.html https://blog.c…

linux下Qt程序部署教程

文章目录 [toc]1、概述2、静态编译安装Qt1.1 安装依赖1.2 静态编译1.3 报错1.4 添加环境变量1.5 下载安装QtCreator 3、配置linuxdeployqt环境1.1 在线安装依赖1.2 使用linuxdeployqt提供的程序1.3 编译安装linuxdeployqt 4、使用linuxdeployqt打包依赖1.1 linuxdeployqt使用选…

【PHP】部署和发布PHP网站到IIS服务器

欢迎来到《小5讲堂》 这是《PHP》系列文章&#xff0c;每篇文章将以博主理解的角度展开讲解。 温馨提示&#xff1a;博主能力有限&#xff0c;理解水平有限&#xff0c;若有不对之处望指正&#xff01; 目录 前言安装PHP稳定版本线程安全版解压使用 PHP配置配置文件扩展文件路径…

视觉经典神经网络学习01_CNN(1)

一、概述 卷积神经网络&#xff08;Convolutional Neural Network&#xff0c;CNN&#xff09;是一种专门用于处理具有网格状结构数据的深度学习模型。最初&#xff0c;CNN主要应用于计算机视觉任务&#xff0c;但它的成功启发了在其他领域应用&#xff0c;如自然语言处理等。…

【golang】单元测试,以及出现undefined时的解决方案

单元测试 要对某一方法进行测试时&#xff0c;例如如下这一简单减法函数&#xff0c;选中函数名后右键->转到->测试 1&#xff09;Empty test file 就是一个空文件&#xff0c;我们可以自己写测试的逻辑 但是直接点绿色箭头运行会出问题&#xff1a; 找不到包。我们要在…