2024最新分别用sklearn和NumPy设计k-近邻法对鸢尾花数据集进行分类(包含详细注解与可视化结果)

news2024/10/9 6:29:59

本文章代码实现以下功能:

利用sklearn设计实现k-近邻法。

利用NumPy设计实现k-近邻法。

将设计的k-近邻法对鸢尾花数据集进行分类,通过准确率来验证所设计算法的正确性,并将分类结果可视化。

评估k取不同值时算法的精度,并通过可视化展示。

sklearn实现

# (1)数据导入,分割数据
# 导入iris数据集
from sklearn.datasets import load_iris
 
# 分割数据模块
from sklearn.model_selection import train_test_split
 
# (2)K最近(KNN,K-Nearest Neighbor)分类算法
from sklearn.neighbors import KNeighborsClassifier
 
# 加载iris数据集
data = load_iris()
# 导入数据和标签
data_X = data.data
data_y = data.target
 
# ———————画图,看第一和第三特征的分布——————————————————
import matplotlib.pyplot as plt
print(data.feature_names)
# print(data.data[:, 0])
# print(data.data[:, 2])
feature_1 = data.data[:, 0] # 设置横坐标标签 代表的是花萼长度
feature_3 = data.data[:, 2]  # 设置纵坐标标签 代表的是花瓣宽度
plt.scatter(feature_1, feature_3)  # 看数据分布
plt.show()
 
# _--------------------150个数据的行索引号0-149------------
plt.scatter(feature_1[:50], feature_3[:50], c='red')  # 第一类
plt.scatter(feature_1[50:100], feature_3[50:100], c='blueviolet')  # 第二类
plt.scatter(feature_1[100:], feature_3[100:], c='darkred')  # 第三类
plt.show()
 
# 分割数据
 
# 将完整数据集的70%作为训练集,30%作为测试集,
# 并使得测试集和训练集中各类别数据的比例与原始数据集比例一致(stratify分层策略),另外可通过设置shuffle=True 提前打乱数据。
X_train, X_test, y_train, y_test = train_test_split(data_X,
                                                    data_y,
                                                    random_state=12,
                                                    stratify=data_y,
                                                    test_size=0.3)
# 建立模型进行训练和预测
 
# 建立模型
knn = KNeighborsClassifier()
# knn=KNeighborsClassifier(n_neighbors=3)

 
# (3)训练模型
knn.fit(X_train, y_train)
print(knn.score(X_test, y_test))  # 计算模型的准确率
 
# (4)预测模型
y_pred = knn.predict(X_test)
print(y_pred - y_test)
 
# (5)评价— ——用accuracy_score计算准确率— ———————
from sklearn.metrics import accuracy_score
 
print(accuracy_score(y_test, y_pred))  # 也可以算正确率
 
print(accuracy_score(y_test, y_pred, normalize=False))  # 统计测试样本分类的个数
 

# 测试不同的k值
k_range = range(1, 31)  # 测试1到30的k值
accuracy = []  # 用于存储每个k值的准确率

for k in k_range:
    knn = KNeighborsClassifier(n_neighbors=k)
    knn.fit(X_train, y_train)
    y_pred = knn.predict(X_test)
    accuracy.append(accuracy_score(y_test, y_pred))

# 绘制k值与准确率的关系图
plt.figure(figsize=(10, 6))
plt.plot(k_range, accuracy, marker='o', linestyle='-', color='b')
plt.title('KNN Varying number of neighbors')
plt.xlabel('Number of neighbors, k')
plt.ylabel('Accuracy')
plt.xticks(k_range)
plt.grid(True)
plt.show()

# (6)保存和加载模型
import joblib
 
# 用joblib.dump保存模型
joblib.dump(knn, 'iris_KNN.pkl')
# # 用joblib.load加载已保存的模型
knn1 = joblib.load('iris_KNN.pkl')
# #测试读取后的Model
print(knn1.predict(data_X[0:1]))  # 预测第一个数据的类别
y_pred1 = knn1.predict(X_test)
print(y_pred1 - y_test)

