机器学习:k近邻

news2025/3/14 1:12:38

所有代码和文档均在golitter/Decoding-ML-Top10: 使用 Python 优雅地实现机器学习十大经典算法。 (github.com),欢迎查看。

K 邻近算法(K-Nearest Neighbors,简称 KNN)是一种经典的机器学习算法,主要用于分类和回归任务。它的核心思想是:给定一个新的数据点,通过查找训练数据中最接近的 K 个邻居,并根据这些邻居的标签来预测新数据点的标签。

KNN 是一种 基于实例的学习(Instance-based learning)算法。在训练阶段,它并不构建显式的模型,而是将训练数据存储起来,在预测阶段计算待预测点与训练集中所有点的距离,然后选择 K 个最近的邻居,根据邻居的标签进行投票或平均来做出预测。

KNN 的优点在于其简单易懂、无需训练过程,并且适用于大多数任务。它能够处理复杂的非线性问题,不依赖数据分布假设,能够很好地适应复杂的决策边界。

然而,KNN 的缺点也很明显。它的计算开销大,因为每次预测都需要计算所有训练数据的距离,导致在大数据集上表现不佳。此外,KNN 需要存储所有训练数据,占用较大的内存空间,并且对异常值敏感,可能会影响预测结果的准确性。

KNN算法步骤:

  1. 选择 K 个邻居的数量,K 值通常是一个奇数,以避免平票的情况。
  2. 计算待预测数据点与训练数据集中每个点的距离。
  3. 根据计算出的距离选择 K 个最接近的点。
  4. 对于分类任务,返回 K 个邻居中最多的类别;对于回归任务,返回 K 个邻居标签的均值。

代码实现

数据处理:使用iris.data数据集,用PCA进行降维。

import numpy as np
import pandas as pd


def pca(X: np.array, n_components: int) -> np.array:
	"""
	PCA 进行降维。
	"""
	# 1. 数据标准化(去均值)
	X_mean = np.mean(X, axis=0)
	X_centered = X - X_mean

	# 2. 计算协方差矩阵
	covariance_matrix = np.cov(X_centered, rowvar=False)

	# 3. 计算特征值和特征向量
	eigenvalues, eigenvectors = np.linalg.eig(covariance_matrix)

	# 4. 按特征值降序排序
	sorted_indices = np.argsort(eigenvalues)[::-1]
	top_eigenvectors = eigenvectors[:, sorted_indices[:n_components]]

	# 5. 投影到新空间
	X_pca = np.dot(X_centered, top_eigenvectors)

	return X_pca


def get_data():
	data = pd.read_csv('iris.csv', header=None)
	# print(data.dtypes)
	unq = data.iloc[:, -1].unique()
	for i, u in enumerate(unq):
		data.iloc[:, -1] = data.iloc[:, -1].apply(lambda x: i if x == u else x)

	# print(data.sample(5))
	xuanze = np.random.choice([True, False], len(data), replace=True, p=[0.8, 0.2])
	train_data = data[xuanze]
	test_data = data[~xuanze]
	train_data = np.array(
		train_data,
		dtype=np.float32,
	)
	test_data = np.array(test_data, dtype=np.float32)
	# 归一化
	train_data[:, :-1] = (train_data[:, :-1] - train_data[:, :-1].mean(axis=0)) / train_data[:, :-1].std(axis=0)
	test_data[:, :-1] = (test_data[:, :-1] - test_data[:, :-1].mean(axis=0)) / test_data[:, :-1].std(axis=0)
	return (
		pca(train_data[:, :-1], 2),
		train_data[:, -1].astype(np.int32),
		pca(test_data[:, :-1], 2),
		test_data[:, -1].astype(np.int32),
	)


if __name__ == '__main__':
	x_train, y_train, x_test, y_test = get_data()
	print(y_train.dtype)
	print(x_test, y_test)
	print(x_train.shape, y_train.shape)

knn过程:

from data_processing import get_data
import numpy as np
import matplotlib.pyplot as plt


def euclidean_distance(x_train: np.array, x_test: np.array) -> np.array:
	"""
	计算欧拉距离
	"""
	return np.sqrt(np.sum((x_train - x_test) ** 2, axis=1))


