ResNet网络分析与demo实例

news2025/2/4 18:00:32

参考自 

  • up主的b站链接:霹雳吧啦Wz的个人空间-霹雳吧啦Wz个人主页-哔哩哔哩视频
  • 这位大佬的博客 Fun'_机器学习,pytorch图像分类,工具箱-CSDN博客

 ResNet 详解

原论文地址 [1512.03385] Deep Residual Learning for Image Recognition (arxiv.org)

ResNet 网络是在 2015年 由微软实验室提出,斩获当年ImageNet竞赛中分类任务第一名,目标检测第一名。获得COCO数据集中目标检测第一名,图像分割第一名。

在ResNet网络的创新点:

  • 提出 Residual 结构(残差结构),并搭建超深的网络结构(可突破1000层)
  • 使用 Batch Normalization 加速训练(丢弃dropout)

下图是ResNet34层模型的结构简图:

在ResNet网络提出之前,传统的卷积神经网络都是通过将一系列卷积层与池化层进行堆叠得到的。

一般我们会觉得网络越深,特征信息越丰富,模型效果应该越好。但是实验证明,当网络堆叠到一定深度时,会出现两个问题

梯度消失或梯度爆炸

退化问题

如下图所示,20层网络 反而比 56层网络 的误差更小:

对于梯度消失或梯度爆炸问题,ResNet论文提出通过数据的预处理以及在网络中使用

 BN(Batch Normalization)层来解决。

对于退化问题,ResNet论文提出了 residual结构残差结构)来减轻退化问题,下图是使用residual结构的卷积网络,可以看到随着网络的不断加深,效果并没有变差,而是变的更好了。(虚线是train error,实线是test error)

为了解决深层网络中的退化问题,可以人为地让神经网络某些层跳过下一层神经元的连接,隔层相连,弱化每层之间的强联系。这种神经网络被称为 残差网络 (ResNets)。

残差网络由许多隔层相连的神经元子模块组成,我们称之为 残差块 Residual block。单个残差块的结构如下图所示:

原文的表注中已说明,conv3_x, conv4_x, conv5_x所对应的一系列残差结构的第一层残差结构都是虚线残差结构。因为这一系列残差结构的第一层都有调整输入特征矩阵shape的使命(将特征矩阵的高和宽缩减为原来的一半,将深度channel调整成下一层残差结构所需要的channel)

需要注意的是,对于ResNet50/101/152,其实conv2_x所对应的一系列残差结构的第一层也是虚线残差结构,因为它需要调整输入特征矩阵的channel。根据表格可知通过3x3的max pool之后输出的特征矩阵shape应该是[56, 56, 64],但conv2_x所对应的一系列残差结构中的实线残差结构它们期望的输入特征矩阵shape是[56, 56, 256](因为这样才能保证输入输出特征矩阵shape相同,才能将捷径分支的输出与主分支的输出进行相加)。所以第一层残差结构需要将shape从[56, 56, 64] --> [56, 56, 256]。注意,这里只调整channel维度,高和宽不变(而conv3_x, conv4_x, conv5_x所对应的一系列残差结构的第一层虚线残差结构不仅要调整channel还要将高和宽缩减为原来的一半)。

下面是 ResNet 18/34 和 ResNet 50/101/152 具体的实线/虚线残差结构图:
 

ResNet 18/34

ResNet 50/101/152s

在迁移学习中,我们希望利用源任务(Source Task)学到的知识帮助学习目标任务 (Target Task)。例如,一个训练好的图像分类网络能够被用于另一个图像相关的任务。再比如,一个网络在仿真环境学习的知识可以被迁移到真实环境的网络。迁移学习一个典型的例子就是载入训练好VGG网络,这个大规模分类网络能将图像分到1000个类别,然后把这个网络用于另一个任务,如医学图像分类。

为什么可以这么做呢?如下图所示,神经网络逐层提取图像的深层信息,这样,预训练网络就相当于一个特征提取器。