# 可视化分类结果
plt.figure(figsize=(8, 6)) #设置图形的大小为8英寸宽和6英寸高。

#绘制实际类别
plt.scatter(feature_1[:50], feature_3[:50], c='red', label='Actual Setosa')
plt.scatter(feature_1[50:100], feature_3[50:100], c='blueviolet', label='Actual Versicolor')
plt.scatter(feature_1[100:], feature_3[100:], c='darkred', label='Actual Virginica')

#绘制预测类别
plt.scatter(X_test[y_pred == 0][:, 0], X_test[y_pred == 0][:, 2], c='lightcoral', marker='x', label='Predicted Setosa')
plt.scatter(X_test[y_pred == 1][:, 0], X_test[y_pred == 1][:, 2], c='lightblue', marker='^', label='Predicted Versicolor')
plt.scatter(X_test[y_pred == 2][:, 0], X_test[y_pred == 2][:, 2], c='pink', marker='s', label='Predicted Virginica')

plt.xlabel('Sepal Length')
plt.ylabel('Petal Width')
plt.title('KNN Classification Result on Iris Dataset')
plt.legend()
plt.grid(True)
plt.show()

实现结果

用Numpy实现

import numpy as np
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import matplotlib.pyplot as plt

# 加载iris数据集
data = load_iris()
X = data.data
y = data.target

# 分割数据集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=12, stratify=y)

# 定义欧氏距离函数
def euclidean_distance(x1, x2):
    return np.sqrt(np.sum((x1 - x2) ** 2))

# 实现KNN算法
class KNN:
    def fit(self, X, y):
        self.X_train = X
        self.y_train = y

    def predict(self, X, k=3):
        y_pred = [self._predict(x, k) for x in X]
        return np.array(y_pred)

    def _predict(self, x, k):
        # 计算x与训练集中每个点的距离
        distances = [euclidean_distance(x, x_train) for x_train in self.X_train]
        # 获取k个最近邻的索引
        k_indices = np.argsort(distances)[:k]
        # 获取这些最近邻对应的标签
        k_nearest_labels = [self.y_train[i] for i in k_indices]
        # 通过多数投票确定预测类别
        most_common = np.bincount(k_nearest_labels).argmax()
        return most_common

# 实例化KNN
knn = KNN()

# 训练模型
knn.fit(X_train, y_train)

# 进行预测
y_pred = knn.predict(X_test)

# 计算准确率
accuracy = accuracy_score(y_test, y_pred)
print(f"Accuracy: {accuracy:.2f}")

# 评估不同k值的准确率
k_values = range(1, 31)
accuracies = []

for k in k_values:
    knn = KNN()
    knn.fit(X_train, y_train)
    y_pred = knn.predict(X_test, k=k)
    accuracy = accuracy_score(y_test, y_pred)
    accuracies.append(accuracy)

# 可视化k值与准确率的关系
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.plot(k_values, accuracies, marker='o')
plt.xlabel('Number of Neighbors')
plt.ylabel('Accuracy')
plt.title('KNN Varying number of neighbors')
plt.grid(True)

# 可视化分类结果
plt.subplot(1, 2, 2)
plt.scatter(X_test[:, 2], X_test[:, 3], c=y_pred, cmap=plt.cm.Set1, edgecolor='k')
plt.title('KNN Classification Result')
plt.xlabel('Petal Length')
plt.ylabel('Petal Width')
handles, labels = plt.gca().get_legend_handles_labels()
plt.legend(handles, labels, loc='upper left')
plt.grid(True)

plt.tight_layout()
plt.show()

# 找出最高精度和对应的k值
max_accuracy = max(accuracies)
best_k = k_values[accuracies.index(max_accuracy)]
print(f"The best accuracy is {max_accuracy:.2f} with k = {best_k}")

实现结果

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2198325.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

基于SpringBoot的轻量级CRM管理系统+搭建教程

运行环境:jdk8 IntelliJ IDEA maven 宝塔面板 技术框架:SpringBoot lombok MyBatis 分页助手 freemarker SpringMVC SpringMail 系统功能: 这是一套轻量级的crm管理系统源码,基于SSM的SpringBoot架构。 这套源码用到很多潮流技术…

