李沐23_LeNet——自学笔记

news2024/11/27 20:36:14

手写的数字识别

知名度最高的数据集:MNIST
1.训练数据:50000

2.测试数据:50000

3.图像大小:28✖28

4.10类

总结

1.LeNet是早期成功的神经网络

2.先使用卷积层来学习图片空间信息

3.使用全连接层来转换到类别空间

代码实现

LeNet由两部分组成:卷积编码器和全连接层密集块

import torch
from torch import nn
from d2l import torch as d2l

class Reshape(torch.nn.Module):
  def forward(self,x):
    return x.view(-1,1,28,28) # 原图:28✖28,填充后是32✖32

net = nn.Sequential(
    nn.Conv2d(1, 6, kernel_size=5, padding=2), nn.Sigmoid(), # 6个通道
    nn.AvgPool2d(kernel_size=2, stride=2),
    nn.Conv2d(6, 16, kernel_size=5), nn.Sigmoid(),
    nn.AvgPool2d(kernel_size=2, stride=2),
    nn.Flatten(),
    nn.Linear(16 * 5 * 5, 120), nn.Sigmoid(),
    nn.Linear(120, 84), nn.Sigmoid(),
    nn.Linear(84, 10))

我们将一个大小为28✖28的单通道(黑白)图像通过LeNet。

X = torch.rand(size=(1, 1, 28, 28), dtype=torch.float32)
for layer in net:
    X = layer(X)
    print(layer.__class__.__name__,'output shape: \t',X.shape)
Conv2d output shape: 	 torch.Size([1, 6, 28, 28])
Sigmoid output shape: 	 torch.Size([1, 6, 28, 28])
AvgPool2d output shape: 	 torch.Size([1, 6, 14, 14])
Conv2d output shape: 	 torch.Size([1, 16, 10, 10])
Sigmoid output shape: 	 torch.Size([1, 16, 10, 10])
AvgPool2d output shape: 	 torch.Size([1, 16, 5, 5])
Flatten output shape: 	 torch.Size([1, 400])
Linear output shape: 	 torch.Size([1, 120])
Sigmoid output shape: 	 torch.Size([1, 120])
Linear output shape: 	 torch.Size([1, 84])
Sigmoid output shape: 	 torch.Size([1, 84])
Linear output shape: 	 torch.Size([1, 10])

在整个卷积块中,与上一层相比,每一层特征的高度和宽度都减小了。

第一个卷积层使用2个像素的填充,来补偿5✖5卷积核导致的特征减少。

相反,第二个卷积层没有填充,因此高度和宽度都减少了4个像素。

随着层叠的上升,通道的数量从输入时的1个,增加到第一个卷积层之后的6个,再到第二个卷积层之后的16个。

同时,每个汇聚层的高度和宽度都减半。最后,每个全连接层减少维数,最终输出一个维数与结果分类数相匹配的输出。

LeNet在Fashion-MNIST数据集上的表现

batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size=batch_size)
# 下载fashion_MNIST数据集
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to ../data/FashionMNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 26421880/26421880 [00:02<00:00, 9258920.33it/s] 


Extracting ../data/FashionMNIST/raw/train-images-idx3-ubyte.gz to ../data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to ../data/FashionMNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 29515/29515 [00:00<00:00, 171125.91it/s]


Extracting ../data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to ../data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to ../data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 4422102/4422102 [00:01<00:00, 3169968.34it/s]


Extracting ../data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to ../data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to ../data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 5148/5148 [00:00<00:00, 4336669.41it/s]

Extracting ../data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to ../data/FashionMNIST/raw




