41. 使用块的网络(VGG)代码实现

news2025/1/11 18:35:20

1. VGG块

在下面的代码中,我们定义了一个名为vgg_block的函数来实现一个VGG块。

该函数有三个参数,分别对应于卷积层的数量num_convs、输入通道的数量in_channels 和输出通道的数量out_channels.

import torch
from torch import nn
from d2l import torch as d2l


def vgg_block(num_convs, in_channels, out_channels):
    layers = []
    for _ in range(num_convs):
        layers.append(nn.Conv2d(in_channels, out_channels,
                                kernel_size=3, padding=1))
        layers.append(nn.ReLU())
        in_channels = out_channels # 下一层的输入通道数等于本层的输出通道数
    layers.append(nn.MaxPool2d(kernel_size=2,stride=2))
    return nn.Sequential(*layers) # *layers是解包,把数组拆成一个个元素

2. VGG网络

VGG神经网络连接 几个VGG块(在vgg_block函数中定义)。其中有超参数变量conv_arch。该变量指定了每个VGG块里卷积层个数和输出通道数。全连接模块则与AlexNet中的相同。

原始VGG网络有5个卷积块,其中前两个块各有一个卷积层,后三个块各包含两个卷积层。 第一个模块有64个输出通道,每个后续模块将输出通道数量翻倍,直到该数字达到512。

由于该网络使用8个卷积层和3个全连接层,因此它通常被称为VGG-11

# ps:这个块数是不能改变的,也就是conv_arch中的元素个数只能是5,
# 如果变成6个的话,那么224/64 = 3...32 ,图片大小都不为整数了
# 但是通道数可以改变,卷积层个数可以改变
conv_arch = ((1, 64), (1, 128), (2, 256), (2, 512), (2, 512)) 

下面的代码实现了VGG-11。可以通过在conv_arch上执行for循环来简单实现。

def vgg(conv_arch):
    conv_blks = []
    in_channels = 1 # 是因为用的数据集fashion_mnist的图片都是灰度图
    # 卷积层部分
    for (num_convs, out_channels) in conv_arch:
        # conv_blks 是一个由很多vgg_block组成的数组
        conv_blks.append(vgg_block(num_convs, in_channels, out_channels))
        in_channels = out_channels # 下一层的输入通道数等于本层的输出通道数

    return nn.Sequential(
        *conv_blks, nn.Flatten(),
        # 全连接层部分
        # 为什么是7*7呢?因为有5个卷积块,每个卷积块的最后一层都是一个strdie=2且大小为2的pooling层,
        # 就意味着每经过一个卷积块大小会减半,经过5个卷积块后,224/32 = 7,最后图片大小变成7*7
        nn.Linear(out_channels * 7 * 7, 4096), nn.ReLU(), nn.Dropout(0.5),
        nn.Linear(4096, 4096), nn.ReLU(), nn.Dropout(0.5),
        nn.Linear(4096, 10))

net = vgg(conv_arch)

接下来,我们将构建一个高度和宽度为224的单通道数据样本,以观察每个层输出的形状。

X = torch.randn(size=(1, 1, 224, 224))
for blk in net:
    X = blk(X)
    print(blk.__class__.__name__,'output shape:\t',X.shape)

在这里插入图片描述

正如从代码中所看到的,我们在每个块的高度和宽度减半,最终高度和宽度都为7。最后再展平表示,送入全连接层处理。

3. 训练模型

由于VGG-11比AlexNet计算量更大,因此我们构建了一个通道数较少的网络,足够用于训练Fashion-MNIST数据集。

