学习使用Sklearn【LDA】线性判别分析,对iris数据分类!

news2025/1/5 9:33:49

数据集、代码均来自kaggle。地址:https://www.kaggle.com/datasets/himanshunakrani/iris-dataset?resource=download

🚀 揭示线性分类器的力量:线性判别分析的探索

欢迎来到线性分类器的世界和线性判别分析(LDA)的迷人领域!🌟在本笔记本中,我们将开始一场激动人心的冒险,揭开这些强大算法的内部工作原理,这些算法构成了许多机器学习应用程序的支柱。

线性分类器是机器学习领域的基本工具,是二元分类、多类分类等任务的基石。但你有没有想过,在这些看似简单却非常有效的模型背后发生了什么?🧐这就是线性判别分析发挥作用的地方。

和我一起深入研究线性分类器背后的原理,揭示线性判别分析背后的魔力。🎩准备好解开谜团,获得见解,并扩展您对机器学习中这些基本技术的理解!

所以,废话不多说,让我们一起深入了解线性分类器的秘密吧!💡

线性分类器

从本质上讲,简单的线性分类器的目标是找到一个区分特征空间中不同类的决策边界。在数学上,对于二值分类,这个边界表示为一个超平面:

w 0 + w 1 x 1 + w 2 x 2 + ⋯ + w m x m = 0 w_0 + w_1x_1 + w_2x_2 + \dots + w_mx_m = 0 w0+w1x1+w2x2++wmxm=0

其中, ( w 0 , w 1 , … , w m ) (w_0, w_1, \dots, w_m) (w0,w1,,wm) 被称为权重, ( x 1 , x 2 , … , x m ) (x_1, x_2, \dots, x_m) (x1,x2,,xm) 被称为特征.

在整个旅程中,我们将深入研究线性分类器的数学,探索优化,损失函数和梯度下降。🎓让我们一起深入了解线性分类器的简单性和强大功能!💫

理解线性判别分析(LDA)

线性判别分析(LDA)是一种用于降维和分类的强大技术。📊与简单的线性分类器不同,LDA考虑来自不同类别的数据点的分布来寻找最优决策边界。

在其核心,LDA寻求最大化类之间的分离,同时最小化每个类内的方差。在数学上,LDA的目标是找到一个最大类间分散和最小类内分散的投影。

给定一组 ( N ) (N) (N)数据点, ( m ) (m) (m)特征和 ( K ) (K) (K)类,LDA通过最大化以下准则来计算最优投影矩阵 ( W ) (W) (W):

J ( W ) = Tr ( W T S B W ) Tr ( W T S W W ) J(W) = \frac{{\text{Tr}(W^T S_B W)}}{{\text{Tr}(W^T S_W W)}} J(W)=Tr(WTSWW)Tr(WTSBW)

其中, ( S B ) (S_B) (SB)表示类间散点矩阵, ( S W ) (S_W) (SW)表示类内散点矩阵。

最优投影矩阵 ( W ) (W) (W)可通过求解广义特征值问题得到:
S B W = λ S W W S_B W = \lambda S_W W SBW=λSWW

一旦计算出投影矩阵 ( W ) (W) (W), LDA将原始数据投影到这个低维子空间上。然后可以在这个简化的空间中使用简单的线性分类器进行分类。

通过利用数据的统计属性,LDA提供了一种健壮的分类和降维方法。🎓让我们一起探索LDA的复杂性,释放它的潜力!!

导入所需依赖项

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from sklearn.linear_model import Lasso
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler
import warnings
warnings.filterwarnings('ignore')

加载数据集

df = pd.read_csv('./iris.csv')
# 查看数据行数和列数
df.shape
(150, 5)
df.head()
sepal_lengthsepal_widthpetal_lengthpetal_widthspecies
05.13.51.40.2setosa
14.93.01.40.2setosa
24.73.21.30.2setosa
34.63.11.50.2setosa
45.03.61.40.2setosa
# 查看类别
df['species'].unique()
array(['setosa', 'versicolor', 'virginica'], dtype=object)
# 查看每一列有多少null值
df.isnull().sum()
sepal_length    0
sepal_width     0
petal_length    0
petal_width     0
species         0
dtype: int64

