PyTorch梯度直通反传

news2025/1/16 1:56:54

有时我们想在层的输出端放置一个阈值函数。这可能出于多种原因。其中之一是我们想将激活总结为二进制值。这种激活的二值化在自编码器中很有用。

然而,阈值化在反向传播过程中会带来问题:阈值函数的导数为零。这种梯度的缺乏导致我们的网络无法学习任何东西。为了解决这个问题,我们可以使用直通估计器 (STE:Straight Through Estimator)。

NSDT工具推荐: Three.js AI纹理开发包 - YOLO合成数据生成器 - GLTF/GLB在线编辑 - 3D模型格式在线转换 - 可编程3D场景编辑器 - REVIT导出3D模型插件 - 3D模型语义搜索引擎 - Three.js虚拟轴心开发包 - 3D模型在线减面 - STL模型在线切割 

1、什么是直通估计器?

假设我们想使用以下函数将层的激活二值化:

此函数将为每个大于 0 的值返回 1,否则将返回 0。

如前所述,此函数的问题在于其梯度为零。为了解决这个问题,我们将在反向传递中使用直通估计器。

直通估计器顾名思义就是它估计函数的梯度。具体来说,它忽略阈值函数的导数,并将传入的梯度传递,就好像该函数是恒等函数一样。下图有助于更好地解释它:

你可以看到在反向传递中如何绕过阈值函数。就是这样,这就是直通式估计器的作用。它使阈值函数的梯度看起来像恒等函数的梯度。

2、直通估计器的PyTorch 实现

截至目前,PyTorch 的 API 中尚未包含 STE 的实现。因此,我们必须自己实现它。为此,我们需要创建一个 Function 类和一个 Module 类。Function 类将包含 STE 的前向和后向功能。Module 类是创建和使用 STE Function 对象的地方。我们将在我们的神经网络中使用 STE Module。

以下是 STE Function 类的实现:

class STEFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        return (input > 0).float()

    @staticmethod
    def backward(ctx, grad_output):
        return F.hardtanh(grad_output)

PyTorch 让我们可以定义具有前向和后向功能的自定义自动求导函数。这里我们为直通式估算器定义了一个自动求导函数。在前向传递中,我们希望将输入张量中的所有值从浮点转换为二进制。在后向传递中,我们希望传递传入的梯度而不对其进行修改。这是为了模仿恒等函数。不过,这里我们对传入的梯度执行 F.hardtanh 操作。此操作将梯度限制在 -1 和 1 之间。我们这样做是为了让梯度不会变得太大。

现在,让我们实现 STE 模块类:

class StraightThroughEstimator(nn.Module):
    def __init__(self):
        super(StraightThroughEstimator, self).__init__()

    def forward(self, x):
            x = STEFunction.apply(x)
            return x

你可以看到,我们在 forward 函数中使用了我们定义的 STE 函数类。要使用 autograd 函数,我们必须将输入传递给 apply 方法。现在,我们可以在神经网络中使用此模块。

使用 STE 的常见方法是在自编码器的瓶颈层内。以下是此类自编码器的实现:

class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__()
        
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)),
            nn.ReLU(),
            
            nn.Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            
            nn.Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            
            nn.Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            
            StraightThroughEstimator(),
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(512, 256, kernel_size=(5, 5), stride=(2, 2), padding=(1, 1)),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            
            nn.ConvTranspose2d(256, 128, kernel_size=(5, 5), stride=(2, 2), padding=(1, 1)),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            
            nn.ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            
            nn.ConvTranspose2d(64, 1, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)),
            nn.Tanh(),
        )
        
    def forward(self, x, encode=False, decode=False):
        if encode:
            x = self.encoder(x)
        elif decode:
            x = self.decoder(x)
        else:
            encoding = self.encoder(x)
            x = self.decoder(encoding)
        return x

这个自编码器是为 MNIST 数据集制作的。它将 28x28 图像压缩为具有 512 个通道的 1x1 图像。然后将其解码回 28x28 图像。

我将 STE 放在编码器的末尾。它将把接收到的张量的所有值转换为二进制。你可能已经注意到我使用了一个非常规的前向函数。我添加了两个新参数 encode 和 decrypt,它们要么是 True,要么是 False。如果 encode 设置为 True,网络将返回编码器的输出。同样,如果 decrypt 设置为 True,网络需要有效的编码并将其解码回图像。

我在 MNIST 数据集上对自动编码器进行了 5 个 epoch 的训练,并带有 MSE 损失。以下是测试集上的重建:

如你所见,重建效果非常好。STE 可用于神经网络,且性能不会有太大损失。

完整代码如下:

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.autograd as autograd
from torchvision import datasets, transforms
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
# dataset preparation
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, ), (0.5, ))
])
trainset = datasets.MNIST('dataset/', train=True, download=True, transform=transform)
testset = datasets.MNIST('dataset/', train=False, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=True)
# defining networks
class STEFunction(autograd.Function):
    @staticmethod
    def forward(ctx, input):
        return (input > 0).float()
    @staticmethod
    def backward(ctx, grad_output):
        return F.hardtanh(grad_output)
