《动手学深度学习 Pytorch版》 6.6 卷积神经网络

news2024/11/19 17:41:12
import torch
from torch import nn
from d2l import torch as d2l

6.6.1 LeNet

LetNet-5 由两个部分组成:

- 卷积编码器:由两个卷积核组成。
- 全连接层稠密块:由三个全连接层组成。

模型结构如下流程图(每个卷积块由一个卷积层、一个 sigmoid 激活函数和平均汇聚层组成):

全连接层(10)

↑ \uparrow

全连接层(84)

↑ \uparrow

全连接层(120)

↑ \uparrow

2 × 2 2\times2 2×2平均汇聚层,步幅2

↑ \uparrow

5 × 5 5\times5 5×5卷积层(16)

↑ \uparrow

2 × 2 2\times2 2×2平均汇聚层,步幅2

↑ \uparrow

5 × 5 5\times5 5×5卷积层(6),填充2

↑ \uparrow

输入图像( 28 × 28 28\times28 28×28 单通道)
net = nn.Sequential(
    nn.Conv2d(1, 6, kernel_size=5, padding=2), nn.Sigmoid(),
    nn.AvgPool2d(kernel_size=2, stride=2),
    nn.Conv2d(6, 16, kernel_size=5), nn.Sigmoid(),
    nn.AvgPool2d(kernel_size=2, stride=2),
    nn.Flatten(),
    nn.Linear(16 * 5 * 5, 120), nn.Sigmoid(),
    nn.Linear(120, 84), nn.Sigmoid(),
    nn.Linear(84, 10))

X = torch.rand(size=(1, 1, 28, 28), dtype=torch.float32)  # 生成测试数据
for layer in net:
    X = layer(X)
    print(layer.__class__.__name__,'output shape: \t',X.shape)  # 确保模型各层数据正确
Conv2d output shape: 	 torch.Size([1, 6, 28, 28])
Sigmoid output shape: 	 torch.Size([1, 6, 28, 28])
AvgPool2d output shape: 	 torch.Size([1, 6, 14, 14])
Conv2d output shape: 	 torch.Size([1, 16, 10, 10])
Sigmoid output shape: 	 torch.Size([1, 16, 10, 10])
AvgPool2d output shape: 	 torch.Size([1, 16, 5, 5])
Flatten output shape: 	 torch.Size([1, 400])
Linear output shape: 	 torch.Size([1, 120])
Sigmoid output shape: 	 torch.Size([1, 120])
Linear output shape: 	 torch.Size([1, 84])
Sigmoid output shape: 	 torch.Size([1, 84])
Linear output shape: 	 torch.Size([1, 10])

6.6.2 模型训练

batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size=batch_size)  # 仍使用经典的 Fashion-MNIST 数据集
def evaluate_accuracy_gpu(net, data_iter, device=None): #@save
    """使用GPU计算模型在数据集上的精度"""
    if isinstance(net, nn.Module):
        net.eval()  # 设置为评估模式
        if not device:
            device = next(iter(net.parameters())).device
    metric = d2l.Accumulator(2)  # 生成一个有两个元素的列表,使用 add 将会累加到对应的元素上
    with torch.no_grad():
        for X, y in data_iter:
            # 为了使用 GPU,需要将数据移动到 GPU 上
            if isinstance(X, list):
                X = [x.to(device) for x in X]
            else:
                X = X.to(device)
            y = y.to(device)
            metric.add(d2l.accuracy(net(X), y), y.numel())  # 累加(正确预测的数量,总预测的数量)
    return metric[0] / metric[1]  # 正确率
