23/76-LeNet

news2024/10/6 6:51:00

LeNet
早期成功的神经网络。
先使用卷积层来学习图片空间信息。
然后使用全连接层转换到类别空间。

在这里插入图片描述

#In[]
'''
LeNet,上世纪80年代的产物,最初为了手写识别设计
'''
from d2l import torch as d2l
import torch 
from torch import nn
from torch.nn.modules.loss import CrossEntropyLoss

from torch.utils import data
import torchvision
from torchvision import transforms
import matplotlib.pyplot as plt
import Common_functions


'''
LeNet:
两个卷积层,两个池化层,三个线性层
假定为MNIST设计,输入为(batch_size,1,28,28)
'''

class Reshape(torch.nn.Module):
    def forward(self,x):
        return x.view(-1,1,28,28)

net = nn.Sequential(
    nn.Conv2d(in_channels=1,out_channels=6,kernel_size=(5,5),padding=2),nn.Sigmoid(), #输出:(6,28,28)
    nn.AvgPool2d(kernel_size=(2,2)), #不指定stride默认不重叠 输出(6,14,14)
    nn.Conv2d(6,16,kernel_size=(5,5)),nn.Sigmoid(),#输出(16,10,10)
    nn.AvgPool2d(kernel_size=(2,2)),#输出(16,5,5)
    nn.Flatten(),
    nn.Linear(16*5*5,120),nn.Sigmoid(),#
    nn.Linear(120,84),nn.Sigmoid(),
    nn.Linear(84,10)
)


X=torch.rand(size=(1,1,28,28),dtype=torch.float32)
for layer in net:
    X=layer(X)
    print(layer.__class__.__name__,'output shape: \t',X.shape)

#In[]


batch_size = 256
train_iter,test_iter=d2l.load_data_fashion_mnist(batch_size=batch_size)



#对evaluate_accuracy函数进行轻微修改
#使用GPU计算模型在数据集上的精度
#计算网络在测试数据集上面的准确率
#由于完整的测试数据集位于内存中,因此在模型使用GPU预测测试数据集之前,我们需要将其复制到显存中。
def evaluate_accuracy_gpu(net,data_iter,device=None):
    if isinstance(net,nn.Module):
        net.eval() #网络用于测试数据
        if not device:
            device = next(iter(net.parameters())).device #如果没有指定device设备,device设备则使用第一层网络参数的设备
    accumulator = d2l.Accumulator(2) #累加器里面包含两个元素
    for X,y in data_iter:
        if isinstance(X,list):
            X = [x.to(device) for x in X] #X为list类型时,需要加X里面每个元素都复制到device设备上面来
        else:
            X = X.to(device)
        y = y.to(device)
        accumulator.add(d2l.accuracy(net(X),y),y.numel()) #累加器第一个元素为在每一个batch_size中预测准确的个数,第二个元素为每一个batch_size中样本总数目,然后依次循环累加,得到测试数据集上面预测准确的总数目,以及数据集总数目
    return accumulator[0]/accumulator[1] #算出模型预测准确率


