实现pytorch版的mobileNetV1

news2024/11/18 15:38:15

mobileNet具体细节,在前面已做了分析记录:轻量化网络-MobileNet系列-CSDN博客

这里是根据网络结构,搭建模型,用于图像分类任务。

1. 网络结构和基本组件

2. 搭建组件

(1)普通的卷积组件:CBL = Conv2d + BN + ReLU6;

(2)深度可分离卷积:DwCBL  = Conv dw+ Conv dp;

Conv dw+ Conv dp = {Conv2d(3x3) + BN + ReLU6 }  + {Conv2d(1x1) + BN + ReLU6};

Conv dw是3x3的深度卷积,通过步长控制是否进行下采样;

Conv dp是1x1的逐点卷积,通过控制输出通道数,控制通道维度的变化;

# 普通卷积
class CBN(nn.Module):
    def __init__(self, in_c, out_c, stride=1):
        super(CBN, self).__init__()
        self.conv = nn.Conv2d(in_c, out_c, 3, stride, padding=1, bias=False)
        self.bn = nn.BatchNorm2d(out_c)
        self.relu = nn.ReLU6(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x
# 深度可分离卷积: 深度卷积(3x3x1) + 逐点卷积(1x1xc卷积)
class DwCBN(nn.Module):
    def __init__(self, in_c, out_c, stride=1):
        super(DwCBN, self).__init__()
        # conv3x3x1, 深度卷积,通过步长,只控制是否缩小特征hw
        self.conv3x3 = nn.Conv2d(in_c, in_c, 3, stride, padding=1, groups=in_c, bias=False)
        self.bn1 = nn.BatchNorm2d(in_c)
        self.relu1 = nn.ReLU6(inplace=True)
        # conv1x1xc, 逐点卷积,通过控制输出通道数,控制通道维度的变化
        self.conv1x1 = nn.Conv2d(in_c, out_c, 1, stride=1, padding=0, bias=False)
        self.bn2 = nn.BatchNorm2d(out_c)
        self.relu2 = nn.ReLU6(inplace=True)

    def forward(self, x):
        x = self.conv3x3(x)
        x = self.bn1(x)
        x = self.relu1(x)
        x = self.conv1x1(x)
        x = self.bn2(x)
        x = self.relu2(x)
        return x

3. 搭建网络

class MobileNetV1(nn.Module):
    def __init__(self, class_num=1000):
        super(MobileNetV1, self).__init__()
        self.stage1 = torch.nn.Sequential(
            CBN(3, 32, 2),  # 下采样/2
            DwCBN(32, 64, 1)
        )
        self.stage2 = torch.nn.Sequential(
            DwCBN(64, 128, 2),  # 下采样/4
            DwCBN(128, 128, 1)
        )
        self.stage3 = torch.nn.Sequential(
            DwCBN(128, 256, 2),  # 下采样/8
            DwCBN(256, 256, 1)
        )
        self.stage4 = torch.nn.Sequential(
            DwCBN(256, 512, 2),  # 下采样/16
            DwCBN(512, 512, 1),  # 5个
            DwCBN(512, 512, 1),
            DwCBN(512, 512, 1),
            DwCBN(512, 512, 1),
            DwCBN(512, 512, 1),
        )
        self.stage5 = torch.nn.Sequential(
            DwCBN(512, 1024, 2),  # 下采样/32
            DwCBN(1024, 1024, 1)
        )

        # classifier
        self.avg_pooling = torch.nn.AdaptiveAvgPool2d((1, 1))
        self.fc = torch.nn.Linear(1024, class_num, bias=True)

        # self.classifier = torch.nn.Softmax()  # 原始的softmax值
        # torch.log_softmax 首先计算 softmax 然后再取对数,因此在数值上更加稳定。
        # 在分类网络在训练过程中,通常使用交叉熵损失函数(Cross-Entropy Loss)。
        # torch.nn.CrossEntropyLoss 会在内部进行 softmax 操作,因此在网络的最后一层不需要手动加上 softmax 操作。

    def forward(self, x):
        scale1 = self.stage1(x)  # /2
        scale2 = self.stage2(scale1)
        scale3 = self.stage3(scale2)
        scale4 = self.stage4(scale3)
        scale5 = self.stage5(scale4)  # /32. 7x7

        x = self.avg_pooling(scale5)  # (b,1024,7,7)->(b,1024,1,1)
        x = torch.flatten(x, 1)  # (b,1024,1,1)->(b,1024,)
        x = self.fc(x)  # (b,1024,)  -> (b,1000,)
        return x


if __name__ == '__main__':
    m1 = MobileNetV1(class_num=1000)
    input_data = torch.randn(64, 3, 224, 224)
    output = m1.forward(input_data)
    print(output.shape)

待续。。。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1363035.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

光明源:智慧公厕在实际应用中作用

什么是智慧公厕呢? 智慧公厕是一种应用先进科技和智能化技术的公共卫生设施,旨在提高公厕的管理效率、服务水平以及用户体验。这类公厕整合了各种现代技术,包括实时监控系统、智能预约服务、在线反馈机制、卫生自动化技术、导航服务、电子支…

Git 常用命令详解及如何在IDEA中操作

文章目录 前言发现宝藏一、初识Git1.Git概述2. Git的功能3. Git运行图示 二、Git下载安装三、Git 代码托管服务1.常用的 Git 代码托管服务2.使用码云代码托管服务 四、Git 常用命令1.Git 全局设置2.获取Git 仓库3.工作区、暂存区、版本库 概念4.Git 工作区中文件的两种状态5.本…

视频云存储/视频智能分析平台EasyCVR在麒麟系统中无法启动该如何解决?

安防视频监控/视频集中存储/云存储/磁盘阵列EasyCVR平台可拓展性强、视频能力灵活、部署轻快,可支持的主流标准协议有国标GB28181、RTSP/Onvif、RTMP等,以及支持厂家私有协议与SDK接入,包括海康Ehome、海大宇等设备的SDK等。平台既具备传统安…

2024年阿里云、腾讯云、华为云、LightNode、硅云服务器如何选?怎么买最划算?[最新价格表]

很多小伙伴都有一颗上云的心,包括我自己 有事没事的折腾一下自己的小破站,也挺有意思的! 那么,云服务器哪家好?优惠力度哪家大?活动入口哪里进?云服务器如何配置?如何选型&#xf…

时间序列预测 — VMD-LSTM实现单变量多步光伏预测(Tensorflow):单变量转为多变量预测多变量

目录 1 数据处理 1.1 导入库文件 1.2 导入数据集 ​1.3 缺失值分析 2 VMD经验模态分解 2.1 VMD分解实验 2.2 VMD-LSTM预测思路 3 构造训练数据 4 LSTM模型训练 5 LSTM模型预测 5.1 分量预测 5.2 可视化 时间序列预测专栏链接:https://blog.csdn.net/qq_…

前端Web系统架构设计

文章目录 1.目录结构定义2. 路由封装2.1 API路由定义2.2 组件路由定义 3. Axios请求开发4. 环境变量封装5. storage模块封装(sessionStorage, localStorage)6. 公共函数封装(日期,金额,权限..)7. 通用交互定义(删除二次确认,类别,面包屑...)8. 接口全貌概览 1.目录结构定义 2. …

Flume基础知识(十):Flume 聚合实战

1)案例需求: hadoop100上的 Flume-1 监控文件/opt/module/group.log, hadoop101上的 Flume-2 监控某一个端口的数据流, Flume-1 与 Flume-2 将数据发送给 hadoop102 上的 Flume-3,Flume-3 将最终数据打印 到控制台。…

