半监督学习

news2024/10/5 17:19:04

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档

目录

  • 介绍
  • 一、Self Training自训练
    • 1、介绍
    • 2、代码示例
    • 3、参数解释
  • 二、Label Propagation(标签传播)
    • 1、介绍
    • 2、代码示例
    • 3、参数解释
  • 三、Label Spreading(标签扩散)
    • 1、介绍
    • 2、代码示例
    • 3、参数解释


介绍

半监督学习(Semi-Supervised Learning,SSL)是机器学习领域中的一个重要分支,它结合了监督学习和无监督学习的思想,用于处理标签数据稀缺而无标签数据丰富的场景。
常用方法:

  • Self Training自训练
  • Label Propagation标签传播
  • Label Spreading标签扩散

一、Self Training自训练

1、介绍

Self Training自训练是一种简单的半监督学习方法,它首先使用已标记的数据训练一个监督学习模型。然后,该模型用于预测未标记数据的标签。预测最自信的标签可以被选择添加到训练集中,然后模型在新的、更大的训练集上重新训练。先训练一个小模型,再继续预测标签,类似于迁移学习。当无标签数据和有标签数据分布相同时,使用自训练方法效果最佳。

2、代码示例

  • 读入数据
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"

import numpy as np
import pandas as pd
import warnings
warnings.filterwarnings("ignore")

# 数据预处理计算函数
def preprocessing(df):
    from sklearn.impute import SimpleImputer
    from sklearn.preprocessing import StandardScaler
    from sklearn.pipeline import Pipeline
    from sklearn.preprocessing import OrdinalEncoder
    from sklearn.compose import ColumnTransformer
    
    cat_cols= df.select_dtypes(include=["object"])   # 分类型变量
    num_cols= df.select_dtypes(include=["int", "float"])   # 数值型变量

    # 连续型数据
    num_imp= SimpleImputer(strategy='mean')  # 缺失值
    num_std= StandardScaler()  # 标准化
    num_pipeline= Pipeline(steps=[("num_imp", num_imp), ("num_std", num_std)])
    # 分类型数据
    cat_imp= SimpleImputer(strategy="most_frequent")  # 缺失值
    cat_encode= OrdinalEncoder()   # 数据编码
    cat_pipeline= Pipeline(steps=[("cat_imp", cat_imp), ("cat_encode", cat_encode)])

    col_trans= ColumnTransformer(transformers=[("num_pipeline", num_pipeline, num_cols.columns),
                                           ("cat_pipeline", cat_pipeline, cat_cols.columns),])
    # 数据集处理的计算
    transfer= col_trans.fit(df)
    return transfer

# 读入数据
raw_data= pd.read_csv('半监督学习.csv')
labels= raw_data.pop("resp_flag")  # 标签
  • 缺失数据对比
print("缺失值/总样本:"+str(labels.isnull().sum())+"/"+str(len(labels)))

在这里插入图片描述

  • 数据处理
    注意:切分的测试数据集一定是有标签的样本
# sklearn中的半监督学习算法要求,所有缺失的标签必须用-1填充
labels_fill= labels.fillna(-1)

# 特征数据处理
transfer= preprocessing(raw_data)
data_trans= transfer.transform(raw_data)

data_concat= pd.concat([labels_fill, pd.DataFrame(data_trans)], axis= 1)

# 保存一部分有标签样本作为测试集
mask_labeled= (labels_fill != -1)
mask_unlabeled= (labels_fill == -1)

data_labeled= data_concat[mask_labeled]
data_unlabeled= data_concat[mask_unlabeled]

# 切分测试集
from sklearn.model_selection  import train_test_split
train, test= train_test_split(data_labeled, test_size=0.2, stratify= data_labeled["resp_flag"], random_state= 42)

Xtrain= pd.concat([train, data_unlabeled])
Ytrain= Xtrain.pop("resp_flag")
  • 使用模型
from sklearn.ensemble import RandomForestClassifier
RF= RandomForestClassifier(oob_score=True)

# Self Training
from sklearn.semi_supervised import SelfTrainingClassifier
RF_self_training= SelfTrainingClassifier(RF)
RF_self_training.fit(Xtrain, Ytrain)

# 测试集模型评估
Xtest= test
Ytest= Xtest.pop("resp_flag")

from sklearn.metrics import roc_auc_score
print("AUC: ", roc_auc_score(Ytest, RF_self_training.predict_proba(Xtest)[:, 1]))

