6.6 实现卷积神经网络LeNet训练并预测手写体数字

news2025/1/9 1:58:35

模型架构

在这里插入图片描述
在这里插入图片描述

代码实现

import torch
from torch import nn
from d2l import torch as d2l
net = nn.Sequential(
    nn.Conv2d(1,6,kernel_size=5,padding=2),nn.Sigmoid(),#padding=2补偿5x5卷积核导致的特征减少。
    nn.AvgPool2d(kernel_size=2,stride=2),
    nn.Conv2d(6,16,kernel_size=5),nn.Sigmoid(),
    nn.AvgPool2d(kernel_size=2,stride=2),
    nn.Flatten(),
    nn.Linear(16*5*5,120),nn.Sigmoid(),
    nn.Linear(120,84),nn.Sigmoid(),
    nn.Linear(84,10)
)
'''定义X,并打印模型的形状'''
# 第一个参数是样本
X = torch.rand(size=(1,1,28,28),dtype=torch.float32)
for layer in net:
    X = layer(X)
    print(layer.__class__.__name__,'output shape: \t',X.shape)
# 输出如下:
Conv2d output shape: 	 torch.Size([1, 6, 28, 28])
Sigmoid output shape: 	 torch.Size([1, 6, 28, 28])
AvgPool2d output shape: 	 torch.Size([1, 6, 14, 14])
Conv2d output shape: 	 torch.Size([1, 16, 10, 10])
Sigmoid output shape: 	 torch.Size([1, 16, 10, 10])
AvgPool2d output shape: 	 torch.Size([1, 16, 5, 5])
Flatten output shape: 	 torch.Size([1, 400])
Linear output shape: 	 torch.Size([1, 120])
Sigmoid output shape: 	 torch.Size([1, 120])
Linear output shape: 	 torch.Size([1, 84])
Sigmoid output shape: 	 torch.Size([1, 84])
Linear output shape: 	 torch.Size([1, 10])
'''定义训练批次并加载训练集和测试集'''
batch_size = 256
# 按照batch_size把数据集取出来。取出来之后是放到内存中的,后面要把它加载到GPU中
train_iter,test_iter = d2l.load_data_fashion_mnist(batch_size=batch_size)
# 计算预测正确的个数
def accuracy(y_hat,y):
    '''计算预测正确的数量'''
    if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:
        # y_hat是下标表示类别,值是该类别的概率。模型结果是预测10个类的概率,谁的概率最大,就取谁的下标
        y_hat = y_hat.argmax(axis=1)
    #  y_hat.type(y.dtype):因为==对数据类型很敏感,因此我们将y_hat的数据类型转换为与y的数据类型一致。
    #  y_hat.type(y.dtype) == y,将预测值y_hat与真实值y比较,返回一个包含 0和1的张量,赋值给cam,最后求和会得到正确预测的数量。
    cam = y_hat.type(y.dtype) == y
    return float(cam.type(y.dtype).sum())
def evaluate_accuracy_gpu(net,data_iter,device=None):
    if isinstance(net,nn.Module):
        net.eval() # 将模型设置为评估模式
        if not device:
            '''
                iter(net.parameters())是将参数集合转换为迭代器,并获取其中的第一个元素
                next(iter(net.parameters())).device ,指取到net.parameters()的第一个元素,获取该元素的设备。
            '''
            device = next(iter(net.parameters())).device
        # Accumulator用于对多个变量进行累加,d2l.Accumulator(2) 是在Accumulator实例中创建了2个变量,分别用于存储正确预测的数量和预测的总数量。当我们遍历数据集时,两者都随着时间的推移而累加。
        metric = d2l.Accumulator(2) # 正确预测数,预测总数
        with torch.no_grad():
            for X,y in data_iter:
                if isinstance(X,list): # 详见文章最下面的补充内容
                    X = [x.to(device) for x in X] # 令X使用设备device
                else:
                    X = X.to(device)
                y = y.to(device)
                # y.numel()是批次中样本的数量,accuracy(net(X),y)是用于计算模型在输入数据X上的输出结果与标签Y之间的准确率。
                # metric.add函数将正确预测的数量 和 样本数量作为参数传递进去,用于记录和累计这些指标的值。
                metric.add(accuracy(net(X),y),y.numel())
        return metric[0]/metric[1] # 返回准确率,其中metric[0]存放的是正确预测的个数,metric[1]存放的是样本数量,