class StraightThroughEstimator(nn.Module):
    def __init__(self):
        super(StraightThroughEstimator, self).__init__()
    def forward(self, x):
        x = STEFunction.apply(x)
        return x
class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__()
        
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)),
            nn.ReLU(),
            
            nn.Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            
            nn.Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            
            nn.Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            
            StraightThroughEstimator(),
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(512, 256, kernel_size=(5, 5), stride=(2, 2), padding=(1, 1)),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            
            nn.ConvTranspose2d(256, 128, kernel_size=(5, 5), stride=(2, 2), padding=(1, 1)),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            
            nn.ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            
            nn.ConvTranspose2d(64, 1, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)),
            nn.Tanh(),
        )
        
    def forward(self, x, encode=False, decode=False):
        if encode:
            x = self.encoder(x)
        elif decode:
            x = self.decoder(x)
        else:
            encoding = self.encoder(x)
            x = self.decoder(encoding)
        return x
net = Autoencoder().to(device)
optimizer = optim.Adam(net.parameters(), lr=0.001, betas=(0.5, 0.999))
criterion_MSE = nn.MSELoss().to(device)
# train loop
epoch = 5
for e in range(epoch):
    print(f'Starting epoch {e} of {epoch}')
    for X, y in tqdm(trainloader):
        optimizer.zero_grad()
        X = X.to(device)
        reconstruction = net(X)
        loss = criterion_MSE(reconstruction, X)
        loss.backward()
        optimizer.step()
    print(f'Loss: {loss.item()}')
# test loop
i = 1
fig = plt.figure(figsize=(10, 10))
for X, y in testloader:
    X_in = X.to(device)
    recon = net(X_in).detach().cpu().numpy()
    if i >= 10:
      break
    fig.add_subplot(5, 2, i).set_title('Original')
    plt.imshow(X[0].reshape((28, 28)), cmap="gray")
    fig.add_subplot(5, 2, i+1).set_title('Reconstruction')
    plt.imshow(recon[0].reshape((28, 28)), cmap="gray")
    i += 2
fig.tight_layout()
plt.show()

原文链接:梯度反传直通图解 - BimAnt

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1844964.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

CSDN图片居中、左对齐、右对齐、大小设置

