费舍尔线性分辩分析(Fisher‘s Linear Discriminant Analysis, FLDA)

news2024/10/6 5:53:29

费舍尔线性分辩分析(Fisher’s Linear Discriminant Analysis, FLDA)

目录

  • 费舍尔线性分辩分析(Fisher's Linear Discriminant Analysis, FLDA)
    • 1. 问题描述
    • 2. 二分类情况
    • 3. 多分类情况
    • 4. 代码实现
      • 4.1 二分类情况
      • 4.2 多分类情况
    • 5. 参考资料

1. 问题描述

为解决两个或多个类别的分类问题,大多数机器学习(ML)算法的工作方式相同。

通常,它们采用某种形式的转换来对输入数据进行处理,以降低原始输入维度到一个新的(更小)维度。其目的是将数据投影到新的空间中。然后,在投影后,它们尝试通过找到线性分离来对数据点进行分类。例如,我们有如下数据,
在这里插入图片描述

对数据直接进行线性分类显然不是最佳的方法,但是如果我们将数据投影到一维空间,我们可以找到一个线性分类器,将数据分为两个类别。这就是费舍尔线性判别分析(FLDA)的基本思想。我们将数据做如下操作:

y = x 0 2 + x 1 2 y=x_{0}^2+x_{1}^2 y=x02+x12

其中, x 0 x_{0} x0 x 1 x_{1} x1是原始数据的两个特征。我们可以看到,通过这种方式,我们将数据投影到了一维空间,然后我们可以找到一个线性分类器,将数据分为两个类别。投影后的数据如下图所示:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-uSabBiva-1690806540681)(image-1.png)]

通常,我们要探寻一种将数据从高维向低维度转换的方式,这被称为表征学习(Representation Learning)。深度学习也是表征学习的 一种,但在深度学习中,我们不需要猜测哪种转换会导致数据的最佳表示,算法会自行解决。

但是,请记住,无论是表示学习还是手工特征,模式都是相同的。我们需要以某种方式改变数据,使其更加适用于分类任务。

2. 二分类情况

假设我们有 C 1 C_1 C1 C 2 C_2 C2两个类别的样本,每个样本维度为 D D D,样本数为 n 1 n_1 n1 n 2 n_2 n2。我们的目标是找到一个投影矩阵 W W W,将数据投影到一维空间:

x ^ = W ⊤ x \hat{x}=W^{\top}x x^=Wx

设新样本 x ^ \hat{x} x^的维度为1,则 W W W的维度为 D × 1 D \times 1 D×1

那么,我们该如何寻找 W W W呢?换句话说,我们寻找的 W W W应该符合什么条件呢?

这就是费舍尔线性判别(Fisher’s Linear Discriminant, FLD)发挥作用的地方。

费舍尔提出的想法是最大化一个函数,该函数将在投影后的类均值之间产生大的分离,同时在每个类内部给出小的方差,从而最小化类之间的重叠。

换句话说,FLD选择最大化类别间分离的投影方法。为此,它最大化类别间方差与类别内方差之比。

简而言之,为了将数据投影到更小的维度并避免类别重叠,FLD保持了两个属性:

  • 数据集类别间具有很大的方差。

  • 数据集每个类别内部具有较小的方差。

请注意,较大的类别间方差意味着投影后的类别平均值应该尽可能远离彼此。相反,较小的类别内方差会使投影后的数据点更加接近。

计算每个类别的均值,我们可以得到:

μ 1 = 1 n 1 ∑ x ∈ C 1 x , μ 2 = 1 n 2 ∑ x ∈ C 2 x \mu_{1}=\frac{1}{n_{1}} \sum_{x \in C_{1}} x, \quad \mu_{2}=\frac{1}{n_{2}} \sum_{x \in C_{2}} x μ1=n11xC1x,μ2=n21xC2x

其中, m 1 m_1 m1 m 2 m_2 m2分别是 C 1 C_1 C1 C 2 C_2 C2类的均值。经过投影后,

μ ^ 1 = W ⊤ μ 1 , μ ^ 2 = W ⊤ μ 2 \hat{\mu}_{1}=W^{\top} \mu_{1}, \quad \hat{\mu}_{2}=W^{\top} \mu_{2} μ^1=Wμ1,μ^2=Wμ2

