【有作图代码】Highway Network与ResNet:skip connection如何解决深层网络欠拟合问题
关键词:
#Highway Network
#ResNet
#skip connection
#深层网络
#欠拟合问题
具体实例与推演
假设我们有一个深层神经网络,其层数为L,每一层的输入和输出分别为 x l x_l xl和 y l y_l yl,传统的神经网络每一层的输出可以表示为:
y l = f ( W l x l + b l ) y_l = f(W_lx_l + b_l) yl=f(Wlxl+bl)
其中, f f f是激活函数, W l W_l Wl和 b l b_l bl分别是第l层的权重和偏置。
第一节:Highway Network与ResNet的类比与核心概念
Highway Network和ResNet就像是给深层神经网络修建了一条“高速公路”,让信息可以直接从浅层传递到深层,避免了信息在传递过程中的“堵塞”和“丢失”,从而解决了深层网络的欠拟合问题。
这就像是在城市中修建了高架桥,让车辆可以直接通过,避免了地面交通的拥堵。
第二节:Highway Network与ResNet的核心概念与应用
2.1 核心概念
核心概念 | 定义 | 比喻或解释 |
---|---|---|
Highway Network | 一种通过引入skip connection来允许信息直接跨层传递的神经网络结构。 | 像是给神经网络修建了一条“高速公路”,让信息可以直接从浅层传递到深层。 |
ResNet | 一种具有残差连接的深层神经网络,通过skip connection解决深层网络退化问题。 | 同样是修建了“高速公路”,但其在结构上更加简洁,应用更加广泛。 |
skip connection | 一种允许信息直接跨层传递的连接方式,可以避免信息在传递过程中的损失。 | 就像是城市中的高架桥,让信息可以直接通过,避免了“交通拥堵”。 |
2.2 优势与劣势
方面 | 描述 |
---|---|
优势 | 能够解决深层网络的欠拟合问题,提高网络的训练效果和泛化能力。 |
劣势 | 可能会增加网络的复杂度和计算量,需要合理设计网络结构。 |
2.3 与深层网络训练的类比
Highway Network和ResNet在深层网络训练中扮演着“疏通者”的角色,它们通过修建“高速公路”,让信息可以更加顺畅地传递,从而避免了网络的“堵塞”和“退化”,提高了网络的训练效果和泛化能力。
第三节:公式探索与推演运算
3.1 Highway Network的基本形式
Highway Network的每一层可以表示为:
y l = H ( x l , W l H ) ⋅ T ( x l , W l T ) + x l ⋅ ( 1 − T ( x l , W l T ) ) y_l = H(x_l, W_lH) \cdot T(x_l, W_lT) + x_l \cdot (1 - T(x_l, W_lT)) yl=H(xl,WlH)⋅T(xl,WlT)+xl⋅(1−T(xl,WlT))
其中, H H H是非线性变换, T T T是变换门(transform gate), 1 − T 1-T 1−T是携带门(carry gate)。
3.2 ResNet的基本形式
ResNet的每一层可以表示为:
y l = f ( x l + F ( x l , W l ) ) y_l = f(x_l + F(x_l, W_l)) yl=f(xl+F(xl,Wl))
其中, F F F是残差函数,通常是一个两层或三层的卷积神经网络。
3.3 具体实例与推演
以Highway Network为例,假设我们有一个简单的两层网络,其输入为 x x x,输出为 y y y,第一层和第二层的权重分别为 W 1 W_1 W1和 W 2 W_2 W2,偏置分别为 b 1 b_1 b1和 b 2 b_2 b2,激活函数为ReLU。
- 第一层输出:
y 1 = H ( x , W 1 H ) ⋅ T ( x , W 1 T ) + x ⋅ ( 1 − T ( x , W 1 T ) ) y_1 = H(x, W_1H) \cdot T(x, W_1T) + x \cdot (1 - T(x, W_1T)) y1=H(x,W1H)⋅T(x,W1T)+x⋅(1−T(x,W1T))
其中, H ( x , W 1 H ) = ReLU ( W 1 H x + b 1 H ) H(x, W_1H) = \text{ReLU}(W_1Hx + b_1H) H(x,W1H)=ReLU(W1Hx+b1H), T ( x , W 1 T ) = σ ( W 1 T x + b 1 T ) T(x, W_1T) = \sigma(W_1Tx + b_1T) T(x,W1T)=σ(W1Tx+b1T), σ \sigma σ是sigmoid函数。
- 第二层输出:
y = H ( y 1 , W 2 H ) ⋅ T ( y 1 , W 2 T ) + y 1 ⋅ ( 1 − T ( y 1 , W 2 T ) ) y = H(y_1, W_2H) \cdot T(y_1, W_2T) + y_1 \cdot (1 - T(y_1, W_2T)) y=H(y1,W2H)⋅T(y1,W2T)+y1⋅(1−T(y1,W2T))
其中, H ( y 1 , W 2 H ) = ReLU ( W 2 H y 1 + b 2 H ) H(y_1, W_2H) = \text{ReLU}(W_2Hy_1 + b_2H) H(y1,W2H)=ReLU(W2Hy1+b2H), T ( y 1 , W 2 T ) = σ ( W 2 T y 1 + b 2 T ) T(y_1, W_2T) = \sigma(W_2Ty_1 + b_2T) T(y1,W2T)=σ(W2Ty1+b2T)。
通过引入skip connection,Highway Network允许信息直接从输入层传递到输出层,避免了信息在传递过程中的损失。
第四节:相似公式比对
公式/网络结构 | 共同点 | 不同点 |
---|---|---|
Highway Network | 都引入了skip connection来允许信息跨层传递。 | Highway Network使用了变换门和携带门来控制信息的传递。 |
ResNet | 都解决了深层网络的退化问题。 | ResNet通过残差连接实现信息的跨层传递,结构更加简洁。 |
LSTM(时间维度展开) | 都在某种程度上实现了信息的“跨层”传递(时间维度上的展开)。 | LSTM是在时间维度上展开,用于处理序列数据,与空间维度上的skip connection不同。 |
第五节:核心代码与可视化
由于Highway Network和ResNet的实现涉及复杂的神经网络结构和训练过程,这里我们提供一个简化的Python代码示例,用于演示skip connection的基本概念。请注意,这只是一个示意性的代码,并不直接对应于具体的Highway Network或ResNet实现。
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
# 定义激活函数ReLU和sigmoid
def relu(x):
return np.maximum(0, x)
def sigmoid(x):
return 1 / (1 + np.exp(-x))
# 模拟一个简单的两层Highway Network
def highway_network(x, W1H, b1H, W1T, b1T, W2H, b2H, W2T, b2T):
# 第一层
H1 = relu(np.dot(x, W1H) + b1H)
T1 = sigmoid(np.dot(x, W1T) + b1T)
y1 = H1 * T1 + x * (1 - T1)
# 第二层
H2 = relu(np.dot(y1, W2H) + b2H)
T2 = sigmoid(np.dot(y1, W2T) + b2T)
y = H2 * T2 + y1 * (1 - T2)
return y
# 初始化权重和偏置(随机初始化)
np.random.seed(0)
W1H, b1H = np.random.randn(3, 3), np.random.randn(3)
W1T, b1T = np.random.randn(3, 1), np.random.randn(1)
W2H, b2H = np.random.randn(3, 3), np.random.randn(3)
W2T, b2T = np.random.randn(3, 1), np.random.randn(1)
# 输入数据
x = np.array([1, 2, 3])
# 通过Highway Network传递
y = highway_network(x, W1H, b1H, W1T, b1T, W2H, b2H, W2T, b2T)
# 可视化结果
sns.set_theme(style="whitegrid")
fig, ax = plt.subplots()
# 输入数据可视化
ax.bar(['x1', 'x2', 'x3'], x, label='Input x', alpha=0.6)
# 输出数据可视化
ax.bar(['y1', 'y2', 'y3'], y, label='Output y', alpha=0.6, bottom=np.array([4, 5, 6])) # bottom用于错位显示
# 添加图例和标签
ax.set_title('Highway Network with Skip Connection')
ax.set_xlabel('Neurons')
ax.set_ylabel('Values')
ax.legend()
# 添加注释
for i, (xi, yi) in enumerate(zip(x, y)):
ax.annotate(f'x{i+1}={xi:.1f}\ny{i+1}={yi:.1f}', xy=(i, yi+4), xytext=(i, yi+5), # 错位显示注释
arrowprops=dict(facecolor='black', shrink=0.05))
plt.show()