深度学习应用技巧4-模型融合:投票法、加权平均法、集成模型法

news2024/11/17 11:48:32

大家好,我是微学AI,今天给大家介绍一下,深度学习中的模型融合。它是将多个深度学习模型或其预测结果结合起来,以提高模型整体性能的一种技术。

深度学习中的模型融合技术,也叫做集成学习,是指同时使用多个模型来进行预测或分类,将它们的结果结合起来,从而获得更准确、更鲁棒的结果。这种方法能够弥补单个模型的不足之处,提高模型的性能。 

常见的深度学习模型包括卷积神经网络(CNN)、循环神经网络(RNN)等,在实际应用中,通常会使用多个模型来解决同一个任务。然而,单独使用每个模型可能会存在过拟合、欠拟合、训练时间长等一些问题。这时,模型融合技术就派上用场了。 对于模型融合技术,其主要的思想是结合多个模型的优点,减少缺点,从而提高整体的性能。

一、模型融合技巧主要包括以下几个方面

1. 投票法: 对多个相同类型的模型进行训练,最后通过投票的方式选择输出结果最多的类别作为最终的预测结果。在实践中,通常会使用奇数个模型,以避免出现相同数量的投票结果。

2. 加权平均法: 对多个相同类型的模型的输出结果进行加权平均。采用加权平均法融合的模型,可以根据效果不同,分配不同的权重。

3. 集成多种不同类型的模型: 在深度学习中,常常会使用不同类型的模型,如 CNN、 RNN 、 LSTM 等,将它们进行集成,综合利用不同模型的优点,进一步提高系统的性能。

4. 提前停止模型训练: 训练多个模型时,如果其中一个模型已经达到了最优的状态,可以停止继续训练,以达到快速训练过程和提高融合效果的目的。

模型融合主要提高了深度学习模型的表现力和泛化性能,更好的解决了过拟合等问题。在选择模型融合技巧时,可以根据具体实际应用选择不同的融合方法,灵活运用各种方法,从而得到更好的模型效果。

应用场景想象:

想象你正在参加一个模型比赛,需要对一些数据进行预测。你仅使用了一个模型进行预测,但是你发现这个模型可能存在某些缺陷,导致预测结果不够准确,于是你想到使用模型融合技术。 你开始集成三个不同的模型,每一个模型都有自己的特点和优缺点。为了得到集成模型的预测结果,你可以采用堆叠法,即把三个模型的预测结果作为输入特征,再训练一个新的模型进行预测。通过堆叠,你得到了一个更加准确的预测结果。 如果你觉得模型之间存在评估的差异,你也可以使用加权平均法来集成模型。你可以根据每一个模型的表现,为它们分配一定的权重,然后根据这些权重,对它们输出的结果进行加权平均,从而得到更加精确的预测结果。 最后,你也可以使用投票法,这种方法是将多个模型的预测结果进行投票,选择获得最多票数的结果作为最终预测结果。它会适用于模型数量较多的情况,即使其中某个模型出现了不准确的情况,也不会对最终结果有太大的影响。 通过这些集成模型的方法,你可以在深度学习中获得更为准确、稳定的预测结果,从而提升模型的性能和可靠性。

二、模型融合代码样例

下面我将利用投票法、加权平均法和集成模型法三种方式对CNN网络进行融合。

import numpy as np
from sklearn.metrics import accuracy_score
from tensorflow import keras
from tensorflow.keras import layers

# 加载 MNIST 数据集
(train_images, train_labels), (test_images, test_labels) = keras.datasets.mnist.load_data()

# 将像素值缩放到 0-1 范围内
train_images = train_images.astype('float32') / 255.
test_images = test_images.astype('float32') / 255.

# 将标签转换为 one-hot 编码
train_labels = keras.utils.to_categorical(train_labels, 10)
test_labels = keras.utils.to_categorical(test_labels, 10)

# 定义 CNN 模型
def create_model():
    model = keras.Sequential()
    model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)))
    model.add(layers.MaxPooling2D((2, 2)))
    model.add(layers.Conv2D(64, (3, 3), activation='relu'))
    model.add(layers.MaxPooling2D((2, 2)))
    model.add(layers.Conv2D(64, (3, 3), activation='relu'))
    model.add(layers.Flatten())
    model.add(layers.Dense(64, activation='relu'))
    model.add(layers.Dense(10, activation='softmax'))
    return model

# 创建多个 CNN 模型
models = []
for i in range(3):
    model = create_model()
    model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
    model.fit(train_images, train_labels, epochs=5, batch_size=128, verbose=1)
    models.append(model)

1.投票法:对测试集进行预测,并取预测结果的众数作为最终预测结果

predictions = []
for model in models:
    predictions.append(model.predict(test_images))
