残差网络,解决梯度消失

news2024/11/22 15:55:29

残差网络

1. 引言

在深度学习的快速发展中,模型的深度和复杂性不断增加。然而,随着网络层数的增加,训练过程中的一些问题逐渐显现出来,尤其是梯度消失和梯度爆炸问题。这些问题导致了深层神经网络的性能下降,限制了模型的表达能力。为了解决这一问题,Kaiming He 等人在 2015 年提出了残差网络(ResNet),该架构通过引入残差学习的概念,显著提高了深层神经网络的训练效果。

2. 残差网络的背景

2.1 深度学习的挑战

在传统的深层网络中,随着层数的增加,网络的训练变得更加困难。训练过程中,梯度在反向传播时可能会逐渐消失,导致前面的层无法有效更新。这种现象被称为梯度消失(vanishing gradient),使得深层网络难以学习到有效的特征。

2.2 残差学习的提出

残差网络的核心思想是通过引入跳跃连接(skip connections)来缓解深层网络中的梯度消失问题。具体来说,网络的每一层不仅学习输入 x x x 到输出 F ( x ) F(x) F(x) 的映射,还学习输入与输出之间的残差(即差异):

y = F ( x ) + x y = F(x) + x y=F(x)+x

其中:

  • y y y 是残差块的输出。
  • F ( x ) F(x) F(x) 是通过多个层(如卷积、激活函数等)计算得到的结果。
  • x x x 是输入。

这种结构允许网络在需要时选择不更新某些层的权重,从而实现恒等映射。

2.3 残差网络的公式推导

考虑一个深度网络的输出为 y y y,如果我们希望网络学习到某个目标函数 H ( x ) H(x) H(x),则可以将其表示为:

H ( x ) = F ( x ) + x H(x) = F(x) + x H(x)=F(x)+x

在这种情况下, F ( x ) F(x) F(x) 是需要学习的残差。通过这种方式,网络可以更容易地学习到恒等映射。假设 H ( x ) H(x) H(x) 是一个恒等映射(即 H ( x ) = x H(x) = x H(x)=x),那么我们可以选择 F ( x ) = 0 F(x) = 0 F(x)=0,这样网络可以直接将输入传递到输出。

3. 简单的计算例子

为了更好地理解残差连接的作用,考虑以下简单的计算例子:

假设我们有一个简单的输入矩阵 x x x

x = [ 1 2 3 4 ] x = \begin{bmatrix} 1 & 2 \\ 3 & 4 \end{bmatrix} x=[1324]

我们希望网络学习到恒等映射。假设网络的结构如下:

  1. 第一层权重 W 1 = [ 0.5 0.5 0.5 0.5 ] W_1 = \begin{bmatrix} 0.5 & 0.5 \\ 0.5 & 0.5 \end{bmatrix} W1=[0.50.50.50.5],偏置 b 1 = [ 0.5 0.5 ] b_1 = \begin{bmatrix} 0.5 \\ 0.5 \end{bmatrix} b1=[0.50.5]
  2. 第二层权重 W 2 = [ 0.5 0.5 0.5 0.5 ] W_2 = \begin{bmatrix} 0.5 & 0.5 \\ 0.5 & 0.5 \end{bmatrix} W2=[0.50.50.50.5],偏置 b 2 = [ − 0.5 − 0.5 ] b_2 = \begin{bmatrix} -0.5 \\ -0.5 \end{bmatrix} b2=[0.50.5]

