在python 深度学习Keras中计算神经网络集成模型

news2024/11/15 19:53:51

 神经网络的训练过程是一个挑战性的优化过程,通常无法收敛。最近我们被客户要求撰写关于深度学习的研究报告,包括一些图形和统计输出。

这可能意味着训练结束时的模型可能不是稳定的或表现最佳的权重集,无法用作最终模型。

解决此问题的一种方法是使用在训练运行结束时多个模型的权重平均值。  

平均模型权重 

学习深度神经网络模型的权重需要解决高维非凸优化问题。

解决此优化问题的一个挑战是,有许多“ 好的 ”解决方案,学习算法可能会反弹而无法稳定。

解决此问题的一种方法是在训练过程即将结束时合并所收集的权重。通常,这可以称为时间平均,并称为Polyak平均或Polyak-Ruppert平均,以该方法的原始开发者命名。

Polyak平均通过优化算法访问的参数空间将轨迹中的几个点平均在一起。

多类别分类问题

我们使用一个小的多类分类问题作为基础来证明模型权重集合。

该问题有两个输入变量(代表点的xy坐标),每组中点的标准偏差为2.0。

# 生成2D分类数据集
X, y = make_blobs(n_samples=1000, centers=3, n_features=2, cluster_std=2, random_state=2)

结果是我们可以建模的数据集的输入和输出元素。

为了了解问题的复杂性,我们可以在二维散点图上绘制每个点,并通过类值对每个点进行着色。

# 数据集的散点图


# 生成2D分类数据集

# 每个类值的散点图

for class_value in range(3):

# 选择带有类别标签的点的索引

row_ix = where(y == class_value)

# 不同颜色点的散点图

pyplot.scatter(X[row_ix, 0], X[row_ix, 1])

# 显示图

运行示例将创建整个数据集的散点图。我们可以看到2.0的标准偏差意味着类不是线性可分离的(由线分隔),从而导致许多歧义点。

 多层感知器模型

在定义模型之前,我们需要设计一个集合的问题。

在我们的问题中,训练数据集相对较小。具体来说,训练数据集中的示例与保持数据集的比例为10:1。这模仿了一种情况,在这种情况下,我们可能会有大量未标记的示例和少量带有标记的示例用于训练模型。

该问题是多类分类问题,我们 在输出层上使用softmax激活函数对其进行建模。这意味着该模型将预测一个具有三个元素的向量,并且该样本属于三个类别中的每个类别。因此,我们必须先对类值进行编码,然后再将行拆分为训练和测试数据集。


# 分为训练和测试


trainX, testX = X[:n_train, :], X[n_train:, :]

trainy, testy = y[:n_train], y[n_train:]

接下来,我们可以定义和编译模型。

该模型将期望具有两个输入变量的样本。然后,该模型具有一个包含25个节点的隐藏层和一个线性激活函数,然后是一个具有三个节点的输出层(用于预测三种类别中每个类别的概率)和一个softmax激活函数。

# 定义模型


model.add(Dense(25, input_dim=2, activation='relu'))

model.add(Dense(3, activation='softmax'))

opt = SGD(lr=0.01, momentum=0.9)

model.compile(loss='categorical_crossentropy', optimizer=opt, metrics=['accuracy'])

最后,我们将在训练和验证数据集上的每个训练时期绘制模型准确性的学习曲线。

# 学习模型精度曲线

pyplot.plot(history.history['acc'], label='train')

pyplot.plot(history.history['val_acc'], label='test')

pyplot.legend()

pyplot.show()

 在这种情况下,我们可以看到该模型在训练数据集上达到了约86%的准确度 。

Train: 0.860, Test: 0.812

 显示了在每个训练时期的训练和测试集上模型精度的学习曲线。

 

在每个训练时期的训练和测试数据集上模型精度的学习曲线

将多个模型保存到文件

模型权重集成的一种方法是在内存中保持模型权重的运行平均值。

 另一种选择是第一步,是在训练过程中将模型权重保存到文件中,然后再组合保存的模型中的权重以生成最终模型。

#模型拟合

n_epochs, n_save_after = 500, 490

for i in range(n_epochs):

# 适合单个模型

model.fit(trainX, trainy, epochs=1, verbose=0)

# 检查我们是否应该保存模型

