快速入门Torch构建自己的网络模型

news2025/2/25 12:16:16

真有用构建自己的网络模型

    • 读前必看
    • 刚学完Alex网络感觉很厉害的样子,我也要搭建一个
    • 可以看着网络结构实现上面的代码你已经很强了,千万不要再想实现VGG等网络!!!90%你能了解到的模型大佬早已实现好,直接调用就OK
    • 下面是源码用nn.Module实现的AlexNet,和我们实现的区别并不大,将模型print出来能看懂就可以
    • 不忘初心,构建自己的网络模型,将AlexNet输入改为单通道图片:
    • Tips

读前必看

  • 如何用框架复现论文中的模型不重要,重要的是明白网络模型原理!!!
  • 如何用框架复现论文中的模型不重要,重要的是明白网络模型原理!!!
  • 如何用框架复现论文中的模型不重要,重要的是明白网络模型原理!!!

刚学完Alex网络感觉很厉害的样子,我也要搭建一个

在这里插入图片描述

回想一下torch构建网络的几种方法

  • nn.Sequential直接顺序实现
  • nn.Module继承基类构建自定义模型
feature = nn.Sequential(
    nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
    nn.ReLU(inplace=True),
    nn.MaxPool2d(kernel_size=3, stride=2),
    nn.Conv2d(64, 192, kernel_size=5, padding=2),
    nn.ReLU(inplace=True),
    nn.MaxPool2d(kernel_size=3, stride=2),
    nn.Conv2d(192, 384, kernel_size=3, padding=1),
    nn.ReLU(inplace=True),
    nn.Conv2d(384, 256, kernel_size=3, padding=1),
    nn.ReLU(inplace=True),
    nn.Conv2d(256, 256, kernel_size=3, padding=1),
    nn.ReLU(inplace=True),
    nn.MaxPool2d(kernel_size=3, stride=2),
)

现在需要计算卷积后图像的维度,根据公式 image_shape = (image_shape - kernel_size + 2 * padding) / stride + 1计算

in_shape= 224
conv_size = [11, 5, 3, 3, 3]
padding_size = [2, 2, 1, 1, 1]
stride_size = [4, 1, 1, 1, 1]
# image_shape = (image_shape - kernel_size + 2 * padding) / stride + 1
for i in range(len(conv_size)):
    in_shape = (in_shape - conv_size[i] + 2 * padding_size[i]) / stride_size[i] + 1
    in_shape = math.floor(in_shape)
    if i in [0, 1, 4]:
        in_shape = (in_shape - 3 + 2 * 0) / 2 + 1
        in_shape = math.floor(in_shape)
print(in_shape)

计算结果是6,输出通道是256,所以特征有25666个,将下面代码添加到Sequential中完成自定义AlexNet构建

nn.Linear(256 * 6 * 6, 4096),
nn.ReLU(inplace=True),
nn.Dropout(p=dropout),
nn.Linear(4096, 4096),
nn.ReLU(inplace=True),
nn.Linear(4096, num_classes)

可以看着网络结构实现上面的代码你已经很强了,千万不要再想实现VGG等网络!!!90%你能了解到的模型大佬早已实现好,直接调用就OK

下面是源码用nn.Module实现的AlexNet,和我们实现的区别并不大,将模型print出来能看懂就可以

