机器学习小结之KNN算法

news2025/1/16 21:41:26

文章目录

  • 前言
  • 一、概念
    • 1.1 机器学习基本概念
    • 1.2 k 值
    • 1.3 距离度量
    • 1.4 加权方式
  • 二、实现
    • 2.1 手写实现
    • 2.2 调库 Scikit-learn
    • 2.3 测试自己的数据
  • 三、总结
    • 3.1 分析
    • 3.2 KNN 优缺点
  • 参考

前言

KNN (K-Nearest Neighbor)算法是一种最简单,也是一个很实用的机器学习的算法,在《机器学习实战》这本书中属于第一个介绍的算法。它属于基于实例的有监督学习算法,本身不需要进行训练,不会得到一个概括数据特征的模型,只需要选择合适的参数 K 就可以进行应用。KNN的目标是在训练数据中发现最佳的 K 个近邻,并根据这些近邻的标签来预测新数据的标签。每次使用 KNN 进行预测时,所有的训练数据都会参与计算。

kNN有很多应用场景:

  • 分类问题,同时天然可以处理多分类问题,比如根据音乐的特征,将其归类到不同的类型。
  • 推荐系统,根据用户的历史行为,推荐相似的物品或服务
  • 图像识别,比如人脸识别、车牌识别等

一、概念

1.1 机器学习基本概念

机器学习是人工智能领域中非常重要的一个分支,它可以帮助我们从大量数据中发现规律并做出预测。

机器学习可以分为监督学习、无监督学习和半监督学习三种类型。

  • 监督学习是指在训练数据中已经标注了正确答案,通过这些数据来训练模型,然后对新数据进行预测。
  • 无监督学习是指在训练数据中没有标注正确答案,通过对数据的聚类、降维等操作来发现数据中的规律。
  • 半监督学习则是介于有监督学习和无监督学习之间的一种方法。

下表是对机器学习一些基本概念解释

概念解释备注
分类将数据集分为不同的类别属于监督学习
聚类将数据集分为由类似的对象组成多个类的过程属于无监督学习
回归指预测连续型数值数据属于监督学习
样本集一般指用于训练模型的数据集,一般分为训练集和测试集。在样本集中,每个样本都包含一个或多个特征和一个标签。
特征用于描述样本的属性或特点通常是训练样本集的列,他们是独立测量的结果,多个特征联系在一起共同组成一个训练样本
标签样本所属的类别或结果
模型从训练数据中学习到的规律或模式。在机器学习中,模型可以用于预测新数据的标签或值
梯度指函数在某一点处的变化率。在机器学习中,梯度可以用于以最小化损失函数优化模型参数

1.2 k 值

k值是指在多个邻居中,选择前k个最相似邻居的类别来决定当前样本的类别,通常 k 是不大于20的整数,常选择3或5

1.3 距离度量

距离度量是指在 kNN 算法中用来计算样本之间距离的方法。常用的距离度量有欧氏距离、曼哈顿距离、切比雪夫距离、闵可夫斯基距离等。

  • 欧式距离

    • 二维平面

      d = ( x 1 − x 2 ) 2 + ( y 1 − y 2 ) 2 d = \sqrt{(x_1 - x_2)^2 + (y_1 - y_2)^2} d=(x1x2)2+(y1y2)2

    • n维

      d = ∑ i = 1 n ∣ x i − y i ∣ 2 d=\sqrt{\sum_{i=1}^{n}{\left| x_{i}-y_{i} \right|^{2}}} d=i=1nxiyi2

  • 曼哈顿距离

    d = ∑ i = 1 n ∣ x i − y i ∣ d= \sum_{i=1}^{n}|x_i - y_i| d=i=1nxiyi

  • 切比雪夫距离

    d = m a x ( ∣ x 1 − x 2 ∣ , ∣ y 1 − y 2 ∣ , ⋯   , ∣ x i − y i ∣ ) d= max(|x_1 - x_2|, |y_1 - y_2|, \cdots, |x_i - y_i|) d=max(x1x2,y1y2,,xiyi)

  • 闵可夫斯基距离

    d = ∑ i = 1 n ( ∣ x i − y i ∣ ) p p d = \sqrt[p]{\sum_{i=1}^{n}(|x_i - y_i|)^p} d=pi=1n(xiyi)p

1.4 加权方式

KNN 算法中的加权方式指的是在计算距离时,对不同距离的样本使用不同的权重。这些权重可以是距离样本数据源的距离,也可以是不同样本之间的距离。加权的方式可以根据实际情况进行选择,以达到更好的分类或预测效果。

