【深度学习】:《PyTorch入门到项目实战》(十六):卷积神经网络:NiN(Network in Network)和1×1卷积(附Pytorch源码)

news2024/11/27 20:39:49

专栏介绍

  • ✨本文收录于【深度学习】:《PyTorch入门到项目实战》专栏,此专栏主要记录如何使用PyTorch实现深度学习算法及其项目实战,目前pytorch基础计算已经更新完,正在更新CNN,接下来会陆续更新RNN、CV、NLP、搜推广项目实战,尽量坚持每周持续更新,欢迎大家订阅!
  • 🌸个人主页:JOJO数据科学
  • 📝个人介绍:某985统计硕士在读
  • 💌如果文章对你有帮助,欢迎✌关注、👍点赞、✌收藏、👍订阅专栏
  • 参考资料:《动手学深度学习》

在这里插入图片描述

文章目录

  • 一、引言
  • 二、1×1卷积网络
      • 1️⃣介绍
      • 2️⃣计算逻辑
      • 3️⃣主要作用
  • 二、NiN架构
      • 1️⃣NiN块
      • 2️⃣全局平均池化层
  • 三、 Pytorch代码实现
      • 1️⃣定义NiN块
      • 2️⃣定义NiN网络
      • 3️⃣加载数据集
      • 4️⃣初始化模型
      • 5️⃣模型训练与评估
  • 四、总结

一、引言

我们之前介绍了LeNet,AlexNet,VGG。在我们用卷积层提取特征后,全连接层的参数如下:
image-20230716130241885

可以看出,全连接层的参数很大,很占内存。因此,如果可以不使用全连接层,或者说减少全连接层的个数,可以减少参数,减少过拟合。下面我们来讨论这一章要介绍的内容NiN

二、1×1卷积网络

1️⃣介绍

在架构内容设计方面,其中一个比较有帮助的想法是使用1×1卷积。如下所示

image-20230717104555545

也许你会好奇,1×1的卷积能做什么呢?不就是乘以数字么?似乎没有什么用,我们来具体看看它如何工作的。假设一个1×1卷积,这里是数字2,输入一张6×6×1的图片,然后对它做卷积,卷积层大小为1×1×1,结果相当于把这个图片乘以2,所以前三个单元格分别是2、4、6等等。

在这里插入图片描述

用1×1的过滤器进行卷积,似乎用处不大,只是对输入矩阵乘以某个数字。但这仅仅是对于6×6×1的一个通道的图片来说,1×1卷积效果不佳

如果是一张6×6×32的图片,那么使用1×1过滤器进行卷积效果更好。具体来说,1×1卷积所实现的功能是遍历这36个单元格,计算左图中32个数字和过滤器中32个数字的元素积之和,然后应用ReLU非线性函数。

image-20230717105302660

2️⃣计算逻辑

上述1×1×32过滤器中的32可以这样理解,一个神经元的输入是32个数字(输入图片中32个通道中的数字),即相同高度和宽度上某一切片上的32个数字,这32个数字具有不同通道,乘以32个权重(将过滤器中的32个数理解为权重)。所以1×1卷积可以从根本上理解为对这32个不同的位置都应用一个全连接层。和传统的CNN在卷积层之后接全连接层相比,全连接层会将特征图展平为一个向量,并进行线性变换。而用1×1卷积核替代全连接层,将空间上的每个像素点作为一组特征进行卷积操作,从而保留了空间结构信息,避免展平为向量,提高了网络的表达能力。

此外,当有多个卷积层时,我们可以更改输出通道数,如下图所示。

image-20230717110114634

输入图片为4×4×3

  • 第一个1×1卷积是增加通道数(通道从3→6)
    原始图像 (4×4×3) → Conv 1 (6个1×1 ×3 kernel) → Conv1 输出图像 (4×4×6)
  • 第二个1×1卷积是减少通道数(通道从6→2)。
    Conv1输出图片(4×4×6)→Conv 2(2个1×1×6 kernel)→Conv2输出图片(4×4×2)

3️⃣主要作用

  • 通道数调整:通过1×1卷积,将卷积核的通道数设置为所需的输出通道数,就可以实现通道数的调整。这样就能够控制特征图的维度,使其适应后续层的输入要求。
  • 特征融合:1×1卷积通过调整卷积核的通道数,将不同通道的特征图相加,从而实现特征的融合。
  • 非线性映射:尽管1×1卷积没有类似3×3或5×5卷积核的局部感知视野,但它仍然引入了非线性映射。由于卷积操作中存在激活函数,1×1卷积能够对特征图进行非线性变换,并增强网络的表达能力。

下面我们来看一下NiN架构

二、NiN架构

