【机器学习】K近邻算法(K-NearestNeighbors , KNN)详解 + Java代码实现

news2024/11/15 7:46:32

文章目录

  • 一、KNN 基本介绍
  • 二、KNN 核心思想
  • 三、KNN 算法流程
  • 四、KNN 优缺点
  • 五、Java 代码实现 KNN
  • 六、KNN 改进策略


一、KNN 基本介绍

邻近算法,或者说K最邻近(KNN,K-NearestNeighbors)分类算法是分类方法中最简单的方法之一。所谓K最近邻,就是K个最近的邻居的意思,说的是每个样本都可以用它最接近的K个邻近值来代表。近邻算法就是将数据集合中每一个记录进行分类的方法。

KNN 最初由 Cover 和 Hart 于1968年提出,是一个理论上比较成熟的方法,也是最简单的机器学习算法之一。

该方法的思路非常简单直观:如果一个样本在特征空间中的 K 个最相似(即特征空间中最邻近)的样本中的大多数属于某一个类别,则该样本也属于这个类别。


二、KNN 核心思想

KNN算法的核心思想是,如果一个样本在特征空间中的K个最相邻的样本中的大多数属于某一个类别,则该样本也属于这个类别,并具有这个类别上样本的特性。

该方法在确定分类决策上只依据最邻近的一个或者几个样本的类别来决定待分样本所属的类别。KNN方法在类别决策时,只与极少量的相邻样本有关。

由于KNN方法主要靠周围有限的邻近的样本,而不是靠判别类域的方法来确定所属类别的,因此对于类域的交叉或重叠较多的待分样本集来说,KNN方法较其他方法更为适合。