#@save
def train_ch6(net, train_iter, test_iter, num_epochs, lr, device):
    """用GPU训练模型(在第六章定义)"""
    def init_weights(m):  # 使用 Xavier 初始化权重
        if type(m) == nn.Linear or type(m) == nn.Conv2d:
            nn.init.xavier_uniform_(m.weight)
    net.apply(init_weights)
    print('training on', device)
    net.to(device)  # 移动数据到GPU
    optimizer = torch.optim.SGD(net.parameters(), lr=lr)
    loss = nn.CrossEntropyLoss()
    animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs],
                            legend=['train loss', 'train acc', 'test acc'])
    timer, num_batches = d2l.Timer(), len(train_iter)
    for epoch in range(num_epochs):
        # 训练损失之和,训练准确率之和,样本数
        metric = d2l.Accumulator(3)
        net.train()
        for i, (X, y) in enumerate(train_iter):
            timer.start()
            optimizer.zero_grad()
            X, y = X.to(device), y.to(device)
            y_hat = net(X)
            l = loss(y_hat, y)
            l.backward()
            optimizer.step()
            with torch.no_grad():
                metric.add(l * X.shape[0], d2l.accuracy(y_hat, y), X.shape[0])
            timer.stop()
            train_l = metric[0] / metric[2]
            train_acc = metric[1] / metric[2]
            if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:
                animator.add(epoch + (i + 1) / num_batches,
                             (train_l, train_acc, None))
        test_acc = evaluate_accuracy_gpu(net, test_iter)
        animator.add(epoch + 1, (None, None, test_acc))
    print(f'loss {train_l:.3f}, train acc {train_acc:.3f}, '
          f'test acc {test_acc:.3f}')
    print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec '
          f'on {str(device)}')
lr, num_epochs = 0.9, 10
train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())
loss 0.471, train acc 0.820, test acc 0.815
40056.7 examples/sec on cuda:0

在这里插入图片描述

练习

(1)将平均汇聚层替换为最大汇聚层,会发生什么?

net_Max = nn.Sequential(
    nn.Conv2d(1, 6, kernel_size=5, padding=2), nn.Sigmoid(),
    nn.MaxPool2d(kernel_size=2, stride=2),
    nn.Conv2d(6, 16, kernel_size=5), nn.Sigmoid(),
    nn.MaxPool2d(kernel_size=2, stride=2),
    nn.Flatten(),
    nn.Linear(16 * 5 * 5, 120), nn.Sigmoid(),
    nn.Linear(120, 84), nn.Sigmoid(),
    nn.Linear(84, 10))

lr, num_epochs = 0.9, 10
train_ch6(net_Max, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())
loss 0.422, train acc 0.844, test acc 0.671
31151.6 examples/sec on cuda:0

在这里插入图片描述

几乎无区别


(2)尝试构建一个基于 LeNet 的更复杂网络,以提高其精准性。

a. 调节卷积窗口的大小。
b. 调整输出通道的数量。
c. 调整激活函数(如 ReLU)。
d. 调整卷积层的数量。
e. 调整全连接层的数量。
f. 调整学习率和其他训练细节(例如,初始化和轮数)。
net_Best = nn.Sequential(
    nn.Conv2d(1, 8, kernel_size=5, padding=2), nn.ReLU(),
    nn.AvgPool2d(kernel_size=2, stride=2),
    nn.Conv2d(8, 16, kernel_size=3, padding=1), nn.ReLU(),
    nn.AvgPool2d(kernel_size=2, stride=2),
    nn.Conv2d(16, 32, kernel_size=3, padding=1), nn.ReLU(),
    nn.AvgPool2d(kernel_size=2, stride=2),
    nn.Flatten(),
    nn.Linear(32 * 3 * 3, 128), nn.ReLU(),
    nn.Linear(128, 64), nn.ReLU(),
    nn.Linear(64, 32), nn.ReLU(),
    nn.Linear(32, 10)
)
lr, num_epochs = 0.4, 10
train_ch6(net_Best, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())
loss 0.344, train acc 0.869, test acc 0.854
32868.3 examples/sec on cuda:0

在这里插入图片描述


(3)在 MNIST 数据集上尝试以上改进后的网络。

import torchvision
from torch.utils import data
from torchvision import transforms

trans = transforms.ToTensor()
mnist_train = torchvision.datasets.MNIST(
    root="../data", train=True, transform=trans, download=True)
mnist_test = torchvision.datasets.MNIST(
    root="../data", train=False, transform=trans, download=True)
train_iter2 = data.DataLoader(mnist_train, batch_size, shuffle=True,
                             num_workers=d2l.get_dataloader_workers())
test_iter2 = data.DataLoader(mnist_test, batch_size, shuffle=True,
                            num_workers=d2l.get_dataloader_workers())

lr, num_epochs = 0.4, 5  # 大约 6 轮往后直接就爆炸
train_ch6(net_Best, train_iter2, test_iter2, num_epochs, lr, d2l.try_gpu())
loss 0.049, train acc 0.985, test acc 0.986
26531.1 examples/sec on cuda:0

在这里插入图片描述


