pytorch深度学习基础(十一)——常用结构化CNN模型构建

news2025/1/12 4:08:53

结构化CNN模型构建与测试

  • 前言
  • GoogLeNet
    • 结构
    • Inception块
    • 模型构建
  • resNet18
    • 模型结构
    • 残差块
    • 模型构建
  • denseNet
    • 模型结构
    • DenseBlock
    • transition_block
    • 模型构建
  • 结尾

前言

在本专栏的上一篇博客中我们介绍了常用的线性模型,在本文中我们将介绍GoogleNet、resNet、denseNet这类结构化的模型的构建方式。

GoogLeNet

结构

整体的结构似乎有些吓人,但其实他也是用了块的思维,仔细观察可以发现,他中间一段很多层的结构都是相似的
在这里插入图片描述

Inception块

这个块就是其中重复的块,这个块分成了四个分支:1x1卷积、1x1卷积+3x3卷积、1x1卷积+5x5卷积、3x3卷积+1x1卷积,最后将这四个分支通道合并

class Inception(nn.Module):
    def __init__(self, in_channels, c1, c2, c3, c4, **kwargs):
        super(Inception, self).__init__(**kwargs)
        self.p1_1 = nn.Conv2d(in_channels, c1, kernel_size=1)
        
        self.p2_1 = nn.Conv2d(in_channels, c2[0], kernel_size=1)
        self.p2_2 = nn.Conv2d(c2[0], c2[1], kernel_size=3, padding=1)
        
        self.p3_1 = nn.Conv2d(in_channels, c3[0], kernel_size=1)
        self.p3_2 = nn.Conv2d(c3[0], c3[1], kernel_size=5, padding=2)
        
        self.p4_1 = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)
        self.p4_2 = nn.Conv2d(in_channels, c4, kernel_size=1)
        
    def forward(self, x):
        p1 = F.relu(self.p1_1(x))
        p2 = F.relu(self.p2_2(F.relu(self.p2_1(x))))
        p3 = F.relu(self.p3_2(F.relu(self.p3_1(x))))
        p4 = F.relu(self.p4_2(self.p4_1(x)))
        
        return torch.cat((p1, p2, p3, p4), dim=1)

模型构建

b1 = nn.Sequential(nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3),
                   nn.ReLU(),
                   nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
                  )
b2 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=1),
                   nn.ReLU(),
                   nn.Conv2d(64, 192, kernel_size=3, padding=1),
                   nn.ReLU(),
                   nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
                  )

b3 = nn.Sequential(Inception(192, 64, (96, 128), (16, 32), 32),
                   Inception(256, 128, (128, 192), (32, 96), 64),
                   nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
                  )

b4 = nn.Sequential(Inception(480, 192, (96, 208), (16, 48), 64),
                   Inception(512, 160, (112, 224), (24, 64), 64),
                   Inception(512, 128, (128, 256), (24, 64), 64),
                   Inception(512, 112, (144, 288), (32, 64), 64),
                   Inception(528, 256, (160, 320), (32, 128), 128),
                   nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
                  )

b5 = nn.Sequential(Inception(832, 256, (160, 320), (32, 128), 128),
                   Inception(832, 384, (192, 384), (48, 128), 128),
                   nn.AdaptiveAvgPool2d((1, 1)),
                   nn.Flatten()
                  )

net = nn.Sequential(b1, b2, b3, b4, b5, nn.Linear(1024, 10))

resNet18

模型结构

这个模型似乎要更加简洁一些,因为这里只有两个分支,但是他有两种分支方式,一种是卷积+残差,另外一种是卷积+经过1x1卷积处理过的残差
在这里插入图片描述

残差块

