随机森林Proximity实现及应用

news2025/1/20 13:18:20

随机森林Proximity实现及应用

  • 1 算法
    • 1.1 随机森林Proximity简介
    • 1.2 RF-GAP
    • 1.3 实现代码
  • 2 应用
    • 2.1 离群点(outlier)检测
      • 2.1.1 原理和实现
      • 2.1.2 实验结果
  • 附录

项目主页:randomforest
C++ implementation of random forests classification, regression, proximity and variable importance.
推荐阅读:Random Forests C++实现:细节,使用与实验

1 算法

1.1 随机森林Proximity简介

随机森林Proximity是一种表示两个样本相似性的度量。Proximity的应用方向包括异常数据检测(Outlier detection)、数据插补(imputation)、可视化等。最初的随机森林Proximity计算方法为,将所有的训练样本经过随机树到达叶子节点,如果两个样本到达同一个叶子节点,那么他们的Proximity值就增加1,最后结果除以树的数量以归一化。这种计算方法非常直接、快速,也被广泛使用。后续学者又提出了几种改进的方法,以提高Proximity估计的准确性,或更加适用于特定应用领域。

1.2 RF-GAP

2023年Jake S. Rhodes等在PAMI上发表了一篇论文《Geometry- and Accuracy-Preserving Random Forest Proximities》(arxiv)。他提出了一种新的随机森林Proximity(RF-GAP)计算方法,从理论上证明了基于该方法的分类/回归预测结果等同于oob分类/回归结果,并且给出了实验验证结果。2023年还能在PAMI上发表纯粹讨论随机森林的文章,且内容又是涉及RF算法中一个比较冷门的方面,说明这个工作是非常扎实的。
在这里插入图片描述

以下简要介绍下该论文的思想,文中Proximity(RF-GAP)的计算公式如下:

P ( i , j ) = 1 ∣ T o o b ( i ) ∣ ∑ t ∈ T o o b ( i ) ∣ { j ∣ i n b a g t l e a f ( i ) } ∣ ∣ { i n b a g t l e a f ( i ) } ∣ P(i,j)=\frac{1}{\lvert T_{oob}(i) \rvert} \sum_{t\in T_{oob}(i)} \frac {|\{j |inbag_{t}^{leaf}(i)\}|} {\lvert\{inbag_{t}^{leaf}(i)\}\rvert} P(i,j)=Toob(i)∣1tToob(i)∣{inbagtleaf(i)}∣{jinbagtleaf(i)}
上述公式是根据我自己理解重写的,也可以参考论文中的公式,符号和表达式简要说明如下:

  1. ∣ ⋅ ∣ \lvert {·}\rvert :表示集合中元素的数量;

  2. P ( i , j ) P(i,j) P(i,j): 样本 i i i j j jProximity值;

  3. T o o b ( i ) T_{oob}(i) Toob(i): 随机树的集合,对于这些树样本 i i i是out-of-bag(oob)数据;

  4. { i n b a g t l e a f ( i ) } \{inbag_{t}^{leaf}(i)\} {inbagtleaf(i)}: oob样本 i i i落入树 t t t的某个叶子节点,该节点上所有训练样本(in-bag)的集合;

  5. { j ∣ i n b a g t l e a f ( i ) } \{j |inbag_{t}^{leaf}(i)\} {jinbagtleaf(i)}: 上述叶子节点的in-bag样本中样本 j j j的集合,有可能是0, 也有可能 ≥ 1 \ge1 1,因为每棵树的训练样本是从原始数据集中有放回抽样得到的(bagging),所以叶子节点上的样本 j j j数量可能大于1。

  6. 需要说明的是, P ( i , j ) ≠ P ( j , i ) P(i,j)\ne P(j,i) P(i,j)=P(j,i) P ( i , i ) = 0 P(i,i)=0 P(i,i)=0 ∑ j P ( i , j ) = 1 \sum_{j}P(i,j)=1 jP(i,j)=1