其中, μ ^ 1 \hat{\mu}_{1} μ^1 μ ^ 2 \hat{\mu}_{2} μ^2分别是 C 1 ^ \hat{C_1} C1^ C 2 ^ \hat{C_2} C2^的均值。

我们计算类间方差(between-class variance)得到:

μ ^ 1 − μ ^ 2 = W ⊤ ( μ 1 − μ 2 ) \hat{\mu}_{1} - \hat{\mu}_{2}=W^{\top}\left(\mu_{1}-\mu_{2}\right) μ^1μ^2=W(μ1μ2)

类内方差(within-class variance)为:

s ^ 1 = ∑ i ∈ C 1 ( x ^ i − μ ^ 1 ) 2 , s ^ 2 = ∑ i ∈ C 2 ( x ^ i − μ ^ 2 ) 2 \hat{s}_{1}=\sum_{i\in C_{1}}\left(\hat{x}_{i}-\hat{\mu}_{1}\right)^{2}, \quad \hat{s}_{2}=\sum_{i\in C_{2}}\left(\hat{x}_{i}-\hat{\mu}_{2}\right)^{2} s^1=iC1(x^iμ^1)2,s^2=iC2(x^iμ^2)2

其中, s ^ 1 \hat{s}_{1} s^1 s ^ 2 \hat{s}_{2} s^2分别是 C 1 ^ \hat{C_1} C1^ C 2 ^ \hat{C_2} C2^的方差。

我们的目标是最大化类间方差与类内方差之比:

J ( W ) = ( μ ^ 1 − μ ^ 2 ) 2 s ^ 1 2 + s ^ 2 2 J(W)=\frac{\left(\hat{\mu}_{1}-\hat{\mu}_{2}\right)^{2}}{\hat{s}_{1}^{2}+\hat{s}_{2}^{2}} J(W)=s^12+s^22(μ^1μ^2)2

在此基础上,我们可以对 J ( W ) J(W) J(W)进行进一步的变换处理。

我们定义一些散度(Scatter)的度量如下:

S B = ( μ 1 − μ 2 ) ( μ 1 − μ 2 ) ⊤   S k = ∑ x ∈ C k ( x − μ k ) ( x − μ k ) ⊤   S W = S 1 + S 2 S_{B}=\left(\mu_{1}-\mu_{2}\right)\left(\mu_{1}-\mu_{2}\right)^{\top}\\ \ \\ S_{k}= \sum_{x \in C_{k}}\left(x-\mu_{k}\right)\left(x-\mu_{k}\right)^{\top}\\ \ \\ S_{W}=S_{1}+S_{2} SB=(μ1μ2)(μ1μ2) Sk=xCk(xμk)(xμk) SW=S1+S2

经过一些变换,我们得到:

J ( W ) = W ⊤ S B W W ⊤ S W W   J(W)=\frac{W^{\top} S_{B} W}{W^{\top} S_{W} W} \\ \ \\ J(W)=WSWWWSBW 

我们的目标是最大化 J ( W ) J(W) J(W),我们现在对 J ( W ) J(W) J(W)求导,得到:

∂ J ( W ) ∂ W = ( W ⊤ S W W ) ∂ ( W ⊤ S B W ) ∂ W − ( W ⊤ S B W ) ∂ ( W ⊤ S W W ) ∂ W \frac{\partial J(W)}{\partial W} = (W^{\top}S_{W}W)\frac{\partial(W^{\top}S_{B}W)}{\partial W}-(W^{\top}S_{B}W)\frac{\partial(W^{\top}S_{W}W)}{\partial W} WJ(W)=(WSWW)W(WSBW)(WSBW)W(WSWW)

令上式为0,我们得到:

( W ⊤ S W W ) 2 S B W − ( W ⊤ S B W ) 2 S W W = 0   ( W ⊤ S W W ) S B W = ( W ⊤ S B W ) S W W (W^{\top}S_{W}W)2S_{B}W-(W^{\top}S_{B}W)2S_{W}W=0 \\ \ \\ (W^{\top}S_{W}W)S_{B}W=(W^{\top}S_{B}W)S_{W}W (WSWW)2SBW(WSBW)2SWW=0 (WSWW)SBW=(WSBW)SWW

