第J2周:ResNet50V2算法实战与解析

news2024/9/27 9:21:36
  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍦 参考文章:365天深度学习训练营-第J2周:ResNet50V2算法实战与解析
  • 🍖 原作者:K同学啊|接辅导、项目定制

目录

    • 一、论文解读
      • 1. ResNetV2结构与ResNet结构对比
      • 2. 关于残差结构的不同尝试
      • 3. 关于激活的尝试
    • 二、模型复现
      • 1. Residual Block
      • 2. 堆叠Residual Block
      • 3. ResNet50V2架构复现
      • 4. 在cifar10上训练

📌 本周任务:
●1.请根据本文 TensorFlow 代码,编写出相应的 Pytorch 代码(建议使用上周的数据测试一下模型是否构建正确)
●2.了解ResNetV2与ResNetV的区别
●3.改进思路是否可以迁移到其他地方呢(自由探索)

一、论文解读

1. ResNetV2结构与ResNet结构对比

在这里插入图片描述

实线表示测试误差(右边的y轴),虚线表示训练损失(左边的y轴),Iterations 表示迭代次数

🧲 改进点:(a)original 表示原始的 ResNet 的残差结构,(b)proposed 表示新的 ResNet 的残差结构。主要差别就是(a)结构先卷积后进行 BN 和激活函数计算,最后执行 addition 后再进行ReLU 计算; (b)结构先进行 BN 和激活函数计算后卷积,把 addition 后的 ReLU 计算放到了残差结构内部。

📌 改进结果:作者使用这两种不同的结构在 CIFAR-10 数据集上做测试,模型用的是 1001层的 ResNet 模型。从图中结果我们可以看出,(b)proposed 的测试集错误率明显更低一些,达到了 4.92%的错误率,(a)original 的测试集错误率是 7.61%。

2. 关于残差结构的不同尝试

在这里插入图片描述
(b-f)中的快捷连接被不同的组件阻碍。为了简化插图,我们不显示BN层,这里所有单位均采用权值层之后的BN层。图中(a-f)都是作者对残差结构的 shortcut 部分进行的不同尝试 ,作者对不同 shortcut 结构的尝试结果如下表所示 。

在这里插入图片描述

使用ResNet-110在CIFAR-10测试集上的分类错误,对所有残差单元应用了不同类型的shortcut connections。当测试误差高于20%时,标注为“fail”。

作者用不同 shortcut 结构的 ResNet-110 在 CIFAR-10 数据集上做测试,发现最原始的(a)original 结构是最好的,也就是 identity mapping 恒等映射是最好的。

3. 关于激活的尝试

在这里插入图片描述
在这里插入图片描述

使用不同激活函数的CIFAR-10测试集上的分类误差(%)。

最好的结果是(e)full pre-activation,其次到(a)original。

二、模型复现

import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision
from torchvision import transforms, datasets
import os, PIL, pathlib, warnings
from torchsummary import summary


#忽略警告信息
warnings.filterwarnings("ignore")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

1. Residual Block

class Block2(nn.Module):
    def __init__(self, in_channel, filters, kernel_size=3, stride=1, conv_shortcut=False):
        super(Block2, self).__init__()
        self.preact = nn.Sequential(
            nn.BatchNorm2d(in_channel),
            nn.ReLU(True)
        )

        self.shortcut = conv_shortcut
        if self.shortcut:
            self.short = nn.Conv2d(in_channel, 4*filters, 1, stride=stride, padding=0, bias=False)
        elif stride>1:
            self.short = nn.MaxPool2d(kernel_size=1, stride=stride, padding=0)
        else:
            self.short = nn.Identity()

        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channel, filters, 1, stride=1, bias=False),
            nn.BatchNorm2d(filters),
            nn.ReLU(True)
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(filters, filters, kernel_size, stride=stride, padding=1, bias=False),
            nn.BatchNorm2d(filters),
            nn.ReLU(True)
        )
        self.conv3 = nn.Conv2d(filters, 4*filters, 1, stride=1, bias=False)

    def forward(self, x):
        x1 = self.preact(x)
        if self.shortcut:
            x2 = self.short(x1)
        else:
            x2 = self.short(x)
        x1 = self.conv1(x1)
        x1 = self.conv2(x1)
        x1 = self.conv3(x1)
        x = x1 + x2
        return x

2. 堆叠Residual Block

class Stack2(nn.Module):
    def __init__(self, in_channel, filters, blocks, stride=2):
        super(Stack2, self).__init__()
        self.conv = nn.Sequential()
        self.conv.add_module(str(0), Block2(in_channel, filters, conv_shortcut=True))
        for i in range(1, blocks-1):
            self.conv.add_module(str(i), Block2(4*filters, filters))
        self.conv.add_module(str(blocks-1), Block2(4*filters, filters, stride=stride))

    def forward(self, x):
        x = self.conv(x)
        return x

3. ResNet50V2架构复现

在这里插入图片描述