1.3 实现代码

上述RF-GAP已经在我的randomforest上实现,分类与回归代码类似,以下仅贴出分类森林RF-GAP计算代码。

int ClassificationForestGAPProximity(LoquatCForest* forest, float** data, const int index_i, float*& proximities)
{
	if (NULL != proximities)
		delete[] proximities;

	proximities = new float[forest->RFinfo.datainfo.samples_num];
	memset(proximities, 0, sizeof(float) * forest->RFinfo.datainfo.samples_num);

	const int ntrees = forest->RFinfo.ntrees;
	int oobtree_num = 0;
	for (int t = 0; t < ntrees; t++)
	{
		//where the i-th sample is oob
		const struct LoquatCTreeStruct* tree = forest->loquatTrees[t];
		bool i_oob = false;
		for (int n = 0; n < tree->outofbag_samples_num; n++)
		{
			if (index_i == tree->outofbag_samples_index[n]) 
			{
				i_oob = true;
				break;
			}
		}

		if (false == i_oob)
			continue;

		oobtree_num++;

		map<int, int> index_multicity;
		const struct LoquatCTreeNode* leaf_i = GetArrivedLeafNode(forest, t, data[index_i]);
		
		if (leaf_i->samples_index != NULL)
		{
			for (int n=0; n<leaf_i->arrival_samples_num; n++)
			{
				if (index_multicity.find(leaf_i->samples_index[n]) == index_multicity.end())
					index_multicity.emplace(leaf_i->samples_index[n], 1);
				else
					index_multicity[leaf_i->samples_index[n]]++;
			}
		}else
		{
			// if forest did not store sample index arrrived at the leaf node, each in bag sample has to be tested
			for (int n = 0; n < tree->inbag_samples_num; n++)
			{
				const int j = tree->inbag_samples_index[n];
				const struct LoquatCTreeNode* leaf_j = GetArrivedLeafNode(forest, t, data[j]);
				if (leaf_i == leaf_j)
				{
					if (index_multicity.find(j) == index_multicity.end())
						index_multicity.emplace(j, 1);
					else
						index_multicity[j]++;
				}
			}
		}
		

		int M = 0;
		for (map<int, int>::iterator it = index_multicity.begin(); it != index_multicity.end(); it++)
		{
			M += it->second;
		}

		if (0 == M)
			continue;

		for (map<int, int>::iterator it = index_multicity.begin(); it != index_multicity.end(); it++)
			proximities[it->first] += it->second*1.0f/M;
	}

	if (0 == oobtree_num)
		return -1;

	for (int j = 0; j < forest->RFinfo.datainfo.samples_num; j++)
		proximities[j] = proximities[j] / oobtree_num;

	return 1;
}

2 应用

通过计算随机森林Proximity挖掘两个样本之间的相似度,可以应用于数据可视化、离群点检测、数据插补等场景。相比其他随机森林Proximity,据论文中描述RF-GAP在这些应用上展现了更好的效果。

2.1 离群点(outlier)检测

2.1.1 原理和实现

对于分类问题,离群点可以这样定义:该样本的outlier measure score显著地大于类内其他样本的值。这些异常样本可能与其他类别的样本相似,或者与所有类别样本都不相似。异常样本的存在会影响分类和回归算法的训练。随机森林离群点检测可以分为基于特征的方法和基于proximity的方法,前者比如Isolation Forest,基于RF-GAP的离群点检测方法属于后者。步骤如下:

  1. 对于样本 i i i,计算raw outlier measure score s c o r e ( i ) = n ∑ j ∈ c l a s s ( i ) P 2 ( i , j ) score(i)=\frac{n}{\sum_{j\in class(i)} P^2(i,j)} score(i)=jclass(i)P2(i,j)n
  2. 计算类内raw_score的中值 m c = m e d i a n { s c o r e ( i ) ∣ c l a s s ( i ) ∈ c } m_c=median\{score(i)|class(i)\in c\} mc=median{score(i)class(i)c},类内样本raw score与中值绝对差的均值 d e v c = 1 N c ∑ c l a s s ( i ) ∈ c ∣ s c o r e ( i ) − m c ∣ dev_c=\frac{1}{N_c}\sum_{class(i)\in c}|score(i)-m_c| devc=Nc1class(i)cscore(i)mc
  3. 计算得到规范化的raw outlier measure score s c o r e ( i ) ← m a x ( s c o r e ( i ) − m c d e v c , 0 ) score(i)\gets max(\frac{score(i)-m_c}{dev_c}, 0) score(i)max(devcscore(i)mc,0)

