深度学习之残差网络ResNet

news2024/12/24 0:43:57

文章目录

    • 1. 残差网络定义
    • 2. 数学基础====函数类
    • 3. 残差块
    • 4.ResNet模型
    • 5.训练模型
    • 6.小结

1. 残差网络定义

随着我们设计的网络越来越深,深刻理解“新添加的层如何提升神经网络的性能”变得至关重要。更重要的是设计网络的能力。在这种网络中,添加层会使得网络更具表达力,为了取得质的突破,我们需要一些数学基础知识。

2. 数学基础====函数类

在这里插入图片描述

3. 残差块

在这里插入图片描述

import torch
from torch import nn
from torch.nn import functional as F
from d2l import torch as d2l


class Residual(nn.Module):  #@save
    def __init__(self, input_channels, num_channels,
                 use_1x1conv=False, strides=1):
        super().__init__()
        self.conv1 = nn.Conv2d(input_channels, num_channels,
                               kernel_size=3, padding=1, stride=strides)
        self.conv2 = nn.Conv2d(num_channels, num_channels,
                               kernel_size=3, padding=1)
        if use_1x1conv:
            self.conv3 = nn.Conv2d(input_channels, num_channels,
                                   kernel_size=1, stride=strides)
        else:
            self.conv3 = None
        self.bn1 = nn.BatchNorm2d(num_channels)
        self.bn2 = nn.BatchNorm2d(num_channels)

    def forward(self, X):
        Y = F.relu(self.bn1(self.conv1(X)))
        Y = self.bn2(self.conv2(Y))
        if self.conv3:
            X = self.conv3(X)
        Y += X
        return F.relu(Y)

在这里插入图片描述
在这里插入图片描述

4.ResNet模型

ResNet的前两层跟之前介绍的GoogLeNet中的一样: 在输出通道数为64、步幅为2的
卷积层后,接步幅为2的
的最大汇聚层。 不同之处在于ResNet每个卷积层后增加了批量规范化层。
代码如下:

b1 = nn.Sequential(nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3),
                   nn.BatchNorm2d(64), nn.ReLU(),
                   nn.MaxPool2d(kernel_size=3, stride=2, padding=1))

GoogLeNet在后面接了4个由Inception块组成的模块。 ResNet则使用4个由残差块组成的模块,每个模块使用若干个同样输出通道数的残差块。 第一个模块的通道数同输入通道数一致。 由于之前已经使用了步幅为2的最大汇聚层,所以无须减小高和宽。 之后的每个模块在第一个残差块里将上一个模块的通道数翻倍,并将高和宽减半。

下面我们来实现这个模块。注意,我们对第一个模块做了特别处理。

def resnet_block(input_channels, num_channels, num_residuals,
                 first_block=False):
    blk = []
    for i in range(num_residuals):
        if i == 0 and not first_block:
            blk.append(Residual(input_channels, num_channels,
                                use_1x1conv=True, strides=2))
        else:
            blk.append(Residual(num_channels, num_channels))
    return blk

接着在ResNet加入所有残差块,这里每个模块使用2个残差块。

b2 = nn.Sequential(*resnet_block(64, 64, 2, first_block=True))
b3 = nn.Sequential(*resnet_block(64, 128, 2))
b4 = nn.Sequential(*resnet_block(128, 256, 2))
b5 = nn.Sequential(*resnet_block(256, 512, 2))

最后,与GoogLeNet一样,在ResNet中加入全局平均汇聚层,以及全连接层输出。

net = nn.Sequential(b1, b2, b3, b4, b5,
                    nn.AdaptiveAvgPool2d((1,1)),
                    nn.Flatten(), nn.Linear(512, 10))

在这里插入图片描述
在这里插入图片描述
图7.6.4 ResNet-18架构
在训练ResNet之前,让我们观察一下ResNet中不同模块的输入形状是如何变化的。 在之前所有架构中,分辨率降低,通道数量增加,直到全局平均汇聚层聚集所有特征。

X = torch.rand(size=(1, 1, 224, 224))
for layer in net:
    X = layer(X)
    print(layer.__class__.__name__,'output shape:\t', X.shape)

自己动手算算怎么得到的输出(n,c,h,w)

