动手学深度学习—使用块的网络VGG(代码详解)

news2024/10/6 20:38:36

目录

  • 1. VGG块
  • 2. VGG网络
  • 3. 训练模型

1. VGG块

经典卷积神经网络的基本组成部分是下面的这个序列:
1.带填充以保持分辨率的卷积层;
2.非线性激活函数,如ReLU;
3.汇聚层,如最大汇聚层。

定义网络块,便于我们重复构建某些网络架构,不仅利于代码编写与阅读也利于后面参数的优化

"""
    定义了一个名为vgg_block的函数来实现一个VGG块:
    1、卷积层的数量num_convs
    2、输入通道的数量in_channels 
    3、输出通道的数量out_channels
"""
import torch
from torch import nn
from d2l import torch as d2l


# 定义vgg块,(卷积层数,输入通道,输出通道)
def vgg_block(num_convs, in_channels, out_channels):
    # 创建空网络结果,之后通过循环操作使用append函数进行添加
    layers = []
    
    # 循环操作,添加卷积层和非线性激活层
    for _ in range(num_convs):
        layers.append(nn.Conv2d(in_channels, out_channels,
                                kernel_size=3, padding=1))
        layers.append(nn.ReLU())
        in_channels = out_channels
        
    # 最后添加最大值汇聚层
    layers.append(nn.MaxPool2d(kernel_size=2, stride=2))
    return nn.Sequential(*layers)

2. VGG网络

在这里插入图片描述
由于会重复用到卷积层、激活函数ReLU和汇聚层,我们将这三个组合成一个块,每次引用这个块来构建网络模型。
通过定义VGG块,使得重复的网络结构实现起来更加容易,也利于代码阅读。

# 原VGG网络有5个卷积块,前两个有一个卷积层,后三个块有两个卷积层
# 该网络使用8个卷积层和3个全连接层,因此它通常被称为VGG-11

# (卷积层数,输出通道数)
conv_arch = ((1, 64), (1, 128), (2, 256), (2, 512), (2, 512))

实现VGG-11:使用8个卷积层和3个全连接层

# 通过for循环实现VGG-11
def vgg(conv_arch):
    # 定义空网络结构
    conv_blks = []
    in_channels = 1
    # 卷积层部分
    for (num_convs, out_channels) in conv_arch:
        # 添加vgg块
        conv_blks.append(vgg_block(num_convs, in_channels, out_channels))
        # 下一层输入通道数=当前层输出通道数
        in_channels = out_channels
        
    return nn.Sequential(
        *conv_blks, nn.Flatten(),
        # 全连接层部分
        nn.Linear(out_channels * 7 * 7, 4096), nn.ReLU(), nn.Dropout(0.5),
        nn.Linear(4096, 4096), nn.ReLU(), nn.Dropout(0.5),
        nn.Linear(4096, 10))

net = vgg(conv_arch)

构建一个高度和宽度为224的单通道数据样本,以观察每个层输出的形状

# 构建一个高度和宽度为224的单通道数据样本,以观察每个层输出的形状
X = torch.randn(size=(1, 1, 224, 224))
for blk in net:
    X = blk(X)
    print(blk.__class__.__name__, 'output shape:\t', X.shape)

每一层的输出形状
在这里插入图片描述

3. 训练模型

构建了一个通道数较少的网络,足够用于训练Fashion-MNIST数据集