def train_ch6(net,train_iter,test_iter,num_epochs,lr,device):
    def init_weights(m):
        if type(m) == nn.Linear or type(m) == nn.Conv2d: # 对神经网络中的线性层和卷积层的权重进行初始化
            nn.init.xavier_uniform_(m.weight) #用于初始化权重的函数,
    net.apply(init_weights)
    print('training on',device)
    net.to(device) # 设置模型使用device
    optimizer = torch.optim.SGD(net.parameters(),lr=lr)
    loss = nn.CrossEntropyLoss()
    '''
        该代码创建了一个名为animator的动画器,用于在训练过程中可视化损失函数和准确率的变化情况
    '''
    animator = d2l.Animator(xlabel='epoch',xlim=[1,num_epochs],legend=['train loss','train acc','test acc'])
    timer,num_batches = d2l.Timer(),len(train_iter)
    for epoch in range(num_epochs):
        # 创建 Accumulator类,统计训练损失之和,正确预测个数之和,样本数
        metric = d2l.Accumulator(3)
        net.train()
        for i,(X,y) in enumerate(train_iter):
            timer.start()
            optimizer.zero_grad()
            X,y = X.to(device),y.to(device)
            y_hat = net(X)
            l = loss(y_hat,y)
            l.backward()
            optimizer.step()
            with torch.no_grad():
                metric.add(l*X.shape[0], d2l.accuracy(y_hat,y), X.shape[0])
            timer.stop()
            train_l = metric[0] / metric[2] # 损失之和 / 样本数
            train_acc = metric[1] / metric[2] # 正确预测个数 / 样本数
            if (i+1) % (num_batches//5)==0 or i == num_epochs-1:
                animator.add(epoch + (i+1)/num_epochs,(train_l,train_acc,None))
        test_acc = evaluate_accuracy_gpu(net,test_iter)
        animator.add(epoch+1,(None,None,test_acc))
    print(f'loss {train_l:.3f}, train acc {train_acc:.3f}, '
          f'test acc {test_acc:.3f}')
    print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec '
          f'on {str(device)}')
# 定义学习率和批次 开始训练
lr,num_epochs = 0.9,10
train_ch6(net,train_iter,test_iter,num_epochs,lr,d2l.try_gpu())

在这里插入图片描述

练习

把平均汇聚层改为最大汇聚层

在这里插入图片描述

把平均池化改为最大池和把激活函数改为RelU之后的效果

net = nn.Sequential(
    nn.Conv2d(1,6,kernel_size=5,padding=2),nn.ReLU(),#padding=2补偿5x5卷积核导致的特征减少。
    nn.MaxPool2d(kernel_size=2,stride=2),
    nn.Conv2d(6,16,kernel_size=5),nn.ReLU(),
    nn.MaxPool2d(kernel_size=2,stride=2),
    nn.Flatten(),
    nn.Linear(16*5*5,120),nn.ReLU(),
    nn.Linear(120,84),nn.Sigmoid(), #注意,此处不能改为RelU,此处的sigmoid是把预测结果映射成概率
    nn.Linear(84,10)
)

在这里插入图片描述

使用训练好的模型进行预测

y_hat = net(x)

补充:

isinstance(net,nn.Module)

isinstance(net,nn.Module)是Python的内置函数,用于判断一个对象是否属于制定类或其子类的实例。如果net是nn.Module类或子类的实例,那么表达式返回True,否则返回False. nn.Module是pytorch中用于构建神经网络模型的基类,其他神经网络都会继承它,因此使用 isinstance(net,nn.Module),可以确定Net对象是否为一个有效的神经网络模型。

