机器学习:k近邻算法(Python)

news2024/12/24 2:33:44

一、k近邻算法的定义

二、KD树结点信息封装 

kdtree_node.py


class KDTreeNode:
    """
    KD树结点信息封装
    """
    def __init__(self, instance_node=None, instance_label=None, instance_idx=None,
                 split_feature=None, left_child=None, right_child=None, kdt_depth=None):
        """
        用于封装kd树的结点信息结构
        :param instance_node: 实例点,一个样本
        :param instance_label: 实例点对应的类别标记
        :param instance_idx: 该实例点对应的样本索引,用于kd树的可视化
        :param split_feature: 划分的特征属性,x^(i)
        :param left_child: 左子树,小于划分点的
        :param right_child: 右子树,大于切分点的
        :param kdt_depth: kd树的深度
        """
        self.instance_node = instance_node
        self.instance_label = instance_label
        self.instance_idx = instance_idx
        self.split_feature = split_feature
        self.left_child = left_child
        self.right_child = right_child
        self.kdt_depth = kdt_depth

三、距离度量的工具类

distUtils.py

import numpy as np


class DistanceUtils:
    """
    距离度量的工具类,此处仅实现闵可夫斯基距离
    """
    def __init__(self, p=2):
        self.p = p  # 默认欧式距离,p=1曼哈顿距离,p=np。inf是切比雪夫距离

    def distance_func(self, xi, xj):
        """
        特征空间中两个样本示例的距离计算
        :param xi: k维空间某个样本示例
        :param xj: k维空间某个样本示例
        :return:
        """
        xi, xj = np.asarray(xi), np.asarray(xj)
        if self.p == 1 or self.p == 2:
            return (((np.abs(xi - xj)) ** self.p).sum()) ** (1 / self.p)
        elif self.p == np.inf:
            return np.max(np.abs(xi - xj))
        elif self.p == "cos":  # 余弦距离或余弦相似度
            return xi.dot(xj) / np.sqrt((xi ** 2).sum()) / np.sqrt((xj ** 2).sum())
        else:
            raise ValueError("目前仅支持p=1、p=2、p=np.inf或余弦距离四种距离...")

四、K近邻算法的实现

knn_kdtree.py

import numpy as np
from kdtree_node import KDTreeNode
from distUtils import DistanceUtils
import heapq  # 堆结构,实现堆排序
from collections import Counter  # 集合中的计数功能
import networkx as nx  # 网络图,可视化
import matplotlib.pyplot as plt