常用的数值数据加权方式如下:

  1. 加权平均值:将K个邻居的属性值加权平均后作为新数据点的预测值。
  2. 均值法:将K个邻居的属性值取平均值后作为新数据点的预测值。
  3. 最差值:将K个邻居的属性值取最小值和最大值,再取平均值作为新数据点的预测值。

常见的离散型数据加权方式如下:

  1. 反函数
  2. 高斯函数
  3. 多项式函数

不同的加权方式可以根据实际情况选择,以达到更好的分类或预测效果。

二、实现

手写数字数据集 为 《机器学习实战》第二章 提供的数据集:https://github.com/pbharrin/machinelearninginaction

2.1 手写实现

import numpy as np
from collections import Counter
import operator
import math
from os import listdir

# inX 输入向量
# dataSet 训练集
# labels 训练集所代表的标签
# k 最近邻居数目
# output: label
def classify0(inX, dataSet, labels, k):
    sortedDistIndicies=euclideanDistance(inX, dataSet)
    classCount = {}
    for i in range(k):
        voteIlabel = labels[sortedDistIndicies[i]]
        classCount[voteIlabel] = classCount.get(voteIlabel, 0) +  1.0 * weight(sortedDistIndicies[i])
    sortedClassCount = sorted(classCount.items(), key = operator.itemgetter(1), reverse = True)

    return sortedClassCount[0][0]

def euclideanDistance(inX, dataSet):
    dataSetSize = dataSet.shape[0]
    diffMat = np.tile(inX, (dataSetSize, 1)) - dataSet
    sqDiffMat = diffMat ** 2
    sqDistances = sqDiffMat.sum(axis = 1)
    distances = sqDistances ** 0.5
    sortedDistIndicies = distances.argsort()
    return sortedDistIndicies

def weight(dist):
    return 1


def classify1(test, train, trainLabel, k):
    distances = []
    for i in range(len(train)):
        distance = np.sqrt(np.sum(np.square(test - train[i, :])))
        distances.append([distance, i])
    distances = sorted(distances)
    targets = [trainLabel[distances[i][1]] for i in range(k)]
    return Counter(targets).most_common(1)[0][0]

def img2vector(filename):
    returnVect = np.zeros((1,1024))
    fr = open(filename)
    for i in range(32):
        lineStr = fr.readline()
        for j in range(32):
            returnVect[0,32*i+j] = int(lineStr[j])
    return returnVect

def handWritingDataSet(inputDir):
    hwLabels = []
    fileNames = []
    dataFileList = listdir(inputDir)           
    m = len(dataFileList)
    dataMat = np.zeros((m,1024))
    for i in range(m):
        fileNameStr = dataFileList[i]
        fileStr = fileNameStr.split('.')[0]     
        classNumStr = int(fileStr.split('_')[0])
        hwLabels.append(classNumStr)
        fileNames.append(fileStr)
        dataMat[i,:] = img2vector( inputDir + '/%s' % fileNameStr)
    return dataMat,hwLabels,fileNames

trainMat, trainLabels, _ = handWritingDataSet('digits/trainingDigits/')
testMat, testLabels,testFileNames = handWritingDataSet('digits/testDigits/')

errorCount = 0
k = 3
for idx, testData in enumerate(testMat):
    prefictLabel = classify0(testData, trainMat, trainLabels, k)
    # prefictLabel = classify1(testData, trainMat, trainLabels, k)
    if testLabels[idx] != prefictLabel:
        errorCount+=1
        print("错误数据:%s.txt, 预测数字:%d" % (testFileNames[idx], prefictLabel))
print("k值:%d, 错误数量:%d, 错误率:%.3f%%" %(k, errorCount, errorCount / 1.0 / np.size(testMat, 0) * 100))

knn_classfi

2.2 调库 Scikit-learn

文档地址:https://scikit-learn.org/stable/modules/generated/sklearn.neighbors.KNeighborsClassifier.html#sklearn.neighbors.KNeighborsClassifier

from sklearn.neighbors import KNeighborsClassifier

trainMat, trainLabels, _ = handWritingDataSet('digits/trainingDigits/')
testMat, testLabels,testFileNames = handWritingDataSet('digits/testDigits/')

errorCount = 0
k = 3

neigh = KNeighborsClassifier(n_neighbors=k)
neigh.fit(trainMat, trainLabels)

for idx, testData in enumerate(testMat):
    prefictLabel = neigh.predict([testData])
    if testLabels[idx] != prefictLabel:
        errorCount+=1
        print("错误数据:%s.txt, 预测数字:%d" % (testFileNames[idx], prefictLabel))
print("k值:%d, 错误数量:%d, 错误率:%.3f%%" %(k, errorCount, errorCount / 1.0 / np.size(testMat, 0) * 100))

knn_scikit-learn

