PyTorch模型容器与AlexNet构建

news2024/11/24 15:56:50

文章和代码已经归档至【Github仓库:https://github.com/timerring/dive-into-AI 】或者公众号【AIShareLab】回复 pytorch教程 也可获取。

文章目录

  • 模型容器与AlexNet构建
    • nn.Sequetial
      • 总结
    • nn.ModuleList
    • nn.ModuleDict
    • 容器总结
    • AlexNet实现

模型容器与AlexNet构建

除了上述的模块之外,还有一个重要的概念是模型容器 (Containers),常用的容器有 3 个,这些容器都是继承自nn.Module

  • nn.Sequetial:按照顺序包装多个网络层
  • nn.ModuleList:像 python 的 list 一样包装多个网络层,可以迭代
  • nn.ModuleDict:像 python 的 dict 一样包装多个网络层,通过 (key, value) 的方式为每个网络层指定名称。

nn.Sequetial

深度学习中,特征提取和分类器这两步被融合到了一个神经网络中。在卷积神经网络中,前面的卷积层以及池化层可以认为是特征提取部分,而后面的全连接层可以认为是分类器部分。比如 LeNet 就可以分为特征提取分类器两部分,这 2 部分都可以分别使用 nn.Seuqtial 来包装。

代码如下:

class LeNetSequetial(nn.Module):
    def __init__(self, classes):
        super(LeNet2, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 6, 5),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(6, 16, 5),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )
        self.classifier = nn.Sequential(
            nn.Linear(16*5*5, 120),
            nn.ReLU(),
            nn.Linear(120, 84),
            nn.ReLU(),
            nn.Linear(84, classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size()[0], -1)
        x = self.classifier(x)
        return x

在初始化时,nn.Sequetial会调用__init__()方法,将每一个子 module 添加到 自身的_modules属性中。这里可以看到,我们传入的参数可以是一个 list,或者一个 OrderDict。如果是一个 OrderDict,那么则使用 OrderDict 里的 key,否则使用数字作为 key。

    def __init__(self, *args):
        super(Sequential, self).__init__()
        if len(args) == 1 and isinstance(args[0], OrderedDict):
            for key, module in args[0].items():
                self.add_module(key, module)
        else:
            for idx, module in enumerate(args):
                self.add_module(str(idx), module)

网络初始化完成后有两个子 modulefeaturesclassifier

features中的子 module 如下,每个网络层以序号作为 key:

在进行前向传播时,会进入 LeNet 的forward()函数,首先调用第一个Sequetial容器:self.features,由于self.features也是一个 module,因此会调用__call__()函数,里面调用

result = self.forward(*input, **kwargs),进入nn.Seuqetialforward()函数,在这里依次调用所有的 module。上一个module的输出是下一个module的输入。

  def forward(self, input):
        for module in self:
            input = module(input)
        return input

在上面可以看到在nn.Sequetial中,里面的每个子网络层 module 是使用序号来索引的,即使用数字来作为key。

一旦网络层增多,难以查找特定的网络层,这种情况可以使用 OrderDict (有序字典)。可以与上面的代码对比一下

class LeNetSequentialOrderDict(nn.Module):
    def __init__(self, classes):
        super(LeNetSequentialOrderDict, self).__init__()

        self.features = nn.Sequential(OrderedDict({
            'conv1': nn.Conv2d(3, 6, 5),
            'relu1': nn.ReLU(inplace=True),
            'pool1': nn.MaxPool2d(kernel_size=2, stride=2),

            'conv2': nn.Conv2d(6, 16, 5),
            'relu2': nn.ReLU(inplace=True),
            'pool2': nn.MaxPool2d(kernel_size=2, stride=2),
        }))

        self.classifier = nn.Sequential(OrderedDict({
            'fc1': nn.Linear(16*5*5, 120),
            'relu3': nn.ReLU(),

            'fc2': nn.Linear(120, 84),
            'relu4': nn.ReLU(inplace=True),

            'fc3': nn.Linear(84, classes),
        }))
        ...
        ...
        ...

总结

nn.Sequetialnn.Module的容器,用于按顺序包装一组网络层,有以下两个特性。

  • 顺序性:各网络层之间严格按照顺序构建,我们在构建网络时,一定要注意前后网络层之间输入和输出数据之间的形状是否匹配
  • 自带forward()函数:在nn.Sequetialforward()函数里通过 for 循环依次读取每个网络层,执行前向传播运算。这使得我们我们构建的模型更加简洁

nn.ModuleList

nn.ModuleListnn.Module的容器,用于包装一组网络层,以迭代的方式调用网络层,主要有以下 3 个方法:

  • append():在 ModuleList 后面添加网络层
  • extend():拼接两个 ModuleList
  • insert():在 ModuleList 的指定位置中插入网络层

下面的代码通过列表生成式来循环迭代创建 20 个全连接层,非常方便,只是在 forward()函数中需要手动调用每个网络层。

class ModuleList(nn.Module):
    def __init__(self):
        super(ModuleList, self).__init__()
        self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(20)])

    def forward(self, x):
        for i, linear in enumerate(self.linears):
            x = linear(x)
        return x