在我的randomforest算法中实现计算outlier measure score的代码如下,过程细节借鉴了Leo Breiman的RF实现(Leo Breiman大神的代码是Fortran写的,关于outlier measure的计算在函数locateout中)。

int RawOutlierMeasure(LoquatCForest* loquatForest, float** data, int* label, float*& raw_score)
{
	if (NULL != raw_score)
		delete [] raw_score;
	
	int rv=1;
	const  int samples_num = loquatForest->RFinfo.datainfo.samples_num;
	const int class_num = loquatForest->RFinfo.datainfo.classes_num;

	raw_score = new float [samples_num];
	memset(raw_score, 0, sizeof(float)*samples_num);
	
	// 1. 计算raw outlier measure score
	float *proximities = NULL;
	for (int i=0; i<samples_num; i++)
	{
		ClassificationForestGAPProximity(loquatForest, data, i, proximities /*OUT*/);

		float  proximity2_sum = 0.f;
		for(int j=0; j<samples_num; j++)
		{
			if (label[j] != label[i] || j == i)
				continue;

			// within class
			proximity2_sum += proximities[j] * proximities[j];
		}
		
		raw_score[i] = samples_num / (proximity2_sum+1e-5);

		delete [] proximities;
		proximities = NULL;
	}

	// 2. 计算类内raw_score的中值,类内样本raw score与中值绝对差的均值
	float *dev = new float[class_num];
	float *median = new float[class_num];
	memset(dev, 0, sizeof(float)*class_num);
	memset(median, 0, sizeof(float)*class_num);

	for (int j=0; j<class_num; j++)
	{
		vector<float>  raw_score_class_j;
		
		for (int i=0; i<samples_num; i++)
		{
			if (label[i] == j)
				raw_score_class_j.push_back(raw_score[i]);
		}

		std::sort(raw_score_class_j.begin(), raw_score_class_j.end());

		const int sample_num_j = raw_score_class_j.size();
		if (0 == sample_num_j)
		{
			rv = 0;
			dev[j] = 1.f;
			median[j] = 1.f;
			continue;
		}

		if (sample_num_j%2 == 1)
			median[j] = raw_score_class_j[sample_num_j/2];
		else
			median[j] = 0.5f*(raw_score_class_j[sample_num_j/2] + raw_score_class_j[sample_num_j/2-1]);
		

		for (vector<float>::iterator it=raw_score_class_j.begin(); it != raw_score_class_j.end(); it++)
			dev[j] += abs(*it - median[j]);
		dev[j] = dev[j] / sample_num_j;
	}

    // 3. 计算得到规范化的raw outlier measure score
	for( int i=0, lb=0; i<samples_num; i++)
	{
		lb = label[i];
		raw_score[i] = RF_MAX( (raw_score[i] - median[lb])/(dev[lb]+1e-5), 0.0);
	}

	delete [] dev;
	delete [] median;

	return rv;
}

2.1.2 实验结果

