《机器学习公式推导与代码实现》chapter5-线性判别分析LDA

news2024/11/27 22:24:41

《机器学习公式推导与代码实现》学习笔记,记录一下自己的学习过程,详细的内容请大家购买作者的书籍查阅。

线性判别分析

线性判别分析(linear discriminant analysis, LDA)是一种经典的线性分类方法,其基本思想是将数据投影到低维空间,使得同类数据尽可能接近,异类数据尽可能疏远。

另外线性判别分析能够通过投影来降低样本维度,并且在投影过程中用到了标签信息,线性判别分析是一种监督降维方法
在这里插入图片描述

1 LDA数学推导

在这里插入图片描述
在这里插入图片描述

2 基于numpy的LDA算法实现

完整的LDA算法流程如下:

  • (1)对训练集按类别进行分组;
  • (2)分别计算每组样本的均值和协方差;
  • (3)计算类间散度矩阵 S w S_{w} Sw;
  • (4)计算两类样本的均值差 μ 0 − μ 1 \mu _{0} -\mu_{1} μ0μ1;
  • (5)对类间散度矩阵 S w S_{w} Sw进行奇异值分解,并求逆;
  • (6)根据$S_{w}^{-1} \left ( \mu _{0} -\mu _{1} \right ) 得到 得到 得到w$;
  • (7)计算投影后的数据点 Y = w X Y = wX Y=wX

模型定义:

import numpy as np

class LDA():
    def __init__(self): # 初始化权重矩阵
        self.w = None
    
    def calc_cov(self, X, Y=None): # 计算协方差矩阵
        m = X.shape[0]
        # 将数据缩放到均值为0,标准差为1的标准正态分布
        X = (X - np.mean(X, axis=0))/np.std(X, axis=0) # 数据标准化
        Y = X if Y == None else (Y - np.mean(Y, axis=0))/np.std(Y, axis=0)
        return 1 / m * np.matmul(X.T, Y)
    
    def fit(self, X, y): # LDA拟合过程
        
        # 按类分组
        X0 = X[y == 0] 
        X1 = X[y == 1]

        # 分别计算两类数据自变量的协方差矩阵
        sigma0 = self.calc_cov(X0)
        sigma1 = self.calc_cov(X1)
        
        Sw = sigma0 + sigma1 # 计算类内散度矩阵

        # 分别计算两类数据自变量的均值和差
        u0, u1 = np.mean(X0, axis=0), np.mean(X1, axis=0)
        mean_diff = np.atleast_1d(u0 - u1) # 如果 u0 - u1 是一个标量,则将其转换为长度为1的一维数组。如果 u0 - u1 是一个数组,则保持不变

        # 利用 SVD 分解,我们可以通过计算矩阵 U、Σ 和 V 来得到原始矩阵的逆矩阵的近似
        U, S, V = np.linalg.svd(Sw) # # 对类内散度矩阵进行奇异值分解
        Sw_ = np.dot(np.dot(V.T, np.linalg.pinv(np.diag(S))), U.T) # 计算类内散度矩阵的逆
        
        self.w = Sw_.dot(mean_diff) # 计算w
    
    def predict(self, X): # LDA分类预测
        y_pred = []
        for sample in X:
            h = sample.dot(self.w)
            y = 1 * (h < 0) # 如果h小于0,则将y设置为1,否则将y设置为0
            y_pred.append(y)
        return y_pred

读取数据:

from sklearn import datasets
from sklearn.model_selection import train_test_split

data = datasets.load_iris()
X = data.data
y = data.target
X = X[y != 2]
y = y[y != 2]
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=41)
X_train.shape, X_test.shape, y_train.shape, y_test.shape
((80, 4), (20, 4), (80,), (20,))

数据测试:

lda = LDA()
lda.fit(X_train, y_train)
y_pred = lda.predict(X_test)

from sklearn.metrics import accuracy_score
accuracy = accuracy_score(y_test, y_pred)
print(accuracy)
0.85

结果可视化:

import matplotlib.pyplot as plt

