SqueezeNet 一维,二维网络复现 pytorch 小白易懂版

news2025/1/20 13:26:57

SqueezeNet

时隔一年我又开始复现神经网络的经典模型,这次主要复的是轻量级网络全家桶,轻量级神经网络旨在使用更小的参数量,无限的接近大模型的准确率,降低处理时间和运算量,这次要复现的是轻量级网络的非常经典的一个模型SqueezeNet,它由美国加州大学伯克利分校的研究团队开发,并于2016年发布。


文章链接: https://arxiv.org/pdf/1602.07360.pdf?source=post_page---------------------------

看懂这篇文章需要的基础知识

  1. 了解python语法基础
  2. 了解深度学习基本原理
  3. 知道什么是卷积层池化层激活函数层softmanx层
  4. 熟悉卷积层池化层需要的参数
  5. 需要了解pytorch模型的基本构成

我记得去年的这个时候,好像GPT还没被特别广泛的使用,还没到一键就能直接输出写好的模型的这一个步骤,那为什么还要看博客这类的文章呢,应该是因为毕竟GPT他还是靠着已有的资料进行读取,他不能图文并茂的给你写一个一定好用的大型模型,不然直接把论文甩给他让他复现就好了,所以还是打算写一下,然后简单画点图然后给之后的学弟学妹们留一点遗产。

SqueezeNet 的模型结构

下面是原论文给出的模型结构
在这里插入图片描述
原文中给出了三种模型,分别是第一个基础模型,以及第二个和第三个带有残差分支的模型,其中卷积池化分支我们都有了解,这里新的东西就是这个Fire层,那就先从这个Fire层开始介绍

Fire层

作者说他的SqueezeNet网络为什么可以有更小的参数量,主要由于用了下面这个叫Fire层的东西,Fire层分两部分

  • 一部分是Squeeze层其实就是卷积核大小为1×1的一个卷积层
  • 另一部分呢是expend层他实际上是卷积核大小为1×1和卷积核大小为卷积层和3×3输出的一个拼接

下面是原论文中对Fire模型的详细描述
在这里插入图片描述
在这里插入图片描述
那如果要实现一维的那就把3×3的卷积核改成1×3的
加上激活函数,其实现代码应该是这样的,接下来详细介绍里面的参数。

  • in_channels 指Fire模块的输入通道数,也是就每个Fire模块的squeeze卷积层的输入通道数
  • squeeze_channels 指的是squeeze层的输出通道数
  • expand1x1_channels 指的是expand层中卷积核大小为1×1的卷积层的输出通道数
  • expand1x3_channels 指的是expand层中卷积核大小为1×2的卷积层的输出通道数
class FireModule(torch.nn.Module):
    def __init__(self, in_channels, squeeze_channels, expand1x1_channels, expand1x3_channels):
        super(FireModule, self).__init__()
        self.squeeze = torch.nn.Conv1d(in_channels, squeeze_channels, kernel_size=1)
        self.relu = torch.nn.ReLU(inplace=True)
        self.expand1x1 = torch.nn.Conv1d(squeeze_channels, expand1x1_channels, kernel_size=1)
        self.expand1x3 = torch.nn.Conv1d(squeeze_channels, expand1x3_channels, kernel_size=3, padding=1)

    def forward(self, x):
        x = self.squeeze(x)
        x = self.relu(x)
        out1x1 = self.expand1x1(x)
        out1x3 = self.expand1x3(x)
        out = torch.cat([out1x1, out1x3], dim=1)
        return self.relu(out)

基础知识补充: torch.cat 将向量在某一个维度上拼接

import torch
# Create two tensors
out1x1 = torch.tensor([[1, 2, 3], [1, 2, 3]])
out1x3 = torch.tensor([[4, 5, 6], [7, 8, 9]])

# Concatenate the tensors along the second dimension (dim=1)
out = torch.cat([out1x1, out1x3], dim=1)
print(out)
# tensor([[1, 2, 3, 4, 5, 6],
#         [1, 2, 3, 7, 8, 9]])
out = torch.cat([out1x1, out1x3], dim=0)
print(out)
# tensor([[1, 2, 3],
#         [1, 2, 3],
#         [4, 5, 6],
#         [7, 8, 9]])

