文章目录
- 引言
- 正文
- 门控残差网络介绍
- 门控残差网络具体实现代码
- 使用pytorch实现
- 总结
引言
- 阅读了pixelSNAIL,很简短,就用了几页,介绍了网络结构,介绍了试验效果就没有了,具体论文学习链接
- 这段时间看他的代码,还是挺痛苦的,因为我对于深度学习的框架尚且不是很熟练 ,而且这个作者很厉害,很多东西都是自己实现的,所以看起来十分费力,本来想逐行分析,结果发现逐行分析不现实,所以这里按照模块进行分析。
- 今天就专门来学习一下他门门控控残差模块如何实现。
正文
门控残差网络介绍
-
介绍
- 通过门来控制每一个残差模块,门通常是由sigmoid函数组成
- 作用:有效建模复杂函数,有助于缓解梯度消失和爆炸的问题
-
基本步骤
- 卷积操作:对输入矩阵执行卷积操作
- 非线性激活:应用非线性激活函数,激活卷积操作的输出
- 第二次卷积操作:对上一个层的输出进行二次卷积
- 门控操作:将二次卷积的输出分为a和b两个部分,并且通过sigmoid函数进行门控
a
,
b
=
S
p
l
i
t
(
c
2
)
G
a
t
e
:
g
=
a
×
s
i
g
m
o
i
d
(
b
)
a,b = Split(c_2) \\ Gate:g = a \times sigmoid(b)
a,b=Split(c2)Gate:g=a×sigmoid(b)
- 这里一般是沿着最后一个通道,将原来的矩阵拆解成a和b,然后在相乘,确保每一个矩阵有一个门控参数
- 将门控输出 g g g和原始输入 x x x相加
-
具体流程图如下
- x: 输入
- c1: 第一次卷积操作(Conv1)
- a1: 非线性激活函数(例如 ReLU)
- c2: 第二次卷积操作(Conv2),输出通道数是输入通道数的两倍
- split: 将c2 分为两部分 a 和 b
- a, b: 由 c2 分割得到的两部分
- sigmoid: 对b 应用 sigmoid 函数
- gated: 执行门控操作 a×sigmoid(b)
- y: 输出,由原始输入 x 和门控输出相加得到
- 这里参考一下论文中的图片,可以看到和基本的门控神经网络是近似的,只不过增加了一些辅助输入还有条件矩阵
门控残差网络具体实现代码
-
具体和上面描述的差不多,这里增加了两个额外的参数,分别是辅助输入a和条件矩阵b
-
注意,这里的二维卷积就是加上了简单的权重归一化的普通二维卷积。
-
辅助输入a
- 用途:提供额外的信息,帮助网络更好地执行任务,比如说在多模态场景或者多任务学习中,会通过a提供主输入x相关联的信息
- 操作:如果提供了a,那么在第一次卷积之后,会经过全连接层与c1相加
-
条件矩阵h
- 用途:主要用于条件生成任务,因为条件生成任务的网络行为会受到某些条件和上下文影响。比如,在文本生成图像中,h会是一个文本描述的嵌入
- 操作:如果提供了 h,那么 h 会被投影到一个与 c2 具有相同维度的空间中,并与 c2 相加。这是通过一个全连接层实现的,该层的权重是 hw。
def gated_resnet(x, a=None, h=None, nonlinearity=concat_elu, conv=conv2d, init=False, counters={}, ema=None, dropout_p=0., **kwargs):
xs = int_shape(x)
num_filters = xs[-1]
# 执行第一次卷积
c1 = conv(nonlinearity(x), num_filters)
# 查看是否有辅助输入a
if a is not None: # add short-cut connection if auxiliary input 'a' is given
c1 += nin(nonlinearity(a), num_filters)
# 执行非线性单元
c1 = nonlinearity(c1)
if dropout_p > 0:
c1 = tf.nn.dropout(c1, keep_prob=1. - dropout_p)
# 执行第二次卷积
c2 = conv(c1, num_filters * 2, init_scale=0.1)
# add projection of h vector if included: conditional generation
# 如果有辅助输入h,那么就将h投影到c2的维度上
if h is not None:
with tf.variable_scope(get_name('conditional_weights', counters)):
hw = get_var_maybe_avg('hw', ema, shape=[int_shape(h)[-1], 2 * num_filters], dtype=tf.float32,
initializer=tf.random_normal_initializer(0, 0.05), trainable=True)
if init:
hw = hw.initialized_value()
c2 += tf.reshape(tf.matmul(h, hw), [xs[0], 1, 1, 2 * num_filters])
# Is this 3,2 or 2,3 ?
a, b = tf.split(c2, 2, 3)
c3 = a * tf.nn.sigmoid(b)
return x + c3
使用pytorch实现
- tensorflow的模型定义过程和pytorch的定义过程就是不一样,tensorflow中的conv2d只需要给出输出的channel,直接输入需要卷积的部分即可。但是使用pytorch,需要进行给定输入的 channel,然后在给出输出的filter_size,很麻烦。
- 除此之外,在定义模型的层的过程中,我们不能在forward中定义层,只能在init函数中定义层。
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import weight_norm
class GatedResNet(nn.Module):
def __init__(self, num_filters, nonlinearity=F.elu, dropout_p=0.0):
super(GatedResNet, self).__init__()
self.num_filters = num_filters
self.nonlinearity = nonlinearity
self.dropout_p = dropout_p
# 第一卷积层
self.conv1 = nn.Conv2d(num_filters, num_filters, kernel_size=3, padding=1)
# self.conv1 = weight_norm(self.conv1)
# 第二卷积层,输出通道是 2 * num_filters,用于门控机制
self.conv2 = nn.Conv2d(num_filters, 2 * num_filters, kernel_size=3, padding=1)
# self.conv2 = weight_norm(self.conv2)
# 条件权重用于 h,初始化在前向传播过程中
self.hw = None
def forward(self, x, a=None, h=None):
c1 = self.conv1(self.nonlinearity(x))
# 检查是否有辅助输入 'a'
if a is not None:
c1 += a # 或使用 NIN 使维度兼容
c1 = self.nonlinearity(c1)
if self.dropout_p > 0:
c1 = F.dropout(c1, p=self.dropout_p, training=self.training)
c2 = self.conv2(c1)
print('the shape of c2',c2.shape)
# 如果有辅助输入 h,则加入 h 的投影
if h is not None:
if self.hw is None:
self.hw = nn.Parameter(torch.randn(h.size(1), self.num_filters) * 0.05)
print(self.hw.shape)
c2 += (h @ self.hw).view(h.size(0), 1, 1, self.num_filters)
# 将通道分为两组:'a' 和 'b'
a, b = c2.chunk(2, dim=1)
c3 = a * torch.sigmoid(b)
return x + c3
# 测试
x = torch.randn(16, 32, 32, 32) # [批次大小,通道数,高度,宽度]
a = torch.randn(16, 32, 32, 32) # 和 x 维度相同的辅助输入
h = torch.randn(16, 64) # 可选的条件变量
model = GatedResNet(32)
out = model(x, a , h)
总结
- 遇到了很多问题,是因为经验不够,而且很多东西都不了解,然后改的很痛苦,而且现在完全还没有跑起来,完整的组件都没有搭建完成,这里还需要继续努力。
- 关于门控残差网络这里,这里学到了很多,知道了具体的运作流程,也知道他是专门针对序列数据,防止出现梯度爆炸的。以后可以多用用看。