class Residual(nn.Module):
    def __init__(self, input_channels, num_channels, use_1x1conv=False, strides=1):
        super().__init__()
        self.conv1 = nn.Conv2d(input_channels, num_channels, kernel_size=3, padding=1, stride=strides)
        self.conv2 = nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1)
        if use_1x1conv:
            self.conv3 = nn.Conv2d(input_channels, num_channels, kernel_size=1, stride=strides)
            
        else:
            self.conv3 = None
            
        self.bn1 = nn.BatchNorm2d(num_channels)
        self.bn2 = nn.BatchNorm2d(num_channels)
        
    def forward(self, X):
        Y = F.relu(self.bn1(self.conv1(X)))
        Y = self.bn2(self.conv2(Y))
        if self.conv3:
            X = self.conv3(X)
            
        Y += X
        return F.relu(Y)

构建时依次使用两种分支方式

def resnet_block(input_channels, num_channels, num_residuals, first_block=False):
    blk = []
    for i in range(num_residuals):
        if i == 0 and not first_block:
            blk.append(Residual(input_channels, num_channels, use_1x1conv=True, strides=2))
        else:
            blk.append(Residual(num_channels, num_channels))
            
    return blk

模型构建

b1 = nn.Sequential(nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3),
                   nn.BatchNorm2d(64), nn.ReLU(),
                   nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
                  )
b2 = nn.Sequential(*resnet_block(64, 64, 2, first_block=True))
b3 = nn.Sequential(*resnet_block(64, 128, 2))
b4 = nn.Sequential(*resnet_block(128, 256, 2))
b5 = nn.Sequential(*resnet_block(256, 512, 2))

net = nn.Sequential(b1, b2, b3, b4, b5, 
                    nn.AdaptiveAvgPool2d((1, 1)),
                    nn.Flatten(),
                    nn.Linear(512, 10)
                   )

denseNet

模型结构

这个模型的主要思路就是把每一次的输出和输入合并起来,同时作为下一层的输入,具体的细节还要结合代码解释
在这里插入图片描述

DenseBlock

我们看下面的代码,注意init中的循环以及forward时对于X的处理。我们可以发现,每经过一个conv_block都会将conv_block的输出并入到输入,以此作为下一层的输入

def conv_block(input_channels, num_channels):
    return nn.Sequential(nn.BatchNorm2d(input_channels), nn.ReLU(),
                         nn.Conv2d(input_channels, num_channels, kernel_size=3, padding=1)
                        )
                        
class DenseBlock(nn.Module):
    def __init__(self, num_convs, input_channels, num_channels):
        super(DenseBlock, self).__init__()
        layer = []
        for i in range(num_convs):
            layer.append(conv_block(
                num_channels * i + input_channels, num_channels
            ))
        self.net = nn.Sequential(*layer)
        
    def forward(self, X):
        for blk in self.net:
            Y = blk(X)
            X = torch.cat((X, Y), dim=1)
        
        return X

transition_block

这个块的主要作用是减少通道,因为在前面的块中,通道数会持续的增长,考虑到计算量,需要在中间加入减少通道的块

def transition_block(input_channels, num_channels):
    return nn.Sequential(nn.BatchNorm2d(input_channels), nn.ReLU(),
                         nn.Conv2d(input_channels, num_channels, kernel_size=1),
                         nn.AvgPool2d(kernel_size=2, stride=2)
                        )

模型构建

b1 = nn.Sequential(nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3),
                   nn.BatchNorm2d(64), nn.ReLU(),
                   nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
                  )
                  