那有了Fire层模块之后就可以开始搭建我们的模型,那在搭建的过程中,各个层的参数如何设置呢,原文中给了如下表
在这里插入图片描述

  • 第一列Layer name/type 指的是层的名称和类型
  • 第二列Output size 指的是输出尺寸
  • 第三列是filter size/stride (if not a fire layer)滤波器(卷积核/池化核)的大小(不包含Fire层)
  • 第四列depth 卷积层的深度,可以无视掉,没什么用
  • 第五-第七 给的就是Fire 层的参数了

再后面的是稀疏性字节大小还有修剪前后的参数大小,这部分不用过于关注,可能要多提一下的就是这个稀疏性sparsity,他指的是卷积层里选择多少参数一直为0,但是并没有详细说具体是怎么实现的,然后我也去搜了一下,需要用一些正则化的东西才可以,这个问题我打算再详细理解一下,暂时我们都默认稀疏性是100,不再为了稀疏性降低参数量实现额外复杂的工作.

根据参数和结构实现代码

一维

import torch
from torchsummary import summary
class FireModule(torch.nn.Module):
    def __init__(self, in_channels, squeeze_channels, expand1x1_channels, expand1x3_channels):
        super(FireModule, self).__init__()
        self.squeeze = torch.nn.Conv1d(in_channels, squeeze_channels, kernel_size=1)
        self.relu = torch.nn.ReLU(inplace=True)
        self.expand1x1 = torch.nn.Conv1d(squeeze_channels, expand1x1_channels, kernel_size=1)
        self.expand1x3 = torch.nn.Conv1d(squeeze_channels, expand1x3_channels, kernel_size=3, padding=1)

    def forward(self, x):
        x = self.squeeze(x)
        x = self.relu(x)
        out1x1 = self.expand1x1(x)
        out1x3 = self.expand1x3(x)
        out = torch.cat([out1x1, out1x3], dim=1)
        return self.relu(out)


class SqueezeNet(torch.nn.Module):
    def __init__(self,in_channels,classes):
        super(SqueezeNet, self).__init__()
        self.features = torch.nn.Sequential(
            # conv1
            torch.nn.Conv1d(in_channels, 96, kernel_size=7, stride=2),
            torch.nn.ReLU(inplace=True),
            # maxpool1
            torch.nn.MaxPool1d(kernel_size=3, stride=2),
            # Fire2
            FireModule(96, 16, 64, 64),
            # Fire3
            FireModule(128, 16, 64, 64),
            # Fire4
            FireModule(128, 32, 128, 128),
            # maxpool4
            torch.nn.MaxPool1d(kernel_size=3, stride=2),
            # Fire5
            FireModule(256, 32, 128, 128),
            # Fire6
            FireModule(256, 48, 192, 192),
            # Fire7
            FireModule(384, 48, 192, 192),
            # Fire8
            FireModule(384, 64, 256, 256),
            # maxpool8
            torch.nn.MaxPool1d(kernel_size=3, stride=2),
            # Fire9
            FireModule(512, 64, 256, 256)
        )
        self.classifier = torch.nn.Sequential(
            # conv10
            torch.nn.Conv1d(512, classes, kernel_size=1),
            torch.nn.ReLU(inplace=True),
            # avgpool10
            torch.nn.AdaptiveAvgPool1d((1))
        )

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        x = torch.flatten(x, 1)
        return x




if __name__ == "__main__":
    # 创建一个SqueezeNet实例
    model = SqueezeNet(in_channels=3,classes=10)
    # model = FireModule(96,16,64,64)
    # 打印模型结构
    summary(model=model, input_size=(3, 224), device='cpu')

二维

import torch
from torchsummary import summary
class FireModule(torch.nn.Module):
    def __init__(self, in_channels, squeeze_channels, expand1x1_channels, expand3x3_channels):
        super(FireModule, self).__init__()
        self.squeeze = torch.nn.Conv2d(in_channels, squeeze_channels, kernel_size=1)
        self.relu = torch.nn.ReLU(inplace=True)
        self.expand1x1 = torch.nn.Conv2d(squeeze_channels, expand1x1_channels, kernel_size=1)
        self.expand3x3 = torch.nn.Conv2d(squeeze_channels, expand3x3_channels, kernel_size=3, padding=1)

    def forward(self, x):
        x = self.squeeze(x)
        x = self.relu(x)
        out1x1 = self.expand1x1(x)
        out3x3 = self.expand3x3(x)
        out = torch.cat([out1x1, out3x3], dim=1)
        return self.relu(out)


