pytorch-16 复现经典网络:LeNet5与AlexNet

news2024/12/28 22:08:37

一、相关概念

对于(10,3,227,227)数据表示,10张3通道的图,图的大小(特征数)为227*227.
通道数:作为卷积的输入通道数和输出通道数。
特征数:特征图的大小
步长stride和填充padding:线性减小特征图的尺寸
池化pooling:非线性且高效减小特征图的尺寸
计算公式:hout = (hin +2p -k) /s +1

二、LeNet5:现代CNN的奠基者

LeNet的核心思想“卷积+池化+线性”。在PyTorch中实现其架构的代码如下:在这里插入图片描述

import torch
from torch import nn
from torch.nn import functional as F
from torchinfo import summary

data = torch.ones(size=(10,1,32,32))

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1,6,5) #(H+2p-K)/S + 1
        self.pool1 = nn.AvgPool2d(kernel_size=2,stride=2)
        self.conv2 = nn.Conv2d(6,16,5)
        self.pool2 = nn.AvgPool2d(2)
        self.fc1 = nn.Linear(5*5*16,120)
        self.fc2 = nn.Linear(120,84)
    
    def forward(self,x):
        x = F.tanh(self.conv1(x))
        x = self.pool1(x)
        x = F.tanh(self.conv2(x))
        x = self.pool2(x)
        x = x.view(-1,5*5*16) #-1,我不关心-1这个位置上的数是多少,你根据我输入的x的结构帮我自己计算吧
        x = F.tanh(self.fc1(x))
        output = F.softmax(self.fc2(x),dim=1) #(samples, features)

net = Model() #实例化
net(data)  #相当于在执行 net.forward(data) 

net = Model() #实例化
summary(net, input_size=(10,1,32,32))

结果显示:

==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
Model                                    --                        --
├─Conv2d: 1-1                            [10, 6, 28, 28]           156
├─AvgPool2d: 1-2                         [10, 6, 14, 14]           --
├─Conv2d: 1-3                            [10, 16, 10, 10]          2,416
├─AvgPool2d: 1-4                         [10, 16, 5, 5]            --
├─Linear: 1-5                            [10, 120]                 48,120
├─Linear: 1-6                            [10, 84]                  10,164
==========================================================================================
Total params: 60,856
Trainable params: 60,856
Non-trainable params: 0
Total mult-adds (M): 4.22
==========================================================================================
Input size (MB): 0.04
Forward/backward pass size (MB): 0.52
Params size (MB): 0.24
Estimated Total Size (MB): 0.81
==========================================================================================

三、AlexNet:从浅层到深度

AlexNet的架构若用文字来表现,则可以
打包成4个组合:
输入→(卷积+池化)→(卷积+池化)→(卷积x3+池化)→(线性x3)→输出
相对的,LeNet5的架构可以打包成3个组合:
输入→(卷积+池化)→(卷积+池化)→(线性x2)→输出在这里插入图片描述
和只有6层(包括池化层)的LeNet5比起来,AlexNet主要做出了如下改变:
1、相比之下,卷积核更小、网络更深、通道数更多,这代表人们已经认识到了图像数据天生适合于多次
提取特征,“深度”才是卷积网络的未来。LeNet5是基于MNIST数据集创造,MNIST数据集中的图片尺寸
大约只有30*30的大小,LeNet5采用了5x5的卷积核,图像尺寸/核尺寸大约在6:1。而基于ImageNet
数据集训练的AlexNet最大的卷积核只有11x11,且在第二个卷积层就改用5x5,剩下的层中都使用3x3
的卷积核,图像尺寸/核尺寸至少也超过20:1。小卷积核让网络更深,但也让特征图的尺寸变得很小,
为了让信息尽可能地被捕获,AlexNet也使用了更多的通道。小卷积核、多通道、更深的网络,这些都
成为了卷积神经网络后续发展的指导方向。
2、使用了ReLU激活函数,摆脱Sigmoid与Tanh的各种问题。
3、使用了Dropout层来控制模型复杂度,控制过拟合。
4、引入了大量传统或新兴的图像增强技术来扩大数据集,进一步缓解过拟合。
5、使用GPU对网络进行训练,使得“适当的训练“(proper training)成为可能。

