sklearn.neighbors 最近邻相关算法,最近邻分类和回归

news2024/12/22 19:33:23

文章目录

  • sklearn.neighbors 最近邻相关算法,分类和插值
    • 1. 查找最近邻元素
    • 2. 最近邻分类
    • 3. 最近邻回归
    • 4. NearestCentroid 最近邻质心分类
    • 5. Neighborhood Components Analysis 邻域成分分析

sklearn.neighbors 最近邻相关算法,分类和插值

主要介绍 sklearn.neighbors 相关方法

1. 查找最近邻元素

from sklearn.neighbors import NearestNeighbors
import numpy as np

'''
找到K近邻
X是训练集,NearestNeighbors 拟合
Y是输入,输出与Y最近的训练集中的样本和距离
'''

# x 是离散的一些二维点
X = np.array([[-1, -1], [-2, -1], [-3, -2], [1, 1], [2, 1], [3, 2]])
# 最近邻模型,n_neighbors=2 表示2个最近邻,algorithm可以选择使用的算法,结果是一致的,效率高低不同, metric 选择度量方法
nbrs = NearestNeighbors(n_neighbors=2, algorithm='ball_tree', metric='euclidean').fit(X) # ['auto', 'ball_tree', 'kd_tree', 'brute']


Y = np.array([[0, 0], [4, 4]])
plt.figure()
plt.plot(X[..., 0], X[..., 1], 'r+', Y[..., 0], Y[..., 1], 'g*',)
plt.show()

# 查找 Y的2个最近邻的距离  和  索引
distances, indices = nbrs.kneighbors(Y)
print(distances, indices)

输出距离Y最近的2个元素索引和距离

在这里插入图片描述

输出邻接矩阵 和 可选的度量方法

# 输出邻接矩阵,稀疏图
nbrs.kneighbors_graph(Y).toarray()
# 输出可以使用的距离指标
from sklearn.neighbors import KDTree, BallTree 
print(sorted(KDTree.valid_metrics))
print(sorted(BallTree.valid_metrics))

output:

在这里插入图片描述

2. 最近邻分类

最近邻分类,并不进行建模。
scikit-learn 实现了两种不同的最近邻分类器:
KNeighborsClassifier实现基于 查询点的k个最近邻居,其中k是由用户指定的整数值。
RadiusNeighborsClassifier根据固定半径内的邻居数量确定分类,其中r是由用户指定的浮点值。

k-neighbors 分类KNeighborsClassifier 是最常用的技术。k值的最优选择 高度依赖于数据:一般来说,一个更大的k 抑制噪声的影响,但使分类边界不那么明显。

在数据未均匀采样的情况下,基于半径的邻居分类RadiusNeighborsClassifier可能是更好的选择。用户指定固定半径r,使得稀疏邻域中的点使用较少的最近邻进行分类。
对于高维参数空间,由于所谓的“维数灾难”,这种方法变得不太有效。

基本的最近邻分类使用统一权重:也就是说,分配给查询点的值是根据最近邻的简单多数票计算得出的。在某些情况下,最好对邻居进行加权,使得更近的邻居对拟合的贡献更大。
这可以通过weights关键字来完成。默认值 为每个邻居分配统一的权重。 可以提供用户定义的距离函数来计算权重。weights = ‘uniform’ 或者 weights = ‘distance’

使用方法:

n_neighbors = 15
weights = 'distance'
clf = neighbors.KNeighborsClassifier(n_neighbors, weights=weights)
clf.fit(X, y)
clf.predict(input)

3. 最近邻回归

同理,也有最近邻回归
scikit-learn 实现了两个不同的邻居回归器:
KNeighborsRegressor实现基于 查询点的k个最近邻居,其中k是由用户指定的整数值。
RadiusNeighborsRegressor基于固定半径内的邻居实现学习查询点,其中r是由用户指定的浮点值。

使用方法:

