k近邻法学习

news2025/1/12 5:52:29

k近邻法(k-nearest neighbor, k-NN)是一种基本分类与回归方法(下面只写分类的)

knn的输入为实例的特征向量,对应于特征空间的店;
输出为实例的类别。
knn假设给定的训练数据集,其中的实力类别已定,分类时,对新的实例,根据其k个最近邻的训练实例的类别,通过多数表决等方式进行预测。

算法

输入:训练数据集
T = { ( x 1 , y 1 ) , ( x 2 , y 2 ) , ⋯   , ( x N , y N ) } T = \left\{\left(\mathbf{x}_1,y_1\right), \left(\mathbf{x}_2,y_2\right), \cdots, \left(\mathbf{x}_N,y_N\right)\right\} T={(x1,y1),(x2,y2),,(xN,yN)}
其中 x i ∈ X ⊆ R n , y i ∈ Y = { c 1 , c s , ⋯   , c K } \mathbf{x}_i\in \mathcal{X} \subseteq \mathbb{R}^n, y_i \in \mathcal{Y}=\left\{c_1, c_s, \cdots,c_K\right\} xiXRn,yiY={c1,cs,,cK},这里大写的 K K K表示类别,和knn的 k k k没有关系
输出:实例 x \mathbf{x} x所属的类 y y y
(1)根据给定的距离度量,在训练集 T T T中找出与 x \mathbf{x} x最近的 k k k个店,涵盖这 k k k个点的 x \mathbf{x} x的领域记作 N k ( x ) N_k\left(\mathbf{x}\right) Nk(x)
(2)在 N k ( x ) N_k\left(\mathbf{x}\right) Nk(x)中根据分类决策规则(如多数表决)决定 x \mathbf{x} x的类别 y y y
y = arg ⁡ max ⁡ c j ∑ x i ∈ N k ( x ) I ( y i = c j ) , i = 1 , 2 , c … , N ; j = 1 , 2 , ⋯   , K y = \arg\max_{c_j}\sum_{\mathbf{x}_i\in N_k\left(\mathbf{x}\right)} I\left(y_i=c_j\right),\quad i=1,2,c\dots, N;\quad j=1,2,\cdots, K y=argcjmaxxiNk(x)I(yi=cj),i=1,2,c,N;j=1,2,,K
其中 I I I是指示函数, y i = c j y_i = c_j yi=cj时为 1 1 1,其他时候为 0 0 0

kd树

由于线性扫描比较耗时,所以用kd树

构造

输入: k k k维空间数据集 T = { x 1 , ⋯   , x N } T=\left\{\mathbf{x}_1, \cdots, \mathbf{x}_N\right\} T={x1,,xN}(注意这里的这个k和knn的k没有关系)
其中 x i = ( x i ( 1 ) , ⋯   , x i ( k ) ) T \mathbf{x}_i = \left(\mathbf{x}_i^{\left(1\right)},\cdots, \mathbf{x}_i^{\left(k\right)}\right)^T xi=(xi(1),,xi(k))T
输出:kd树
(1)开始:构造根节点,根节点对应于包含 T T T k k k维空间的超矩形区域
选择 x ( 1 ) \mathbf{x}^{\left(1\right)} x(1)为坐标轴,以 T T T中所有的实例的 x ( 1 ) \mathbf{x}^{(1)} x(1)坐标的中位数为切分点,将根节点对应的超矩形区域切分为两个子区域。

由根节点生成深度为1的左、右子节点:左子节点对应坐标 x ( 1 ) \mathbf{x}^{(1)} x(1)小于切分点的子区域,右子节点对应于坐标 x ( 1 ) \mathbf{x}^{(1)} x(1)大于切分点的子区域

(2)重复:对深度为 j j j的节点,选择 x ( l ) \mathbf{x}^{\left(l\right)} x(l)为切分的坐标轴, l = j ( m o d    k ) + 1 l=j\left(\mod k\right) + 1 l=j(modk)+1,以该节点的区域中所有实例的 x ( l ) \mathbf{x}^{(l)} x(l)坐标的中位数为切分点,将该节点对应的超矩形区域切分为两个子区域。

由该节点生成深度为 j + 1 j+1 j+1的左、右子节点:左子节点对应坐标 x ( l ) \mathbf{x}^{(l)} x(l)小于切分点的子区域,右子节点对应于坐标 x ( l ) \mathbf{x}^{(l)} x(l)大于切分点的子区域

