PixelSNAIL论文代码学习(2)——门控残差网络的实现

news2025/1/16 2:51:59

文章目录

    • 引言
    • 正文
      • 门控残差网络介绍
      • 门控残差网络具体实现代码
      • 使用pytorch实现
    • 总结

引言

  • 阅读了pixelSNAIL,很简短,就用了几页,介绍了网络结构,介绍了试验效果就没有了,具体论文学习链接
  • 这段时间看他的代码,还是挺痛苦的,因为我对于深度学习的框架尚且不是很熟练 ,而且这个作者很厉害,很多东西都是自己实现的,所以看起来十分费力,本来想逐行分析,结果发现逐行分析不现实,所以这里按照模块进行分析。
  • 今天就专门来学习一下他门门控控残差模块如何实现。

正文

门控残差网络介绍

  • 介绍

    • 通过门来控制每一个残差模块,门通常是由sigmoid函数组成
    • 作用:有效建模复杂函数,有助于缓解梯度消失和爆炸的问题
  • 基本步骤

    • 卷积操作:对输入矩阵执行卷积操作
    • 非线性激活:应用非线性激活函数,激活卷积操作的输出
    • 第二次卷积操作:对上一个层的输出进行二次卷积
    • 门控操作:将二次卷积的输出分为a和b两个部分,并且通过sigmoid函数进行门控 a , b = S p l i t ( c 2 ) G a t e : g = a × s i g m o i d ( b ) a,b = Split(c_2) \\ Gate:g = a \times sigmoid(b) a,b=Split(c2)Gate:g=a×sigmoid(b)
      • 这里一般是沿着最后一个通道,将原来的矩阵拆解成a和b,然后在相乘,确保每一个矩阵有一个门控参数
    • 将门控输出 g g g和原始输入 x x x相加
  • 具体流程图如下

    • x: 输入
    • c1: 第一次卷积操作(Conv1)
    • a1: 非线性激活函数(例如 ReLU)
    • c2: 第二次卷积操作(Conv2),输出通道数是输入通道数的两倍
    • split: 将c2 分为两部分 a 和 b
    • a, b: 由 c2 分割得到的两部分
    • sigmoid: 对b 应用 sigmoid 函数
    • gated: 执行门控操作 a×sigmoid(b)
    • y: 输出,由原始输入 x 和门控输出相加得到

在这里插入图片描述

  • 这里参考一下论文中的图片,可以看到和基本的门控神经网络是近似的,只不过增加了一些辅助输入还有条件矩阵

在这里插入图片描述

门控残差网络具体实现代码

  • 具体和上面描述的差不多,这里增加了两个额外的参数,分别是辅助输入a和条件矩阵b

  • 注意,这里的二维卷积就是加上了简单的权重归一化的普通二维卷积。

  • 辅助输入a

    • 用途:提供额外的信息,帮助网络更好地执行任务,比如说在多模态场景或者多任务学习中,会通过a提供主输入x相关联的信息
    • 操作:如果提供了a,那么在第一次卷积之后,会经过全连接层与c1相加
  • 条件矩阵h

    • 用途:主要用于条件生成任务,因为条件生成任务的网络行为会受到某些条件和上下文影响。比如,在文本生成图像中,h会是一个文本描述的嵌入
    • 操作:如果提供了 h,那么 h 会被投影到一个与 c2 具有相同维度的空间中,并与 c2 相加。这是通过一个全连接层实现的,该层的权重是 hw。
def gated_resnet(x, a=None, h=None, nonlinearity=concat_elu, conv=conv2d, init=False, counters={}, ema=None, dropout_p=0., **kwargs):
    xs = int_shape(x)
    num_filters = xs[-1]

    # 执行第一次卷积
    c1 = conv(nonlinearity(x), num_filters)

    # 查看是否有辅助输入a
    if a is not None:  # add short-cut connection if auxiliary input 'a' is given
        c1 += nin(nonlinearity(a), num_filters)

    # 执行非线性单元
    c1 = nonlinearity(c1)
    if dropout_p > 0:
        c1 = tf.nn.dropout(c1, keep_prob=1. - dropout_p)

    # 执行第二次卷积
    c2 = conv(c1, num_filters * 2, init_scale=0.1)

    # add projection of h vector if included: conditional generation
    # 如果有辅助输入h,那么就将h投影到c2的维度上
    if h is not None:
        with tf.variable_scope(get_name('conditional_weights', counters)):
            hw = get_var_maybe_avg('hw', ema, shape=[int_shape(h)[-1], 2 * num_filters], dtype=tf.float32,
                                   initializer=tf.random_normal_initializer(0, 0.05), trainable=True)
        if init:
            hw = hw.initialized_value()
        c2 += tf.reshape(tf.matmul(h, hw), [xs[0], 1, 1, 2 * num_filters])

    # Is this 3,2 or 2,3 ?
    a, b = tf.split(c2, 2, 3)
    c3 = a * tf.nn.sigmoid(b)
    return x + c3