class KNearestNeighborKDTree:
    """
    K近邻算法的实现,基于KD树结构
    1. fit: 特征向量空间的划分,即构建KD树(建立KNN算法模型)
    2. predict: 预测,近邻搜索
    3. 可视化kd树
    """
    def __init__(self, k: int=5, p=2, view_kdt=False):
        """
        KNN算法的初始化必要参数
        :param k: 近邻数
        :param p: 距离度量标准
        :param view_kdt: 是否可视化KD树
        """
        self.k = k  # 预测,近邻搜索时,使用的参数,表示近邻树
        self.p = p  # 预测,近邻搜索时,使用的参数,表示样本的近邻度
        self.view_kdt = view_kdt
        self.dis_utils = DistanceUtils(self.p)  # 距离度量的类对象
        self.kdt_root: KDTreeNode() = None  # KD树的根节点
        self.k_dimension = 0  # 特征空间维度,即样本的特征属性数
        self.k_neighbors = []  # 用于记录某个测试样本的近邻实例点

    def fit(self, x_train, y_train):
        """
        递归创建KD树,即对特征向量空间进行划分,递归调用进行创建
        :param x_train: 训练样本集
        :param y_train: 训练样本目标集合
        :return:
        """
        if self.k < 1:
            raise ValueError("k must be greater than 0 and be int.")
        x_train, y_train = np.asarray(x_train), np.asarray(y_train)
        self.k_dimension = x_train.shape[1]  # 特征维度
        idx_array = np.arange(x_train.shape[0])  # 训练样本索引编号
        self.kdt_root = self._build_kd_tree(x_train, y_train, idx_array, 0)
        if self.view_kdt:
            self.draw_kd_tree()  # 可视化kd树

    def _build_kd_tree(self, x_train, y_train, idx_array, kdt_depth):
        """
        递归创建KD树,KD树是二叉树,严格区分左子树右子树,表示对k维空间的一个划分
        :param x_train: 递归划分的训练样本子集
        :param y_train: 递归划分的训练样本目标子集
        :param idx_array: 递归划分的样本索引
        :param depth: kd树的深度
        :return:
        """
        if x_train.shape[0] == 0:  # 递归出口
            return

        split_dimension = kdt_depth % self.k_dimension  # 数据的划分维度x^(i)
        sorted(x_train, key=lambda x: x[split_dimension])  # 按某个划分维度排序
        median_idx = x_train.shape[0] // 2  # 中位数所对应的数据的索引
        median_node = x_train[median_idx]  # 切分点作为当前子树的根节点
        # 划分左右子树区域
        left_instances, right_instances = x_train[:median_idx], x_train[median_idx + 1:]
        left_labels, right_labels = y_train[:median_idx], y_train[median_idx + 1:]
        left_idx, right_idx = idx_array[:median_idx], idx_array[median_idx + 1:]
        # 递归调用
        left_child = self._build_kd_tree(left_instances, left_labels, left_idx, kdt_depth + 1)
        right_child = self._build_kd_tree(right_instances, right_labels, right_idx, kdt_depth + 1)
        kdt_new_node = KDTreeNode(median_node, y_train[median_idx], idx_array[median_idx],
                                  split_dimension, left_child, right_child, kdt_depth)
        return kdt_new_node

    def _search_kd_tree(self, kd_tree: KDTreeNode, x_test):
        """
        kd树的递归搜索算法,后序遍历,搜索k个最近邻实例点
        数据结构:堆排序,搜索过程中,维护一个小根堆
        :param kd_tree: 已构建的kd树
        :param x_test: 单个测试样本
        :return:
        """
        if kd_tree is None:  # 递归出口
            return

        # 计算测试样本与当前kd子树的根结点的距离(相似度)
        distance = self.dis_utils.distance_func(kd_tree.instance_node, x_test)
        # 1. 如果不够k个样本,继续递归
        # 2. 如果搜索了k个样本,但是k个样本未必是最近邻的。
        # 当计算的当前实例点的距离小于k个样本的最大距离,则递归,大于最大距离,没必要递归
        if (len(self.k_neighbors) < self.k) or (distance < self.k_neighbors[-1]["distance"]):
            self._search_kd_tree(kd_tree.left_child, x_test)  # 递归左子树
            self._search_kd_tree(kd_tree.right_child, x_test)  # 递归右子树
            # 在整个搜索路径上的kd树的结点,存储在self.k_neighbors中,包含三个值
            # 当前实例点,类别,距离
            self.k_neighbors.append({
                "node": kd_tree.instance_node,  # 结点
                "label": kd_tree.instance_label,  # 当前实例的类别
                "distance": distance  # 当前实例点与测试样本的距离
            })
            # 按照距离进行排序,选择最小的k个最近邻样本实例,更新最近邻距离
            # 小根堆,k_neighbors中第一个结点是距离测试样本最近的
            self.k_neighbors = heapq.nsmallest(self.k, self.k_neighbors,
                                               key=lambda d: d["distance"])

    def predict(self, x_test):
        """
        KD树的近邻搜索,即测试样本的预测
        :param x_test: 测试样本,ndarray: (n * k)
        :return:
        """
        x_test = np.asarray(x_test)
        if self.kdt_root is None:
            raise ValueError("KDTree is None, Please fit KDTree...")
        elif x_test.shape[1] != self.k_dimension:
            raise ValueError("Test Sample dimension unmatched KDTree's dimension.")
        else:
            y_test_hat = []  # 用于存储测试样本的预测类别
            for i in range(x_test.shape[0]):
                self.k_neighbors = []  # 调用递归搜索,则包含了k个最近邻的实例点
                self._search_kd_tree(self.kdt_root, x_test[i])
                # print(self.k_neighbors)
                y_test_labels = []
                # 取每个近邻样本的类别标签
                for k in range(self.k):
                    y_test_labels.append(self.k_neighbors[k]["label"])
                # 按分类规则(多数表决法)
                # print(y_test_labels)
                counter = Counter(y_test_labels)
                idx = int(np.argmax(list(counter.values())))
                y_test_hat.append(list(counter.keys())[idx])
        return np.asarray(y_test_hat)

    def _create_kd_tree(self, graph, kdt_node: KDTreeNode, pos=None, x=0, y=0, layer=1):
        """
        递归可视化KD树,递归构造树的结点、边。
        :param graph: 有向图对象,递归中逐步增加结点和左子树右子树
        :param kdt_node: 递归创建KD树的结点
        :param pos: 可视化中树结点位置,初始化(0, 0)绘制根结点
        :param x: 对应pos中的横坐标,随着递归,更新
        :param y: 对应pos中的纵坐标,随着递归,更新
        :param layer: kd树的层次
        :return:
        """
        if pos is None:
            pos = {}
        pos[str(kdt_node.instance_idx)] = (x, y)
        if kdt_node.left_child:
            # 父结点指向左子树
            graph.add_edge(str(kdt_node.instance_idx), str(kdt_node.left_child.instance_idx))
            l_x, l_y = x - 1 / 2 ** layer, y - 1  # 下一个树结点位置的计算
            l_layer = layer + 1  # 树的层次 + 1
            self._create_kd_tree(graph, kdt_node.left_child, x=l_x, y=l_y, pos=pos, layer=l_layer)  # 递归
        if kdt_node.right_child:
            # 父结点指向右子树
            graph.add_edge(str(kdt_node.instance_idx), str(kdt_node.right_child.instance_idx))
            r_x, r_y = x + 1 / 2 ** layer, y - 1
            r_layer = layer + 1
            self._create_kd_tree(graph, kdt_node.right_child, x=r_x, y=r_y, pos=pos, layer=r_layer)  # 递归
        return graph, pos

    def draw_kd_tree(self):
        """
        可视化kd树
        :return:
        """
        directed_graph = nx.DiGraph()  # 初始化一个有向图,树
        graph, pos = self._create_kd_tree(directed_graph, self.kdt_root)
        fig, ax = plt.subplots(figsize=(20, 10))  # 比例可以根据树的深度适当调节
        nx.draw_networkx(graph, pos, ax=ax, node_size=500, font_color="w", font_size=15,
                         arrowsize=20)
        plt.tight_layout()
        plt.show()

 五、K近邻算法的测试