(3)直到两个子区域没有实例存在时停止,从而形成kd树的区域划分

补充:
找中位数,可以使用C++的nth_element,也就是快排里的partition
在这里插入图片描述

搜索

假设寻找 x ∈ R k \mathbf{x}\in\mathbb{R}^k xRk k k k个最近邻
(1)设 L L L为一个有 k k k个空位的列表,用于保存已搜寻到的最近点。
(2)根据 x \mathbf{x} x的坐标值和每个节点的切分向下搜索
(3)当达叶子节点时,如果 L L L里不足 k k k个点,则将当前节点的特征坐标加入 L L L;如果 L L L不为空并且当前节点的特征与 x \mathbf{x} x的距离小于 L L L里最长的距离,则用当前特征替换掉 L L L中离 x \mathbf{x} x最远的点。
(4)如果当前节点不是整棵树根节点,执行 (a);反之,输出 L L L,算法完成。
(a) 向上一层(当前节点的父节点)执行1和2。

  1. 如果此时 L L L里不足 k k k个点,则将节点特征加入 L L L;如果 L L L中已满 k k k个点,且当前节点与 x \mathbf{x} x的距离小于 L L L里最长的距离,则用节点特征替换掉 L L L中离最远的点。
  2. 计算 x \mathbf{x} x和当前节点切分线的距离。如果该距离大于等于 L L L中距离 x \mathbf{x} x最远的距离并且 L L L中已有 k k k个点,则在切分线另一边不会有更近的点,执行 (4);如果该距离小于 L L L中最远的距离或者 L L L中不足 k k k个点,则切分线另一边可能有更近的点,因此在当前节点的另一个孩子中从 (2) 开始执行。

这里(4)-(a)-2说的切分线的距离,指:设根据第 l l l个维度切分,那么计算 x ( l ) x^{(l)} x(l)和切分线的距离
因此选的距离,应该是类似 L p L_p Lp这种,这样如果距离大于等于 L L L中距离 x \mathbf{x} x最远的距离并且 L L L中已有 k k k个点,另一个区域中的点 x i \mathbf{x}_i xi才能满足 d ( x , x i ) ≥ ∣ x ( l ) − x i ( l ) ∣ d\left(\mathbf{x}, \mathbf{x}_i\right) \ge \left|x^{(l)}-x_i^{(l)}\right| d(x,xi) x(l)xi(l) ,进而舍弃这些点

代码

#!/usr/bin/env python
# _*_ coding:utf-8 _*_
from collections import Counter

import numpy as np
import heapq

import matplotlib
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split


def distance(x, y):
    return np.sqrt(np.sum((x.squeeze() - y.squeeze()) ** 2))


class KDNode:
    def __init__(self, data=None, label=None, split_dim=None, split_val=None, left=None, right=None):
        self.data = data  # shape(n,)
        self.label = label  # shape(1,)
        self.split_dim = split_dim
        self.split_val = split_val
        self.left = left
        self.right = right


class KDTree:
    def __init__(self, k, distance):
        self.root = None
        self.k = k
        self.distance = distance

    def _build_tree(self, X, Y, l, r, depth):
        split_dim = depth % X.shape[1]
        if l + 1 == r:
            return KDNode(X[l], Y[l], split_dim, X[l, split_dim])
        elif l >= r:
            return None

        # mid = l + (r - l) // 2
        mid = (l + r - 1) // 2
        partition = l + np.argpartition(X[l:r, split_dim], mid - l)
        X[l:r] = X[partition]
        Y[l:r] = Y[partition]
        split_val = X[mid, split_dim]
        root = KDNode(X[mid], Y[mid], split_dim, split_val)
        root.left = self._build_tree(X, Y, l, mid, depth + 1)
        root.right = self._build_tree(X, Y, mid + 1, r, depth + 1)
        return root

    def build_tree(self, X, Y):
        self.root = self._build_tree(X, Y, 0, X.shape[0], 0)

    def _search(self, x, root: KDNode, ans: list, k: int):
        if not root:
            return
        elif root.left is None and root.right is None:
            dist = self.distance(root.data, x)
            if len(ans) < k:
                # the id(root) here is to prevent heapq comparing the data and the label, because it is not comparable
                heapq.heappush(ans, (-dist, id(root), root.data, root.label))
            elif len(ans) == k and -dist > ans[0][0]:  # dist1 < dist_max
                heapq.heapreplace(ans, (-dist, id(root), root.data, root.label))
            return
        split_dim = root.split_dim

        next_root, other_root = None, None
        if x[split_dim] < root.split_val:
            next_root = root.left
            other_root = root.right
        else:
            next_root = root.right
            other_root = root.left

        self._search(x, next_root, ans, k)

        dist = self.distance(root.data, x)
        if len(ans) < k:
            heapq.heappush(ans, (-dist, id(root), root.data, root.label))
        elif len(ans) == k and -dist > ans[0][0]:  # dist1 < dist_max
            heapq.heapreplace(ans, (-dist, id(root), root.data, root.label))
        if other_root is not None and np.abs(x[split_dim] - root.split_val) < -ans[0][0]:
            self._search(x, other_root, ans, k)

    def search(self, x):
        ans = []
        self._search(x.squeeze(), self.root, ans, self.k)
        # ans.sort(key=lambda cur: -cur[0])
        return [cur[2:] for cur in ans]


