机器学习作业二之KNN算法

news2024/11/15 18:41:12

KNN(K- Nearest Neighbor)法即K最邻近法,最初由 Cover和Hart于1968年提出,是一个理论上比较成熟的方法,也是最简单的机器学习算法之一。该方法的思路非常简单直观:如果一个样本在特征空间中的K个最相似(即特征空间中最邻近)的样本中的大多数属于某一个类别,则该样本也属于这个类别。该方法在定类决策上只依据最邻近的一个或者几个样本的类别来决定待分样本所属的类别 。

该方法的不足之处是计算量较大,因为对每一个待分类的文本都要计算它到全体已知样本的距离,才能求得它的K个最邻近点。目前常用的解决方法是事先对已知样本点进行剪辑,事先去除对分类作用不大的样本。另外还有一种 Reverse KNN法,它能降低KNN算法的计算复杂度,提高分类的效率 。

KNN算法比较适用于样本容量比较大的类域的自动分类,而那些样本容量较小的类域采用这种算法比较容易产生误分 。

——百度百科

一、算法思想:

已经有的样本,均有n个特征值,可以用n个坐标轴表示出这个样本点的位置。

而测试集中的元素,也均有n个特征值,也可以用n个坐标轴表示出这个样本点的位置。

对于每个测试集中的元素,找到距离其最近的k个点(距离可使用欧氏距离d(x, y) = \sqrt{\sum_{i=1}^{n} (x_i - y_i)^2}),在这k个点中选出数量最多的一个种类,将这个种类作为其结果。

需要注意的是,由于每个坐标的相对大小不同, 需要将数值做归一化处理。


二、代码

思想不难,代码:

import csv
import math
import operator

from matplotlib import pyplot as plt

def guiyihua(train, input):
    maxval1 = 0
    minval1 = 101
    maxval2 = 0
    minval2 = 101
    for i in range( len(train) ):
        train[i][0] = float(train[i][0])
        train[i][1] = float(train[i][1])
        maxval1 = max(maxval1, train[i][0])
        minval1 = min(minval1, train[i][0])
        maxval2 = max(maxval2, train[i][1])
        minval2 = min(minval2, train[i][1])
    for i in range( len(input) ):
        input[i][0] = float(input[i][0])
        input[i][1] = float(input[i][1])
        maxval1 = max(maxval1, input[i][0])
        minval1 = min(minval1, input[i][0])
        maxval2 = max(maxval2, input[i][1])
        minval2 = min(minval2, input[i][1])
    for i in range( len(train)):
        train[i][0] = (train[i][0]-minval1)/(maxval1-minval1)
        train[i][1] = (train[i][1]-minval2)/(maxval2-minval2)
    for i in range( len(input) ):
        input[i][0] = (input[i][0]-minval1)/(maxval1-minval1)
        input[i][1] = (input[i][1]-minval2)/(maxval2-minval2)

def load(fname):
    with open(fname, 'rt') as csvfile:
        lists = csv.reader(csvfile)
        data = list(lists)
        return data
    
def euclideanDistance(atrain, ainput, needcal):
    re2 = 0
    for i in range(needcal):
        re2 += (atrain[i] - ainput[i])**2
    return math.sqrt(re2)

def jg(train, ainput, k):
    alldis = []
    needcal = len(ainput)-1 #需要计算的维度
    for i in range(len(train)):
        nowdis = euclideanDistance(train[i], ainput, needcal)
        alldis.append((train[i], nowdis))
    alldis.sort(key=operator.itemgetter(1))
    
    vote = {}
    for i in range(k):
        type = alldis[i][0][-1]
        if type in vote:
            vote[type] += 1
        else:
            vote[type] = 1
        
    sortvote = sorted(vote.items(), key=operator.itemgetter(1), reverse=True)#items()将字典转为列表,这样可以对第二个值进行排序
    return sortvote[0][0]
 
