LeNet-5(fashion-mnist)

news2024/11/15 18:06:30

文章目录

  • 前言
  • LeNet
  • 模型训练

前言

LeNet是最早发布的卷积神经网络之一。该模型被提出用于识别图像中的手写数字。

LeNet

LeNet-5由以下两个部分组成

  • 卷积编码器(2)
  • 全连接层(3)
    卷积块由一个卷积层、一个sigmoid激活函数和一个平均汇聚层组成。
    第一个卷积层有6个输出通道,第二个卷积层有16个输出通道。采用2×2的汇聚操作,且步幅为2.
    3个全连接层分别有120,84,10个输出。
    此处对原始模型做出部分修改,去除最后一层的高斯激活。
net=nn.Sequential(nn.Conv2d(1,6,kernel_size=5,padding=2),nn.Sigmoid(),
                  nn.AvgPool2d(kernel_size=2,stride=2),
                  nn.Conv2d(6,16,kernel_size=5),nn.Sigmoid(),
                  nn.AvgPool2d(kernel_size=2,stride=2),
                  nn.Flatten(),
                  nn.Linear(16*5*5,120),nn.Sigmoid(),
                  nn.Linear(120,84),nn.Sigmoid(),
                  nn.Linear(84,10))

模型训练

为了加快训练,使用GPU计算测试集上的精度以及训练过程中的计算。
此处采用xavier初始化模型参数以及交叉熵损失函数和小批量梯度下降。

batch_size=256
train_iter,test_iter=data_iter.load_data_fashion_mnist(batch_size)

将数据送入GPU进行计算测试集准确率

def evaluate_accuracy_gpu(net,data_iter,device=None):
    """使用GPU计算模型在数据集上的精度"""
    if isinstance(net,torch.nn.Module):
        net.eval()
        if not device:
            device=next(iter(net.parameters())).device
    # 正确预测的数量,预测的总数
    eva = 0.0
    y_num = 0.0
    with torch.no_grad():
        for X,y in data_iter:
            if isinstance(X,list):
                X=[x.to(device) for x in X]
            else:
                X=X.to(device)
            y=y.to(device)
            eva += accuracy(net(X), y)
            y_num += y.numel()
    return eva/y_num

训练过程同样将数据送入GPU计算

def train_epoch_gpu(net, train_iter, loss, updater,device):

    # 训练损失之和,训练准确数之和,样本数
    train_loss_sum = 0.0
    train_acc_sum = 0.0
    num_samples = 0.0
    # timer = d2l.torch.Timer()
    for i, (X, y) in enumerate(train_iter):
        # timer.start()
        updater.zero_grad()
        X, y = X.to(device), y.to(device)
        y_hat = net(X)
        l = loss(y_hat, y)
        l.backward()
        updater.step()
        with torch.no_grad():
            train_loss_sum += l * X.shape[0]
            train_acc_sum += evaluation.accuracy(y_hat, y)
            num_samples += X.shape[0]
        # timer.stop()
    return train_loss_sum/num_samples,train_acc_sum/num_samples


def train_gpu(net,train_iter,test_iter,num_epochs,lr,device):
    def init_weights(m):
        if type(m)==torch.nn.Linear or type(m)==torch.nn.Conv2d:
            torch.nn.init.xavier_uniform_(m.weight)

    net.apply(init_weights)
    net.to(device)
    print('training on',device)
    optimizer=torch.optim.SGD(net.parameters(),lr=lr)
    loss=torch.nn.CrossEntropyLoss()
    # num_batches=len(train_iter)
    tr_l=[]
    tr_a=[]
    te_a=[]
    for epoch in range(num_epochs):
        net.train()
        train_metric=train_epoch_gpu(net,train_iter,loss,optimizer,device)
        test_accuracy = evaluation.evaluate_accuracy_gpu(net, test_iter)
        train_loss, train_acc = train_metric
        train_loss = train_loss.cpu().detach().numpy()
        tr_l.append(train_loss)
        tr_a.append(train_acc)
        te_a.append(test_accuracy)
        print(f'epoch: {epoch + 1}, train_loss: {train_loss}, train_acc: {train_acc}, test_acc:{test_accuracy}')
    x = torch.arange(num_epochs)
    plt.plot((x + 1), tr_l, '-', label='train_loss')
    plt.plot(x + 1, tr_a, '--', label='train_acc')
    plt.plot(x + 1, te_a, '-.', label='test_acc')
    plt.legend()
    plt.show()
    print(f'on {str(device)}')