# 构建了一个通道数较少的网络,足够用于训练Fashion-MNIST数据集
ratio = 4
# //为整除
small_conv_arch = [(pair[0], pair[1] // 4) for pair in conv_arch]
net = vgg(small_conv_arch)

定义精度评估函数

"""
    定义精度评估函数:
    1、将数据集复制到显存中
    2、通过调用accuracy计算数据集的精度
"""
def evaluate_accuracy_gpu(net, data_iter, device=None): #@save
    # 判断net是否属于torch.nn.Module类
    if isinstance(net, nn.Module):
        net.eval()
        
        # 如果不在参数选定的设备,将其传输到设备中
        if not device:
            device = next(iter(net.parameters())).device
    
    # Accumulator是累加器,定义两个变量:正确预测的数量,总预测的数量。
    metric = d2l.Accumulator(2)
    with torch.no_grad():
        for X, y in data_iter:
            # 将X, y复制到设备中
            if isinstance(X, list):
                # BERT微调所需的(之后将介绍)
                X = [x.to(device) for x in X]
            else:
                X = X.to(device)
            y = y.to(device)
            
            # 计算正确预测的数量,总预测的数量,并存储到metric中
            metric.add(d2l.accuracy(net(X), y), y.numel())
    return metric[0] / metric[1]

定义GPU 训练函数

"""
    定义GPU训练函数:
    1、为了使用gpu,首先需要将每一小批量数据移动到指定的设备(例如GPU)上;
    2、使用Xavier随机初始化模型参数;
    3、使用交叉熵损失函数和小批量随机梯度下降。
"""
#@save
def train_ch6(net, train_iter, test_iter, num_epochs, lr, device):
    """用GPU训练模型(在第六章定义)"""
    # 定义初始化参数,对线性层和卷积层生效
    def init_weights(m):
        if type(m) == nn.Linear or type(m) == nn.Conv2d:
            nn.init.xavier_uniform_(m.weight)
    net.apply(init_weights)
    
    # 在设备device上进行训练
    print('training on', device)
    net.to(device)
    
    # 优化器:随机梯度下降
    optimizer = torch.optim.SGD(net.parameters(), lr=lr)
    
    # 损失函数:交叉熵损失函数
    loss = nn.CrossEntropyLoss()
    
    # Animator为绘图函数
    animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs],
                            legend=['train loss', 'train acc', 'test acc'])
    
    # 调用Timer函数统计时间
    timer, num_batches = d2l.Timer(), len(train_iter)
    
    for epoch in range(num_epochs):
        
        # Accumulator(3)定义3个变量:损失值,正确预测的数量,总预测的数量
        metric = d2l.Accumulator(3)
        net.train()
        
        # enumerate() 函数用于将一个可遍历的数据对象
        for i, (X, y) in enumerate(train_iter):
            timer.start() # 进行计时
            optimizer.zero_grad() # 梯度清零
            X, y = X.to(device), y.to(device) # 将特征和标签转移到device
            y_hat = net(X)
            l = loss(y_hat, y) # 交叉熵损失
            l.backward() # 进行梯度传递返回
            optimizer.step()
            with torch.no_grad():
                # 统计损失、预测正确数和样本数
                metric.add(l * X.shape[0], d2l.accuracy(y_hat, y), X.shape[0])
            timer.stop() # 计时结束
            train_l = metric[0] / metric[2] # 计算损失
            train_acc = metric[1] / metric[2] # 计算精度
            
            # 进行绘图
            if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:
                animator.add(epoch + (i + 1) / num_batches,
                             (train_l, train_acc, None))
                
        # 测试精度
        test_acc = evaluate_accuracy_gpu(net, test_iter) 
        animator.add(epoch + 1, (None, None, test_acc))
        
    # 输出损失值、训练精度、测试精度
    print(f'loss {train_l:.3f}, train acc {train_acc:.3f},'
          f'test acc {test_acc:.3f}')
    
    # 设备的计算能力
    print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec'
          f'on {str(device)}')

在这里插入图片描述

进行训练

# 学习率略高
lr, num_epochs, batch_size = 0.05, 10, 128
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=224)
d2l.train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())

在这里插入图片描述
块的使用导致网络定义的非常简洁。使用块可以有效地设计复杂的网络。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1118367.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【LeetCode】67. 二进制求和

1 问题 给你两个二进制字符串 a 和 b ,以二进制字符串的形式返回它们的和。 示例 1: 输入:a “11”, b “1” 输出:“100” 示例 2: 输入:a “1010”, b “1011” 输出:“10101” 2 答案 自己写…

【分类讨论】CF1834D

Problem - D - Codeforces 题意: 思路: 这是个分类讨论题,一开始还以为是枚举什么的,发现根本枚举不了 注意到最终的答案就两种情况:区间包含 or 区间不包含 对于第一种情况,贡献的最大值就是mxlen - m…

【Arduino TFT】基于 ESP32 S7789 320x240 TFT实现的SD2 天气时钟

忘记过去,超越自己 ❤️ 博客主页 单片机菜鸟哥,一个野生非专业硬件IOT爱好者 ❤️❤️ 本篇创建记录 2023-10-21 ❤️❤️ 本篇更新记录 2023-10-21 ❤️🎉 欢迎关注 🔎点赞 👍收藏 ⭐️留言📝&#x1f64…

使用langchain-chatchat里,faiss库中报错: AssertionError ,位置:assert d == self.d

发生报错: AssertionError,发生位置:class_wrappers.py里 assert d self.d,假如输出语句,查看到是因为d和self.d维度不匹配造成,解决方式: 删除langchain-chatchat/knowledge_base里的info.db…

初识树结构和二叉树

一,树概念及结构 1.1树结构的概念 树是一种非线性的数据结构,它是由n(n>0)个有限结点组成一个具有层次关系的集合。把它叫做树是因为它看起来像一棵倒挂的树,也就是说它是根朝上,而叶朝下的。 注意&a…

嵌入式Linux中内存管理详解分析

Linux中内存管理 内存管理的主要工作就是对物理内存进行组织,然后对物理内存的分配和回收。但是Linux引入了虚拟地址的概念。 虚拟地址的作用 如果用户进程直接操作物理地址会有以下的坏处: 1、 用户进程可以直接操作内核对应的内存,破坏…

linux任务优先级

