9、动手学深度学习——使用块的网络(VGG)

news2024/10/5 22:21:25

1、VGG块

虽然AlexNet证明深层神经网络卓有成效,但它没有提供一个通用的模板来指导后续的研究人员设计新的网络。 在下面的几个章节中,我们将介绍一些常用于设计深层神经网络的启发式概念。

与芯片设计中工程师从放置晶体管到逻辑元件再到逻辑块的过程类似,神经网络架构的设计也逐渐变得更加抽象。研究人员开始从单个神经元的角度思考问题,发展到整个层,现在又转向块,重复层的模式

使用块的想法首先出现在牛津大学的 视觉几何组(visual geometry group) 的VGG网络中。通过使用循环和子程序,可以很容易地在任何现代深度学习框架的代码中实现这些重复的架构。

经典卷积神经网络的基本组成部分是下面的这个序列:

  1. 带填充以保持分辨率的卷积层;
  2. 非线性激活函数,如ReLU;
  3. 汇聚层,如最大汇聚层。

一个VGG块与之类似,由一系列卷积层组成,后面再加上用于空间下采样的最大汇聚层。在最初的VGG论文中 (Simonyan and Zisserman, 2014),作者使用了带有卷积核、填充为1(保持高度和宽度)的卷积层,和带有汇聚窗口、步幅为2(每个块后的分辨率减半)的最大汇聚层。在下面的代码中,我们定义了一个名为vgg_block的函数来实现一个VGG块。

import torch
from torch import nn
from d2l import torch as d2l


def vgg_block(num_convs, in_channels, out_channels):
    layers = []
    for _ in range(num_convs):
        layers.append(nn.Conv2d(in_channels, out_channels,
                                kernel_size=3, padding=1))
        layers.append(nn.ReLU())
        in_channels = out_channels
    layers.append(nn.MaxPool2d(kernel_size=2,stride=2))
    # print(layers)                # 用于调试,查看每层信息
    return nn.Sequential(*layers)

该函数有三个参数,分别对应于卷积层的数量num_convs、输入通道的数量in_channels 和输出通道的数量out_channels。

2、VGG网络

与AlexNet、LeNet一样,VGG网络可以分为两部分:第一部分主要由卷积层和汇聚层组成,第二部分由全连接层组成。

在这里插入图片描述

由多个尺寸相同的卷积层和一个池化层构成一个VGG块,这个块来代替AlexNet中11x11卷积层和3x3池化层与5x5卷积层和3x3池化层这一部分,用更加模块化的方式进行替代。

VGG神经网络连接 图7.2.1的几个VGG块(在vgg_block函数中定义)。其中有超参数变量conv_arch。该变量指定了每个VGG块里卷积层个数输出通道数。全连接模块则与AlexNet中的相同。

原始VGG网络有5个卷积块,其中前两个块各有一个卷积层,后三个块各包含两个卷积层。 第一个模块有64个输出通道,每个后续模块将输出通道数量翻倍,直到该数字达到512。由于该网络使用8个卷积层和3个全连接层,因此它通常被称为VGG-11。

conv_arch = ((1, 64), (1, 128), (2, 256), (2, 512), (2, 512))

def vgg(conv_arch):
    conv_blks = []
    in_channels = 1
    # 卷积层部分
    for (num_convs, out_channels) in conv_arch:
        # 查看每层信息
        # print(f'num_convs {num_convs}, in_channels {in_channels}, out_channels {out_channels}')
        conv_blks.append(vgg_block(num_convs, in_channels, out_channels))        
        in_channels = out_channels

    return nn.Sequential(
        *conv_blks, nn.Flatten(),
        # 全连接层部分
        nn.Linear(out_channels * 7 * 7, 4096), nn.ReLU(), nn.Dropout(0.5),
        nn.Linear(4096, 4096), nn.ReLU(), nn.Dropout(0.5),
        nn.Linear(4096, 10))

net = vgg(conv_arch)
net		# 查看net结构