class KNN:
    def __init__(self, k, distance):
        self.kd_tree = KDTree(k, distance)

    def fit(self, X, Y):
        self.kd_tree.build_tree(X, Y)

    def predict_one(self, x):
        """

        :param x: x.shape=(n,)
        :return:
        """
        k_list = self.kd_tree.search(x)
        # print(k_list)
        cnt = Counter()
        for p, y in k_list:
            cnt.update({y: 1})
            # weighted by 1/ distance
            # cnt.update({y: 1.0 / distance(x, p)})
        return cnt.most_common(1)[0][0]

    def predict(self, X):
        return np.array([self.predict_one(x) for x in X], dtype=np.int64)


if __name__ == '__main__':
    # X = np.array([
    #     [6.27, 5.50],
    #     [1.24, -2.86],
    #     [17.05, -12.79],
    #     [-6.88, -5.40],
    #     [-2.96, -0.50],
    #     [7.75, -22.68],
    #     [10.80, -5.03],
    #     [-4.60, -10.55],
    #     [-4.96, 12.61],
    #     [1.75, 12.26],
    #     [15.31, -13.16],
    #     [7.83, 15.70],
    #     [14.63, -0.35]
    # ])
    # Y = np.random.randint(0, 2, X.shape[0])
    # print(X)
    # print(Y)
    # knn = KNN(3, distance)
    # knn.fit(X, Y)
    # print(knn.predict(np.array([[-1, -5], [-1, -5]])))

    iris = load_iris()
    X = iris.data[:, :2]  # (150,2)
    Y = iris.target  # (150,)
    X_train, X_test, y_train, y_test = train_test_split(X, Y, stratify=Y, random_state=42)

    n_neighbors = 5
    knn = KNN(5, distance)
    knn.fit(X_train, y_train)
    y_pred = knn.predict(X_test)
    print(y_pred)
    # 查看各项得分
    print("y_pred", y_pred)
    print("y_test", y_test)
    # print("score on train set", knn.score(X_train, y_train))
    # print("score on test set", knn.score(X_test, y_test))
    print("accuracy score", accuracy_score(y_test, y_pred))


    # 可视化
    # 自定义colormap
    def colormap():
        return matplotlib.colors.LinearSegmentedColormap.from_list('cmap', ['#FFC0CB', '#00BFFF', '#1E90FF'], 256)


    x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
    y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
    axes = [x_min, x_max, y_min, y_max]
    xp = np.linspace(axes[0], axes[1], 500)  # 均匀500的横坐标
    yp = np.linspace(axes[2], axes[3], 500)  # 均匀500个纵坐标
    xx, yy = np.meshgrid(xp, yp)  # 生成500X500网格点
    xy = np.c_[xx.ravel(), yy.ravel()]  # 按行拼接,规范成坐标点的格式
    y_pred = knn.predict(xy).reshape(xx.shape)  # 训练之后平铺

    # 可视化方法一
    # plt.figure(figsize=(15, 5), dpi=100)
    plt.contourf(xx, yy, y_pred, alpha=0.3, cmap=colormap())
    # 画三种类型的点
    p1 = plt.scatter(X[Y == 0, 0], X[Y == 0, 1], color='blue', marker='^')
    p2 = plt.scatter(X[Y == 1, 0], X[Y == 1, 1], color='green', marker='o')
    p3 = plt.scatter(X[Y == 2, 0], X[Y == 2, 1], color='red', marker='*')
    # 设置注释
    plt.legend([p1, p2, p3], iris['target_names'], loc='upper right', fontsize='large')
    # 设置标题
    plt.title(f"3-Class classification (k = {n_neighbors})", fontdict={'fontsize': 15})
    plt.show()