net = ModuleList()

print(net)

fake_data = torch.ones((10, 10))

output = net(fake_data)

print(output)

nn.ModuleDict

nn.ModuleDictnn.Module的容器,用于包装一组网络层,以索引的方式调用网络层,主要有以下 5 个方法:

  • clear():清空 ModuleDict
  • items():返回可迭代的键值对 (key, value)
  • keys():返回字典的所有 key
  • values():返回字典的所有 value
  • pop():返回一对键值,并从字典中删除

下面的模型创建了两个ModuleDictself.choicesself.activations,在前向传播时通过传入对应的 key 来执行对应的网络层。

class ModuleDict(nn.Module):
    def __init__(self):
        super(ModuleDict, self).__init__()
        self.choices = nn.ModuleDict({
            'conv': nn.Conv2d(10, 10, 3),
            'pool': nn.MaxPool2d(3)
        })

        self.activations = nn.ModuleDict({
            'relu': nn.ReLU(),
            'prelu': nn.PReLU()
        })

    def forward(self, x, choice, act):
        x = self.choices[choice](x)
        x = self.activations[act](x)
        return x


net = ModuleDict()

fake_img = torch.randn((4, 10, 32, 32))

output = net(fake_img, 'conv', 'relu')
# output = net(fake_img, 'conv', 'prelu')
print(output)

容器总结

  • nn.Sequetial:顺序性,各网络层之间严格按照顺序执行,常用于 block 构建,在前向传播时的代码调用变得简洁
  • nn.ModuleList:迭代行,常用于大量重复网络构建,通过 for 循环实现重复构建
  • nn.ModuleDict:索引性,常用于可选择的网络层

AlexNet实现

AlexNet 特点如下:

  • 采用 ReLU 替换饱和激活函数,减轻梯度消失
  • 采用 LRN (Local Response Normalization) 对数据进行局部归一化,减轻梯度消失
  • 采用 Dropout 提高网络的鲁棒性,增加泛化能力
  • 使用 Data Augmentation,包括 TenCrop 和一些色彩修改

AlexNet 的网络结构可以分为两部分:features 和 classifier。

可以在计算机视觉库torchvision.models中找到 AlexNet 的代码,通过看可知使用了nn.Sequential来封装网络层。