1、AlexNet的架构复现

在PyTorch中来复现AlexNet的架构:

import torch
from torch import nn
from torch.nn import functional as F

data = torch.ones(size=(10,3,227,227)) #224 x 224

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        
        # 大卷积核、较大的步长、较多的通道
        # 为了处理尺寸较大的原始图片,先使用11x11的卷积核和较大的步长来快速降低特征图的尺寸
        # 同时,使用比较多的通道数,来弥补降低尺寸造成的数据损失
        self.conv1 = nn.Conv2d(3,96, kernel_size=11, stride=4)
        self.pool1 = nn.MaxPool2d(kernel_size=3, stride=2) #overlap pooling
        
        #卷积核、步长恢复正常大小,进一步扩大通道
        # 已经将特征图尺寸缩小到27x27,计算量可控,可以开始进行特征提取了
        # 卷积核、步长恢复到业界常用的大小,进一步扩大通道来提取数据
        self.conv2 = nn.Conv2d(96,256,kernel_size=5,padding=2)
        self.pool2 = nn.MaxPool2d(kernel_size=3, stride=2)
        
        # 疯狂提取特征,连续用多个卷积层
        # kernel 5, padding 2, kernel 3, padding 1 可以维持住特征图的大小
        self.conv3 = nn.Conv2d(256,384,kernel_size=3, padding =1) 
        self.conv4 = nn.Conv2d(384,384,kernel_size=3, padding =1)
        self.conv5 = nn.Conv2d(384,256,kernel_size=3, padding =1)
        self.pool3 = nn.MaxPool2d(kernel_size=3, stride=2)
        
        # 进入全连接层,进行信息汇总
        self.fc1 = nn.Linear(6*6*256,4096) #上层所有特征图上的所有像素
        self.fc2 = nn.Linear(4096,4096)
        self.fc3 = nn.Linear(4096,1000)
    
    def forward(self,x):
        x = F.relu(self.conv1(x))
        x = self.pool1(x)

        x = F.relu(self.conv2(x))
        x = self.pool2(x)

        x = F.relu(self.conv3(x))
        x = F.relu(self.conv4(x))
        x = F.relu(self.conv5(x))
        x = self.pool3(x)

        x = x.view(-1,6*6*256)  #将数据拉平

        x = F.dropout(x,p=0.5)                      #dropout:随机让50%的权重为0
        x = F.relu(F.dropout(self.fc1(x),p=0.5))    #dropout:随机让50%的权重为0 
        x = F.relu(self.fc2(x))
        output = F.softmax(self.fc3(x),dim=1)

net = Model()
net(data)

from torchinfo import summary
summary(net,input_size=(10,3,227,227))

结果显示:

==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
Model                                    --                        --
├─Conv2d: 1-1                            [10, 96, 55, 55]          34,944
├─MaxPool2d: 1-2                         [10, 96, 27, 27]          --
├─Conv2d: 1-3                            [10, 256, 27, 27]         614,656
├─MaxPool2d: 1-4                         [10, 256, 13, 13]         --
├─Conv2d: 1-5                            [10, 384, 13, 13]         885,120
├─Conv2d: 1-6                            [10, 384, 13, 13]         1,327,488
├─Conv2d: 1-7                            [10, 256, 13, 13]         884,992
├─MaxPool2d: 1-8                         [10, 256, 6, 6]           --
├─Linear: 1-9                            [10, 4096]                37,752,832
├─Linear: 1-10                           [10, 4096]                16,781,312
├─Linear: 1-11                           [10, 1000]                4,097,000
==========================================================================================
Total params: 62,378,344
Trainable params: 62,378,344
Non-trainable params: 0
Total mult-adds (G): 11.36
==========================================================================================
Input size (MB): 6.18
Forward/backward pass size (MB): 52.74
Params size (MB): 249.51
Estimated Total Size (MB): 308.44
==========================================================================================

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1710998.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