y_pred = np.argmax(np.round(np.mean(predictions, axis=0)), axis=1)
y_true = np.argmax(test_labels, axis=1)

# 计算准确率
acc = accuracy_score(y_true, y_pred)
print("使用投票法进行模型融合的准确率:", acc)

2.加权平均法:为每个模型定义一个权重,并将预测结果加权平均

weights = [0.2, 0.3, 0.5]
predictions = []
for i, model in enumerate(models):
    prediction = model.predict(test_images)
    predictions.append(weights[i] * prediction)
y_pred = np.argmax(np.sum(predictions, axis=0), axis=1)
y_true = np.argmax(test_labels, axis=1)

# 计算准确率
acc = accuracy_score(y_true, y_pred)
print("使用加权平均法进行模型融合的准确率:", acc)

3.模型集成法:将多个 CNN 模型堆叠在一起,将其看作一个更强大的模型,并对测试集进行预测

inputs = keras.Input(shape=(28, 28, 1))
outputs = [model(inputs) for model in models]
y = layers.Average()(outputs)
ensemble_model = keras.Model(inputs=inputs, outputs=y)
ensemble_model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
# 训练集成模型
ensemble_model.fit(train_images, train_labels, epochs=5, batch_size=64)

# 测试集成模型
predictions = ensemble_model.predict(test_images)
y_pred = np.argmax(predictions, axis=1)
y_true = np.argmax(test_labels, axis=1)

# 计算准确率
acc = accuracy_score(y_true, y_pred)
print("使用模型集成方法进行模型融合的准确率:", acc)

最后我们可以看三种方法的准确率,在实战案例中根据业务需求进行方法的选择。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/398739.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

CentOS7.5(1804)安装vsftpd(ftp)

1.准备安装包 1. vsftpd-3.0.2-29.el7_9.x86_64.rpm 2. ftp-0.17-67.el7.x86_64.rpm 可以自行下载,也可从我的博客中下载,下载传送门点 这里 2.安装vsftpd #1. 上传文件到服务器上,比如/home目录 #2. 执行以下命令安装 rpm -ivh vsftpd-3.0.2-29.el7_9.x86_64.rpm #3. 启动vsf…

极限的无穷小和无穷大

目录 无穷小: 无穷大: 无穷小: 举几个无穷小量的例子: 以0为极限的意思就是无穷小。 注:无穷小是变量,不能把很小很小的数混为一谈。 2:0是可以作为无穷小的唯一的一个数。 我们进行证明&…

完整教程:使用Spring Boot实现大文件断点续传及文件校验

一、简介 随着互联网的快速发展,大文件的传输成为了互联网应用的重要组成部分。然而,由于网络不稳定等因素的影响,大文件的传输经常会出现中断的情况,这时需要重新传输,导致传输效率低下。 为了解决这个问题&#xff…

【敏捷开发】jenkins「CI持续集成 CD持续部署」

文章目录前言一、安装jenkins1. 部署中的痛点2. 什么是jenkins3. jenkins的安装和配置(1)下载(2)安装二、上传到运行服务器1. jenkins构建服务器流程2. 安装jenkins常用插件3. 通过freestyle构建项目4. 将构建服务器上的代码上传到…

ASA材料3D打印服务 抗紫外线材料3D打印服务 抗紫外线模型制作-CASAIM中科院广州电子

3D打印技术又称增材制造,通常是采用数字技术材料打印机来实现的,常在模具制造、工业设计等领域被用于制造模型,后逐渐用于一些产品的直接制造。随着 3D 打印逐渐成为主流生产流程的一部分,ASA抗紫外线材料应运而生。中科院广州电子…

Ubuntu 搭建NextCloud私有云盘【内网穿透远程访问】

文章目录1.前言2.本地软件安装2.1 nextcloud安装2.2 cpolar安装3.本地网页发布3.1 Cpolar云端设置3.2 Cpolar本地设置4.公网访问测试5. 结语1.前言 对于爱好折腾的电脑爱好者来说,Linux是绕不开的、必须认识的系统(大部分服务器都是采用Linux操作系统&a…

华为OD机试题,用 Java 解【数组二叉树】问题

华为Od必看系列 华为OD机试 全流程解析+经验分享,题型分享,防作弊指南)华为od机试,独家整理 已参加机试人员的实战技巧华为od 2023 | 什么是华为od,od 薪资待遇,od机试题清单华为OD机试真题大全,用 Python 解华为机试题 | 机试宝典使用说明 参加华为od机试,一定要注意不…

将生成的NYUv2边界GT加载到dataloader中并进行训练

由上一篇我们可以知道,我们生成了一个label_img文件夹,里面存放的是索引对应图片的filename,每个filename里面存放的是GT的40个通道的边缘GT。train里面是这样,test里面也是这样。 加载数据我们要到train文件的dataloader中&…