`nn.init.xavier_uniform_(m.weight)

nn.init.xavier_uniform_(m.weight) 是一个用于初始化权重的函数,采用的是 Xavier 均匀分布初始化方法。

在神经网络中,权重的初始化非常重要,合适的初始化可以帮助网络更好地学习和收敛。Xavier 初始化方法是一种常用的权重初始化方法之一,旨在使权重在前向传播过程中保持方差不变。

具体而言,nn.init.xavier_uniform_() 函数会对输入的权重张量 m.weight 进行操作,将其初始化为一个均匀分布中的随机值。这个均匀分布的范围根据权重张量的形状进行调整,以保持前向传播过程中特征的方差稳定。

通过使用 Xavier 初始化方法,可以加速神经网络的训练过程,并且有助于避免梯度消失或梯度爆炸等问题。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/842353.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

[疑难杂症2023-007]multiprocessing.Process使用时遇到的几个棘手问题

本文由Markdown编辑器编辑完成。 1. 背景 近日,为了解决自己负责的一个组件,在处理大量数据时,由于内存释放不及时,而导致整个组件占用了较高的内存。 这主要是因为目前我们在使用python的一个采用多进程的框架——Celery. 关于…

解决Map修改key的问题

需求 现在返回json数据带有分页的数据,将返回data属性数据变更为content,数据不变,key发生变化 实现1,源数据比较复杂,组装数据比较麻烦 说明:如果使用这种方式完成需求,需要创建对象&#xff0…

搭建简易syslog日志中转服务器

在某种场景下,无法接入日志审计设备,本文提供一种方式,可通过搭建简易日志中转服务器,收集到该环境下的日志后,再将其导入日志审计设备中。 0x1 开启服务 rsyslog守护进程来自于当前的linux发布版本的预装模块&#x…

【vue】vue基础知识

1、插值表达式&属性绑定 <!--template展示给用户&#xff0c;相当于MVVM模式中的V--> <template><div class"first_div">//插值表达式<p>{{ message }}</p>//这里的参数是从父组件的template里传过来的<p>{{data_1}}</p…

SIP协议之呼叫保持(HOLD)

呼叫保持(HOLD)是SIP协议应用中的一个重要功能&#xff0c;用于实现不挂断电话而达到暂停媒体&#xff08;常见于音频&#xff0c;视频很少用&#xff09;的目的&#xff0c;而解保持操作会恢复通话。 一、保持/解保持实现机制 1.1 保持 保持发起方&#xff08;保持方&#x…

Flutter 自定义view

带进度动画的圆环。没gif&#xff0c;效果大家自行脑补。 继承CustomPainter&#xff0c;paint()方法中拿到canvas&#xff0c;绘制API和android差不多。 import package:flutter/material.dart;class ProgressRingPainter extends CustomPainter {double strokeWidth 20;Col…

2023应急指挥系统总体架构方案 PPT

导读&#xff1a;原文《应急指挥系统总体架构方案PPT》&#xff08;获取来源见文尾&#xff09;&#xff0c;本文精选其中精华及架构部分&#xff0c;逻辑清晰、内容完整&#xff0c;为快速形成售前方案提供参考。 完整版领取方式 完整版领取方式&#xff1a; 如需获取完整的电…

如何找到死锁的线程?_java都学什么

在Java中&#xff0c;死锁是指两个或多个线程被无限地阻塞&#xff0c;等待彼此持有的资源&#xff0c;从而导致程序无法继续执行的情况。死锁通常是由于线程之间循环等待资源而产生的。要找到死锁的线程&#xff0c;可以采用以下方法&#xff1a; 1.线程转储(Thread Dump) 通过…

【Vant Weapp】van-icon 图标

<van-icon name"play-circle" color"rgba(255,255,255,.6)" size"40"/>

关于策略模式的注入问题

上面抄别人的 当在实现策略方法时&#xff0c;报null&#xff0c;排查后发现是接口实现有多个&#xff0c;需要添加别名 注入时添加Qeualifier&#xff0c;指定名称&#xff0c;如下图&#xff1b;如图上修改&#xff0c; 测试类中不用new具体行为策略了&#xff0c;注入别名即…

OpenStreetMap数据转3D场景【Python + PostgreSQL】

很长一段时间以来&#xff0c;我对 GIS 和渲染感兴趣&#xff0c;在分别尝试这两者之后&#xff0c;我决定最终尝试以 3D 方式渲染 OpenStreetMap 中的地理数据&#xff0c;重点关注不超过城市的小规模。 在本文中&#xff0c;我将介绍从建筑形状生成三角形网格、以适合 Blend…

认识Vue;vue使用和安装;声明式和命令式编程;MVVM模型;data属性;methods属性

目录 1_认识Vue2_vue使用和安装3_声明式和命令式编程4_MVVM模型5_data属性6_methods属性 1_认识Vue Vue (读音 /vjuː/&#xff0c;类似于 view) 是一套用于构建用户界面的渐进式 JavaScript框架。 全称是Vue.js或者Vuejs&#xff1b; 它基于标准 HTML、CSS 和 JavaScript 构建…

SpringBoot开发环境热部署

目录 开发热部署 添加dev-tools依赖 在application.properties中配置devtools 在IDEA中添加设置 开发热部署 在实际的项目开发调试过程中会频繁地修改后台类文件&#xff0c;导致需要重新编译、 重新启动&#xff0c;整个过程非常麻烦&#xff0c;影响开发…

密码攻击与ADSelfService Plus的保护

密码攻击是当前网络安全面临的严峻挑战之一。黑客通过不断演进的技术手段&#xff0c;试图入侵用户账户&#xff0c;窃取敏感信息&#xff0c;从而对个人和组织造成严重损害。为了应对密码攻击的威胁&#xff0c;ManageEngine推出了ADSelfService Plus&#xff0c;这是一款功能…

【数据结构】链表(一)

链表&#xff08;一&#xff09; 文章目录 链表&#xff08;一&#xff09;01 引入02 概念及结构03 单向不带头不循环链表实现3.1 创建节点类型3.2 简易创建一个链表3.3 遍历链表每个节点3.4 获取链表长度3.5 查找是否包含关键字key是否在单链表当中3.6 头插法3.7 尾插法3.8 任…

无涯教程-Perl - delete函数

描述 此函数从哈希中删除指定的键和关联的值,或从数组中删除指定的元素。该操作适用于单个元素或切片。 语法 以下是此函数的简单语法- delete LIST返回值 如果键不存在,并且与已删除的哈希键或数组索引关联的值,则此函数返回undef。 Perl 中的 delete函数 - 无涯教程网无…

Java spring boot 全解Camunda 7,从 0 到 1 构建工作流平台——第二节:Spring boot 简单集成

目录 1. 成果展示2. 环境准备3. 项目构建3.1 项目结构3.2 引入Camunda 依赖3.3 启动spring boot 程序3.4 启动 web app 程序 引言&#xff1a;当今技术发展迅猛&#xff0c;企业对于业务流程的高效管理和自动化需求也日益增长。在这个背景下&#xff0c;Spring Boot和Camunda7成…

【网络基础实战之路】基于MGRE多点协议的实战详解

系列文章传送门&#xff1a; 【网络基础实战之路】设计网络划分的实战详解 【网络基础实战之路】一文弄懂TCP的三次握手与四次断开 【网络基础实战之路】基于MGRE多点协议的实战详解 【网络基础实战之路】基于OSPF协议建立两个MGRE网络的实验详解 PS&#xff1a;本要求基于…

Jupyter Notebook 未授权访问远程命令执行漏洞

漏洞描述 Jupyter是一个开源的交互式计算环境&#xff0c;它支持多种编程语言&#xff0c;包括Python、R、Julia等。Jupyter的名称来源于三种编程语言的缩写&#xff1a;Ju(lia)、Py(thon)和R。 Jupyter的主要特点是它以笔记本&#xff08;Notebook&#xff09;的形式组织代码…

Python基础教程——贪吃蛇、连连看小游戏(完整版,附源码)

一、贪吃蛇 1. 案例介绍 贪吃蛇是一款经典的益智游戏&#xff0c;简单又耐玩。该游戏通过控制蛇头方向吃蛋&#xff0c;从而使得蛇变得越来越长。 通过上下左右方向键控制蛇的方向&#xff0c;寻找吃的东西&#xff0c;每吃一口就能得到一定的积分&#xff0c;而且蛇的身子会…