由于投影操作,我们只关心 W W W的方向,上面的式子,可以去掉 ( W ⊤ S B W ) , ( W ⊤ S W W ) (W^{\top}S_{B}W),(W^{\top}S_{W}W) (WSBW),(WSWW),根据 S B S_{B} SB的定义, S B W S_BW SBW的方向与 ( μ 1 − μ 2 ) (\mu_{1}−\mu_{2}) (μ1μ2)一致,我们可以得到:
W ∝ S W − 1 ( μ 2 − μ 1 ) W∝S_{W}^{-1}(\mu_{2}−\mu_{1}) WSW1(μ2μ1)

3. 多分类情况

由于投影不再是一个标量,我们这里假设维度为 D ′ D^{\prime} D,我们使用Scatter矩阵的行列式来获得一个标量目标函数:


J ( W ) = ∣ W ⊤ S B W ∣ ∣ W ⊤ S W W ∣ J(W)=\frac{|W^{\top} S_{B} W|}{|W^{\top} S_{W} W|} J(W)=WSWWWSBW

对于上式的目标函数,最优投影矩阵 W W W的列向量是最大的 D ′ D^{\prime} D个特征向量,对应于 S W − 1 S B S_{W}^{-1}S_{B} SW1SB的最大的 D ′ D^{\prime} D个特征值。

对此,我们有求解公式:

W = max D ′ ( eig ( S W − 1 S B ) ) W=\text{max}_{D^{\prime}}(\text{eig}(S_{W}^{-1}S_{B})) W=maxD(eig(SW1SB))

4. 代码实现

4.1 二分类情况

rom sklearn.datasets import load_iris
import numpy as np
import pandas as pd
from torchvision import datasets
import matplotlib.pyplot as plt
from collections import Counter
from numpy.linalg import pinv
import matplotlib.lines as mlines

在MNIST上进行二分类,我们选择数字0和1,代码如下:

mnist_tr = datasets.MNIST('data', train=True, download=True)
mnist_te = datasets.MNIST('data', train=False, download=True)
x_train = mnist_tr.data.numpy()
y_train = mnist_tr.targets.numpy()
x_test = mnist_te.data.numpy()
y_test = mnist_te.targets.numpy()
two_class_data = []
two_class_target = []
for x, y in zip(x_train, y_train):
  # two class data
  if y == 0 or y == 1:
    two_class_data.append(x.flatten())
    two_class_target.append(y.squeeze())
two_class_data = np.asarray(two_class_data)
two_class_target = np.asarray(two_class_target)

划分数据为两类

C1_input = []
C1_target = []

C2_input = []
C2_target = []

for i in range(len(two_class_target)):
  y = two_class_target[i]
  x = two_class_data[i]
  if y == 0:
    C1_input.append(x.flatten())
    C1_target.append(y.squeeze())
  elif y == 1:
    C2_input.append(x.flatten())
    C2_target.append(y.squeeze())
    
C1_input = np.asarray(C1_input)
C1_target = np.asarray(C1_target)

C2_input = np.asarray(C2_input)
C2_target = np.asarray(C2_target)

计算类内均值

m1 = np.mean(C1_input,axis=0)
m2 = np.mean(C2_input,axis=0)

计算类内方差

tmp = np.subtract(C1_input, m1)
a = np.dot(tmp.T, tmp)

tmp = np.subtract(C2_input, m2)
b = np.dot(tmp.T, tmp)
SW = np.add(a,b)

计算变换矩阵 W W W

inv_SW = pinv(SW)
s = m2 - m1
W = np.dot(inv_SW, np.expand_dims(s,1))

投影后的数据

y = np.dot(two_class_data,W)

计算分类阈值,这里我们选择两类数据投影后的均值作为阈值

m1 = np.mean(y[C1_target==0])
m2 = np.mean(y[C2_target==1])
threshold = (m1+m2)/2

计算分类准确率