数据处理

# 为了理解方便,只保留 species 中类别为 setosa、versicolor 的数据
df = df[df['species'].isin(['setosa', 'versicolor'])]
# 将类别为versicolor替换为0,类别为setosa替换为1
df['species'].replace({'versicolor': 0,'setosa': 1}, inplace=True)
# 划分数据集为特征集X,以及目标集y
columns_y = ['species']
X = df.drop(columns=columns_y, inplace=False)
y = df['species']
# 划分数据为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, random_state=42)

基本线性分类器

class LinearClassifier:
    def __init__(self):
        self.w = None
        self.bias = None
    
    def fit(self, X, y):
        class_0 = X[y == 0]
        class_1 = X[y == 1]
        # 两种方法计算每个类别的均值
        centroid_0 = np.mean(class_0, axis=0) 
        centroid_1 = np.sum(class_1, axis=0) / len(class_1) 
        
        # 求法向量w作为两类的质心之差
        self.w = centroid_1 - centroid_0
        
        # 求截距w0
        dist_to_centroid_0 = np.linalg.norm(centroid_0)
        dist_to_centroid_1 = np.linalg.norm(centroid_1)
        self.bias = 0.5 * (dist_to_centroid_0 - dist_to_centroid_1)
        
    def predict(self, X):
        # 基于线性分选机的分类预测
        predictions = np.dot(X, self.w) + self.bias
        return np.where(predictions >= 0, 1, 0)
classifier = LinearClassifier()

classifier.fit(X_train, y_train)

y_pred = classifier.predict(X_test)

accuracy = np.mean(y_pred == y_test)
print("Accuracy:", accuracy)
Accuracy: 0.44

LDA

df = pd.read_csv('./iris.csv')
df = df[df['species'].isin(['setosa', 'versicolor'])]
df['species'].replace({'versicolor': -1}, inplace=True)
df['species'].replace({'setosa': 1}, inplace=True)
columns_to_drop = ['species']
X = df.drop(columns=columns_to_drop, inplace=False)
y = df['species']
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, random_state=42)
y_test.unique()
array([-1,  1])
class FishersLDA:
    def __init__(self):
        self.w = None
        self.b = None  

    def fit(self, X, y):
        # 按照类别划分数据
        X0 = X[y == -1]
        X1 = X[y == 1]

        # 计算每个类的均值
        mean0 = np.mean(X0, axis=0)
        mean1 = np.mean(X1, axis=0)

        # 计算类内散点矩阵
        Sw = np.dot((X0 - mean0).T, (X0 - mean0)) + np.dot((X1 - mean1).T, (X1 - mean1))

        # 计算Fisher线性判别式
        self.w = np.dot(np.linalg.inv(Sw), mean1 - mean0)
        self.b = - 0.5 * (np.dot(mean0, np.dot(np.linalg.inv(Sw), mean0)) - np.dot(mean1, np.dot(np.linalg.inv(Sw), mean1)))

    def predict(self, X):
        if self.w is None or self.b is None:
            raise Exception("Model not trained yet!")

        # 计算判别函数
        f_x = np.dot(X, self.w) - self.b

        # 基于判别函数的符号进行分类
        y_pred = np.sign(f_x)

        return y_pred.astype(int)
    
    def get_z_projection(self, X):
        if self.w is None or self.b is None:
            raise Exception("Model not trained yet!")

        # 计算判别函数
        f_x = np.dot(X, self.w) - self.b
        
        return f_x
