NiN详解

news2024/11/22 16:11:08

入门小菜鸟,希望像做笔记记录自己学的东西,也希望能帮助到同样入门的人,更希望大佬们帮忙纠错啦~侵权立删。

✨完整代码在我的github上,有需要的朋友可以康康✨

https://github.com/tt-s-t/Deep-Learning.git

目录

一、NiN网络的背景

二、NiN网络结构

1、NiN块

2、NiN网络架构

三、NiN的亮点

1、使用1*1卷积代替全连接层

2、使用全局平均池化代替最后的全连接层

四、NiN网络的缺点

五、NiN代码实现对FashionMNIST数据的分类


一、NiN网络的背景

AlexNet和VGG都是先由卷积层构成的模块充分抽取空间特征,再由全连接层构成的模块来输出分类结果。但是其中的全连接层的参数量过于巨大,因此NiN提出用1*1卷积代替全连接层,串联多个由卷积层和“全连接”层构成的小网络来构建⼀个深层网络。


二、NiN网络结构

1、NiN块

这里的1*1卷积(即mlpconv结构)就是为了代替全连接层。

 🌳可代替的原因🌳

我们知道全连接层的原理公式是:Y=XW+B,其中X\in R^{n*d}, W\in R^{d*1}

我们现在假设有一个待卷积的结果Z,shape为(N,C,H,W),1*1卷积对应的shape为(1,1,C),这时的1*1卷积相当于W。

这里的Z相当于X,每一个样本(C,H,W)就相当于是n=H*W,d=C,即相当于排成了(H*W,C)后再和W相乘。每个通道下相同位置的像素都和卷积核W(d=C)相乘,即完成了XW,默认B=0。

这么替代的好处是:

(1)灵活放缩通道数:通过控制卷积核的数量达到通道数的放缩。

(2)增加非线性。1×1卷积核的卷积过程相当于全连接层的计算过程,并且还加入了非线性函数,从而可以增加网络的非线性。

(3)计算参数少(简化模型)

(4)不改变图像空间结构

全连接层会破坏图像的空间结构(要先展平),而1*1卷积层不会破坏图像的空间结构。

(5)输入可以是任意尺寸

全连接层的输入尺寸是固定的,因为全连接层的参数个数取决于图像大小。而卷积层的输入尺寸是任意的,因为卷积核的参数个数与图像大小无关。

2、NiN网络架构

由多个NiN块按需堆起来后进行全局池化,最后再得到每个归属类所得的分数。


三、NiN的亮点

1、使用1*1卷积代替全连接层

这么做的优点如上所述

2、使用全局平均池化代替最后的全连接层

这也是为了解决全连接层参数过多的问题。

对于分类问题,在之前通常的解决方法是:在最后一个卷积层的feature map和全连接层连接,最后通过softmax进行分类。但全连接层带来的问题就是参数空间过大,容易过拟合。早期AlexNet采用了Dropout来减轻过拟合,提高网络的泛化能力,但依旧无法解决参数过多的问题。

而全局平均池化的做法是将全连接层去掉,在最后一层,将卷积层数目设为与类别数目一致,然后全局pooling, 从而直接输出属于各个类的结果分数。

优势:

(1)全局平均池化更原生地支持于卷积结构,通过加强特征映射与相应分类的对应关系,特征映射可以更容易解释为分类映射。

(2)全局平均池化一层没有需要优化的参数,减少大量的训练参数有效避免过拟合,因此对输入的空间转换具有更强的鲁棒性


四、NiN网络的缺点

全局平均池化对特征图简单地进行加权取平均操作可能会丢失一些有用信息


五、NiN代码实现对FashionMNIST数据的分类

代码也可以查看https://github.com/tt-s-t/Deep-Learning.git中的NiN文件夹

这里展示模型的搭建

import torch.nn as nn

def nin_block(in_channels, out_channels, kernel_size, stride, padding):
    block = nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding),
        nn.ReLU(),
        nn.Conv2d(out_channels, out_channels, kernel_size=1),
        nn.ReLU(),
        nn.Conv2d(out_channels, out_channels, kernel_size=1),
        nn.ReLU())
    return block

class NiN(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nin_block(1, 96, kernel_size=5, stride=1, padding=2),
            nn.MaxPool2d(kernel_size=3, stride=2,padding=1),
            nn.Dropout(0.5),
            nin_block(96, 256, kernel_size=5, stride=1, padding=2),
            nn.MaxPool2d(kernel_size=3, stride=2,padding=1),
            nn.Dropout(0.5),
            # 标签类别数是10
            nin_block(256, 10, kernel_size=3, stride=1, padding=1),
            #全局平均代替最后的全连接层
            nn.AdaptiveAvgPool2d((1,1))
            )
    
    def forward(self,input):
        x = self.net(input)
        x = x.view(x.size(0), 10)
        #print(x.shape)
        return x

调用网络进行训练与测试

import torch
import torch.nn as nn
import torchvision
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from nin import NiN