计算过程:

  1. 输入 x x x 经过第一层:
    F 1 ( x ) = W 1 ⋅ x + b 1 = [ 0.5 0.5 0.5 0.5 ] ⋅ [ 1 2 ] + [ 0.5 0.5 ] = [ 0.5 ⋅ 1 + 0.5 ⋅ 2 + 0.5 0.5 ⋅ 1 + 0.5 ⋅ 2 + 0.5 ] = [ 2.5 2.5 ] F_1(x) = W_1 \cdot x + b_1 = \begin{bmatrix} 0.5 & 0.5 \\ 0.5 & 0.5 \end{bmatrix} \cdot \begin{bmatrix} 1 \\ 2 \end{bmatrix} + \begin{bmatrix} 0.5 \\ 0.5 \end{bmatrix} = \begin{bmatrix} 0.5 \cdot 1 + 0.5 \cdot 2 + 0.5 \\ 0.5 \cdot 1 + 0.5 \cdot 2 + 0.5 \end{bmatrix} = \begin{bmatrix} 2.5 \\ 2.5 \end{bmatrix} F1(x)=W1x+b1=[0.50.50.50.5][12]+[0.50.5]=[0.51+0.52+0.50.51+0.52+0.5]=[2.52.5]

  2. 输入 F 1 ( x ) F_1(x) F1(x) 经过第二层:
    F 2 ( F 1 ( x ) ) = W 2 ⋅ F 1 ( x ) + b 2 = [ 0.5 0.5 0.5 0.5 ] ⋅ [ 2.5 2.5 ] + [ − 0.5 − 0.5 ] = [ 0.5 ⋅ 2.5 + 0.5 ⋅ 2.5 − 0.5 0.5 ⋅ 2.5 + 0.5 ⋅ 2.5 − 0.5 ] = [ 2.5 − 0.5 2.5 − 0.5 ] = [ 2.0 2.0 ] F_2(F_1(x)) = W_2 \cdot F_1(x) + b_2 = \begin{bmatrix} 0.5 & 0.5 \\ 0.5 & 0.5 \end{bmatrix} \cdot \begin{bmatrix} 2.5 \\ 2.5 \end{bmatrix} + \begin{bmatrix} -0.5 \\ -0.5 \end{bmatrix} = \begin{bmatrix} 0.5 \cdot 2.5 + 0.5 \cdot 2.5 - 0.5 \\ 0.5 \cdot 2.5 + 0.5 \cdot 2.5 - 0.5 \end{bmatrix} = \begin{bmatrix} 2.5 - 0.5 \\ 2.5 - 0.5 \end{bmatrix} = \begin{bmatrix} 2.0 \\ 2.0 \end{bmatrix} F2(F1(x))=W2F1(x)+b2=[0.50.50.50.5][2.52.5]+[0.50.5]=[0.52.5+0.52.50.50.52.5+0.52.50.5]=[2.50.52.50.5]=[2.02.0]

  3. 通过残差连接:
    y = F 2 ( F 1 ( x ) ) + x = [ 2.0 2.0 ] + [ 1 2 ] = [ 3.0 4.0 ] y = F_2(F_1(x)) + x = \begin{bmatrix} 2.0 \\ 2.0 \end{bmatrix} + \begin{bmatrix} 1 \\ 2 \end{bmatrix} = \begin{bmatrix} 3.0 \\ 4.0 \end{bmatrix} y=F2(F1(x))+x=[2.02.0]+[12]=[3.04.0]

在这个例子中,网络通过学习残差使得输出更接近于输入。

4. 卷积自编码器与残差连接的实现

在实际应用中,残差连接不仅可以用于分类任务,还可以应用于自编码器等结构中。以下是一个卷积自编码器(ConvAutoencoder)和一个带有残差连接的卷积自编码器(ResidualConvAutoencoder)的实现代码:

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision import datasets
from torch.utils.data import DataLoader

# 检查设备
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")

# 数据准备
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))  # 归一化
])

# 下载 MNIST 数据集
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)


# 定义卷积自编码器(不带残差连接)
class ConvAutoencoder(nn.Module):
    def __init__(self):
        super(ConvAutoencoder, self).__init__()
        # 编码器
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=3, stride=2, padding=1),  # 16 x 14 x 14
            nn.ReLU(),
            nn.Conv2d(16, 4, kernel_size=3, stride=2, padding=1),  # 4 x 7 x 7
            nn.ReLU()
        )
        # 解码器
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(4, 16, kernel_size=3, stride=2, padding=1, output_padding=1),  # 16 x 14 x 14
            nn.ReLU(),
            nn.ConvTranspose2d(16, 1, kernel_size=3, stride=2, padding=1, output_padding=1),  # 1 x 28 x 28
            nn.Sigmoid()  # 使用 Sigmoid 确保输出在 [0, 1] 范围内
        )

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded


# 定义带残差连接的卷积自编码器
class ResidualConvAutoencoder(nn.Module):
    def __init__(self):
        super(ResidualConvAutoencoder, self).__init__()
        # 编码器
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=3, stride=2, padding=1),  # 16 x 14 x 14
            nn.ReLU(),
            nn.Conv2d(16, 4, kernel_size=3, stride=2, padding=1),  # 4 x 7 x 7
            nn.ReLU()
        )
        # 解码器
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(4, 16, kernel_size=3, stride=2, padding=1, output_padding=1),  # 16 x 14 x 14
            nn.ReLU(),
            nn.ConvTranspose2d(16, 1, kernel_size=3, stride=2, padding=1, output_padding=1)  # 1 x 28 x 28
        )

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)

        # 添加跳跃连接
        residual = x  # 原始输入
        decoded += residual  # 将输入添加到解码器的输出
        decoded = torch.sigmoid(decoded)  # 使用 Sigmoid 确保输出在 [0, 1] 范围内
        return decoded


# 训练和获取梯度的函数
def train_and_get_gradients(model, train_loader, criterion, optimizer, epochs=1):
    model.train()
    gradients = []

    for epoch in range(epochs):
        for images, _ in train_loader:  # 不需要标签
            images = images.to(device)  # 移动到设备
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, images)  # 自编码器的损失是重构损失
            loss.backward()

            # 收集梯度
            grads = [param.grad.clone() for param in model.parameters() if param.grad is not None]
            gradients.append(grads)

            optimizer.step()

    return gradients


# 初始化模型、损失函数和优化器
model_no_residual = ConvAutoencoder().to(device)
model_with_residual = ResidualConvAutoencoder().to(device)
criterion = nn.MSELoss()  # 使用均方误差损失
optimizer_no_residual = optim.Adam(model_no_residual.parameters(), lr=0.001)
optimizer_with_residual = optim.Adam(model_with_residual.parameters(), lr=0.001)

# 训练并获取梯度
print("Training ConvAutoencoder without residual connections...")
gradients_no_residual = train_and_get_gradients(model_no_residual, train_loader, criterion, optimizer_no_residual,
                                                epochs=1)

print("Training ResidualConvAutoencoder with residual connections...")
gradients_with_residual = train_and_get_gradients(model_with_residual, train_loader, criterion, optimizer_with_residual,
                                                  epochs=1)


# 计算并输出梯度的统计信息
def print_gradient_statistics(gradients, model_name):
    first_layer_grad = gradients[0][0]  # 获取第一个 batch 的第一个层的梯度
    l2_norm = torch.norm(first_layer_grad).item()
    min_grad = first_layer_grad.min().item()
    max_grad = first_layer_grad.max().item()
    mean_grad = first_layer_grad.mean().item()

    print(f"\nGradient statistics for {model_name} (first layer):")
    print(f"L2 Norm: {l2_norm:.4f}")
    print(f"Min: {min_grad:.4f}")
    print(f"Max: {max_grad:.4f}")
    print(f"Mean: {mean_grad:.4f}")


print_gradient_statistics(gradients_no_residual, "ConvAutoencoder without residual connections")
print_gradient_statistics(gradients_with_residual, "ResidualConvAutoencoder with residual connections")

5. 结果分析

根据训练结果,我们得到了以下梯度统计信息:

  • 不带残差连接的自编码器

    • L2 Norm: 0.0038
    • Min: -0.0007
    • Max: 0.0008
    • Mean: 0.0000
  • 带残差连接的自编码器

    • L2 Norm: 0.0083
    • Min: -0.0014
    • Max: 0.0017
    • Mean: 0.0001

结果分析

  1. 梯度的大小

    • 带有残差连接的自编码器的 L2 范数(0.0083)大于不带残差连接的自编码器(0.0038)。这表明带残差连接的模型在训练过程中对参数的更新幅度更大。
  2. 梯度的分布

    • 带有残差连接的自编码器的梯度范围(从 -0.0014 到 0.0017)相对较宽,说明模型在学习过程中对参数的调整更加积极。
    • 梯度的均值为 0.0001,表明在该层的权重更新上,正负梯度几乎相抵消,显示出模型在训练过程中的稳定性。
  3. 残差连接的影响

    • 残差连接的主要作用是提供一个直接的路径,使得梯度可以在网络中更有效地流动。尽管带有残差连接的自编码器的梯度较大,但这并不一定意味着学习效果更好。较大的梯度可能会导致不稳定的学习过程,特别是在深层网络中。

