【python】机器学习-K-近邻(KNN)算法

news2025/1/4 20:07:47

        

目录

一 . K-近邻算法(KNN)概述 

二、KNN算法实现

三、 MATLAB实现

四、 实战


一 . K-近邻算法(KNN)概述 

        K-近邻算法(KNN)是一种基本的分类算法,它通过计算数据点之间的距离来进行分类。在KNN算法中,当我们需要对一个未知数据点进行分类时,它会与训练集中的各个数据点进行特征比较,并找到与之最相似的前K个数据点。然后根据这K个数据点的类别来确定未知数据点所属的类别。

        KNN算法的步骤非常简单: 1)计算未知数据点与训练集中各个数据点之间的距离。常用的距离度量包括欧氏距离和曼哈顿距离。 2)按照距离递增的顺序对数据点进行排序。 3)选择距离最小的K个数据点。 4)根据这K个数据点的类别来确定未知数据点的类别。通常采用多数表决的方式,即统计K个数据点中各个类别出现的次数,将出现次数最多的类别作为未知数据点的预测类别。

        KNN算法的特点是简单易懂,容易实现。它没有显式的训练过程,仅依赖于已有的训练数据。然而,KNN算法的计算复杂度较高,尤其是当训练集很大时。此外,KNN算法还对训练样本的质量和数量敏感,需要合理地选择K值和距离度量方法。

     在KNN中,通过计算对象间距离来作为各个对象之间的非相似性指标,避免了对象之间的匹配问题,在这里距离一般使用欧氏距离或曼哈顿距离:

    

        同时,KNN通过依据k个对象中占优的类别进行决策,而不是单一的对象类别决策。这两点就是KNN算法的优势。

   接下来对KNN算法的思想总结一下:就是在训练集中数据和标签已知的情况下,输入测试数据,将测试数据的特征与训练集中对应的特征进行相互比较,找到训练集中与之最为相似的前K个数据,则该测试数据对应的类别就是K个数据中出现次数最多的那个分类,其算法的描述为:

  1. 首先需要收集足够的带有标签的训练数据,这些数据包含了输入特征和相应的输出标签。

  2. 对于输入的测试数据,需要计算它与每个训练数据之间的距离(如欧氏距离、曼哈顿距离等)。

  3. 选取距离测试数据最近的K个训练数据,并统计它们中出现最多的标签类别。

  4. 将测试数据归类为出现次数最多的标签类别。

二、KNN算法实现

        KNN算法的实现通常可以使用Python等编程语言进行实现

        

import numpy as np

class KNN():
    def __init__(self, k=3, distance='euclidean'):
        self.k = k
        self.distance = distance
        
    def fit(self, X, y):
        self.X_train = X
        self.y_train = y
        
    def predict(self, X):
        y_pred = []
        for x in X:
            distances = []
            for i, x_train in enumerate(self.X_train):
                if self.distance == 'euclidean':
                    dist = np.linalg.norm(x - x_train)
                elif self.distance == 'manhattan':
                    dist = np.sum(np.abs(x - x_train))
                distances.append((dist, self.y_train[i]))
            distances.sort()
            neighbors = distances[:self.k]
            classes = {}
            for neighbor in neighbors:
                if neighbor[1] in classes:
                    classes[neighbor[1]] += 1
                else:
                    classes[neighbor[1]] = 1
            max_class = max(classes, key=classes.get)
            y_pred.append(max_class)
        return y_pred

        这段代码实现了基本的KNN分类算法,包括fit函数进行训练集拟合,predict函数进行预测。其中k参数表示要选择的最近邻居数,distance参数为距离度量方法。在上述示例代码中,欧氏距离和曼哈顿距离两种距离度量方法均已实现。

        通过选择不同的数据集和参数,可以验证KNN算法的分类性能。在实现KNN算法时,还可以采用更加高效的数据结构(如kd树、球树)和距离度量方法等技巧,来对算法进行优化和改进。

三、 MATLAB实现

        

  1. 使用pdist2函数计算欧氏距离,而不是手动计算,可以极大地提高计算速度。

  2. 在计算距离之后,直接利用sort函数进行排序,并选择前k个最近邻。这样可以简化代码,并且使用向量化计算,计算速度更快。

  3. 使用mode函数求取邻居中出现次数最多的类别作为预测结果,并且使用2维输入方式保证正确性。

