基于CNN的FashionMNIST数据集识别6——ResNet模型

news2025/3/31 0:47:08

前言

之前我们在cnn已经搞过VGG和GoogleNet模型了,这两种较深的模型出现了一些问题:

梯度传播问题

在反向传播过程中,梯度通过链式法则逐层传递。对于包含 L 层的网络,第 l 层的梯度计算为:

其中 a(k) 表示第 k层的激活值。当多个雅可比矩阵 ∂a(k+1)/∂a(k) 的乘积中出现大量小于1的特征值时(例如使用Sigmoid激活函数),梯度会指数级衰减(‌梯度消失‌);反之若特征值大于1,则梯度爆炸式增长(‌梯度爆炸‌)。 

实验证明,VGG-19的训练损失曲线在后期趋于平缓,参数更新停滞。

网络退化问题

当网络深度超过某个阈值时(例如20层),VGG会出现以下矛盾现象:

  • 训练误差不降反升(与过拟合无关)
  • 测试集准确率显著低于更浅的网络

网络退化问题通常是过深的网络的表达力下降导致的,原始像素信息需经过所有层的非线性变换,关键特征可能在传递过程中被破坏。

计算代价问题

以VGG-16为例:

  • 全连接层占总参数量的90%以上(约1.38亿参数中的1.22亿)
  • 最后三个全连接层(4096→4096→1000)产生巨大计算开销(我在训练的时候不得不减少前两个全连接层的神经元数量来尽快完成训练)。

 单张224×224图像前向传播的浮点运算量(FLOPs):

其中l是神经网络层数, Kl 为卷积核尺寸,Cin​、Cout 为输入/输出通道数。VGG-16的FLOPs高达15.5G 。训练起来太费劲了。

过拟合问题

模型复杂度应与训练数据规模匹配。VGG-16的1.38亿参数需要极大训练集(ImageNet的120万图像勉强足够),但在小数据集上,测试集准确率显著低于更紧凑的网络。

这些问题说明单纯叠深度不是万能的,甚至有副作用。这里我们使用ResNet来一定程度解决上面的问题。

源码

import torch
from torch import nn
from torchsummary import summary

class Residual(nn.Module):
    def __init__(self, in_channels, out_channels, use_1conv=False, strides = 1):
        super().__init__()
        self.Rulu = nn.ReLU()
        self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1, stride=strides)
        self.conv2 = nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.bn2 = nn.BatchNorm2d(out_channels)
        if use_1conv:
            self.conv3 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=strides)
        else:
            self.conv3 = None

    def forward(self, x):
        y = self.Rulu(self.bn1(self.conv1(x)))
        y = self.bn2(self.conv2(y))
        if self.conv3:
            x = self.conv3(x)
        y = self.Rulu(x + y)
        return y

class ResNet18(nn.Module):
    def __init__(self, Residual):
        super().__init__()
        self.block1 = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=64, kernel_size=7, stride=3, padding=3),
            nn.ReLU(),
            nn.BatchNorm2d(64),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        )
        self.block2 = nn.Sequential(
            Residual(64, 64, use_1conv=False, strides=1),
            Residual(64, 64, use_1conv=False, strides=1)
        )
        self.block3 = nn.Sequential(
            Residual(64, 128, use_1conv=True, strides=2),
            Residual(128, 128, use_1conv=False, strides=1)
        )
        self.block4 = nn.Sequential(
            Residual(128, 256, use_1conv=True, strides=2),
            Residual(256, 256, use_1conv=False, strides=1)
        )
        self.block5 = nn.Sequential(
            Residual(256, 512, use_1conv=True, strides=2),
            Residual(512, 512, use_1conv=False, strides=1)
        )
        self.block6 = nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            nn.Linear(512, 10)
        )

    def forward(self, x):
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.block4(x)
        x = self.block5(x)
        x = self.block6(x)
        return x

if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = ResNet18(Residual).to(device)
    print(summary(model, (1, 224, 224)))


 残差块

resnet的核心就是多个残差块,从图片中可以看到,残差块有两个显著特征:

  • 每次卷积后都加入了一个批量规范化层。
  • 每个残差块在计算完毕后都会将原始输入x叠加到输出里。

这种跳跃连接允许梯度直接回传至浅层,一定程度缓解了梯度爆炸/梯度消失问题。可以同时学习新特征和保留原始特征。

假设理想映射为 H(x),传统网络直接拟合 H(x),而残差网络拟合残差 F(x)=H(x)−x。当 H(x)≈x 时,后者优化目标 F(x)→0 比前者 H(x)→x 的优化难度更低。

批量规范化层

批量规范化(Batch Normalization,BN)是深度学习中革命性的技术之一,神经网络里加上这个效果嘎嘎好。

bn要解决的问题

内部协变量偏移(Internal Covariate Shift)

网络参数更新导致各层输入分布不断变化,迫使后续层需要持续适应新的数据分布,显著降低训练速度。分布偏移迫使网络使用更低的学习率来维持稳定性。

梯度传播障碍

梯度幅度在各层差异巨大,在链式传导时容易导致梯度消失/爆炸。此外,参数初始值对训练结果影响很明显。