def showright(train, input):
    plt.subplot(2, 5, 1)
    plt.title("right")
    for i in range(len(train)):
        if train[i][-1] == "第一种" :  
            plt.scatter(train[i][0], train[i][1], c = '#0066FF', s = 10, label = "第一种")
        else :
            plt.scatter(train[i][0], train[i][1], c = '#CC0000', s = 10, label = "第二种")
    for i in range(len(input)):
        if input[i][-1] == "第一种" :
            plt.scatter(input[i][0], input[i][1], c = '#0066FF', s = 50, label = "cs第一种")
            #plt.scatter(input[i][0], input[i][1], c = '#FF3333', s = 30, label = "cs第一种")
        else :
            plt.scatter(input[i][0], input[i][1], c = '#CC0000', s = 50, label = "cs第一种")
            #plt.scatter(input[i][0], input[i][1], c = '#FF33FF', s = 30, label = "cs第二种")

def showtest(train, input, re, ki, cnt):
    plt.subplot(2, 5, ki+1)
    plt.title("k = "+ repr(ki)+" acc: "+repr(1.0*cnt/(1.0*len(input))*100 )+ '%')
    for i in range(len(train)):
        if train[i][-1] == "第一种" :
            plt.scatter(train[i][0], train[i][1], c = '#0066FF', s = 10, label = "第一种")
        else :
            plt.scatter(train[i][0], train[i][1], c = '#CC0000', s = 10, label = "第二种")
    for i in range(len(input)):
        if re[i] == "第一种" :
            plt.scatter(input[i][0], input[i][1], c = '#0066FF', s = 50, label = "cs第一种")
            #plt.scatter(input[i][0], input[i][1], c = '#00FF33', s = 30, label = "cs第一种")
        else :
            plt.scatter(input[i][0], input[i][1], c = '#CC0000', s = 50, label = "cs第二种")
            #plt.scatter(input[i][0], input[i][1], c = '#00FFFF', s = 30, label = "cs第二种")
def main():
    train = load("C:\\Users\\T.HLQ12\\Desktop\\wdnmd\\python\\jiqixuexi\\train.csv")
    input = load("C:\\Users\\T.HLQ12\\Desktop\\wdnmd\\python\\jiqixuexi\\test.csv")
    guiyihua(train, input)
    # print(train)
    # print(input)
    showright(train, input)
    for ki in range (1, 10):
        re = []
        k = ki
        cnt = 0
        for i in range(len(input)):
            type = jg(train, input[i], k)
            if(type == input[i][-1]):
                cnt += 1
            re.append(type)
            print("预测:" + type + ",实际上: " + input[i][-1])
        print("准确率: " + repr(1.0*cnt/(1.0*len(input))*100) + '%')
        showtest(train, input, re, ki, cnt)
    plt.show()
main()
    

逐个解释一下:

guiyihua:

不会归一化的英文,就写拼音了,从训练集和测试集中找出一个最大值和最小值。然后把训练集和测试集的数据都减去最小值,再除以最大值减最小值即可。

load:

使用with open可以不用人为关闭文件。其中csv.reader会返回一个迭代器,配合list将data赋值为二维数组。

euclideanDistance:

欧式距离,就是把所有维度平方下相加,然后再返回开根号的值。

jg:

这个是judge的缩写,判断输入的测试集中的一个元素的种类。函数的参数有训练集,一个输入的值和一个k。遍历训练集中的所有元素,算出距离测试点的欧式距离,然后添加到alldis数组里。最后对数组进行排序(参数中意味按照元组中第一个值排序,默认从大到小)。然后创建一个vote字典。这个字点的第一个值是种类,第二个值是种类的个数。循环遍历距离数组,每次碰到一个种类,就把这个种类的数量加一。循环结束后,对这个字典的第二个值进行排序(排序中参数:第一个item:将字典vote转换为一个包含键值对的列表, 第二个:对下标1进行排序,第三个:从大到小排序),选出最大数量的种类作为这个测试元素的结果。

shoright 与showtest:

这两个函数是用于绘制散点图的。Subplot中第一个参数是行数,第二个参数是列数。第三个参数是第几个部分。Title中可以设置这张图的标题。训练集中每个种类的颜色都不一样,点是使用scatter打上去的,其中第一个参数是这个点在第一条坐标轴上对应的值。第二个点是第二个坐标轴上对应的值,c是颜色。S是点的大小。 label是这个点的标签。循环遍历训练集和测试的每个点,就可以绘制出一张散点图。

三、实际问题

一、

如图所示,这个报错是因为vote中不存在vote括号中的值,修改为:

即可。

二、

这个问题不知道发生的原因是什么,查询资料本来以为是scatter可以使用切分,但是实际上没有办法使用,最后就替换成了循环遍历每个点来绘制散点图的方式。