class SqueezeNet(torch.nn.Module):
    def __init__(self,in_channels,classes):
        super(SqueezeNet, self).__init__()
        self.features = torch.nn.Sequential(
            # conv1
            torch.nn.Conv2d(in_channels, 96, kernel_size=7, stride=2),
            torch.nn.ReLU(inplace=True),
            # maxpool1
            torch.nn.MaxPool2d(kernel_size=3, stride=2),
            # Fire2
            FireModule(96, 16, 64, 64),
            # Fire3
            FireModule(128, 16, 64, 64),
            # Fire4
            FireModule(128, 32, 128, 128),
            # maxpool4
            torch.nn.MaxPool2d(kernel_size=3, stride=2),
            # Fire5
            FireModule(256, 32, 128, 128),
            # Fire6
            FireModule(256, 48, 192, 192),
            # Fire7
            FireModule(384, 48, 192, 192),
            # Fire8
            FireModule(384, 64, 256, 256),
            # maxpool8
            torch.nn.MaxPool2d(kernel_size=3, stride=2),
            # Fire9
            FireModule(512, 64, 256, 256)
        )
        self.classifier = torch.nn.Sequential(
            # conv10
            torch.nn.Conv2d(512, classes, kernel_size=1),
            torch.nn.ReLU(inplace=True),
            # avgpool10
            torch.nn.AdaptiveAvgPool2d((1,1))
        )

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        x = torch.flatten(x, 1)
        return x




if __name__ == "__main__":
    # 创建一个SqueezeNet实例
    model = SqueezeNet(in_channels=3,classes=10)
    # model = FireModule(96,16,64,64)
    # 打印模型结构
    summary(model=model, input_size=(3, 224, 224), device='cpu')

结束

对于SqueezeNet的第二个和第三个模型,我先把其他的轻量级网络都复现完之后我再回来写一下,对于入门来说先实现个基础版本就够用了

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1114138.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

WebSocket 入门案例

目录 WebSocket入门案例WebSocket-server新增项目:添加依赖:yml:启动类: frontend-server前端项目:添加依赖:添加yml:启动类:前端引入JS:前端页面:后端代码:测试: WebSocket 入门案…

众和策略:地产板块发力走高,荣盛发展涨停,碧桂园等大幅拉升

地产板块20日盘中发力走高,到发稿,金科股份、荣盛展开涨停,中南建造、富丽家族涨超7%,华夏夸姣涨逾6%。 港服方面,内资地产股亦走强,到发稿,珠光控股涨超20%,碧桂园涨近10%&#xf…

PBA.常用人工智能预测分析算法

相同的数据型态,利用不同的方法分析,就可以解决不同的课题。例如目前已相当纯熟的人脸识别技术,在国防应用可以进行安保工作;企业可做员工门禁系统;可结合性别、年龄辨识让卖场进行市调分析,或结合追踪技术…

聚焦于先进电池技术等领域的前沿研究和应用,龙讯旷腾出席中国化学会第二届能源化学青年论坛