# 实例化和拟合FLDA
flda = FishersLDA()
flda.fit(X_train.values, y_train)
# 对测试集进行预测
y_pred = flda.predict(X_test.values)
y_pred
array([-1, -1, -1,  1,  1,  1,  1, -1,  1,  1,  1,  1, -1,  1, -1,  1, -1,
       -1,  1,  1, -1, -1,  1,  1, -1])
# 计算准确性
accuracy = accuracy_score(y_test, y_pred)
print("Accuracy:", accuracy)

# 其他分类指标
print("Classification Report:")
print(classification_report(y_test, y_pred))
Accuracy: 1.0
Classification Report:
              precision    recall  f1-score   support

          -1       1.00      1.00      1.00        11
           1       1.00      1.00      1.00        14

    accuracy                           1.00        25
   macro avg       1.00      1.00      1.00        25
weighted avg       1.00      1.00      1.00        25

尝试使用两个相关性最大的特征

df = pd.read_csv('./iris.csv')
df = df[df['species'].isin(['setosa', 'versicolor'])]
df['species'].replace({'versicolor': -1}, inplace=True)
df['species'].replace({'setosa': 1}, inplace=True)
df.head()
sepal_lengthsepal_widthpetal_lengthpetal_widthspecies
05.13.51.40.21
14.93.01.40.21
24.73.21.30.21
34.63.11.50.21
45.03.61.40.21
# 计算相关系数
correlation_matrix = df.corr()
correlation_with_target = correlation_matrix['species'].drop('species')
correlation_with_target
sepal_length   -0.728290
sepal_width     0.684019
petal_length   -0.969955
petal_width    -0.960158
Name: species, dtype: float64
# 选择两个绝对相关系数最高的特征
selected_features = correlation_with_target.abs().nlargest(2).index
# 提取特征和目标
X_reduced = df[selected_features]
y = df['species']
X_reduced
petal_lengthpetal_width
01.40.2
11.40.2
21.30.2
31.50.2
41.40.2
.........
954.21.2
964.21.3
974.31.3
983.01.1
994.11.3

100 rows × 2 columns

# 将数据分成训练集和测试集
X_train_reduced, X_test_reduced, y_train, y_test = train_test_split(X_reduced, y, test_size=0.25, random_state=42)
lda = FishersLDA()
lda.fit(X_train_reduced.values, y_train)
y_pred = lda.predict(X_test_reduced.values)
# 评估准确性
accuracy = accuracy_score(y_test, y_pred)
print("Accuracy:", accuracy)
Accuracy: 1.0
z_projection = lda.get_z_projection(X_train_reduced)
# print(z_projection)
ind_pos = y_train.values==1
ind_neg = y_train.values==-1
z_pos = z_projection[ind_pos]
z_neg = z_projection[ind_neg]
hist_p, bin_edges_p = np.histogram(z_pos, bins=10)
hist_n, bin_edges_n = np.histogram(z_neg, bins=10)

# 类分布的直方图
plt.bar(bin_edges_p[:-1], hist_p, color=['blue'], width=0.02)
plt.bar(bin_edges_n[:-1], hist_n, color=['orange'], width=0.02)

plt.xlabel('Class')
plt.ylabel('Frequency')
plt.title('Histogram of Class Distribution')
plt.show()

在这里插入图片描述

from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
clf = LinearDiscriminantAnalysis()
clf.fit(X_train_reduced, y_train)
y_pred_sklearn = clf.predict(X_test_reduced)
# 评估准确性
accuracy = accuracy_score(y_test, y_pred_sklearn)
print("Accuracy:", accuracy)
Accuracy: 1.0

尝试LASSO回归进行特征选择

df = pd.read_csv('./iris.csv')
df = df[df['species'].isin(['setosa', 'versicolor'])]
df['species'].replace({'versicolor': -1}, inplace=True)
df['species'].replace({'setosa': 1}, inplace=True)
columns_to_drop = ['species']
X = df.drop(columns=columns_to_drop, inplace=False)
y = df['species']
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)

# 正则化参数(alpha)的不同值
alphas = np.logspace(-4, 2, 100)