num_channels, growth_rate = 64, 32
num_convs_in_dense_blocks = [4, 4, 4, 4]
blks = []
for i, num_convs in enumerate(num_convs_in_dense_blocks):
    blks.append(DenseBlock(num_convs, num_channels, growth_rate))
    num_channels += num_convs * growth_rate
    if i != len(num_convs_in_dense_blocks) - 1:
        blks.append(transition_block(num_channels, num_channels // 2))
        num_channels = num_channels // 2

net = nn.Sequential(b1, *blks,
                    nn.BatchNorm2d(num_channels), nn.ReLU(),
                    nn.AdaptiveAvgPool2d((1, 1)),
                    nn.Flatten(),
                    nn.Linear(num_channels, 10)
                   )

结尾

我们到现在模型就已经构建好了,测试的过程可以参照本专栏的上一篇博客
pytorch深度学习基础(十)——常用线性CNN模型的结构与训练

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/179279.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

APT之木马静态免杀

前言 这篇文章主要是记录手动编写代码进行木马免杀,使用工具也可以免杀,只不过太脚本小子了,而且工具的特征也容易被杀软抓到,指不定哪天就用不了了,所以要学一下手动去免杀木马,也方便以后开发一个只属于…

blender导入骨骼动画方法[psa动作]

先导入女性的psk文件 然后调整缩放大小和人物一样,包括角度朝向. ctrla应用所有改变 然后选择psk文件以及其他人物模型的全部 ,然后 在Layout-物体-父级 -附带空顶相点组 image.png之后会发现所有人物多了修改器,点击其中一个修改器 点添加修改器 -数据传递 勾选顶点数据-选择顶…

人员动作行为AI分析系统 yolov5

人员动作行为AI分析系统通过pythonyolo系列网络学习模型,对现场画面人员行为进行实时分析监测,自动识别出人的各种异常行为动作,立即抓拍存档预警同步回传给后台。 我们使用YOLO算法进行对象检测。YOLO是一个聪明的卷积神经网络(CNN)&#xf…

带滤波器的PID控制仿真-1

采用低通滤波器可有效地滤掉噪声信号,在控制系统的设计中是一种常用的方法。基于低通滤波器的信号处理实例设低通滤波器为:采样时间为1ms,输入信号为带有高频正弦噪声( 100Hz)的低频(0.2Hz)正弦信号。采用低…

离散数学与组合数学-05树

文章目录离散数学与组合数学-05树5.1 认识树5.1.1 树的模型5.1.2 树的应用5.2 无向树5.2.1 定义5.2.2 树的性质5.2.3 性质应用5.3 生成树5.3.1 引入5.3.2 定义5.3.3 算法5.3.4 应用5.4 最小生成树5.4.1 引入5.4.2 定义5.4.3 算法5.5 根树5.5.1 根数定义5.5.2 倒置法5.5.3 树的家…

【编程入门】开源记事本(SwiftUI版)

背景 前面已输出多个系列: 《十余种编程语言做个计算器》 《十余种编程语言写2048小游戏》 《17种编程语言10种排序算法》 《十余种编程语言写博客系统》 《十余种编程语言写云笔记》 本系列对比云笔记,将更为简化,去掉了网络调用&#xff0…

C++模板进阶

这篇文章是对模板初阶的一些补充,让大家在进行深一层的理解。 文章目录1. 非类型模板参数2. 模板的特化2.1 概念2.2 函数模板特化2.3 类模板特化2.3.1 全特化2.3.2 偏特化2.4 类模板特化应用示例3 模板分离编译3.1 什么是分离编译3.2 模板的分离编译3.3 解决方法4.…

【各种**问题系列】什么是 LTS 长期支持

目录 🍁 什么是长期支持(LTS)版本? 🍂 LTS 版本的优点: 🍁 什么是 Ubuntu LTS? 🍂 Ubuntu LTS 软件更新包括什么? 在 Linux 的世界里,特别是谈…

【Java开发】Spring Cloud 08 :链路追踪

任何一个架构难免会出现bug,微服务相比于单体架构日志查询更为困难,因此spring cloud推出了Sleuth等组件的链路追踪技术来实现报错信息的定位及查询。项目源码:尹煜 / coupon-yinyu GitCode1 调用链追踪我们可以想象这样一个场景&#xff0c…

单一数字评估指标、迁移学习、多任务学习、端到端的深度学习

目录1.单一数字评估指标(a single number evaluation metric)有时候要比较那个分类器更好,或者哪个模型更好,有很多指标,很难抉择,这个时候就需要设置一个单一数字评估指标。例1:比较A,B两个分类器的性能&a…

Android MVVM的实现

Android MVVM的实现 前言: 在我们写一些项目的时候,通常会对一些常用的一些常用功能进行抽象封装,简单例子:比如BaseActivity,BaseFragment等等…一般这些Base会去承载一些比如标题栏,主题之类的工作&…

提权漏洞和域渗透历史漏洞整理

Windows提权在线辅助工具 https://i.hacking8.com/tiquan/🌴Kernel privilege escalation vulnerability collection, with compilation environment, demo GIF map, vulnerability details, executable file (提权漏洞合集) https://github.com/Ascotbe/Kernelhu…

恶意代码分析实战 13 反调试技术

13.1 Lab16-01 首先,将可执行文件拖入IDA中。 我们可以看到有三处都调用了sub_401000函数,并且代码都在哪里停止执行。由于没有一条线从这些方框中引出,这就意味着函数可能终止了程序。 右侧每一个大框中都包含一个检查,这个检查…

Makefile学习②:Makefile基本语法

Makefile学习②:Makefile基本语法 Makefile基本语法 目标: 依赖 (Tab)命令 目标:一般是指要编译的目标,也可以是一个动作 依赖:指执行当前目标所要依赖的先项,包括其他目标&#xf…

neural collaborative filtering 阅读笔记

本文主要介绍了一种一种基于神经网络的技术,来解决在含有隐形反馈的基础上进行推荐的关键问题————协同过滤。 2.1 Learning from Implicit Data yui1,(ifinteraction(useru,itemi)isobserved)y_{ui} 1,(if interaction (user u, item i) is observed)yui​1,(…

还在为ElementUI的原生校验方式苦恼吗,快用享受element-ui-verify插件的快乐吧(待续)

element-ui-verify 本文章意在介绍element-ui-verify插件使用&#xff0c;以及对比elementUI原生校验方式&#xff0c;突显该插件用少量代码也能实现原生的校验效果甚至更好。 1.先观察一个示例 <template><d2-container><el-form :model"ruleForm&qu…

二叉树超级经典OJ题

目录1.根据二叉树创建字符串2.二叉树的层序遍历3.二叉树的层序遍历II4.二叉树的最近公共祖先5.二叉搜索树与双向链表6.从前序与中序遍历序列构造二叉树1.根据二叉树创建字符串 根据二叉树创建字符串 给你二叉树的根节点root&#xff0c;请你采用前序遍历的方式&#xff0c;将二…

编码器M法测速仿真(Simulink)

编码器M法和T法测速的详细讲解可以参看下面的文章链接,这里不再赘述,这里主要介绍Simulink里建模仿真,带大家从另一个角度理解编码器测速原理。 PLC通过编码器反馈值计算速度的推荐做法(算法解析+ST代码)_RXXW_Dor的博客-CSDN博客_编码器计算速度程序实例PLC如何测量采集编…

Power BI中类似Vlookup的查询筛选功能如何实现

一、问题描述 在Excel中有一个非常经典的函数Vlookup&#xff0c;可以通过首列查找&#xff0c;返回相对应的其他列的值。这种功能&#xff0c;在Power BI中没有Vlookup函数&#xff0c;那么该如何实现这一功能呢&#xff1f;下面通过一个实例做分析演示。 二、数据源 已知某…

厚积薄发打卡Day114:Debug设计模式:设计原则(二)<接口隔离原则、迪米特法则>

厚积薄发打卡Day114&#xff1a;Debug设计模式&#xff1a;设计原则&#xff08;二&#xff09;<接口隔离原则、迪米特法则> 接口隔离原则 定义 用多个专门的接口&#xff0c;而不使用单一的总接口&#xff0c;客户端不应该依赖它不需要的接口 一个类对一个类的依赖应…