if i >= n_save_after:

model.save('model_' + str(i) + '.h5')

将模型保存到文件中。

pip install h5py

将10个模型保存到当前工作目录中。

具有平均模型权重的新模型

 首先,我们需要将模型加载到内存中。 

# 从文件加载模型



def load_all_models(n_start, n_end):

all_models = list()

for epoch in range(n_start, n_end):


# 从文件加载模型



model = load_model(filename)



print('>loaded %s' % filename)

return all_models

我们可以调用该函数来加载所有模型。

# 按顺序加载模型



members = load_all_models(490, 500)

print('Loaded %d models' % len(members))

加载后,我们可以使用模型权重的加权平均值创建一个新模型。

将这些元素捆绑在一起,我们可以加载10个模型并计算平均加权平均值(算术平均值)。 

首先运行示例将从文件中加载10个模型。

从这10个模型中创建一个模型权重集合,为每个模型赋予相等的权重,并报告模型结构的摘要。

_________________________________________________________________

Layer (type)                 Output Shape              Param #

=================================================================

dense_1 (Dense)              (None, 25)                75

_________________________________________________________________

dense_2 (Dense)              (None, 3)                 78

=================================================================

Total params: 153

Trainable params: 153

Non-trainable params: 0

_________________________________________________________________

使用平均模型权重集合进行预测

既然我们知道如何计算模型权重的加权平均值,我们就可以使用生成的模型评估预测。

一个问题是,我们不知道要结合多少模型才能获得良好的性能。我们可以通过评估最近n个模型的模型权重平均合集来解决此问题,并改变n以查看有多少个模型产生良好的性能。


# 反向加载模型,所以我们首先用最后一个模型来构建整体


members = list(reversed(members))

# 选择一个成员的子集


subset = members[:n_members]

# 准备一个权重相等的数组


weights = [1.0/n_members for i in range(1, n_members+1)]

# 用所有模型权重的加权平均值创建一个新的模型



model = model_weight_ensemble(subset, weights)

# 作出预测和评价精度



_, test_acc = model.evaluate(testX, testy, verbose=0)

return test_acc

然后,我们可以评估从从最后1个模型到最后10个模型的训练运行中保存的最近n个模型的不同数量创建的模型。除了评估组合的最终模型外,我们还可以评估测试数据集上每个保存的独立模型以比较性能。

# 计算等待集合上不同数量的集合



single_scores, ensemble_scores = list(), list()

for i in range(1, len(members)+1):

可以绘制收集的分数,蓝色点表示单个保存的模型的准确性,橙色线表示组合了最后n个模型的权重的模型的测试准确性。

#绘制得分和集合模型数量

x_axis = [i for i in range(1, len(members)+1)]

pyplot.plot(x_axis, single_scores, marker='o', linestyle='None')

pyplot.plot(x_axis, ensemble_scores, marker='o')

pyplot.show()

首先运行示例将加载10个保存的模型。

报告每个单独保存的模型的性能以及整体模型的权重,该模型的权重是从所有模型(包括每个模型)开始平均计算的,并且从训练运行的末尾开始向后工作。

结果表明,最后两个模型的最佳测试精度约为81.4%。我们可以看到模型权重集合的测试准确性使性能达到平衡,并且表现也一样。

 我们可以看到,对模型权重求平均值确实可以使最终模型的性能达到平衡,至少与运行的最终模型一样好。

 

线性和指数递减加权平均值

我们可以更新示例,并评估集合中模型权重的线性递减权重。

权重可以计算如下:

# 准备一个权值线性递减的数组



weights = [i/n_members for i in range(n_members, 0, -1)]

运行示例将再次报告每个模型的性能,这一次是每个平均模型权重集合的测试准确性,模型的贡献呈线性下降。

我们可以看到,至少在这种情况下,该集合的性能比任何独立模型都小,达到了约81.5%的精度。

 

我们还可以对模型的贡献进行指数衰减的实验。这要求指定衰减率(α)。下面的示例为指数衰减创建权重,其下降率为2。

# 准备一个按指数递减的权重数组



alpha = 2.0

weights = [exp(-i/alpha) for i in range(1, n_members+1)]

下面列出了模型对集合模型中平均权重的贡献呈指数衰减的完整示例。

