经典卷积神经网络 - NIN

news2024/9/21 19:01:46

网络中的网络,NIN。

AlexNet和VGG都是先由卷积层构成的模块充分抽取空间特征,再由全连接层构成的模块来输出分类结果。但是其中的全连接层的参数量过于巨大,因此NiN提出用1*1卷积代替全连接层,串联多个由卷积层和“全连接”层构成的小网络来构建⼀个深层网络。

AlexNet和VGG对LeNet的改进主要在于如何扩大和加深这两个模块。
或者,可以想象在这个过程的早期使用全连接层。然而,如果使用了全连接层,可能会完全放弃表征的空间结构。

网络中的网络NiN)提供了一个非常简单的解决方案:在每个像素的通道上分别使用多层感知机。也就是使用了多个1*1的卷积核。同时他认为全连接层占据了大量的内存,所以整个网络结构中没有使用全连接层。

NIN块

image-20231023194711049

一个卷积层后跟两个全连接层。

  • 步幅为1,无填充,输出形状跟卷积层输出一样。
  • 起到全连接层的作用。

NIN网络结构
在这里插入图片描述
image-20231024090401226

  • 无全连接层

  • 交替使用NIN块和步幅为2的最大池化层

    逐步减小高宽和增大通道数

  • 最后使用全局平均池化层得到输出

    其输入通道数是类别数

此网络结构总计4层: 3mlpconv + 1global_average_pooling

优点:

  1. 提供了网络层间映射的一种新可能;
  2. 增加了网络卷积层的非线性能力。

总结:

  • NIN块使用卷积层加上个 1 × 1 1\times 1 1×1卷积,后者对每个像素增加了非线性性
  • NIN使用全局平均池化层来替代VGG和AlexNet中的全连接层,不容易过拟合,更少的参数个数

代码实现

使用CIFAR-10数据集。

maxpooling不改变通道数,只改变长和宽

model.py

import torch
from torch import nn

# nin块
def nin_block(in_channels,out_channels,kernel_size,strides,padding):
    return nn.Sequential(
        nn.Conv2d(in_channels,out_channels,kernel_size,strides,padding),
        nn.ReLU(),
        nn.Conv2d(out_channels,out_channels,kernel_size=1),
        nn.ReLU(),
        nn.Conv2d(out_channels,out_channels,kernel_size=1),
        nn.ReLU(),
    )

# 构建网络
class NIN(nn.Module):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.model = nn.Sequential(
            nin_block(3,96,kernel_size=11,strides=4,padding=0),
            nn.MaxPool2d(3,stride=2),
            nin_block(96,256,kernel_size=5,strides=1,padding=2),
            nn.MaxPool2d(3,stride=2),
            nin_block(256,384,kernel_size=3,strides=1,padding=1),
            nn.MaxPool2d(3,stride=2),
            nn.Dropout(0.5),
            nin_block(384,10,kernel_size=3,strides=1,padding=1),
            nn.AdaptiveAvgPool2d((1,1)),
            nn.Flatten()
        )

    def forward(self,x):
        return self.model(x)


# 验证模型正确性
if __name__ == '__main__':
    nin = NIN()
    x = torch.ones((64,3,244,244))
    output = nin(x)
    print(output)

train.py

import torch
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets
from torchvision.transforms import transforms
from model import NIN

# 扫描数据次数
epochs = 3
# 分组大小
batch = 64
# 学习率
learning_rate = 0.01
# 训练次数
train_step = 0
# 测试次数
test_step = 0


# 定义图像转换
transform = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor()
])
# 读取数据
train_dataset = datasets.CIFAR10(root="./dataset",train=True,transform=transform,download=True)
test_dataset = datasets.CIFAR10(root="./dataset",train=False,transform=transform,download=True)
# 加载数据
train_dataloader = DataLoader(train_dataset,batch_size=batch,shuffle=True,num_workers=0)
test_dataloader = DataLoader(test_dataset,batch_size=batch,shuffle=True,num_workers=0)
# 数据大小
train_size = len(train_dataset)
test_size = len(test_dataset)
print("训练集大小:{}".format(train_size))
print("验证集大小:{}".format(test_size))

# GPU
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(device)
# 创建网络
net = NIN()
net = net.to(device)
# 定义损失函数
loss = nn.CrossEntropyLoss()
loss = loss.to(device)
# 定义优化器
optimizer = torch.optim.SGD(net.parameters(),lr=learning_rate)