成都站电催化培训 2023年龙讯团队线下培训已走过北京、西安等城市,前几期均以定向邀请非公开的形式培训,应大家的积极号召,本期电催化成都站的培训我们将以公开招募的形式举办,并且保留前几期的优惠(前30位免费&#…

html网页代码块高亮加行号

程序示例精选 html网页代码块高亮加行号 如需安装运行环境或远程调试,见文章底部个人QQ名片,由专业技术人员远程协助! 前言 这篇博客针对《html网页代码块高亮加行号》编写代码,代码整洁,规则,易读。 学习…

shopee平台现在好做吗

Shopee 是一家知名的电子商务平台,特别在东南亚地区非常流行。是否在 Shopee 平台做生意是否好做取决于多种因素,包括你的产品、市场竞争、营销策略和运营能力等。 以下是一些考虑因素: 1、产品选择:选择畅销的产品或具有市场需求…

需要在 MySQL 服务器中监控的重要指标

MySQL是一个开源的关系数据库管理系统,它基于客户端-服务器模型运行,使用SQL作为其通信模式。它具有灵活性和可扩展性、高安全性、易用性以及无缝处理大型数据集的能力,由于其广泛的功能,MySQL 被用作数据库管理系统的一部分。 什…

初识Java 14-2 测试

目录 测试驱动开发(TDD) 日志 调试 使用JDB进行调试 基准测试 微基准测试 Java微基准测试工具(JMH) 分析和优化 重构 本笔记参考自: 《On Java 中文版》 测试驱动开发(TDD) 测试驱动开…

高博学子参加第二届火焰杯软件测试高校就业选拔赛喜获佳绩

近日,高博软件学院软件工程教研室组织指导全院近80名学生参加第二届火焰杯软件测试高校就业选拔赛。经过初赛、决赛,共有13名同学获优秀奖。获奖名单如下:(排名不分先名):滕美妙、陈虹霖、陆春媚、陈媛、周…

厚壁菌门/拟杆菌门——肠道菌群的阴阳面,代表什么

在研究肠道菌群或复杂微生物样本构成时,“门"(Phylum)是细菌分类的高级分类单位之一。 细菌分类依次为门纲目科属种亚种,最大的分类层面是门,以前写过人群肠道菌群构成主要是以拟杆菌门和厚壁菌门为主&#xff0c…

基于YOLOv5[n/s/m/l/x]全系列参数模型开发构建小麦麦穗颗粒智能化精准检测识别计数系统

小麦麦穗颗粒或者是其他农作物颗粒计数本身是一件很繁琐枯燥的事情,这种事情交给程序来做是最好不过的了,最近正好在做课题项目,导师给的题目就是跟农业相关的,这里想的就是基于目标检测模型来开发构建一套智能化的精准检测计数系…

影像科室图像存储与传输系统源码 智能化影像报告系统源码

影像科室信息管理系统源码 RIS/PACS系统源码 PACS三维影像处理系统源码 影像科室信息管理系统,它包括RIS系统、PACS工作站和PACS服务器系统。提供强大的结构和智能化的影像报告系统、支持各种图象操作,以及实现图像的路由、预取、多级多层次存储。 系统特…

Selenium实战教程----Selenium中的动作

Selenium中针对元素进行的动作在代码中可以分为两类: Selenium::WebDriver::ActionBuilder类中的动作方法Selenium::WebDriver::Element类中的动作方法 其中ActionBuilder类中的动作方法比较丰富,基本涵盖了所有可以进行的操作。 而Element类的动作比较…

ERR_PNPM_JSON_PARSE Unexpected end of JSON input while parsing empty string in

终端报错:  ERR_PNPM_JSON_PARSE  Unexpected end of JSON input while parsing empty string in   报错原因:依赖没有删除干净  解决办法:  ①删除node_modules  ②在package.json的dependencies删除不需要依赖  ③重新pnpm i

干货 | 锁向环到底是什么?是怎么进行倍频的?

你们有没有这样一个疑问,就是CPU的主频怎么做到几个GHz呢? 每一秒要给处理器几亿个脉冲,就拿11代I7处理器来说,它的基本频率就可达2.5GHz,但在我们常规的认知中,频率的大小取决于晶振的频率,比…

速成offer收割机,接口自动化测试面试题,精准打击面试...

目录:导读 前言一、Python编程入门到精通二、接口自动化项目实战三、Web自动化项目实战四、App自动化项目实战五、一线大厂简历六、测试开发DevOps体系七、常用自动化测试工具八、JMeter性能测试九、总结(尾部小惊喜) 前言 1、我们测试的接口…

虚拟机灾备建设中NFS存储直接访问技术应用

NFS(Network File System)是一种网络文件系统,允许不同计算机之间共享文件和目录。在Linux系统中,可以使用NFS协议来访问网络存储。 当新服务器硬盘不足时,旧的服务器硬盘容量大,不拔硬盘的情况下&#xf…

5VUSB输入双节磷酸铁锂电池串联应用升压充电管理IC-YB5081

5VUSB输入双节磷酸铁锂电池串联应用升压充电管理IC 概要: YB5081是一款5V输入,支持双节磷酸铁锂电池的升压充电管理IC.YB5081集成功率Mos,采用异步开关架构。使其在应用时仅需极少的外圈器件,可有效减少整体方案尺寸,降低BOM成本…

2023年起重机司机(限门式起重机)证考试题库及起重机司机(限门式起重机)试题解析

题库来源:安全生产模拟考试一点通公众号小程序 2023年起重机司机(限门式起重机)证考试题库及起重机司机(限门式起重机)试题解析是安全生产模拟考试一点通结合(安监局)特种作业人员操作证考试大纲和(质检局)特种设备作…

2023年施工升降机司机(建筑特殊工种)证模拟考试题库及施工升降机司机(建筑特殊工种)理论考试试题

题库来源:安全生产模拟考试一点通公众号小程序 2023年施工升降机司机(建筑特殊工种)证模拟考试题库及施工升降机司机(建筑特殊工种)理论考试试题是由安全生产模拟考试一点通提供,施工升降机司机(建筑特殊工种)证模拟考试题库是根据施工升降机司机(建筑特…