鸢尾花分类和手写数字识别(K近邻)

news2025/1/5 14:46:03

鸢尾花分类

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
import pandas as pd
import mglearn

# 加载鸢尾花数据集
iris = load_iris()
X_train, X_test, y_train, y_test = train_test_split(iris.data,
                                                    iris.target,
                                                    test_size=0.3,
                                                    random_state=0)

# 导入数据集
iris_dataframe = pd.DataFrame(X_train, columns=iris.feature_names)
# 绘制散点矩阵图
grr = pd.plotting.scatter_matrix(iris_dataframe, # 要绘制散点矩阵图的特征数据
                                 c=y_train, # 指定颜色映射的依据
                                 figsize=(15, 15),
                                 marker='o',
                                 hist_kwds={'bins': 20}, # 设置直方图的参数,将直方图分为 20 个区间
                                 s=60,
                                 alpha=.8,
                                 cmap=mglearn.cm3) # 设置颜色映射,这里是使用 mglearn.cm3 颜色映射

在这里插入图片描述

from sklearn.neighbors import KNeighborsClassifier

knn = KNeighborsClassifier(n_neighbors = 3)
knn.fit(X_train,y_train)

KNeighborsClassifier(n_neighbors=3)

import numpy as np
y_pred = knn.predict(X_test)
print("Test set predictions:\n{}".format(y_pred))
print("Test set score:{:.2f}".format(np.mean(y_pred == y_test)))
Test set predictions:
[2 1 0 2 0 2 0 1 1 1 2 1 1 1 1 0 1 1 0 0 2 1 0 0 2 0 0 1 1 0 2 1 0 2 2 1 0
 2 1 1 2 0 2 0 0]
Test set score:0.98

手写数字识别

import torch
from torch.utils.data import DataLoader
import torchvision.datasets as dsets
import torchvision.transforms as transforms

#指定每次训练迭代的样本数量
batch_size = 100
transform = transforms.ToTensor()  #将图片转化为PyTorch张量

train_dataset = dsets.MNIST(root='./data',
                            train=True,
                            transform=transforms.ToTensor(),
                            download=False)
test_dataset = dsets.MNIST(root='./data',
                           train=False,
                           transform=transforms.ToTensor(),
                           download=False)
#加载数据
train_loader = DataLoader(dataset=train_dataset,
                          batch_size=batch_size,
                          shuffle=True)
test_loader = DataLoader(dataset=test_dataset,
                         batch_size=batch_size,
                         shuffle=True)
print("train data:",train_dataset.data.size())
print("train labels:",train_dataset.targets.size())
print("test data:",test_dataset.data.size())
print("test labels:",test_dataset.targets.size())
train data: torch.Size([60000, 28, 28])
train labels: torch.Size([60000])
test data: torch.Size([10000, 28, 28])
test labels: torch.Size([10000])
import matplotlib.pyplot as plt

# 看下第99张图是什么
train = train_loader.dataset.data[98]
plt.imshow(train, cmap=plt.cm.binary)
plt.show()
print(train_loader.dataset.targets[98])

在这里插入图片描述

tensor(3)
# 若不进行图像预处理:
import numpy as np
import operator
 
# KNN分类器构建
class KNNClassifier:
    def __init__(self):
        self.Xtr = None
        self.ytr = None
 
    def fit(self, X_train, y_train):
        self.Xtr = X_train
        self.ytr = y_train
 
    def predict(self, k, dis, X_test):
        assert dis == 'E' or dis == 'M'  # E代表欧氏距离, M代表曼哈顿距离。确保变量dis的值必须是'E'或'M',否则会抛出异常
        num_test = X_test.shape[0]
        labellist = []
 
        if dis == 'E':
            for i in range(num_test):
                distances = np.sqrt(np.sum(((self.Xtr - np.tile(X_test[i], (self.Xtr.shape[0],1))) ** 2), axis=1))
                nearest_k = np.argsort(distances)[:k]
                classCount = {self.ytr[i]: 0 for i in nearest_k}
                for i in nearest_k:
                    classCount[self.ytr[i]] += 1
                sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
                labellist.append(sortedClassCount[0][0])
 
        elif dis == 'M':
            for i in range(num_test):
                distances = np.sum(np.abs(self.Xtr - np.tile(X_test[i], (self.Xtr.shape[0], 1))), axis=1)
                nearest_k = np.argsort(distances)[:k]
                classCount = {self.ytr[i]: 0 for i in nearest_k}
                for i in nearest_k:
                    classCount[self.ytr[i]] += 1
                sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
                labellist.append(sortedClassCount[0][0])
 
        return np.array(labellist)