运行该示例显示出性能的微小改进,就像在保存的模型的加权平均值中使用线性衰减一样。

测试准确性得分的线图显示了使用指数衰减而不是模型的线性或相等权重的较强稳定效果。

  

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/56068.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

MyBatis-Plus

MyBatis-Plus 1、简介 MyBatis-Plus(简称 MP)是一个 MyBatis 的增强工具,在 MyBatis 的基础上只做增强不做改变,为简化开发、提高效率而生。 润物无声 只做增强不做改变,引入它不会对现有工程产生影响,…

数据结构学习笔记(Ⅷ):排序

目录 1 排序基础 1.1 排序的基本概念 2 排序算法 2.1 插入排序 1.思想 2.实现 3.效率分析 4.优化 2.2 希尔排序 1.定义 2.实现 3.效率分析 3 交换排序 3.1 冒泡排序 1.定义 2.实现 3.效率分析 3.2 快速排序 1.算法思想 2.实现 3.效率分析 4 选择排序 4.…

第4章 SpringBoot与Web应用

文章目录第4章 SpringBoot与Web应用4.1 配置Tomcat运行4.2 https安全访问4.3 数据验证4.4 配置错误页4.5 全局异常处理4.6 文件上传4.6.1 基础上传4.6.2 上传文件限制4.6.3 上传多个文件4.7 拦截器4.8 AOP拦截器4.9 本章小结4.9 本章小结第4章 SpringBoot与Web应…

[附源码]计算机毕业设计病人跟踪治疗信息管理系统Springboot程序

项目运行 环境配置: Jdk1.8 Tomcat7.0 Mysql HBuilderX(Webstorm也行) Eclispe(IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持)。 项目技术: SSM mybatis Maven Vue 等等组成,B/S模式 M…

Linux系统移植二:生成fsbl引导文件并制作BOOT.bin

前情提要 对于ZYNQ而言,在引导过程中,先运行FSBL来设置PS,然后运行U-Boot用于加载Linux内核映像并引导Linux Linux系统移植一:移植U-BOOT 添加自己的板子并编译(非petalinux版) 一文中已成功生成了u-boot…

基于MPPT的PV光伏发电simulink建模和仿真

目录 1.算法描述 2.matlab算法仿真效果 3.MATLAB核心程序 4.完整MATLAB 1.算法描述 MPPT控制器的全称是“最大功率点跟踪”(Maximum Power Point Tracking)太阳能控制器,是传统太阳能充放电控制器的升级换代产品。MPPT控制器能够实时侦测…

ManiSkill 2022机器学习顶会ICLR上的世界顶尖机械臂大赛赛题解读,演示轨迹转换,点云查看

1.赛事相关信息 点击查看 2.赛题分析 软体对GPU要求较高,环境配置复杂,选择刚体环境先以模仿学习/强化学习的刚体环境为基础,后期再考虑无限制刚体环境部分任务(如将物块移动到指定位置),存在相机之外的…

Day818.电商系统的分布式事务调优 -Java 性能调优实战

电商系统的分布式事务调优 Hi,我是阿昌,今天学习记录的是关于电商系统的分布式事务调优。 一个线上事故,在一次 DBA 完成单台数据库线上补丁后,系统偶尔会出现异常报警,开发工程师很快就定位到了数据库异常问题。 具…

SQL通用语法与DDL操作

学习笔记 sql通用语法 1 sql语句可以单行或多行书写,以分号结尾; 2 sql语句可以使用空格/缩进来增强语句的可读性; 3 mysql数据库的sql语句不区分大小写 4 单行注释:-- 内容 或 # 内容 多行注释: /* 内容 */ sql语句…

【地图之vue-baidu-map】点击获取坐标(点Marker)、坐标集(多边形polygon)

点击获取坐标&#xff08;点Marker&#xff09; 官网链接&#xff1a;Vue Baidu Map 需求 1.点击某点设置该点为中心点 2.获取点的经纬度 3.确定选取成功&#xff0c;取消就不赋值。 实现步骤 第一步&#xff1a;设置打开弹窗的地方 <el-button click"clickAdd…

c# .net 树莓派/香橙派用到物联网包Iot.Device.bindings 支持设备说明文档