下面举一个具体的例子(来源于:https://zhuanlan.zhihu.com/p/143092725)

如下图所示,图中绿色的点就是我们要预测的那个点。

假设 K=3。那么 KNN 算法就会找到与它距离最近的三个点(这里用圆圈把它圈起来了),看看哪种类别多一些,比如这个例子中是蓝色三角形多一些,新来的绿色点就归类到蓝三角了。

在这里插入图片描述
但是,当 K=5 的时候,判定就变成不一样了。这次变成红圆多一些,所以新来的绿点被归类成红圆。如下图所示:

在这里插入图片描述

从这个例子中,我们就能看得出 K 的取值是很重要的。


三、KNN 算法流程

  1. 准备数据,对数据进行预处理
  2. 计算测试样本点(也就是待分类点)到其他每个样本点的距离
  3. 对每个距离进行排序,然后选择出距离最小的K个点
  4. 对K个点所属的类别进行比较,根据少数服从多数的原则,将测试样本点归入在K个点中占比最高的那一类

注意:由于 KNN 算法中需要计算两点之间的距离,距离有很多种度量方式,比如常见的曼哈顿距离、欧式距离、切比雪夫距离等等。不过通常 KNN 算法中使用的是欧式距离。


四、KNN 优缺点

KNN 优点

  • KNN 方法思路简单,易于理解,易于实现,无需估计参数(同样是分类算法,逻辑回归需要先对数据进行大量训练,最后才会得到一个算法模型。而 KNN 算法却不需要,它没有明确的训练数据的过程,或者说这个过程很快)
  • 模型训练时间快
  • 对异常值不敏感
  • 预测效果好

KNN 缺点

  • 对内存要求较高,因为该算法存储了所有训练数据
  • 预测阶段可能很慢,因为要从大量的训练数据中找到最近的 K 个点
  • 当样本不平衡时,如一个类的样本容量很大,而其他类样本容量很小时,有可能导致当输入一个新样本时,该样本的K个邻居中大容量类的样本占多数

五、Java 代码实现 KNN

由于网络上关于 Python 实现 KNN 的博客实在是太多啦,所以本篇博客就以 Java 实现 KNN !Python 的话可以直接调用 sklearn,非常方便~

TrainDataSet:训练集对象

public class TrainDataSet {

    /**
     * 特征集合
     **/
    public List<double[]> features = new ArrayList<>();
    /**
     * 标签集合
     **/
    public List<Integer> labels = new ArrayList<>();
    /**
     * 特征向量维度
     **/
    public int featureDim;

    public int size() {
        return labels.size();
    }

    public double[] getFeature(int index) {
        return features.get(index);
    }

    public int getLabel(int index) {
        return labels.get(index);
    }

    public void addData(double[] feature, int label) {
        if (features.isEmpty()) {
            featureDim = feature.length;
        } else {
            if (featureDim != feature.length) {
                throwDimensionMismatchException(feature.length);
            }
        }
        features.add(feature);
        labels.add(label);
    }

    public void throwDimensionMismatchException(int errorLen) {
        throw new RuntimeException("DimensionMismatchError: 你应该传入维度为 " + featureDim + " 的特征向量 , 但你传入了维度为 " + errorLen + " 的特征向量");
    }

}

KNearestNeighbors:KNN算法对象

public class KNearestNeighbors {
    /**
     * 训练数据集
     **/
    TrainDataSet trainDataSet;
    /**
     * k值
     **/
    int k;
    /**
     * 距离公式
     **/
    DistanceType distanceType;

    /**
     * @param trainDataSet: 训练数据集
     * @param k:            k值
     */
    public KNearestNeighbors(TrainDataSet trainDataSet, int k, DistanceType distanceType) {
        this.trainDataSet = trainDataSet;
        this.k = k;
        this.distanceType = distanceType;
    }

    // 传入特征,返回预测值
    public int predict(double[] feature) {
        if (feature.length != trainDataSet.featureDim) {
            trainDataSet.throwDimensionMismatchException(feature.length);
        }
        PriorityQueue<Node> nodePriorityQueue = new PriorityQueue<>();
        for (int i = 0; i < trainDataSet.size(); i++) {
            nodePriorityQueue.add(new Node(trainDataSet.getLabel(i), calcDistance(trainDataSet.getFeature(i), feature)));
        }
        int cnt = 0;
        Map<Integer, Integer> map = new HashMap<>();
        int predictLabel = -1;
        int maxNum = -1;
        for (int i = 0; i < k && !nodePriorityQueue.isEmpty(); i++) {
            int label = nodePriorityQueue.poll().label;
            if (map.containsKey(label)) {
                map.replace(label, map.get(label) + 1);
            } else {
                map.put(label, 1);
            }
            if (map.get(label) > maxNum) {
                maxNum = map.get(label);
                predictLabel = label;
            }
            cnt++;
        }
        if (cnt != k || maxNum == -1) {
            throw new RuntimeException("predict fail");
        }
        return predictLabel;
    }

    // 计算距离
    private double calcDistance(double[] arr1, double[] arr2) {
        switch (distanceType) {
            case EuclideanDistance:
                return calcEuclideanDistance(arr1, arr2);
            case ManhattanDistance:
                return calcManhattanDistance(arr1, arr2);
            case ChebyshevDistance:
                return calcChebyshevDistance(arr1, arr2);
            default:
                break;
        }
        throw new RuntimeException("未知的distanceType: " + distanceType);
    }

    // 计算欧式距离
    private double calcEuclideanDistance(double[] arr1, double[] arr2) {
        double res = 0d;
        for (int i = 0; i < arr1.length; i++) {
            res += Math.pow(arr1[i] - arr2[i], 2);
        }
        return Math.sqrt(res);
    }

    // 计算曼哈顿距离
    private double calcManhattanDistance(double[] arr1, double[] arr2) {
        double res = 0d;
        for (int i = 0; i < arr1.length; i++) {
            res += Math.abs(arr1[i] - arr2[i]);
        }
        return res;
    }

    // 计算切比雪夫距离
    private double calcChebyshevDistance(double[] arr1, double[] arr2) {
        double res = 0d;
        for (int i = 0; i < arr1.length; i++) {
            res = Math.max(res, Math.abs(arr1[i] - arr2[i]));
        }
        return res;
    }

    private static class Node implements Comparable<Node> {
        int label;
        double distance;

        public Node(int label, double distance) {
            this.label = label;
            this.distance = distance;
        }

        @Override
        public int compareTo(Node o) {
            return Double.compare(distance, o.distance);
        }
    }

    public enum DistanceType {
        // 欧式距离
        EuclideanDistance,
        // 曼哈顿距离
        ManhattanDistance,
        // 切比雪夫距离
        ChebyshevDistance;
    }

}

六、KNN 改进策略

目前对 KNN 算法改进的方向主要可以分为 4 类

  • 寻求更接近于实际的距离函数以取代标准的欧氏距离,典型的工作包括 WAKNN、VDM
  • 搜索更加合理的 K 值以取代指定大小的 K 值典型的工作包括 SNNB、 DKNAW
  • 运用更加精确的概率估测方法去取代简单的投票机制,典型的工作包括 KNNDW、LWNB、 ICLNB
  • 建立高效的索引,以提高 KNN 算法的运行效率,代表性的研究工作包括 KDTree、 NBTree

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/151202.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Cadence PCB仿真使用Allegro PCB SI为分立元件创建统一的模型并赋值方法图文教程

⏪《上一篇》   🏡《总目录》   ⏩《下一篇》 目录 1,概述2,配置方法3,总结1,概述 本文简单介绍使用Allegro PCB SI软件配置电压地网络电压的方法。 2,配置方法 第1步:打开待仿真的PCB文件,并确认软件为Allegro PCB SI 如果,打开软件不是Allegro PCB SI则可这样…

[ 数据结构 ] 背包问题(动态规划)

0 背包问题 有一个背包,容量为4磅,现有如下物品 物品重量价格吉他(G)11500音响(S)43000电脑(L)32000 1)要求达到的目标为装入的背包的总价值最大&#xff0c;并且重量不超出 2)要求装入的物品不能重复(01背包) 1 动态规划 动态规划(Dynamic Programming)算法的核心思想是&…