writer = SummaryWriter("logs")
# 训练
for epoch in range(epochs):
    print("-------------------第 {} 轮训练开始-------------------".format(epoch))
    net.train()
    for data in train_dataloader:
        train_step = train_step + 1
        images,targets = data
        images = images.to(device)
        targets = targets.to(device)
        outputs = net(images)
        loss_out = loss(outputs,targets)
        optimizer.zero_grad()
        loss_out.backward()
        optimizer.step()

        if train_step%100==0:
            writer.add_scalar("Train Loss",scalar_value=loss_out.item(),global_step=train_step)
            print("训练次数:{},Loss:{}".format(train_step,loss_out.item()))

    # 测试
    net.eval()
    total_loss = 0
    total_accuracy = 0
    with torch.no_grad():
        for data in test_dataloader:
            test_step = test_step + 1
            images, targets = data
            images = images.to(device)
            targets = targets.to(device)
            outputs = net(images)
            loss_out = loss(outputs, targets)
            total_loss = total_loss + loss_out
            accuracy = (targets == torch.argmax(outputs,dim=1)).sum()
            total_accuracy = total_accuracy + accuracy
        # 计算精确率
        print(total_accuracy)
        accuracy_rate = total_accuracy / test_size

        print("第 {} 轮,验证集总损失为:{}".format(epoch+1,total_loss))
        print("第 {} 轮,精确率为:{}".format(epoch+1,accuracy_rate))
        writer.add_scalar("Test Total Loss",scalar_value=total_loss,global_step=epoch+1)
        writer.add_scalar("Accuracy Rate",scalar_value=accuracy_rate,global_step=epoch+1)
    torch.save(net,"./model/net_{}.pth".format(epoch+1))
    print("模型net_{}.pth已保存".format(epoch+1))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1132443.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

C语言:杨氏矩阵、杨氏三角、单身狗1与单身狗2

