百度飞将 paddle ,实现贝叶斯神经网络 bayesue neure network bnn,aistudio公开项目 复现效果不好

news2024/11/14 0:49:20

论文复现赛:贝叶斯神经网络 - 飞桨AI Studio星河社区

https://github.com/hrdwsong/BayesianCNN-Paddle

论文复现:Weight Uncertainty in Neural Networks

本项目复现时遇到一个比较大的问题,用pytorch顺利跑通源代码后,修改至paddle框架下再次训练,发现模型不收敛,训练准确率一直维持在0.1附近(随机挑选概率), 模型完全没有学到东西。

针对此问题,我依次对dataset、dataloader、模型参数初始化、优化器、loss函数,甚至沿着整个计算图跟踪了梯度是否正确传递。 最终定位为paddle.max函数,使用该函数后,问题出现;屏蔽该函数后,问题消失。 经分析,应该是该函数不连续,不支持梯度传递。而pytorch版本的max函数则没有此问题。

 

一、简介

本文将不确定性引入神经网络,将确定性参数的神经网络改造为具有随机特性的概率神经网络(也成贝叶斯神经网络)。本文是贝叶斯神经网络的奠基作之一,具有很高的引用量。

具体地,在传统神经网络中,各网络节点的参数为确定值;通过本文方法引入不确定性后,各网络节点的参数转变为满足概率分布的随机变量。 每次正向推理时,网络会根据概率分布对参数值进行采样,并以采样到的值作为本次正向推理的参数值。传统神经网络与贝叶斯神经网络的异同点如下图所示:

训练贝叶斯神经网络时,通过本文方法,可将loss函数反向传递到网络节点的概率分布参数上,从而动态调优该网络。

论文链接:Weight Uncertainty in Neural Networks

二、复现精度

基于paddlepaddle深度学习框架,对文献算法进行复现后,本项目达到的测试精度,如下表所示。 参考文献的最高精度为98.68%

模型和方法本项目精度
lenet-bbb98.75%
alexnet-bbb98.73%
3conv3fc-bbb99.07%
lenet-lrt98.76%
alexnet-lrt98.82%
3conv3fc-lrt99.29%

超参数配置如下:

超参数名设置值
lr0.01
batch_size256
epochs200

三、数据集

本项目使用的是MNIST数据集。该数据集为美国国家标准与技术研究所(National Institute of Standards and Technology (NIST))发起整理,一共统计了来自250个不同的人手写数字图片,其中50%是高中生,50%来自人口普查局的工作人员。该数据集的收集目的是希望通过算法,实现对手写数字的识别。

  • 数据集大小:
    • MNIST数据集是机器学习领域中非常经典的一个数据集,由60000个训练样本和10000个测试样本组成,每个样本都是一张28 * 28像素的灰度手写数字图片。
  • 数据格式:它包含了四个部分
    • (1)Training set images: train-images-idx3-ubyte.gz (9.9 MB, 解压后 47 MB, 包含 60,000 个样本)
    • (2)Training set labels: train-labels-idx1-ubyte.gz (29 KB, 解压后 60 KB, 包含 60,000 个标签)
    • (3)Test set images: t10k-images-idx3-ubyte.gz (1.6 MB, 解压后 7.8 MB, 包含 10,000 个样本)
    • (4)Test set labels: t10k-labels-idx1-ubyte.gz (5KB, 解压后 10 KB, 包含 10,000 个标签)

数据集链接:MNIST

四、环境依赖

  • 硬件:

    • x86 cpu
    • NVIDIA GPU
  • 框架:

    • PaddlePaddle = 2.1.2
  • 其他依赖项:

    • numpy==1.19.3
    • matplotlib==3.3.4
    • pandas==1.2.4
    • pytest==6.2.4
    • paddle==1.0.2
    • Pillow==8.3.1

五、快速开始

1、执行以下命令启动训练:

python train.py --net_type 3conv3fc --dataset MNIST

训练贝叶斯神经网络,运行完毕后,模型参数文件保存在./checkpoints/MNIST/bayesian目录下。