y[y<threshold] = 0
y[y>=threshold] = 1
acc = np.sum(y.squeeze()==two_class_target)/len(two_class_target)
print('acc:',acc)

4.2 多分类情况

three_class_data = {}
for x, y in zip(x_train, y_train):
  if y == 0 or y == 1 or y == 2:
    if y not in three_class_data:
      three_class_data[y] = [x.flatten()]
    else:
      three_class_data[y].append(x.flatten())
      
three_class_data[0] = np.asarray(three_class_data[0])
three_class_data[1] = np.asarray(three_class_data[1])
three_class_data[2] = np.asarray(three_class_data[2])

class DataSet:
  def __init__(self, data, targets, valid_classes=None):
    if valid_classes is None:
      self.valid_classes = np.unique(targets)
    else:
      self.valid_classes = valid_classes
     
    self.number_of_classes = len(self.valid_classes)
    self.data = self.to_dict(data,targets)
    
  def to_dict(self,data,targets):
    data_dict = {}
    for x, y in zip(data, targets):
      if y in self.valid_classes:
        if y not in data_dict:
          data_dict[y] = [x.flatten()]
        else:
          data_dict[y].append(x.flatten())
     
    for i in self.valid_classes:
      data_dict[i] = np.asarray(data_dict[i])

    return data_dict
  
  def get_data_by_class(self,class_id):
    if class_id in self.valid_classes:
      return self.data[class_id]
    else:
      raise("Class not found.")
  
  def get_all_data(self):
    data = []
    labels = []
    for label, class_i_data in self.data.items():
      data.extend(class_i_data)
      labels.extend(class_i_data.shape[0] * [label])
    data = np.asarray(data)
    labels = np.asarray(labels)
    return data, labels

dataset = DataSet(x_train, y_train, valid_classes=[0, 1, 2])

inputs, targets = dataset.get_all_data()

定义类别数量和目标维度

number_of_classes = three_class_data.keys()
D_prime = 2

计算类内均值

mk = []
for class_i, input_vectors in three_class_data.items():
  mk.append(np.mean(input_vectors,axis=0))
  mk[class_i] = np.asarray(mk[class_i])

计算类内方差

Sks = []
for (class_i, input_vectors), m in zip(three_class_data.items(),mk):
  tmp = np.subtract(input_vectors, m)
  Sks.append(np.dot(np.transpose(tmp), tmp))
  
Sks = np.asarray(Sks)

计算类间方差

N = 0
Nk = []
sum_ = 0
for class_i, data in three_class_data.items():
  Nk.append(data.shape[0])
  sum_ += np.sum(data,axis=0)

N = sum(Nk)
# m is the mean of the total data set
m = sum_ / N

SB = []
for class_i in three_class_data.keys():
  tmp = mk[class_i] - m
  SB.append(np.multiply(Nk[class_i], np.outer(tmp, tmp.T)))
SB = np.sum(SB,axis=0) # sum of K (# of classes) matrices

计算投影矩阵 W W W

from numpy.linalg import eig
matrix = np.dot(pinv(Sw),SB)
print("Out:",matrix.shape)

# find eigen values and eigen-vectors pairs for np.dot(pinv(SW),SB)
eigen_values, eigen_vectors = eig(matrix)

# sort eigen values and eigen-vectors pairs
idx = eigen_values.argsort()[::-1]
eigen_values = eigen_values[idx]
eigen_vectors = eigen_vectors[:,idx]
# find the projection matrix W
W = eigen_vectors[:,:D_prime]

投影后的数据

def inference(x,W):
  y = np.dot(x,W)
  return y

yk = []
for class_i, data in three_class_data.items():
  yk.extend(inference(data,W))
y = np.asarray(yk)
print(y.shape)

5. 参考资料

  • https://sthalles.github.io/fisher-linear-discriminant/

  • https://www.csd.uwo.ca/~oveksler/Courses/CS434a_541a/Lecture8.pdf

  • https://www.ccs.neu.edu/home/vip/teach/MLcourse/5_features_dimensions/lecture_notes/LDA/LDA.pdf

  • http://webdancer.is-programmer.com/posts/37867.html

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/817241.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

ROS-PyQt小案例