这篇笔记记录了linux任务(指线程而非进程)优先级相关的概念,以及用户态可以用来操作这些优先级的系统调用。 基本概念 调度策略 linux内核中的调度器为任务定义了调度策略,也叫调度类,每个任务同一时刻都有唯一的调…

Android Framework系列---输入法服务

Android Framework系列之输入法服务 本文基于Android R(11),从Framework角度介绍Android输入法框架流程及常用调试方法。 写在前面 车载项目需要定制输入法,也有一些POC演示的项目使用原生比如LatinIME(源码路径为/packages/inputmethods…

CVE-2019-9766漏洞实战

1.利用msf生成反向连接的shellcode 2.构造具有反弹shell的MP3文件 将上一步标记的部分替换脚本中的shellcode 3.运行脚本,生成恶意mp3文件 4.msf设置监听并运行exploit 5.打开恶意文件 6.攻击机已经获得shell 文笔生疏,措辞浅薄,望各位大佬不吝赐教…

运行原理:eBPF 是一个新的虚拟机吗?

目录 背景 eBPF 虚拟机是如何工作的? BPF 指令是什么样的? eBPF 程序是什么时候执行的? 小结 背景 前面,我们从最简单的 Hello World 开始,带你借助 BCC 库从零开发了一个跟踪 openat() 系统调用的 eBPF 程序。…

Leetcode1839. 所有元音按顺序排布的最长子字符串

Every day a Leetcode 题目来源:1839. 所有元音按顺序排布的最长子字符串 解法1:滑动窗口 要找的是最长美丽子字符串的长度,我们可以用滑动窗口解决。 设窗口内的子字符串为 window,每当 word[right] > window.back() 时&…

喜讯!持安科技入选2023年北京市知识产权试点单位!

近日,北京市知识产权局发布了“2023年度北京市知识产权试点示范单位及2020年度北京市知识产权试点示范单位复审通过名单”名单。 经过严格的初审、形式审核和专家评审,北京持安科技有限公司入选“2023年北京市知识产权试点单位”。 北京市知识产权试点示…

A预测蛋白质结构

基于AlphaFold2进行蛋白质结构预测的文章解析 RoseTTAFold: Tunyasuvunakool, K., Adler, J., Wu, Z. et al. Highly accurate protein structure prediction for the human proteome. Nature 596, 590–596 (2021) AlphaFold2: Accurate prediction of protein structures a…

git commit报错:running pre-commit hook: lint-staged

报错截图: 报错信息: running pre-commit hook: lint-staged 解决方式: 在项目(vue)的package.json文件中,查找 “husky” 部分,并确认其下的 “pre-commit” 钩子是否正确地引用了 lint-staged。 其中配置示例如下&a…

2023年中国自动排气阀产业链、市场规模及存在问题分析]图[

自动排气阀是一种用于排除管道、容器或设备中累积的空气或气体的装置。在液体流动系统中,气体或空气可能会积聚在管道或容器中,影响流体流动、导致气锁和能效降低。自动排气阀的作用是在系统中的气体达到一定压力时,自动地释放气体&#xff0…

LeetCode_并查集_DFS_中等_2316.统计无向图中无法互相到达点对数

目录 1.题目2.思路3.代码实现(Java) 1.题目 给你一个整数 n ,表示一张 无向图 中有 n 个节点,编号为 0 到 n - 1 。同时给你一个二维整数数组 edges ,其中 edges[i] [ai, bi] 表示节点 ai 和 bi 之间有一条无向边。请…

如何解决电脑出现msvcp140.dll丢失问题,msvcp140.dll丢失的最全解决方法

首先,我们需要了解什么是“msvcp140.dll”。这是一个动态链接库文件,它是Microsoft Visual C 2015 Redistributable的一部分。当计算机运行某些程序时,这个文件会被调用,以支持程序的正常运行。因此,当这个文件丢失时&…

3dmax中导出模型到unity注意事项

从3dmax中导出 1. 注意单位,根据需要,选英寸还是选厘米 2. 不能导出有错误的骨骼,否则导入后模型网格里出现 Skinned Mesh Renderer ,对网格变换移动有影响,正常情况下都应该是 Mesh Renderer 3. 导出一般不带光源和…

【LeetCode刷题】:仅仅反转字母(双指针+字符串)

给你一个字符串 s ,根据下述规则反转字符串: 所有非英文字母保留在原有位置 所有英文字母(小写或大写)位置反转 返回反转后的 s 示例 1: 输入:s “ab-cd” 输出:“dc-ba” 示例 2: …

【sqlserver】配置管理器打不开

问题描述 无法连接到 WMI 提供程序。您没有权限或者该服务器无法访问。请注意,您只能使用SQL Server 配置管理器来管理 SQL Server 2005 和更高版本的服务 器。无效类[0x80041010] 解决方式: 命令提示符-右键-以管理员身份运行,再把以下代码执行一遍&…