清华大模型公开课第二季 | Lecture 2 神经网络与大模型基础 Part 1

本文由readlecture.cn转录总结。ReadLecture专注于音、视频转录与总结,2小时视频,5分钟阅读,加速内容学习与传播。 大纲 引言 课程介绍 主讲人介绍 课程内容概述 神经网络基础知识 神经网络的定义和结构 神经元的基本单元 多维输入和权重…

从《被程序员耽搁的外卖员》看IT就业前景

《被程序员耽搁的外卖员》这部作品乍一看,似乎只是一个轻松幽默的故事,讲述一位外卖员因为学习编程而改变生活轨迹的小故事。然而,它在反映社会现实、揭示IT行业就业前景方面具有诸多启示。本文旨在通过此故事来分析当前IT就业的现状和未来发…

Spring Boot读取resources目录下文件(打成jar可用),并放入Guava缓存

1、文件所在位置&#xff1a; 2、需要Guava依赖&#xff1a; <dependency><groupId>com.google.guava</groupId><artifactId>guava</artifactId><version>23.0</version></dependency>3、启动时就读取放入缓存的代码&#xf…

​Leetcode 746. 使用最小花费爬楼梯​ 入门dp C++实现

问题&#xff1a;Leetcode 746. 使用最小花费爬楼梯 给你一个整数数组 cost &#xff0c;其中 cost[i] 是从楼梯第 i 个台阶向上爬需要支付的费用。一旦你支付此费用&#xff0c;即可选择向上爬一个或者两个台阶。 你可以选择从下标为 0 或下标为 1 的台阶开始爬楼梯。 请你…

Linux源码阅读笔记-以太网驱动分析

驱动框架 Linux 内核网络设备驱动框架分别为四个模块&#xff0c;分别为网络协议借口模块、网络设备接口模块、设备驱动功能模块和网络设备与媒介模块。具体视图如下&#xff1a; 网络协议接口模块&#xff1a;主要功能 网络接口卡接收和发送数据在 Linux 内核当中处理流程如下…

LoRA技术详解---附实战代码

LoRA技术详解—附实战代码 引言 随着大语言模型规模的不断扩大&#xff0c;如何高效地对这些模型进行微调成为了一个重要的技术挑战。Low-Rank Adaptation&#xff08;LoRA&#xff09;技术应运而生&#xff0c;它通过巧妙的低秩分解方法&#xff0c;显著减少了模型微调时需要…

UNIAPP popper气泡弹层【unibest框架下】vue3+typescript

看了下市场的代码&#xff0c;要么写的不怎么好&#xff0c;要么过于复杂。于是把市场的代码下下来了自己改。200行代码撸了个弹出层组件。兼容H5和APP。 功能&#xff1a; 1)只支持上下左右4个方向的弹层不支持侧边靠齐 2)不对屏幕边界适配 3)支持弹层外边点击自动隐藏 4)支持…

重学SpringBoot3-集成Redis(八)之限时任务(延迟队列)

更多SpringBoot3内容请关注我的专栏&#xff1a;《SpringBoot3》 期待您的点赞&#x1f44d;收藏⭐评论✍ 重学SpringBoot3-集成Redis&#xff08;八&#xff09;之限时任务&#xff08;延迟队列&#xff09; 1. 延迟任务的场景2. Redis Sorted Set基本原理3. 使用 Redis Sorte…

粗糙表面的仿真和处理软件

首款基于粗糙表面的仿真和处理软件&#xff0c;该软件具有三种方法&#xff0c;主要是二维数字滤波法&#xff0c;相位频谱法和共轭梯度法。可以分别仿真具有高斯和非高斯分布的粗糙表面&#xff0c;其中非高斯表面利用Johnson转换系统进行变换给定偏度和峰度。对生成的粗糙表面…

Mysql高级篇(下)——数据库备份与恢复