model.py

  • 定义ResNet18/34的残差结构,即BasicBlock
  • 定义ResNet50/101/152的残差结构,即Bottleneck
  • 定义ResNet网络结构
  • 定义resnet18/34/50/101/152

import torch.nn as nn
import torch


# ResNet18/34的残差结构,用的是2个3x3的卷积
class BasicBlock(nn.Module):
    expansion = 1  # 残差结构中,主分支的卷积核个数是否发生变化,不变则为1

    def __init__(self, in_channel, out_channel, stride=1, downsample=None):  # downsample对应虚线残差结构
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel,
                               kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channel)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel,
                               kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channel)
        self.downsample = downsample

    def forward(self, x):
        identity = x
        if self.downsample is not None:  # 虚线残差结构,需要下采样
            identity = self.downsample(x)  # 捷径分支 short cut

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        out += identity
        out = self.relu(out)

        return out

# ResNet50/101/152的残差结构,用的是1x1+3x3+1x1的卷积
class Bottleneck(nn.Module):
    expansion = 4  # 残差结构中第三层卷积核个数是第一/二层卷积核个数的4倍

    def __init__(self, in_channel, out_channel, stride=1, downsample=None):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel,
                               kernel_size=1, stride=1, bias=False)  # squeeze channels
        self.bn1 = nn.BatchNorm2d(out_channel)
        # -----------------------------------------
        self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel,
                               kernel_size=3, stride=stride, bias=False, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channel)
        # -----------------------------------------
        self.conv3 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel * self.expansion,
                               kernel_size=1, stride=1, bias=False)  # unsqueeze channels
        self.bn3 = nn.BatchNorm2d(out_channel * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample

    def forward(self, x):
        identity = x
        if self.downsample is not None:
            identity = self.downsample(x)  # 捷径分支 short cut

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        out += identity
        out = self.relu(out)

        return out


class ResNet(nn.Module):
    # block = BasicBlock or Bottleneck
    # block_num为残差结构中conv2_x~conv5_x中残差块个数,是一个列表
    def __init__(self, block, blocks_num, num_classes=1000, include_top=True):
        super(ResNet, self).__init__()
        self.include_top = include_top
        self.in_channel = 64

        self.conv1 = nn.Conv2d(3, self.in_channel, kernel_size=7, stride=2,
                               padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(self.in_channel)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, blocks_num[0])             # conv2_x
        self.layer2 = self._make_layer(block, 128, blocks_num[1], stride=2)  # conv3_x
        self.layer3 = self._make_layer(block, 256, blocks_num[2], stride=2)  # conv4_x
        self.layer4 = self._make_layer(block, 512, blocks_num[3], stride=2)  # conv5_x
        if self.include_top:
            self.avgpool = nn.AdaptiveAvgPool2d((1, 1))  # output size = (1, 1)
            self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')

    # channel为残差结构中第一层卷积核个数
    def _make_layer(self, block, channel, block_num, stride=1):
        downsample = None

        # ResNet50/101/152的残差结构,block.expansion=4
        if stride != 1 or self.in_channel != channel * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.in_channel, channel * block.expansion, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(channel * block.expansion))

        layers = []
        layers.append(block(self.in_channel, channel, downsample=downsample, stride=stride))
        self.in_channel = channel * block.expansion

        for _ in range(1, block_num):
            layers.append(block(self.in_channel, channel))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        if self.include_top:
            x = self.avgpool(x)
            x = torch.flatten(x, 1)
            x = self.fc(x)

        return x


def resnet34(num_classes=1000, include_top=True):
    return ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes, include_top=include_top)


def resnet101(num_classes=1000, include_top=True):
    return ResNet(Bottleneck, [3, 4, 23, 3], num_classes=num_classes, include_top=include_top)

train.py

由于ResNet网络较深,直接训练的话会非常耗时,因此用迁移学习的方法导入预训练好的模型参数:
在pycharm中输入import torchvision.models.resnet,ctrl+左键resnet跳转到pytorch官方实现resnet的源码中,下载预训练的模型参数:
 