使用pytorch实现

  • tensorflow的模型定义过程和pytorch的定义过程就是不一样,tensorflow中的conv2d只需要给出输出的channel,直接输入需要卷积的部分即可。但是使用pytorch,需要进行给定输入的 channel,然后在给出输出的filter_size,很麻烦。
  • 除此之外,在定义模型的层的过程中,我们不能在forward中定义层,只能在init函数中定义层。
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import weight_norm

class GatedResNet(nn.Module):
    def __init__(self, num_filters, nonlinearity=F.elu, dropout_p=0.0):
        super(GatedResNet, self).__init__()
        self.num_filters = num_filters
        self.nonlinearity = nonlinearity
        self.dropout_p = dropout_p

        # 第一卷积层
        self.conv1 = nn.Conv2d(num_filters, num_filters, kernel_size=3, padding=1)
#         self.conv1 = weight_norm(self.conv1)

        # 第二卷积层,输出通道是 2 * num_filters,用于门控机制
        self.conv2 = nn.Conv2d(num_filters, 2 * num_filters, kernel_size=3, padding=1)
#         self.conv2 = weight_norm(self.conv2)

        
        # 条件权重用于 h,初始化在前向传播过程中
        self.hw = None

    def forward(self, x, a=None, h=None):
        c1 = self.conv1(self.nonlinearity(x))

        # 检查是否有辅助输入 'a'
        if a is not None:
            c1 += a  # 或使用 NIN 使维度兼容

        c1 = self.nonlinearity(c1)
        if self.dropout_p > 0:
            c1 = F.dropout(c1, p=self.dropout_p, training=self.training)

        c2 = self.conv2(c1)
        print('the shape of c2',c2.shape)

        # 如果有辅助输入 h,则加入 h 的投影
        if h is not None:
            if self.hw is None:
                self.hw = nn.Parameter(torch.randn(h.size(1),  self.num_filters) * 0.05)
            print(self.hw.shape)
            c2 +=  (h @ self.hw).view(h.size(0), 1, 1, self.num_filters)
            

        # 将通道分为两组:'a' 和 'b'
        a, b = c2.chunk(2, dim=1)
        c3 = a * torch.sigmoid(b)

        return x + c3

# 测试
x = torch.randn(16, 32, 32, 32)  # [批次大小,通道数,高度,宽度]
a = torch.randn(16, 32, 32, 32)  # 和 x 维度相同的辅助输入
h = torch.randn(16, 64)  # 可选的条件变量
model = GatedResNet(32)
out = model(x, a , h)

在这里插入图片描述

总结

  • 遇到了很多问题,是因为经验不够,而且很多东西都不了解,然后改的很痛苦,而且现在完全还没有跑起来,完整的组件都没有搭建完成,这里还需要继续努力。
  • 关于门控残差网络这里,这里学到了很多,知道了具体的运作流程,也知道他是专门针对序列数据,防止出现梯度爆炸的。以后可以多用用看。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/968148.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【数学建模竞赛】Matlab逻辑规则,结构基础及函数

逻辑基础 逻辑变量 在Matlab中,逻辑变量是一种特殊类型的变量,用于表示逻辑值。逻辑变量只有两个可能的值:true(真)和false(假)。在Matlab中,我们可以使用0和1来表示逻辑变量的值。…

数据结构(Java实现)-字符串常量池与通配符

字符串常量池 在Java程序中,类似于:1, 2, 3,3.14,“hello”等字面类型的常量经常频繁使用,为了使程序的运行速度更快、更节省内存,Java为8种基本数据类型和String类都提供了常量池。…

Excel_VBA程序文件的加密及解密说明

VBA应用技巧及疑难解答 Excel_VBA程序文件的加密及解密 在您看到这个文档的时候,请和我一起念:“唵嘛呢叭咪吽”“唵嘛呢叭咪吽”“唵嘛呢叭咪吽”,为自己所得而感恩,为付出者赞叹功德。 本不想分享之一技术,但众多学…

【Java核心知识】JUC包相关知识

文章目录 JUC包主要内容Java内置锁为什么会有线程安全问题Synchronize锁Java对象结构Synchronize锁优化线程间通信Synchronize与wait原理 CAS和JUC原子类CAS原理JUC原子类ABA问题 可见性和有序性为什么会有可见性参考链接 显式锁Lock接口常用方法显式锁分类显式锁实现原理参考链…

数据结构(Java实现)-排序

排序的概念 排序:所谓排序,就是使一串记录,按照其中的某个或某些关键字的大小,递增或递减的排列起来的操稳定性:假定在待排序的记录序列中,存在多个具有相同的关键字的记录,若经过排序&#xff…

chatgpt谈论日本排放污水事件

W...Y的主页 😊 代码仓库分享 💕 近日,世界发生了让人义愤填膺的时间——日本排放核污水。这件事情是那么的突然且不计后果,海洋是我们全人类共同的财产,而日本却想用自己一己私欲将全人类的安全置之度外&#xff0c…

攻防世界-Caesar

原题 解题思路 没出现什么特殊字符,可能是个移位密码。凯撒密码加密解密。偏移12位就行。

