ResNet网络详解

news2025/4/16 16:46:47

ResNet

ResNet在2015年由微软实验室提出,斩获当年lmageNet竞赛中分类任务第一名,目标检测第一名。获得coco数据集中目标检测第一名,图像分割第一名。

ResNet亮点

1.超深的网络结构(突破1000层)
2.提出residual模块
3.使用Batch Normalization加速训练(丢弃dropout)

网络一定是越深越好吗?

一般来说,网络越深,咱们能获取的信息越多,而且特征也越丰富。我们思考一个问题,我们一直加深我们的网络,取得的效果一定越好吗?

在这里插入图片描述

看到这两幅图我们看到,网络并不是深度越深效果越好,那么是什么原因导致的呢?在ResNet论文中,提出了两个问题,一个是随着网络深度的增加,梯度消失或梯度爆炸问题;另一个是退化问题。为了让更深的网络也能训练出好的效果,何恺明提出了ResNet网络。

在这里插入图片描述

残差结构

在这里插入图片描述

左边的残差结构用于ResNet-34,右边的残差结构用于ResNet-50/101/152。残差结构可以用表达式F(x)=f(x)+x表示。左边的残差结构很简单,主线f(x)是两个3x3的卷积,最后与x合并。右边多了两个1x1的卷积,用来降维和升维(减少参数量)。

ResNet网络结构

在这里插入图片描述

我们着重分析一下ResNet-34结构,可以看到是由多个残差结构堆叠而成:

在这里插入图片描述

为什么会有实线和虚线之分呢?从网络参数图上可以看到,经过一个类型的第一个残差结构后,图像的大小为原来的一半,我们拿第一条虚线举例:输入图像为56x56x64,输出为28x28x128,所以分支需要一个1x1卷积来将输入图处理成输出大小。

请添加图片描述

右侧虚线残差结构的主分支上第一个1x1卷积层的步距是2,第二个3x3卷积层步距是1。但在pytorch官方实现过程中是第一个1x1卷积层的步距是1,第二个3x3卷积层步距是2,这样能够在imagenet的top1上提升大概0.5%的准确率。

Batch Normalization

Batch Normalization是google团队在2015年论文《Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift》提出的。通过该方法能够加速网络的收敛并提升准确率。我们在图像预处理过程中通常会对图像进行标准化处理,这样能够加速网络的收敛,假设对于Conv1来说输入的就是满足某一分布的特征矩阵,但对于Conv2而言经过Conv1卷积后输入的feature map就不一定满足某一分布规律了。而我们Batch Normalization的目的就是使我们的feature map满足均值为0,方差为1的分布规律。下图是一个示例:

在这里插入图片描述

我们需要知道的是将bn层放在卷积层(Conv)和激活层(例如Relu)之间,且卷积层不要使用偏置bias,如图所示:

请添加图片描述

ResNet实现

1.建立模型:
import torch.nn as nn
import torch


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_channel, out_channel, stride=1, downsample=None, **kwargs):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel,
                               kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channel)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel,
                               kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channel)
        self.downsample = downsample

    def forward(self, x):
        identity = x
        if self.downsample is not None:
            identity = self.downsample(x)

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        out += identity
        out = self.relu(out)

        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_channel, out_channel, stride=1, downsample=None,
                 groups=1, width_per_group=64):
        super(Bottleneck, self).__init__()

        width = int(out_channel * (width_per_group / 64.)) * groups

        self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=width,
                               kernel_size=1, stride=1, bias=False)  # squeeze channels
        self.bn1 = nn.BatchNorm2d(width)
        # -----------------------------------------
        self.conv2 = nn.Conv2d(in_channels=width, out_channels=width, groups=groups,
                               kernel_size=3, stride=stride, bias=False, padding=1)
        self.bn2 = nn.BatchNorm2d(width)
        # -----------------------------------------
        self.conv3 = nn.Conv2d(in_channels=width, out_channels=out_channel*self.expansion,
                               kernel_size=1, stride=1, bias=False)  # unsqueeze channels
        self.bn3 = nn.BatchNorm2d(out_channel*self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample

    def forward(self, x):
        identity = x
        if self.downsample is not None:
            identity = self.downsample(x)

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        out += identity
        out = self.relu(out)

        return out


class ResNet(nn.Module):

    def __init__(self,
                 block,
                 blocks_num,
                 num_classes=1000,
                 include_top=True,
                 groups=1,
                 width_per_group=64):
        super(ResNet, self).__init__()
        self.include_top = include_top
        self.in_channel = 64

        self.groups = groups
        self.width_per_group = width_per_group

        self.conv1 = nn.Conv2d(3, self.in_channel, kernel_size=7, stride=2,
                               padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(self.in_channel)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, blocks_num[0])
        self.layer2 = self._make_layer(block, 128, blocks_num[1], stride=2)
        self.layer3 = self._make_layer(block, 256, blocks_num[2], stride=2)
        self.layer4 = self._make_layer(block, 512, blocks_num[3], stride=2)
        if self.include_top:
            self.avgpool = nn.AdaptiveAvgPool2d((1, 1))  # output size = (1, 1)
            self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')

    def _make_layer(self, block, channel, block_num, stride=1):
        downsample = None
        if stride != 1 or self.in_channel != channel * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.in_channel, channel * block.expansion, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(channel * block.expansion))

        layers = []
        layers.append(block(self.in_channel,
                            channel,
                            downsample=downsample,
                            stride=stride,
                            groups=self.groups,
                            width_per_group=self.width_per_group))
        self.in_channel = channel * block.expansion

        for _ in range(1, block_num):
            layers.append(block(self.in_channel,
                                channel,
                                groups=self.groups,
                                width_per_group=self.width_per_group))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        if self.include_top:
            x = self.avgpool(x)
            x = torch.flatten(x, 1)
            x = self.fc(x)

        return x