test_knn_1.py

import numpy as np
from sklearn.datasets import load_iris, load_breast_cancer
from knn_kdtree import KNearestNeighborKDTree
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, accuracy_score
import matplotlib.pyplot as plt
from sklearn.model_selection import StratifiedKFold
from sklearn.preprocessing import StandardScaler


iris = load_iris()
X, y = iris.data, iris.target

# bc_data = load_breast_cancer()
# X, y = bc_data.data, bc_data.target
X = StandardScaler().fit_transform(X)

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=0, stratify=y)
k_neighbors = np.arange(3, 21)

# acc = []
# for k in k_neighbors:
#     knn = KNearestNeighborKDTree(k=k)
#     knn.fit(X_train, y_train)
#     y_test_hat = knn.predict(X_test)
#     # print(classification_report(y_test, y_test_hat))
#     acc.append(accuracy_score(y_test, y_test_hat))

accuracy_scores = []  # 存储每个alpha阈值下的交叉验证均分
for k in k_neighbors:
    scores = []
    k_fold = StratifiedKFold(n_splits=10).split(X, y)
    for train_idx, test_idx in k_fold:
        # knn = KNearestNeighborKDTree(k=k, p="cos")
        knn = KNearestNeighborKDTree(k=k)
        knn.fit(X[train_idx], y[train_idx])
        y_test_pred = knn.predict(X[test_idx])
        scores.append(accuracy_score(y[test_idx], y_test_pred))
        del knn
    print("k = %d:" % k, np.mean(scores))
    accuracy_scores.append(np.mean(scores))