文章结尾,铺垫下一章带来的期待

你是否容易在阅读时打瞌睡? 是否有很多买回来的书,放在书架上一年甚至几年都未读完,积满了灰尘? 但是,对于小说和电视剧,你却完全停不下来。每集片尾的预告激发了你持续观看下一集的渴望,带来了无限的期待…… 当你撰写文章或编写工具书时,内容可能呈现出乏味的面貌…

二叉树习题精讲-单值二叉树

单值二叉树 965. 单值二叉树 - 力扣(LeetCode)https://leetcode.cn/problems/univalued-binary-tree/description/ 判断这里面的所有数值是不是一样 方案1:遍历 方案2:拆分子问题 /*** Definition for a binary tree node.* struc…

数据库自动化管理的六大等级

什么是数据库自动化管理? 数据库自动化管理是指通过使用工具和流程,在尽量减少人为干预的情况下,管理和执行与数据库相关的任务。主要目的当然是提高效率,减少人为错误,确保一致性,并解放 DBA 和开发者&am…

系统思考—决策

风险来自于你不知道你在做什么。——沃伦巴菲特 今天和一个合作伙伴的创始人交流,她提出了一个引人深思的问题:“策略性陪伴和战略复盘,什么原因不由客户自己来做?”这个问题让我深入思考了第三方策略性陪伴顾问的独特价值和重要…

《征服数据结构》块状链表

摘要: 1,块状链表的介绍 2,块状链表的代码实现(Java和C) 1,块状链表的介绍 前面我们讲过数组和链表,数组具有 O(1)的查询时间,O(N)的删除,O(N)的插入,而链表具…

java 对接农行支付相关业务(二)