Sequential output shape: torch.Size([1, 64, 56, 56]) Sequential
output shape: torch.Size([1, 64, 56, 56]) Sequential output shape:
torch.Size([1, 128, 28, 28]) Sequential output shape:
torch.Size([1, 256, 14, 14]) Sequential output shape:
torch.Size([1, 512, 7, 7]) AdaptiveAvgPool2d output shape:
torch.Size([1, 512, 1, 1]) Flatten output shape: torch.Size([1,
512]) Linear output shape: torch.Size([1, 10])

5.训练模型

同之前一样,我们在Fashion-MNIST数据集上训练ResNet。

lr, num_epochs, batch_size = 0.05, 10, 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=96)
d2l.train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())

loss 0.012, train acc 0.997, test acc 0.893
5032.7 examples/sec on cuda:0

在这里插入图片描述

6.小结

学习嵌套函数(nested function)是训练神经网络的理想情况。在深层神经网络中,学习另一层作为恒等映射(identity
function)较容易(尽管这是一个极端情况)。

残差映射可以更容易地学习同一函数,例如将权重层中的参数近似为零。

利用残差块(residual blocks)可以训练出一个有效的深层神经网络:输入可以通过层间的残余连接更快地向前传播。

残差网络(ResNet)对随后的深层神经网络设计产生了深远影响。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2217373.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

单例模式:为何继承无法保证子类的单例特性

这里写目录标题 一、引言二、背景描述三、单例模式的规范边界全局访问点与静态工厂方法代码示例与注意事项 四、单例实现继承遇到的问题五、结论与替代方案结论替代方案特殊的想法🌸🌸源码阶段验证编译阶段验证运行阶段验证总结 一、引言 在软件设计中&a…

实时语音转文字(基于NAudio+Whisper+VOSP+Websocket)