plt.figure(figsize=(7, 5))
plt.plot(k_neighbors, accuracy_scores, "ko-", lw=1)
plt.grid(ls=":")
plt.xlabel("K Neighbors", fontdict={"fontsize": 12})
plt.ylabel("Accuracy Scores", fontdict={"fontsize": 12})
plt.title("KNN(KDTree) Testing Scores under different K Neighbors", fontdict={"fontsize": 14})
plt.show()

# knn = KNearestNeighborKDTree(k=3)
# knn.fit(X_train, y_train)
# knn.draw_kd_tree()


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1452166.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

对待不合理需求,前端工程师如何优雅的say no!

曾经有位老板&#xff0c; 每次给前端提需求&#xff0c;前端都说实现不了&#xff0c;后来他搜索了一下&#xff0c;发现网上都有答案。他就在招聘要求上加了条&#xff1a;麻烦你在说不行的时候&#xff0c;搜索一下。 上面是一个段子&#xff0c;说的有点极端了&#xff0c;…

简单DP算法(动态规划)

简单DP算法 算法思想例题1、01背包问题题目信息思路题解 2、摘花生题目信息思路题解 3、最长上升子序列题目信息思路题解 题目练习1、地宫取宝题目信息思路题解 2、波动数列题目信息思路题解 算法思想 从集合角度来分析DP问题 例如求最值、求个数 例题 1、01背包问题 题目…

2.16学习总结

1.邮递员送信&#xff08;dijkstra 不只是从起到到目标点&#xff0c;还要走回去&#xff09; 2.炸铁路(并查集) 3.统计方形&#xff08;数据加强版&#xff09;&#xff08;排列组合&#xff09; 4.滑雪&#xff08;记忆化&#xff09; 5.小车问题&#xff08;数学问题&#x…

高B格可视化大屏设计具备的10大特征

简洁明了&#xff1a; 可视化大屏界面应该尽可能简洁明了&#xff0c;突出重点&#xff0c;避免过多的信息和视觉干扰。同时&#xff0c;需要考虑到用户的视觉效果和易用性&#xff0c;使用户能够迅速地获取所需信息。 数据精准&#xff1a; 可视化大屏界面显示的数据应该准确…

阿里云BGP多线精品EIP香港CN2线路低时延,价格贵

阿里云香港等地域服务器的网络线路类型可以选择BGP&#xff08;多线&#xff09;和 BGP&#xff08;多线&#xff09;精品&#xff0c;普通的BGP多线和精品有什么区别&#xff1f;BGP&#xff08;多线&#xff09;适用于香港本地、香港和海外之间的互联网访问。使用BGP&#xf…

react+ts【项目实战一】配置项目/路由/redux

文章目录 1、项目搭建1、创建项目1.2 配置项目1.2.1 更换icon1.2.2 更换项目名称1.2.1 配置项目别名 1.3 代码规范1.3.1 集成editorconfig配置1.3.2 使用prettier工具 1.4 项目结构1.5 对css进行重置1.6 注入router1.7 定义TS组件的规范1.8 创建代码片段1.9 二级路由和懒加载1.…

今日早报 每日精选15条新闻简报 每天一分钟 知晓天下事 2月17日,星期六

每天一分钟&#xff0c;知晓天下事&#xff01; 2024年2月17日 星期六 农历正月初八 1、 中疾控&#xff1a;我国自主研发的猴痘mRNA疫苗即将进入临床试验。 2、 2024年度总票房破100亿元&#xff0c;其中春节档已突破70亿元。 3、 国产大飞机首次国外亮相&#xff0c;C919已抵…

5年前端老司机:浅谈web前端开发技术点

有部分同学和朋友问到过我相关问题。利用周末我就浅浅地谈谈我对web前端开发的理解和体会&#xff0c;仅仅能浅浅谈谈&#xff0c;高手请自己主动跳过本篇文章。 毕竟我如今经验并非非常足&#xff0c;连project师都算不上&#xff0c;更不用说大牛了。今天也不谈技术。技术非…

2.14日学习打卡----初学Zookeeper(一)

2.14日学习打卡 目录: 2.14日学习打卡Zookeeper概念一. 集中式到分布式单机架构集群架构什么是分布式三者区别 二. CAP定理分区容错性一致性可用性一致性和可用性的矛盾一致性和可用性如何选择 三. 什么是Zookeeper分布式架构Zookeeper从何而来Zookeeper介绍 四. 应用场景数据发…

