多标签分类怎么做?(Python)

news2024/11/16 7:54:07

一、基本介绍

首先简单介绍下,多标签分类与多分类、多任务学习的关系:

  • 多分类学习(Multi-class):分类器去划分的类别是多个的,但对于每一个样本只能有一个类别,类别间是互斥的。例如:分类器判断这只动物是猫、狗、猪,每个样本只能有一种类别,就是一个三分类任务。常用的做法是OVR、softmax多分类

  • 多标签学习(Multi-label ):对于每一个样本可能有多个类别(标签)的任务,不像多分类任务的类别是互斥。例如判断每一部电影的标签可以是多个的,比如有些电影标签是【科幻、动作】,有些电影是【动作、爱情、谍战】。需要注意的是,每一样本可能是1个类别,也可能是多个。而且,类别间通常是有所联系的,一部电影有科幻元素 同时也大概率有动作篇元素的。

  • 多任务学习(Multi-task):基于共享表示(shared representation),多任务学习是通过合并几个任务中的样例(可以视为对参数施加的软约束)来提高泛化的一种方式。额外的训练样本以同样的方式将模型的参数推向泛化更好的方向,当模型的一部分在任务之间共享时,模型的这一部分更多地被约束为良好的值(假设共享是合理的),往往能更好地泛化。某种角度上,多标签分类可以看作是一种多任务学习的简单形式。

二、多标签分类实现

实现多标签分类算法有DNN、KNN、ML-DT、Rank-SVM、CML,像决策树DT、最近邻KNN这一类模型,从原理上面天然可调整适应多标签任务的(多标签适应法),如按同一划分/近邻的客群中各标签的占比什么的做下排序就可以做到了多标签分类。这部电影10个近邻里面有5部是动作片,3部是科幻片,可以简单给这部电影至少打个【科幻、动作】。

这里着重介绍下,比较通用的多标签实现思路,大致有以下4种:

方法一:多分类思路

简单粗暴,直接把不同标签组合当作一个类别,作为一个多分类任务来学习。如上述 【科幻、动作】、【动作、爱情、谍战】、【科幻、爱情】就可以看作一个三分类任务。这种方法前提是标签组合是比较有限的,不然标签会非常稀疏没啥用。

方法二:OVR二分类思路

也挺简单的。将多标签问题转成多个二分类模型预测的任务。如电影总的子标签有K个,划分出K份数据,分别训练K个二分类模型,【是否科幻类、是否动作类....第K类】,对于每个样本预测K次打出最终的标签组合。

这种方法简单灵活,但是缺点是也很明显,各子标签间的学习都是独立的(可能是否科幻类对判定是否动作类的是有影响),忽略了子标签间的联系,丢失了很多信息。

对应的方法有sklearn的OneVsRestClassifier方法,

from xgboost import XGBClassifierfrom sklearn.multiclass import OneVsRestClassifierimport numpy as np
clf_multilabel = OneVsRestClassifier(XGBClassifier())
train_data = np.random.rand(500, 100)  # 500 entities, each contains 100 featurestrain_label = np.random.randint(2, size=(500,20))  # 20 targets
val_data = np.random.rand(100, 100)
clf_multilabel.fit(train_data,train_label)val_pred = clf_multilabel.predict(val_data)

方法三:二分类改良

在方法二的基础上进行改良,即考虑标签之间的关系。每一个分类器的预测结果将作为一个数据特征传给下一个分类器,参与进行下一个类别的预测。该方法的缺点是分类器之间的顺序会对模型性能产生巨大影响。

方法四:多个输出的神经网络

这以与多分类方法类似,但不同的是这里神经网络的多个输出,输出层由多个的sigmoid+交叉熵组成,并不是像softmax各输出是互斥的。

如下构建一个输出为3个标签的概率的多标签模型,模型是共用一套神经网络参数,各输出的是独立(bernoulli分布)的3个标签概率