class Plot():
    def __init__(self): 
        self.cmap = plt.get_cmap('viridis')

    def _transform(self, X, dim):
        covariance = LDA().calc_cov(X)
        eigenvalues, eigenvectors = np.linalg.eig(covariance)
        # Sort eigenvalues and eigenvector by largest eigenvalues
        idx = eigenvalues.argsort()[::-1]
        eigenvalues = eigenvalues[idx][:dim]
        eigenvectors = np.atleast_1d(eigenvectors[:, idx])[:, :dim]
        # Project the data onto principal components
        X_transformed = X.dot(eigenvectors)

        return X_transformed

    # Plot the dataset X and the corresponding labels y in 2D using PCA.
    def plot_in_2d(self, X, y=None, title=None, accuracy=None, legend_labels=None):
        X_transformed = self._transform(X, dim=2)
        x1 = X_transformed[:, 0]
        x2 = X_transformed[:, 1]
        class_distr = []

        y = np.array(y).astype(int)

        colors = [self.cmap(i) for i in np.linspace(0, 1, len(np.unique(y)))]

        # Plot the different class distributions
        for i, l in enumerate(np.unique(y)):
            _x1 = x1[y == l]
            _x2 = x2[y == l]
            _y = y[y == l]
            class_distr.append(plt.scatter(_x1, _x2, color=colors[i]))

        # Plot legend
        if not legend_labels is None: 
            plt.legend(class_distr, legend_labels, loc=1)

        # Plot title
        if title:
            if accuracy:
                perc = 100 * accuracy
                plt.suptitle(title)
                plt.title("Accuracy: %.1f%%" % perc, fontsize=10)
            else:
                plt.title(title)

        # Axis labels
        plt.xlabel('class 1')
        plt.ylabel('class 2')

        plt.show()

Plot().plot_in_2d(X_test, y_pred, title="LDA", accuracy=accuracy)
Plot().plot_in_2d(X_test, y_test, title="LDA", accuracy=accuracy)

在这里插入图片描述

3 基于sklearn的LDA算法实现

from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
clf = LinearDiscriminantAnalysis()
clf.fit(X_train, y_train)
y_pred = clf.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
print(accuracy)
1.0

笔记本_Github地址

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/742160.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

openGauss学习笔记-06 openGauss 基本概念

文章目录 openGauss学习笔记-06 openGauss 基本概念6.1 数据库&#xff08;Database&#xff09;6.2 数据块&#xff08;Block&#xff09;6.3 行&#xff08;Row&#xff09;6.4 列&#xff08;Cloumn&#xff09;6.5 表&#xff08;Table&#xff09;6.6 数据文件&#xff08…

Opencv之角点 Harris、Shi-Tomasi 检测详解

角点&#xff0c;即图像中某些属性较为突出的像素点 常用的角点有以下几种&#xff1a; 梯度最大值对应的像素点两条直线或者曲线的交点一阶梯度的导数最大值和梯度方向变化率最大的像素点一阶导数值最大&#xff0c;但是二阶导数值为0的像素点 API简介&#xff1a; void c…

Go语言网络编程:HTTP服务端之底层原理与源码分析——http.HandleFunc()、http.ListenAndServe()

一、启动 http 服务 import ("net/http" ) func main() {http.HandleFunc("/ping", func(w http.ResponseWriter, r *http.Request) {w.Write([]byte("ping...ping..."))})http.ListenAndServe(":8999", nil) }在 Golang只需要几行代…

MySQL存储过程和存储函数练习

创建表并插入数据 字段名 数据类型 主键 外键 非空 唯一 自增 id INT 是 否 是 是 否 name VARCHAR(50) 否 否 是 否 否 glass VARCHAR(50) 否 否 是 否 否 sch 表内容 id name glass 1 xiaommg glass 1 2 xiaojun glass 2 1、创建一个可以统计表格内记录条数的存储函数 &#…

耳夹式骨传导耳机哪个牌子好?耳夹骨传导耳机推荐

骨传导耳机品牌越来越多&#xff0c;选择骨传导耳机时可不是一件简单的事&#xff0c;在挑选的时候首先需要考虑到耳机自身的综合性能&#xff0c;以及耳机的配置如何都会影响到我们使用耳机的幸福感&#xff0c;接下来我来给大家挑选几款目前口碑不错的耳夹式骨传导耳机&#…

windows下使用cd命令切换到D盘的方法

windows下使用cd命令切换到D盘的方法 winr输入cmd进入终端

【CANFD详细介绍与CAN区别】

在汽车领域&#xff0c;随着人们对数据传输带宽要求的增加&#xff0c;传统的CAN总线由于带宽的限制难以满足这 种增加的需求。此外为了缩小CAN网络&#xff08;max. 1MBit/s&#xff09;与FlexRay(max.10MBit/s)网络的带宽差距&#xff0c;BOSCH公司推出了CAN FD。 CAN FD&…

基于控制屏障函数的安全关键系统二次规划(适用于ACC)(Matlab代码实现)

目录 &#x1f4a5;1 概述 &#x1f4da;2 运行结果 &#x1f389;3 参考文献 &#x1f468;‍&#x1f4bb;4 Matlab代码 &#x1f4a5;1 概述 基于控制屏障函数的安全关键系统二次规划&#xff08;适用于ACC&#xff09;是一种用于自适应巡航控制&#xff08;ACC&#x…