Spring Cloud--从零开始搭建微服务基础环境【三】

😀前言 本篇博文是关于Spring Cloud–从零开始搭建微服务基础环境【三】,希望你能够喜欢 🏠个人主页:晨犀主页 🧑个人简介:大家好,我是晨犀,希望我的文章可以帮助到大家,…

全网都在用的nnUNet V2版本改进了啥,怎么安装?(一)

nnUNet,这个医学领域的分割巨无霸!在论文和比赛中随处可见他的身影。大家对于nnUNet v1版本的教程都赞不绝口,因为它简单易懂、详细全面,让很多朋友都轻松掌握了使用方法。 最近,我也抽出时间仔细研究了nnUNet v2,并全…

vue声明周期

1.在created中发送数据 async created(){ const resawait axios.get("url) this.listres.data.data } 2.在mounted中获取焦点 mounted(){ document.querySelector(#inp).focus()

分类预测 | MATLAB实现GRNN广义回归神经网络多特征分类预测

分类预测 | MATLAB实现GRNN广义回归神经网络多特征分类预测 目录 分类预测 | MATLAB实现GRNN广义回归神经网络多特征分类预测分类效果基本介绍模型描述预测过程程序设计参考资料分类效果 基本介绍 MATLAB实现GRNN广义回归神经网络多特

企业架构LNMP学习笔记9

nginx配置文件定义php-fpm服务&#xff1a; 编写测试文件&#xff1a; vim /usr/local/nginx/html/index.php 内容&#xff1a; <?phpphpinfo(); 在nginx的配置文件中配置&#xff1a; 修改配置文件&#xff0c;告知nginx如果收到.php结尾的请求&#xff0c;交由给php-…

「MySQL-02」数据库的操纵、备份、还原和编码规则

目录 一、库操作 1. 创建数据库 2. 查看所有数据库 3. 删除数据库 4. 修改数据库 5. 进入一个数据库 二、查看和设置数据库的编码规则 1. MySQL的两个编码规则&#xff1a;字符集和校验规则 2. 查看MySQL当前使用的字符集以及校验规则 3. 查看MySQL支持的所有字符集 4. 查看MyS…

nnUNet v2数据准备及格式转换 (二)

如果你曾经使用过nnUNet V1&#xff0c;那你一定明白数据集的命名是有严格要求的&#xff0c;必须按照特定的格式来进行命名才能正常使用。 这一节的学习需要有数据&#xff0c;如果你有自己的数据&#xff0c;可以拿自己的数据来实验&#xff0c;如果没有&#xff0c;可以用十…

JVM类的加载过程

加载过程 JVM的类的加载过程分为五个阶段&#xff1a;加载、验证、准备、解析、初始化。 加载   加载阶段就是将编译好的的class文件通过字节流的方式从硬盘或者通过网络加载到JVM虚拟机当中来。&#xff08;我们平时在Idea中书写的代码就是放在磁盘中的&#xff0c;也可以通…

Kubernetes可视化管理工具Kuboard部署使用及k8s常用命令梳理记录

温故知新 &#x1f4da;第一章 前言&#x1f4d7;背景&#x1f4d7;目的&#x1f4d7;总体方向 &#x1f4da;第二章 安装 Kubernetes 多集群管理工具 - Kuboard v3&#x1f4d7;部署方式&#x1f4d7;通过Kuboard v3 - Kubernetes安装&#xff08;在master节点执行)&#x1f4…

大学生攻略:正确的购买和使用你的电脑

笔者是计算机专业在读大学生&#xff0c;从小学开始接触电脑&#xff0c;进行过各种操作(更换硬件维修&#xff0c;换系统&#xff0c;系统命令行&#xff0c;管理员权限&#xff0c;无视风险继续安装&#xff0c;没有这条 )&#xff0c;相对大学生有一定参考价值。 购买 1.买…

【Java并发】聊聊AQS原理机制

什么是AQS AbstractQueuedSynchronizer是一个抽象队列同步器&#xff0c;主要是实现并发工具类的基石。 是用来构建锁或者其它同步器组件的重量级基础框架及整个JUC体系的基石&#xff0c; 通过内置的FIFO队列来完成资源获取线程的排队工作&#xff0c;并通过一个int类变量表示…

仿京东 项目笔记1

目录 项目代码1. 项目配置2. 前端Vue核心3. 组件的显示与隐藏用v-if和v-show4. 路由传参4.1 路由跳转有几种方式&#xff1f;4.2 路由传参&#xff0c;参数有几种写法&#xff1f;4.3 路由传参相关面试题4.3.1 路由传递参数&#xff08;对象写法&#xff09;path是否可以结合pa…

MyBatis-Plus —— 初窥门径

前言 在前面的文章中荔枝梳理了MyBatis及相关的操作&#xff0c;作为MyBatis的增强工具&#xff0c;MyBatis-Plus无需再在xml中写sql语句&#xff0c;在这篇文章中荔枝将梳理MyBatis-Plus的基础知识并基于SpringBoot梳理MyBatis-Plus给出的两个接口&#xff1a;BaseMapper和ISe…