model_urls = {
    'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
    'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
    'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
    'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
    'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
    'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
    'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
    'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
    'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
}

import torch
from model import resnet34
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
import json

data_transform = transforms.Compose(
    [transforms.Resize(256),
     transforms.CenterCrop(224),
     transforms.ToTensor(),
     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

# load image
img = Image.open("../tulip.jpg")
plt.imshow(img)
# [N, C, H, W]
img = data_transform(img)
# expand batch dimension
img = torch.unsqueeze(img, dim=0)

# read class_indict
try:
    json_file = open('./class_indices.json', 'r')
    class_indict = json.load(json_file)
except Exception as e:
    print(e)
    exit(-1)

# create model
model = resnet34(num_classes=5)
# load model weights
model_weight_path = "./resNet34.pth"
model.load_state_dict(torch.load(model_weight_path))
model.eval()
with torch.no_grad():
    # predict class
    output = torch.squeeze(model(img))
    predict = torch.softmax(output, dim=0)
    predict_cla = torch.argmax(predict).numpy()
print(class_indict[str(predict_cla)], predict[predict_cla].numpy())
plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1337303.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

短视频矩阵系统:赋予用户创造与分享的力量

在如今快节奏的社交网络时代,人们对于信息获取和娱乐方式的需求也逐渐发生了变化。作为当下最受欢迎的短视频平台之一,抖音短视频矩阵系统正以其独特的魅力和吸引力,深深地打动着亿万用户。 抖音短视频矩阵系统是一种基于移动端的短视频分享…

基于STC89C52RC的温湿度显示与按键可调的时钟显示

大学时候的课程设计项目,本人只负责软件设计。 课题摘要 摘 要 温湿度参数的检测已经成为人们日常生产生活中的一个重要的参数指标。温度和湿度是两个最基本的环境参数,人们生活与温湿度息息相关。在工农业生产、环保、科研、化工业、制药业等地方&…

客户跟进没效果?这三招请收好!!

在现代商业环境中,与客户进行有效的跟进至关重要。但是,有时候不论我们多么努力地跟进,却依然无法获得预期的结果。 今天就给大家分享三个高效跟进客户的方法,帮助大家提高效率! 首先,了解客户需求是关键…

【MySQL变更】gh-ost原理解读

gh-ost简介 gh-ost是处理MySQL在线表结构变更的工具,与pt-osc 不同,gh-ost不会使用触发器。 gh-ost 可以进行测试,暂停,动态控制和重新配置,审计还有其他许多操作perks。 命名 最初它被命名为gh-osc:Git…

从入门到精通,30天带你学会C++【第九天:排序合集】

Everyday English Never put off what you can do until tomorrow. 今日事,今日毕。 前言 首先跟大家说声抱歉,我已经25天没写博客了,我知道我掉了很多粉丝,但是还有很多人坚持关注着我,在这里我表示感谢…

C++ Qt开发:QSqlDatabase数据库组件

Qt 是一个跨平台C图形界面开发库,利用Qt可以快速开发跨平台窗体应用程序,在Qt中我们可以通过拖拽的方式将不同组件放到指定的位置,实现图形化开发极大的方便了开发效率,本章将重点介绍QSqlDatabase数据库模块的常用方法及灵活运用…

【软件工程大题】白盒测试

给出一个简单的测试样例,然后再进行白盒测试的讲解 if A and B then action1 if C or D then action2 1.语句覆盖 每个语句执行一次 也就是,样例中的每个语句执行一次,至于ABCD取值,要满足IF条件,让四个语句都执行一次 A and B -> T ⇒ AT …

蓝桥杯的学习规划

c语言基础: Python语言基础 学习路径:画框的要着重学习

2023年教程汇总 | 《小杜的生信笔记》

2023年总结 2023年即将结束,我们即将迎来2024年。2023年,我们做了什么呢??这个是个值得深思的问题…? 12月份是个快乐且痛苦时间节点。前一段时间,单位需要提交2023年工作总结,真的是憋了好久才可以下笔…

【数据结构】无向图的最小生成树(Prime,Kruskal算法)

文章目录 前言一、最小生成树二、Kruskal算法1.方法:2.判断是否成环3.代码实现 三、 Prim算法1.方法:2.代码 四、源码 前言 连通图:在无向图中,若从顶点v1到顶点v2有路径,则称顶点v1与顶点v2是连通的。如果图中任意一对…

医院信息化-6 大模型与医疗

之前写了一系列跟医疗信息化相关的内容,其中有提到人工智能,但是写的都是原先的一些AI算法基础上的医疗应用。现在大模型出现的涌现推理能力确实让人惊讶,并且出现可商用化的可能性,因此最近一年关于大模型在医疗的应用也开始出现…

ComfyUI如何中文汉化

comfyui中文地址如下: https://github.com/AIGODLIKE/AIGODLIKE-ComfyUI-Translationhttps://github.com/AIGODLIKE/AIGODLIKE-ComfyUI-Translation如何安装? 1. git安装 进入项目目录下的custom_nodes目录下,然后进入控制台,运…

Java——基本数据类型

Java基本数据类型 一、 整型1. byte2. short3. int4. long 二、浮点型1. float2. double 三、 字符型(char)四、 布尔型(boolean) 总结 算下刚转Java到现在也有三个多月了,所以打算对Java的知识进行汇总一下,本篇文章介绍一下Java…

Linux之用户/组 管理

关机&重启命令 shutdown -h now立刻进行关机shutdown -h 11分钟后关机(shutdown默认等于shutdown -h 1) -h即halt shutdown -r now现在重新启动计算机 -r即reboot halt关机reboot重新启动计算机sync把内存数据同步到磁盘 再进行shutdown/reboot/halt命令在执行…

【支持向量机】SVM线性可分支持向量机学习算法——硬间隔最大化支持向量机及例题详解

支特向量机(support vector machines, SVM)是一种二类分类模型。它的基本模型是定义在特征空间上的间隔最大的线性分类器。包含线性可分支持向量机、 线性支持向量机、非线性支持向量机。 当训练数据线性可分时,通过硬间隔最大化学习线性分类器, 即为线性…

提升泵类设备性能的解决方案:基于AI的预测性维护

随着工业的智能化和数字化发展,设备维护的方式得到不断优化。人工智能(AI)、机器学习和云计算等先进技术的引入,使得设备健康管理系统的数据采集、实时分析、故障预警与智能诊断等能力得到提升。借助这些设备预测性维护手段&#…

LISN到底是啥?干啥用的?

LISN是在EMC测试的时候,会被使用的设备,如下图所示: 双路V型电源阻抗稳定网络。它完全符合CISPR16-1-2、MIL-STD 461F、VDE 0876、FCC Part 15标准的要求,其等效电路为50Ω||(5Ω50μH),频率范围…

vue3使用mixins

<template><div>{{ num }}___{{ fav }}</div><button click"favBtn">改变值</button> </template><script setup lang"ts"> import mixin from "../mixins/mixin"; let { num, fav, favBtn } mixin(…

【微服务核心】Spring Boot

Spring Boot 文章目录 Spring Boot1. 简介2. 开发步骤3. 配置文件4. 整合 Spring MVC 功能5. 整合 Druid 和 Mybatis6. 使用声明式事务7. AOP整合配置8. SpringBoot项目打包和运行 1. 简介 SpringBoot&#xff0c;开箱即用&#xff0c;设置合理的默认值&#xff0c;同时也可以…

【MySQL】数据库之日志管理、备份与恢复

目录 一、MySQL的日志管理 二、MySQL的完全备份与恢复 物理冷备份&#xff08;完全备份&#xff09;与恢复 数据库上云迁移的方案&#xff1f; 逻辑热备份&#xff08;完全备份&#xff09;与恢复 三、MySQL的增量备份与恢复 1、手动增量备份 2、脚本增量备份 3、增量备…