NiN(Network in Network) 由Min Lin等人在2013年提出。它的设计目标是通过引入多层感知机结构(MLPConv)来提高卷积神经网络(CNN)的表达能力。

NiN框架的核心思想是在卷积层内嵌套一个小型MLP网络,用于增强特征表达能力。与传统的CNN不同,NiN框架在每个卷积层中使用1×1的卷积核,这样可以引入非线性变换和参数共享,从而提高特征的非线性表示能力。具体而言,NiN框架包含了以下几个关键组件:

1️⃣NiN块

一个NiN块由1个卷积层和2个1×1卷积层构成。其中,第一个卷积层负责提取空间特征,第2个1×1卷积层将通道数降低,第3个1×1卷积层则将通道数增加。这样的设计可以增加网络的非线性表示能力,并且通过1×1卷积层调整通道数可以灵活控制特征图的维度。

image-20230717104119621

2️⃣全局平均池化层

在NiN网络的最后,通过全局平均池化层将特征图的空间维度降为1×1,得到一个通道数等于类别数的特征图。然后,通过Softmax函数进行分类。

主要结构如下

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-3YJDYZTe-1689582257905)(https://raw.githubusercontent.com/19973466719/jojo-pic/main/20230717155327.png#)]

下面我们用Pytorch来实现基于NiN架构对Fashion-MNIST数据集识别

三、 Pytorch代码实现

1️⃣定义NiN块

这里和原始的nin块有两个1×1卷积不同,我这里只使用了1个1×1卷积,因为数据集比较小,所以使用1个1×1卷积层效果更好,并且也大大节省了训练时间。

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

def nin_block(in_channels, out_channels, kernel_size, strides, padding):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size, strides, padding),
        nn.ReLU(),
        nn.Conv2d(out_channels, out_channels, kernel_size=1), nn.ReLU())

2️⃣定义NiN网络

net = nn.Sequential(
    nin_block(1, 96, kernel_size=11, strides=4, padding=0),
    nn.MaxPool2d(3, stride=2),
    nin_block(96, 256, kernel_size=5, strides=1, padding=2),
    nn.MaxPool2d(3, stride=2),
    nin_block(256, 384, kernel_size=3, strides=1, padding=1),
    nn.MaxPool2d(3, stride=2),
    nn.Dropout(0.5),
    # 标签类别数是10
    nin_block(384, 10, kernel_size=3, strides=1, padding=1),
    nn.AdaptiveAvgPool2d((1, 1)),
    # 将四维的输出转成二维的输出,其形状为(批量大小,10)
    nn.Flatten())

3️⃣加载数据集

# 加载Fashion-MNIST数据集
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((224,224)),
    transforms.Normalize((0.5,), (0.5,))
])

trainset = torchvision.datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True, num_workers=2)

testset = torchvision.datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False, num_workers=2)

4️⃣初始化模型

# Xavier初始化:
def init_weights(m):
    if type(m) == nn.Linear or type(m) == nn.Conv2d: #对全连接层和卷积层初始化
        nn.init.xavier_uniform_(m.weight)
net.apply(init_weights)
# 检查是否有可用的GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = net.to(device)

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.1)

5️⃣模型训练与评估

# 训练模型
num_epochs = 10
train_losses = []
test_losses = []

for epoch in range(num_epochs):
    train_loss = 0.0
    test_loss = 0.0
    correct = 0
    total = 0

    # 训练模型
    model.train()
    for images, labels in trainloader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()

    # 测试模型
    model.eval()
    with torch.no_grad():
        for images, labels in testloader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            test_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    avg_train_loss = train_loss / len(trainloader)
    avg_test_loss = test_loss / len(testloader)
    train_losses.append(avg_train_loss)
    test_losses.append(avg_test_loss)

    print(f"Epoch [{epoch+1}/{num_epochs}], Train Loss: {avg_train_loss:.4f}, Test Loss: {avg_test_loss:.4f}, Acc: {correct/total*100:.2f}%")