参考:
统计学习方法(李航)
https://zhuanlan.zhihu.com/p/23966698
https://bitbucket.org/StableSort/play/src/master/src/com/stablesort/kdtree/KDTree.java
https://zhuanlan.zhihu.com/p/343657182

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/496452.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

使用vscode进行python的单元测试,提高开发效率

背景知识 单元测试在我们的开发过程中非常有必要&#xff0c;它可以验证实现的一个函数是否达到预期。以前在学校写代码时&#xff0c;都是怼一堆代码&#xff0c;然后直接运行&#xff0c;如果报错再一步步调试&#xff0c;这样大部分时间都浪费在调试工作上。工作后发现大家…

【c/c++】curl编译(CMake方式)

一、curl下载 下载地址&#xff1a;curl - Download 进入下载页面&#xff0c;选择Old Releases。 二、CMake下载 这玩意居然有官网&#xff0c;刷新了我的认知&#xff0c;省事啊。 Download | CMake 三、CMake生成VS项目 1、点击【Browse Source ...】&#xff0c;先选择…

蓝牙耳机哪个品牌最好?数码博主整理2023超高性价比蓝牙耳机推荐

近来收到很多私信不知道蓝牙耳机哪个品牌最好&#xff0c;希望我能进行一期蓝牙耳机推荐&#xff0c;考虑到大家的预算不高&#xff0c;我特意花费时间测评了当下主流品牌的热销平价蓝牙耳机&#xff0c;最终整理成了这份超高性价比蓝牙耳机推荐&#xff0c;感兴趣的朋友们可以…

ASN.1-PKCS10

ASN1采用一个个的数据块来描述整个数据结构&#xff0c;每个数据块都有四个部分组成&#xff1a; 1、数据块数据类型标识&#xff08;一个字节&#xff09; 数据类型包括简单类型和结构类型。 简单类型是不能再分解类型&#xff0c;如整型(INTERGER)、比特串(BIT STRING)、字…

【Unity】搭建Jenkins打包工作流,远程打热更、构建App

Jenkins是团队协作项目打包常用的工作流&#xff0c;不多做介绍。 Jenkins的部署Unity打包环境还是非常简单的&#xff1a; 工作流程如下&#xff1a; 1. 在Jenkins中添加打包配置参数(如: 版本号, 目标平台等), 参数将以UI的形式显示在Jenkins Web界面以便打包前填写参数&a…

机器人抓取检测——Dex-Net

如今&#xff0c;在各种期刊顶会都能看到平面抓取检测的论文&#xff0c;他们声称能应对多物体堆叠场景&#xff0c;然而实际效果都不尽人意&#xff0c;我认为主要原因有如下几点&#xff1a; 缺乏多物体堆叠场景的抓取数据集。现在最常用的Cornell Grasp Dataset, Jacquard数…

政务网中使用内部华为云

项目按甲方要求&#xff0c;部署在政务网&#xff0c;各种需要在系统中播放的视频存放于内部华为云&#xff1b;然后&#xff0c;系统需要在互联网上访问。 经过一天捣鼓&#xff0c;终于搞定。过程中遇到了许多问题&#xff0c;有nginx代理的&#xff0c;docker域名解析的&am…

FTP Entering Extended Passive Mode

目录 原因 两种方法解决,哪个行用哪种 方法一 方法二 原因 FTP的连接建立有两种模式PORT

10个优秀设计网站盘点

从平面广告设计、包装设计和标志设计到游戏特效&#xff0c;都与我们的生活息息相关。过去&#xff0c;设计师依靠一张图纸和一支笔&#xff0c;但进入数字时代后&#xff0c;设计工作从图纸转移到了电脑上。 各种设计网站和在线设计工具相继衍生&#xff0c;简化了工作步骤&a…

Packet Tracer - 配置扩展 ACL - 场景 1