lr,num_epochs=0.9,10
Train.train_gpu(net,train_iter,test_iter,num_epochs,lr,device='cuda')

在这里插入图片描述
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1372681.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

数据结构及单链表例题(下)

上次我们已经了解了单链表的数据结构定义以及创建单链表的两种方法,这节介绍几道例题. 文章目录 前言 一、已知L为带头结点的单链表,请依照递归思想实现下列运算 二、单链表访问第i个数据节点 三、在第i个元素前插入元素e 四、删除第i个结点 五、查找带头结点单链表倒数第…

Vivado开发FPGA使用流程、教程 verilog(建立工程、编译文件到最终烧录的全流程)

目录 一、概述 二、工程创建 三、添加设计文件并编译 四、线上仿真 五、布局布线 六、生成比特流文件 七、烧录 一、概述 vivado开发FPGA流程分为创建工程、添加设计文件、编译、线上仿真、布局布线(添加约束文件)、生成比特流文件、烧录等步骤&a…

公司内部核心文件数据\资料防泄密软件系统,防止未经授权文件、文档、图纸、源代码、音视频...等数据资料外泄,自动智能透明加密保护!

为了保护公司内部的核心文件和数据资料,防止未经授权的外泄,使用自动智能透明加密保护软件系统是非常重要的。 这样的系统可以通过以下方式实现防泄密: 自动智能加密:该系统可以对公司内部的核心文件和数据资料进行自动智能加密&…

夏目友人帐OVA:和猫咪老师的初次跑腿、曾几何时下雪之日 2013.12.15

夏目友人帐OVA 1、和猫咪老师的初次跑腿 / ニャンコ先生とはじめてのおつかい2、曾几何时下雪之日 / いつかゆきのひに 1、和猫咪老师的初次跑腿 / ニャンコ先生とはじめてのおつかい 和夏目一起外出的途中,猫咪老师因追蜻蜓遇到了一对迷路的龙凤胎兄妹。猫咪老师不…

僵尸毁灭工程手动存档工具

介绍 这是一个可以对僵毁游戏存档进行备份的小工具,其基本原理是对僵毁存档中数以万计的小文件做哈希值计算并保存下来,下一次备份时再对存档文件进行哈希值计算,每次备份只对两次计算结果中存在差异的文件进行复制与替换从而忽略掉大部分未…

汽车IVI中控开发入门及进阶(十二):手机投屏

前言: 汽车座舱有车载中控大屏、仪表/HUD多屏的显示能力,有麦克风/喇叭等车载环境更好的音频输入输出能力,有方控按键、旋钮等方便的反向控制输入能力,还有高精度的车辆数据等。但汽车座舱中控主机硬件计算能力升级迭代周期相对较长,汽车的应用和服务不够丰富。现在很多汽…

电脑USB接口不同颜色的含义

当你看到笔记本电脑或台式机的USB端口时,你会发现USB端口的颜色很多;这些颜色可不只是为了好看,实际上不同颜色代表着不同的性能,那么这些带颜色的USB端口都是什么含义呢,下面就具体介绍下不同颜色代表的含义。-----吴…

Oracle regexp_substr

select regexp_substr(123|456|789, [^|], 1, 2) from dual;

高性能、可扩展、支持二次开发的企业电子招标采购系统源码

在数字化时代,企业需要借助先进的数字化技术来提高工程管理效率和质量。招投标管理系统作为企业内部业务项目管理的重要应用平台,涵盖了门户管理、立项管理、采购项目管理、采购公告管理、考核管理、报表管理、评审管理、企业管理、采购管理和系统管理等…