图片居中、左对齐、右对齐 ![在这里插入图片描述](https://img-blog.csdnimg.cn/99dc1072e8f1471990b700e1c85d301a.jpeg#pic_center) 大小设置 空格400x150 空格30%x # 长400 宽200 ![在这里插入图片描述](https://img-blog.csdnimg.cn/99dc1072e8f1471990b700e1c85d301a.…

hive on spark 的架构和常见问题 - hive on spark 使用的是 yarn client 模式还是 yarn cluster 模式?

hive on spark 的架构和常见问题 - hive on spark 使用的是 yarn client 模式还是 yarn cluster 模式? 1. 回顾下 spark 的架构图和部署模式 来自官方的经典的 spark 架构图如下: 上述架构图,从进程的角度来讲,有四个角色/组件&…

opencascade AIS_InteractiveContext源码学习3 highlighting management 对象高亮管理

AIS_InteractiveContext 前言 交互上下文(Interactive Context)允许您在一个或多个视图器中管理交互对象的图形行为和选择。类方法使这一操作非常透明。需要记住的是,对于已经被交互上下文识别的交互对象,必须使用上下文方法进行…

TugraphDB:探索图数据库新境界

TugraphDB:释放图数据的全部潜能- 精选真开源,释放新价值。 概览 TugraphDB是支付宝背后的分布式图数据库。该项目是由蚂蚁集团和清华大学共同研发的高性能分布式图数据库,支持事务处理、TB 级大容量、低延迟查找和快速图分析等功能。专为处…

安卓系统安装linux搭建随手服务器termux平替软件介绍

引言 旧手机丢可惜,可以用ZeroTermux(一款代替termux)的超级终端,来模拟Linux(甚至你可以模拟Win,只要性能够用) ps:此软件只是termux的增强版,相当于增加右边菜单&…

第N5周:调用Gensim库训练Word2Vec模型

🍨 本文为🔗365天深度学习训练营 中的学习记录博客🍖 原作者:K同学啊 | 接辅导、项目定制🚀 文章来源:K同学的学习圈子 目录 本周任务: 1.安装Gensim库 2.对原始语料分词 3.停用词 4.训练Woed2Vec模型 …

单阶段目标检测--NMS

目录 一、概念: 二、算法过程 三、代码实现 一、概念: 在目标检测的初始结果中,同一个物体,可能对应有多个边界框 (bounding box,bb),这些边界框通常相互重叠。如何从中选择一个最合适 的(也就…

【2024最新华为OD-C/D卷试题汇总】[支持在线评测] 披萨大作战(100分) - 三语言AC题解(Python/Java/Cpp)

🍭 大家好这里是清隆学长 ,一枚热爱算法的程序员 ✨ 本系列打算持续跟新华为OD-C/D卷的三语言AC题解 💻 ACM银牌🥈| 多次AK大厂笔试 | 编程一对一辅导 👏 感谢大家的订阅➕ 和 喜欢💗 &#x1f…

1 UC

1 UC 1、环境变量2、环境变量表3、错误处理4、库文件4.1 静态库4.2 动态库4.3 动态库的动态加载 5、虚拟地址 1、环境变量 什么是环境变量? 每个进程都有一张自己的环境变量表,表中的每个条目都是形如“键值”形式的环境变量。进程可以通过环境变量访问…

opencascade AIS_InteractiveContext源码学习4 object local transformation management

AIS_InteractiveContext 前言 交互上下文(Interactive Context)允许您在一个或多个视图器中管理交互对象的图形行为和选择。类方法使这一操作非常透明。需要记住的是,对于已经被交互上下文识别的交互对象,必须使用上下文方法进行…

数据结构4---串

一、字符串暴力匹配 要注意的就是i与j的回溯&#xff0c;通过不断移动主串的指针&#xff0c;时间复杂度高 #include <stdio.h> #include <stdlib.h>typedef struct String {char* data;int len; }String;String* initString() {String* s (String*)malloc(sizeo…

分布式理论与设计 四、分布式系统设计策略

在分布式环境下&#xff0c;有几个问题是普遍关心的&#xff1a; 如何检测当前节点还活着&#xff1f;如何保障高可用&#xff1f;容错处理负载均衡 1.心跳检测 在分布式环境中&#xff0c;我们提及过存在非常多的节点&#xff08;Node&#xff09;。那么就有一个非常重要的…

c++ 编译过程杂记等

开篇一张图。 编译器 把我们的代码翻译成机器语言 ​ gcc编译程序的过程 gcc编译程序主要经过四个过程&#xff1a; 四个过程说明&#xff1a; ​ 预处理实际上是将头文件、宏进行展开。 编译阶段&#xff0c;gcc调用不同语言的编译器&#xff0c;例如c语言调用编译器ccl…

OpenTenBase入门

什么是OpenTenBase OpenTenBase 是一个提供写可靠性&#xff0c;多主节点数据同步的关系数据库集群平台。你可以将 OpenTenBase 配置一台或者多台主机上&#xff0c; OpenTenBase 数据存储在多台物理主机上面。数据表的存储有两种方式&#xff0c; 分别是 distributed 或者 re…

Android Studio main,xml 视图代码转换

Android Studio main,xml 视图&&代码转换 其实很简单,但是对我们小白来说还是比较蒙的。 废话不多说,直接上图。 我的Android Studio 是 4.0 版的 我刚打开是这个界面,在我想学习如何用代码来布局,可能大家也会找不见代码的位置。 follow me 是不是感觉很简单呢。…

使用Python和BeautifulSoup轻松抓取表格数据

你是否曾经希望可以轻松地从网页上获取表格数据&#xff0c;而不是手动复制粘贴&#xff1f;好消息来了&#xff0c;使用Python和BeautifulSoup&#xff0c;你可以轻松实现这一目标。今天&#xff0c;我们将探索如何使用这些工具抓取中国气象局网站(http://weather.cma.cn)上的…

使用fastapi和pulumi搭建基于Azure云的IAC Restful API服务 — 对外发布

前言 在IAC&#xff08;即Infrastructure As Code&#xff0c;基础设施即代码&#xff09;领域&#xff0c;Terraform 是一个老牌工具&#xff0c;使用HCL&#xff08;HashiCorp Configuration Language&#xff09;语言来编写配置文件。它支持几乎所有主流的云提供商&#xf…

贝锐蒲公英异地组网方案:实现制药设备远程监控、远程运维

公司业务涉及放射性药品的生产与销售&#xff0c;在全国各地拥有20多个分公司。由于药品的特殊性&#xff0c;在日常生产过程中&#xff0c;需要符合药品监管规范要求&#xff0c;对各个分部的气相、液相设备及打印机等进行监管&#xff0c;了解其运行数据及工作情况。 为满足这…

[极客大挑战 2020]Roamphp4-Rceme

rce,rce,rce!!! 右键源代码里给了提示&#xff0c;有备份文件index.php.swp,大伙都做到这来了&#xff0c;应该不用写了吧。看源码 <?php error_reporting(0); session_start(); if(!isset($_SESSION[code])){$_SESSION[code] substr(md5(mt_rand().sha1(mt_rand)),0,5);…

电脑上使用备忘录怎么查看编辑时间?能显示时间的备忘录

在快节奏的生活中&#xff0c;很多人喜欢使用备忘录来记录日常事项和重要信息。备忘录不仅能帮助我们捕捉灵感&#xff0c;还能确保重要任务不被遗漏。然而&#xff0c;有时候我们需要知道某条记录的编辑时间&#xff0c;以便于回溯和整理信息。如果备忘录不能显示编辑时间&…