在这里插入图片描述

3、参数解释

base_estimator: BaseEstimator,# 基学习器
threshold: Float = 0.75,# 默认阈值0.75,大于0.75,小于0.25会被打标签,该参数比k_best更为常用
criterion: Literal['threshold', 'k_best'] = "threshold",# 默认值threshold,为该值时和threshold参数相同,即设阈值,k_best超参数阈值,如为10,则不考虑预测概率,只取排名前10的打标签
k_best: Int = 10,# 超参数阈值,如为10,则不考虑预测概率,只取排名前10的打标签
max_iter: int | None = 10,# 最大迭代次数
verbose: bool = False

二、Label Propagation(标签传播)

在sklearn中,基于图算法的半监督学习有Label Propagation和Label Spreading两种。他们的主要区别是第二种方法带有正则化机制。

1、介绍

Label Propagation(标签传播)基本原理:Label Propagation算法基于图理论。算法首先构建一个图,其中每个节点代表一个数据点,无论是标记的还是未标记的。节点之间的边代表数据点之间的相似性。算法的目的是通过图传播标签信息,使未标记数据获得标签。

关键特点:
相似性度量:通常使用K近邻(KNN)或者基于核的方法来定义数据点之间的相似性。
标签传播:标签信息从标记数据点传播到未标记数据点,通过迭代过程实现。
适用场景:适合于数据量较大、标记数据稀缺的情况。

  • 以环形数据为例,绿色全是为打标签的数据:
    在这里插入图片描述
    打标签后数据结果如图:
from sklearn.semi_supervised import LabelPropagation

label_propagation = LabelPropagation(kernel="knn")
label_propagation.fit(X, labels)

output= np.asarray(label_propagation.transduction_)
outer_numbers = np.where(output == outer)[0]
inner_numbers = np.where(output == inner)[0]

plt.figure(figsize=(4, 4))
plt.scatter(X[outer_numbers, 0], X[outer_numbers, 1],)
plt.scatter(X[inner_numbers, 0], X[inner_numbers, 1],);

在这里插入图片描述

2、代码示例

from sklearn.semi_supervised import LabelPropagation

label_propagation = LabelPropagation(kernel="knn")
label_propagation.fit(Xtrain, Ytrain)

Ytrain_propagation= label_propagation.transduction_

from sklearn.ensemble import RandomForestClassifier
RF_propagation= RandomForestClassifier(oob_score=True)
RF_propagation.fit(Xtrain, Ytrain_propagation)

print("AUC: ", roc_auc_score(Ytest, RF_propagation.predict_proba(Xtest)[:, 1]))

在这里插入图片描述

3、参数解释

    kernel: ((...) -> Any) | Literal['knn', 'rbf'] = "rbf",# knn:k近邻,RBF核用于计算图中节点之间的相似度。这些相似度值随后用于传播标签信息,从而根据相邻节点的标签来预测未知节点的标签,rbf函数和正态分布比较相似
    *,
    gamma: Float = 20, # rbf函数的系数,可以简单理解为正态分布的方差
    n_neighbors: Int = 7, # 附近的7个样本,哪个样本多,就打成哪个标签,为knn时生效
    max_iter: Int = 1000,# 迭代次数
    tol: float = 0.001,# 算法收敛的阈值
    n_jobs: Int | None = None

三、Label Spreading(标签扩散)

1、介绍

基本原理:Label Spreading和Label Propagation非常相似,但在处理标签信息和正则化方面有所不同。它同样基于构建图来传播标签。

关键特点:
正则化机制:Label Spreading引入了正则化参数,可以控制标签传播的过程,使算法更加健壮。
稳定性:由于正则化的存在,Label Spreading在面对噪声数据时通常比Label Propagation更稳定。
适用场景:同样适用于有大量未标记数据的情况,尤其当数据包含噪声或者不完全标记时。

2、代码示例

from sklearn.semi_supervised import LabelSpreading

label_spreading = LabelSpreading(kernel="knn", alpha= 0.2)
label_spreading.fit(Xtrain, Ytrain)

Ytrain_spreading= label_spreading.transduction_

from sklearn.ensemble import RandomForestClassifier
RF_spreading= RandomForestClassifier(oob_score=True)
RF_spreading.fit(Xtrain, Ytrain_spreading)