固乔快递查询助手:批量、快速、全面的快递信息查询软件

在快递行业飞速发展的今天,如何高效、准确地掌握快递信息成为了很多人的需求。而固乔快递查询助手正是解决这一难题的利器。 固乔快递查询助手是一款专注于快递信息查询的软件,支持多家主流快递公司查询。用户只需输入单号,即可快速查询到实时…

BitMap解析之RoaringBitMap

文章目录 BitMap计算的问题Roaring BitMap原理解析Container 介绍1. ArrayContainer2. BitmapContainer3. RunContainer小结 Roaring BitMap的集合运算1. 两个ArrayContainer的and过程2. 两个BitmapContainer的and过程时空分析 Container的创建与转换为什么说Roaring Bitmap压缩…

关于Django字段类型DateTimeField默认是不是不能为空的测试【DateTimeField学习过程记录】

chatgpt说这个字段默认情况下: blankFalse,nullFalse,也就是在数据库层面和表单层面都不能为空,我们不妨来测试一下在数据库层面是不是这样。 沿用博文 https://blog.csdn.net/wenhao_ir/article/details/135499962的配置。 模型…

科大讯飞星火大模型加持数字员工系列产品发布

面对时代浪潮,基业长青的企业总会率先拥抱变化,在时代交替中创造新的增长空间。当数字化浪潮涌入千行百业,企业掌舵者如何选择转型? 从数字员工到灯塔工厂,愈发成熟的人工智能技术已深入企业管理,持续提高…

Python3 安装教程(windows)

Python (官网)是这两年来比较流行的一门编程语言。相对简单的语法以及丰富的第三方库。 步骤有三步: 1.下载 Python 安装包 2.安装 Python 3.查验是否安装成功 一.下载 Python 安装包 (Python、Sublime 官方下载地址是外国的服…

「超级细菌」魔咒或将打破,MIT 利用深度学习发现新型抗生素

作者:加零 编辑:李宝珠、三羊 MIT 利用图神经网络 Chemprop 识别潜在抗生素,特异性杀死鲍曼不动杆菌。 自然界中充满了各种各样的微生物,例如结核杆菌(导致肺结核)、霍乱弧菌(导致霍乱&#…

1991-2022年A股上市公司股价崩盘风险指标数据

1991-2022年A股上市公司股价崩盘风险指标数据 1、时间:1991-2022年 2、来源:整理自csmar 3、指标:证券代码、交易年度、NCSKEW(分市场等权平均法)、NCSKEW(分市场流通市值平均法)、NCSKEW(分市场总市值平均法); NCSKEW(综合市…

Android开发基础(一)

Android开发基础(一) 本篇主要是从Android系统架构理解Android开发。 Android系统架构 Android系统的架构采用了分层的架构,共分为五层,从高到低分别是Android应用层(System Apps)、Android应用框架层&a…

最新GPT4、AI绘画、DALL-E3文生图模型教程,GPT语音对话使用,ChatFile文档对话总结

一、前言 ChatGPT3.5、GPT4.0、GPT语音对话、Midjourney绘画,文档对话总结DALL-E3文生图,相信对大家应该不感到陌生吧?简单来说,GPT-4技术比之前的GPT-3.5相对来说更加智能,会根据用户的要求生成多种内容甚至也可以和…

Vue 自定义仿word表单录入之单选按钮组件

因项目需要&#xff0c;要实现仿word方式录入数据&#xff0c;要实现鼠标经过时才显示编辑组件&#xff0c;预览及离开后则显示具体的文字。 鼠标经过时显示 正常显示及离开时显示 组件代码 <template ><div class"pager-input flex border-box full-width fl…

006集 正则表达式 re 应用实例—python基础入门实例

正则表达式指预先定义好一个 “ 字符串模板 ” &#xff0c;通过这个 “ 字符串模 板” 可以匹配、查找和替换那些匹配 “ 字符串模板 ” 的字符串。 Python的中 re 模块&#xff0c;主要是用来处理正则表达式&#xff0c;还可以利用 re 模块通过正则表达式来进行网页数据的爬取…