/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py:558: UserWarning: This DataLoader will create 4 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.
  warnings.warn(_create_warning_msg(
def evaluate_accuracy_gpu(net, data_iter, device=None): #计算模型精度
    """使用GPU计算模型在数据集上的精度"""
    if isinstance(net, nn.Module):
        net.eval()  # 设置为评估模式
        if not device:
            device = next(iter(net.parameters())).device
    # 正确预测的数量,总预测的数量
    metric = d2l.Accumulator(2)
    with torch.no_grad():
        for X, y in data_iter:
            if isinstance(X, list):
                # BERT微调所需的(之后将介绍)
                X = [x.to(device) for x in X]
            else:
                X = X.to(device)
            y = y.to(device)
            metric.add(d2l.accuracy(net(X), y), y.numel())
    return metric[0] / metric[1]

训练函数

1.训练函数train_ch6也类似于3.6节中定义的train_ch3。

2.使用高级API创建的模型作为输入,并进行相应的优化。

3.Xavier随机初始化模型参数。
4.使用交叉熵损失函数和小批量随机梯度下降。

#用GPU训练模型,比第三章多了device
def train_ch6(net, train_iter, test_iter, num_epochs, lr, device):
    """用GPU训练模型(在第六章定义)"""
    def init_weights(m):
        if type(m) == nn.Linear or type(m) == nn.Conv2d:
            nn.init.xavier_uniform_(m.weight)
    net.apply(init_weights)
    print('training on', device)
    net.to(device)
    optimizer = torch.optim.SGD(net.parameters(), lr=lr)
    loss = nn.CrossEntropyLoss()
    animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs], # 动画效果
                            legend=['train loss', 'train acc', 'test acc'])
    timer, num_batches = d2l.Timer(), len(train_iter)
    for epoch in range(num_epochs):
        # 训练损失之和,训练准确率之和,样本数
        metric = d2l.Accumulator(3)
        net.train()
        for i, (X, y) in enumerate(train_iter):
            timer.start()
            optimizer.zero_grad()
            X, y = X.to(device), y.to(device)
            y_hat = net(X)
            l = loss(y_hat, y)
            l.backward()
            optimizer.step()
            with torch.no_grad():
                metric.add(l * X.shape[0], d2l.accuracy(y_hat, y), X.shape[0])
            timer.stop()
            train_l = metric[0] / metric[2]
            train_acc = metric[1] / metric[2]
            if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:
                animator.add(epoch + (i + 1) / num_batches,
                             (train_l, train_acc, None))
        test_acc = evaluate_accuracy_gpu(net, test_iter)
        animator.add(epoch + 1, (None, None, test_acc))
    print(f'loss {train_l:.3f}, train acc {train_acc:.3f}, '
          f'test acc {test_acc:.3f}')
    print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec '
          f'on {str(device)}')

训练和评估LeNet-5模型。

lr, num_epochs = 0.9, 10
train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())
loss 0.461, train acc 0.827, test acc 0.793
35664.6 examples/sec on cuda:0

在这里插入图片描述


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1582095.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Al+医学,用这个中文多模态医学大模型帮你看胸片

随着人工智能技术的飞速发展&#xff0c;AI 在医学领域的应用已经成为现实。特别是在医学影像诊断方面&#xff0c;AI 大模型技术展现出了巨大的潜力和价值&#xff0c;但目前针对中文领域医学大多模态大模型还较少。 今天为大家介绍的这款 XrayGLM&#xff0c;就是由澳门理工…

HackTheBox-Machines--Soccer

文章目录 1 信息收集2 CVE-2021-45010 漏洞利用3 横向移动4 权限提升 Soccer 测试过程 1 信息收集 a.端口扫描&#xff1a;发现22、80、9091端口    b.目录扫描&#xff1a;http://soccer.htb/tiny/    c.子域爆破    d.信息泄漏 nmap -sC -sV 10.129.87.151端口扫描结…

【APUE】网络socket编程温度采集智能存储与上报项目技术------多路复用

作者简介&#xff1a; 一个平凡而乐于分享的小比特&#xff0c;中南民族大学通信工程专业研究生在读&#xff0c;研究方向无线联邦学习 擅长领域&#xff1a;驱动开发&#xff0c;嵌入式软件开发&#xff0c;BSP开发 作者主页&#xff1a;一个平凡而乐于分享的小比特的个人主页…

JVM—jps、jstat、jinfo、jmap、jstack的使用

JVM—jps、jstat、jinfo、jmap、jstack的使用 jps jps全称&#xff1a;Java Virtual Machine Process Status Tool 可以查看Java进程&#xff0c;相当于Linux下的ps命令&#xff0c;只不过它只列出Java进程。 jps:列出Jav程序ID和Main函数名称 jps -q:只输出进程ID jps -m …

【星期计算】蓝桥杯

–> 因为这里是结果填空题&#xff0c;我们直接暴力用java自带的BigInteger类。 /*** 试题 A: 星期计算** 本题总分&#xff1a;5 分* 【问题描述】* 已知今天是星期六&#xff0c;请问20的22次方天后是星期几&#xff1f;* 注意用数字 1 到 7 表示星期一到星期日。* * 【答…

Adobe Photoshop 2024 v25.6强大的图形编辑工具

Adobe Photoshop 2024是一款非常强大的图像处理软件&#xff0c;具有丰富的功能和工具&#xff0c;可以满足各种图像处理需求。 软件下载&#xff1a;Adobe Photoshop 2024 v25.6中文激活版 它不仅支持基本的图像编辑和调整&#xff0c;还具有高级的特性&#xff0c;如智能对象…

自定义类型—结构体

目录 1 . 结构体类型的声明 1.1 结构的声明 1.2 结构体变量的创建与初始化 1.3 结构体的特殊声明 1.4 结构体的自引用 2. 结构体内存对齐 2.1 对齐规则 2.2 为什么存在内存对齐 2.3 修改默认对齐数 3. 结构体传参 4.结构体实现位段 4.1 位段的内存分配 1 . 结构体类…

w1r3s 靶机学习

w1r3s 靶机学习 0x01 IP C for command kali ip 10.10.10.128victim ip 10.10.10.1290x02 开扫 C sudo nmap -sn 10.10.10.0/24-sn 多一步入侵和轻量级侦察 发送四项请求 -sL 列表扫描&#xff0c;多用于探测可用ip&#xff0c;广播扫描 –send-ip 时间戳请求&#xff0…