if __name__ == '__main__':
    X_train = train_loader.dataset.data.numpy()  # 转化为numpy
    X_train = X_train.reshape(X_train.shape[0], 28 * 28)
    y_train = train_loader.dataset.targets.numpy()
 
    X_test = test_loader.dataset.data[:1000].numpy()
    X_test = X_test.reshape(X_test.shape[0], 28 * 28)
    y_test = test_loader.dataset.targets[:1000].numpy()
 
    num_test = y_test.shape[0]
 
    knn = KNNClassifier()
    knn.fit(X_train, y_train)
 
    
    y_test_pred = knn.predict(5, 'M', X_test)
 
    num_correct = np.sum(y_test_pred == y_test)
    accuracy = float(num_correct) / num_test
    print('Got %d / %d correct => accuracy: %f' % (num_correct, num_test, accuracy))
Got 368 / 1000 correct => accuracy: 0.368000

KNN算法在计算距离时对特征的尺度非常敏感。如果图像的尺寸或像素值范围(亮度或颜色深度)不统一,可能会导致距离计算偏向于尺度较大的特征。如:将所有图像归一化到同一尺寸和/或将像素值标准化到同一范围(如0到1),可以确保不同的特征对距离的贡献均衡,从而使KNN分类器更公平、更准确。

KNN算法是基于距离度量(如欧氏距离或曼哈顿距离)来确定每个测试点的“邻居”,因此确保所有特征具有相似的尺度是至关重要的。对于KNN来说,由于它对数据的尺度敏感,选择均值方差归一化通常是更好的选择。

# 代码改进:添加一个均值方差归一化函数

def standardize_image(image): #均值方差归一化
    mean = np.mean(image)
    std = np.std(image)
    return (image - mean) / std
 # 图像(归一化)
train = train_loader.dataset.data[:1000].numpy()
digit_01 = train[33]
digit_02 = standardize_image(digit_01)
plt.imshow(digit_01, cmap=plt.cm.binary)
plt.show()
plt.imshow(digit_02, cmap=plt.cm.binary)
plt.show()
print(train_loader.dataset.targets[33])
print("Before standardization: mean = {}, std = {}".format(np.mean(digit_01), np.std(digit_01)))
print("After standardization: mean = {}, std = {}".format(np.mean(digit_02), np.std(digit_02)))

在这里插入图片描述

在这里插入图片描述

tensor(9)
Before standardization: mean = 27.007653061224488, std = 70.88341925375865
After standardization: mean = 5.890979314337566e-17, std = 1.0

虽然归一化前后的图像在视觉上似乎并无明显的差异,但通过打印归一化前后的像素值平均值和标准差可以发现标准化后,数据的平均值为5.890979314337566e-17,接近0(浮点数计算的微小误差,这在数值计算中可以视为0),这是标准化的预期结果,旨在将数据的均值中心化到0;标准差为1.0,确保数据的尺度一致。

if __name__ == '__main__':
    #训练数据
    X_train = train_loader.dataset.data.numpy()  #转化为numpy
    X_train = X_train.reshape(X_train.shape[0], 28 * 28)
    X_train = standardize_image(X_train) #均值方差归一化处理
    y_train = train_loader.dataset.targets.numpy()
 
    #测试数据
    X_test = test_loader.dataset.data[:1000].numpy()
    X_test = X_test.reshape(X_test.shape[0], 28 * 28)
    X_test = standardize_image(X_test)
    y_test = test_loader.dataset.targets[:1000].numpy()
 
    num_test = y_test.shape[0]
 
    knn = KNNClassifier()
    knn.fit(X_train,y_train)
 
   # y_test_pred = kNN_classify(5,'M',X_train,y_train,X_test)
    y_test_pred = knn.predict(5,'M',X_test)
 
    num_correct = np.sum(y_test_pred == y_test)
    accuracy = float(num_correct) / num_test
    print('Got %d / %d correct => accuracy: %f' % (num_correct,num_test,accuracy))