def resnet34(num_classes=1000, include_top=True):
    # https://download.pytorch.org/models/resnet34-333f7ec4.pth
    return ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes, include_top=include_top)


def resnet50(num_classes=1000, include_top=True):
    # https://download.pytorch.org/models/resnet50-19c8e357.pth
    return ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, include_top=include_top)


def resnet101(num_classes=1000, include_top=True):
    # https://download.pytorch.org/models/resnet101-5d3b4d8f.pth
    return ResNet(Bottleneck, [3, 4, 23, 3], num_classes=num_classes, include_top=include_top)


def resnext50_32x4d(num_classes=1000, include_top=True):
    # https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth
    groups = 32
    width_per_group = 4
    return ResNet(Bottleneck, [3, 4, 6, 3],
                  num_classes=num_classes,
                  include_top=include_top,
                  groups=groups,
                  width_per_group=width_per_group)


def resnext101_32x8d(num_classes=1000, include_top=True):
    # https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth
    groups = 32
    width_per_group = 8
    return ResNet(Bottleneck, [3, 4, 23, 3],
                  num_classes=num_classes,
                  include_top=include_top,
                  groups=groups,
                  width_per_group=width_per_group)

2.训练模型:
import os
import sys
import json

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets
from tqdm import tqdm

from model import resnet34


def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("using {} device.".format(device))

    data_transform = {
        "train": transforms.Compose([
                                     transforms.RandomResizedCrop(224),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
                                     # transforms.ConvertImageDtype('RGB')
                                     ]),
        "val": transforms.Compose([transforms.Resize(256),
                                   transforms.CenterCrop(224),
                                   transforms.ToTensor(),
                                   transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}

    train_dataset = datasets.ImageFolder(root='./train',
                                         transform=data_transform["train"])
    train_num = len(train_dataset)
    flower_list = train_dataset.class_to_idx
    cla_dict = dict((val, key) for key, val in flower_list.items())
    json_str = json.dumps(cla_dict, indent=4)
    with open(
            'class_indices.json', 'w') as json_file:
        json_file.write(json_str)

    batch_size = 16
    nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
    print('Using {} dataloader workers every process'.format(nw))

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size, shuffle=True,
                                               num_workers=nw)

    validate_dataset = datasets.ImageFolder(root='./val',
                                            transform=data_transform["val"])
    val_num = len(validate_dataset)
    validate_loader = torch.utils.data.DataLoader(validate_dataset,
                                                  batch_size=batch_size, shuffle=False,
                                                  num_workers=nw)

    print("using {} images for training, {} images for validation.".format(train_num, val_num))
    
    net = resnet34(num_classes=36)
    net.to(device)

    # define loss function
    loss_function = nn.CrossEntropyLoss()

    # construct an optimizer
    params = [p for p in net.parameters() if p.requires_grad]
    optimizer = optim.Adam(params, lr=0.0001)

    epochs = 15
    best_acc = 0.0
    save_path = './resNet34.pth'
    train_steps = len(train_loader)
    for epoch in range(epochs):
        # train
        net.train()
        running_loss = 0.0
        train_bar = tqdm(train_loader, file=sys.stdout)
        for step, data in enumerate(train_bar):
            images, labels = data
            optimizer.zero_grad()
            logits = net(images.to(device))
            loss = loss_function(logits, labels.to(device))
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()

            train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,
                                                                     epochs,
                                                                     loss)

        # validate
        net.eval()
        acc = 0.0  # accumulate accurate number / epoch
        with torch.no_grad():
            val_bar = tqdm(validate_loader, file=sys.stdout)
            for val_data in val_bar:
                val_images, val_labels = val_data
                outputs = net(val_images.to(device))
                # loss = loss_function(outputs, test_labels)
                predict_y = torch.max(outputs, dim=1)[1]
                acc += torch.eq(predict_y, val_labels.to(device)).sum().item()

                val_bar.desc = "valid epoch[{}/{}]".format(epoch + 1,
                                                           epochs)

        val_accurate = acc / val_num
        print('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %
              (epoch + 1, running_loss / train_steps, val_accurate))

        if val_accurate > best_acc:
            best_acc = val_accurate
            torch.save(net.state_dict(), save_path)

    print('Finished Training')