class ResNet50V2(nn.Module):
    def __init__(self,
                 include_top=True,  # 是否包含位于网络顶部的全链接层
                 preact=True,  # 是否使用预激活
                 use_bias=True,  # 是否对卷积层使用偏置
                 input_shape=[224, 224, 3],
                 classes=1000,
                 pooling=None):  # 用于分类图像的可选类数
        super(ResNet50V2, self).__init__()

        self.conv1 = nn.Sequential()
        self.conv1.add_module('conv', nn.Conv2d(3, 64, 7, stride=2, padding=3, bias=use_bias, padding_mode='zeros'))
        if not preact:
            self.conv1.add_module('bn', nn.BatchNorm2d(64))
            self.conv1.add_module('relu', nn.ReLU())
        self.conv1.add_module('max_pool', nn.MaxPool2d(kernel_size=3, stride=2, padding=1))

        self.conv2 = Stack2(64, 64, 3)
        self.conv3 = Stack2(256, 128, 4)
        self.conv4 = Stack2(512, 256, 6)
        self.conv5 = Stack2(1024, 512, 3, stride=1)

        self.post = nn.Sequential()
        if preact:
            self.post.add_module('bn', nn.BatchNorm2d(2048))
            self.post.add_module('relu', nn.ReLU())
        if include_top:
            self.post.add_module('avg_pool', nn.AdaptiveAvgPool2d((1, 1)))
            self.post.add_module('flatten', nn.Flatten())
            self.post.add_module('fc', nn.Linear(2048, classes))
        else:
            if pooling=='avg':
                self.post.add_module('avg_pool', nn.AdaptiveAvgPool2d((1, 1)))
            elif pooling=='max':
                self.post.add_module('max_pool', nn.AdaptiveMaxPool2d((1, 1)))

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.conv5(x)
        x = self.post(x)
        return x


model = ResNet50V2().to(device)
summary(model, (3, 224, 224))

4. 在cifar10上训练

import torchvision.datasets as datasets
import torchvision.transforms as transforms

# 定义预处理转换
transform = transforms.Compose([
    transforms.ToTensor(),  # 将 PIL 图像转换为张量
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 归一化
])

# 加载 CIFAR-10 数据集
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