data_train = torchvision.datasets.FashionMNIST(
        root="FashionMNIST", train=True,transform=transforms.Compose([
        transforms.Resize((32,32)),
        transforms.ToTensor()
    ]), download=True)
data_test = torchvision.datasets.FashionMNIST(
    root="FashionMNIST", train=False, transform=transforms.Compose([
        transforms.Resize((32,32)),
        transforms.ToTensor()
    ]), download=True)
data_train_loader = DataLoader(data_train, batch_size=32, shuffle=True, num_workers=4)#数据加载器加载训练数据
data_test_loader = DataLoader(data_test, batch_size=16, num_workers=4)#数据加载器加载测试数据

model = NiN()

device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
model.to(device)

# config
epochs = 12#迭代次数
lr = 0.0001#学习率

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=lr)

def train():
    print('start training')
    # 训练模型
    for epoch in range(epochs):
        model.train()#训练模式
        epoch_loss = 0
        epoch_accuracy = 0
        for _, (data, label) in enumerate(data_train_loader):
            data = data.to(device)
            label = label.to(device)
            output = model(data)#输出
            loss = criterion(output, label)#计算loss

            optimizer.zero_grad()#清空过往梯度(因为每次循环都是一次完整的训练)
            loss.backward()#反向传播
            optimizer.step()#更新参数

            acc = (output.argmax(dim=1) == label).float().mean()
            epoch_accuracy += acc / len(data_train_loader)#当前训练平均准确率
            epoch_loss += loss / len(data_train_loader)#累计loss

        print(f'EPOCH:{epoch:2}, train loss:{epoch_loss:.4f}, train acc:{epoch_accuracy:.4f}')

def test():
    best_accuracy = 0
    model.eval() #加与不加都行
    total_correct = 0 #记录正确数目
    avg_loss = 0.0 #记录平均错误
    for _, (images, labels) in enumerate(data_test_loader):
        images = images.to(device)
        labels = labels.to(device)
        output = model(images)
        avg_loss += criterion(output, labels).sum() #将损失累加起来
        pred = output.detach().max(1)[1] #max(1)得到每行最大值的第一个(得到概率最大的那个),.detach()指这个tensor永远不需要计算其梯度
        total_correct += pred.eq(labels.view_as(pred)).sum() #累加与pred同类型的labels(即为正确)的数值,即记录正确分数(如果预测对了对应的位置就是1)

    avg_loss /= len(data_test) #平均误差
    if(float(total_correct) / len(data_test) > best_accuracy):
        torch.save(model.cpu().state_dict(), 'model.pth')
    best_accuracy = max(best_accuracy,float(total_correct) / len(data_test))
    print('bestaccuracy is %f' % best_accuracy)
    print('Test Avg. Loss: %f, Accuracy: %f' % (avg_loss.detach().cpu().item(), float(total_correct) / len(data_test))) #输出信息

def main(): #开始训练和测试
    train()
    test()

if __name__ == '__main__':
    main()

欢迎大家在评论区批评指正,谢谢大家~

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/138923.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【C语言开源库】 一个只有500行代码的开源http服务器:Tinyhttpd学习

项目搬运,带中文翻译:https://github.com/nengm/Tinyhttpd在嵌入式中,我们HTTP服务器用得最多的就是boa还有就是goahead,但是这2个代码量比较大,而Tinyhttpd只有几百行,比较有助于我们学习。一、编译及运行直接make之后…

用Python让奇怪的想法变成现实,2023年继续创作

2023年继续写作,用文章记录生活 时间过得真快,一下就到2023年了。 由于疫情肆虐,在网络的游弋的实现也长了,写作的自然也多了。 回想一下,2018-2021年这三年时间里一篇文章也没写过为0,哈哈,没…

【EHub_tx1_tx2_E100】Ubuntu18.04 + ROS_ Melodic + NVISTAR VP300 激光雷达 评测

简介:介绍NVISTAR 的二维DTOF激光雷达 在EHub_tx1_tx2_E100载板,TX1核心模块环境(Ubuntu18.04)下测试ROS驱动,打开使用RVIZ 查看点云数据,本文的前提条件是你的TX1里已经安装了ROS版本:Melodic。…

滴滴前端一面经典手写面试题

实现bind 实现bind要做什么 返回一个函数,绑定this,传递预置参数bind返回的函数可以作为构造函数使用。故作为构造函数时应使得this失效,但是传入的参数依然有效 // mdn的实现 if (!Function.prototype.bind) {Function.prototype.bind f…

Kuberneters(2)- Pod详解

第四章 实战入门 本章节将介绍如何在kubernetes集群中部署一个nginx服务,并且能够对其进行访问。 Namespace ​ Namespace是kubernetes系统中的一种非常重要资源,它的主要作用是用来实现多套环境的资源隔离或者多租户的资源隔离。 ​ 默认情况下&…

路由跳转同一个界面,但是params不同。页面不刷新?(路由的key)

文章目录引入知识点:路由的key值思路:结论:解决方法:效果:应用场景:引入知识点:路由的key值 如果不设置路由的key值,默认情况下是根据路径判断的,就是不包括params值 例子&#xff…