使用mnist手写字符识别数据集训练随机森林分类器,样本数60000,特征数780。RF参数为:随机树 T = 200 T=200 T=200,分类候选特征数 780 \sqrt{780} 780 ,节点最小样本数5,最大树深度40。
计算outlier measure的核心是计算proximities P ( i , j ) P(i,j) P(i,j),当训练RF时叶子节点保存了落入其中的in-bag样本信息(至少有样本在训练集中的序号),那么初略估计RF-GAP算法的复杂度为 O ( T ∗ N ) O(T*N) O(TN) T T T N N N分别是随机树和训练集样本的数量。RF-GAP文章最后作者提到这个算法比较适用于样本在几千范围的数据集,下一步工作是改进算法以适用于更大的数据集。
计算所有样本"outlier measure score",其中部分类别样本的数值在下面图中展示(按数值升序),横坐标是样本,纵坐标是样本对应的"outlier measure score"。其中,类别4和9(即数字4–子图3,数字9–子图4)的极少数样本显著异常,可以对应查看下文中的离群样本图像,更加直观。

在这里插入图片描述

检测mnist手写数字识别数据集上离群样本,选取了几个类别中最大的raw outlier measure score对应的图像,从图像上来看确实较难辨认。比如数字4的离群样本,看起来是数字7,验证过确实标注类别4,在原训练集中的序号是59915 (序号从0开始),raw score 超过200,与类内其他样本有显著区别。
在这里插入图片描述

随机选取了一些raw outlier measure score为0的样本图像(RF认为类内相似度较高),示例如下,确实是一些正常可辨认的手写数字。
在这里插入图片描述

备注:mnist有几种训练集,对应特征有784维(对应图像28x28)和780维,本文选用的特征是780维。为了显示,在780维图像头部补4个0,使每个图像从780扩充到784,导致字符都有点偏右。
在这里插入图片描述

附录

使用我实现的randomforest算法进行RF训练+异常样本检测(计算样本的raw outlier measure score)的代码如下:

int main()
{
	// read training samples if necessary
	char filename[500] = "/to/direction/dataset/train-data.txt" 
	float** data = NULL;
	int* label = NULL;
	Dataset_info_C datainfo;
	int rv = InitalClassificationDataMatrixFormFile2(filename, data/*OUT*/, label/*OUT*/, datainfo/*OUT*/);
	// check the return value
	// 	... ...

	// setting random forests parameters
	RandomCForests_info rfinfo;
	rfinfo.datainfo = datainfo;
	rfinfo.maxdepth = 40;
	rfinfo.ntrees = 200;
	rfinfo.mvariables = (int)sqrtf(datainfo.variables_num);
	rfinfo.minsamplessplit = 5;
	rfinfo.randomness = 1;
	// train forest
	LoquatCForest* loquatCForest = NULL;
	rv = TrainRandomForestClassifier(data, label, rfinfo, loquatCForest /*OUT*/, 20);
	// check the return value
	// 	... ...

	//outlier measurement//
	float *raw_score=NULL;
	RawOutlierMeasure2(loquatCForest, data, label, raw_score);
	// raw_socre -- outlier measurements
	// ... ...
	delete [] raw_score;
	/outlier measurement//

	// clear the memory allocated for the entire forest
	ReleaseClassificationForest(&loquatCForest);
	// release money: data, label
	for (int i = 0; i < datainfo.samples_num; i++)
		delete[] data[i];
	delete[] data;
	delete[] label;
	return 0;
}

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/573867.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

可以免费使用的ChatGPT保姆级教程 (New Bing)

ChatGPT狂飙160天&#xff0c;世界已经不是之前的样子。https://ai.weoknow.com 每天给大家更新可用的国内可用chatGPT资源 最近&#xff0c;ChatGPT已经非常流行&#xff0c;但由于各种原因&#xff0c;国内用户无法直接免费使用ChatGPT的API&#xff0c;各种伟大的神也利用这…

沉浸式翻译 安装及使用

介绍一下最近非常或的沉浸式翻译工具&#xff0c;非常有助于外文阅读&#xff0c;包括网页、pdf等。可以同时显示原文和译文&#xff0c;操作简单&#xff0c;使用起来还是非常友好的。 先上链接&#xff1a;介绍 - 沉浸式翻译 如何使用 - 沉浸式翻译 1.安装 支持Edg…