Packet Tracer - 配置扩展 ACL - 场景 1 拓扑图 地址分配表 设备 接口 IP 地址 子网掩码 默认网关 R1 G0/0 172.22.34.65 255.255.255.224 不适用 G0/1 172.22.34.97 255.255.255.240 不适用 G0/2 172.22.34.1 255.255.255.192 不适用 服务器 NIC 172.22.…

戴尔Alienware x15R1 x15R2原厂win11系统带F12 Support Assist OS Recovery恢复功能

戴尔Alienware x15R1 x15R2原厂win11系统带F12 Support Assist OS Recovery恢复功能 恢复各机型预装系统&#xff0c;带所有dell主题壁纸、dell软件驱动、带戴尔SupportAssist OS Recovery恢复功能&#xff0c;一次性恢复成新机状态&#xff0c;并且以后不用重装系统&#xff…

pyinstaller打包Mediapipe时遇到的问题

使用pyinstaller对python文件打包 打包流程 安装pyinstaller pip install pyinstaller打包文件 pyinstaller test.py 打包完成后会生成一个dist文件夹,打包的文件会在里面,找到test.exe。 pyinstaller -F test.py 加上-F会把所有的文件打包成一个exe,也是在dist文件夹下…

Docker File

DockerFile 是用来构建Docker镜像的构建文件&#xff0c;是由一些列命令和参数构成的脚本。 一、DockerFile 一、在home目录下创建docker-test-volume目录 cd /home mkdir docker-test-volume 二、在home目录下的docker-test-volume目录创建dockerfile1文件 vim dockerfile1…

运营-8.内容分发

内容分发本质要解决的问题包含两点&#xff1a; 1.高效的连接人与信息 2.过滤出有价值的信息&#xff0c;让合适的人看到合适的信息。 常见的内容分发方式 1.编辑分发 2.订阅分发 3.社交分发 4.算法分发 TIPS&#xff1a;根据产品性质、技术实力等因素&#xff0c;不同…

长尾学习(一):Long-Tail Learning via Logit Adjustment

一、背景 这是一篇从损失函数入手解决长尾问题的一种新思路&#xff0c;借鉴基于标签频次的logit adjustment方法&#xff0c;鼓励模型在高频类别与低频类别之间的Margin较大&#xff0c;提出了两种校准方法&#xff1a; 事后校准&#xff08;post-hoc adjustment&#xff09;…

tiechui_lesson03_缓冲读写与自定义控制

学习了与应用层通过缓冲区方式的交互&#xff0c;包括读写&#xff0c;自定义控制等。小坑比较多&#xff0c;大部分是是头文件和设置上的错误&#xff0c;跟着视频敲想快进就跳过了一些细节。包括&#xff1a; <windef.h> 头文件的引用 //使用DWORD等类型switch语句…

iOS开发多target

场景 背景:设想一下有一个场景,一个业务分为多种身份,他们大部分功能是相同的,但是也有自己的差异性。这种情况,想要构建出不同身份的APP。你会怎么做??? 当然,你可以拷贝一份代码出来,给项目重新命名。这样做的好处是,他们互相不会冲突,不用去关心是否有逻辑的冲…

Python中变量赋值过程的理解

Python中变量赋值过程的理解 在Python中对变量赋值过程的理解&#xff0c;有助于学习者对Python的变量和所指向的对象之间的指向关系深刻理解&#xff0c;避免编程中多个变量赋值后&#xff0c;对变量结果的不确定&#xff0c;减少赋值过程中疑问和困惑。 1.赋值过程基本过程 …

全文检索-Elasticsearch-进阶检索

文章目录 前言一、SearchAPI1.1 URL 后接参数检索1.2 URL 加请求体检索 二、Query DSL2.1 基本语法格式2.2 匹配查询 match2.3 短语匹配 match_phase2.4 多字段匹配 multi_match2.5 复合查询 bool2.6 过滤 filter2.7 查询 term2.8 聚合 aggregations 三、Mapping3.1 待完成3.2 …

Mybatis动态SQL用法

动态SQL是Mybatis的一大重要特性&#xff0c;它可以完成不同条件下的SQL拼接&#xff0c;降低了因为SQL语句书写中的小错误而造成程序报错的概率&#xff0c;例如拼接时要确保不能忘记添加必要的空格&#xff0c;还要注意去掉列表最后一个列名的逗号&#xff0c;利用动态SQL就可…