【Pytroch】基于K邻近算法的数据分类预测(Excel可直接替换数据)

news2025/1/13 10:14:20

【Pytroch】基于K邻近算法的数据分类预测(Excel可直接替换数据)

  • 1.模型原理
  • 2.数学公式
  • 3.文件结构
  • 4.Excel数据
  • 5.下载地址
  • 6.完整代码
  • 7.运行结果

1.模型原理

K最近邻(K-Nearest Neighbors,简称KNN)是一种简单但常用的机器学习算法,用于分类和回归问题。它的核心思想是基于已有的训练数据,通过测量样本之间的距离来进行分类预测。在实现KNN算法时,可以使用PyTorch来进行计算和操作。

下面是使用PyTorch实现KNN算法的一般步骤:

  1. 准备数据集:首先,需要准备训练数据集,包括样本特征和对应的标签。

  2. 计算距离:对于每个待预测的样本,计算它与训练数据集中每个样本的距离。常见的距离度量包括欧氏距离、曼哈顿距离等。

  3. 排序与选择:将计算得到的距离按照从小到大的顺序进行排序,并选择距离最近的K个样本。

  4. 投票或平均:对于分类问题,选择K个样本中出现最多的类别作为预测结果;对于回归问题,选择K个样本的标签的平均值作为预测结果。

2.数学公式

当使用K最近邻(KNN)算法进行数据分类预测时,以下是其基本原理的数学描述:

  1. 距离度量:假设我们有一个训练数据集 D D D,其中包含 n n n 个样本。每个样本 x i x_i xi 都有 m m m 个特征,可以表示为 x i = ( x i 1 , x i 2 , … , x i m ) x_i = (x_{i1}, x_{i2}, \ldots, x_{im}) xi=(xi1,xi2,,xim)。对于一个待预测的样本 x new x_{\text{new}} xnew,我们需要计算它与训练集中每个样本的距离。常见的距离度量方式包括欧氏距离(Euclidean Distance)和曼哈顿距离(Manhattan Distance)等:

    • 欧氏距离: d ( x i , x new ) = ∑ j = 1 m ( x i j − x new , j ) 2 d(x_i, x_{\text{new}}) = \sqrt{\sum_{j=1}^m (x_{ij} - x_{\text{new},j})^2} d(xi,xnew)=j=1m(xijxnew,j)2

    • 曼哈顿距离: d ( x i , x new ) = ∑ j = 1 m ∣ x i j − x new , j ∣ d(x_i, x_{\text{new}}) = \sum_{j=1}^m |x_{ij} - x_{\text{new},j}| d(xi,xnew)=j=1mxijxnew,j

  2. 排序与选择:计算完待预测样本与所有训练样本的距离后,我们将距离按照从小到大的顺序排序。然后选择距离最近的 K K K 个训练样本。

  3. 投票或平均:对于分类问题,我们可以统计这 K K K 个样本中每个类别出现的次数,然后选择出现次数最多的类别作为预测结果。这就是所谓的“投票法”:

    • y ^ = argmax c ∑ i = 1 K I ( y i = c ) \hat{y} = \text{argmax}_{c} \sum_{i=1}^{K} I(y_i = c) y^=argmaxci=1KI(yi=c)

    其中, y ^ \hat{y} y^ 是预测的类别, y i y_i yi 是第 i i i 个样本的真实类别, c c c 是类别。

    对于回归问题,我们可以选择 K K K 个样本的标签的平均值作为预测结果。

总结起来,K最近邻算法的基本原理是通过测量样本之间的距离来进行分类预测。对于分类问题,通过投票法确定预测类别;对于回归问题,通过取标签的平均值来预测数值。在实际应用中,需要选择合适的距离度量和适当的 K K K 值,以及进行必要的数据预处理和特征工程。

3.文件结构

在这里插入图片描述

iris.xlsx						% 可替换数据集
Main.py							% 主函数

4.Excel数据

在这里插入图片描述

5.下载地址

- Excle资源下载地址

6.完整代码

import torch
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, confusion_matrix
import matplotlib.pyplot as plt


def knn(X_train, y_train, X_test, k=5):
    X_train = torch.tensor(X_train, dtype=torch.float32)
    X_test = torch.tensor(X_test, dtype=torch.float32)
    y_train = torch.tensor(y_train, dtype=torch.long)

    predictions = []

    for i in range(X_test.shape[0]):
        distances = torch.sum((X_train - X_test[i]) ** 2, dim=1)
        _, indices = torch.topk(distances, k, largest=False)  # 获取距离最小的k个邻居的索引
        knn_labels = y_train[indices]
        pred = torch.mode(knn_labels).values  # 投票选出标签
        predictions.append(pred.item())

    return predictions