# 绘制测试误差和训练误差曲线
plt.plot(train_losses, label='Training Loss')
plt.plot(test_losses, label='Testing Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.show()
Epoch [1/10], Train Loss: 2.2827, Test Loss: 2.0189, Acc: 33.98%
Epoch [2/10], Train Loss: 1.7984, Test Loss: 1.2083, Acc: 59.43%
Epoch [3/10], Train Loss: 1.0804, Test Loss: 0.9443, Acc: 65.87%
Epoch [4/10], Train Loss: 1.0075, Test Loss: 0.8990, Acc: 67.74%
Epoch [5/10], Train Loss: 0.8120, Test Loss: 0.8054, Acc: 69.70%
Epoch [6/10], Train Loss: 0.7379, Test Loss: 0.7040, Acc: 73.27%
Epoch [7/10], Train Loss: 0.4918, Test Loss: 0.5636, Acc: 79.59%
Epoch [8/10], Train Loss: 0.4344, Test Loss: 0.4079, Acc: 84.71%
Epoch [9/10], Train Loss: 0.4012, Test Loss: 0.3962, Acc: 85.51%
Epoch [10/10], Train Loss: 0.3833, Test Loss: 0.3757, Acc: 85.74%

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-3mJlH7Ho-1689582257906)(https://raw.githubusercontent.com/19973466719/jojo-pic/main/20230717161928.png)]

从结果来看,和AlexNet相比,精确度还要低一些,可能是我们的数据集太小,把batch_size调大一点可能效果会好一些。

四、总结

NiN框架的主要优点是:

  1. 提高了表达能力:引入了MLP结构,增强了网络的非线性表示能力,有助于更好地捕捉复杂的特征。
  2. 减少参数:使用1×1卷积核和全局平均池化层,减少了网络中的参数数量,降低了过拟合的风险。
  3. 提高计算效率:由于减少了参数数量,NiN框架相对于传统的CNN具有更高的计算效率。

🔎总的来说,NiN框架在许多计算机视觉任务中取得了很好的性能,成为CNN架构设计中的重要思路之一,后续我们要介绍的GoogleNet借用了这种思想。
本章的介绍到此介绍,如果文章对你有帮助,请多多点赞、收藏、评论、订阅支持!!【深度学习】:《PyTorch入门到项目实战》

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/769649.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

ceph三个接口的创建

目录 创建 CephFS 文件系统 MDS 接口 服务端操作 客户端操作 创建 Ceph 块存储系统 RBD 接口 1、创建一个名为 rbd-demo 的专门用于 RBD 的存储池 2、将存储池转换为 RBD 模式 ​编辑 3、初始化存储池 4、创建镜像 5、镜像管理 6、Linux客户端使用 客户端使用 RBD …

从MVC跨越到DDD微服务架构是如何演进的

微服务架构演进 领域模型中对象的层次从内到外依次是:值对象、实体、聚合和限界上下文。 实体或值对象的简单变更,一般不会让领域模型和微服务发生大变。但聚合的重组或拆分却可以。因为聚合内业务功能内聚,能独立完成特定业务。那聚合的重组…

嵌入式软件测试笔记12 | 什么是状态转换测试?如何开展?

12 | 什么是状态转换测试?如何开展? 1 状态转换测试简介1.1 基于状态的测试设计技术1.2 系统行为 2 故障类别2.1 状态2.2 防护2.3 转换2.4 事件2.5 其它 3 状态转换测试技术3.1 编写状态-事件表3.2 编写转换树3.3 编写合法测试用例的测试脚本 3.4 编写非…

【100天精通python】Day5:python 基本语句,流程控制语句

目录 1. 条件语句 1.1 if语句 1.2 if-else语句 1.3 if-elif-else语句 2 循环语句 2.1 for循环 2.2 while循环: 3 跳转语句 3.1 break语句 3.2 continue语句 3.3 pass语句 4 异常处理语句(try-except语句) 5 语句嵌套 5.1 条…

将maven库中没有的jar包导入本地库后编译还提示缺这个jar包

Maven本地仓库有对应的jar包但是报找不到 问题原因 第一,你本地仓库对应的包文件夹下有_remote.repositories这个文件; 第二,你的项目现在连接不到下载这个包的仓库; 以上两点就是本地明明有对应的jar包,但项目中还…

SpringBoot项目中WEB页面放哪里--【SB系列之008】

SpringBoot系列文章目录 SpringBoot 的项目编译即报错处理–SB系列之001 —第一部的其它章节可以通过001链接 SpringBoot项目中WEB页面放哪里–【SB系列之008】 SpringBoot项目中WEB与Controller的联系–【SB系列之009】 ———————————————— 文章目录 SpringBoo…

VulnHub靶机-Socnet

文章目录 实验环境信息收集存活主机探测端口服务探测目录扫描反弹shell建立代理 内网探测漏洞发现漏洞利用 权限提升总结 实验环境 靶机地址: https://www.vulnhub.com/entry/boredhackerblog-social-network,454/靶机ip:192.168.56.101 攻击机&#x…

银河麒麟系统挂载的home文件夹无执行权限

银河麒麟系统挂载的home文件夹里放可执行程序,脚本无法运行,最后修改/etc/fstab文件如下所示就可以了 修改完重启电脑就可以执行可执行程序了

Thymeleaf + Layui+快速分页模板(含前后端代码)

发现很多模块写法逻辑太多重复的&#xff0c;因此把分页方法抽取出来记录以下&#xff0c;以后想写分页直接拿来用即可&#xff1a; 1. 首先是queryQrEx.html&#xff1a; <!DOCTYPE html> <html xmlns:th"http://www.w3.org/1999/xhtml"> <head>…

专业信用修复!揭阳市企业信用修复办法,企业修复好处及失信危害,申请条件

本文小编将为大家介绍2023年揭阳市企业信用修复办法指导及信用修复意义、好处等内容&#xff0c;详情如下&#xff0c;如果有广州市、深圳市、江门市、佛山市、汕头市、湛江市、韶关市、中山市、珠海市、茂名市、肇庆市、阳江市、惠州市、潮州市、揭阳市、清远市、河源市、东莞…

Nautlius Chain主网正式上线,模块Layer3时代正式开启

Nautilus Chain是在Vitalik Buterin提出Layer3理念后&#xff0c; 对Layer3领域的全新探索。作为行业内首个模块化Layer3链&#xff0c;我们正在对Layer3架构进行早期的定义&#xff0c;并有望进一步打破公链赛道未来长期的发展格局。 在今年年初&#xff0c;经过我们一系列紧张…

RTD2555T RTD2556T(Typec) eDP屏显示介绍

RTD2555T是新一代HDMI2DP转eDP的IC&#xff0c;主要应该用Typec便携式显示器驱动芯片&#xff0c;搭配LDR6282等PD IC实现2个Typec口正反插&#xff0c;充电等&#xff0c;支持按键菜单操作&#xff0c;支持串口通信控制等功能定制。

【stable diffusion】保姆级入门课程01-Stable diffusion(SD)文生图究竟是怎么一回事

目录 学前视频 0.本章素材 1.什么是文生图 2.界面介绍 2.1切换模型的地方 2.2切换VAE 2.3功能栏 2.4提示词 1.提示词的词性 2.提示词的语法 3.提示词的组成 4.提示词的权重调整 2.5参数调整栏 1.采样方法 2.采样迭代步数 3.面部修复 4.平铺图 5.高清修复 6.…

数据中心机房建设,务必确定这13个关键点

下午好&#xff0c;我的网工朋友。 关于机房、机架的相关内容&#xff0c;给你们说了不少。 今天再给你补充个知识点&#xff0c;机房建设&#xff0c;要怎么做。 熟悉机房建设的网工朋友可能都知道&#xff0c;一个全面的数据中心机房建设工程一般包括&#xff1a; 综合布…

多线程——互斥和同步

多线程—互斥和同步 文章目录 多线程—互斥和同步多线程互斥互斥量mutex互斥量的接口初始化互斥量静态分配动态分配&#xff1a;pthread_mutex_init初始化互斥量 销毁互斥量int pthread_mutex_destroy销毁互斥量 互斥量加锁和解锁pthread_mutex_lock加锁pthread_mutex_trylock非…

IPv4 与 IPv6:网络性能和带宽的比较

网络连接已经成为我们生活中不可或缺的一部分&#xff0c;而IP地址是网络连接中最基本和最重要的部分。IPv4和IPv6是两种常用的IP地址协议&#xff0c;它们之间有着很大的差异。 首先&#xff0c;让我们了解一下IPv4和IPv6的基本概念。IPv4是互联网上使用最广泛的IP地址协议&am…

(栈队列堆) 剑指 Offer 30. 包含min函数的栈 ——【Leetcode每日一题】

❓ 剑指 Offer 30. 包含min函数的栈 难度&#xff1a;简单 定义栈的数据结构&#xff0c;请在该类型中实现一个能够得到栈的最小元素的 min 函数在该栈中&#xff0c;调用 min、push 及 pop 的时间复杂度都是 O ( 1 ) O(1) O(1)。 示例: MinStack minStack new MinStack()…

Hadoop——HDFS的Java API操作(文件上传、下载、删除等)

1、创建Maven项目 2、修改pom.xml文件 <dependencies><!-- Hadoop所需依赖包 --><dependency><groupId>org.apache.hadoop</groupId><artifactId>hadoop-common</artifactId><version>2.7.0</version></dependency&…

linux之Ubuntu系列(四)用户管理 用户和权限 chmod 超级用户root, R、W、X、T、S

r(Read&#xff0c;读取)&#xff1a;对文件而言&#xff0c;具有读取文件内容的权限&#xff1b;对目录来说&#xff0c;具有浏览目 录的权限。 w(Write,写入)&#xff1a;对文件而言&#xff0c;具有新增、修改文件内容的权限&#xff1b;对目录来说&#xff0c;具有删除、移…