四、实验结论

结果:(第一张图为对的。对于每张图,大的是测试集,小的是训练集,颜色相同的是一个种类)

数据:

从结果分析上来看,K在1~9范围内,不能很好的确定最优值,需要多次取值,反复确认才能锁定k值。

尽管有着计算量大,维度灾难等缺点,但是可以不用训练,容易理解,对于新手来说很友好。

(纯手打,求老师轻点批改)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1544981.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Ubuntu deb文件 安装 MySQL

更新系统软件依赖 sudo apt update && sudo apt upgrade下载安装包 输入命令查看Ubuntu系统版本 lsb_release -a2. 网站下载对应版本的安装包 下载地址. 解压安装 mkdir /home/mysqlcd /home/mysqltar -xvf mysql-server_8.0.36-1ubuntu20.04_amd64.deb-bundle.tar# …

Pandas操作MultiIndex合并行列的Excel,写入读取以及写入多余行及Index列处理,插入行,修改某个单元格的值

Pandas操作MultiIndex合并行列的excel,写入读取以及写入多余行及Index列处理 1. 效果图及问题2. 源码参考 今天是谁写Pandas的 复合索引MultiIndex,写的糊糊涂涂,晕晕乎乎。 是我呀… 记录下,现在终于灵台清明了。 明天在记录下直…

02-K近邻算法

机器学习其实有一个很朴实的想法: 预测 x x x的值, 那就在训练集 X X X中找到与 x x x相似的样本, 再把与x相似的这些样本的值加权作为预测值 那么我们如何度量样本之间的相似性?又该如何加权呢? 在k近邻中, 我们一般采…

【CXL协议-事务层之CXL.cache (3)】

3.2 CXL.cache 3.2.1 概述 CXL.cache 协议将设备和主机之间的交互定义为许多请求,每个请求至少有一个关联的响应消息,有时还有数据传输。 该接口由每个方向的三个通道组成: 请求、响应和数据。 这些通道根据其方向命名,D2H&…

基于FPGA实现的自适应三速以太网

一、三速以太网 千兆以太网PHY芯片是适配百兆和十兆的&#xff0c;十兆就不管了&#xff0c;我们的设计只适应千兆和百兆。 根据上图&#xff0c;我们是可以获取当前主机网口的速率信息的。 always(posedge w_rxc_bufr) beginif(w_rec_valid d0) beginro_speed < w_rec_…

【r-tree算法】一篇文章讲透~

目录 一、引言 二、R-tree算法的基本原理 1 数据结构 2 插入操作 3 删除操作 4 查询操作 5 代码事例 三、R-tree算法的性能分析 1 时间复杂度 2 空间复杂度 3 影响因素 四、R-tree算法的变体和改进 1 R*-tree算法 2 X-tree算法 3 QR-tree算法 五、R-tree算法的…

【物联网】Qinghub Kafka 数据采集

基础信息 组件名称 &#xff1a; kafka-connector 组件版本&#xff1a; 1.0.0 组件类型&#xff1a; 系统默认 状 态&#xff1a; 正式发布 组件描述&#xff1a;通用kafka连接网关&#xff0c;消费来自kafka的数据&#xff0c;并转发给下一个节点做相关的数据解析。 配置文…

http模块 获取http请求报文中的路径 与 查询字符串

虽然request.url已包含属性和查询字符串&#xff0c;但使用不便&#xff0c;若只需其中一个不好提取&#xff0c;于是用到了如下路径和字符串的单独查询方法&#xff1a; 一、获取路径 例如&#xff1a;我在启动谷歌端口时输入http://127.0.0.1:9000 后接了 "/search?k…

Docker 搭建Redis集群

目录 1. 3主3从架构说明 2. 3主3从Redis集群配置 2.1关闭防火墙启动docker后台服务 2.2 新建6个docker容器实例 2.3 进去任意一台redis容器&#xff0c;为6台机器构建集群关系 2.4 进去6381&#xff0c;查看集群状态 3. 主从容错切换迁移 3.1 数据读写存储 3.1.1 查看…

27---eMMC电路设计

视频链接 eMMC电路设计01_哔哩哔哩_bilibili eMMC电路设计 1、eMMC简介 eMMC叫嵌入式多媒体卡&#xff0c;英文全称为Embedded Multi Media Card。是一种闪存卡&#xff08;Flash Memory Card&#xff09;标准&#xff0c;它定义了MMC的架构以及访问Flash Memory的接口和协…