def train_ch6(net,train_iter,test_iter,num_epochs,lr,device):
    def init_weights(m):#手动初始化模型参数
        if type(m) == nn.Linear or type(m) == nn.Conv2d:
            nn.init.xavier_uniform_(m.weight) #使用xavier_uniform分布初始化参数
    net.apply(init_weights)
    net.to(device)#将模型复制到gpu上面
    print('training on',device)
    loss = nn.CrossEntropyLoss() #定义loss
    optim = torch.optim.SGD(net.parameters(),lr=lr) #定义优化器
    animator = d2l.Animator(xlabel='epoch',xlim=[1,num_epochs],legend=['train_loss','train_acc','test_acc'])
    timer = d2l.Timer()
    num_batches = len(train_iter)
    for epoch in range(num_epochs):
        net.train()#模型开始训练,需要放在第一层循环里面,因为后面evaluate_accuracy_gpu()函数里面有net.eval(),将模型改变为测试状态,因此需要在每一个循环epoch后面手动再加上模型开始处于训练状态
        accumulator = d2l.Accumulator(3) #累加器
        for i,(X,y) in enumerate(train_iter):
           timer.start()
           optim.zero_grad()
           X = X.to(device)#将X复制到gpu上面
           y = y.to(device) #将y复制到gpu上面
           y_hat = net(X) #得到模型训练后的输出标签y_hat
           l = loss(y_hat,y)#计算每一个batch_size的loss
           l.backward() #计算梯度
           optim.step() #使用优化器更新模型参数
           with torch.no_grad():#不需要模型梯度
               accumulator.add(l*X.shape[0],d2l.accuracy(y_hat,y),X.shape[0])
           timer.stop()
           train_loss = accumulator[0]/accumulator[2] #从累加器里面获得所有训练集的loss之和
           train_acc = accumulator[1]/accumulator[2] #从累加器里面获得所有训练集的准确数之和
           if (i+1) % (num_batches // 5) == 0 or i == num_batches-1:
               animator.add(epoch+(i+1)/num_batches,(train_loss,train_acc,None))
        test_accuracy = evaluate_accuracy_gpu(net,test_iter) #每次训练完一个epoch后的模型用于测试数据集上面计算测试精确度
        animator.add(epoch+1,(None,None,test_accuracy))
    print(f'模型训练完最后一轮时 train_loss:{train_loss},train_acc:{train_acc},test_acc:{test_accuracy}')
    print(f'{num_epochs*accumulator[2]/timer.sum()}examples/second on {str(device)}')#打印出模型每秒能处理多少个样本数

lr,num_epochs= 0.9,10
train_ch6(net,train_iter=train_iter,test_iter=test_iter,lr=lr,num_epochs=num_epochs,device=d2l.try_gpu())
'''
输出结果:
模型训练完最后一轮时 train_loss:0.4322478462855021,train_acc:0.8396666666666667,test_acc:0.8163
55954.65804440994examples/second on cuda:0
'''








#训练
if torch.cuda.is_available():
    device = "cuda:0"
else:
    device = "cpu"
device = torch.device(device)

Common_functions.train_device(net,train_iter,test_iter,lr=0.9,device=device)
# %%

plt.show()

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1391024.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Git中,版本库和远程库有什么区别

✅作者简介:大家好,我是Leo,热爱Java后端开发者,一个想要与大家共同进步的男人😉😉 🍎个人主页:Leo的博客 💞当前专栏:每天一个知识点 ✨特色专栏&#xff1a…

【jupyter添加虚拟环境内核(pytorch、tensorflow)- 实操可行】

jupyter添加虚拟环境内核(pytorch、tensorflow)- 实操可行 1、查看当前状态(winR,cmd进入之后)2、激活虚拟环境并进入3、安装ipykernel5、完整步骤代码总结6、进入jupyter 添加pytorch、tensorflow内核操作相同,以下内容默认已经安…

IP定位技术在网络安全行业的探索

随着互联网的普及和深入生活,网络安全问题日益受到人们的关注。作为网络安全领域的重要技术,IP定位技术正逐渐成为行业研究的热点。本文将深入探讨IP定位技术在网络安全行业的应用和探索。 一、IP定位技术的概述 IP定位技术是通过IP地址来确定设备地理位…

muduo网络库剖析——通道Channel类

muduo网络库剖析——通道Channel类 前情从muduo到my_muduo 概要事件种类channel 框架与细节成员函数细节实现使用方法 源码结尾 前情 从muduo到my_muduo 作为一个宏大的、功能健全的muduo库,考虑的肯定是众多情况是否可以高效满足;而作为学习者&#x…

RK3399平台入门到精通系列讲解(硬件篇)常用的硬件工具介绍

🚀返回总目录 文章目录 一、万⽤表1.1、测量交流和直流电压1.2、测量交流和直流电流二、逻辑分析仪三、示波器作为⼀名嵌⼊式开发⼯程师,是有必要对各类常⽤的硬件⼯具有⼀定了解的,你可以不懂怎么使⽤它,但你必须知道它是什么,有什么⽤,在什么时候可以⽤得上。 一、万…

自动驾驶中的坐标系

自动驾驶中的坐标系 自动驾驶中的坐标系 0.引言1.相机传感器坐标系2.激光雷达坐标系3.车体坐标系4.世界坐标系4.1.地理坐标系4.2.投影坐标系4.2.1.投影方式4.2.2.墨卡托(Mercator)投影4.2.3.高斯-克吕格(Gauss-Kruger)投影4.2.4.通用横轴墨卡托UTM(UniversalTransve…

Linux命令之用户账户管理whoami,useradd,passwd,chage,usermod,userdel的使用

1、查看当前用户账户 2、切换用户为root用户 3、新建用户user1,给用户user1设置密码为password123 4、新建用户user2,UID为510,指定其所属的私有组为group1(group1组的标识符为500),用户的主目录为/home/us…

人类的逻辑常常是演绎、归纳和溯因推理混合

人类的逻辑推理往往是一种综合运用不同推理方式的能力。 演绎推理是从已知的前提出发,推断出必然的结论。通过逻辑规则的应用,人们可以从一些已知的事实或前提出发,得出一个必然成立的结论。演绎推理是一种严密的推理方式,它能够保…

网络安全笔记-SQL注入

文章目录 前言一、数据库1、Information_schema2、相关函数 二、SQL注入分类1、联合查询注入(UNION query SQL injection)语法 2、报错注入(Error-based SQL injection)报错注入分类报错函数报错注入原理 3、盲注布尔型盲注&#…

ROS第 2 课 ROS 系统安装和环境搭建

文章目录 方法一:一键安装(推荐)方法二:逐步安装(常规安装方式)1.版本选择2.检查 Ubuntu 的软件和更新源3.设置 ROS 的下载源3.1 设置国内下载源3.2 设置公匙3.3 更新软件包 4. 安装 ROS5. 设置环境变量6. …

HBase 基础

HBase 基础 HBase1. HBase简介1.1 HBase定义1.2 HBase数据模型1.2.1 HBase逻辑结构1.2.2 HBase物理存储结构1.2.3 数据模型 1.3 HBase基本架构 2. HBase环境安装2.1 HBase 安装部署2.1.1 HBase 本地按照2.1.2 HBase 伪分布模式安装2.1.3 HBase 集群安装 2.2 HBase Shell操作2.2…

springcloud Alibaba中gateway和sentinel联合使用

看到这个文章相信你有一定的sentinel和gateway基础了吧。 官网的gateway和sentinel联合使用有些过时了,于是有了这个哈哈,给你看看官网的: 才sentinel1.6,现在都几了啊,所以有些过时。 下面开始讲解: 首先…

JAVAEE初阶 文件IO(一)

这里写目录标题 一. 计算机中存储数据的设备1.1 CPU1.2 内存1.3 硬盘1.4 三种存储的区别 二.文件系统2.1 相对路径2.2 绝对路径2.3 .和..的含义2.4 例子2.5 everything工具 三.文件3.1 文本文件3.2 二进制文件 四. JAVA对于文件的API4.1 getParent getName getPath getAbsolute…

Dubbo服务降级:保障稳定性的终极指南【六】

欢迎来到我的博客,代码的世界里,每一行都是一个故事 Dubbo服务降级:保障稳定性的终极指南【六】 前言服务降级概述服务降级配置服务降级最佳实践 前言 在构建分布式系统时,不可避免地会面临高流量、网络故障和服务不可用等问题。…

Python | 三、函数

函数的形参和实参(对应卡码网11题句子缩写) 除非实参是可变对象,如列表、字典和集合,则此时形参会复制实参的地址,即此时二者指向同一个地址,因此在函数内对形参的操作会影响到实参除这种情况外&#xff0…

FlinkAPI开发之处理函数

案例用到的测试数据请参考文章: Flink自定义Source模拟数据流 原文链接:https://blog.csdn.net/m0_52606060/article/details/135436048 概述 之前所介绍的流处理API,无论是基本的转换、聚合,还是更为复杂的窗口操作&#xff0c…

Kafka-RecordAccumulator分析

前面介绍过,KafkaProducer可以有同步和异步两种方式发送消息,其实两者的底层实现相同,都是通过异步方式实现的。 主线程调用KafkaProducer.send方法发送消息的时候,先将消息放到RecordAccumulator中暂存,然后主线程就…

HCIA—— 16每日一讲:HTTP和HTTPS、无状态和cookie、持久连接和管线化、(初稿丢了,这是新稿,请宽恕我)

学习目标: HTTP和HTTPS、无状态和cookie、持久连接和管线化、HTTP的报文、URI和URL(初稿丢了,这是新稿,请宽恕我😶‍🌫️) 学习内容: HTTP无状态和cookieHTTPS持久连接和管线化 目…

Angular系列教程之MVC模式和MVVM模式

文章目录 MVC模式MVVM模式MVC与MVVM的区别Angular如何实现MVVM模式总结 在讨论Angular的时候,我们经常会听到MVC和MVVM这两种设计模式。这两种模式都是为了将用户界面(UI)和业务逻辑分离,使得代码更易于维护和扩展。在这篇文章中,我们将详细介…

[Python练习]使用Python爬虫爬取豆瓣top250的电影的页面源码

1.安装requests第三方库 在终端中输入以下代码(直接在cmd命令提示符中,不需要打开Python) pip install requests -i https://pypi.douban.com/simple/ 从豆瓣网提供的镜像网站下载requests第三方库 pip install requests 是从国外网站下…