n_neighbors = 5
weights = 'uniform'
knn = neighbors.KNeighborsRegressor(n_neighbors, weights=weights)
y_ = knn.fit(X, y).predict(input)

4. NearestCentroid 最近邻质心分类

NearestCentroid
如果我们不再是求解到所有样本的距离,而是求解到不同类别样本中心的距离,距离哪个样本中心最近,我们即认为该待预测样本属于哪个类,这就是NearestCentroid算法.
该算法能降低计算量,但是对于不是中心分布的样本来说,准确率不高。

from sklearn.neighbors import NearestCentroid
import numpy as np
X = np.array([[-1, -1], [-2, -1], [-3, -2], [1, 1], [2, 1], [3, 2]])
y = np.array([1, 1, 1, 2, 2, 2])
clf = NearestCentroid()
clf.fit(X, y)

print(clf.predict([[-0.8, -1]]))

shrink_threshold 参数

粗略理解是 类别中心的特征值 除以 类内方差, 再减去 shrink_threshold 参数,如果大于0,说明该特征方差较大,将被去除,避免影响分类结果。
目的在于去除noisy features,使用后效果有优化
如果有误,还请指正

shrinkage = 0.2
clf = NearestCentroid(shrink_threshold=shrinkage)
clf.fit(X, y)
y_pred = clf.predict(X)

5. Neighborhood Components Analysis 邻域成分分析

这篇博客解释的比较清楚

Neighborhood Components Analysis 是将数据映射到另一个空间,(也可以理解为改变度量距离的函数)

在这里插入图片描述

在矩阵A的转换下,计算softmax 概率, 然后把同一类的概率加起来,希望这个概率大。

在这里插入图片描述

目标是希望正确分类的概率最大
目标函数为

在这里插入图片描述

目前使用 scipy 的 L-BFGS-B 进行Q的求解。 通过设置Q 的维度,可以达到降维的目的。

先使用 Neighborhood Components Analysis 进行转换,再应用KNeighborsClassifier 效果会更好, 因为Neighborhood Components Analysis 将样本转换到一个更好的表达空间。

如果用NCA降维,和LDA,PCA效果比较
NCA分数最高

在这里插入图片描述

代码如下:

# License: BSD 3 clause

import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.decomposition import PCA
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.neighbors import KNeighborsClassifier, NeighborhoodComponentsAnalysis
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler

n_neighbors = 3
random_state = 0

# Load Digits dataset
X, y = datasets.load_digits(return_X_y=True)

# Split into train/test
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.5, stratify=y, random_state=random_state
)

dim = len(X[0])
n_classes = len(np.unique(y))

# Reduce dimension to 2 with PCA
pca = make_pipeline(StandardScaler(), PCA(n_components=2, random_state=random_state))

# Reduce dimension to 2 with LinearDiscriminantAnalysis
lda = make_pipeline(StandardScaler(), LinearDiscriminantAnalysis(n_components=2))

# Reduce dimension to 2 with NeighborhoodComponentAnalysis
nca = make_pipeline(
    StandardScaler(),
    NeighborhoodComponentsAnalysis(n_components=2, random_state=random_state),
)

# Use a nearest neighbor classifier to evaluate the methods
knn = KNeighborsClassifier(n_neighbors=n_neighbors)

# Make a list of the methods to be compared
dim_reduction_methods = [("PCA", pca), ("LDA", lda), ("NCA", nca)]

# plt.figure()
for i, (name, model) in enumerate(dim_reduction_methods):
    plt.figure()
    # plt.subplot(1, 3, i + 1, aspect=1)

    # Fit the method's model
    model.fit(X_train, y_train)

    # Fit a nearest neighbor classifier on the embedded training set
    knn.fit(model.transform(X_train), y_train)

    # Compute the nearest neighbor accuracy on the embedded test set
    acc_knn = knn.score(model.transform(X_test), y_test)

    # Embed the data set in 2 dimensions using the fitted model
    X_embedded = model.transform(X)

    # Plot the projected points and show the evaluation score
    plt.scatter(X_embedded[:, 0], X_embedded[:, 1], c=y, s=30, cmap="Set1")
    plt.title(
        "{}, KNN (k={})\nTest accuracy = {:.2f}".format(name, n_neighbors, acc_knn)
    )
