「Python」机器学习之线性判别分析(代码,不调包)

news2025/1/11 23:49:40

「Python」机器学习之线性判别分析(代码,不调包)

  • 前言
  • 1 线性判别分析(LDA)
  • 2 实现
    • 2.1 LDA实现
    • 2.2 数据集示例
  • 3 最后


前言

  1. 语言:python
  2. 库:numpy, matplotlib
  3. 教材参考:《机器学习》——周志华2016版(“西瓜书”)
  4. 平台版本:Linux 6.4.7-arch1-3,python10

在学习了机器学习的相关内容后,决定自己尝试实现,并记录下来,便于以后重温和回顾。

由于 python 在数学计算上的优势,所以使用 python以简化相关的数学运算操作,并尽量不采用现有的机器学习库,专注于算法实现。

1 线性判别分析(LDA)

注:内容来自周志华《机器学习》2016版第3章第4节
机器学习(Machine Learning)的任务主要有 拟合和分类 两种,线性判别分析(Linear Discriminant Analysis, LDA)是其中一种比较老的分类算法,示意图如下:
LDA示意图(图片来自周志华《机器学习》)
假设,数据集的样本可分为两类,分别用“+”和“-”表示,一个样本的类别由参数 x 1 x_1 x1 x 2 x_2 x2决定。要将此数据集进行分类,采用LDA算法:

  1. x 1 x 2 x_1x_2 x1x2平面上画一条直线 ω \omega ω
  2. 将数据点投影到直线 ω \omega ω上;
  3. 找到一条直线 ω \omega ω,令两类数据的投影点中心尽可能远,同类数据的投影点离投影中心尽可能近。

用数学语言描述:

数据集 D = { ( x i , y i ) } i m D = \{(x_i,y_i)\}_i^m D={(xi,yi)}im, y i ∈ { 0 , 1 } y_i \in \{0,1\} yi{0,1}, 令 X i , μ i , Σ i X_i, \mu_i, \Sigma_i Xi,μi,Σi分别表示第 i ∈ { 0 , 1 } i \in \{0,1\} i{0,1}类示例的集合、均值矩阵、协方差矩阵。

  1. 将数据集投影到直线 ω \omega ω上;
  1. 两类样本的中心在直线 ω \omega ω上的投影分别是 ω T μ 0 , ω T μ 1 \omega^T\mu_0, \omega^T\mu_1 ωTμ0,ωTμ1
  1. 两类样本的投影协方差分别为 ω T Σ 0 ω , ω T Σ 1 ω \omega^T\Sigma_0\omega, \omega^T\Sigma_1\omega ωTΣ0ω,ωTΣ1ω
  1. 要找到直线 ω \omega ω,使得两类的投影中心相距尽可能远 ( ∣ ∣ ω T μ 0 − ω T μ 1 ∣ ∣ 2 2 ||\omega^T\mu_0-\omega^T\mu_1||^2_2 ∣∣ωTμ0ωTμ122 尽可能大),类内的投影点离投影中心尽可能近 ( ω T Σ 0 ω + ω T Σ 1 ω \omega^T\Sigma_0\omega+\omega^T\Sigma_1\omega ωTΣ0ω+ωTΣ1ω 尽可能小)。
  1. 于是得到欲最大化目标函数: J = ∣ ∣ ω T μ 0 − ω T μ 1 ∣ ∣ 2 2 ω T Σ 0 ω + ω T Σ 1 ω = ω T ( μ 0 − μ 1 ) ( μ 0 − μ 1 ) T ω ω T ( Σ 0 + Σ 1 ) ω J=\frac{||\omega^T\mu_0-\omega^T\mu_1||^2_2}{\omega^T\Sigma_0\omega+\omega^T\Sigma_1\omega}=\frac{\omega^T(\mu_0-\mu_1)(\mu_0-\mu_1)^T\omega}{\omega^T(\Sigma_0+\Sigma_1)\omega} J=ωTΣ0ω+ωTΣ1ω∣∣ωTμ0ωTμ122=ωT(Σ0+Σ1)ωωT(μ0μ1)(μ0μ1)Tω
  1. 定义“类内散度” S ω = Σ 0 + Σ 1 S_\omega=\Sigma_0+\Sigma_1 Sω=Σ0+Σ1,定义“类间散度” S b = ( μ 0 − μ 1 ) ( μ 0 − μ 1 ) T S_b=(\mu_0-\mu_1)(\mu0-\mu_1)^T Sb=(μ0μ1)(μ0μ1)T
  1. 于是得到: J = ω T S b ω ω T S ω ω J=\frac{\omega^TS_b\omega}{\omega^TS_\omega\omega} J=ωTSωωωTSbω
  1. 由此可采用梯度下降方法训练得出直线斜率 ω \omega ω,除此之外,由于此式有解,可以直接公式求解(拉格朗日乘子法): ω = S ω − 1 ( μ 1 − μ 1 ) \omega=S_\omega^{-1}(\mu_1-\mu_1) ω=Sω1(μ1μ1)