print("AUC: ", roc_auc_score(Ytest, RF_spreading.predict_proba(Xtest)[:, 1]))

在这里插入图片描述

3、参数解释

	kernel: ((...) -> Any) | Literal['rbf', 'knn'] = "rbf",
    *,
    gamma: Float = 20,
    n_neighbors: Int = 7,
    alpha: Float = 0.2, # 正则化参数,用于控制算法对标签平滑的程度,值较小时,会更强调邻居节点信息,值较大时,更倾向于保持原始标签
    max_iter: Int = 30,
    tol: Float = 0.001, # 算法收敛的阈值
    n_jobs: Int | None = None

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1833200.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

突然挣不到钱了?带货主播大降薪,有人收入“腰斩”!时薪低至20元,“不如街头发小广告”

韭菜都想来割韭菜了,从00后到60后都在直播带货,部分业内人士认为不懂行的商家以及海量素人主播的加入,拉低了行业的平均薪酬。 2024年的电商年中大促接近尾声,电商直播市场再次成为广为关注的焦点。然而,与热闹的“618…

解放代码:识别与消除循环依赖的实战指南

目录 一、对循环依赖的基本认识 (一)代码中形成循环依赖的说明 (二)无环依赖的原则 二、识别和消除循环依赖的方法 (一)使用JDepend识别循环依赖 使用 Maven 集成 JDepend 分析报告识别循环依赖 &a…

C++使用spdlog输出日志文件

参考博客: 日志记录库 spdlog 基础使用_spdlog 写日志-CSDN博客 GitHub - gabime/spdlog: Fast C logging library. 首先在github上下载spdlog源码,不想编译成库的话,可以直接使用源码,将include文件夹下的spdlog文件夹&#x…

CSS选择符和可继承属性

属性选择符&#xff1a; 示例&#xff1a;a[target"_blank"] { text-decoration: none; }&#xff08;选择所有target"_blank"的<a>元素&#xff09; /* 选择所有具有class属性的h1元素 */ h1[class] { color: silver; } /* 选择所有具有hre…

116.网络游戏逆向分析与漏洞攻防-邮件系统数据分析-解析结构数据的创建信息

免责声明&#xff1a;内容仅供学习参考&#xff0c;请合法利用知识&#xff0c;禁止进行违法犯罪活动&#xff01; 如果看不懂、不知道现在做的什么&#xff0c;那就跟着做完看效果&#xff0c;代码看不懂是正常的&#xff0c;只要会抄就行&#xff0c;抄着抄着就能懂了 内容…

智谱AI GLM-4V-9B视觉大模型环境搭建推理

引子 最近在关注多模态大模型&#xff0c;之前4月份的时候关注过CogVLM&#xff08;CogVLM/CogAgent环境搭建&推理测试-CSDN博客&#xff09;。模型整体表现还不错&#xff0c;不过不支持中文。智谱AI刚刚开源了GLM-4大模型&#xff0c;套餐里面包含了GLM-4V-9B大模型&…

springboot+vue+mybatis酒店房间管理系统+PPT+论文+讲解+售后

随着现在网络的快速发展&#xff0c;网络的应用在各行各业当中它很快融入到了许多商家的眼球之中&#xff0c;他们利用网络来做这个电商的服务&#xff0c;随之就产生了“酒店房间管理系统”&#xff0c;这样就让人们酒店房间管理系统更加方便简单。 对于本酒店房间管理系统的…

本地localhost与目标地址跨域问题的解决方法

场景 开发过程中遇到一个控件&#xff0c;上传图片到某cdn&#xff0c;目标地址对localhost会有跨域问题&#xff1a; 解决方法 参照此博客&#xff0c;将本地地址定义为某网址&#xff0c;如abc&#xff1a; win10修改本地host文件&#xff0c;用以增加自定义本地访问域名12…

不做题,可以通过PMP考试吗?

如果你想要避免浪费3900元并且不想再支付2500元的补考费&#xff0c;我建议你多做一些PMP考试的练习题&#xff1b;如果你不在意这些费用&#xff0c;也可以选择资助我&#xff0c;嘿嘿。不做题的话&#xff0c;通过PMP考试的几率是非常小的。因为做题是检验我们学习成果并发现…

躬行践履始玉成,行而不辍终致远 | 中创算力季度优秀员工表彰大会