Sequential(
  (0): Sequential(
    (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (1): Sequential(
    (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (2): Sequential(
    (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU()
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (3): Sequential(
    (0): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU()
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (4): Sequential(
    (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU()
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (5): Flatten(start_dim=1, end_dim=-1)
  (6): Linear(in_features=25088, out_features=4096, bias=True)
  (7): ReLU()
  (8): Dropout(p=0.5, inplace=False)
  (9): Linear(in_features=4096, out_features=4096, bias=True)
  (10): ReLU()
  (11): Dropout(p=0.5, inplace=False)
  (12): Linear(in_features=4096, out_features=10, bias=True)
)

可以看出,主要VGG块分为三部分。
第一部分是只含有一个卷积层和一个池化层操作的层(0层和1层):这一部分会让图片的高宽减半,下一层的输出通道数变为上一层的两倍。
第二部分是含有两个卷积层和一个池化层操作的层(2层和3层):第一个卷积层让通道数翻倍,第二个卷积层相当于是进行了全连接操作,最后池化层让高宽减半。
第三部分也是含有两个卷积层和一个池化层的层:与上述第二部分不同的是,这一部分里会保持输出通道数不变,然后和第二层卷积层进行全连接操作,最后进行池化操作。

然后,与AlexNet相同,最后进行了都展开全连接层,直至类别输出预测。

接下来,我们将构建一个高度和宽度为224的单通道数据样本,以观察每个层输出的形状。

X = torch.randn(size=(1, 1, 224, 224))
for blk in net:
    X = blk(X)
    print(blk.__class__.__name__,'output shape:\t',X.shape)

Sequential output shape:	 torch.Size([1, 64, 112, 112])
Sequential output shape:	 torch.Size([1, 128, 56, 56])
Sequential output shape:	 torch.Size([1, 256, 28, 28])
Sequential output shape:	 torch.Size([1, 512, 14, 14])
Sequential output shape:	 torch.Size([1, 512, 7, 7])
Flatten output shape:	 torch.Size([1, 25088])
Linear output shape:	 torch.Size([1, 4096])
ReLU output shape:	 torch.Size([1, 4096])
Dropout output shape:	 torch.Size([1, 4096])
Linear output shape:	 torch.Size([1, 4096])
ReLU output shape:	 torch.Size([1, 4096])
Dropout output shape:	 torch.Size([1, 4096])
Linear output shape:	 torch.Size([1, 10])

正如从代码中所看到的,我们在每个块的高度和宽度减半,最终高度和宽度都为7。最后再展平表示,送入全连接层处理。

3、训练

由于VGG-11比AlexNet计算量更大,因此我们构建了一个通道数较少的网络,足够用于训练Fashion-MNIST数据集。

ratio = 4
small_conv_arch = [(pair[0], pair[1] // ratio) for pair in conv_arch]
net = vgg(small_conv_arch)

除了使用略高的学习率外,模型训练过程与 7.1节中的AlexNet类似。

lr, num_epochs, batch_size = 0.05, 10, 128
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=224)
d2l.train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())

loss 0.220, train acc 0.918, test acc 0.900
2578.4 examples/sec on cuda:0
在这里插入图片描述

参考文章:7.2. 使用块的网络(VGG)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/720998.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

element input,一个中文占3个字符

思路&#xff1a;标记字符的下标&#xff0c;截取&#xff0c;重新赋值 代码如下&#xff0c;可直接复制预览 <template><div class"form-item"><el-inputv-model"testValue":maxlength"maxlength"input"handleInput"…

Kafka入门, 消费者组案例(十九)

pom 文件 <dependencies><dependency><groupId>org.apache.kafka</groupId><artifactId>kafka-clients</artifactId><version>3.0.0</version></dependency></dependencies>独立消费者案例&#xff08;订阅主语&a…

简单认识LVS-DR负载群集和部署实例

文章目录 一、LVS-DR负载群集简介1、DR模式数据包流向分析2、DR 模式的特点 二、DR模式 LVS负载均衡群集部署 一、LVS-DR负载群集简介 1、DR模式数据包流向分析 1、客户端发送请求到 Director Server&#xff08;负载均衡器&#xff09;&#xff0c;请求的数据报文&#xff0…

放大器的基本知识

文章目录 1.反向输入&#xff08;引出&#xff1a;反向器&#xff09;1.反向输入例子 2.同向输入&#xff08;引出&#xff1a;电压跟随器&#xff09;2.同向输入例子 3.加法运算 1.反向输入&#xff08;引出&#xff1a;反向器&#xff09; 1.反向输入例子 —————————…

基于Java网上药品售卖系统设计实现(源码+lw+部署文档+讲解等)

博主介绍&#xff1a;✌全网粉丝30W,csdn特邀作者、博客专家、CSDN新星计划导师、Java领域优质创作者,博客之星、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ &#x1f345;文末获取源码联系&#x1f345; &#x1f447;&#x1f3fb; 精彩专…

【MYSQL基础】基础命令介绍

基础命令 MYSQL注释方式 -- 单行注释/* 多行注释 哈哈哈哈哈 哈哈哈哈 */连接数据库 mysql -u root -p12345678退出数据库连接 使用exit;命令可以退出连接 查询MYSQL版本 mysql> select version(); ----------- | version() | ----------- | 8.0.27 | ----------- 1…

HA使用Node-RED推送消息到手机

目录 1.一个简单示例(1).注入使用个inject节点&#xff0c;用来触发(2).手机通知是call service节点(3).保存-部署&#xff0c;看效果 2.消息-添加变量 1.一个简单示例 (1).注入使用个inject节点&#xff0c;用来触发 (2).手机通知是call service节点 Node-RED需要提前和Home A…

spring boot + Apache tika 实现文档内容解析

Apache tika是Apache开源的一个文档解析工具。Apache Tika可以解析和提取一千多种不同的文件类型(如PPT、XLS和PDF)的内容和格式&#xff0c;并且Apache Tika提供了多种使用方式&#xff0c;既可以使用图形化操作页面&#xff08;tika-app&#xff09;&#xff0c;又可以独立部…

Dockerfile 基本命令

本文目录 1. 什么是 Dockerfile2. Dockerfile 基本命令2.1 FROM 指定基础镜像2.2 RUN 执行命令2.3 COPY 复制文件2.4 ADD 更高级的复制文件2.5 CMD2.6 ENTRYPOINT2.7 ENV 设置环境变量2.8 ARG2.9 VOLUME 定义匿名卷2.10 EXPOSE2.11 WORKDIR 指定工作目录2.12 USER 指定当前用户…

【洛谷】P1073 [NOIP2009 提高组] 最优贸易(dp+搜索)

接下来讲具体解法。第一、输入。存邻接表第二、我们需要做深搜。可以用递归来做&#xff0c;同时做动规&#xff1a;函数如下&#xff08;贴了注释&#xff09;.void dfs(int x,int minx,int pre) { //x表示当前访问的节点编号&#xff0c;minx表示目…

添加白名单 gcc/g++【Linux系统编程】

目录 一、添加白名单 二、gcc和g的使用 1、背景知识 一、添加白名单 如何让普通用户可以执行sudo&#xff08;以root的身份&#xff09;指令&#xff1f; 添加白名单 用root身份在/etc/sudoers目录添加 vim /etc/sudoers二、gcc和g的使用 1、背景知识 &#xff08;1&#…

【FFmpeg实战】ffplay整体框架

原文地址&#xff1a;https://segmentfault.com/a/1190000042611796 本文使用的ffplay.c的版本是搭配ffmpeg5.0的版本。 ffplay代码大致架构 关于fplay的架构很难三言两语说得清楚&#xff0c;而且本人对它的理解也不是很深&#xff0c;加上行笔比较啰嗦&#xff0c;可能就更…

springboot配置多个mongo数据源

yaml配置文件&#xff1a; spring:data:mongodb:uri: mongodb://admin:密码ip:27017/paasoo?authSourceadminother:uri: mongodb://admin:密码ip:27017/conversation?authSourceadmin java config文件&#xff1a; package com.paasoo.quartz.config.mongo;import org.spr…

VR数字乡村激活乡土文化生命力,助力乡村振兴

民俗节庆、传统技艺等蕴含着中华五千年以来的传统文化&#xff0c;乡村文化建设在为文化留住血脉的同时&#xff0c;也为高质量发展创造更多可能。找准乡村文化与产业的结合点&#xff0c;有利于激发产业发展的潜力&#xff0c;激活乡土文化的生命力&#xff0c;为乡村振兴注入…

baichuan-7B模型

文章目录 baichuan-7B介绍baichuan-7B 推理baichuan-7B 微调 baichuan-7B介绍 2023年6月15日&#xff0c;搜狗创始人王小川创立的百川智能公司&#xff0c;发布了70 亿参数量的中英文预训练大模型——baichuan-7B。 baichuan-7B 基于 Transformer 结构&#xff0c;在大约 1.2…

【Ubuntu学习MySQL——安装MySQL】

首先得su&#xff0c;然后输入密码&#xff0c;进入到root模式下&#xff0c;以下命令均在root用户模式下进行 1.在这里我们使用RPM包来安装Mysql&#xff0c;所以首先安装RPM包 apt install rpm2.安装完RPM包之后&#xff0c;检测系统是否自带安装MySQL&#xff0c;如果没有…

最小年龄仅5岁!盘点全球最“天才”少年黑客 TOP 10

你还能想起自己8岁的时候&#xff0c;每天都在玩什么吗&#xff1f;可能是在楼下和小朋友一起捉迷藏&#xff1f;在家追一本连载的漫画书&#xff1f;又或者在电脑上玩种菜偷菜的小游戏&#xff1f; 当同龄人还在沉迷于这些比较“基础”的小游戏时&#xff0c;有这样一批和互联…

ARM_uart_发送接收字符 and 发送接收字符串

include/uart4.h #ifndef __UART4_H__ #define __UART4_H__#include "stm32mp1xx_gpio.h" #include "stm32mp1xx_rcc.h" #include "stm32mp1xx_uart.h"//初始化相关操作 void hal_uart4_init();//发送一个字符 void hal_put_char(const char st…

逆波兰表达式

思路 变量 String[] arr Stack 代码 public class Test1 {public static void main(String[] args) {String s "3 40 5 * 6 -";Stack numArr new Stack(10);int num1 0;int num2 0;int res 0;int index 0;String[] arr s.split(" ");for(String…

Flink 读写Kafka总结

前言 总结Flink读写Kafka Flink 版本 1.15.4 Table API 本文主要总结Table API的使用&#xff08;SQL&#xff09;&#xff0c;官方文档&#xff1a;https://nightlies.apache.org/flink/flink-docs-release-1.17/zh/docs/connectors/table/kafka/ kerberos认证相关配置 …