今天花了大半天时间研究一个实时语音转文字的程序,目的还包括能够唤醒服务,并把命令提供给第三方。 由于这方面的材料已经很多,我就只把过程中遇到的和解决方案简单说下。源代码开源在AudioWhisper: 实时语音转文字(基于NAudioWhisperVOSPWe…

基于SSM的个性化商铺系统【附源码】

基于SSM的个性化商铺系统 效果如下: 用户登录界面 app首页界面 商品信息界面 店铺信息界面 用户功能界面 我的订单界面 后台登录界面 管理员功能界面 用户管理界面 商家管理界面 店铺信息管理界面 商家功能界面 个人中心界面 研究背景 研究背景 科学技术日新月异…

Leetcode 每日温度

class Solution {public int[] dailyTemperatures(int[] temperatures) {int n temperatures.length;Stack<Integer> stack new Stack<>();//默认将数组中的所有元素初始化为 0int[] results new int[n];for(int i 0; i < n; i) {while(!stack.isEmpty() &a…

leaflet前端JS实现高德地图POI兴趣点批量分类下载(附源码下载)

前言 leaflet 入门开发系列环境知识点了解&#xff1a; leaflet api文档介绍&#xff0c;详细介绍 leaflet 每个类的函数以及属性等等leaflet 在线例子leaflet 插件&#xff0c;leaflet 的插件库&#xff0c;非常有用 内容概览 leaflet前端JS实现高德地图POI兴趣点批量分类下载…

小猿口算炸鱼脚本

目录 写在前面&#xff1a; 一、关于小猿口算&#xff1a; 二、代码逻辑 1.数字识别 2.答题部分 三、代码分享&#xff1a; 补充&#xff1a;软件包下载 写在前面&#xff1a; 最近小猿口算已经被不少大学生攻占&#xff0c;小学生直呼有挂。原本是以为大学生都打着本…

【Python爬虫】看电影还在用VIP?一个python代码让你实现电影自由!附源码

今日主题 如何用Python解析vip电影。 什么是vip电影&#xff1f; 这些vip电影啊&#xff0c;想要观看的话&#xff0c;必须充值会员&#xff0c;否则没法看。 比如这个&#xff1a; 这些vip电影解析后呢&#xff1f; 不需要会员&#xff0c;不需要登录&#xff0c;可以直接…

Java-类与对象

一、面向对象 在了解类与对象前&#xff0c;我们需要先知道"面向对象"这个词的概念&#xff1a; 在Java语言中&#xff0c;我们的主要思想就是"面向对象"&#xff0c;而在之前我们所学习的C语言中大部分时候的思想是"面向过程"。 那么什么是&…

MySQL-10.DML-添加数据insert

一.DML(INSERT) -- DDL&#xff1a;数据操作语言 -- DML&#xff1a;插入数据 - insert -- 1.为tb_emp表的username&#xff0c;name&#xff0c;gender字段插入值 insert into tb_emp (username,name,gender) values (wuji,无忌,1); -- 这样会报错&#xff0c;因为create_ti…

DS堆的实际应用(10)

文章目录 前言一、堆排序建堆排序 二、TopK问题原理实战创建一个有一万个数的文件读取文件并将前k个数据创建小堆用剩余的N-K个元素依次与堆顶元素来比较将前k个数据打印出来并关闭文件 测试 三、堆的相关习题总结 前言 学完了堆这个数据结构的概念和特性后&#xff0c;我们来看…

限时设计ui

ctrl-------放大缩小 空格-----画面移动 alt------复制 页面<画板<图层 添加交互事件 原型 点击蓝色的圆&#xff0c;从1跳转到2 点击绿色的圆&#xff0c;从2跳转到1

基于SSM+Vue+MySQL的健身房管理系统

系统展示 系统背景 随着人们生活水平的提高和健康意识的增强&#xff0c;越来越多的人选择去健身房锻炼。传统的健身房管理方式往往依赖于纸质记录和人工操作&#xff0c;这种方式不仅效率低下&#xff0c;而且容易出错。为了提高健身房的管理效率和服务质量&#xff0c;开发一…

python项目实战——下载美女图片

python项目实战——下载美女图片 文章目录 python项目实战——下载美女图片完整代码思路整理实现过程使用xpath语法找图片的链接检查链接是否正确下载图片创建文件夹获取一组图片的链接获取页数 获取目录页的链接 完善代码注意事项 完整代码 import requests import re import…

图文检索综述(2):Deep Multimodal Data Fusion

Deep Multimodal Data Fusion 摘要1 引言2 基于编码器-解码器融合2.1 数据级别融合2.2 分层特征融合2.3 决策级别融合 3 基于注意力融合3.1 模态内的自注意力3.2 模态间的交叉注意力3.3 基于transformer的方法 4 基于图神经网络融合4.1 单个模态的表示学习4.2 融合数据的表示学…

【数据结构】宜宾大学-计院-实验三

线性表的应用——实现两多项式的相加 课前准备&#xff1a;实验学时&#xff1a;2实验目的&#xff1a;实验内容&#xff1a;实验结果&#xff1a;实验报告:&#xff08;及时撰写实验报告&#xff09;实验测试结果&#xff1a;代码实现&#xff1a;&#xff08;C/C&#xff09;…

Java 小游戏《超级马里奥》

文章目录 一、效果展示二、代码编写1. 素材准备2. 创建窗口类3. 创建常量类4. 创建动作类5. 创建关卡类6. 创建障碍物类7. 创建马里奥类8. 编写程序入口 一、效果展示 二、代码编写 1. 素材准备 首先创建一个基本的 java 项目&#xff0c;并将本游戏需要用到的图片素材 image…

华为 HCIP-Datacom H12-821 题库 (38)

&#x1f423;博客最下方微信公众号回复题库,领取题库和教学资源 &#x1f424;诚挚欢迎IT交流有兴趣的公众号回复交流群 &#x1f998;公众号会持续更新网络小知识&#x1f63c; 1.请对 2001:0DB8:0000:C030:0000:0000:09A0:CDEF 地址进行压缩。&#xff08; &#xff09;&…

阻塞I/O与非阻塞I/O

目录 一、基本概念 二、阻塞I/O的实现机制 —— 等待队列 一、基本概念 阻塞&#xff1a;在执行单元进行操作时&#xff0c;如果不能获得申请的资源&#xff0c;则执行单元挂起直至资源可用后再进行操作。 非阻塞&#xff1a;在执行单元进行操作时&#xff0c;如果不能获得申…

UDP反射放大攻击防范手册

UDP反射放大攻击是一种极具破坏力的恶意攻击手段。 一、UDP反射放大攻击的原理 UDP反射放大攻击主要利用了UDP协议的特性。攻击者会向互联网上大量的开放UDP服务的服务器发送伪造的请求数据包。这些请求数据包的源IP地址被篡改为目标受害者的IP地址。当服务器收到这些请求后&…

爬虫实战(黑马论坛)

1.定位爬取位置内容&#xff1a; # -*- coding: utf-8 -*- import requests import time import re# 请求的 URL 和头信息 url https://bbs.itheima.com/forum-425-1.html headers {user-agent: Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like…