蓬勃发展&#xff0c;根基在人。在中创发展的道路上&#xff0c;有初心不改的领导者、有披星戴月的业务标杆、也有默默坚守的员工&#xff0c;他们扎根中创&#xff0c;用努力、拼搏、坚持&#xff0c;在中创的历程上镌刻下 属于自己的一份印记&#xff01; 为了表彰优秀&…

电商商品项目||电商竞品分析|主流电商商品API接口在竞品分析中的重要应用

竞争数据采集 竞争数据是对在电子商务业务中彼此存在竞争关系的商家、品牌、产品(即竞争对手&#xff09;等各项运营数据的总称&#xff0c;在电子商务企业的经营过程中&#xff0c;对竞争对手进行分析可以帮助决策者和管理员了解竞争对手的发展势头&#xff0c;为企业成略制定…

多尺度特征提取:原理、应用与挑战

多尺度 多尺度特征提取&#xff1a;原理、应用与挑战**原理****应用****挑战****总结** 多尺度特征提取&#xff1a;原理、应用与挑战 在计算机视觉、自然语言处理和信号处理等领域&#xff0c;有效地捕捉和解析数据的多种尺度特性是至关重要的。多尺度特征提取是一种技术&…

24年下半年安徽教资认定准确时间和流程

安徽教资认定准确时间 网上报名时间&#xff1a; 第一批次&#xff1a;4月8日至4月19日17时 第二批次&#xff1a;6月17日至6月28日17时 注意&#xff1a;符合安徽省申请条件的普通大中专院校2024届全日制毕业生&#xff0c;应统一选择6月17日至6月28日17时的时间段进行网上报名…

VM4.3 二次开发02 方案加载、执行及显示

效果 这是二次开发的第二个文章&#xff0c;所以不重复说明环境配置相关的内容。如果不懂的可以看本专栏的上一个文章。 海康视觉算法平台VisionMaster 4.3.0 C# 二次开发01 加载方案并获取结果-CSDN博客 界面代码 <Window x:Class"VmTestWpf.App.MainWindow"x…

浏览器必装插件推荐:最新版Simple Allow Copy,解除网页复制限制!

经常在网上找资料的朋友&#xff0c;尤其是学生党&#xff0c;总会遇到一个问题&#xff1a;很多资料网站的文字是禁止复制的。于是大家通常会使用各种文字识别软件来图文转换&#xff0c;或者直接手打。 今天这款小工具&#xff0c;可以轻松复制各种氪金网站上的任何文字&…

爆肝整理AI Agent:在企业应用中的6种基础类型

AI Agent智能体在企业应用中落地的价值、场景、成熟度做了分析&#xff0c;并且探讨了未来企业IT基础设施与架构如何为未来Gen AI&#xff08;生成式AI&#xff09;做好准备。在这样的架构中&#xff0c;我们把最终体现上层应用能力的AI Agent从不同的技术要求与原理上分成了几…

考试系统Spring Security的配置

设置Spring Security配置类 1、设置包括认证、授权方法 protected void configure(HttpSecurity http) throws Exception {http.headers().frameOptions().disable();List<String> securityIgnoreUrls systemConfig.getSecurityIgnoreUrls();String[] ignores new Str…

重生奇迹mu圣导师介绍

出生地&#xff1a;勇者大陆 性 别&#xff1a;男 擅 长&#xff1a;统率&宠物使用 转 职&#xff1a;祭师&#xff08;3转&#xff09; 介 绍&#xff1a;当玩家账号中有一个Lv250以上角色时&#xff0c;便可以创建职业为圣导师的新角色&#xff0c;圣导师每升一级获得…

Trying to access array offset on value of type null

主要原因是版本7.4以后PHP解析器会对null类型的下标访问直接报错 背景&#xff1a; laravel框架 同时使用了扩展A和扩展B 扩展A要求 php>7.4,同时扩展B的对null类型的下标访问不兼容php7.4 修改扩展B不太现实&#xff0c;毕竟扩展B中有太多的对null类型的下标访问。 解决…

6月17日(周一),AH 股行情总结

AI手机及苹果概念股全日走强&#xff0c;领益智造、山东精密等多股涨停&#xff0c;立讯精、歌尔股份涨逾6% 。新车型秦L销售预期提振股价&#xff0c;比亚迪涨超1% &#xff1b;航运、煤炭、地产板块下跌。 文章正文 周一&#xff0c;A股低开高走&#xff0c;上证指数收跌0.…