(4)显示不同输入(例如,毛衣和外套)时 LetNet 第一层和第二层的激活值。

for X, y in test_iter:
        break

x_first_Sigmoid_layer = net[0:2](X)[0:9, 1, :, :]
d2l.show_images(x_first_Sigmoid_layer.reshape(9, 28, 28).cpu().detach(), 1, 9)
x_second_Sigmoid_layer = net[0:5](X)[0:9, 1, :, :]
d2l.show_images(x_second_Sigmoid_layer.reshape(9, 10, 10).cpu().detach(), 1, 9)
d2l.plt.show()


在这里插入图片描述

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1017765.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

分布式电源接入对配电网影响分析(Matlab代码实现)

💥💥💞💞欢迎来到本博客❤️❤️💥💥 🏆博主优势:🌞🌞🌞博客内容尽量做到思维缜密,逻辑清晰,为了方便读者。 ⛳️座右铭&a…

汽车电子——产品标准规范汇总和梳理(电磁兼容)

文章目录 前言 一、辐射发射 二、传导发射 三、沿电源线瞬态传导发射 四、辐射抗扰 五、大电流注入抗扰 六、磁场抗扰 七、便携式发射机抗扰 八、沿电源线瞬态抗扰 九、非电源线瞬态抗扰 十、电快速脉冲群抗扰 十一、静电抗扰 总结 前言 见《汽车电子——产品标准规…

Oracle数据库安装卸载

文章目录 一、数据库版本说明二、软件的下载和安装1.下载软件2.安装数据3.创建数据库4.PLSQL 三、数据库的卸载1.关闭相关服务3.卸载软件3.删除注册信息 四、用户和权限 一、数据库版本说明 1998年Oracle8i:i指internet,表示oracle向互联网发展&#xf…

bootstrap柵格

.col-xs- 超小屏幕 手机 (<768px) .col-sm- 小屏幕 平板 (≥768px) .col-md- 中等屏幕 桌面显示器 (≥992px) .col-lg- 大屏幕 大桌面显示器 (≥1200px) 分为12个格子 -后面的1代表占12分子1也就是一份 1.中等屏幕 <div class"container-fluid a">&l…

ARM Linux DIY(十二)NES 游戏

文章目录 前言交叉编译工具链使能 Cnes 游戏模拟器移植游戏手柄调试 前言 很多小伙伴为了不让自己的 V3s 吃灰&#xff0c;进而将其打造成游戏机。 我们 DIY 的板子具备屏幕、扬声器、USB Host&#xff08;可以接游戏手柄&#xff09;&#xff0c;当然也要凑一凑热闹。 交叉编…

SpringBootTest

SpringBootTest注解 --基于SpringBoot2.5.7版本 可以在运行基于Spring Boot的测试的测试类上指定的注释。在常规的Spring TestContext框架之上提供了以下特性: 默认提供SpringBootContextLoader作为ContextLoader&#xff0c;也通过 ContextConfiguration(loader…)来自定义 若…

docker安装zookeeper(单机版)

第一步&#xff1a;拉取镜像 docker pull zookeeper第二步&#xff1a;启动zookeeper docker run -d -e TZ"Asia/Shanghai" -p 2181:2181 -v /home/sunyuhua/docker/zookeeper:/data --name zookeeper --restart always zookeeper

基于改进莱维飞行和混沌映射的粒子群优化BP神经网络预测股票价格研究(Matlab代码实现)

&#x1f4a5;&#x1f4a5;&#x1f49e;&#x1f49e;欢迎来到本博客❤️❤️&#x1f4a5;&#x1f4a5; &#x1f3c6;博主优势&#xff1a;&#x1f31e;&#x1f31e;&#x1f31e;博客内容尽量做到思维缜密&#xff0c;逻辑清晰&#xff0c;为了方便读者。 ⛳️座右铭&a…

1990-2022年上市公司常用控制变量数据(1300+变量)

1990-2022年上市公司常用控制变量数据&#xff08;1300变量&#xff09; 1、时间&#xff1a;1990-2022年 2、指标&#xff1a;共计1391个变量&#xff0c;包括但不限于以下类别&#xff1a;、基本信息&#xff08;性质、行业、地址&#xff09;、股票市场&#xff08;发行、…

阿里云无影云电脑详细介绍_无影优势价格和使用