基于Java实现全功能电子商城

🍅文末获取源码联系🍅 👇🏻 精彩项目推荐订阅👇🏻 不然下次找不到哟 基于SpringBoot的旅游网站 基于SpringBoot的MusiQ音乐网站 感兴趣的可以先收藏起来,还有大家在毕设选题,项目以及…

【数据库】视图索引执行计划多表查询面试题

文章目录 一、视图1.1 概念1.2 视图与数据表的区别1.3 优点1.4 语法1.5 实例 二、索引2.1 什么是索引2.2.为什么要使用索引2.3 优缺点2.4 何时不使用索引2.5 索引何时失效2.6 索引分类2.6.1.普通索引2.6.2.唯一索引2.6.3.主键索引2.6.4.组合索引2.6.5.全文索引 三、执行计划3.1…

2024.1.5 关于 二叉平衡树(AVL 树)详解

目录 二叉搜索树 二叉搜索树的简介 二叉搜索树的查找 二叉搜索树的效率 AVL树 AVL 树的简介 AVL 树的实现 AVL树的旋转 右单旋 左单旋 左右双旋 右左双旋 完整 AVL树插入代码 验证 AVL 树 AVL 树的性能 二叉搜索树 要想了解关于二叉平衡树的相关知识,了…

mnn-llm: 大语言模型端侧CPU推理优化