急!开具数电票,提示风险预警怎么办?

随着数电票试点基本落地全国&#xff0c;越来越多的企业需要开具数电票。但一些财务伙伴在开具数电票时&#xff0c;却收到了风险预警弹窗&#xff0c;这是什么意思呢&#xff1f;财务遇到了该如何处理&#xff1f;今天&#xff0c;百小望和大家聊一聊。 1、什么是红黄蓝预警&a…

Java如何实现的跨平台

其实Java能够实现跨平台主要是依赖于虚拟机。 源代码 首先Java的源代码存在于.java文件中&#xff0c;这些源代码是与平台无关的&#xff0c;这就意味着这些源代码可以在任何一个平台上进行编写。 编译成字节码 通过Java编译器将这些源代码编译成字节码&#xff0c;字节码是…

基于Linux定时任务实现的MySQL周期性备份

1、创建备份目录 sudo mkdir -p /var/backups/mysql/database_name2、创建备份脚本 sudo touch /var/backups/mysql/mysqldump.sh# 用VIM编辑脚本文件&#xff0c;写入备份命令 sudo vim /var/backups/mysql/mysqldump.sh# 内如如下 #!/bin/bash mysqldump -uroot --single-…

【IC前端虚拟项目】验证阶段开篇与知识预储备

【IC前端虚拟项目】数据搬运指令处理模块前端实现虚拟项目说明-CSDN博客 从这篇开始进入验证阶段&#xff0c;因为很多转方向的小伙伴是转入芯片验证工程师方向的&#xff0c;所以有必要先做一个知识预储备的说明&#xff0c;或者作为验证入门的一个小指导吧。 在最开始&#…

2024 EasyRecovery易恢复 帮你轻松找回回收站删除的视频

随着数字化时代的到来&#xff0c;我们的生活和工作中越来越依赖于电子设备。然而&#xff0c;电子设备中的数据丢失问题也随之而来。数据丢失可能是由各种原因引起的&#xff0c;如硬盘故障、病毒感染、误删除等。面对这种情况&#xff0c;一个高效、可靠的数据恢复工具变得尤…

力扣—2024 春招冲刺百题计划

矩阵 1. 螺旋矩阵 代码实现&#xff1a; /** param matrix int整型二维数组 * param matrixRowLen int matrix数组行数* param matrixColLen int* matrix数组列数* return int整型一维数组* return int* returnSize 返回数组行数 */ int* spiralOrder(int **matrix, int matri…

C语言 | Leetcode C语言题解之第19题删除链表的倒数第N个结点

题目&#xff1a; 题解&#xff1a; struct ListNode* removeNthFromEnd(struct ListNode* head, int n) {struct ListNode* dummy malloc(sizeof(struct ListNode));dummy->val 0, dummy->next head;struct ListNode* first head;struct ListNode* second dummy;f…

蓝桥杯嵌入式(G431)备赛笔记——UART

printf的重定向 为了方便使用&#xff0c;通过keil中的Help功能的帮助&#xff0c;做一个printf的重定向 搜索fputc&#xff0c;复制这段 将复制的那段放入工程中&#xff0c;并添加串口发送的函数 关键代码 u8 rx_buff[30]; // 定义一个长度为30的接收缓冲区数组rx_buff u8…

C++初阶:6.string类

string类 string不属于STL,早于STL出现 看文档 C非官网(建议用这个) C官网 文章目录 string类一.为什么学习string类&#xff1f;1.C语言中的字符串2. 两个面试题(暂不做讲解) 二.标准库中的string类1. string类(了解)2. string类的常用接口说明&#xff08;注意下面我只讲解…

数组与链表:JavaScript中的数据结构选择

&#x1f90d; 前端开发工程师、技术日更博主、已过CET6 &#x1f368; 阿珊和她的猫_CSDN博客专家、23年度博客之星前端领域TOP1 &#x1f560; 牛客高级专题作者、打造专栏《前端面试必备》 、《2024面试高频手撕题》 &#x1f35a; 蓝桥云课签约作者、上架课程《Vue.js 和 E…

单片机之蜂鸣器

目录 蜂鸣器介绍 蜂鸣器的分类 发声原理分类 按有源无源分类 三极管驱动 蜂鸣器原理 音符与频率对照表 蜂鸣器播放130.8Hz的声音 仿真案例 蜂鸣器发声 电路图 keil文件 蜂鸣器播放音乐 歌曲数据获得 使用的频率 keil文件 蜂鸣器介绍 前言&#xff1a;蜂鸣器是…

SpringBoot中注册Bean的方式汇总

文章目录 ComponentScan Componet相关注解BeanImportspring.factories总结Configuration和Component的主要区别&#xff1f;Bean是不是必须和Configuration一起使用&#xff1f;Import导入配置类有意义&#xff1f;出现异常&#xff1a;java.lang.NoClassDefFoundError: Could…