plt.show()

[1]https://scikit-learn.org/stable/modules/neighbors.html

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/124531.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

day31【代码随想录】回溯之子集||、递增子序列、全排列、全排列||

文章目录前言一、子集 II(力扣90)二、递增子序列(力扣491)三、全排列(力扣46)四、全排列||(力扣47)总结前言 1、子集|| 2、递增子序列 3、全排列 4、全排列|| 一、子集 II&#xff…

【C++】指针的基础知识 | 学习笔记

文章目录前言一、指针的定义和使用1.1、指针定义1.2、指针使用二、指针占用的内存空间三、空指针和野指针3.1.空指针3.2 野指针四、const修饰指针4.1 常量指针4.2 指针常量4.3 const既修饰指针也修饰常量五、指针,数组,函数混用案例5.1 指针和数组混用5.…

Talk预告 | 上海交通大学计算机系博士生李杰锋方浩树:多人场景,全身136关键点检测与跟踪框架AlphaPose技术讲解

本期为TechBeat人工智能社区第466期线上Talk! 北京时间12月28日(周三)20:00,上海交通大学计算机系博士生——李杰锋&方浩树的Talk将准时在TechBeat人工智能社区开播! 他们与大家分享的主题是: “多人场景,全身136关键点检测与…

初识Unity

视频教程:史上最全Unity3D教程 常用快捷键 1.按住鼠标滚轮,拖动场景 2.滑动鼠标滚轮,缩放场景 3.右键,旋转视角 4.右键W、A、S、D,漫游视角,同时按下Shift可加速移动 5.alt鼠标左键,环视…

【财务】FMS财务管理系统---费用管理

在FMS财务管理系统中,和公司主营业务收入相关的费用有哪些?本篇文章中,笔者对具体分类和流程进行了系统的分析和总结,与大家分享。 财务中的费用管理主要包括销售费用、财务费用、管理费用等几大部分,看到费用大家首先…

C#,图像二值化(06)——全局阈值的大津OTSU算法及其源代码

1、大津OTSU算法 最大类间方差法是1979年由日本学者大津(Nobuyuki Otsu)提出的,是一种自适应阈值确定的方法,又叫大津法,简称OTSU,是一种基于全局的二值化算法,它是根据图像的灰度特性,将图像分为前景和背景两个部分。…

Git简介以及安装

目录 一、Git简介 1、版本控制系统简介 2、 Git的安装 a、安装git b、Git 的配置 二,本地仓库 三、GIT分支操作 1、关于分支 2. 分支基本操作 3、分支合并 4、冲突 一、Git简介 1、版本控制系统简介 版本控制系统(VCS)是将『什么…

【数据结构】直接插入排序,希尔排序,选择排序,堆排序

文章目录排序的概念直接插入排序希尔排序选择排序堆排序排序的概念 排序:所谓排序,就是使一串记录,按照其中的某个或某些关键字的大小,递增或递减的排列起来的操作。 稳定性:假定在待排序的记录序列中,存在…

keras环境搭建

目录 1. 安装miniconda 2. 安装CPU版本的tensorflow 2. 安装keras 3. 安装依赖库 4. 测试 环境:win10,无独立显卡,不用GPU加速。 1. 安装miniconda Miniconda3-latest-Windows-x86_64.exe (1)安装目录可自选&a…

生成对抗:Pix2Pix