仙人掌之歌——权力的游戏(2)

他是特级战斗英雄 “那个李通&#xff0c;会不会看起来好吓人呀&#xff1f;” 云冰洁有些紧张的样子&#xff0c;几乎要让陈速笑出来。 “哪有&#xff0c;一个很 nice 的人好吧。就是看起来比较严肃而已&#xff0c;我也从没看他笑过倒是。” 陈速让云冰洁看菜单&#xff0…

【极海APM32F4xx Tiny】学习笔记01-模板工程创建

本项目的使用的开发板 关于芯片使用的其他笔记 外部晶振时钟 模板工程创建/LED工程 项目仓库 https://gitcode.net/u010261063/apm32_test_part 创建模板工程的核心要素 复制官方的标准外设库复制启动文件复制cmsis文件复制相关的公共头文件如apm32f4xx_int.h 和 system_apm…

mybatis trim标签使用详解

mybatis trim标签使用详解 mybatis的trim标签一般用于去除sql语句中多余的and关键字&#xff0c;逗号&#xff0c;或者给sql语句前拼接 “where“、“set“以及“values(“ 等前缀&#xff0c;或者添加“)“等后缀&#xff0c;可用于选择性插入、更新、删除或者条件查询等操作。…

Dubbo框架

文章目录 1. 什么是Dubbo2. Dubbo架构3. SpringBoot整合Dubbo框架3.1 前期准备3.1.1 Zookeeper的安装 3.2 项目创建3.3 添加依赖3.4 定义服务接口3.5 服务端的实现3.6 消费端请求任务3.7 服务端配置文件3.8 消费端配置文件3.9 启动应用 4. Dubbo负载均衡5. Dubbo集群容错 1. 什…

第一部分-基础篇-第一章:PSTN与VOIP(上篇)

文章目录 序言引言&#xff1a;什么是VOIP和PSTN1.1 PSTN起源与发展1.1.1 最早的电话网1.1.2 人工电话交换时代1.1.3自动电话交换时代1.1.4半电子交换机时代1.1.5空分交换机时代1.1.6 数字交换机时代1.1.7现代PSTN时代1.1.8 下一代网络及VoIP时代 1.2 电话实现技术1.2.1 电话号…

【MySQL】如何速通MySQL(1)

&#x1f4cc;前言&#xff1a;本篇博客介绍如何速通MySQL&#xff0c;主要介绍Mysql中主要的基础的入门&#xff0c;学习MySQL之前要先安装好MySQL&#xff0c;如果还没有安装的小伙伴可以看看博主前面的博客&#xff0c;里面有详细的安装教程。或者看一下下面这个链接~ &…

“AI孙燕姿”爆火背后,是内容合规问题的再次升级|上云那些事

“讽刺的是&#xff0c;人类再怎么快也无法超越它。”这是歌手孙燕姿关于自己AI分身遍布网络一事&#xff0c;在MAKE MUSIC网站的博客上发表的看法。 来源&#xff1a;孙燕姿MAKE MUSIC网站博客 当大家还在担心AIGC会不会让自己失业时&#xff0c;歌手孙燕姿就因为“AI孙燕姿”…

LDA算法实现鸢尾花数据集降维

目录 1. 作者介绍2. LDA降维算法2.1 基本概念2.2 算法流程 3. LDA算法实现3.1 数据集介绍3.2 代码实现3.3 结果展示 1. 作者介绍 唐杰&#xff0c;男&#xff0c;西安工程大学电子信息学院&#xff0c;2022级研究生 研究方向&#xff1a;机器视觉与人工智能 电子邮件&#xff…

深度学习笔记(八)——语义分割标注转换

核心思想&#xff1a;“将颜色转换成对应的标号” 形式一&#xff1a;Json格式的标注转换成调色板mask 形式二&#xff1a;RGB类型mask(24位三通道&#xff09;转成调色板mask&#xff08;8位单通道&#xff09;&#xff0c;调色板的格式为.png 形式三&#xff1a;对于二分类的…