function y_pred = knn(X_train, y_train, X_test, k)
    n_train = size(X_train, 1);
    n_test = size(X_test, 1);
    y_pred = zeros(n_test, 1);

    % 计算欧氏距离
    distances = pdist2(X_train, X_test);
    
    % 选择前k个最近邻
    [~, indices] = sort(distances);
    neighbors = y_train(indices(1:k,:));
    
    % 使用投票法预测标签
    y_pred = mode(neighbors, 1)';
end

四、 实战

     在这里根据一个人收集的约会数据,根据主要的样本特征以及得到的分类,对一些未知类别的数据进行分类,大致就是这样。 

     我使用的是python 3.4.3,首先建立一个文件,例如date.py,具体的代码如下:

#coding:utf-8

from numpy import *
import operator
from collections import Counter
import matplotlib
import matplotlib.pyplot as plt


###导入特征数据
def file2matrix(filename):
    fr = open(filename)
    contain = fr.readlines()###读取文件的所有内容
    count = len(contain)
    returnMat = zeros((count,3))
    classLabelVector = []
    index = 0
    for line in contain:
        line = line.strip() ###截取所有的回车字符
        listFromLine = line.split('\t')
        returnMat[index,:] = listFromLine[0:3]###选取前三个元素,存储在特征矩阵中
        classLabelVector.append(listFromLine[-1])###将列表的最后一列存储到向量classLabelVector中
        index += 1
    
    ##将列表的最后一列由字符串转化为数字,便于以后的计算
    dictClassLabel = Counter(classLabelVector)
    classLabel = []
    kind = list(dictClassLabel)
    for item in classLabelVector:
        if item == kind[0]:
            item = 1
        elif item == kind[1]:
            item = 2
        else:
            item = 3
        classLabel.append(item)
    return returnMat,classLabel#####将文本中的数据导入到列表

##绘图(可以直观的表示出各特征对分类结果的影响程度)
datingDataMat,datingLabels = file2matrix('D:\python\Mechine learing in Action\KNN\datingTestSet.txt')
fig = plt.figure()
ax = fig.add_subplot(111)
ax.scatter(datingDataMat[:,0],datingDataMat[:,1],15.0*array(datingLabels),15.0*array(datingLabels))
plt.show()

## 归一化数据,保证特征等权重
def autoNorm(dataSet):
    minVals = dataSet.min(0)
    maxVals = dataSet.max(0)
    ranges = maxVals - minVals
    normDataSet = zeros(shape(dataSet))##建立与dataSet结构一样的矩阵
    m = dataSet.shape[0]
    for i in range(1,m):
        normDataSet[i,:] = (dataSet[i,:] - minVals) / ranges
    return normDataSet,ranges,minVals

##KNN算法
def classify(input,dataSet,label,k):
    dataSize = dataSet.shape[0]
    ####计算欧式距离
    diff = tile(input,(dataSize,1)) - dataSet
    sqdiff = diff ** 2
    squareDist = sum(sqdiff,axis = 1)###行向量分别相加,从而得到新的一个行向量
    dist = squareDist ** 0.5
    
    ##对距离进行排序
    sortedDistIndex = argsort(dist)##argsort()根据元素的值从大到小对元素进行排序,返回下标

    classCount={}
    for i in range(k):
        voteLabel = label[sortedDistIndex[i]]
        ###对选取的K个样本所属的类别个数进行统计
        classCount[voteLabel] = classCount.get(voteLabel,0) + 1
    ###选取出现的类别次数最多的类别
    maxCount = 0
    for key,value in classCount.items():
        if value > maxCount:
            maxCount = value
            classes = key
    return classes

##测试(选取10%测试)
def datingTest():
    rate = 0.10
    datingDataMat,datingLabels = file2matrix('D:\python\Mechine learing in Action\KNN\datingTestSet.txt')
    normMat,ranges,minVals = autoNorm(datingDataMat)
    m = normMat.shape[0]
    testNum = int(m * rate)
    errorCount = 0.0
    for i in range(1,testNum):
        classifyResult = classify(normMat[i,:],normMat[testNum:m,:],datingLabels[testNum:m],3)
        print("分类后的结果为:,", classifyResult)
        print("原结果为:",datingLabels[i])
        if(classifyResult != datingLabels[i]):
                                  errorCount += 1.0
    print("误分率为:",(errorCount/float(testNum)))
                                  
###预测函数
def classifyPerson():
    resultList = ['一点也不喜欢','有一丢丢喜欢','灰常喜欢']
    percentTats = float(input("玩视频所占的时间比?"))
    miles = float(input("每年获得的飞行常客里程数?"))
    iceCream = float(input("每周所消费的冰淇淋公升数?"))
    datingDataMat,datingLabels = file2matrix('D:\python\Mechine learing in Action\KNN\datingTestSet2.txt')
    normMat,ranges,minVals = autoNorm(datingDataMat)
    inArr = array([miles,percentTats,iceCream])
    classifierResult = classify((inArr-minVals)/ranges,normMat,datingLabels,3)
    print("你对这个人的喜欢程度:",resultList[classifierResult - 1])