结论

残差网络通过引入跳跃连接,有效地缓解了深层网络中的梯度消失问题。通过在卷积自编码器中实现残差连接,我们能够观察到模型在训练过程中梯度的变化情况。虽然带有残差连接的模型在某些情况下梯度更大,但这并不一定意味着它的学习效果更好。实际应用中,仍需结合其他指标(如损失、准确率等)来全面评估模型的性能。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2245410.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Sqlite: Java使用、sqlite-devel

这里写目录标题 一、简介二、使用1. Java项目中(1)引入驱动(2)工具类(3)调用举例 2. sqlite-devel in linuxsqlite-devel使用 三、更多应用1. 数据类型2. 如何存储日期和时间3. 备份 一、简介 非常轻量级&…

MySQL深入:B+树的演化、索引和索引结构

提示:内容是读《MySQL技术内幕:InnoDB存储引擎》,笔记摘要 文章目录 二叉查找树平衡二叉树(AVL) B树(BTree)B树(BTree)InnoDB B树索引索引结构(InnoDB B树)B树存放的数据量 二叉查找树 在二叉查找树中,左子…

C语言-11-18笔记

1.C语言数据类型 类型存储大小值范围char1 字节-128 到 127 或 0 到 255unsigned char1 字节0 到 255signed char1 字节-128 到 127int2 或 4 字节-32,768 到 32,767 或 -2,147,483,648 到 2,147,483,647unsigned int2 或 4 字节0 到 65,535 或 0 到 4,294,967,295short2 字节…

“乐鑫组件注册表”简介

当启动一个新的开发项目时,开发者们通常会利用库和驱动程序等现有的代码资源。这种做法不仅节省时间,还简化了项目的维护工作。本文将深入探讨乐鑫组件注册表的概念及其核心理念,旨在指导您高效地使用和贡献组件。 概念解析 ESP-IDF 的架构…

【人工智能】PyTorch、TensorFlow 和 Keras 全面解析与对比:深度学习框架的终极指南

文章目录 PyTorch 全面解析2.1 PyTorch 的发展历程2.2 PyTorch 的核心特点2.3 PyTorch 的应用场景 TensorFlow 全面解析3.1 TensorFlow 的发展历程3.2 TensorFlow 的核心特点3.3 TensorFlow 的应用场景 Keras 全面解析4.1 Keras 的发展历程4.2 Keras 的核心特点4.3 Keras 的应用…

Sigrity SPEED2000 TDR TDT Simulation模式如何进行时域阻抗仿真分析操作指导-差分信号

Sigrity SPEED2000 TDR TDT Simulation模式如何进行时域阻抗仿真分析操作指导-差分信号 Sigrity SPEED2000 TDR TDT Simulation模式如何进行时域阻抗仿真分析操作指导-单端信号详细介绍了单端信号如何进行TDR仿真分析,下面介绍如何对差分信号进行TDR分析,还是以下图为例进行分…

Django一分钟:django中收集关联对象关联数据的方法

场景:我有一个模型,被其它多个模型关联,我配置了CASCADE级联删除,我想要告知用户删除该实例之后,哪些关联数据将会被一同删除。 假设我们当前有这样一组模型: class Warehouse(models.Model):""…

Flink学习连载第二篇-使用flink编写WordCount(多种情况演示)

使用Flink编写代码,步骤非常固定,大概分为以下几步,只要牢牢抓住步骤,基本轻松拿下: 1. env-准备环境 2. source-加载数据 3. transformation-数据处理转换 4. sink-数据输出 5. execute-执行 DataStream API开发 //n…

利用开源的低代码表单设计器FcDesigner高效管理和渲染复杂表单结构

FcDesigner 是一个强大的开源低代码表单设计器组件,支持快速拖拽生成表单。提供丰富的自定义及扩展功能,FcDesigner支持多语言环境,并允许开发者进行二次开发。通过将表单设计输出为JSON格式,再通过渲染器进行加载,实现…

【三合黑马指标】指标操盘技术图文教程,三线粘合抓黑马,短线买点持股辅助,通达信炒股软件指标