2 实现

能力有限,可能存在错误,谨慎参考。

2.1 LDA实现

# -*- encoding: utf-8 -*-
"""
    @File		:	机器学习之线性判别分析.py
    @Description:	实验线性判别分析,LDA模型
    @Author		:	Daiwu Shen
    @Date		:	2023-07-10 20:24:22
"""
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt


class LinerDiscriminantAnalysis:
    def __init__(self, data: np.ndarray, label: np.ndarray) -> None:
        self.data = data # 数据
        self.label = label # 数据标签(类别)
        self.class_num = np.unique(self.label).size # 类型数
        self.omega = np.ones((self.data.shape[1])) # 斜率
        self.classify_data = [] # 不同类别的数据
        self.dataClassify()# 进行数据分类
        self.project_means = np.zeros(self.class_num) # 各类别的数据均值
        self.means = np.array([[np.mean(self.classify_data[j][:, i]) for i in range(self.classify_data[j].shape[1])] for j in range(len(self.classify_data))]) # 

        self.cov = np.array([np.cov(self.classify_data[i], rowvar=False) for i in range(len(self.classify_data))])

        self.Sw = self.cov[0]
        for i in range(1, self.cov.shape[0]):
            self.Sw += self.cov[i]

        self.Sb = np.zeros((self.means.shape[1], self.means.shape[1]))
        for i in range(self.means.shape[0]):
            for j in range(self.means.shape[0]):
                delta = np.array([self.means[i]-self.means[j]])
                self.Sb += np.dot(delta.T, delta)

        self.j = [np.dot(np.dot(self.omega, self.Sb), self.omega.T) / np.dot(np.dot(self.omega, self.Sw), self.omega.T)]

    def dataClassify(self):
        """ 对数据分为两类 """
        for _ in range(self.class_num):
            self.classify_data.append(np.empty((0, self.data.shape[1])))

        for index in range(self.label.size):
            self.classify_data[self.label[index][0]] = np.append(self.classify_data[self.label[index][0]], [self.data[index]], axis=0)

    def train(self, alpha: float = 0.01, num_iterator: int = 500):
        """ 开始训练 """
        for i in range(num_iterator):
            self.omega += self.objectiveFunc(alpha=alpha)
            if i % 100 == 0:
                print(self.J(), self.omega)

        for i in range(len(self.project_means)):
            self.project_means[i] = np.mean(self.omega.dot(self.classify_data[i].T))

    def objectiveFunc(self, alpha: float):
        """ 更新目标函数 """
        gredient = np.zeros(self.data.shape[1])
        for i in range(self.data.shape[1]):
            # 先算一边没有增加的损失函数
            j1 = self.J()
            # 然后增加omega
            self.omega[i] += alpha
            # 通过增加后的omega计算一边损失函数
            j2 = self.J()
            # 给omega还原
            self.omega[i] -= alpha
            # 如果omega增加后损失函数变大则增大(通过alpha改变增大的幅度)
            gredient[i] = (j2-j1)/alpha
        return gredient

    def J(self):
    	 """损失函数"""
        molecular = self.omega.dot(self.Sb).dot(self.omega.T)
        denominator = self.omega.dot(self.Sw).dot(self.omega.T)
        j = molecular/denominator
        self.j.append(j)
        return j

    def predict(self, x: np.ndarray):
        """ 预测 """
        result = []
        flag = True
        print(self.project_means)
        for i in range(len(x)):
            project = self.omega.dot(x[i])
            for j in range(1, len(self.project_means)):
                if project < (self.project_means[j-1]+self.project_means[j])/2:
                    result.append(j-1)
                    flag = False
                    break

            if flag:
                result.append(self.project_means.size-1)
            flag = True

        return result

    def correctRate(self, label1: np.ndarray, label2: np.ndarray):
        """ 准确率 """
        count = 0
        for i in range(len(label1)):
            if label1[i] == label2[i]:
                count += 1

        return count/len(label1)

    def Lagrange(self):
        """ 拉格朗日乘子法,可直接求解omega """
        """ 奇异值分解法求逆矩阵(优点:稳定) """
        u, sigma, v = np.linalg.svd(self.Sw, full_matrices=False)
        sw_I = np.matmul(v.T/sigma, u.T)

        """ 直接求逆矩阵 """
        # sw_I = np.linalg.inv(self.Sw)

        """ 拉格朗日乘子法:w=Sw_I (mean[0]-mean[1]) """
        """ 求均值的两两差值 """
        means = np.zeros(self.means.shape[1])
        for i in range(self.means.shape[0]):
            means += self.means[i]-self.means[0]

        self.omega = sw_I.dot(means)

        for i in range(len(self.project_means)):
            self.project_means[i] = np.mean(self.omega.dot(self.classify_data[i].T))