ratio = 4
small_conv_arch = [(pair[0], pair[1] // ratio) for pair in conv_arch]
net = vgg(small_conv_arch)

除了使用略高的学习率外,模型训练过程与AlexNet类似。

lr, num_epochs, batch_size = 0.05, 10, 128
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=224)
d2l.train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/118225.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【玩转c++】c++ :string类讲解(万字详解)

目录 🍁1. 为什么要学习string类 🍁2. 标准库中的string类 🍁3. string类各种接口 默认成员函数 Iterators迭代器 capacity容量 Element access:元素访问 Modifiers:修改 字符串操作 成员变量 非成员函数 🍁4. 扩展阅读 本期主题…

stm32f407VET6 系统学习 day07 通用定时器, OLED 屏幕使用 PWM 的使用

1. 通用定时器的知识 1.STM32共有14个定时器,其中12个16位定时器,2个32 位定时器 2. 通用定时器特点 1. 16/32位向上、向下、向上/向下(中心对齐)计数模式,自动装载计数器(TIMXCNT) 。 2. 16位可编程预分频器(TIMx_PSC)&…

-bash: lsof: command not found解决办法

简言 centos系统,检测端口时使用lsof命令发现lsof功能未开启,如下图 [rootiZwz9501p9hnysn92hpx27Z tnt_game]# lsof -bash: lsof: command not found 安装lsof centos系统下可以直接使用yum安装lsof功能,如下图 yum可自动完成安装lsof ls…

gitlab-ci.yml关键字(一)image、variables、include

image 这是一个全局关键字,如果流水线的执行器是使用docker来运行的话,那可以指定docker中的docker镜像。如果执行器是shell的话,那该关键字是无用的,即便机器中已近安装了docker的环境,该关键字可以在全局或者某一个…

NeurIPS2021 | ViTAE+: vision transformer中的归纳偏置探索

参考资料:NeurIPS 2021 | ViTAE: vision transformer中的归纳偏置探索 - 知乎 paper地址:https://openreview.net/pdf?id_RnHyIeu5Y5 论文标题:ViTAE: Vision Transformer Advanced by Exploring Intrinsic Inductive Bias code&#xff…

假设检验之卡方检验

之前我对卡方检验的了解都是一知半解的,即知道作用是对离散变量分布差异的比较,根据期望频数和观察频数的差异计算出来一个卡方值,之后根据自由度和显著性水平查卡方分布对应的临界值,比较大小得出有无明显差异的结论。 一般我们都…

基于FPGA平台实现 ARM Cortex-M0 SOC(一)简介

本系列笔记为基于FPGA平台实现 ARM Cortex-M0 SOC 集创赛作品复盘 Platform: ARM Cortex-M0 Design Srart AT510 XLINX FPGA ARM MDK 5 CM0-Design start 是ARM公司放出的一个免费的ARM 内核学习版本,它比M3还要简单,并且官方把整块代码模糊化…

TFN CK1840B 喇叭天线 定向 18GHz~40GHz

TFN CK1840B 喇叭天线 定向 18GHz~40GHz 产品概述 TFN CK1840B喇叭天线工作频率为 18GHz~40GHz。具有频带宽, 性能可靠, 增益高等优 点, 是理想的 EMC 测试、电子对抗等领域的定向接收、发射天线。 应用领域 ● 电子对抗领域 ● EMC 测试…

基于python多光谱遥感数据处理、图像分类、定量评估及机器学习方法应用

普通数码相机记录了红、绿、蓝三种波长的光,多光谱成像技术除了记录这三种波长光之外,还可以记录其他波长(例如:近红外、热红外等)光的信息。与昂贵、不易获取的高光谱、高空间分辨率卫星数据相比,中等分辨…

Gateway

Gateway—SpringCloud微服务网关组件 一、Spring Cloud Gateway简介 1.为什么要用Gateway? 在微服务架构中,通常一个系统会被拆分为多个微服务,微服务之间的调用可以用OpenFeign,但面对这么多微服务客户端调用会遇到哪些问题呢…

Hudi(3):Hudi之基本概念

目录 0. 相关文章链接 1. 时间轴(TimeLine) 1.1. Instant action:在表上执行的操作类型 1.2. Instant time 1.3. State 1.4. 两个时间概念 2. 文件布局(File Layout) 2.1. Hudi表的文件结构 2.2. Hudi存储的两…

Cocos 引擎生态部负责人李阳:己之所欲,可施于人,希望通过生态促进国内引擎技术发展

前言 “小小的身体,大大的能量,这个应该是我对大表姐最直接的感觉,在她娇小的身躯里蕴含了无限的精力和潜力,很像漫威里的神奇女侠,作为一个具备冒险精神的非典型程序员,大表姐热爱的体育活动都是很具挑战…

大数据系列——什么是ClickHouse?ClickHouse有什么用途?

目录 一、什么是ClickHouse 二、ClickHouse有什么用途 三、ClickHouse的不足 四、适用场景 五、ClickHouse特点 六、ClickHouse VS MySQL 七、类SQL 语句 八、核心概念 一、什么是ClickHouse clickHouse是俄罗斯的 Yandex 公司于 2016 年开源的列式存储数据库&#x…

win11系统用户名称为中文导致文件夹出现繁体字文件夹、系统路径配置错误修改教程(博主亲测,基于win11,系统文件保留)

写在前面:很多人在拿到新电脑激活那会,命名就是简单的中文,但是中文命名电脑系统名称,会导致系统用户文件夹自动命名为中文,在后期使用中会导致c盘系统用户文件夹下面出现不知名繁体字文件夹,甚至有的朋友会…

终难逃一阳

阳了,抗原试剂显示我阳了。每天都带口罩的我还是未能逃过此劫。真是覆巢之下,焉有完卵。 ​ 1.背景 12月初国家逐步放开防疫,随之而来的就是奥秘克戎肆虐全国。身边同事和朋友一个接着一个倒下,朋友圈里更是哀嚎一片。好在专家…

《CSAPP》笔记——链接、异常控制流、虚拟内存

文章目录传送门链接基础链接器的意义编译器驱动程序静态链接ELF目标文件格式可重定位目标文件符号和符号表链接过程符号解析解析规则静态链接库带有静态链接库的解析过程重定位重定位条目重定位节重定位符号引用重定位相对引用重定位绝对引用加载可执行目标文件动态链接共享库库…

Kafka 消费者组开发

Kafka consumer - 消费者组 上一篇文章学习到kafka消费者、消费者组之间处理消息的差异,总结起来就是: 同一个消费组的不同消费实例 共同消费topiic的消息, 一个消息只会消费一次; 也叫做集群消费同一个消息被不同的消费组同时消费&#xf…

机器学习基石1(ML基本概念和VC dimension)

文章目录一、什么是机器学习?二、什么时候可以使用机器学习?三、感知机perceptron四、机器学习的输入形式五、机器真的可以学习吗?六、vc dimension一、什么是机器学习? 其实第一个问题和第二个问题是穿插到一块儿回答的,首先机器学习要解决的是常规的…

RedisTemplate操作redis

目录 Redis Repositories方式 a、启用 Repository 功能 b、注解需要缓存的实体 c、创建一个 Repository 接口 d、测试类中测试 Redis Repositories方式 Spring Data Redis 从 1.7 开始提供 Redis Repositories ,可以无缝的转换并存储 domain objects&#xff0…

TOPSIS法(熵权法)(模型+MATLAB代码)

TOPSIS可翻译为逼近理想解排序法,国内简称为优劣解距离法 TOPSIS法是一种常用的综合评价方法,其能充分利用原始数据的信息,其结果能精确地反映各评价方案之间的距离 一、模型介绍 极大型指标(效益型指标) &#xff…