前言&#xff1a;目前还在学习ROS无人机框架中&#xff0c;&#xff0c;&#xff0c; 更多更新文章详见我的个人博客主页【前往】 ROS与PyQt5结合的小demo&#xff0c;用于学习如何设计一个界面&#xff0c;并与ROS中的Service和Topic结合&#xff0c;从而控制多个小乌龟的运动…

从零开始搭建Vue3框架(二):Vue-Router4.0使用与配置

前言 上篇文章我们创建了模板项目并成功运行&#xff0c;但是运行后的页面只是一个静态页面&#xff0c;并没有页面间跳转。 对于Vue这种单页应用来说&#xff0c;最要紧的就是控制整个系统的页面路由。因为我们使用Vue3的框架&#xff0c;所以这里使用Vue-Router4.0版本。 …

1992-2021年全国及31省对外开放度测算数据含原始数据和计算过程(无缺失)

1992-2021年全国及31省对外开放度测算数据含原始数据和计算过程&#xff08;无缺失&#xff09; 1、时间&#xff1a;1992-2021年 2、范围&#xff1a;全国及31省 3、指标&#xff1a;进出口总额、国内生产总值、年均汇率 4、计算方法&#xff1a;对外开放度进出口总额/GDP…

【Git系列】Git配置SSH免密登录

&#x1f433;Git配置SSH免密登录 &#x1f9ca;1.设置用户名和邮箱&#x1f9ca;2. 生成密钥&#x1f9ca;3.远程仓库配置密钥&#x1f9ca;2. 免密登录 在以上push操作过程中&#xff0c;我们第一次push时&#xff0c;是需要进行录入用户名和密码的&#xff0c;比较麻烦。而且…

【数据分析专栏之Python篇】四、pandas介绍

前言 在上一篇中我们安装和使用了Numpy。本期我们来学习使用 核心数据分析支持库 Pandas。 一、pandas概述 1.1 pandas 简介 Pandas 是 Python 的 核心数据分析支持库&#xff0c;提供了快速、灵活、明确的数据结构&#xff0c;旨在简单、直观地处理关系型、标记型数据。 …

Resnet与Pytorch花图像分类

1、介绍 1.1数据集介绍 flower_data├── train│ └── 1-102&#xff08;102个文件夹&#xff09;│ └── XXX.jpg&#xff08;每个文件夹含若干张图像&#xff09;├── valid│ └── 1-102&#xff08;102个文件夹&#xff09;└── ─── └── XXX.jp…

如何使用免费敏捷工具Leangoo领歌管理Sprint Backlog

什么是Sprint Backlog&#xff1f; Sprint Backlog是Scrum的主要工件之一。在Scrum中&#xff0c;团队按照迭代的方式工作&#xff0c;每个迭代称为一个Sprint。在Sprint开始之前&#xff0c;PO会准备好产品Backlog&#xff0c;准备好的产品Backlog应该是经过梳理、估算和优先…

ffmpeg安装

简介 FFmpeg是一个开源的音视频处理库&#xff0c;它提供了一系列的工具和API&#xff0c;可以用于处理音视频文件。你可以使用FFmpeg的命令行工具来执行各种音视频处理操作&#xff0c;比如转码、剪辑、合并等。FFmpeg的命令格式通常是&#xff1a;ffmpeg [全局选项] {[输入文…

章节5:SQL注入之WAF绕过

章节5&#xff1a;SQL注入之WAF绕过 5.1 SQL注入之WAF绕过上 WAF拦截原理&#xff1a;WAF从规则库中匹配敏感字符进行拦截。 5.2 SQL注入之WAF绕过下 &#xff08;原理简单了解&#xff09; 关键词大小写绕过 有的WAF因为规则设计的问题&#xff0c;只匹配纯大写或纯小写的…

B. Binary Cafe(二进制的妙用)

题目&#xff1a;Problem - B - Codeforces 总结&#xff1a; 对于该题最简单的方法为使用二进制的数表示状态 例如&#xff1a; 对于一个数7的二进制&#xff1a;111 它的每一位都可表示两种状态我们可以理解为取或者不取 对于7这个数字它可以表示一种状态即在三个位置都…