MySQL5-数据类型

目录 1.数值类型(分为整型和浮点型) 2.字符串类型 3.日期类型 MySQL和Java编程一样,创建表时要考虑数据类型。 MySQL表组成:列名/列数据类型;数据。 1.数值类型(分为整型和浮点型) 数据类型…

天工开物 #4 构建一个受保护的网站

前段时间,我出于兴趣试着做了一个需要登录鉴权才能访问的个人网站,最终以 Docusaurus[1] 为内容框架,Next.js[2] 做中间件,Vercel[3] 托管网站,再加上 Auth0[4] 作为鉴权解决方案,实现了一个基本免费的方案…

数位DP入门笔记(1)HUD-2089

题目: 题目理解和思路: 1.此题是给一个6位车牌号,正着不能含有连着的62,不能有4。 2.判断车牌号可能会采用dfs,因为每增加一位数就包含带4,或者形成62两种不合法情况(事实上没有用到&#xf…

java学习day67(乐友商城)商品详情及静态化

1.商品详情 当用户搜索到商品,肯定会点击查看,就会进入商品详情页,接下来我们完成商品详情页的展示, 1.1.Thymeleaf 在商品详情页中,我们会使用到Thymeleaf来渲染页面,所以需要先了解Thymeleaf的语法。 …

带你深度剖析《数据在内存中的存储》——C语言

文章目录 一、数据类型介绍 二、整型在内存中的存储方式 2、1 原码、反码、补码的讲解 2、2 大小端介绍 2、2、1 大小端的概念 2、2、2 为什么要区分大小端存储呢? 2、2、3 大小端判断练习 三、浮点数在内存中的存储方式 3、1 浮点数在内存中的存储例题 3、2 浮点数…

TensorFlow2.0实战:Cats vs Dogs

数据集准备 在本文中,我们使用“Cats vs Dogs”的数据集。这个数据集包含了23,262张猫和狗的图像 你可能注意到了,这些照片没有归一化,它们的大小是不一样的 但是非常棒的一点是,你可以在Tensorflow Datasets中获取这个数据集 …

梦在远方路在脚下,社科院与杜兰大学金融管理硕士项目与你一路相伴

梦想是指引我们飞翔的翅膀,梦想是远方的灯塔指引着我们前进的方向。梦想距离我们很远,但路在脚下,只要朝着梦想前进,终有一天梦想会照进现实。就像拥有读研梦想的我们,在社科院杜兰金融管理硕士项目汲取能量&#xff0…

【Android OpenGL开发】OpenGL ES与EGL介绍

什么是OpenGL ES OpenGL(Open Graphics Library)是一个跨编程语言、跨平台的编程图形程序接口,主要用于图像的渲染。 Android提供了简化版的OpenGL接口,即OpenGL ES。 早先定义 OpenGL ES 是 OpenGL 的嵌入式设备版本&#xff…

Mac上超实用的6款软件,老用户都知道!

今天为大家带来的是6款超实用的Mac软件,让你不再走弯路。第一款:Amphetamine 防休眠的利器Amphetamine for mac是应用在Mac上的一款防休眠工具,可以自定义哪些程序运行时不休眠,做到自定义Mac睡眠时间,可以通过超级简单…

【数据结构】链式存储:链表(无头双向链表实现)

目录 🥇一:无头双向链表 🎒二、无头双向链表的实现 📘1.创建节点类 📒2.创建链表 📗3.打印链表 📕4.查找是否包含关键字key是否在单链表当中 📙5.得到单链表的长度 &#x1…

PCL中常用的高级采样方法

0. 简介 我们在使用PCL时候,常常不满足于常用的降采样方法,这个时候我们就想要借鉴一些比较经典的高级采样方法。这一讲我们将对常用的高级采样方法进行汇总,并进行整理,来方便读者完成使用 1. 基础下采样 1.1 点云随机下采样 …

代码随想录拓展day6 N皇后

代码随想录拓展day6 N皇后 只有这一个内容。一刷的时候也没弄太明白,二刷的时候补上。还有部分内容来自牛客网左老师的算法课程。 总体思路不容易想明白,优化也有很大难度。这要是面试能碰上基本就是故意不给过了吧。 思路 首先来看一下皇后们的约束…

Flink 容错恢复 2.0 2022 最新进展

摘要:本文整理自阿里云 Flink 存储引擎团队负责人,Apache Flink 引擎架构师 & PMC 梅源在 FFA 核心技术专场的分享。主要介绍在 2022 年度,Flink 容错 2.0 这个项目在社区和阿里云产品的进展,内容包括:Flink 容错恢…

基于ssm的个人健康管理系统

项目描述 临近学期结束,还是毕业设计,你还在做java程序网络编程,期末作业,老师的作业要求觉得大了吗?不知道毕业设计该怎么办?网页功能的数量是否太多?没有合适的类型或系统?等等。这里根据疫情当下,你想解决的问…