# 创建数据加载器
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)
def train(dataloader,model,loss_fn,optimizer):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    train_loss,train_acc = 0,0
    for x,y in dataloader:
        x,y = x.to(device),y.to(device)
        pred = model(x)
        loss = loss_fn(pred,y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        train_acc += (pred.argmax(1)==y).type(torch.float).sum().item()
    train_loss /= num_batches
    train_acc /= size
    return train_loss,train_acc
def test(dataloader,model,loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss,test_acc = 0,0
    with torch.no_grad():
        for x,y in dataloader:
            x,y = x.to(device),y.to(device)
            pred = model(x)
            loss = loss_fn(pred,y)
            test_loss += loss.item()
            test_acc += (pred.argmax(1)==y).type(torch.float).sum().item()
    test_loss /= num_batches
    test_acc /= size
    return test_loss,test_acc

参考文章:http://t.csdn.cn/VhPbf

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/805356.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【需求响应】一种新的需求响应机制DR-VCG研究

目录 1 主要内容 2 部分代码 3 程序结果 4 程序链接 1 主要内容 该程序对应文章《Contract Design for Energy Demand Response》,电力系统需求响应(DR)用来调节用户对电能的需求,即在预测的需求高于电能供应时,希…

VS Code调试Darknet

一、安装插件 二、连接服务器 三、调试darknet工程 {"version": "2.0.0","options": {"cwd": "${workspaceFolder}"},"tasks": [{"label": "clean","type": "shell",&qu…

数据结构之动态顺序表(附带完整程序)

🎈基本概念 🌈一.线性表、顺序表的定义 ☀️(1)线性表: 是n个具有相同特性的数据元素的有限序列。线性表在逻辑上是线性结构,但在物理上存储时,通常以数组和链式结构的形式存储。 ☀️&…

C# 关于使用newlife包将webapi接口寄宿于一个控制台程序、winform程序、wpf程序运行

C# 关于使用newlife包将webapi接口寄宿于一个控制台程序、winform程序、wpf程序运行 安装newlife包 Program的Main()函数源码 using ConsoleApp3; using NewLife.Log;var server new NewLife.Http.HttpServer {Port 8080,Log XTrace.Log,SessionLog XTrace.Log }; serv…

hcip——ospf综合

要求 1. 搭建toop 2.地址规划 协议范围路由器地址 RIP 172.16.0.0 17 R12 loop0:172.16.0.0 24 loop1:172.16.1.0 24 OSPF 172.16.128.0 17 area1 172.16.144.0 20 R1 g0:172.16.144.1 24 loop0:172.16.145.1 24 R2 g0:172.16.144.2 24 loop:172…

iOS - Apple开发者账户添加新测试设备

获取UUID 首先将设备连接XCode,打开Window -> Devices and Simulators,通过下方位置查看 之后登录(苹果开发者网站)[https://developer.apple.com/account/] ,点击设备 点击加号添加新设备 填写信息之后点击Continue,并一路继续…

Golang Devops项目开发(1)

1.1 GO语言基础 1 初识Go语言 1.1.1 开发环境搭建 参考文档:《Windows Go语言环境搭建》 1.2.1 Go语言特性-垃圾回收 a. 内存自动回收,再也不需要开发人员管理内存 b. 开发人员专注业务实现,降低了心智负担 c. 只需要new分配内存,…

Android系统服务之AMS

目录 概述 重点和难点问题 启动方式 main入口: run方法: BootstrapSevices 小结: 与其他线程的通信原理 参考文档: 概述 AMS是Android系统主要负责四大组件的启动,切换,调度以及应用程序进程管理和调度等工…

watch避坑,使用computed进行处理数据

业务场景:在vue中监听el-input 中的字数有没有超过60,如果超过60字时将60后面的字变为 “>>” 符号,以此实现预览苹果手机推送摘要场景。 错误:开始的逻辑是使用watch监听,检查length超过60直接 加上符号&#x…

选好NAS网络储存解决方案,是安全储存的关键

随着网络信息的发展,NAS也越来越受到企业的关注,NAS网络存储除了提供简单的存储服务外,还可以提供更好的数据安全性、更方便的文件共享方式。但市面上的产品种类繁多,我们该如何选择合适的产品,通过企业云盘&#xff0…

spring5源码篇(12)——spring-mvc请求流程

spring-framework 版本:v5.3.19 文章目录 一、请求流程1、处理器映射器1.1、 RequestMappingHandlerMapping1.2、获取对应的映射方法1.3、添加拦截器 2、获取合适的处理器适配器3、通过处理器适配器执行处理器方法3.1、拦截器的前置后置3.2、处理器的执行3.2.1 参数…

Unity 性能优化二:内存问题

目录 策略导致的内存问题 GFX内存 纹理资源 压缩格式 Mipmap 网格资源 Read/Write 顶点数据 骨骼 静态合批 Shader资源 Reserved Memory RenderTexture 动画资源 音频资源 字体资源 粒子系统资源 Mono堆内存 策略导致的内存问题 1. Assetbundle 打包的时候…

【C++】C++ STL标准模板库知识点总结(秋招篇)

文章目录 前言STL的六大组件是?容器(container) 算法(algorithm) 迭代器(iterator) 三者的关系?容器分为几种?分别有哪些?关联性容器和非关联性容器有什么区别?Vector容器是怎么调整大小的?(内存…

VirtualEnv 20.24.0 发布

导读VirtualEnv 20.24.0 现已发布,VirtualEnv 用于在一台机器上创建多个独立的 Python 运行环境,可隔离项目之间的第三方包依赖,为部署应用提供方便,把开发环境的虚拟环境打包到生产环境即可,不需要在服务器上再折腾一…

RAM明明断电会丢失数据,为什么初始化的全局变量存储在RAM?详细分析程序的存储

前言 (1)之前因为一个字符指针和字符数组指针引发的bug,折磨了我一个下午才发现问题。之后我就打算研究一下系统是如何发现野指针乱访问问题。后面就一直深入到微机系统中的内存管理了。 (2)这些其实都是基础知识&…

SpringBoot房屋租赁系统【附ppt|万字文档(LW)和搭建文档】

主要功能 前台登录: ①首页:公告信息、房屋信息展示、查看更多等 ②房屋信息、房屋类型、我要当房主、公告信息、留言反馈等 ③个人中心:可以查看自己的信息、更新图片、更新信息、退出登录、我的收藏 后台登录: ①首页、个人中心…

Day 69-70:矩阵分解

代码: package dl;import java.io.*; import java.util.Random;/** Matrix factorization for recommender systems.*/public class MatrixFactorization {/*** Used to generate random numbers.*/Random rand new Random();/*** Number of users.*/int numUsers…

使用贝叶斯算法完成文档分类问题

贝叶斯原理 贝叶斯原理(Bayes theorem)是一种用于计算条件概率的数学公式。它是以18世纪英国数学家托马斯贝叶斯(Thomas Bayes)的名字命名的。贝叶斯原理表达了在已知某个事件发生的情况下,另一个事件发生的概率。具体…

【Golang系统开发】搜索引擎(1) 如何快速判断网页是否已经被爬取

文章目录 1. 写在前面2. 数组存储3. 位图存储3.1 位图简介3.2 链表法3.3 开放寻址法 1. 写在前面 在实际工作中,我们经常需要判断一个对象是否存在,比如判断用户注册登陆时候,需要判断用户是否存在,再比如搜索引擎中的爬虫&#x…

大数据面试题之Elasticsearch:每日三题(七)

大数据面试题之Elasticsearch:每日三题 1.Elasticsearch索引文档的流程?2.Elasticsearch更新和删除文档的流程?3.Elasticsearch搜索的流程? 1.Elasticsearch索引文档的流程? 协调节点默认使用文档ID参与计算(也支持通过routing)&a…