2.2 数据集示例

由于 Iris 数据集的标签是 Iris 的种类名称,为方便计算机训练,需要将字符串的种类名称数值化:

def Classify2Number(data: np.ndarray, className: list = []):
    """ 将数据集的标签进行数值化(如用0表示A类数据,1表示B类数据) """
    result = np.empty((0, 1), int)
    for item in data:
        for i in range(len(className)):
            if item[0] in className[i]:
                result = np.append(result, [[i]], axis=0)
    return result

根据LDA的原理,可以知道LDA是可以直接实现简单的多分类的,但是当数据变复杂后效果较差。由于 Iris 数据集有3个种类,数据简单,可以实现三分类。

if __name__ == "__main__":
    data = pd.read_csv("数据集文件路径", header=None) # 读取数据集,无表头则header=None,有标头省略
    data.columns = ["sepal-l", "sepal-w", "petal-l", "petal-w", "class"] # 这里的Iris数据集文件无表头,在这里添加表头
    test = data.sample(frac=0.2) # 分出测试集和训练集
    train = data.drop(test.index)

		# 初始化LDA对象
    LDA = LinerDiscriminantAnalysis(
        train[["sepal-l", "sepal-w", "petal-l", "petal-w"]].values,
        Classify2Number(train[["class"]].values, ["Iris-setosa", "Iris-versicolor", "Iris-virginica"]))
    # 二分类改为如下(将setosa作为一类,将versicolor和virginica共同作为一类)
    # LDA = LinerDiscriminantAnalysis(
    #   train[["sepal-l", "sepal-w", "petal-l", "petal-w"]].values,
    #   Classify2Number(train[["class"]].values, ["Iris-setosa", ["Iris-versicolor", "Iris-virginica"]]))

    """ 梯度下降训练 """
    LDA.train(alpha=0.001)

		""" 拉格朗日乘子法直接求解 """
    # LDA.Lagrange()

    print(LDA.omega)

		""" 绘制损失函数 """
    plt.plot(range(len(LDA.j)), LDA.j)
    plt.show()

    """ 测试,并打印准确率 """
    print(LDA.correctRate(
    	LDA.predict(test[["sepal-l", "sepal-w", "petal-l", "petal-w"]].values), 
    	Classify2Number(test[["class"]].values, ["Iris-setosa", "Iris-versicolor", "Iris-virginica"])[:, 0]))