2.3 测试自己的数据

上面的手写数据集,训练集有1934个,测试集有946个,都是32x32的图片转的文本。如果想测试自己的手写数字,那就需要将手写数字图片先转成32x32像素格式的图片,然后再转成文本,下面是一个图片转文本代码

import cv2
import os

def img2txt(inputDir):
    dataFileList = os.listdir(inputDir)

    for file in dataFileList:
        if not file.endswith('png'):
            continue
        img = cv2.imread(inputDir + file, cv2.IMREAD_GRAYSCALE)
        fr = open(inputDir + file.split('.')[0] + '.txt', 'w')
        height, width = img.shape[0:2]

        for row in range(height):
            line = ''
            for col in range(width):
                if img[row, col] > 250:
                    line+='0'
                else:
                    line+='1'
            fr.write(line)
            fr.write('\n')

        fr.close()
if __name__ == '__main__':
    img2txt('img/')

下面准备自己手写的0-9 十个数字进行测试,下面数字是用windows画图工具,先裁剪为32x32像素,再用鼠标手写实现。

self-digits

将10个数字转成文本进行测试,结果错误率在30%

self_knn

三、总结

3.1 分析

  • 识别测试集手写数字时,总是有一些样本不能正确识别,通过观察发现是因为与其他类别特征比较接近
  • 使用自身手写数字识别,识别错误的样本并不是因为相类似,例如4被识别为7,这个不太明白,可能与样本特征有关

3.2 KNN 优缺点

  • 优点
    1. 思想简单,理论成熟,既可以用来做分类也可以用来做回归
    2. 由于选择距离最近的 k 个,对异常值不敏感
  • 缺点
    1. KNN 需要计算每个测试样本与所有训练样本之间的距离,时间复杂度很高,计算成本也很高
    2. 无法给出数据任何的基础结构信息
    3. 算法比较简单,当训练数据较小时,对于一些很相似的不同类数据很难区分

参考

  1. https://github.com/pbharrin/machinelearninginaction

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/487758.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

VLAD Diffusion,一个更好用且易于安装的Stable Diffusion Web UI

VLAD Diffusion 是我们前面介绍过的 AUTOMATIC1111/stable-diffusion-webui的一个定制的更新,它主要是为了更频繁发布的更新和错误修复。它包含 新的安装程序,并且提供了高级CUDA调优不在依赖Accelerate,因为Accelerate是分布式的&#xff0…

setTimeout不准时,CSS精准实现计时器功能

实际开发过程中,我们会经常遇到,首次进入页面进行相应提示,然后指定时间后自动消失或者前端时钟展示等需求。 按照传统方案,我们可以使用 setTimeout 实现。但其存在:实际延时比设定值更久的情况。 setTimeout 不准时…

单个案例奖金2000元!AidLux AI 应用案例悬赏征集活动第二期选题上线啦

AidLux AI 应用案例悬赏征集活动第一期开发者作品新鲜"出炉"啦! 得益于AidLux在AI应用部署端的极大优势,开发者们在短时间内轻松落地了大批AI应用。 其中,不乏后厨老鼠识别告警系统、粮食作物特定病虫害告警系统、基于视觉的仰卧起…

专注主业、管控风险,中国春来的“非激进式扩张”

近日,中国春来发布截至2023年2月28日止六个月的中期业绩公告,期内收入同比增长14.2%至7.49亿元,利润同比上涨32%至3.31亿元,交出了亮眼的成绩单。 探究中国春来业绩上涨的原因,关键在于扩大招生。而招生规模很大程度上…

ChatGPT终于被我问到胡说八道的程度了!

问:Python是强类型语言,还是弱类型语言 chatgpt:Python是强类型语言。Python很少会隐式地转换变量的类型,所以Python是强类型的语言 问:什么是强类型语言 chatgpt:强类型语言是指在编程语言中&#xff0…

自动控制原理笔记-频率响应法-系统的开环频率特性图的绘制