文章目录 农行掌银集成第三方APP1:掌银支付对接快e通的流程1.1 在农行网站上注册我们的app信息([网址](https://openbank.abchina.com/Portal/index/index.html))1.2:java整合农行的jar包依赖1.3:把相关配置信息整合到项目中1.4:前端获取授权码信息1.5:后端根据授权码信…

Unity【入门】环境搭建、界面基础、工作原理

Unity环境搭建、界面基础、工作原理 Unity环境搭建 文章目录 Unity环境搭建1、Unity引擎概念1、什么是游戏引擎2、游戏引擎对于我们的意义3、如何学习游戏引擎 2、软件下载和安装3、新工程和工程文件夹 Unity界面基础1、Scene场景和Hierarchy层级窗口1、窗口布局2、Hierarchy层…

Spring Cloud Alibaba-06-Sleuth链路追踪

Lison <dreamlison163.com>, v1.0.0, 2024.4.03 Spring Cloud Alibaba-06-Sleuth链路追踪 文章目录 Spring Cloud Alibaba-06-Sleuth链路追踪为什么使用链路追踪常见链路追踪解决方案Sleuth概述概述Sleuth术语 Sleuth Zipkin 原理Sleuth原理简述Zipkin 原理简述 Sleut…

剪画小程序:自媒体创作的第一步:如何将视频中的文案提取出来?

自媒体创作第一步&#xff0c;文案提取无疑是至关重要的一环。 做自媒体之所以要进行文案提取&#xff0c;有以下重要原因&#xff1a; 首先&#xff0c;提高效率。通过文案提取&#xff0c;可以快速获取关键信息&#xff0c;避免在无关紧要的内容上浪费时间&#xff0c;从而…

OpenEuler安装

1.下载镜像文件 2.新建虚拟机 版本选替他linux5.x内核64位 内存选4G 磁盘大小选40GB 内存和磁盘大小不能按默认&#xff0c;不然会很卡甚至没反应 优先使用英语 安装目的地一开始会有警告标志&#xff0c;点进去点完成 输入密码时不能太短还要保证拥有至少三种字符类型 等待安…

【数据结构】AVL树——平衡二叉搜索树

个人主页&#xff1a;东洛的克莱斯韦克-CSDN博客 祝福语&#xff1a;愿你拥抱自由的风 目录 二叉搜索树 AVL树概述 平衡因子 旋转情况分类 左单旋 右单旋 左右双旋 右左双旋 AVL树节点设计 AVL树设计 详解单旋 左单旋 右单旋 详解双旋 左右双旋 平衡因子情况如…

基于微信小程序+ JAVA后端实现的【微信小程序跑腿平台】设计与实现 (内附设计LW + PPT+ 源码+ 演示视频 下载)

项目名称 项目名称&#xff1a; 《微信小程序跑腿平台的设计与实现》 项目技术栈 该项目采用了以下核心技术栈&#xff1a; 后端框架/库&#xff1a; Java, SSM框架数据库&#xff1a; MySQL前端技术&#xff1a; 微信小程序, HTML…&#xff08;其它相关技术&#xff09; …

.BFS.

BFS &#xff08;Breadth-First Search&#xff09;是一种用于遍历或搜索树&#xff08;tree&#xff09;或图&#xff08;graph&#xff09;的算法。 这个算法从根&#xff08;或某个任意节点&#xff09;开始&#xff0c;并探索最近的邻居节点&#xff0c; 然后再探索那些节点…

adb的常见操作和命令

最近学习adb的时候&#xff0c;整理了一些adb的使用场景&#xff0c;如&#xff1a;adb与设备交互&#xff0c;adb的安装、卸载&#xff0c;adb命令启动&#xff0c;通过命令清除缓存&#xff0c;文件传输和日志操作。 adb的两大作用&#xff1a;在app测试的时候可以提供监控日…

如何高效测试防火墙的NAT64与ALG应用协议转换能力

在本文开始介绍如何去验证防火墙&#xff08;DUT&#xff09;支持NAT64 ALG应用协议转换能力之前&#xff0c;我们先要简单了解2个比较重要的知识点&#xff0c;即&#xff0c;NAT64和ALG这两个家伙到底是什么&#xff1f; 网络世界中的“翻译官” - NAT64技术 简而言之&…

【Linux安全】iptables防火墙(二)

目录 一.iptables规则的保存 1.保存规则 2.还原规则 3.保存为默认规则 二.SNAT的策略及应用 1.SNAT策略的典型应用环境 2.SNAT策略的原理 2.1.未进行SNAT转换后的情况 2.2.进行SNAT转换后的情况 3.SNAT策略的应用 3.1.前提条件 3.2.实现方法 三.DNAT策略及应用 1…

学习笔记——数据通信基础——数据通信网络(网络工程师)

网络工程师 网络工程&#xff0c;就是围绕着网络进行的一系列的活动&#xff0c;包括∶网络规划、设计、实施、调试、排错等。网络工程设计的知识领域很宽广&#xff0c;其中路由和交换是计算机网络的基本。 网络工程师∶是在网络工程领域&#xff0c;掌握专业的网络技术&…

Go 使用 RabbitMQ---------------之一

RabbitMQ 是一种消息代理。消息代理的主要目的是接收、存储并转发消息。在复杂的系统设计和微服务架构中,RabbitMQ 经常被用作中间件来处理和转发系统之间的消息,以确保数据的一致性和可靠性。正是因为提供了可靠的消息机制、跟踪机制和灵活的消息路由,常常被用于排队算法、…

SAP PP学习笔记 - 错误 CX_SLD_API_EXCEPTION - Job dump is not fully saved (too big)

我这个错误是跑完MRP&#xff0c;然后在MD04查看在库/所有量一览&#xff0c; 点计划手配&#xff08;Planned order 计划订单&#xff09;生成 制造指图&#xff08;Production order 生产订单&#xff09;&#xff0c; 到目前这几步都OK&#xff0c;然后在制造指图界面点保…

View->Bitmap缩放到自定义ViewGroup的任意区域

Bitmap缩放和平移 加载一张Bitmap可能为宽高相同的正方形&#xff0c;也可能为宽高不同的矩形缩放方向可以为中心缩放&#xff0c;左上角缩放&#xff0c;右上角缩放&#xff0c;左下角缩放&#xff0c;右下角缩放Bitmap中心缩放&#xff0c;包含了缩放和平移两个操作&#xf…