Linux 搭建jenkins docker

jekin docker gitee docker 安装 jenkins docker run -d --restartalways \ --name jenkins -uroot -p 10340:8080 \ -p 10341:50000 \ -v /home/docker/jenkins:/var/jenkins_home \ -v /var/run/docker.sock:/var/run/docker.sock \ -v /usr/bin/docker:/usr/bin/docker je…

【双指针】Leetcode 盛最多水的容器

题目解析 11. 盛水最多的容器 木桶效应&#xff0c;寻找一个区间使得这个区间的体积最大 算法讲解 1. 暴力枚举 遍历这个容器&#xff0c;将每一个区间的体积求出来&#xff0c;然后找出最大的 class Solution { public:int maxArea(vector<int>& height){int n…

SQLite数据库文件损坏的可能几种情况(一)

返回&#xff1a;SQLite—系列文章目录 上一篇&#xff1a;SQLiteC/C接口详细介绍sqlite3_stmt类&#xff08;十三&#xff09; 下一篇&#xff1a;SQLite使用的临时文件&#xff08;二&#xff09; 概述 SQLite数据库具有很强的抗损坏能力。如果应用程序崩溃&#xff0c…

如何在内网访问其他电脑?

网络的发展使得人与人之间的通信变得更加便捷&#xff0c;而在公司或者家庭中&#xff0c;也经常遇到需要内网访问其他电脑的需求。内网访问其他电脑可以实现在局域网内部进行数据共享、文件传输、远程控制等操作&#xff0c;提高工作效率和便利性。本文将介绍内网访问其他电脑…

labelImg | windows anaconda安装labelImg

labelImg 是图片标注软件&#xff0c;用于数据集的制作、标注等等。 下面介绍 labelImg 的安装过程。 用的是 anaconda&#xff0c;所以以 anaconda prompt 作为终端&#xff1a; 在 Anaconda Prompt 中依次运行以下命令&#xff08;注意大小写&#xff09;&#xff1a; pi…

评测 r5 8640HS和i5 12500H选哪个 锐龙r58640HS和酷睿i512500H对比

r7 8840HS采用 Zen 4架构 4 nm制作工艺8核 16线程主频 3.3GHz睿频5.1GHz 三 级缓存16MB TDP 功耗 28w 搭载AMD Radeon 780M核显 选r7 8840HS还是i5 12500H这些点很重要 http://www.adiannao.cn/dy i5 12500H为4大核8小核&#xff0c;12核心16线程设计&#xff0c;CPU主频 2.5…

面试知识汇总——垃圾回收器(分代收集算法)

分代收集算法 根据对象的存活周期&#xff0c;把内存分成多个区域&#xff0c;不同区域使用不同的回收算法回收对象。 对象在创建的时候&#xff0c;会先存放到伊甸园。当伊甸园满了之后&#xff0c;就会触发垃圾回收。 这个回收的过程是&#xff1a;把伊甸园中的对象拷贝到F…

Python私有属性和私有方法

私有属性和私有方法 在实际开发中&#xff0c;对象的某些属性或者方法只希望在对象内部被使用&#xff0c;而不希望在外界被访问。 私有属性&#xff1a;对象不希望公开的属性 私有方法&#xff1a;对象不希望公开的方法 定义方式&#xff1a;在属性名或者方法名前添加两个下划…

计算机网络常见题(持续更新中~)

1 描述一下HTTP和HTTPS的区别 2 Cookie和Session有什么区别 3 如果没有Cookie,Session还能进行身份验证吗&#xff1f; 4 BOI,NIO,AIO分别是什么 5 Netty的线程模型是怎么样的 6 Netty是什么&#xff1f;和Tomcat有什么区别&#xff0c;特点是什么&#xff1f; 7 TCP的三次…

基于SpringBoot+MyBatis网上点餐系统

采用技术 基于SpringBootMyBatis网上点餐系统的设计与实现~ 开发语言&#xff1a;Java 数据库&#xff1a;MySQL 技术&#xff1a;SpringBootMyBatis 工具&#xff1a;IDEA/Ecilpse、Navicat、Maven 页面展示效果 功能列表 前台首页功能 用户注册 用户登录 用户功能 …