从0到1完成一个Vue后台管理项目(十一、前端分页实现)

往期 从0到1完成一个Vue后台管理项目&#xff08;一、创建项目&#xff09; 从0到1完成一个Vue后台管理项目&#xff08;二、使用element-ui&#xff09; 从0到1完成一个Vue后台管理项目&#xff08;三、使用SCSS/LESS&#xff0c;安装图标库&#xff09; 从0到1完成一个Vu…

测试分析--精准分析

测试分析的概念 测试分析是建立在对「需求本身」、「用户使用场景」以及对应的「系统架构」和「实现细节」的充分了解的基础上&#xff0c;通过对数据流、状态变化、逻辑时序、功能/性能/兼容性等方面的分析&#xff0c;得出测试点的过程&#xff1b; 在现阶段敏捷开发模式普遍…

【部分真题】2022年12月QMS质量管理体系试题(1-5题)尚大解析版

注1&#xff1a;由于是机考&#xff0c;题目顺序随机变化&#xff0c;但题目内容所有考生一致。 注2&#xff1a;选择题的选项顺序会随机改变&#xff0c;但选项的内容不变。 注3&#xff1a;为了方便学员学习与复习&#xff0c;已经按教程&考试大纲进行全面优化排序。 注4…

4644. 求和

4644. 求和 https://www.acwing.com/problem/content/description/4647/ 第十三届蓝桥杯省赛CA/C组 , 第十三届蓝桥杯省赛JAVAA组 算法标签&#xff1a;推公式&#xff1b;前缀和 思路 推公式做法&#xff1a; (a1a2a3...an)2a12a22a32...an22a1a22a1a3...2a1an2a2a3...2an−…

【模板】最小生成树(C++)

题目描述 如题&#xff0c;给出一个无向图&#xff0c;求出最小生成树&#xff0c;如果该图不连通&#xff0c;则输出 orz。 输入格式 第一行包含两个整数 N,MN,MN,M&#xff0c;表示该图共有 NNN 个结点和 MMM 条无向边。 接下来 MMM 行每行包含三个整数 Xi,Yi,ZiX_i,Y_i,…

设计师必备的免费样机素材

很多设计师会用样机模型来展示自己的作品&#xff0c;让设计图案、应用界面等作品应用到实物效果图中&#xff0c;能体现作品的最终效果&#xff0c;更加形象逼真。哪里能下载到样机模板呢&#xff1f;今天我就推荐6个网站帮你解决&#xff0c;赶紧收藏&#xff01; 1、菜鸟图库…

20230109测试ToyBrick的RK3588开发板运行Buildroot的V0.02版本(20220312)

20230109测试ToyBrick的RK3588开发板运行Buildroot的V0.02版本&#xff08;20220312&#xff09; 2023/1/9 18:03 https://wiki.t-firefly.com/zh_CN/Firefly-Linux-Guide/manual_buildroot.html 1. Buildroot 使用手册 1.1. 桌面应用 官方发布的 Buildroot 固件&#xff0c;默…

RabbitMQ学习一【尚硅谷】