def plot_confusion_matrix(conf_matrix, classes):
    plt.figure(figsize=(8, 6))
    plt.imshow(conf_matrix, cmap=plt.cm.Blues, interpolation='nearest')
    plt.title("Confusion Matrix")
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes)
    plt.yticks(tick_marks, classes)
    plt.xlabel("Predicted Label")
    plt.ylabel("True Label")
    plt.tight_layout()
    plt.show()

def plot_predictions_vs_true(y_true, y_pred):
    plt.figure(figsize=(10, 6))
    plt.plot(y_true, 'go', label='True Labels')
    plt.plot(y_pred, 'rx', label='Predicted Labels')
    plt.title("True Labels vs Predicted Labels")
    plt.xlabel("Sample Index")
    plt.ylabel("Class Label")
    plt.legend()
    plt.show()

def main():
    # 读取Data.xlsx文件并加载数据
    data = pd.read_excel("iris.xlsx")

    # 划分特征值和标签
    features = data.iloc[:, :-1].values
    labels = data.iloc[:, -1].values

    # 将数据集拆分为训练集和测试集
    X_train, X_test, y_train, y_test = train_test_split(features, labels, test_size=0.2, random_state=42)

    y_pred = knn(X_train, y_train, X_test, k=5)
    accuracy = accuracy_score(y_test, y_pred)
    print("训练集准确率:{:.2%}".format(accuracy))


    conf_matrix = confusion_matrix(y_test, y_pred)
    print("混淆矩阵:")
    print(conf_matrix)

    classes = np.unique(y_test)
    plot_confusion_matrix(conf_matrix, classes)
    plot_predictions_vs_true(y_test, y_pred)

if __name__ == "__main__":
    main()

7.运行结果

在这里插入图片描述

在这里插入图片描述
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/872207.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

python编程基础与案例集锦,python编程入门经典

大家好,本文将围绕python编程基础与案例集锦展开说明,python编程入门与案例详解是一个很多人都想弄明白的事情,想搞清楚python入门程序例子需要先了解以下几个事情。 【程序1】 题目:输入一行字符,分别统计出其中英文字…

【c++】七夕快到了却还没对象?手把手教你new一个出来!

前言 本章给大家带来的是C内存管理。在C语言阶段,我们经常使用malloc,calloc,realloc,free进行内存管理。但是,C语言的内存管理存在很多缺陷,会对程序的稳定性和安全性造成影响。 不过,C语言的…

cubemx hal stm32 舵机 可减速 任意位置停止 驱动代码

CubeMX配置 对于 STM32 F407VE 这里的84是来自APB1那路2倍频得到: 代码部分 两个舵机都是180度的 servo.c #include "servo.h" #include "tim.h" #include "stdio.h"__IO uint32_t g_SteerUWT[2] {0}; uint16_t g_SteerDeg[…

开发一个RISC-V上的操作系统(七)—— 硬件定时器(Hardware Timer)

目录 往期文章传送门 一、硬件定时器 硬件实现 软件实现 二、上板测试 往期文章传送门 开发一个RISC-V上的操作系统(一)—— 环境搭建_riscv开发环境_Patarw_Li的博客-CSDN博客 开发一个RISC-V上的操作系统(二)—— 系统引导…

Monkey测试真的靠谱吗?

Monkey测试,顾名思义,就是模拟一只猴子在键盘上乱敲,从而达到测试被测系统的稳定性。Monkey测试,是Android自动化测试的一种手段,Monkey测试本身非常简单,Android SDK 工具支持adb Shell命令,实…

jvm里的内存溢出

目录 堆溢出 虚拟机栈和本地方法栈溢出(栈溢出很少出现) 方法区和运行时常量池溢出 本机内存直接溢出(实际中很少出现、了解即可) 堆溢出 堆溢出:最常见的是大list,list里面有很多元素 堆溢出该怎么解决…

C++红黑树

一、红黑树的概念 红黑树是一种二叉搜索树,在其每个节点上增加一个存储位用于表示节点的颜色,可以是Red或Black 通过对任何一条从根到叶子的路径上的各个节点着色方式的限制,红黑树确保没有一条路径比其他路径长两倍 红黑树的性质&#xff…

SQL | 汇总数据

9-汇总数据 9.1-聚集函数 在实际开发过程中,可能会遇到下面这些情况: 确定大于某个值的有多少行数据,比如游戏排行榜,查询玩家排行多少名。 获取表中某些行的和,比如双十一当天,某个用户总订单价格是多少…

208、仿真-51单片机脉搏心率与心电报警Proteus仿真设计(程序+Proteus仿真+配套资料等)