如上图,副图指标【三合黑马指标】,三条线彩线1-2-3,四条虚线代表四种短线技术做多信号,最底部的凸起形态线短线做多确认信号 。 黑马牛股选股技巧,可以选择周线三线粘合状态,在粘合时选股关注,如…

nwjs崩溃复现、 nwjs-控制台手动操纵、nwjs崩溃调用栈解码、剪切板例子中、nwjs混合模式、xdotool显示nwjs所有进程窗口列表

-1. nwjs在低版本ubuntu运行情况 ubuntu16.04运行nw-v0.93或0.89报错找不到NSS_3.30、GLIBC_2.25 uname -a #Linux Asus 4.15.0-112-generic #113~16.04.1-Ubuntu SMP Fri Jul 10 04:37:08 UTC 2020 x86_64 x86_64 x86_64 GNU/Linux cat /etc/issue #Ubuntu 16.04.7 LTS \n \l…

DICOM图像解析:深入解析DICOM格式文件的高效读取与处理

引言 在医学影像领域,DICOM(Digital Imaging and Communications in Medicine)标准已成为信息交换和存储的核心规范。掌握DICOM文件的读取与解析,对于开发医学影像处理软件至关重要。本文将系统地解析DICOM文件的结构、关键概念,并提供高效的读取与显示方法,旨在为开发者…

npm上传自己封装的插件(vue+vite)

一、npm账号及发包删包等命令 若没有账号,可在npm官网:https://www.npmjs.com/login 进行注册。 在当前项目根目录下打开终端命令窗口,常见命令如下: 1、登录命令:npm login(不用每次都重新登录&#xff0…

SpringAOP模拟实现

文章目录 1_底层切点、通知、切面2_切点匹配3_从 Aspect 到 Advisor1_代理创建器2_代理创建时机3_Before 对应的低级通知 4_静态通知调用1_通知调用过程2_模拟 MethodInvocation 5_动态通知调用 1_底层切点、通知、切面 注意点: 底层的切点实现底层的通知实现底层的…

Scala学习记录,全文单词统计

全文单词统计: 可分为以下几个步骤: 1.读取文件,得到很长的字符串 2.把字符串拆分成一个一个的单词 3.统计每个单词出现的次数 4.排序 5.把结果写入到一个文件中 完整代码如下: import java.io.PrintWriter import scala.io.So…

【UE5】使用基元数据对材质传参,从而避免新建材质实例

在项目中,经常会遇到这样的需求:多个模型(例如 100 个)使用相同的材质,但每个模型需要不同的参数设置,比如不同的颜色或随机种子等。 在这种情况下,创建 100 个实例材质不是最佳选择。正确的做…

电子应用设计方案-16:智能全屋灯光系统方案设计

智能全屋灯光系统方案设计 一、系统概述 本智能全屋灯光系统旨在为用户提供便捷、舒适、节能且个性化的照明体验,通过智能化的控制方式实现对全屋灯光的集中管理和灵活调控。 二、系统组成 1. 智能灯具 - 包括吸顶灯、吊灯、壁灯、台灯、筒灯、射灯等多种类型&#…

逆向题(23):nss:2956(花指令)

nss:2956(花指令) 打开主程序后,我们发现在这里有问题。而且跟之前学长讲的不一样。 我们学学长那样,先分解成数据,然后一步步从上往下按c去做,看看最后还会不会报错, 很显然没有…

28.<Spring博客系统⑤(部署的整个过程(CentOS))>

引入依赖 Spring-boot-maven-plugin 用maven进行打包的时候必须用到这个插件。看看自己pom.xml中有没有这个插件 并且看看配置正确不正常。 注&#xff1a;我们这个项目打的jar包在30MB左右。 <plugin><groupId>org.springframework.boot</groupId><artif…

力扣力扣力:860柠檬水找零

860. 柠檬水找零 - 力扣&#xff08;LeetCode&#xff09; 需要注意的是&#xff0c;我们一开始是没有任何钱的&#xff0c;也就是说我们需要拿着顾客的钱去找零。如果第一位顾客上来就是要找零那么我们无法完成&#xff0c;只能返回false。 分析&#xff1a; 上来我们先不分…