Mysql高级篇&#xff08;下&#xff09;——数据库备份与恢复 一、物理备份与逻辑备份1、物理备份2、逻辑备份3、对比4、总结 二、mysqldump实现逻辑备份1、mysqldump 常用选项2、mysqldump 逻辑备份语法&#xff08;1&#xff09;备份一个数据库&#xff08;2&#xff09;备份…

linux自动挂载tf卡

本人使用的是armbian系统&#xff0c;ssh工具使用的是finalshell&#xff0c;挂载的是一张64G TF卡。 1.查看系统所检测到的磁盘&#xff0c;这里的 sda1检测到的硬盘但是没有被挂载 lsblk //查看信息 2.在根目录新建一个目录tfcard用于挂载硬盘&#xff0c;命令如下&#xf…

【万字长文】Word2Vec计算详解(一)

【万字长文】Word2Vec计算详解&#xff08;一&#xff09; 写在前面 本文用于记录本人学习NLP过程中&#xff0c;学习Word2Vec部分时的详细过程&#xff0c;本文与本人写的其他文章一样&#xff0c;旨在给出Word2Vec模型中的详细计算过程&#xff0c;包括每个模块的计算过程&a…

电商选品/跟卖| 亚马逊商品类爬取

电商跟卖,最重要是了解哪些商品可以卖, 哪些商品不能卖, 为了更好了解商品信息,我们会经常爬取商品类目的信息. 需求 亚马逊类目信息链接爬虫 打开亚马逊类目信息地址 https://www.amazon.com/gp/new-releases/automotive/refzg_bsnr_nav_automotive_0 一直递归下去&#x…

云原生(四十七) | PHP软件安装部署

文章目录 PHP软件安装部署 一、PHP软件部署步骤 二、安装与配置PHP PHP软件安装部署 一、PHP软件部署步骤 第一步&#xff1a;安装 EPEL 仓库 与 Remi仓库 第二步&#xff1a;启用 Remi 仓库 第三步&#xff1a;安装 PHP、PHP-FPM 第四步&#xff1a;启动并开机启用 PH…

10.8 sql语句查询(未知的)

1.查询结果去重 关键字:distinct (放在查询的后面) AC: select distinct university from user_profile 2.查询结果限制返回行数 关键字:limit AC: select device_id from user_profile limit 0,2 3.将查询后的列重新命名 关键字:as AC: select device_id as user_infos…

wildcard使用教程,解决绝大多数普通人的海外支付难题

许多人可能已经注意到,国外的一些先进AI工具对国内用户并不开放。而想要使用这些工具,我们通常会面临两个主要障碍:一是网络访问的限制,二是支付问题。网络问题很容易解决&#xff0c;难的是如何解决在国内充值海外软件。 今天给大家推荐一个工具——wildcard&#xff0c;用它…

【CSS in Depth 2 精译_046】7.1 CSS 响应式设计中的移动端优先设计原则(下)

当前内容所在位置&#xff08;可进入专栏查看其他译好的章节内容&#xff09; 第一章 层叠、优先级与继承&#xff08;已完结&#xff09; 1.1 层叠1.2 继承1.3 特殊值1.4 简写属性1.5 CSS 渐进式增强技术1.6 本章小结 第二章 相对单位&#xff08;已完结&#xff09; 2.1 相对…

StoryMaker: Towards Holistic Consistent Characters in Text-to-image Generation

https://arxiv.org/pdf/2409.12576v1https://github.com/RedAIGC/StoryMaker 问题引入 针对的是文生图的模型&#xff0c;现在已经有方法可以实现指定人物id的情况下进行生成&#xff0c;但是还没有办法保持包括服装、发型等整体&#xff0c;本文主要解决这个问题&#xff1b…

时间卷积网络(TCN)原理+代码详解

目录 一、TCN原理1.1 因果卷积&#xff08;Causal Convolution&#xff09;1.2 扩张卷积&#xff08;Dilated Convolution&#xff09; 二、代码实现2.1 Chomp1d 模块2.2 TemporalBlock 模块2.3 TemporalConvNet 模块2.4 完整代码示例 参考文献 在理解 TCN 的原理之前&#xff…