Android 车载应用开发之SystemUI 详解

一、SystemUI SystemUI全称System User Interface,直译过来就是系统级用户交互界面,在 Android 系统中由SystemUI负责统一管理整个系统层的 UI,它是一个系统级应用程序(APK),源码在/frameworks/base/packages/目录下,而不是在/packages/目录下,这也说明了SystemUI这个…

集群聊天项目

不懂的一些东西 (const TcpConnectionPtr&&#xff09;作为形参啥意思&#xff1a;接收一个常量引用&#xff0c;函数内部不允许修改该指针所指向的对象。 优势 1.网络层与业务层分离&#xff1a;通过网络层传来的id&#xff0c;设计一个map存储id以及对印的业务处理器&…

文件上传漏洞--Upload-labs--Pass01--前端绕过

一、前端绕过原理 通俗解释&#xff0c;我们将写有恶意代码的php后缀文件上传到网页&#xff0c;网页中的javascript代码会先对文件的后缀名进行检测&#xff0c;若检测到上传文件的后缀名为非法&#xff0c;则会进行alert警告。若想上传php后缀的文件&#xff0c;就要想办法对…

windows一开机一直循环:No Boot Device Found. Press any key to reboot the machine解决方法

一、长按F12 二、选择Settiings/General/Boot Sequence 三、选择UEFI模式&#xff0c; 四、选择下方APPLY 五、退出&#xff1a;

【Spring面试题】

目录 前言 1.Spring框架中的单例bean是线程安全的吗? 2.什么是AOP? 3.你们项目中有没有使用到AOP&#xff1f; 4.Spring中的事务是如何实现的&#xff1f; 5.Spring中事务失效的场景有哪些&#xff1f; 6.Spring的bean的生命周期。 7.Spring中的循环引用 8.构造方法…

Linux下解压tar.xz文件的命令

tar -c: 建立压缩档案-x&#xff1a;解压-t&#xff1a;查看内容-r&#xff1a;向压缩归档文件末尾追加文件-u&#xff1a;更新原压缩包中的文件 ------------------------------------------ 这五个是独立的命令&#xff0c;压缩解压都要用到其中一个&#xff0c;可以和别的…

机器学习入门--循环神经网络原理与实践

循环神经网络 循环神经网络&#xff08;RNN&#xff09;是一种在序列数据上表现出色的人工神经网络。相比于传统前馈神经网络&#xff0c;RNN更加适合处理时间序列数据&#xff0c;如音频信号、自然语言和股票价格等。本文将介绍RNN的基本数学原理、使用PyTorch和Scikit-Learn…

上位机图像处理和嵌入式模块部署(图像项目处理过程)

【 声明&#xff1a;版权所有&#xff0c;欢迎转载&#xff0c;请勿用于商业用途。 联系信箱&#xff1a;feixiaoxing 163.com】 对于一般的图像项目来说&#xff0c;图像处理只是工作当中的一部分。在整个项目处理的过程中有很多的内容需要处理&#xff0c;比如说了解需求、评…

信息安全技术基础知识

一、考点分布 信息安全基础&#xff08;※※&#xff09;信息加密解密技术&#xff08;※※※&#xff09;密钥管理技术&#xff08;※※&#xff09;访问控制及数字签名技术&#xff08;※※※&#xff09;信息安全的保障体系 二、信息安全基础 信息安全包括5个基本要素&#…

【COMP337 LEC3】

LEC 3 Mathematical Preliminaries Common Discrete Probability Distributions 1. Bernoulli distribution : 伯努利分布 models binary outcomes (coin flip). 模型二进制结果 P ( X head ) p and P ( X tail ) 1 − p 2. Generalised Bernoulli distribution…

牛客网SQL进阶123:高难度试卷的得分的截断平均值

官网链接&#xff1a; SQL类别高难度试卷得分的截断平均值_牛客题霸_牛客网牛客的运营同学想要查看大家在SQL类别中高难度试卷的得分情况。 请你帮她从exam_。题目来自【牛客题霸】https://www.nowcoder.com/practice/a690f76a718242fd80757115d305be45?tpId240&tqId2180…