毕设帮助、开题指导、技术解答(有偿)见文未 目录 一、硬件设计 二、设计功能 三、Proteus仿真图 四、程序源码 资料包括: 需要完整的资料可以点击下面的名片加下我,找我要资源压缩包的百度网盘下载地址及提取码。 方案选择 单片机的选择 方案一&a…

回归预测 | MATLAB实现基于PSO-LSSVM-Adaboost粒子群算法优化最小二乘支持向量机结合AdaBoost多输入单输出回归预测

回归预测 | MATLAB实现基于PSO-LSSVM-Adaboost粒子群算法优化最小二乘支持向量机结合AdaBoost多输入单输出回归预测 目录 回归预测 | MATLAB实现基于PSO-LSSVM-Adaboost粒子群算法优化最小二乘支持向量机结合AdaBoost多输入单输出回归预测预测效果基本介绍模型描述程序设计参考…

MySQL8安装和删除教程 保姆级(Windows)

下载 官网: mysql官网点击Downloads->MySQL Community(GPL) Downloads->MySQL Community Server(或者点击MySQL installer for Windows) Windows下有两种安装方式 在线安装 一般带有 web字样 这个需要联网离线安装 一般没有web字样 安装 下载好之后,版本号可以不一样&…

Web 自动化测试学会这一招,下班至少早一小时

♥ 前 言 大家都知道,我们在通过 Selenium 执行 Web 自动化测试时,每次都需要启动/关闭浏览器,如果是多线程执行还会同时打开多个,比较影响工作的正常进行。那有没有办法可以不用让浏览器的自动化执行干扰我们的工作呢&#xf…

graphab 教程 ——生成廊道

Graphab软件包括图谱创建、基于图谱的连通性计算、分析与推广、制图四个模块。Graphab软件的图谱创建基于栅格数据进行,包括斑块识别和连接建立两个步骤。Graphab 软件可识别的栅格数据格式包括TIFF、ASCI和RST,栅格像元记录数值用于识别斑块类型,识别规则可以选择四邻域或八邻…

B100-技能提升-线程池分布式锁

目录 线程池什么是线程池?为什么用线程池?线程池原理常见四种线程池和自定义线程池 线程池 什么是线程池? 池化技术 为什么用线程池? 1 由于设置最大线程数,防止线程过多而导致系统崩溃。 2 线程复用,不需要频繁创建或销毁…

WebAPIs 第三天

DOM 事件进阶 事件流事件委托其他事件元素尺寸与位置 一.事件流 事件流与两个阶段说明事件捕获事件冒泡阻止冒泡解绑事件 1.1 事件流与两个阶段说明 ① 事件流:指的是事件完整执行过程中的流动路径 ② 事件流分为捕获阶段和冒泡阶段 1.2 事件捕获 从DOM根元素…

opsForHash() 与 opsForValue 请问有什么区别?

&#x1f449;&#xff1a;&#x1f517;官方API参考手册 如图&#xff0c;opsForHash()返回HashOperations<K,HK,HV>但是 opsForValue()返回ValueOperations<K,V>… 区别就是opsForHash的返回值泛型中有K,HK,HV,其中K是Redis指定的某个数据库里面某一个关键字(由…

三分钟带你快速掌握MongoDB数据库和集合基础操作

文章目录 前言一、案例需求二、数据库操作1. 选择和创建数据库2. 数据库的删除 三、集合操作1. 集合的显式创建&#xff08;了解&#xff09;2. 集合的隐式创建3. 集合的删除 总结 前言 为了巩固所学的知识&#xff0c;作者尝试着开始发布一些学习笔记类的博客&#xff0c;方便…

太牛了!国内版ChatDoc企业知识库,直接操作Doc、Docx、PDF、txt等文件

自ChatGPT问世以来&#xff0c;国外就有ChatPDF、ChatDOC等基于文档问答的项目&#xff0c;但是国内还一直处于对话类产品的研发中。 贵州猿创科技研发了基于本地向量模型的ChatDoc知识库系统&#xff0c;可以直接上传Doc、Docx、PDF、txt、网页链接等进行问答。 体验地址&…

【算法篇C++实现】常见查找算法

文章目录 &#x1f680;一、预备知识⛳&#xff08;一&#xff09;查找的定义⛳&#xff08;二&#xff09;数组和索引 &#x1f680;二、二分查找&#x1f680;三、穷举搜索&#x1f680;四、并行搜索⛳&#xff08;一&#xff09;并发的基本概念⛳&#xff08;二&#xff09;…

修改VS Code终端的显示行数

文章目录 前言修改VS Code终端显示行数参考 前言 在我们使用VS Code运行代码的过程中&#xff0c;有时需要再终端中显示很多的运行过程信息或者结果。然而&#xff0c;VS Code的终端默认显示1000行的内容&#xff0c;随着显示内容的增多&#xff0c;之前的内容就丢失了。为了解…