coefs = []

for alpha in alphas:
    lasso = Lasso(alpha=alpha)
    lasso.fit(X_scaled, y)
    coefs.append(lasso.coef_)

plt.figure(figsize=(10, 6))

for i in range(4):
    plt.plot(alphas, [coef[i] for coef in coefs], label=f'Feature {i+1}')

plt.xscale('log')
plt.xlabel('Alpha')
plt.ylabel('Coefficient Value')
plt.title('LASSO Regression Coefficients Shrinkage')
plt.legend()
plt.grid(True)
plt.show()

在这里插入图片描述

完整代码和数据

链接: https://pan.baidu.com/s/1igLuxvmivgGbr_8-SfYymg?pwd=mj24 提取码: mj24

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1942022.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

在服务器调用api操作rabbitmq

不同的rabbitmq版本可能api不同,仅做参考,RabbitMQ 3.7.18。同时,我基本没看官方api文档,根据rabbitmq客户端控制台调用接口参数来决定需要什么参数。例如: 1、添加用户 curl -u 用户名:密码 -H “Content-Type: a…

[亲测可用]俄罗斯方块H5-网页小游戏源码-HTML源码

本站的HTML模板资源:所见文章图片即所得,搭建和修改教程请看这篇文章:https://yizhi2024.top/8017.html

Maven 的模块化开发示例

Maven 的模块化开发是一种非常有效的软件开发方式,它允许你将一个大型的项目分割成多个更小、更易于管理的模块(modules)。每个模块都可以独立地构建、测试和运行,这不仅提高了开发效率,也便于团队协作和项目的维护。以…

华为云.云日志服务LTS及其基本使用

云计算 云日志服务LTS及其基本使用 - 文章信息 - Author: 李俊才 (jcLee95) Visit me at CSDN: https://jclee95.blog.csdn.netMy WebSite:http://thispage.tech/Email: 291148484163.com. Shenzhen ChinaAddress of this article:https://blog.csdn.net/qq_28550…

如何给7Z分卷文件设置密码?简单几步给文件加上安全锁

在压缩7Z文件的时候,如果文件比较大,很多小伙伴都会把文件压缩成7Z分卷文件,那想要保护7Z分卷文件,要如何设置密码呢?不清楚的小伙伴,一起来看看吧! 我们可以使用7-Zip解压缩文件,在…

安全的备忘录工具有哪些 安全好用的备忘录

在这个数字化的时代,我们的生活中充斥着各种各样的信息,从工作计划到个人琐事,从账号密码到重要日期,这些信息都需要我们牢记。然而,人的记忆毕竟有限,于是,备忘录工具成为了我们日常生活中不可…

easyExcel和poi的版本对应

easypoi3.0.5对应的poi版本_easypoi和poi版本对应-CSDN博客 https://github.com/alibaba/easyexcel/blob/v3.2.0/pom.xml 解决 java.lang.NoClassDefFoundError: org/apache/poi/POIXMLTypeLoader 报错-CSDN博客 参考这个文档解决的- 引入最佳版本是3.15版本 java.lang.NoClas…

将Excel或CSV文件导入MySQL

数据库信息 版本:mysql-5.7.22 字符集如下 一、将 Excel 文件导入 MySQL,此时 MySQL 中不存在该表。 在数据库中,右键-导入向导

windows下mysql开启慢sql监控

上代码 #开启慢sql监控 SET GLOBAL slow_query_log ON; #设置慢sql日志存储路径 示例 SET GLOBAL slow_query_log_file D:\\javaTools\\mysql-8.0.32-winx64\\mysql-8.0.32-winx64\\slowSql\\slowSql.log; #超时时间 SET GLOBAL long_query_time 10; #查看是否开启慢查询 …

RabbitMQ的学习和模拟实现|GTest测试框架的介绍和简单使用