什么是阿里云无影云电脑&#xff1f;无影云电脑&#xff08;原云桌面&#xff09;是一种快速构建、高效管理桌面办公环境&#xff0c;无影云电脑可用于远程办公、多分支机构、安全OA、短期使用、专业制图等使用场景&#xff0c;阿里云百科分享无影云桌面的详细介绍、租用价格、…

通过指针变量存取一维数组元素

任务描述 本关任务&#xff1a;编写程序用指针变量遍历一个数组。 相关知识 为了完成本关任务&#xff0c;你需要掌握&#xff1a; 1.指向数组元素的指针的运算 2.如何遍历数组。 视频1 指针与一维数组 指向数组元素的指针的运算 指针与整数的加减运算 有数组和指针变量的…

区块链(3):区块链去中心化

1 点对点同步区块链的流程 流程图如下&#xff1a; 流程讲解&#xff1a; &#xff08;1&#xff09;连接节点 &#xff08;2&#xff09;向该节点请求最新区块 &#xff08;3&#xff09;请求到区块以后&#xff0c;根据返回的最新区块前置hash是否和我本身的区块hash相等…

抖音矩阵系统-60+自媒体平台一键发布-短视频获客系统

当老板做企业的&#xff0c;想在抖音上做生意的&#xff0c;一定一定要开通蓝V企业号线索版来做矩阵。看看我每天的线索都是999&#xff0c;客服都接待不过来&#xff0c;现在还有300个客资未分配&#xff0c;抖音都在推了企业员工号的玩法&#xff0c;你还在质疑什么&#xff…

ubuntu 18.04 中 eBPF samples/bpf 编译

1. history 信息 一次成功编译 bpf 后执行 history 得到的信息&#xff1a; yingzhiyingzhi-Host:~/ex/ex_kernel/linux-5.4$ history1 ls2 mkdir ex3 cd ex4 mkdir ex_kernel5 ls /boot/6 sudo apt install linux-source7 ls /usr/src/8 uname -r9 cd ex_kernel/10…

Nginx配置最佳实践

&#x1f337;&#x1f341; 博主猫头虎&#xff08;&#x1f405;&#x1f43e;&#xff09;带您 Go to New World✨&#x1f341; &#x1f984; 博客首页——&#x1f405;&#x1f43e;猫头虎的博客&#x1f390; &#x1f433; 《面试题大全专栏》 &#x1f995; 文章图文…

双向链表的实现(增删查改)——最好理解的链表

双向链表的实现 一&#xff0c;双向链表的特点二&#xff0c;双向链表的结构三&#xff0c;双向链表的内容实现3.1创建node节点3.2初始化3.3打印3.4插入3.4.1尾插3.4.2头插3.4.3在pos位置上插入 3.5删除3.5.1尾删3.5.2头删3.5.3删除pos位置上的数据 四&#xff0c;调试技巧&…

基于Xml方式Bean的配置-Bean的延时加载

SpringBean的配置详解 Bean的延时加载 当lazy-init设置为true时为延时加载&#xff0c;也就是当Spring容器创建的时候&#xff0c;不会立即创建Bean实例&#xff0c;等待用到时再创建Bean实例并储存到单例池中&#xff0c;后续使用该Bean时直接从单例池中获取即可&#xff0c;…

文心一言 VS 讯飞星火 VS chatgpt (94)-- 算法导论9.2 3题

三、用go语言&#xff0c;给出 RANDOMIZED-SELECT的一个基于循环的版本。 文心一言&#xff1a; 在循环中实现 RANDOMIZED-SELECT 的一个版本可以如下所示。这个版本使用 Go 语言编写&#xff0c;利用随机化来选择一个元素&#xff0c;并在循环中不断地调整选择的元素&#x…

红黑树封装实现map和set

全文目录 map和set红黑树的封装红黑树的迭代器begin() 和 end()operator()operator--() map完整代码set完整代码改造后的红黑树代码 map和set map和set 是STL中的容器&#xff0c;两者底层的实现都是参考文档&#xff1a; map参考文档 set参考文档 两个容器底层都是去重的二叉…

Spring复杂对象的3中创建方法

复杂对象是相对于简单对象可以直接 new 出的对象。这种对象在 Spring 中不可以通过简单对象的创建方式来创建。下面我们将通过实现 FactoryBean 接口、实例工厂、静态工厂三种方法来创建。 FactoryBean 接口 Spring 提供 FactoryBean 接口并且提供了 getObject 方法是为了支持…