目录 一、系统的开环对数频率特性图(Bode图) 绘制方法I:(各环节的Bode图求和) 绘制方法II:(不求和,直接绘图) 二、系统的开环幅相特性图(Nyquist图、极坐标…

Linux网络编程——网络基础[1]

目录 1.网络发展 2.初识协议 2.1协议分层 2.2OSI七层模型 2.3TCP/IP四层(五层)模型 3.网络传输的基本流程 3.1协议报头 3.2局域网通信原理 3.3广域网通信原理 3.4数据包的封装和分用 4.网络中的地址管理 1.网络发展 计算机是帮助人解决计算问题的,而人…

实在智能出席第六届数字中国建设峰会,入围2022年信息技术应用创新优秀解决方案榜单

最美榕城四月天,山海之间尽显数字澎湃。这一周来,实在智能来到了“有福之州”,为数字中国建设增添实在色彩。 4月25日,实在华夏行抵达福州站,与众多生态合作伙伴携手共话数字发展新未来; 4月26日&#xff…

在DARTS空间中进行神经架构搜索(NAS)

前言 神经架构搜索(NAS):自动化设计高性能深度神经网络架构的技术神经架构搜索任务主要有三个关键组成部分,即: 模型搜索空间,定义了一个要探索的模型的集合一个合适的策略作为探索这个模型空间的方法一个模型评估器,…

全景丨0基础学习VR全景制作,平台篇第15章:热点功能-音图文

大家好,欢迎观看蛙色VR官方——后台使用系列课程! 功能说明 应用场景 热点,指在全景作品中添加各种类型图标的按钮,引导用户通过按钮产生更多的交互,增加用户的多元化体验。 音图文热点,即音频、图片、文字…

如何将redis部署在linux操作系统中:(十分详细的步骤)

一:通过虚拟机安装一个linux环境 注意:安装一个带有可视化界面的环境 将指标选中install centos7 按enter键 选择自己需要的语言 选中gui:桌面(可视化界面) 只需要配置软件设置即可,其他的则进行默认配置进行 root用…

【stm32疑难杂症】:Error: L6218E: Undefined symbol TIM_Cmd (referred from timer.o).

项目场景: 在使用工程是发现问题: ..\OBJ\OLED.axf: Error: L6218E: Undefined symbol TIM_Cmd (referred from timer.o). ..\OBJ\OLED.axf: Error: L6218E: Undefined symbol TIM_ITConfig (referred from timer.o). ..\OBJ\OLED.axf: Error: L6218E: …

少儿编程scratch

目录 少儿编程scratch 第一课 孙悟空72变 说绕口令的小猫 欢乐音乐会 海底世界 多变的章鱼哥 益虫与害虫 猫抓老鼠 监控报警器 神奇的画笔 小蝙蝠逃生记 森林里的体育课 寻找小狗哈哈 我是小小饲养员 青蛙王子 少儿编程scratch 第一课 需求描述:scratch的…

安卓缓存那些事情面试,一篇全部搞定

安卓缓存那些事情面试,一篇全部搞定 安卓缓存机制LruCache算法手写Bitmap的三级缓存一.为什么Bitmap三级缓存?二.原理三.代码 Bitmap的二次采样和质量压缩一.为什么二次采样二.哪二次采样三.代码:网络请求图片进行尺寸压缩四.质量压缩1.方法介绍2.案例&a…

携手中国电信打造 5G 智慧机场, ALVA Systems 创新 AR 应用闪耀云生态成果展

4 月 26 日,由国家网信办、国家发改委、科技部、工信部、国务院国资委、福建省人民政府共同主办,福州市人民政府等有关单位承办的第六届数字中国建设峰会数字福州暨生态大会在福州举办。 作为数字中国建设主力军之一,中国电信天翼云重磅亮相&…

Shell脚本编程入门--Day1

文章目录 什么是shell?变量环境变量的设置和显示特殊变量特殊的状态变量 什么是shell? 从技术角度,Shell的最简单定义:命令行解释器(command Interpreter)主要包含: 1, 将使用者的命令翻译给核…

COS 压测指南

COSBench 简介 COSBench 是一款由 Intel 开源,用于对象存储的压测工具。腾讯云对象存储(Cloud Object Storage,COS)作为兼容 S3 协议的对象存储系统,可使用该工具进行读写性能压测。 系统环境 工具推荐运行在 CentO…

PyCharm 下载安装教程(中文语言包)

文章目录 下载安装简单创建项目中文语言包 Py Charm是由JetBrains打造的一款Python IDE(Integrated Development Environment,集成开发环境) 下载 点击链接进入官网:https://www.jetbrains.com/pycharm/download/#sectionwindows …

PM864AK01-eA一极用于直流电压电平,地面是用于海底/地下电缆的永久返回路径

​ PM864AK01-eA一极用于直流电压电平,地面是用于海底/地下电缆的永久返回路径 高压直流输电 电力以交流电的形式产生和传输,但对于长距离传输,会产生很大的损耗,或者在两个交流系统无法同步的情况下。所以我们可以用直流输电的方…

【Java EE】-HTTP请求构造以及HTTPS的加密流程

作者:学Java的冬瓜 博客主页:☀冬瓜的主页🌙 专栏:【JavaEE】 分享: 在满园弥漫的沉静的光芒之前,一个人更容易看到时间,并看到自己的身影。——史铁生《我与地坛》 主要内容:构造http请求&…