一、消息队列 1、MQ的相关概念 2、RabbitMQ 2.1 四大核心概念 生产者&#xff1a; 交换机&#xff1a;交换机是 RabbitMQ非常重要的一个部件&#xff0c;一方面它接收来自生产者的消息&#xff0c;另一方面它将消息 推送到队列中。交换机必须确切知道如何处理它接收到的消息…

一文详解Linux Python3安装

在公司申请了一台CentOS 7的Linux版本虚拟机&#xff0c;需要安装一个Python3的环境&#xff0c;定期进行特定任务处理。这里对CentOS 7配置Python3环境的步骤进行了记录&#xff0c;供大家参考。 本文基于如下Linux系统版本&#xff1a; 一、默认Python版本 默认情况下&am…

Excelize 2.7.0 发布, 2023 年首个更新

Excelize 是 Go 语言编写的用于操作 Office Excel 文档基础库&#xff0c;基于 ECMA-376&#xff0c;ISO/IEC 29500 国际标准。可以使用它来读取、写入由 Microsoft Excel™ 2007 及以上版本创建的电子表格文档。支持 XLAM / XLSM / XLSX / XLTM / XLTX 等多种文档格式&#xf…

C 程序设计教程(13)—— 顺序结构程序设计练习题

C 程序设计教程&#xff08;13&#xff09;—— 顺序结构程序设计练习题 该专栏主要介绍 C 语言的基本语法&#xff0c;作为《程序设计语言》课程的课件与参考资料&#xff0c;用于《程序设计语言》课程的教学&#xff0c;供入门级用户阅读。 目录C 程序设计教程&#xff08;1…

【openGauss】在openEuler(ARM架构)上安装openGauss(一主两备含CM版)

一、系统版本介绍 当前案例中的openGauss安装&#xff0c;底层操作系统为openEuler-20.03-LTS版本&#xff0c;当前openGauss对Python版本兼容性最好的是Python 3.6版本与Python 3.7版本&#xff0c;该实验使用的openEuler版本自带Python 3.7.4&#xff0c;不需要再自行安装 二…

汽车电子系统网络安全活动

声明 本文是学习GB-T 38628-2020 信息安全技术 汽车电子系统网络安全指南. 下载地址 http://github5.com/view/764而整理的学习笔记,分享出来希望更多人受益,如果存在侵权请及时联系我们 汽车电子系统网络安全活动 7.1 概念设计阶段 7.1.1 概述 概念设计阶段的活动流程如图…

房产管理系统分布架构分析

一、数图互通房产管理系统采用分布式架构下的高可用设计&#xff1a; &#xff08;1)可以避免因单点故障造成系统平台宕机&#xff1a; a、负载均衡技术&#xff08;failover &#xff0c;选址&#xff0c;硬件负载&#xff0c;软件负载&#xff0c;去中心化负载&#xff08;g…

tp5处理前端上传的图片文件

前端上传了一个图片文件,tp5框架如何处理 效果图&#xff1a; 效果图一: 效果图二: 如果需要看前端如何展示、删除上传的缩略图请到此篇博客&#xff1a; 前端&#xff1a; <form id"upload_pic_wrap" target"upload_file" enctype"multipar…

任务间通讯

信号量与邮箱 系统中的多个任务在运行时&#xff0c;经常需要互相无冲突地访问同一个共享资源&#xff0c;或者需要互相支持和依赖&#xff0c;甚至有时还要互相加以必要的限制和制约&#xff0c;才保证任务的顺利运行。因此&#xff0c;操作系统必须具有对任务的运行进行协调…

C++11引入的尾置返回类型

C11引入的尾置返回类型一、什么是尾置返回类型(trailing return type)二、尾置返回的典型场景2.1 常规方式如何返回数组指针2.2 使用尾置返回类型三、尾置返回类型的应用四、总结一、什么是尾置返回类型(trailing return type) 我们先来看一下传统的函数是怎么定义的&#xff…

Leetcode N皇后

题目链接 Leetcode.51 N 皇后 Leetcode.52 N皇后 II N皇后 题目描述 按照国际象棋的规则&#xff0c;皇后可以攻击与之处在同一行或同一列或同一斜线上的棋子。 n 皇后问题 研究的是如何将 n 个皇后放置在 nn的棋盘上&#xff0c;并且使皇后彼此之间不能相互攻击。 给你一个…