2、执行以下命令进行评估

python test.py --net_type 3conv3fc --dataset MNIST 用于测试贝叶斯神经网络,测试前,将已训练好的最优参数模型从./results/3CONV3FC拷贝至./checkpoints/MNIST/bayesian

In [ ]

# 解压项目文件夹
!unzip -o Paddle-BayesianCNN-V1.zip
%cd Paddle-BayesianCNN

In [7]

# config_bayesian.py文件中修改训练方法,选择'bbb'或'lrt'
# 训练模型
!python train.py --net_type 3conv3fc --dataset MNIST

In [ ]

# 测试模型精度
!python test.py --net_type 3conv3fc --dataset MNIST

六、代码结构与详细说明

6.1 代码结构

├── onfig_bayesian.py               # 配置
├── metrics.py                      # 度量相关
├── README.md                       # readme
├── requirements.txt                # 依赖
├── test                            # 测试
├── train                           # 启动训练入口
├── utils.py                        # 公共调用
├── checkpoints                     # 保存
│   ├── MNIST                        # 数据集名称
│      ├── bayesian
│      ├── best
├── data
│   ├── data.py
├── layers
│   ├── misc.py
│   ├── BBB
│       ├── BBBConv.py
│       ├── BBBLinear.py
├── models
│   ├── BayesianModels
│       ├── BayesianOriginNet.py
│       ├── BayesianLeNet.py

6.2 参数说明

可以在 train.py 中设置训练与评估相关参数,具体如下:

参数默认值说明其他
--net_type3conv3fc, 可选选择模型可选择lenet/alexnet/3conv3fc/originet
--datasetMNIST, 可选选择数据集本项目目前仅支持MNIST

6.3 训练流程

可参考快速开始章节中的描述

训练输出

执行训练开始后,将得到类似如下的输出。每一轮epoch训练将会打印当前training loss、training acc、val loss、val acc以及训练kl散度。

Epoch: 0 	Training Loss: 957661.3024 	Training Accuracy: 0.5314 	Validation Loss: 6048323.2596 	Validation Accuracy: 0.8872 	train_kl_div: 108218176.5714
Validation loss decreased (inf --> 6048323.259558).  Saving model ...
Epoch: 1 	Training Loss: 620338.8870 	Training Accuracy: 0.7838 	Validation Loss: 4819156.8720 	Validation Accuracy: 0.8885 	train_kl_div: 90394454.2449
Validation loss decreased (6048323.259558 --> 4819156.872046).  Saving model ...
Epoch: 2 	Training Loss: 483882.8229 	Training Accuracy: 0.8268 	Validation Loss: 3822200.3844 	Validation Accuracy: 0.8913 	train_kl_div: 71920784.3061
Validation loss decreased (4819156.872046 --> 3822200.384351).  Saving model ...
Epoch: 3 	Training Loss: 390434.6679 	Training Accuracy: 0.8332 	Validation Loss: 2554367.9053 	Validation Accuracy: 0.9018 	train_kl_div: 48361270.8571
Validation loss decreased (3822200.384351 --> 2554367.905322).  Saving model ...
Epoch: 4 	Training Loss: 275255.5825 	Training Accuracy: 0.8434 	Validation Loss: 1809289.2232 	Validation Accuracy: 0.9172 	train_kl_div: 34525619.3469

6.4 测试流程

可参考快速开始章节中的描述

此时的输出为:

Testing Accuracy: 0.9907

七、实验数据比较及复现心得

7.1 实验数据比较

在不同的超参数配置下,模型的收敛效果、达到的精度指标有较大的差异,以下列举不同超参数配置下,实验结果的差异性,便于比较分析:

(1)学习率:

原文献采用的优化器与本项目一致,为Adam优化器,原文献学习率设置为0.001,本项目经调参发现, 学习率设置为0.01或0.0001时,网络有时会不收敛,该模型的稳定性存在可改进空间。

(2)epoch轮次