## 多标签 分类from keras.models import Modelfrom keras.layers import Input,Dense
inputs = Input(shape=(15,))hidden = Dense(units=10,activation='relu')(inputs)output = Dense(units=3,activation='sigmoid')(hidden)model=Model(inputs=inputs, outputs=output)model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])model.summary()
# 训练模型,x特征,y为多个标签model.fit(x, y.loc[:,['LABEL','LABEL1','LABEL3']], epochs=3)

通过共享的模型参数来完成多标签分类任务,在考虑了标签间的联系的同时,共享网络参数可以起着模型正则化的作用,可能对提高模型的泛化能力有所帮助的(在个人验证中,测试集的auc涨了1%左右)。这一点和多任务学习是比较有联系的,等后面有空再好好研究下多任务。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/108058.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

electron:获取MAC地址

一、背景 当我们需要用户“使用指定设备”访问程序的时候,我们需要获取用户设备的固定的id,设备id用户id实现业务需求,这个所谓的id就是MAC地址。 对于其他方法: uuid:uuid是一个唯一的字符串,可以存放到…

深度融合钉钉PaaS,授客学堂助力企业实现培训数字化

方案简介 授客学堂将企业培训领域的经验与钉钉开放能力深度融合,通过集成钉钉人事一体、酷应用、IM底座、待办等多种开放能力,实现学员培训数据实时互通,为客户提供更新更酷的能力,高效解决企业培训的数字化服务。 方案场景 在…

tensorflow feature_columns

总结来说: feature_column定义了一种数据预处理的方式,可以看作是一种格式,指定了key,用于后续读取输入流中对应列的数据feature_column不是tensor,所以如果在下一步应用到模型中是需要tensor,还需要通过f…

非互联网人士如何转行互联网?

结论是,具备互联网式的做事思维积累互联网项目经验。我靠着这个方法从一名传统销售顺利转行,(之前没有任何互联网工作经验)入职了一家互联网公司做用户运营,半年前跳槽成为一个4人运营小团队的leader。 在分享我自身的…

我国丁辛醇行业现状:上游丙烯供给充足 下游需求下滑 市场出现高差价现象

根据观研报告网发布的《中国丁辛醇行业发展深度分析与投资前景研究报告(2022-2029年)》显示,丁辛醇是一种丁醇和辛醇合成的有机物,无色透明、易燃的油状液体,具有特殊的气味,能与水及多种化合物形成共沸物&…

服务器IPMI(BMC)装机

将网线连接服务器的控制口与PC,服务器的控制口默认IP为192.168.100.100,网关默认为192.168.100.1,将PC的IP修改为与服务器控制口相同网段。打开浏览器,输入https://19168.100.100,进入IPMI登录界面。账号密码需要找运维…

Apache Airflow Hive Provider <5.0.0 存在操作系统命令注入漏洞

漏洞描述 Apache Airflow 是一个用于以编程方式创作、安排和监控工作流平台。Apache Airflow Hive Provider 是一个使用 SQL 读取、写入和管理分布式存储中的大型数据集的工具包。 Apache Airflow Hive Provider 在 5.0.0 之前的版本中由于对airflow/providers/apache/hive/h…

Stm32标准库函数6——f103 PWM 电调(50Hz)

#include "stm32f10x.h" #include "delay.h" TIM_TimeBaseInitTypeDef TIM_TimeBaseStructure; TIM_OCInitTypeDef TIM_OCInitStructure; u16 Ppm; /************************************************* 函数: void RCC_Configuration(void) 功能: 配…

D. Friends and Subsequences Codeforces Round #361 (Div. 2)RMQ+二分 单调队列

题目传送门 题意为 给定两个长度为n的数组,设为a数组和b数组,需要找到所有可能的区间中,a数组的最大值等于b数组的最小值的个数。 1:RMQ 二分 RMQ 能找到一个数组在任意区间的最大值或者最小值,只需要在O(n)的时间…

【Java基础知识复盘】HashMap篇——持续更新中