GTest 项目仓库:https://github.com/ffengc/HareMQ GTest GTest是什么我们需要学习的GTest功能宏断言事件机制 全局测试套件独立测试套件 GTest是什么 GTest是一个跨平台的 C单元测试框架,由google公司发布。gtest是为了在不同平台上为编写C单元测…

数学建模学习(112):FAHP模糊层次分析法

文章目录 一、FAHP方法由来二、模糊层次分析法原理2.1 AHP缺陷2.2 模糊集理论2.3 模糊层次分析法(FAHP)三、模糊层次分析法步骤3.1 问题定义与层次结构建立3.2 构造模糊判断矩阵3.2.1 计算模糊判断矩阵的列和向量3.2.2 计算模糊综合向量3.2.3 计算模糊权重向量3.3 解模糊数3.…

【Python】NumPy简要教程

文章目录 一、简介二、 ndarray 对象三、矩阵拼接四、数值运算4.1 数值选取4.2 单个数组的运算4.21 NumPy定义的常量4.22 单数组运算 4.3 数组之间的运算4.31 常见运算🟢4.32 广播机制:Broadcasting 五、数值类型、类型转换六、文件I/O 一、简介 NumPy …

56 网络层

本节重点 理解网络层的作用,深入理解IP协议的基本原理 对整个TCP/IP协议有系统的理解 对TCP/IP协议体系下的其他重要协议和技术有一定的了解 目录 前置认识ip协议基本概念协议头格式网段划分特殊的ip地址ip地址的数量限制私有ip和公有ip路由路由表生成算法 在复杂…

2024全网最全面及最新且最为详细的网络安全技巧 七之 XSS漏洞典例分析EXP以及 如何防御和修复(2)———— 作者:LJS

目录 8.5 Exploiting XSS with 20 characters limitation(蓝色为翻译)​编辑 Unicode compatibility 20 length limitation problem Taking advantage Next steps 8.6 Intigriti XSS 系列挑战 Writeups 8.6.1 xss challenge 1220 题目概述 思路分析 POC a.有交互 b.无交互 …

Ubuntu22.04安装与卸载nginx

换源 如果是国内的就不用换 中科大的源,由于我这里是Ubuntu,所以我就直接选Ubuntu22.04就行 点击下载,或者你直接复制这个sources.list的内容到linux中的/etc/apt/sources.list也可以,把原来的sources.list备份一下,…

python+pyqt开发海康相机数据采集系统

pythonpyqt开发海康相机数据采集系统 pythonpyqt开发海康相机数据采集系统 1 开发软件功能: 支持搜索相机:Gige相机设备和USB相机设备支持两种触发模式:软件触发和编码器触发支持数据采集过程中图像实时保存支持参数调节和实时预览&#xff…

安装好anaconda,打开jupyter notebook,新建 报500错

解决办法: 打开anaconda prompt 输入 jupyter --version 重新进入jupyter notebook: 可以成功进入进行代码编辑

批量打断相交线——ArcGISpro 解决方法

在数据处理,特别是地理空间数据处理或是任何涉及图形和线条分析的场景中,有时候需要把相交的线全部从交点打断一个常见的需求。这个过程对于后续的分析、编辑、或是可视化展现都至关重要,因为它可以确保每条线都是独立的,避免了因…

VPN概述

什么是VPN? VPN --- 虚拟专用网 --- 是指依靠ISP或者其他NSP或者企业自身,构建的专用的安全的数据通 信网络,只不过,这个专线网络是逻辑上的,而不是物理上的,所以叫做虚拟专用网 VPN诞生的原因是什么? 1&…

Qt实战:专栏内容介绍及目录

1、专栏介绍 Qt相比Visual Studio (VS) 的优势主要体现在跨平台能力、‌丰富的功能、‌高性能、‌现代UI设计、‌社区支持和企业支持等方面。‌ 跨平台能力:‌Qt 允许应用程序在多个操作系统上编译和运行,‌无需为每个平台编写特定的代码,‌…