cGAN : Pix2Pix 生成对抗网络还有一个有趣的应用就是,图像到图像的翻译。例如:草图到照片,黑白图像到RGB,谷歌地图到卫星视图,等等。Pix2Pix就是实现图像转换的生成对抗模型,但是Pix2Pix中的对抗网络又不同于普通的GAN…

计网第三章.数据链路层—可靠传输

以下来自湖科大计算机网络公开课的笔记 文章目录0.基本概念1. 停止等待协议SW2. 回退N帧协议GBN3. 选择重传SR首先,这部分说的可靠传输的实现机制不只限于数据链路层,而是适用于整个计算机网络体系 0.基本概念 一般情况下,有线链路的误码率…

Docker 中的挂载卷

我们现在有这样一个需求。 我们有一个 Spring 的项目是部署在容器中的,如果不进行任何配置的话,这个项目运行的所有日子都会在容器中。 当容器重启说着终止后,上面的日志比较难进行查看。 我们希望我们的日志同时也记录在操作系统中&#…

阿贡国家实验室:量子中继器及其在量子网络中的作用

很多人小时候都玩过传声筒游戏:A将消息小声告诉B,然后B将他听到的内容小声告诉C,依此类推,玩过的人都知道,最后传达到的信息往往和真实消息完全不同。 从某种意义上说,这和中继器技术的重要性强相关。中继器…

MySQL锁,锁的到底是什么?

只要学计算机,「锁」永远是一个绕不过的话题。MySQL锁也是一样。 一句话解释MySQL锁: MySQL锁是解决资源竞争的一种方案。 短短一句话却包含了3点值得我们注意的事情: 对什么资源进行竞争?竞争的方式(或者说情形&a…

舆情监控和应急处理方案,如何做好网络舆情监控?

舆情监控是指通过不同的渠道,如社交媒体、新闻媒体、博客、论坛等,对公众的言论进行收集、分析、评估和反馈的过程。舆情监控的目的是帮助企业或组织了解公众的观点和情绪,并且能够及时做出回应,避免可能出现的舆论危机。接下来TO…

2022年度投影仪行业数据分析报告:十大热门品牌排行榜

在当前的大环境下,线下娱乐受阻,而用户对于足不出户的观影、娱乐需求推动着智能投影设备的增长。近几年来,投影仪行业保持着较快速度的增长,面对整体市场需求不振的形势,投影仪仍在保持正向增长。随着家用智能投影在市…

Charles - 阻塞请求、修改请求与响应内容、重定向请求地址、指定文件为响应内容

1、阻塞请求 1、鼠标放在指定接口上 > 右键 > 勾选 Block List 2、重新访问这接口,这条请求被阻塞,不会有返回信息 取消阻塞接口: 鼠标放在指定接口上 > 右键 > 取消勾选 Block List 2、修改请求与响应内容 第一步&#xff1…

【一文看懂 ES 核心】存储查询集群

一文看懂 ES 核心 Elasticsearch 作为一个搜索引擎,其可以提供高效的搜索匹配数据的能力,对于这类工具了解其运行原理其实是有一套功法的。 聊存储,ES 是如何存储数据的?聊方法,ES 是如何进行搜索匹配的?…

【Linux】文件描述符、文件操作、重定向的模拟实习

目录 一、重温C语言文件操作 1.1 文件打开方式 1.2 文件写操作 1.3 文件读操作 1.3 标准输入输出 二、系统接口的使用 2.1 open 函数 2.2 close 函数 2.3 write 函数 2.4 read 函数 三、文件描述符 3.1 如何管理文件 3.2 0 & 1 & 2 3.3 文件描述符的分配…

种草!超好用的PDF转换器上线啦~

宝子们 重磅福利来啦 你还在为每次转换文件头疼吗 老铁,大拿版万能转换器正式上线啦 以前的文件转换器,不是充会员就是收费高 最坑的是花钱还解决不了问题 每次转换文件内容有误.... 特殊符号或者公式更是无法有效转换 为了整顿这种局面&#xff0c…