if __name__ == '__main__':
    main()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/19182.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

java项目-第147期ssm社区生活超市管理系统_(spring+springmvc+mybatis+jsp)_java毕业设计_计算机毕业设计

java项目-第147期ssm社区生活超市管理系统_(springspringmvcmybatisjsp)_java毕业设计_计算机毕业设计 【源码请到资源专栏下载】 今天分享的项目是《ssm社区生活超市管理系统》 该项目分为3个角色,管理员、用户、供应商角色。 用户可以浏览前台商品,进行…

[附源码]java毕业设计软件项目过程管理系统

项目运行 环境配置: Jdk1.8 Tomcat7.0 Mysql HBuilderX(Webstorm也行) Eclispe(IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持)。 项目技术: SSM mybatis Maven Vue 等等组成,B/S模式 M…

牛客小白月赛 61 E 排队

题目大意: n个数,共有n!种排列方式,记Pi(a)表示序列a的第i种排队方式,cnt(Pi(a))表示P(i)的逆序对个数,PLMM想知道这n!种排列方式共有多少对逆序对 给定一个 nnn 个数,在所有排列顺序…

Windows 11 Insider Preview Build 25247.1000(rs_prerelease)更新内容

微软于今日推出了新的Dev预览版25247.1000,引入了电源设置的新能源建议,“帐户”页面现在会在 OneDrive 存储空间不足时显示警告。下面就和小编一起来看看详细的更新内容吧。 更新内容 此版本包括一些新功能,包括能源建议、任务管理器的一些改…

MySQL8.0优化 - 锁 - 从数据操作的类型划分:读锁、写锁

文章目录学习资料锁的不同角度分类锁的分类图如下从数据操作的类型划分:读锁、写锁读锁写锁锁定读MySQL8.0新特性写操作学习资料 【MySQL数据库教程天花板,mysql安装到mysql高级,强!硬!-哔哩哔哩】 【阿里巴巴Java开…

【21-业务开发-基础业务-商品模块-分类管理-商品系统三级分类的新增类别前后端代码实现-商品系统三级分类的更新类别前后端代码实现-之前错误的Bug修正】

一.知识回顾 【0.三高商城系统的专题专栏都帮你整理好了,请点击这里!】 【1-系统架构演进过程】 【2-微服务系统架构需求】 【3-高性能、高并发、高可用的三高商城系统项目介绍】 【4-Linux云服务器上安装Docker】 【5-Docker安装部署MySQL和Redis服务】…

2022 全网最全最新 Java 面试题 - 独家内部教材

怎样才能拿到大厂的 offer,没有掌握绝对的技术,那么就要不断的学习 从疫情破局而出,又在毕业季一路过关斩将,我是如何笑面试官,拿到阿里,腾讯等八家大厂的 offer 的呢,在这里分享我的秘密武器&…

kubernetes(K8S)学习笔记P3:集群 YAML 文件(部署)

集群 YAML 文件(部署)4.集群 YAML 文件(部署)4.1 YAML 文件概述4.2YAML 文件书写格式4.2.1YAML 介绍4.2.2YAML 基本语法4.2.3YAML 支持的数据结构4.3资源清单描述方法4.3.1常用字段4.3.2字段解释4.4快速编写yml-->kubdectl cre…

数据结构由中序序列和后序序列构造二叉树