oracle安装

服务端安装&#xff08;公司中不需要&#xff0c;只安装客户端就行&#xff09; 1、挂载一个Windows系统 双击vmx文件 启动 2、网络配置 添加一个网络 自己电脑看控制面板是否添加虚拟网卡 查看连接的网络&#xff0c;ip地址不能为1&#xff0c;为1就自己修改&#xff0c;…

深度剖析:C++内存池的设计与实现

深度剖析&#xff1a;C内存池的设计与实现 一、引言&#xff08;Introduction&#xff09;1.1 内存管理的重要性1.2 内存池的基本概念1.3 内存池的应用场景 二、C内存管理机制&#xff08;C Memory Management Mechanism&#xff09;2.1 C内存分配与释放2.2 C内存管理的问题2.3…

《Kali渗透基础》04. 主动信息收集(一)

kali渗透 1&#xff1a;主动信息收集2&#xff1a;发现3&#xff1a;二层发现3.1&#xff1a;arping3.2&#xff1a;nmap3.3&#xff1a;netdiscover3.4&#xff1a;Scapy 4&#xff1a;三层发现4.1&#xff1a;ping4.2&#xff1a;Scapy4.3&#xff1a;nmap4.4&#xff1a;fpi…

Win2016服务器DNS服务搭建

文章目录 前言一、什么是DNS&#xff1f;1.为什么需要DNS系统2.为DNS&#xff08;Domain Name System&#xff0c;域名系统&#xff09;的功能3.域名解决方案的演进 二、域名介绍1.域名空间结构2.常见的顶级域名 三、DNS解析原理1.查询过程及方式2.DNS的查询分类 四、配置DNS服…

【论文精读】ICLR2022 - 语言驱动的语义分割

【论文精读】ICLR2022 - 语言驱动的语义分割 【论文原文】&#xff1a;LANGUAGE-DRIVEN SEMANTIC SEGMENTATION 【作者信息】&#xff1a;Boyi Li Cornell University, Cornell Tech Kilian Q. Weinberger Cornell University Serge Belongie University of Copenhagen Vladl…

2023年试用uniapp、vue2、vue3、typescript、vite、nvue

1. 前言 试用了一下 uniapp、vue2、vue3、typescript、vite、nvue 等技术&#xff0c;写了两个页面&#xff0c;两个页面加起来不到400行代码。 尝试使用了四种组合&#xff1a; 组合1&#xff1a;uniapp vue2 JavaScript nvue文件 非fast模式 组合2&#xff1a;uniapp…

【C++系列Pn】模板搞不懂,脑阔抖三抖(精讲模板,快来复习趴)

前言 大家好吖&#xff0c;欢迎来到 YY 滴 C系列 &#xff0c;热烈欢迎&#xff01;本章主要内容面向接触过C的老铁&#xff0c;主要内容含 目录 一.模板 1.函数模板 一.函数模板概念 二.函数模板的格式 三.函数模板的实例化 1.隐式实例化 2.显式实例化 3.模板参数的…

数据库的简介

文章目录 前言一、为什么需要数据库二、数据库基本概念1.什么是数据库2.什么是数据库管理系统3.数据库表4.数据库表 三、常见的数据库管理系统 前言 数据库的简介 一、为什么需要数据库 信息时代数据容量海量增长&#xff0c;结构化存储大量数据&#xff0c;便于高效的检索和…

如何在华为OD机试中获得满分?Java实现【区块链文件转储系统】一文详解!

✅创作者&#xff1a;陈书予 &#x1f389;个人主页&#xff1a;陈书予的个人主页 &#x1f341;陈书予的个人社区&#xff0c;欢迎你的加入: 陈书予的社区 &#x1f31f;专栏地址: Java华为OD机试真题&#xff08;2022&2023) 文章目录 1. 题目描述2. 输入描述3. 输出描述…