Servlet文档2

servlet文档2 HttpServletRequest 获取请求头API getMethod()获取请求的方式getRequestURI()获取请求的uri&#xff08;相对路径&#xff09;getRequestURL()获取请求的url&#xff08;绝对路径&#xff09;getRemoteAddr()获取请求的地址getProtocol()获取请求的协议getRem…

Vue3 CSS v-bind 计算和三元运算

官方文档 中指出&#xff1a;CSS 中的 v-bind 支持 JavaScript 表达式&#xff0c;但需要用引号包裹起来&#xff1a; 例子如下&#xff1a; <script lang"ts" setup> const treeContentWidth ref(140); </script><style lang"less" scop…

mschart Label Formart显示数值的格式化

默认这个数值想显示2位小数&#xff0c; 格式化代码如下。 series1.Label "#VAL{###.###}";

字符指针?指针数组?数组指针?《C语言指针进阶第一重奏》

目录 一.字符指针 1.1字符指针的认识 1.2字符指针存放字符串 1.3字符指针的使用 二.指针数组 2.1指针数组的认识 三.数组指针 3.1数组指针的认识 3.2数组名和&数组名的区别 3.3数组指针的使用 3.4数组参数&#xff0c;指针参数 3.5一维数组传参 3.6二维数组传…

如何让Stable Diffusion正确画手(1)-通过embedding模型优化图片质量

都说AI画手画不好手&#xff0c; 看这些是我用stable diffusion生成的图片&#xff0c;小姐姐都很漂亮&#xff0c;但手都千奇百怪&#xff0c;破坏了图片的美感。 其实只需要一个提示词&#xff0c;就能生成正确的手部&#xff0c;看这是我重新生成的效果&#xff0c;每一个小…

【leetcode】面试题 02.01. 移除重复节点 (python + 链表)

题目链接&#xff1a;[leetcode] 面试题 02.01. 移除重复节点 # Definition for singly-linked list. # class ListNode(object): # def __init__(self, x): # self.val x # self.next Noneclass Solution(object):def removeDuplicateNodes(self, he…

MySQL为什么采用B+树作为索引底层数据结构?

索引就像一本书的目录&#xff0c;通过索引可以快速找到我们想要找的内容。那么什么样的数据结构可以用来实现索引呢&#xff1f;我们可能会想到&#xff1a;二叉查找树&#xff0c;平衡搜索树&#xff0c;或者是B树等等一系列的数据结构&#xff0c;那么为什么MySQL最终选择了…

尚硅谷Docker实战教程-笔记12【高级篇,Docker-compose容器编排】

尚硅谷大数据技术-教程-学习路线-笔记汇总表【课程资料下载】视频地址&#xff1a;尚硅谷Docker实战教程&#xff08;docker教程天花板&#xff09;_哔哩哔哩_bilibili 尚硅谷Docker实战教程-笔记01【基础篇&#xff0c;Docker理念简介、官网介绍、平台入门图解、平台架构图解】…

一篇文章搞懂Libevent网络库的原理与应用

1. Libevent介绍 Libevent 是一个用C语言编写的、轻量级的开源高性能事件通知库&#xff0c;主要有以下几个亮点&#xff1a; > - 事件驱动&#xff08; event-driven&#xff09;&#xff0c;高性能; > - 轻量级&#xff0c;专注于网络&#xff1b; > - 源代码相当…

前端(五)——从 Vue.js 到 UniApp:开启一次全新的跨平台开发之旅

&#x1f642;博主&#xff1a;小猫娃来啦 &#x1f642;文章核心&#xff1a;从 Vue.js 到 UniApp&#xff1a;开启一次全新的跨平台开发之旅 文章目录 UniApp和vue.js什么是UniApp&#xff1f;UniApp的写法什么是vue.js&#xff1f;UniApp与vue.js是什么关系&#xff1f; 为什…

Python+Appium+Pytest自动化测试-参数化设置

来自APP Android端自动化测试初学者的笔记&#xff0c;写的不对的地方大家多多指教哦。&#xff08;所有内容均以微博V10.11.2版本作为例子&#xff09; 在自动化测试用例执行过程中&#xff0c;经常出现执行相同的用例&#xff0c;但传入不同的参数&#xff0c;导致我们需要重…

【Redis基础】快速入门

一、初识Redis 1. 认识NoSQL 2. 认识Redis Redis诞生于2009年&#xff0c;全称是Remote Dictionary Server&#xff08;远程词典服务器&#xff09;&#xff0c;是一个基于内存的键值型NoSQL数据库特征 &#xff08;1&#xff09;键值&#xff08;key-value&#xff09;型&am…