在大语言模型(LLM)端侧部署上,基于 MNN 实现的 mnn-llm 项目已经展现出业界领先的性能,特别是在 ARM 架构的 CPU 上。目前利用 mnn-llm 的推理能力,qwen-1.8b在mnn-llm的驱动下能够在移动端达到端侧实时会话的能力,能够在较低内存…

安全与认证Week3

目录 Key Management 密钥管理 密钥交换、证书 密钥的类别 密钥管理方面 密钥分发问题 密钥分发方案 混合密钥分发 公钥分发 公钥证书 X.509 理解X.509 X.509证书包含 X.509使用过程 X.509身份验证服务 X.509版本3 取消 由X.509引申关于CA 用户认证、身份管理…

手机上下载 Linux 系统

我们首先要下载 Ternux 点击下载以及vnc viewer (提取码:d9sX),需要魔法才行 下载完以后我们打开 Ternux 敲第一个命令 pkg upgrade 这个命令是用来跟新软件的 敲完命令就直接回车,如果遇到需要输入 Y/N 的地方全部输入 Y 下一步 #启动TMOE…

HackTheBox - Medium - Linux - Ambassador

Ambassador Ambassador 是一台中等难度的 Linux 机器,用于解决硬编码的明文凭据留在旧版本代码中的问题。首先,“Grafana”CVE (“CVE-2021-43798”) 用于读取目标上的任意文件。在研究了服务的常见配置方式后,将在其…

原生JS调用OpenAI GPT接口并实现ChatGPT逐字输出效果

效果&#xff1a; 猜你感兴趣&#xff1a;springbootvue实现ChatGPT逐字输出打字效果 附源码&#xff0c;也是小弟原创&#xff0c;感谢支持&#xff01; 没废话&#xff0c;上代码&#xff1a; <!DOCTYPE html> <html lang"en"> <head><me…

CRM的request管理笔记

1 request类型 request有两种&#xff0c;device request和link request。 link request link req是对link进行精确控制。 link req是对每个link的请求&#xff0c;比如某一帧是否需要bubble recovery、某一帧是否需要长曝光等feature。device request 对一个设备进行每帧控制…

Linux系统IO—探索输入输出操作的奥秘

&#x1f3ac;慕斯主页&#xff1a;修仙—别有洞天 ♈️今日夜电波&#xff1a;HEART BEAT—YOASOBI 2:20━━━━━━️&#x1f49f;──────── 5:35 &#x1f504; ◀️ ⏸ ▶️ ☰ …

如何通过绘制【学习曲线】来判断模型是否【过拟合】

学习曲线是一种图形化工具&#xff0c;用于展示模型在训练集和验证集&#xff08;或测试集&#xff09;上的性能随着训练样本数量的增加而如何变化。它可以帮助我们理解模型是否受益于更多的训练数据&#xff0c;以及模型是否可能存在过拟合或欠拟合问题。学习曲线的x轴通常是训…

Win11开始菜单怎么改成经典模式-Win11切换Win10风格开始菜单方法

Win11切换Win10风格开始菜单方法 方法一&#xff1a; 1. 在Win11电脑上下载一个“Startallback”软件&#xff0c;下载安装完成后&#xff0c;在“控制面板”里打开该软件。 2. 打开后&#xff0c;在“欢迎界面”&#xff0c;选择使用“Windows10主题样式”并重启电脑即可。…

CodeWave智能开发平台--03--目标:应用创建--06变量作用域和前后端服务逻辑

摘要 本文是网易数帆CodeWave智能开发平台系列的第08篇&#xff0c;主要介绍了基于CodeWave平台文档的新手入门进行学习&#xff0c;实现一个完整的应用&#xff0c;本文主要完成06变量作用域和前后端服务逻辑 CodeWave智能开发平台的08次接触 CodeWave参考资源 网易数帆Co…