本项目训练时,采用的epoch轮次为200。LOSS和准确率在110个epoch附近已趋于稳定,模型处于收敛状态,下图为3CONV3FC-BBB的训练曲线。

7.2 复现心得

本项目复现时遇到一个比较大的问题,用pytorch顺利跑通源代码后,修改至paddle框架下再次训练,发现模型不收敛,训练准确率一直维持在0.1附近(随机挑选概率), 模型完全没有学到东西。

针对此问题,我依次对dataset、dataloader、模型参数初始化、优化器、loss函数,甚至沿着整个计算图跟踪了梯度是否正确传递。 最终定位为paddle.max函数,使用该函数后,问题出现;屏蔽该函数后,问题消失。 经分析,应该是该函数不连续,不支持梯度传递。而pytorch版本的max函数则没有此问题。

八、模型信息

训练完成后,模型保存在checkpoints目录下。

训练和测试日志保存在results目录下。

信息说明
发布者hrdwsong
时间2021.08
框架版本Paddle 2.1.2
应用场景贝叶斯神经网络
支持硬件GPU、CPU
repo地址https://github.com/hrdwsong/BayesianCNN-Paddle

请点击此处查看本环境基本用法.
Please click here for more detailed instructions.

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2101482.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【每日一练】python之tkinter的Label标签基础用法