2022.11.19 由中序序列和后序序列构造二叉树任务描述相关知识编程要求测试说明C/C代码任务描述 本关任务要求采用中序遍历序列和后序遍历序列构造二叉树。 相关知识 给定一棵二叉树的中序遍历序列和后序遍历序列可以构造出这棵二叉树。例如后序序列是DEBFGCA,中序…

MySQL8.0优化 - 锁 - 从对待锁的态度划分:乐观锁、悲观锁

文章目录学习资料锁的不同角度分类锁的分类图如下从对待锁的态度划分:乐观锁、悲观锁悲观锁(Pessimistic Locking)乐观锁(Optimistic Locking)两种锁的适用场景学习资料 【MySQL数据库教程天花板,mysql安装…

Ajax笔记

Ajax笔记资源的请求方式一、概念1、Ajax作用2、jQuery中的Ajax二、$.get()函数的语法$.get()发起不带参数的请求$.get()发起带参数的请求三、$.post()函数的语法$.post()向服务器提交数据<font colorred>四、$.ajax()函数的语法使用$.ajax()发起GET请求使用$.ajax()发起P…

JSP使用

目录 简介 作用 创建 结构 常用脚本 声名脚本 表达式脚本 代码脚本 注释 九大内置对象 四大域对象 out与response.getWriter 静态引入 动态引入 EL表达式 作用 语法 取值顺序 获取指定参数 输出指定对象的数据 运算符 算数运算符 关系比较 逻辑运算符…

【Vue】使用 axios 发送ajax 请求

在 Vue 里面我们如何去发送一些 Ajax(阿贾克斯)请求 从远程的网站上获取一些数据。 假如我们有这样的接口的地址&#xff1a; https://www.xxxx.site 假设它是一个能跨域访问的接口。‍‍‍‍ 如果我们想去在我们的代码里面发这种请求&#xff0c;我该怎么做&#xff1f; 首…

Ubuntu 桌面系统升级

本文介绍 Ubuntu 桌面系统升级的两种方式&#xff0c;通过 UI 或命令行的方式&#xff0c;演示为 20.04 升级为 22.04。并介绍了 windows 的 Linux 子系统 wsl 的升级注意事项。 背景 之前在学习 ROS2 时&#xff0c;安装 ros-humble-desktop 出现依赖错误&#xff1a;无法修正…

[附源码]java毕业设计食材采购平台论文

项目运行 环境配置&#xff1a; Jdk1.8 Tomcat7.0 Mysql HBuilderX&#xff08;Webstorm也行&#xff09; Eclispe&#xff08;IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持&#xff09;。 项目技术&#xff1a; SSM mybatis Maven Vue 等等组成&#xff0c;B/S模式 M…

AI智工作室11.19练习题解

F CodeForces - 710A 训练1(共12题) - Virtual Judge 我的代码 #include<iostream> using namespace std; int main() {char arr[10][10],a;int b,c,k0;cin>>a>>c; ba-a1; // cout<<b<<" "<<c<<endl; for(int i0;i<9…

使用VSCode编辑与编译WSL2下源代码

1. 安装WSL2 2. windows下安装VSCode 3. VSCode安装插件Remote Development 北京时间2019年5月3日&#xff0c;在 PyCon 2019 大会上&#xff0c;微软发布了 VS Code Remote&#xff0c;开启了远程开发的新时代&#xff01;这次发布包含了三款核心的全新插件&#xff0c;它们…

Java文件操作【教你用Java运行微信】

文章目录01 创建文件02 获取文件信息03 目录操作和文件删除04 运行可执行文件01 创建文件 new File(String pathname) //根据路径创建一个File对象&#xff1b;new File(File parent,String child) //根据父目录文件子路径创建&#xff1b;new File(String parent,String chil…

使用mqtt.fx向EMQX服务器发送消息

摘要&#xff1a;本文介绍如何使用mqtt.fx向mqtt服务器&#xff08;EMQX&#xff09;发送消息。顺便介绍一下labview与EMQX连接成功的实现效果。 上一篇文章介绍了如何在ubuntu下安装emqx服务器&#xff0c;以及如何使用mqtt.fx订阅服务器上的一个主题。 ubuntu系统下搭建本地…

为什么ArcGIS添加的TIFF栅格数据是一片纯色

下面来介绍一下今天的正式内容&#xff1a;为什么你添加的tiff栅格数据明明有数据&#xff0c;为什么却在GIS中显示一片颜色。 即使你去拉伸后 他还是显示这样&#xff1f; 那如何才能让他正常显示呢&#xff1f; 逻辑其实是简单的。我们检查数据会发现&#xff0c;这份数据等…