本人知识复盘系列的博客并非全部原创,大部分摘自网络,只是为了记录在自己的博客方便查阅,往后也会陆续在本篇博客更新本人查阅到的新的知识点,望悉知! HashMap 概述 HashMap 是一个散列表,它存储的内容是…

hashMap相关

文章目录HashMapHashMap介绍HashMap在 JDK1.7和 JDK1.8中的区别JDK1.7中HashMap头插法死循环的原因HashMap的底层原理HashMap的扩容机制解决Hash冲突的方法为什么在解决hash冲突的时候选择先用链表,再转红黑树?HashMap为什么线程不安全一般用什么作为HashMap的key?…

程序员需要达到什么水平才能顺利拿到 20k 无压力?

很有趣的是,在程序员身上,我看到了最明显,也最有趣的贫富差距。 根据2022最新版大厂新入职员工职级对应表,大厂技术线的员工轻而易举地拿到了20w的水平,而只要往上够一够,30w也不是什么难事。 然而&#xf…

玩转云服务器:怎样用云服务器架设大型3D魔幻手游【魔域】服务器,实现联机多人同玩,带你一起搞机,了解游戏搭建过程,详细教程

准备工作: 你首先要准备一台云服务器! 服务器配置:2核4G以上配置! 服务器系统:win2012 开始搭建: 下载游戏服务端(有些多人叫源码,这里我就不解释了,喜欢怎么叫就怎…

SpringCloud Alibaba | 网关(三) : SpringCloudGateway 过滤器获取application/json中body数据

SpringCloudGateway 过滤器获取application/json中body数据一、前言二、通过cachedRequestBodyObject缓存获取三、ServerHttpRequest getBody方法获取四、(* ̄︶ ̄)一、前言 项目接口需要加解密,就在网关层进行解密操作。那么问题来了怎么在gateway 的filt…

基于松鼠算法改进的DELM预测-附代码

松鼠算法改进的深度极限学习机DELM的回归预测 文章目录松鼠算法改进的深度极限学习机DELM的回归预测1.ELM原理2.深度极限学习机(DELM)原理3.松鼠算法4.松鼠算法改进DELM5.实验结果6.参考文献7.Matlab代码1.ELM原理 ELM基础原理请参考:https:…

线程池相关

文章目录为什么需要线程池?池化思想常用方法execute()方法submit()方法shutdownisShutdownisTerminatedawaitTerminationshutdownNow创建线程池 七个参数流程JAVA线程池有哪几种类型?线程池常用的阻塞队列有哪些?源码中线程池是怎么复用线程的?如何合理配置线程池…

EMQX Cloud 自定义函数实现多种 IoT 数据形式的灵活转化

物联网场景中,各类设备终端的种类繁杂,所使用的通信协议各异,从而使得应用层的数据格式也各不相同。为了帮助用户实现统一数据格式,EMQX Cloud 最近推出了自定义函数功能:根据用户自定义的脚本对设备上报的数据进行预处…

上美股份在港交所上市:预计全年利润下滑,一叶子收入持续走低

12月22日,上海上美化妆品股份有限公司(HK:02145,下称“上美股份”)在港交所上市。本次上市,上美股份的发行价格为25.20港元/股,为此前发行区间的最低值。据此计算,上美股份的募资总额约为9.31亿…

CDH6.3.2集成Apache Atlas2.1.0

1 环境准备 1.1 CDH6.3.2 环境搭建 参考文档如下 Cloudera Manager安装CDH6教程-(一)虚拟环境安装配置 Cloudera Manager安装CDH6教程-(二)搭建Cloudera和CDH6 CM和CDH在安装的时候遇到的问题 CDH6.3.2 各组件版本 1.2 apa…

火爆“有机新消费”驶入酱油赛道 好记打造我国有机酱油行业领导品牌

根据观研报告网发布的《2022年中国有机酱油市场分析报告-市场竞争策略与发展动向前瞻》显示,有机酱油是指采用有机农作物为原料酿制的酱油。有机酱油含有浓郁的酱香和脂香,是一种不可多得的上等调味品,适合于蘸食,红烧&#xff0c…