def knn(k: int, x_train: np.array, y_train: np.array, x_test: np.array) -> np.array:
	"""
	k近邻算法
	"""
	predictions = []
	for test in x_test:
		distances = euclidean_distance(x_train, test)
		nearest_indices = np.argsort(distances)[:k]  # 返回最近的k个点的索引
		nearest_labels = y_train[nearest_indices]  # 返回最近的k个点的标签
		prediction = np.argmax(np.bincount(nearest_labels))  # 返回最近的k个点中出现次数最多的标签
		predictions.append(prediction)
	return np.array(predictions)


def accuracy(predictions: np.array, y_test: np.array) -> float:
	"""
	计算准确率
	"""
	return np.sum(predictions == y_test) / len(y_test)


if __name__ == '__main__':
	k = 5
	x_train, y_train, x_test, y_test = get_data()
	predictions = knn(k, x_train, y_train, x_test)
	acc = accuracy(predictions, y_test)
	print(f'准确率为: {acc * 100:.2f}')

	# 绘制训练数据
	plt.scatter(x_train[:, 0], x_train[:, 1], c=y_train, cmap='viridis', marker='o', label='Train Data', alpha=0.7)

	# 绘制测试数据
	plt.scatter(x_test[:, 0], x_test[:, 1], c=y_test, cmap='coolwarm', marker='x', label='Test Data', alpha=0.7)

	# 绘制预测结果
	plt.scatter(
		x_test[:, 0],
		x_test[:, 1],
		c=predictions,
		cmap='coolwarm',
		marker='.',
		edgecolor='black',
		alpha=0.7,
		label='Predictions',
	)

	# 添加标题和标签
	plt.title('KNN Classification Results')
	plt.xlabel('Feature 1')
	plt.ylabel('Feature 2')
	plt.legend()

	# 显示图形
	plt.show()

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2300302.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

讯飞唤醒+VOSK语音识别+DEEPSEEK大模型+讯飞离线合成实现纯离线大模型智能语音问答。

在信息爆炸的时代,智能语音问答系统正以前所未有的速度融入我们的日常生活。然而,随着数据泄露事件的频发,用户对于隐私保护的需求日益增强。想象一下,一个无需联网、即可响应你所有问题的智能助手——这就是纯离线大模型智能语音…

Day4 25/2/17 MON

【一周刷爆LeetCode,算法大神左神(左程云)耗时100天打造算法与数据结构基础到高级全家桶教程,直击BTAJ等一线大厂必问算法面试题真题详解(马士兵)】https://www.bilibili.com/video/BV13g41157hK?p4&v…

HTML【详解】input 标签

input 标签主要用于接收用户的输入,随 type 属性值的不同,变换其具体功能。 通用属性 属性属性值功能name字符串定义输入字段的名称,在表单提交时,服务器通过该名称来获取对应的值disabled布尔值禁用输入框,使其无法被…

Jvascript网页设计案例:通过js实现一款密码强度检测,适用于等保测评整改

本文目录 前言功能预览样式特点总结:1. 整体视觉风格2. 密码输入框设计3. 强度指示条4. 结果文本与原因说明 功能特点总结:1. 密码强度检测2. 实时反馈机制3. 详细原因说明4. 视觉提示5. 交互体验优化 密码强度检测逻辑Html代码Javascript代码 前言 能满…

用React实现一个登录界面

使用React来创建一个简单的登录表单。以下是一个基本的React登录界面示例: 1. 设置React项目 如果你还没有一个React项目,你可以使用Create React App来创建一个。按照之前的步骤安装Create React App,然后创建一个新项目。 2. 创建登录组…

图论:tarjan 算法求解强连通分量

题目描述 有一个 n n n 个点, m m m 条边的有向图,请求出这个图点数大于 1 1 1 的强连通分量个数。 输入格式 第一行为两个整数 n n n 和 m m m。 第二行至 m 1 m1 m1 行,每一行有两个整数 a a a 和 b b b,表示有一条…

Java:单例模式(Singleton Pattern)及实现方式

一、单例模式的概念 单例模式是一种创建型设计模式,确保一个类只有一个实例,并提供一个全局访问点来访问该实例,是 Java 中最简单的设计模式之一。该模式常用于需要全局唯一实例的场景,例如日志记录器、配置管理、线程池、数据库…