""" 什么是tkinter窗口?tkinter是python中一个标准的库,用于创建图形界面(GUI)应用程序,它提供了一组工具和组件,使开发者能够在Python中创建窗口、按钮、标签、文本框、菜单等各种界面元素…

基于协同过滤的电影推荐系统

推荐系统已经成为当今互联网平台不可或缺的一部分,尤其是在电影、音乐和电子商务等领域。本文将带您深入探讨如何利用协同过滤算法,构建一个功能齐全的电影推荐系统。我们将结合Python、Django框架以及协同过滤算法,逐步实现这一目标。 完整…

Go父类调用子类方法(虚函数调用)

前言 在Go语言中,支持组合而不是继承。网上都说可以通过接口和结构体内嵌来模拟面向对象编程中的子类和父类关系。但给的例子或写法感觉都不是很好,难以达到我的目的(比如通过模板模式实现代码的重用等)。因此调查了一下实现方式…

内裤洗衣机需要一人一台吗?快来围观2024年五大好货集合

随着市面上的内衣抑菌产品越来越多,内衣洗衣机的质量也是参差不齐,一些网红跨界品牌内衣洗衣机的用料和做工品质较差,使用过程中出现清洗不干净和稳定性不足等问题。那么选购内衣洗衣机需要注意什么呢?我作为一名小家电测评博主&a…

pikachu文件包含漏洞靶场(本地文件包含+远程文件包含关卡)

本地文件包含 1.来到关卡随便点击一个提交 可以发现这里可以读取文件 这是1.txt内容 能读取到上一级文件那么也就可以读取本地文件 上传一个jpg文件 拿去连就ok了 远程包含 包含写木马的文件 该文件内容如下,其作用是在fi_remote.php文件的同级目录下新建一个文…

Java 基于微信小程序的小区服务管理系统,附源码

博主介绍:✌stormjun、8年大厂程序员经历。全网粉丝15w、csdn博客专家、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ 🍅文末获取源码联系🍅 👇🏻 精彩专栏推荐订阅👇&…

【赵渝强老师】MongoDB的WiredTiger存储引擎

WiredTiger提供文档级别(Document-Level)的并发控制,检查点(CheckPoint),数据压缩和本地数据加密( Native Encryption)等功能。从MongoDB 3.2 版本开始,WiredTiger成为Mo…

大带宽服务器推流延迟怎么回事

大带宽服务器推流延迟的原因可能涉及多个方面,包括编码解码的延迟、网络传输延迟、CDN分发延迟以及播放端的缓冲处理等。下面将详细解释各个影响因素,并提出相应的优化建议: 1. 编码解码的延迟 视频编码格式的影响:不同的编码格式…

net、udp、tcp

Makefile的main.c文件中的全局变量SONG song,要在fun.c文件里面写成extern SONG song 编译方法 第一次编写 或 网络编程 物理层的网线规定有八根,颜色不一样,功能不一样,光猫把光信号转换成电信号,光纤10Gb WiFi叫无线局域网,一般也就50米左右,手机流量叫蜂窝网络,…

无限延展(倒推法)

本题不妨逆推。 对于长度为的字符串 ,若要求第位的延展,考虑 在最后一次延展前的位置。 若延展结束后的长度为,每次考虑以下内容: 若 ​,说明本次伸展无效, ,

CTFHub技能树-备份文件下载-bak文件

当开发人员在线上环境中对源代码进行了备份操作,并且将备份文件放在了 web 目录下,就会引起网站源码泄露。 使用dirsearch扫描出index.php.bak 有些时候网站管理员可能为了方便,会在修改某个文件的时候先复制一份,将其命名为xxx.b…

没关系,会一手Git版本控制就行(全)

Git版本控制 文章目录 Git版本控制1. 版本控制1.1 概述1.2 版本控制优点1.3 本地版本控制系统(离线版)1.4 集中化的版本控制系统(联网版)1.5 分布式版本控制系统(离线联网版) 2. Git概述2.1 Git基本工作流程…

BUUCTF PWN wp--ciscn_2019_n_8

第一步 checksec一下,本题为32位。 分析一下保护机制: 一、RELRO: Partial RELRO Partial RELRO 提供了一定程度的保护。在这种情况下,部分重定位表在程序启动时被设置为只读。这可以防止一些针对重定位表的攻击,比如通过篡改重…

Elasticsearch 介绍

1、课程介绍 1.1 ES 8.x 演化进程 版本号发布日期多少个次要版本迭代历时8.02022年2月11日?至今7.02019年4月11日17个次要版本34个月6.02017年11月15日8个次要版本17个月5.02016年10月27日6个次要版本13个月 2、Elasticsearch 是什么 2.1 概念 2.1.1 标准定义 …

QLineEdit中文本显示不全,部分字符显示空白问题

环境 QT5.14.2 Windows 7 现象 触发某个条件后,使用QLineEdit的setText方法设置文本后,文本部分字符缺失,现象如下(小数点后都是4位): 解决办法 设置根据显示器的像素密度进行自动缩放;再主…

VMware启动报错-Intel VT-x处于禁用状态,进入BIOS更改CPU设置

问题描述 VMware启动虚拟机失败,报错显示Intel VT-x处于禁用状态。 原因分析 打开主机的任务管理器,找到CPU,发现虚拟化处于禁用状态,查阅资料之后发现,进入BIOS模式将CPU虚拟化禁用开启即可解决。 解决步骤 1、…

KAN 学习 Day1 —— 模型框架解析及 HelloKAN

说明 最近了解到了一个新东西——KAN,我的毕设导师给推荐的船新框架。我看过很多剖析其原理的文章,发现大家对其持有的观点都各不相同,有的说可以颠覆传统MLP,有的说可以和Transformer同等地位,但是也有人说它训练速度…

15chatGLM3半精度微调

1 模型准备 数据依然使用之前的数据,但是模型部分我们使用chatglb-3,该模型大小6B,如果微调的话需要24*4 96GB,硬件要求很高,那么我们使用半精度微调策略进行调试,半精度微调有很多坑啊,注意别踩到…

只会SQL语句,可以做什么工作?

1、SQL是什么 首先简单介绍一下SQL(Structured Query Language),是一种可以进行数据提取、聚合、分析,并对数据库进行构建和修改的编程语言。 相对来说,SQL上手非常容易,因为语法结构比较固定&#xff0c…

第一性原理计算从定义到场景到硬件配置详细讲解

第一性原理计算,又称为从头计算(The Ab initio Calculation),是一种基于量子力学原理,通过计算机模拟来预测材料、分子、固体等体系性质的方法。这种方法的核心思想是不依赖于实验数据或经验参数,而是直接从…