Got 950 / 1000 correct => accuracy: 0.950000

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1798012.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

指挥中心操作台厂家的优势有哪些

指挥中心操作台厂家的优势众多,它们以专业的技术、优质的产品和全面的服务,满足了各行各业对高效、安全、稳定指挥中心的需求。以下将从几个方面详细阐述指挥中心操作台厂家的优势。 指挥中心操作台厂家具备强大的研发实力。这些厂家通常拥有专业的研发团…

作为工程师的我,假装我很忙~(摸鱼软件推荐)

引言 最近IT行业内及(几)精(经)美(内)康(扛),多次内卷,造就了假装勤奋(忙碌)的假象。 为此,我推荐各位技术大佬&#xf…

國際知名榮譽顧問加入台灣分析集團總部,全面升級量子電腦Q系統

近期,國際知名的榮譽顧問正式加入台灣分析集團總部,利用相同的量子數據規格訊息數據庫,進行全方位的系統升級。此次升級後,量子電腦Q系統的精確預測和迅速反應能力提升了3.29%。透過高級的數據處理和技術分析,社群用戶將在瞬息萬變的市場中保持領先地位。 “量子電腦Q系統”由資…

小程序名片怎么生成?AI名片生成器源码系统 为企业店铺创建自己的数字名片

在数字化时代,小程序名片已经成为企业店铺展示自身形象、推广产品和服务的重要工具。分享一个AI名片生成器源码系统春哥AI雷达智能名片小程序系统企业商业运营版,含完整代码包和详细的图文安装部署搭建教程,新手也能轻松使用,源码…

PGL图学习之图游走类metapath2vec模型[系列五]

本项目链接:https://aistudio.baidu.com/aistudio/projectdetail/5009827?contributionType1 有疑问查看原项目 相关项目参考: 关于图计算&图学习的基础知识概览:前置知识点学习(PGL)系列一 https://aistudio.…

【数据结构】使用堆实现 求最小K个数