Python爬虫实战:股票分时数据抓取与存储 (1)

在金融数据分析中,股票分时数据是投资者和分析师的重要资源。它能够帮助我们了解股票在交易日内的价格波动情况,从而为交易决策提供依据。然而,获取这些数据往往需要借助专业的金融数据平台,其成本较高。幸运的是,通过…

将图片base64编码后,数据转成图片

将图片数据进行base64编码后,可以在浏览器上查看图片,只需在前端加上data:image/png;base64,即可 在线工具: Base64转图片 - 加菲工具

天翼云910B部署DeepSeek蒸馏70B LLaMA模型实践总结

一、项目背景与目标 本文记录在天翼云昇腾910B服务器上部署DeepSeek 70B模型的全过程。该模型是基于LLaMA架构的知识蒸馏版本,模型大小约132GB。 1.1 硬件环境 - 服务器配置:天翼云910B服务器 - NPU:8昇腾910B (每卡64GB显存) - 系统内存&…

Jetson Agx Orin平台preferred_stride调试记录--1924x720图像异常

1.问题描述 硬件: AGX Orin 在Jetpack 5.0.1和Jetpack 5.0.2上测试验证 图像分辨率在1920x720和1024x1920下图像采集正常 但是当采集图像分辨率为1924x720视频时,图像输出异常 像素格式:yuv_uyvy16 gstreamer命令如下 gst-launch-1.0 v4l2src device=/dev/video0 ! …

DeepSeek冲击(含本地化部署实践)

DeepSeek无疑是春节档最火爆的话题,上线不足一月,其全球累计下载量已达4000万,反超ChatGPT成为全球增长最快的AI应用,并且完全开源。那么究竟DeepSeek有什么魔力,能够让大家趋之若鹜,他又将怎样改变世界AI格…

CF 144A.Arrival of the General(Java实现)

题目分析 一个n个身高数据,问最高的到最前面,最矮的到最后面的最短交换次数 思路分析 首先,如果数据有重复项,例如示例二中,最矮的数据就是最后一个出现的数据位置,最高的数据就是最先出现的数据位置&…

set的使用(c++)

STL里面已经为我们实现了两种红黑树,一种是存储关键字的set,另一种是存储双关键字的map,今天主要来了解set,无论是set还是map后面都跟一个multi,它们区别是set 不能存相同元素, multiset 可以存相同的元素&…

IDEA单元测试插件 SquareTest 延长试用期权限

SquareTest是一款强大的IDEA单元测试生成插件工具,具体使用方法就不过多介绍了,这里主要介绍变更试用期,方便大家使用 配置信息 我的电脑安装前提配置条件 IntelliJ IDEA 2023.2windows 系统 软件安装 IntelliJ IDEA 直接安装插件Squar…

25/2/17 <嵌入式笔记> 桌宠代码解析

这个寒假跟着做了一个开源的桌宠,我们来解析下代码,加深理解。 代码中有开源作者的名字。可以去B站搜着跟着做。 首先看下main代码 #include "stm32f10x.h" // Device header #include "Delay.h" #include &quo…

油田安全系统:守护能源生命线的坚固壁垒

油田安全系统:不可或缺的能源护盾 在能源领域,油田作为国家重要的能源供应基地,其安全生产的重要性不言而喻。油田安全系统犹如一道坚固的护盾,全方位守护着人员生命、企业财产以及生态环境,是油田平稳运行与可持续发展…

【故障处理】- 执行命令crsctl query crs xxx一直hang

【故障处理】- 执行命令crsctl query crs xxx一直hang 一、概述二、故障处理三、解决方法 一、概述 Oracle RAC环境中,遇到执行crsctl query crs xxx等相关命令不返回任何结果,一直hang在那里。系统下执行命令ps -ef |grep crsctl query crs softwarever…

JMeter工具介绍、元件和组件的介绍

Jmeter功能概要 JDK常用文件目录介绍 Bin目录:存放可执行文件和配置文件 Docs目录:是Jmeter的API文档,用于开发扩展组件 printable_docs目录:用户帮助手册 lib目录:存放JMeter依赖的jar包和用户扩展所依赖的Jar包…