使用Roles模块搭建LNMP架构

使用Roles模块搭建LNMP架构 1.Ansible-playbook中部署Nginx角色2.Ansible-playbook中部署PHP角色3.Ansible-playbook中部署MySQL角色4.启动安装分布式LNMP 1.Ansible-playbook中部署Nginx角色 创建nginx角色所需要的工作目录&#xff1b; mkdir -p /etc/ansible/playbook/rol…

剖析 Kubernetes 控制器:Deployment、ReplicaSet 和 StatefulSet 的功能与应用场景

&#x1f337;&#x1f341; 博主猫头虎 带您 Go to New World.✨&#x1f341; &#x1f984; 博客首页——猫头虎的博客&#x1f390; &#x1f433;《面试题大全专栏》 文章图文并茂&#x1f995;生动形象&#x1f996;简单易学&#xff01;欢迎大家来踩踩~&#x1f33a; &a…

【kubernetes】k8s单master集群环境搭建及kuboard部署

k8s入门学习环境搭建 学习于许大仙: https://www.yuque.com/fairy-era k8s官网 https://kubernetes.io/ kuboard官网 https://kuboard.cn/ 基于k8s 1.21.10版本 前置环境准备 一主两从&#xff0c;三台虚拟机 CPU内存硬盘角色主机名IPhostname操作系统4C16G50Gmasterk8s-mast…

JSON动态生成表格

<!DOCTYPE html> <html><head><meta charset"utf-8"><title></title></head><body><script>var fromjava"{\"total\":3,\"students\":[{\"name\":\"张三\",\&q…

哔哩哔哩缓存转码|FFmpeg将m4s文件转为mp4|PHP自动批量转码B站视频

window下载安装FFmpeg 打开ffMpeg官网选择window>Windows builds from gyan.dev 打开https://www.gyan.dev/ffmpeg/builds/ 这里是上面提取的下载链接如果过期不能用自己去官网下 配置FFmpeg环境变量 上面下载的FFmpeg是绿色软件&#xff0c;下载解压到你的常用软件安装目…

配置IPv6 over IPv4 GRE隧道示例

组网需求 如图1&#xff0c;两个IPv6网络分别通过SwitchA和SwitchC与IPv4公网中的SwitchB连接&#xff0c;客户希望两个IPv6网络中的PC1和PC2实现互通。 其中PC1和PC2上分别指定SwitchA和SwitchC为自己的缺省网关。 图1 配置IPv6 over IPv4 GRE隧道组网图 配置思路 要实现I…

【LeetCode每日一题合集】2023.7.24-2023.7.30(TODO Lazy 线段树)

文章目录 771. 宝石与石头代码1——暴力代码2——位运算集合⭐&#xff08;英文字母的long集合表示&#xff09; 2208. 将数组和减半的最少操作次数&#xff08;贪心 优先队列&#xff09;2569. 更新数组后处理求和查询⭐⭐⭐⭐⭐&#xff08;线段树&#xff09;2500. 删除每行…

这所985很保护一志愿,每年招150+!非常稳定!

一、学校及专业介绍 中国海洋大学&#xff08;Ocean University of China&#xff0c;OUC&#xff09;&#xff0c;位于山东省青岛市&#xff0c;是中华人民共和国教育部直属的综合性全国重点大学&#xff0c;位列国家“双一流”、“985工程”、“211工程”重点建设高校。 1.1…

CHI中的error处理

Error Handling Error types 包含两种sub-packet级别的error, 和两种packe级别的error; Packet level error Data Error, DERR □ 访问的地址是正确的&#xff0c;但是访问的数据有错误&#xff1b;通常是在数据崩溃的时候使用&#xff0c;例如ECC&#xf…

三分钟白话RocketMQ系列—— 核心概念

目录 关键字摘要 Q1&#xff1a;RocketMQ是什么&#xff1f; Q2: 作为消息中间件&#xff0c;RocketMQ和kafka有什么区别&#xff1f; Q3: RocketMQ的基本架构是怎样的&#xff1f; Q4&#xff1a;RocketMQ有哪些核心概念&#xff1f; 总结 RocketMQ是一个开源的分布式消…