Azure AD 与 AWS 单一帐户SSO访问集成,超详细讲解,包括解决可能出现的错误问题

本教程介绍如何将 AWS Single-Account Access 与 Azure Active Directory (Azure AD) 相集成。 将 AWS Single-Account Access 与 Azure AD 集成后,可以: 在 Azure AD 中控制谁有权访问 AWS Single-Account Access。让用户使用其 Azure AD 帐户自动登录…

SwiftUI 常用组件和属性(SwiftUI初学笔记)

本文为初学SwiftUI笔记。记录SwiftUI常用的组件和属性。 组件 共有属性(View的属性) Image("toRight").resizable().background(.red) // 背景色.shadow(color: .black, radius: 2, x: 9, y: 15) //阴影.frame(width: 30, height: 30) // 宽高 可以只设置宽或者高.…

2023年上半年软考中/高级一起报名考试+备考学习

软考是全国计算机技术与软件专业技术资格(水平)考试(简称软考)项目,是由国家人力资源和社会保障部、工业和信息化部共同组织的国家级考试,既属于国家职业资格考试,又是职称资格考试。 系统集成…

Springboot——自定义Filter使用测试总结

文章目录前言自定义过滤器并验证关于排除某些请求的方式创建测试接口请求测试验证异常过滤器的执行流程注意事项资料参考前言 在Java-web的开发领域,对于过滤器和拦截器用处还是很多,但两者的概念却极易混淆。 过滤器和拦截器都是采用AOP的核心思想&am…

【微服务】—— 初识微服务

文章目录1. 什么是微服务1.1 微服务的特性自主专用性1.2 微服务的优势敏捷性灵活扩展轻松部署技术自由可重复使用的代码弹性2. 微服务技术栈3. 微服务架构演进3.1 单体架构3.2 分布式架构服务治理3.3 微服务微服务结构微服务技术对比企业需求1. 什么是微服务 微服务是一种开发软…

【删繁就简】Echarts 视觉映射组件中国地图分段颜色显示,选中范围内外颜色设置策略

【删繁就简】Echarts 视觉映射组件中国地图分段颜色显示,选中范围内外颜色设置策略一、背景二、增加0值分段配置项三、解决方案3.1 更改地图底色3.2 更改outOfRange配置项一、背景 在前端项目开发过程中,需要在大屏模块模块中按照项目在各省份分部的数量…

【100个 Unity实用技能】 ☀️ | 脚本无需挂载到游戏对象上也可执行的方法

Unity 小科普 老规矩,先介绍一下 Unity 的科普小知识: Unity是 实时3D互动内容创作和运营平台 。包括游戏开发、美术、建筑、汽车设计、影视在内的所有创作者,借助 Unity 将创意变成现实。Unity 平台提供一整套完善的软件解决方案&#xff…

关于Go语言的底层,你想知道的都在这里!

文章目录1. GoLang语言1.1 Slice1.2 Map1.3 Channel1.4 Goroutine1.5 GMP调度1.6 垃圾回收机制1.7 其他知识点2. Web框架Gin和微服务框架Micro2.1 Gin框架2.2 Micro框架2.3 Viper2.4 Swagger2.5 Zap2.6 JWT文章字数大约1.95万字,阅读大概需要65分钟,建议…

洗牌发牌-第14届蓝桥杯STEMA测评Scratch真题精选

[导读]:超平老师的《Scratch蓝桥杯真题解析100讲》已经全部完成,后续会不定期解读蓝桥杯真题,这是Scratch蓝桥杯真题解析第105讲。 蓝桥杯选拔赛现已更名为STEMA,即STEM 能力测试,是蓝桥杯大赛组委会与美国普林斯顿多…

docker从安装到部署一个项目

一.centos安装docker 参考博客:https://blog.csdn.net/m0_47010003/article/details/127775185 1.设置一下下载Docker的镜像源 设置下载的镜像源为国内的阿里云,如果不设置,会默认去Docker的官方下载 yum-config-manager --add-repo http…

飞桨携手Hugging Face共建开源社区,文图生成黑科技画你所想!

最近的 AIGC 有多火,你不会不知道吧? AI绘画收到越来越多关注的同时,你想不想自己试试?如何基于开源项目训练自己的趣味模型,开源出来被更多人看到? 在这个人人都是创作家的时代,你可以脑洞大开…

element ui 的滚动条,Element UI 文档中没有被提到的滚动条

element ui 的滚动条,Element UI 文档中被提到的滚动条 Element UI 官网中有用到自定义的滚动条组件&#xff0c;但是发布的所有版本中都不曾提及&#xff0c;个中原因我们不得而知&#xff0c;不过我们还是可以拿过来引用到自己的项目中。 使用的时候&#xff0c; 放在 <el…