bn的数学理论

对于每个batch B={x1,...,xm}:

从公式可以看出,bn的本质上是对每层神经网络的输入做了标准化处理。首先计算当前批次的平均值和方差,再进行归一化处理,消除物理量纲上的差异 。

最后有一步仿射变换,引入可学习的参数 ‌γ(缩放因子)‌ 和 ‌β(平移因子)‌,恢复数据表达能力。

恢复数据表达能力是咋回事呢?

首先标准化是一种数据处理方法,其目的是将数据调整到一个标准范围内,从而使得不同特征具有相同的尺度。标准化适用于特征的分布呈正态分布或接近正态分布的情况。

将每层强制标准化,可能某些特征的重要性被抑制,假设某层理想输出应为 ‌N(2, 0.5)‌,但标准化后变为 ‌N(0,1)‌,直接使用会损失原有分布的信息。

BN的仿射变换等价于在原模型上叠加了一个线性变换层,使网络能自主选择是否保留标准化效果‌。

自适应平均池化

nn.AdaptiveAvgPool2d((1, 1)),

自适应平均池化在resnet里替代了一部分全连接层。自适应平均池化和普通平均池化的区别是:

  • 普通平均池化‌:需手动指定窗口大小(如 kernel_size=3)和步幅(如 stride=2),输出尺寸由输入尺寸和参数共同决定。
  • 自适应平均池化‌:直接指定输出尺寸(如 (1,1)),PyTorch自动计算所需的窗口大小和步幅。

自适应平均池化的输入与输出‌:

输入‌:形状为 (batch_size, channels, H, W) 的4D张量。

输出‌:形状为 (batch_size, channels, 1, 1) 的4D张量。

其数学原理是对每个通道的特征图,计算所有元素求平均值,每个通道的特征图被池化为单个数值(平均值)。

自适应平均池化的主要优点是减少参数,防止过拟合,常用于较深的神经网络中。

思考:x+y和torch.cat

之前做GoogleNet的时候,前向传播里使用torch.cat做多路径计算融合,那么resNet里前向传播里的x+y融合是否也可以使用torch.cat代替?

答案是不行,因为resNet里的相加和GoogleNet里的相加有区别。

  • x + y是‌逐元素相加‌,要求两个张量的形状完全相同(通道数、尺寸一致)。
  • torch.cat([x, y], dim)是‌沿指定维度拼接‌,会改变输出张量的形状(通道数翻倍)。

所以resNet里残差块做相加时,两个张量的通道数,宽,高,要完全相同才行,相加之后宽,高,通道数都没有变化;而googleNet的inception块里做相加时,多个张量的通道数可以不同,宽,高完全相同,相加之后宽,高不变,而通道数是所有张量的通道数之和。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2322924.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

0323-B树、B+树

多叉树---->B树(磁盘)、B树 磁盘由多个盘片组成,每个盘片分为多个磁道和扇区。数据存储在这些扇区中,扇区之间通过指针链接,形成链式结构。 内存由连续的存储单元组成,每个单元有唯一地址,数…

【工作记录】F12查看接口信息及postman中使用

可参考 详细教程:如何从前端查看调用接口、传参及返回结果(附带图片案例)_f12查看接口及参数-CSDN博客 1、接口信息 接口基础知识2:http通信的组成_接口请求信息包括-CSDN博客 HTTP类型接口之请求&响应详解 - 三叔测试笔记…

2024年认证杯SPSSPRO杯数学建模B题(第二阶段)神经外科手术的定位与导航全过程文档及程序

2024年认证杯SPSSPRO杯数学建模 B题 神经外科手术的定位与导航 原题再现: 人的大脑结构非常复杂,内部交织密布着神经和血管,所以在大脑内做手术具有非常高的精细和复杂程度。例如神经外科的肿瘤切除手术或血肿清除手术,通常需要…

Android 12系统源码_系统启动(二)Zygote进程

前言 Zygote(意为“受精卵”)是 Android 系统中的一个核心进程,负责 孵化(fork)应用进程,以优化应用启动速度和内存占用。它是 Android 系统启动后第一个由 init 进程启动的 Java 进程,后续所有…

MOSN(Modular Open Smart Network)-05-MOSN 平滑升级原理解析