class AlexNet(nn.Module):
    def __init__(self, num_classes: int = 1000, dropout: float = 0.5) -> None:
        super().__init__()
        # _log_api_usage_once(self)
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(64, 192, kernel_size=5, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(192, 384, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
        )

        self.avgpool = nn.AdaptiveAvgPool2d((6, 6))
        self.classifier = nn.Sequential(
            nn.Dropout(p=dropout),
            nn.Linear(256 * 6 * 6, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(p=dropout),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, num_classes),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

不忘初心,构建自己的网络模型,将AlexNet输入改为单通道图片:

model = AlexNet()
model.features[0] = nn.Conv2d(1, 64, kernel_size=11, stride=4, padding=2)
print(model)

Tips

Q1: padding是卷积之后还是卷积之前还是卷积之后实现的?
padding是在卷积之前补0,如果愿意的话,可以通过使用torch.nn.Functional.pad来补非0的内容。

Q2:padding补0的默认策略是什么?
四周都补!如果pad输入是一个tuple的话,则第一个参数表示高度上面的padding,第2个参数表示宽度上面的

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1391326.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【python】16.Python语言进阶

Python语言进阶 重要知识点 生成式(推导式)的用法 prices {AAPL: 191.88,GOOG: 1186.96,IBM: 149.24,ORCL: 48.44,ACN: 166.89,FB: 208.09,SYMC: 21.29 } # 用股票价格大于100元的股票构造一个新的字典 prices2 {key: value for key, value in prices…

推荐几个Github高星GoLang管理系统

在Web开发领域,Go语言(Golang)以其高效、简洁、高并发等特性逐渐成为许多开发者的首选语言。有许多优秀的Go语言Web后台管理系统,这些项目星星众多,提供了丰富的功能和良好的代码质量。本文将介绍一些GitHub高星的GoLa…

07-微服务getaway网关详解

一、初识网关 在微服务架构中,一个系统会被拆分为很多个微服务。那么作为客户端要如何去调用这么多的微服务呢?如果没有网关的存在,我们只能在客户端记录每个微服务的地址,然后分别去调用。这样的话会产生很多问题,例…

设计模式之依赖倒转原则

在软件开发的世界里,设计模式一直是提升代码质量、确保软件稳定性以及优化软件可维护性的重要工具。而在这其中,依赖倒转原则无疑是其中最具代表性的设计模式之一。那么,什么是依赖倒转原则?它又为何如此重要?让我们一…

鼎捷软件获评国家级智能制造“AAA级集成实施+AA级咨询设计”供应商

为贯彻落实《“十四五”智能制造发展规划》,健全智能制造系统解决方案供应商(以下简称“供应商”)分类分级体系,推动供应商规范有序发展,智能制造系统解决方案供应商联盟组织开展了供应商分类分级评定(第一批)工作,旨在遴选一批专…

Python--GIL(全局解释器锁)

在Python中,GIL(全局解释器锁)是一个非常重要的概念,它对Python的多线程编程有着深远的影响。GIL是Python解释器级别的锁,用于保证任何时刻只有一个线程在执行Python字节码。这意味着即使在多核处理器上,Py…

[HTML]Web前端开发技术13(HTML5、CSS3、JavaScript )横向二级导航菜单 Web页面设计实例——喵喵画网页

希望你开心,希望你健康,希望你幸福,希望你点赞! 最后的最后,关注喵,关注喵,关注喵,佬佬会看到更多有趣的博客哦!!! 喵喵喵,你对我真的…

离线数据仓库-关于增量和全量

数据同步策略 数据仓库同步策略概述一、数据的全量同步二、数据的增量同步三、数据同步策略的选择 数据仓库同步策略概述 应用系统所产生的业务数据是数据仓库的重要数据来源,我们需要每日定时从业务数据库中抽取数据,传输到数据仓库中,之后…

探索Redis特殊数据结构:Bitmaps(位图)在实际中的应用

一、概述 Redis官方提供了多种数据类型,除了常见的String、Hash、List、Set、zSet之外,还包括Stream、Geospatial、Bitmaps、Bitfields、Probabilistic(HyperLogLog、Bloom filter、Cuckoo filter、t-digest、Top-K、Count-min sketch、Confi…

一文掌握SpringBoot注解之@Async知识文集(1)

🏆作者简介,普修罗双战士,一直追求不断学习和成长,在技术的道路上持续探索和实践。 🏆多年互联网行业从业经验,历任核心研发工程师,项目技术负责人。 🎉欢迎 👍点赞✍评论…

手把手教你搭建一个数据可视化看板

前言 俗话说的好,“字不如表,表不如图”、“有图有真相,一图胜千言”。 数据可视化就是用图的形式把基础数据直观,简洁的,高效的展示出来,今天为大家介绍一下如何使用葡萄城公司的嵌入式BI工具——Wyn商业…

Unity3d C#实现场景编辑/运行模式下3D模型XYZ轴混合一键排序功能(含源码工程)

前言 在部分场景搭建中需要整齐摆放一些物品(如仓库中的货堆、货架等),因为有交互的操作在单个模型上,每次总是手动拖动模型操作起来也是繁琐和劳累。 在这背景下,我编写了一个在运行或者编辑状态下都可以进行一键排序…

Day12 C基础(指针进阶)

文章目录 指针修饰1.const 修饰2.void 大小端二级指针指针和数组1.指针和一维数组直接访问:间接访问: 2.指针和二维数组直接访问:间接访问: 数组指针 指针修饰 1.const 修饰 1)const int num 10; const int num 10;num 3; i…

【面试合集】说说微信小程序的实现原理?

面试官:说说微信小程序的实现原理? 一、背景 网页开发,渲染线程和脚本是互斥的,这也是为什么长时间的脚本运行可能会导致页面失去响应的原因,本质就是我们常说的 JS 是单线程的 而在小程序中,选择了 H…

Mac系统下,保姆级Jenkins自动化部署Android

一、Jenkins自动化部署 1、安装jenkins 官网:macOS Installers for Jenkins LTS 选择macOS brew install jenkins-lts 安装最新: brew install jenkins-lts 启动jenkins服务: brew services start jenkins-lts 重启jenkins服务: brew services restart jenkin…

YOLOv5改进系列(27)——添加SCConv注意力卷积(CVPR 2023|即插即用的高效卷积模块)

【YOLOv5改进系列】前期回顾: YOLOv5改进系列(0)——重要性能指标与训练结果评价及分析 YOLOv5改进系列(1)——添加SE注意力机制 YOLOv5改进系列(2)——添加CBAM注意力机制 YOLOv5改进系列&…

Netty-Netty源码分析

Netty线程模型图 Netty线程模型源码剖析图 Netty高并发高性能架构设计精髓 主从Reactor线程模型NIO多路复用非阻塞无锁串行化设计思想支持高性能序列化协议零拷贝(直接内存的使用)ByteBuf内存池设计灵活的TCP参数配置能力并发优化 无锁串行化设计思想 在大多数场景下&#…

如何用GPT进行论文润色与改写?

详情点击链接:如何用GPT/GPT4进行论文润色与改写?一OpenAI 1.最新大模型GPT-4 Turbo 2.最新发布的高级数据分析,AI画图,图像识别,文档API 3.GPT Store 4.从0到1创建自己的GPT应用 5. 模型Gemini以及大模型Claude2二…

1.16 day3 IO网络编程

用udp实现tftp下载功能 #include <myhead.h> #define PORT 69 #define IP "192.168.122.24" int xiazai(int sfd,struct sockaddr_in sin,int fd,socklen_t socklen) {char buf[516]"";char ack[4];short *p1(short *)buf;*p1htons(1);char *p2buf2…

数字化转型:为何失败率居高不下,以及如何避免重蹈覆辙

在当今快速发展的数字化时代&#xff0c;许多企业纷纷投身于数字化转型的浪潮中&#xff0c;以期通过技术革新提升竞争力、优化运营、提高效率。然而&#xff0c;尽管数字化转型的潜在益处巨大&#xff0c;但失败率却居高不下&#xff0c;甚至导致企业陷入困境。 本文将深入探讨…