3 最后

  1. 在训练前应对数据进行应有的初始化处理,如归一化、标准化等,由于 Iris 数据集数据量小,且数据规范,这里没有对数据进行处理。
  2. 在进行二分类时,将setosa作为一类,将versicolor和virginica共同作为一类。这也是做普遍多分类的一种方法:一对其余(OvR),对N个类的数据集,每次将一个类的样例作为正例、所有其他类的样例作为反例来训练 N 个分类器,在测试时若仅有一个分类器预测为正类,则对应的类别标记作为最终分类结果。见周志华《机器学习》2016版第3章第5节

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/931552.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

MySql学习4:多表查询

教程来源 黑马程序员 MySQL数据库入门到精通&#xff0c;从mysql安装到mysql高级、mysql优化全囊括 多表关系 各个表结构之间存在各种关联关系&#xff0c;基本上分为三种&#xff1a;一对多&#xff08;多对一&#xff09;、多对多、一对一 一对多&#xff08;多对一&…

论文阅读 The Power of Tiling for Small Object Detection

The Power of Tiling for Small Object Detection Abstract 基于深度神经网络的技术在目标检测和分类方面表现出色。但这些网络在适应移动平台时可能会降低准确性&#xff0c;因为图像分辨率的增加使问题变得更加困难。在低功耗移动设备上实现实时小物体检测一直是监控应用的…

计网第四章(网络层)(四)

目录 一、IP数据报的发送和转发过程 发送&#xff1a; 1.直接交付和间接交付 如果判断源主机和目的主机是否在同一个网络中&#xff1f; 2.默认网关&#xff1a; 转发&#xff1a; 路由表&#xff1a; 一、IP数据报的发送和转发过程 发送&#xff1a; 由主机发送IP数据…

统信OS国产操作系统身份证读卡器社保卡读卡web网页开发使用操作流程

用于DONSEE系列身份证阅读器谷歌Chrome火狐Firefox插件&#xff0c;支持的型号有&#xff1a;EST-100、EST-100GS、EST-100G、EST-100U、EST-200G、EST-J13X等。 本方案无缝支持最新版本谷歌Chrome火狐Firefox等网页浏览器&#xff0c;支持H5、Vue、React、Node.js、Electron、…

Java“牵手”天猫商品销量API接口数据,天猫API接口申请指南

天猫平台商品销量接口是开放平台提供的一种API接口&#xff0c;通过调用API接口&#xff0c;开发者可以获取天猫商品的标题、价格、库存、月销量、总销量、库存、详情描述、图片等详细信息 。 获取商品销量接口API是一种用于获取电商平台上商品销量数据的接口&#xff0c;通过…

mysql 命令行 执行sql文件

方法1 source source file.sql; file.sql : 绝对路径或 相对路径。 方法2 mysql -u xxx -p < file.sql 方法3 MySQLImport 工具 mysqlimport [options] database file_name 其中&#xff0c;database为要导入数据的数据库名&#xff0c;file_name为要导入的SQL文件名。还可以…

框架分析(5)-Django

框架分析&#xff08;5&#xff09;-Django 专栏介绍Django核心概念以及组件讲解模型&#xff08;Model&#xff09;视图&#xff08;View&#xff09;模板&#xff08;Template&#xff09;路由&#xff08;URLconf&#xff09;表单&#xff08;Form&#xff09;后台管理&…

【动态规划】1137. 第 N 个泰波那契数

Halo&#xff0c;这里是Ppeua。平时主要更新C&#xff0c;数据结构算法&#xff0c;Linux与ROS…感兴趣就关注我bua&#xff01; 文章目录 0. 题目解析1.算法原理1.1 状态表示1.2 状态转移方程1.3初始化1.4 填表顺序1.5 返回值 2.算法代码 &#x1f427; 本篇是整个动态规划的…

RT-Thread 时钟管理

时钟节拍 任何操作系统都需要提供一个时钟节拍&#xff0c;以供系统处理所有和时间有关的事件&#xff0c;如线程的延时、时间片的轮转调度以及定时器超时等。 RTT中&#xff0c;时钟节拍的长度可以根据RT_TICK_PER_SECOND的定义来调整。rtconfig.h配置文件中定义&#xff1a…

软件测试用例经典方法 |一文了解软件测试规范

软件测试规范是测试工作的依据和准则&#xff0c;在进行软件测试时&#xff0c;应在相关国标文件的要求和指导下完成测试工作&#xff0c;这样可以从根本上保证软件测试工作的质量&#xff0c;进而提升软件产品的质量。 一个完整的软件测试规范应该包括对规范本身的详细说明&a…