前言 大家好,我是老马。 sofastack 其实出来很久了,第一次应该是在 2022 年左右开始关注,但是一直没有深入研究。 最近想学习一下 SOFA 对于生态的设计和思考。 sofaboot 系列 SOFAStack-00-sofa 技术栈概览 MOSN(Modular O…

Flink介绍与安装

Apache Flink是一个在有界数据流和无界数据流上进行有状态计算分布式处理引擎和框架。Flink 设计旨在所有常见的集群环境中运行,以任意规模和内存级速度执行计算。 一、主要特点和功能 1. 实时流处理: 低延迟: Flink 能够以亚秒级的延迟处理数据流,非常…

【gradio】从零搭建知识库问答系统-Gradio+Ollama+Qwen2.5实现全流程

从零搭建大模型问答系统-GradioOllamaQwen2.5实现全流程(一) 前言一、界面设计(计划)二、模块设计1.登录模块2.注册模块3. 主界面模块4. 历史记录模块 三、相应的接口(前后端交互)四、实现前端界面的设计co…

PowerBI,用度量值实现表格销售统计(含合计)的简单示例

假设我们有产品表 和销售表 我们想实现下面的效果 表格显示每个产品的信息,以及单个产品的总销量 有一个切片器能筛选各个门店的产品销量 还有一个卡片图显示所筛选条件下,所有产品的总销量 实现方法: 1.我们新建一个计算表,把…

26考研——查找_树形查找_二叉排序树(BST)(7)

408答疑 文章目录 三、树形查找二叉排序树(BST)二叉排序树中结点值之间的关系二叉树形查找二叉排序树的查找过程示例 向二叉排序树中插入结点插入过程示例 构造二叉排序树的过程构造示例 二叉排序树中删除结点的操作情况一:被删除结点是叶结点…

【行驶证识别】批量咕嘎OCR识别行驶证照片复印件图片里的文字信息保存表格或改名字,基于QT和腾讯云api_ocr的实现方式

项目背景 在许多业务场景中,如物流管理、车辆租赁、保险理赔等,常常需要处理大量的行驶证照片复印件。手动录入行驶证上的文字信息,像车主姓名、车辆型号、车牌号码等,不仅效率低下,还容易出现人为错误。借助 OCR(光学字符识别)技术,能够自动识别行驶证图片中的文字信…

21.Excel自动化:如何使用 xlwings 进行编程

一 将Excel用作数据查看器 使用 xlwings 中的 view 函数。 1.导包 import datetime as dt import xlwings as xw import pandas as pd import numpy as np 2.view 函数 创建一个基于伪随机数的DataFrame,它有足够多的行,使得只有首尾几行会被显示。 df …

LabVIEW FPGA与Windows平台数据滤波处理对比

LabVIEW在FPGA和Windows平台均可实现数据滤波处理,但两者的底层架构、资源限制、实时性及应用场景差异显著。FPGA侧重硬件级并行处理,适用于高实时性场景;Windows依赖软件算法,适合复杂数据处理与可视化。本文结合具体案例&#x…

【NLP 48、大语言模型的神秘力量 —— ICL:in context learning】

目录 一、ICL的优势 1.传统做法 2.ICL做法 二、ICL的发展 三、ICL成因的两种看法 1.meta learning 2.Bayesian Inference 四、ICL要点 ① 语言模型的规模 ② 提示词prompt中提供的examples数量和顺序 ③ 提示词prompt的形式(format) 五、fine-tune VS I…

vue 中渲染 markdown 格式的文本

文章目录 需求分析第一步:安装依赖第二步:创建 Markdown 渲染组件第三步,使用实例扩展功能1. 代码高亮:2. 自定义渲染规则:需求 渲染 markdown 格式的文本 分析 在Vue 3中实现Markdown渲染的常见方法。通常有两种方式:使用现有的Markdown解析库,或者自己编写解析器…

工业4G路由器赋能智慧停车场高效管理

工业4G路由器作为智慧停车场管理系统通信核心,将停车场内的各个子系统连接起来,包括车牌识别系统、道闸控制系统、车位检测系统、收费系统以及监控系统等。通过4G网络,将这些系统采集到的数据传输到云端服务器或管理中心,实现信息…

企业如何平稳实现从Tableau到FineBI的信创迁移?

之前和大家分享了《如何将Tableau轻松迁移到Power BI》。但小编了解到,如今有些企业更愿意选择国产BI平台。为此,小编今天以Fine BI为例子,介绍如何从Tableau轻松、低成本地迁移到国产BI平台。 在信创政策全面推进的背景下,企业数…

蓝桥与力扣刷题(蓝桥 蓝桥骑士)

题目:小明是蓝桥王国的骑士,他喜欢不断突破自我。 这天蓝桥国王给他安排了 N 个对手,他们的战力值分别为 a1,a2,...,an,且按顺序阻挡在小明的前方。对于这些对手小明可以选择挑战,也可以选择避战。 身为高傲的骑士&a…

前端学习笔记--CSS

HTMLCSSJavaScript 》 结构 表现 交互 如何学习 1.CSS是什么 2.CSS怎么用? 3.CSS选择器(重点,难点) 4.美化网页(文字,阴影,超链接,列表,渐变。。。) 5…

31天Python入门——第15天:日志记录

你好,我是安然无虞。 文章目录 日志记录python的日志记录模块创建日志处理程序并配置输出格式将日志内容输出到控制台将日志写入到文件 logging更简单的一种使用方式 日志记录 日志记录是一种重要的应用程序开发和维护技术, 它用于记录应用程序运行时的关键信息和…

使用ucharts写的小程序,然后让圆环中间的空白位置变大

将ringWidth属性调小 extra: { ring: { ringWidth: 20, activeOpacity: 1.5, activeRadius: 10, offsetAngle: 0, labelWidth: 15, border: true, borderWidth: 0, borderColor: #F…