下面介绍四道题目和解法 1.杨氏矩阵 算法:右上角计算 题目:有一个数字矩阵,矩阵的每行从左到右是递增的,矩阵从上到下是递增的,请编写程序在这样的矩阵中查找某个数字是否存在。 要求:时间复杂度小于O(N…

react笔记基础部分(组件生命周期路由)

注意点&#xff1a; class是一个关键字&#xff0c; 类。 所以react 写class, 用classname &#xff0c;会自动编译替换class 点击方法&#xff1a; <button onClick {this.sendData}>给父元素传值</button>常用的插件&#xff1a; 需要引入才能使用的&#xf…

ubuntu执行普通用户或root用户执行apt-get update时报错Couldn‘t create temporary file /tmp/...

apt-get update无法更新&#xff0c;报错&#xff1a; Couldnt create temporary file /tmp/apt.conf.GSzv74 for passing config to&#xff0c;&#xff0c;&#xff0c; 这是由于/tmp目录没有权限导致的&#xff0c;解决办法&#xff1a; chmod 777 /tmp

额定电压输出电流:电源性能测试指标之一

额定电压和额定电流是电源设计生产时需要考虑的两个重要参数&#xff0c;额定电压是电源输出的电压标准&#xff0c;额定电流是电源能够提供的最大电流容量。这两个参数是评估电源性能的重要指标之一&#xff0c;指导着电气设备的正常工作运行。 额定电压输出电流测试方法 额定…

上门家政维修多城市代理多商户师傅入驻小程序开源版开发

上门家政维修多城市代理多商户师傅入驻小程序开源版开发 用户登录/注册&#xff1a;用户可以使用手机号或第三方账号登录或注册小程序。 服务分类&#xff1a;在主页上显示不同的服务分类&#xff0c;例如电器维修、家具拆装、管道疏通、清洁保洁等。 城市选择&#xff1a;用…

C++反转链表递归

文章目录 题目描述解题思路代码复杂度分析 题目描述 LCR 024. 反转链表 - 力扣&#xff08;LeetCode&#xff09; 给定单链表的头节点 head &#xff0c;请反转链表&#xff0c;并返回反转后的链表的头节点。 解题思路 这里我们采用递归的思路来解决首先我们分为两个视角来查看…

竞赛选题 深度学习卫星遥感图像检测与识别 -opencv python 目标检测

文章目录 0 前言1 课题背景2 实现效果3 Yolov5算法4 数据处理和训练5 最后 0 前言 &#x1f525; 优质竞赛项目系列&#xff0c;今天要分享的是 &#x1f6a9; **深度学习卫星遥感图像检测与识别 ** 该项目较为新颖&#xff0c;适合作为竞赛课题方向&#xff0c;学长非常推荐…

超好用的数据可视化工具推荐,小白也适用!

Excel、Tableau……可以做数据可视化的工具不少&#xff0c;但简单、好用又高效&#xff0c;甚至连无SQL基础的小白也能轻松使用的就真没几个。奥威BI数据可视化工具是少有的操作难度低、成本支出低、灵活自助分析能力强的BI工具。 1、操作难度低 奥威BI数据可视化工具的操作…

图片放大镜效果

安装&#xff1a; vueuse 插件 npm i vueuse/core 搜索&#xff1a; useMouseInElement 方法 <template><div ref"target"><h1>Hello world</h1></div> </template><script> import { ref } from vue import { useM…

图纸管理制度《三》

一、目的和使用范围 为了更好的规范设备及设计图纸的保管、发放和使用&#xff0c;根据业主仅提供四套图纸的实际情况&#xff0c;本着施工图纸服务施工的第一原则&#xff0c;合理利用有限的图纸资源&#xff0c;将《管理制度汇编》中的图纸管理制度进行细化&#xff0c;制定本…

视频与png图片批量分类技巧:轻松管理文件

在我们的日常工作中&#xff0c;经常会遇到需要处理大量文件的情况&#xff0c;其中就包括视频和png图片。这些文件数量繁多&#xff0c;如果一个个手动分类&#xff0c;不仅耗时而且容易出错。因此&#xff0c;掌握批量分类技巧成为了高效管理文件的关键。本文将为您运用云炫文…

地面文物古迹保护方案,用科技为文物古迹撑起“智慧伞”

一、行业背景 当前&#xff0c;文物保护单位的安防系统现状存在各种管理弊端&#xff0c;安防系统没有统一的平台&#xff0c;系统功能不足、建设标准不同&#xff0c;产品和技术多样&#xff0c;导致各系统独立&#xff0c;无法联动&#xff0c;形成了“信息孤岛”。地面文物…

64从零开始学Java之关于日期时间的新特性

作者&#xff1a;孙玉昌&#xff0c;昵称【一一哥】&#xff0c;另外【壹壹哥】也是我哦 千锋教育高级教研员、CSDN博客专家、万粉博主、阿里云专家博主、掘金优质作者 前言 在上一篇文章中&#xff0c;壹哥给大家讲解了Java里的格式化问题&#xff0c;这样我们就可以个性化设…

《排错》Python重新安装后,执行yum命令报错

安装完新的python以后&#xff0c;发现yum命令没法用 以下是报错信息&#xff1a; [rootmaster ~]# yum There was a problem importing one of the Python modules required to run yum. The error leading to this problem was:No module named yumPlease install a packag…

ResNet中文翻译(Deep Residual Learning for Image Recognition)

Deep Residual Learning for Image Recognition 用于图像识别的深度残差学习 原文&#xff1a;https://arxiv.org/abs/1512.03385 摘要 更深层次的神经网络更难训练。我们提出了一个残差学习框架&#xff0c;以简化比以前使用的网络更深的网络训练。我们明确地将层重新表示为参…

考研专业课程管理系统 JAVA开源项目 毕业设计

1. 项目介绍 基于JAVAVueSpringBootMySQL 的考研专业课程管理系统&#xff0c;包含了考研课程、考研专业、考研注册、考研院校和高校教师模块&#xff0c;还包含系统自带的用户管理、部门管理、角色管理、菜单管理、日志管理、数据字典管理、文件管理、图表展示等基础模块&…

React 图片瀑布流

思路&#xff1a; 根据浏览器宽度&#xff0c;确定列数&#xff0c;请求的图片列表数据是列数的10倍&#xff0c;按列数取数据渲染 Index.js: import React from react import { connect } from react-redux import { withRouter } from react-router-dom import { SinglePag…

22 行为型模式-状态模式

1 状态模式介绍 2 状态模式结构 3 状态模式实现 代码示例 //抽象状态接口 public interface State {//声明抽象方法,不同具体状态类可以有不同实现void handle(Context context); }

PyQt5入门4——给目标检测算法构建一个简单的界面

PyQt5入门4——给目标检测算法构建一个简单的界面 学习前言要构建怎么样的界面实例使用1、窗口构建a、按钮&#xff1a;获取图片b、Inputs、Outputs文本提示c、Inputs、Outputs图片显示d、箭头显示e、整点祝福 2、主程序运行 全部代码 学习前言 搞搞可视化界面哈&#xff0c;虽…

外汇天眼:CySEC向塞浦路斯投资公司的董事会成员发出警告

塞浦路斯证券与交易委员会&#xff08;CySEC&#xff09;已警告塞浦路斯投资公司&#xff08;CIFs&#xff09;的董事会成员&#xff0c;提醒他们加强履行职责&#xff0c;推动诚信和高道德标准的文化&#xff0c;此前监管行动揭示了合规方面的差距。 CySEC已经在加强监管措施…