Python(.pyc)反编译:pycdc工具安装与使用

本文将介绍如何将python的.pyc文件反编译成源码&#xff0c;以便我们对源码的学习与改进。pycdc工具安装 下载地址&#xff1a; 1、Github地址&#xff1a;https://github.com/zrax/pycdc &#xff0c;下载后需要使用CMake进行编译。 2、已下载好及编译好的地址&#xff1a;ht…

Java多线程(十二)

目录 一、多线程环境使用哈希表 1.1 HashTable 1.2 ConcurrentHashTable 二、ConcurrentHashMap和Hashtable、HashMap 的区别 一、多线程环境使用哈希表 HashMap 本身就是线程不安全的&#xff0c;所以在多线程的环境下可以使用&#xff1a;HashTable、 ConcurrentHashMap 1.…

Mysql中explain执行计划信息中字段详解

Mysql中explain执行计划信息中字段详解 1. 获取执行计划2. 字段含义2.1 id2.2 select_type2.3 table2.4 partitions2.5 type2.6 possible_keys2.7 key2.8 ley_len2.9 ref2.10 rows2.11 extra 1. 获取执行计划 explain select * from t1; --或 desc select * from t1;2. 字段含…

Pandas数据分析教程-数据清洗-扩展数据类型

pandas-02-数据清洗&预处理 扩展数据类型1. 传统数据类型缺点2. 扩展的数据类型3. 如何转换类型文中用S代指Series,用Df代指DataFrame 数据清洗是处理大型复杂情况数据必不可少的步骤,这里总结一些数据清洗的常用方法:包括缺失值、重复值、异常值处理,数据类型统计,分…

23款奔驰GLE450轿跑升级原厂外观暗夜套件,战斗感满满的

升级的方案基本都是替换原来车身部位的镀铬件&#xff0c;可能会有人问&#xff1a;“难道直接用改色膜贴黑不好吗&#xff1f;”如果是贴膜的话&#xff0c;第一个是颜色没有那么纯正&#xff0c;这些镀铬件贴黑的技术难度先抛开不说&#xff0c;即使贴上去了&#xff0c;那过…

Mac电脑系统应该用什么软件进行优化清理?

作为一枚资深的Windows系统使用者&#xff0c;小编刚刚转向Mac系统的怀抱时&#xff0c;各种不适应&#xff0c;Windows系统中普遍使用的360清理软件目前暂时没有Mac版本的&#xff0c;这就让小编很是头疼了&#xff0c;大家的Mac都是用的什么清理软件呢&#xff1f; 经过一番…

Notion团队协作魔法:如何玩转数字工作空间?

Notion简介 Notion已经成为现代团队协作的首选工具之一。它不仅仅是一个笔记应用&#xff0c;更是一个强大的团队协作平台&#xff0c;能够满足多种工作场景的需求。 Notion的核心功能 Notion提供了丰富的功能&#xff0c;如文档、数据库、看板、日历等&#xff0c;满足团队的…

USB接口发展历程大全

1996年&#xff0c;由英特尔、微软、ibm等多家公司联合设计的usb标准问世&#xff0c;键盘、鼠标、智能手机以及打印机等等大多使用usb标准来实现供电和数据传输。 usb接口从诞生之初就是为了实现通用这个目的。在usb诞生之前&#xff0c;键盘、鼠标多使用ps二接口&#xff0c…

Doris异常处理

1、decimal 字段异常 修改为 2、连接超时 Caused by: com.mysql.cj.exceptions.CJCommunicationsException: Communications link failure The last packet successfully received from the server was 1,068 milliseconds ago. The last packet sent successfully to the ser…

kali开启SSH服务(简单无比)

1.切换到管理员用户&#xff1a; su root 提示输入root密码 2.启动SSH服务 命令为&#xff1a; /etc/init.d/ssh start 或者 systemctl start ssh 3.查看SSH服务状态是否正常运行&#xff0c;命令为&#xff1a; /etc/init.d/ssh status 图片仅供参考&#xff1a;