欢迎浏览高耳机的博客 希望我们彼此都有更好的收获 感谢三连支持! 首先我们会想到,通过建立小根堆,使堆顶元素为数组中的最小元素; 然后使堆顶元素出堆,循环K次; public int[] smallestK2(int[] arr, int …

Error: [WinError 2] 系統找不到指定的檔案

背景及相关说明 由于工作的需要,自己电脑上是多python版本环境,分别是python3.6.8,python3.8.8,python3.9.2,默认的环境是python3.6.8,现在想要安装一下paddleocr进行文字识别,然后打算使用创建…

Docker 容器 mysql 配置主从

1、前提条件 集群的条件下 服务器 172.16.11.195 13316:3306 服务器 172.16.11.196 13317:3306 配置好主数据库和从数据 2、配置主从数据库 2.1使用portainer 来管理容器 建立数据库密码 新增配置文件 # mysql-master.cnf [mysqld] server_id110 log-binmysql-binrela…

2024/6/7 英语每日一段

A recent review study examining a decade of research on technology and sleep found the link is more nuanced than previously thought. “It’s an interaction between a person’s vulnerabilities--and not everyone has these vulnerabilities--and the type of act…

公寓远程抄表系统:智能管理方法新的篇章

1.界定和功能 公寓远程抄表系统是一种前沿的自动化控制,它允许物业管理管理人员在远离现场部位收集和分析公寓里的电力能源使用数据,似水、电、气等。根据集成传感器、物联网产品和云计算,系统能实时检测并记录公寓的能耗状况,大…

Linux C语言:字符数组和字符串

一、字符数组 1、定义 字符数组是元素的数据类型为字符类型的数组 √ char c[10]; √ char ch[3][4]; 2、 字符数组初始化 字符数组的初始化 :√ 逐个字符赋值 3、字符串 C语言中无字符串变量,一般用字符数组处理字符串字符串结束标志&#xff1a…

mysql 数据库datetime 类型,转换为DO里面的long类型后,只剩下年了,没有了月和日

解决方法也简单: 自定义个一个 Date2LongTypeHandler <resultMap id="BeanResult" type="XXXX.XXXXDO"><result column="gmt_create" property="gmtCreate" jdbcType="DATE" javaType="java.lang.Long"…

ts类型声明文件、内置声明文件

1. ts类型声明文件 在ts中以d.ts为后缀的文件就是类型声明文件&#xff0c;主要作用是为js模块提供类型信息支持&#xff0c;从而获得类型提示 1.1 第三方包用ts编写的&#xff0c;会自动生成一个 .d.ts文件&#xff0c;进行类型声明 1.2 有些包不是用ts编写的&#xff0c;在…

Type-C PD芯片,带充电的OTG转接器方案 LDR6500

随着现代社会生活水平的飞速提升&#xff0c;人们的电子设备日益丰富多样。从智能手机、平板电脑到笔记本电脑、智能手表&#xff0c;再到无线耳机、游戏主机如任天堂Switch、索尼PS5等&#xff0c;这些电子设备已经成为了我们生活中不可或缺的一部分。然而&#xff0c;这些设备…

verilog阻塞和非阻塞语法

阻塞和非阻塞是FPGA硬件编程中需要了解的一个概念,绝大部分时候,因为非阻塞的方式更加符合时序逻辑设计的思想,有利于时钟和信号的同步,更加有利于时序收敛,所以除非特殊情况,尽量采用非阻塞方式。 1,非阻塞代码 非阻塞赋值,A和B是同时被赋值的,具体是说在时钟的上升…

mac安装nigix且配置 vue/springboot项目(本地/服务器)

一、mac安装Nigix 1. 查看是否存在 nginx 执行brew search nginx 命令查询要安装的软件是否存在 brew search nginx 2. 安装nginx brew install nginx 3. 查看版本 nginx -v 4. 查看信息 查看ngxin下载的位置以及nginx配置文件存放路径等信息 brew info nginx 下载的存…

鸿蒙OS流转之跨端迁移

前言 流转在HarmonyOS中泛指多设备分布式操作&#xff0c;也是HarmonyOS的亮点之一。流转按体验可以分为跨端迁移和多端协同&#xff0c;这里主要跟大家讲一下如何进行跨端迁移&#xff0c;以及我在项目开发过程中&#xff0c;所遇到的问题与解决方法。 开发步骤 在开发过程…

剪画小程序:图片去除文字,我用它只要10秒!

Hello&#xff0c;大家好呀&#xff01;我是不会画画的小画~ 图片上的文字该如何去除&#xff1f; 在工作或者学习中&#xff0c;我们常常需要处理一些图片文件&#xff0c;比如扫描的文件、 电子文档等。有时候&#xff0c;图片上可能会有文字&#xff0c;这时候需要将图片…

源码!源码!商城源码!如何选择

选择合适的商城源码是电商平台成功运营的关键因素之一。下面将从多个维度分析如何选择适合的商城源码&#xff1a; 安全性与稳定性 安全机制&#xff1a;安全的商城源码能保护用户数据和交易安全&#xff0c;避免信息泄露和被黑风险。 稳定运行&#xff1a;稳定的商城系统可以…

JAVA-LeetCode 热题-第24题:两两交换链表中的节点

思路&#xff1a; 定义三个指针&#xff0c;其中一个临时指针&#xff0c;进行交换两个节点的值&#xff0c;重新给临时指针赋值&#xff0c;移动链表 class Solution {public ListNode swapPairs(ListNode head) {ListNode pre new ListNode(0,head);ListNode temp pre;wh…