c# .net 树莓派&#xff08;进口&#xff0c;贵&#xff09;/香橙派&#xff08;国产&#xff0c;功能相同&#xff0c;性价比高&#xff09;用到物联网包Iot.Device.bindings 支持设备说明文档 我们c# .net 开发树莓派/香橙派都需要用到Iot.Device.bindings和System.Device.G…

阿里巴巴 Github 星标 57.9KJava 面试突击汇总(全彩版)首次公开

现在互联网大环境不好&#xff0c;互联网公司纷纷裁员并缩减 HC&#xff0c;更多程序员去竞争更少的就业岗位&#xff0c;整的 IT 行业越来越卷。身为 Java 程序员的我们就更不用说了&#xff0c;上班 8 小时需要做好本职工作&#xff0c;下班后还要不断提升技能、技术栈&#…

命令注入漏洞解析

漏洞简介 Atlassian Bitbucket Server 和 Data Center 是 Atlassian 推出的一款现代化代码协作平台&#xff0c;支持代码审查、分支权限管理、CICD 等功能。 受影响的Bitbucket Server 和 Data Center版本存在使用环境变量的命令注入漏洞&#xff0c;具有控制其用户名权限的攻…

代码随想录刷题记录day34 动态规划理论基础+斐波那契数+爬楼梯+使用最小花费爬楼梯

代码随想录刷题记录day34 动态规划理论基础斐波那契数爬楼梯使用最小花费爬楼梯 动态规划理论基础 解决的问题 由前一个状态决定了后一个的状态&#xff0c;可以用动态规划来解决。贪心是没有状态推导的。 解题步骤 确定dp数组&#xff08;dp table&#xff09;以及下标的…

一键集成 SQL 审核到你的 GitLab 和 GitHub CI/CD

本文以 GitLab 为例&#xff0c;GitHub 方式类似。 操作步骤 事先准备 开启 Bytebase 团队版&#xff08;从 v1.8.0 开始&#xff0c;你可以直接开启 14 天的团队版免费试用&#xff09;。 为你的 Bytebase workspace 和项目开启 VCS 工作流&#xff1a;https://www.bytebas…

基于钉钉通讯录,同步构建本地LDAP服务

上一篇《利用飞书通讯录同步搭建本地LDAP》方案发出后&#xff0c;引起不少企业 IT 人员共鸣。本次&#xff0c;宁盾针对使用了钉钉社交应用的企业推出基于钉钉通讯录&#xff08;组织架构和用户信息&#xff09;同步搭建本地 LDAP的方案。 钉钉已经成为很多企业日常处理工作的…

基于FPGA的智能小车系统

目 录 前 言 1 第1章 系统总体方案设计 4 1.1 系统任务描述 4 1.2 控制系统要求 4 1.3 方案设计与论证 4 1.3.1 小车载体选择 4 1.3.2 主控制器选择 5 1.3.3 传感器选择 5 1.3.4 电机驱动选择 6 1.3.5 稳压电源选择 7 1.3.6 智能小车系统最终方案 7 1.4 系统总体设计 8 1.4.…

【Java开发】 Spring 07 :Spring AOP 实践详解(通过 AOP 打印数据访问层)

AOP 指是面向切面编程&#xff08;通过预编译方式和运行期间动态代理实现程序功能的统一维护的一种技术&#xff09;&#xff0c;利用AOP可以对业务逻辑的各个部分进行隔离&#xff0c;从而使得业务逻辑各部分之间的耦合度降低&#xff0c;提高程序的可重用性&#xff0c;同时提…

SpringBoot 3.0 新特性,内置声明式 HTTP 客户端

http interface 从 Spring 6 和 Spring Boot 3 开始&#xff0c;Spring 框架支持将远程 HTTP 服务代理成带有特定注解的 Java http interface。类似的库&#xff0c;如 OpenFeign 和 Retrofit 仍然可以使用&#xff0c;但 http interface 为 Spring 框架添加内置支持。 什么是…

RabbitMQ之集群方案原理

对于无状态应用&#xff08;如普通的微服务&#xff09;很容易实现负载均衡、高可用集群。而对于有状态的系统&#xff08;如数据库等&#xff09;就比较复杂。 1、业界实践 主备模式&#xff1a;单活&#xff0c;容量对等&#xff0c;可以实现故障转移。使用独立存储时需要借…