class AlexNet(nn.Module):

    def __init__(self, num_classes=1000):
        super(AlexNet, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(64, 192, kernel_size=5, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(192, 384, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
        )
        self.avgpool = nn.AdaptiveAvgPool2d((6, 6))
        self.classifier = nn.Sequential(
            nn.Dropout(),
            nn.Linear(256 * 6 * 6, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, num_classes),
        )

    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/744053.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

解决固态硬盘只显示一半容量的好方法,解放隐藏的存储空间!

硬盘只显示一半容量”! “几天前,我的闪迪固态硬盘出现了一些奇怪的事情,这是个500GB的硬盘,但系统没有显示全部容量,只显示了250GB。这是什么原因?我该怎么办呢?如果大家有解决过类似问题,请…

使用SpringBoot+React搭建一个Excel报表平台

摘要:本文由葡萄城技术团队于CSDN原创并首发。转载请注明出处:葡萄城官网,葡萄城为开发者提供专业的开发工具、解决方案和服务,赋能开发者。 前言 Excel报表平台是一款功能强大、操作简单的系统平台,可以帮助用户上传…

海量倾斜摄影模型数据web端上传发布,在线浏览、在线分享,你还不知道吗?

倾斜摄影模型突出的特点就是数据量较大,这是由其高精度、对地表全覆盖的真实影像所决定的。如何将海量倾斜摄影模型数据加载到地图中并进行在线浏览是行业用户一直关心的内容,现在通过「四维轻云」就可以实现地理空间数据的在线管理、编辑及分享。 1、倾…

青岛大学_王卓老师【数据结构与算法】Week05_04_案例引入_学习笔记

本文是个人学习笔记,素材来自青岛大学王卓老师的教学视频。 一方面用于学习记录与分享, 另一方面是想让更多的人看到这么好的《数据结构与算法》的学习视频。 如有侵权,请留言作删文处理。 课程视频链接: 数据结构与算法基础…

第二章:在html中使用javascript

1、在html页面中插入js的主要方法就是使用<script>元素 2、html4.01为<script>定义了以下6个属性&#xff1a;【language已经废弃&#xff0c;其他5个属性都是可选的】 async 表示应该立即下载脚本&#xff0c;但不应该妨碍页面中的其他操作&#xff0c;比如下载…

.NET Core 数据库DB First自动生成,Sqlite,sql server,Mysql

文章目录 前言数据库ORM代码自动添加前期准备安装Nuget Sql serverMysqlSqlite查询结果 前言 .NET Core是C# .NET 未来发展的必然趋势&#xff0c;C# 要像Java一样跨平台运行。这里解决一个.NET core 会遇到的问题&#xff0c;如何添加ORM框架。 ORM是数据库对象映射关系模型…

Anaconda的安装和配置

对于自学Python的小伙伴来说&#xff0c;在刚开始&#xff0c;我们就得要安装Python以及python的库&#xff0c;但是我们可以通过安装Anaconda很好地解决这一难题&#xff0c;给我们初学者节省很多令人头疼的环境安装问题&#xff0c;今天我就为大家分享下Anaconda的介绍&#…

什么是加密领域的 Web 3.0?

随着科技的不断进步和互联网的发展&#xff0c;我们正逐渐迈入数字经济时代。在这个时代中&#xff0c;加密领域的Web 3.0成为了一个备受关注的话题。从区块链技术到加密货币&#xff0c;从去中心化应用程序到智能合约&#xff0c;这些新兴技术正在改变着我们对互联网的认知。本…

仓库24代拣货标签——功能特点

1. 通过无线方式快速刷新屏幕&#xff1b; 2. 移动式功能用法&#xff08;自动切换基站进行通信&#xff09;&#xff1b; 3. 电量低于 50%的情况下&#xff0c;提供外接供电&#xff0c;可以对电池进行充电&#xff0c;充电时会亮红灯&#xff0c;充满后亮绿灯&#xff08;如果…

Vue3挂载全局方法及组件中如何使用

文章目录 前言一、在mian.ts&#xff08;mian.js&#xff09;中配置全局变量1、如何封装 二、如何调用1.template中调用2.在script标签中如何拿到 前言 在Vue3项目中&#xff0c;需要频繁使用某一个方法。配置到全局感觉会方便很多。 例如&#xff1a;因为很多页面都需要对时…

openEuler 22.03 LTS登录AWS Marketplace

openEuler 22.03 LTS镜像正式登录AWS Marketplace&#xff0c;目前在亚太及欧洲15个Region开放使用&#xff0c;后续将开放更多版本和区域&#xff0c;openEuler 22.03 LTS AMI(Amazon Machine Images)由openEuler社区提供支持。 点击查看具体使用指导&#xff1a;https://www…

软件高效自动化部署:华为云部署服务CodeArts Deploy

随着互联网、数字化的发展&#xff0c;公司机构与各类企业往往需要进行大量频繁的软件部署&#xff0c;部署设备类型多样&#xff0c;如&#xff1a;本地机器、云上裸金属服务器、云上虚拟机与容器等。 面对多种部署模式、分布式复杂运行环境&#xff0c;该如何用最短时间、高…

华为战略方法论:BLM模型之战略意图(限制版)

目录 前言 案例 BLM模型 专栏列表 CSDN学院 个人简介 前言 对于任何一家企业来说&#xff0c;即便你没有清晰的战略规划。 一般也都是会有战略意图的。 战略意图具体是指你主观想要达成什么样的期望或者是状态。 换句话说&#xff0c;如果没有这种期望&#xff0c;你…

记录--盘点前端实现文件下载的几种方式

这里给大家分享我在网上总结出来的一些知识&#xff0c;希望对大家有所帮助 前端涉及到的文件下载还是很多应用场景的&#xff0c;那么前端文件下载有多少种方式呢&#xff1f;每种方式有什么优缺点呢&#xff1f;下面就来一一介绍。 1. 使用 a 标签下载 通过a标签的download属…

STM32 Proteus仿真空气质量检测环境监测苯PM2.5 MQ135温度湿度 -0068

STM32 Proteus仿真空气质量检测环境监测苯PM2.5 MQ135温度湿度 -0068 Proteus仿真小实验&#xff1a; STM32 Proteus仿真空气质量检测环境监测苯PM2.5 MQ135温度湿度 -0068 功能&#xff1a; 硬件组成&#xff1a;STM32F103R6单片机 LCD1602显示器DHT11温度湿度多个按键蜂鸣…

简要介绍 | 通信感知一体化:探索信息与通信技术的新边界

注1&#xff1a;本文系“简要介绍”系列之一&#xff0c;仅从概念上对通信感知一体化技术进行非常简要的介绍&#xff0c;不适合用于深入和详细的了解。 通信感知一体化&#xff1a;探索信息与通信技术的新边界 通信感知一体化&#xff08;ISAC&#xff09;&#xff1a;从入门到…

H3C端口镜像

端口镜像简介 端口镜像通过将指定端口或CPU的报文复制到与数据监测设备相连的端口&#xff0c;使用户可以利用数据监测设备分析这些复制过来的报文&#xff0c;以进行网络监控和故障排除。 基本概念 镜像源镜像源是指被监控的对象&#xff0c;该对象可以是端口或单板上的CPU&am…

Vue-cli脚手架

文章目录 前言搭建Vue-Cli脚手架安装npm可能出现的报错及解决办法国内淘宝镜像服务器 全局安装 vue-cli创建 Vue-Cli工程创建 Vue 的基本模板 总结终端打开/关闭操作创建Vue-Cli工程过程 前言 提示&#xff1a;这里可以添加本文要记录的大概内容&#xff1a; Vue CLI是一个基…

数字工厂管理系统如何解决汽配企业的管理痛点

在现代汽车产业中&#xff0c;汽车配件企业扮演着至关重要的角色。然而&#xff0c;许多汽配企业面临着管理痛点&#xff0c;如生产效率低下、库存管理困难、供应链不透明等。为了解决这些问题&#xff0c;越来越多的汽配企业转向数字工厂管理系统。本文将探讨数字工厂管理系统…

如何在Microsoft Word中快速对齐名字

对齐方式决定段落边缘的外观和方向:左对齐文本、右对齐文本、居中文本或对齐文本,这些文本沿左右边距均匀对齐。例如,在左对齐(最常见的对齐方式)的段落中,段落的左边缘与左边距齐平。 在 Microsoft Word 中,还有不少的人用敲空格的方式来对齐的目的,特别是两个字的姓名…