新建test.py文件了解程序的运行结果,代码:

#coding:utf-8

from numpy import *
import operator
from collections import Counter
import matplotlib
import matplotlib.pyplot as plt

import sys
sys.path.append("D:\python\Mechine learing in Action\KNN")
import date
date.classifyPerson()


                

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1117587.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【CSS】全局滚动条样式设置

直接在 App.vue 全局文件下设置滚动条样式: ::-webkit-scrollbar {width: 5px;position: absolute; } ::-webkit-scrollbar-thumb {background: #1890ff; } ::-webkit-scrollbar-track {background: #ddd; }

力扣每日一题51:N皇后问题

题目描述: 按照国际象棋的规则,皇后可以攻击与之处在同一行或同一列或同一斜线上的棋子。 n 皇后问题 研究的是如何将 n 个皇后放置在 nn 的棋盘上,并且使皇后彼此之间不能相互攻击。 给你一个整数 n ,返回所有不同的 n 皇后问…

【51单片机外部中断控制流水灯转向】2023-10-21

缘由单片机不会搞 原理都清晰合一块成傻杯了 各位爷 用keil Vison5 还有Proteus8仿真图给出一下吧_嵌入式-CSDN问答 #include <reg52.h> unsigned char Js0; bit k0; void main() {//缘由unsigned char ls0; EA1;//总中断允许EX01;//允许外部中断0中断TH0(65536-50000)…

mysql优化之explain详解

mysql的explain&#xff08;执行计划&#xff09;用于解释sql的执行的过程&#xff0c;然后把sql的执行过程用一张表格表示出来&#xff0c;它并不真正的执行sql&#xff0c;如下图。explain能够为我们优化sql提供很好参考作用。 下面我来看下执行计划表中各个字段是什么意思 i…

【Linux】kill 命令使用

经常用kill -9 XXX 。一直在kill&#xff0c;除了kill -9 -15 &#xff0c;还能做什么&#xff1f;今天咱们一起学习一下。 kill 命令用于删除执行中的程序或工作。 kill命令 -Linux手册页 命令选项及作用 执行令 man kill 执行命令结果 参数 -l 信号&#xff0c;若果…

【吞噬星空】又被骂,罗峰杀人目无法纪,但官方留后手,增加审判戏份

Hello,小伙伴们&#xff0c;我是小郑继续为大家深度解析国漫吞噬星空资讯。 吞噬星空动画中&#xff0c;罗峰复仇的戏份&#xff0c;简直是帅翻了&#xff0c;尤其是秒杀阿特金三大巨头&#xff0c;让人看的也是相当的解气&#xff0c;相当的爽&#xff0c;一点都不拖沓&#x…

有什么站内搜索引擎优化的方法?今天跟大家分享!

在你的网站上安装站内搜索引擎对于提升用户体验和增加互动至关重要。在今天快节奏的数字世界中&#xff0c;用户希望能够快速、轻松地找到信息。通过提供站内搜索引擎&#xff0c;用户能够轻松浏览你的网站&#xff0c;帮助他们找到他们正在寻找的具体信息。接下来我将跟大家介…

浅析高校用电问题及智慧电力监管平台的构建

安科瑞 崔丽洁 摘 要&#xff1a;介绍了当前高校用电存在的问题&#xff0c;进行了原因分析&#xff0c;由此提出建立高校用电智慧监管平台。对高校用电智慧监管平台的构架进行设计&#xff0c;运用物联网技术&#xff0c;实现各回路实时自主控制&#xff0c;并细化管理权限&a…

ATA-8202射频功率放大器在超声雾化研究中的应用

超声雾化技术是一种利用高频声波能量产生微细液滴的技术&#xff0c;广泛应用于医学、生物科学、材料科学等领域。在超声雾化过程中&#xff0c;功率放大器扮演着关键的角色&#xff0c;它能提供足够的能量来驱动超声发射装置&#xff0c;并调节声波参数&#xff0c;实现有效的…

【数据结构】线性表(八)队列:顺序队列及其基本操作(初始化、判空、判满、入队、出队、存取队首元素)

文章目录 一、队列1. 定义2. 基本操作 二、顺序队列0. 顺序表1. 头文件和常量2. 队列结构体3. 队列的初始化4. 判断队列是否为空5. 判断队列是否已满6. 入队7. 出队8. 存取队首元素9. 主函数10. 代码整合 堆栈Stack 和 队列Queue是两种非常重要的数据结构&#xff0c;两者都是特…

论文浅尝 | Concept2Box:从双视图学习知识图谱的联合几何嵌入模型

笔记整理&#xff1a;张钊源&#xff0c;天津大学硕士&#xff0c;研究方向为知识图谱 链接&#xff1a;https://virtual2023.aclweb.org/paper_P4210.html 动机 知识图嵌入&#xff08;KGE&#xff09;已被广泛研究&#xff0c;用于嵌入大规模关系数据以满足许多现实世界的应用…

Spring Security总体架构介绍

参考&#xff1a;架构 :: Spring Security Reference (springdoc.cn) 一、过滤器 Spring Security 框架对 Servlet 请求的处理是基于过滤器机制。 容器会提前创建好FilterChain对每一个请求进行过滤&#xff0c;FilterChain中包含Filter 实例和 Servlet&#xff08;Spring MV…

编写后台登录滑动成功获取验证码 人机验证

vue-puzzle-vcode Vue 纯前端的拼图人机验证、右滑拼图验证 安装vue-puzzle-vcode npm install vue-puzzle-vcode --save使用vue-puzzle-vcode import Vcode from "vue-puzzle-vcode";<Vcode :show"isShow" success"onSuccess" close"…

ZooKeeper+HBase分布式集群环境搭建

安装版本&#xff1a;hadoop-2.10.1、zookeeper-3.4.12、hbase-2.3.1 一、zookeeper集群搭建与配置 1.下载zookeeper安装包 2.解压移动zookeeper 3.修改配置文件&#xff08;创建文件夹&#xff09; 4.进入conf/ 5.修改zoo.cfg文件 6.进入/usr/local/zookeeper-3.4.12/zkdat…

虚拟机与主机(win10之间的通信)

(201条消息) Ubuntu虚拟机不显示ip地址【已解决】_ubuntu没有ip_不爱赖床的懒虫的博客-CSDN博客 sudo /sbin/dhclient VMTool安装与卸载 (201条消息) ubuntu中vmtools的安装与彻底卸载_卸载vmtools_林麦安的博客-CSDN博客 (202条消息) 解决虚拟机安装 VMware Tools 灰色无法…

聊聊RocketMQ中的broker的TPS和QPS为何相差巨大,是如何统计的

这里是weihubeats,觉得文章不错可以关注公众号小奏技术&#xff0c;文章首发。拒绝营销号&#xff0c;拒绝标题党 最近在看RocketMQ的一些监控指标的时候&#xff0c;总觉得一些监控指标不太对&#xff0c;好像对不上。 所以打算研究下看看RocketMQ中的 broker TPS、broker QP…

嵌入式学习笔记(60)内存管理之堆

1.7.1.什么是堆&#xff08;heap&#xff09; 内存管理对OS来说是一件非常复杂的事&#xff0c;因为首先内存容量大&#xff0c;其次内存需求在时间和大小块上没有规律&#xff08;OS上运行着几十、几百、几千个进程随时都会申请或者释放内存&#xff0c;申请或者释放的内存块…

JavaWeb从入门到起飞笔记——导学课程

学完这一节&#xff0c;我不知道学Web开发究竟能干什么&#xff1f;你知道吗&#xff1f; 以下是黑马程序员Java从入门到起飞的笔记 一、学完Javaweb能干什么&#xff1f; 学完Java后我们可以独立开发一些后台管理系统&#xff0c;例如CRMER器&#xff0c;京东和淘宝&#x…

tuxera ntfs2024破解版mac电脑磁盘读写软件

大家都知道由于操作系统的原因&#xff0c;在苹果电脑上不能够读写NTFS磁盘&#xff0c;但是&#xff0c;今天小编带来的这款tuxera ntfs 2024 mac版&#xff0c;完美的解决了这个问题。这是一款在macOS平台上使用的磁盘读写软件&#xff0c;能够实现苹果Mac OS X系统读写Micro…

C++ 字符串编码转换封装函数,UTF-8编码与本地编码互转

简介 字符串编码转换封装函数&#xff0c;UTF-8编码与本地编码互转。 中文乱码的解决方法 有时候我们会遇到乱码的字符串&#xff0c;比如&#xff1a; 古文码 可能是用GBK方式读取UTF-8编码的中文导致的&#xff0c;用下面的Utf8ToLocal(string str)函数转换一下就可以了。…