【机器学习】西瓜书习题3.5Python编程实现线性判别分析,并给出西瓜数据集 3.0α上的结果

news2025/2/23 14:35:41

参考代码
结合自己的理解,添加注释。

代码

  1. 导入相关的库
import numpy as np
import pandas as pd
import matplotlib
from matplotlib import pyplot as plt
  1. 导入数据,进行数据处理和特征工程
    得到数据集 D = { ( x i , y i ) } i = 1 m , y i ∈ { 0 , 1 } D=\{ (x_i,y_i) \}_{i=1}^m, y_i \in \{0,1\} D={(xi,yi)}i=1m,yi{0,1}
# 1.数据处理,特征工程
data_path = 'watermelon3_0_Ch.csv'
data = pd.read_csv(data_path).values
# 按照数据集3.0α,强制转换数据类型
X = data[:,7:9].astype(float)
y = data[:,9]
y[y=='是'] = 1
y[y=='否'] = 0
y = y.astype(int)
  1. 计算西瓜书60页中的 X i 、 μ i 、 Σ i X_{i}、\mu_i、\Sigma_i XiμiΣi
# 将X的数据根据label值分成X0和X1
pos = y == 1
neg = y == 0
X0 = X[neg]
X1 = X[pos]

# 计算u0,u1 keepdims保持原数据维数
u0 = X0.mean(0, keepdims=True)
u1 = X1.mean(0, keepdims=True)

# 计算sigma0,sigma1
sigma0 = np.dot((X0-u0).T,X0-u0)
sigma1 = np.dot((X1-u1).T,X1-u1)
  1. 根据式3.33计算类内散度矩阵
    S w = Σ 0 + Σ 1 = ∑ x ∈ X 0 ( x − μ 0 ) ( x − μ 0 ) T + ∑ x ∈ X 1 ( x − μ 1 ) ( x − μ 1 ) T S_w=\Sigma_0+\Sigma_1=\sum_{x\in X_{0}}(x-\mu_0)(x-\mu_0)^T+\sum_{x\in X_{1}}(x-\mu_1)(x-\mu_1)^T Sw=Σ0+Σ1=xX0(xμ0)(xμ0)T+xX1(xμ1)(xμ1)T
    根据式3.39计算 w w w
    w = S w − 1 ( μ 0 − μ 1 ) w=S_w^{-1}(\mu_0-\mu_1) w=Sw1(μ0μ1)
# 计算类内散度矩阵 with-class scatter matrix
sw = sigma0 + sigma1

# numpy.linalg.inv() 函数来计算矩阵的逆
w = np.dot(np.linalg.inv(sw),(u0-u1).T).reshape(1,-1)
  1. 画出样本点和得到的直线
fig, ax = plt.subplots()
ax.spines['right'].set_color('none')
ax.spines['top'].set_color('none')
ax.spines['left'].set_position(('data', 0))
ax.spines['bottom'].set_position(('data', 0))

plt.scatter(X1[:, 0], X1[:, 1], c='k', marker='o', label='good')
plt.scatter(X0[:, 0], X0[:, 1], c='r', marker='x', label='bad')

plt.xlabel('密度', labelpad=1)
plt.ylabel('含糖量')
plt.legend(loc='upper right')

x_tmp = np.linspace(-0.05, 0.15)
y_tmp = x_tmp * w[0, 1] / w[0, 0]
plt.plot(x_tmp, y_tmp, '#808080', linewidth=1)

得到下图
在这里插入图片描述

  1. 计算每个样本点在直线上的投影
    计算的理解参考这篇文章
# 求w这个向量的 单位向量 wu
# np.linalg.norm()默认求2 范数,表示向量中各个元素平方和 的 1/2 次方,L2 范数又称 Euclidean 范数或者 Frobenius 范数。
wu = w / np.linalg.norm(w)

# 正负样本点
# 求负样本的投影点,并连线
X0_project = np.dot(X0, np.dot(wu.T, wu))
plt.scatter(X0_project[:, 0], X0_project[:, 1], c='r', s=15)
for i in range(X0.shape[0]):
    plt.plot([X0[i, 0], X0_project[i, 0]], [X0[i, 1], X0_project[i, 1]], '--r', linewidth=1)

# 求正样本的投影点,并连线
X1_project = np.dot(X1, np.dot(wu.T, wu))
plt.scatter(X1_project[:, 0], X1_project[:, 1], c='k', s=15)
for i in range(X1.shape[0]):
    plt.plot([X1[i, 0], X1_project[i, 0]], [X1[i, 1], X1_project[i, 1]], '--k', linewidth=1)

得到下图
在这里插入图片描述

将上述代码封装成类,如下:

class LDA(object):

    def fit(self, X_, y_, plot_=False):
        pos = y_ == 1
        neg = y_ == 0
        X0 = X_[neg]
        X1 = X_[pos]

        u0 = X0.mean(0, keepdims=True)  # (1, n)
        u1 = X1.mean(0, keepdims=True)

        sw = np.dot((X0 - u0).T, X0 - u0) + np.dot((X1 - u1).T, X1 - u1)
        w = np.dot(np.linalg.inv(sw), (u0 - u1).T).reshape(1, -1)  # (1, n)

        if plot_:
            # 设置字体为楷体
            plt.rcParams['axes.unicode_minus']=False #用来正常显示负号
            plt.rcParams['font.sans-serif'] = ['KaiTi']
            fig, ax = plt.subplots()
            ax.spines['right'].set_color('none')
            ax.spines['top'].set_color('none')
            ax.spines['left'].set_position(('data', 0))
            ax.spines['bottom'].set_position(('data', 0))

            plt.scatter(X1[:, 0], X1[:, 1], c='k', marker='o', label='good')
            plt.scatter(X0[:, 0], X0[:, 1], c='r', marker='x', label='bad')

            plt.xlabel('密度', labelpad=1)
            plt.ylabel('含糖量')
            plt.legend(loc='upper right')

            x_tmp = np.linspace(-0.05, 0.15)
            y_tmp = x_tmp * w[0, 1] / w[0, 0]
            plt.plot(x_tmp, y_tmp, '#808080', linewidth=1)

            wu = w / np.linalg.norm(w)

            # 正负样板店
            X0_project = np.dot(X0, np.dot(wu.T, wu))
            plt.scatter(X0_project[:, 0], X0_project[:, 1], c='r', s=15)
            for i in range(X0.shape[0]):
                plt.plot([X0[i, 0], X0_project[i, 0]], [X0[i, 1], X0_project[i, 1]], '--r', linewidth=1)

            X1_project = np.dot(X1, np.dot(wu.T, wu))
            plt.scatter(X1_project[:, 0], X1_project[:, 1], c='k', s=15)
            for i in range(X1.shape[0]):
                plt.plot([X1[i, 0], X1_project[i, 0]], [X1[i, 1], X1_project[i, 1]], '--k', linewidth=1)

            # 中心点的投影
            u0_project = np.dot(u0, np.dot(wu.T, wu))
            plt.scatter(u0_project[:, 0], u0_project[:, 1], c='#FF4500', s=60)
            u1_project = np.dot(u1, np.dot(wu.T, wu))
            plt.scatter(u1_project[:, 0], u1_project[:, 1], c='#696969', s=60)

            ax.annotate(r'u0 投影点',
                        xy=(u0_project[:, 0], u0_project[:, 1]),
                        xytext=(u0_project[:, 0] - 0.2, u0_project[:, 1] - 0.1),
                        size=13,
                        va="center", ha="left",
                        arrowprops=dict(arrowstyle="->",
                                        color="k",
                                        )
                        )

            ax.annotate(r'u1 投影点',
                        xy=(u1_project[:, 0], u1_project[:, 1]),
                        xytext=(u1_project[:, 0] - 0.1, u1_project[:, 1] + 0.1),
                        size=13,
                        va="center", ha="left",
                        arrowprops=dict(arrowstyle="->",
                                        color="k",
                                        )
                        )
            plt.axis("equal")  # 两坐标轴的单位刻度长度保存一致
            plt.show()

        self.w = w
        self.u0 = u0
        self.u1 = u1
        return self

    def predict(self, X):
        project = np.dot(X, self.w.T)

        wu0 = np.dot(self.w, self.u0.T)
        wu1 = np.dot(self.w, self.u1.T)

        return (np.abs(project - wu1) < np.abs(project - wu0)).astype(int)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/813013.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

python+opencv学习

1. 把所有绿色通道值变为0 import cv2 import numpy as npimgcv2.imread(2.jpg) #读取图片 img[:,:,1]0 #绿色通道变为0 cv2.imshow(图片,img) #显示图片 cv2.waitKey(0) #无限地显示窗口 人生建议&#xff0c;一定要买书来学习呀&#xff0c;不是说教程不好&#xff0c;而是书…

idea application.yml配置文件没有提示或读不到配置

1.首先确定你的resources文件夹正常且yml文件图表和下面一样 不一样的右键去设置 2.确保你已经缩进了且层级关系正常 3.如果以上都不是&#xff0c;先考虑删除.idea重开试试 4.以上解决不了就装以下两个插件解决

[深入理解NAND Flash] 闪存(NAND Flash) 学习指南

依公开知识及经验整理&#xff0c;付费内容&#xff0c;禁止转载。 所在专栏 《深入理解Flash:闪存特性与实践》 1. 我想和你说 漠然回首&#xff0c;从事存储芯片行业已多年&#xff0c;这些年最宝贵的青春都献给了闪存&#xff0c;虽不说如数家珍&#xff0c;但也算专业。 …

【Git】git reflog git log

前言 日常开发过程中&#xff0c;我们经常会遇到要进行版本回退的情况&#xff0c;这时候需要使用git reflog和git reset 命令 git reflog 常用命令&#xff1a; 1、git reflog -n 查看多少条 2、git reflog show origin 查看远程历史变动 git log 什么都不加默认显示当前分…

SpringBoot项目中的web安全防护

最近这个月公司对项目进行了几次安全性扫描&#xff0c;然后扫描出来了一些安全漏洞&#xff0c;所以最近也一直在修复各种安全漏洞&#xff0c;还有就是最近在备考软考高级系统架构设计师&#xff0c;也刚好复习到了网络安全这一个章节&#xff0c;顺便将最近修复的安全漏洞总…

漏洞利用-PoC-in-GitHub+msf简单利用

查找库-PoC-in-GitHub 里面集成了几乎所有cve漏洞 下载&#xff1a;https://github.com/nomi-sec/PoC-in-GitHub 演示&#xff1a; 如想要查找vulfocus靶场中 Metabase远程命令执行漏洞 的利用方法。 可以下载一个Yomm闪电文件搜索 Yomm闪电文件搜索下载&#xff1a;https://…

Github-Copilot初体验-Pycharm插件的安装与测试

引言&#xff1a; 80%代码秒生成&#xff01;AI神器Copilot大升级 最近copilot又在众多独角兽公司的合力下&#xff0c;取得了重大升级。GitHub Copilot发布还不到两年&#xff0c; 就已经为100多万的开发者&#xff0c;编写了46%的代码&#xff0c;并提高了55%的编码速度。 …

代理模式——对象的间接访问

1、简介 1.1、概述 由于某些原因&#xff0c;客户端不想或不能直接访问某个对象&#xff0c;此时可以通过一个被称为“代理”的第三者来实现间接访问&#xff0c;该方案对应的设计模式被称为代理模式。 代理模式是一种应用很广泛的结构型设计模式&#xff0c;而且变化很多。…

活动回顾|火山引擎 DataLeap 分享:DataOps、数据治理、指标体系最佳实践(文中领取 PPT)

更多技术交流、求职机会&#xff0c;欢迎关注字节跳动数据平台微信公众号&#xff0c;回复【1】进入官方交流群 在 7 月 21 日至 22 日举行的 ArchSummit 全球架构师峰会&#xff08;深圳站&#xff09;及 DataFunCon.数据智能创新与实践大会&#xff08;北京站&#xff09;上&…

C++ 类的组合

解决复杂问题的有效方法就是将其层层分解为简单问题的组合&#xff0c;首先解决简单问题&#xff0c;复杂问题也就迎刃而解了。实际上&#xff0c;这种部件组装的生产方式广泛应用在工业生产中。例如&#xff0c;电视机的一个重要部件是显像管&#xff0c;但很多电视机厂自己并…

ARM裸机-7

1、S5PV210的地址映射 1.1、什么是地址映射 S5PV210属于ARM Cortex-A8架构&#xff0c;32位CPU&#xff0c;CPU设计时就有32根地址线&32根数据线。32根地址线决定了CPU的地址空间为4G&#xff0c;那么这4G空间如何分配使用&#xff1f;这个问题就是地址映射问题。 1.2、S…

AnimateDiff论文解读-基于Stable Diffusion文生图模型生成动画

文章目录 1. 摘要2. 引言3. 算法3.1 Preliminaries3.2. Personalized Animation3.3 Motion Modeling Module 4. 实验5.限制6. 结论 论文&#xff1a; 《AnimateDiff: Animate Your Personalized Text-to-Image Diffusion Models without Specific Tuning》 github: https://g…

高级 IO

目录 前言 什么是IO&#xff1f; 有哪些IO的的方式呢&#xff1f; 五种IO模型 这五种模型在特性有什么差别呢&#xff1f; 其他高级IO 非阻塞IO fcntl 实现函数SetNonBlock I/O多路转接之select 初识select select函数 参数说明&#xff1a; 关于timeval结构 函数…

【解惑笔记】树莓派+OpenCV+YOLOv5目标检测(Pytorch框架)

【学习资料】 子豪兄的零基础树莓派教程https://github.com/TommyZihao/ZihaoTutorialOfRaspberryPi/blob/master/%E7%AC%AC2%E8%AE%B2%EF%BC%9A%E6%A0%91%E8%8E%93%E6%B4%BE%E6%96%B0%E6%89%8B%E6%97%A0%E7%97%9B%E5%BC%80%E6%9C%BA%E6%8C%87%E5%8D%97.md#%E7%83%A7%E5%BD%95…

【多线程中的线程安全问题】线程互斥

1 &#x1f351;线程间的互斥相关背景概念&#x1f351; 先来看看一些基本概念&#xff1a; 1️⃣临界资源&#xff1a;多线程执行流共享的资源就叫做临界资源。2️⃣临界区&#xff1a;每个线程内部&#xff0c;访问临界资源的代码&#xff0c;就叫做临界区。3️⃣互斥&…

python与深度学习(十一):CNN和猫狗大战

目录 1. 说明2. 猫狗大战2.1 导入相关库2.2 建立模型2.3 模型编译2.4 数据生成器2.5 模型训练2.6 模型保存2.7 模型训练结果的可视化 3. 猫狗大战的CNN模型可视化结果图4. 完整代码5. 猫狗大战的迁移学习 1. 说明 本篇文章是CNN的另外一个例子&#xff0c;猫狗大战&#xff0c…

建立动态数组,输入5个学生的,另外用一个函数检查其中有无低于60分的,输出不合格的成绩。

题为c程序设计&#xff08;第五版&#xff09;谭浩强 例8.30 文章目录 前言一、pandas是什么&#xff1f;二、使用步骤 1.引入库2.读入数据总结 前言 这篇博客&#xff0c;让我们一起来学习内存的动态分配。 那么&#xff0c;什么是内存的动态分配呢&#xff1f;C语言允许建立…

RS485或RS232转ETHERCAT连接ethercat转换器

最近&#xff0c;生产管理设备中经常会遇到两种协议不相同的情况&#xff0c;这严重阻碍了设备之间的通讯&#xff0c;串口设备的数据不能直接传输给ETHERCAT。这可怎么办呢&#xff1f; 别担心&#xff0c;捷米JM-ECT-RS485/232来了&#xff01;这是一款自主研发的ETHERCAT从站…

FreeRTOS源码分析-7 消息队列

目录 1 消息队列的概念和作用 2 应用 2.1功能需求 2.2接口函数API 2.3 功能实现 3 消息队列源码分析 3.1消息队列控制块 3.2消息队列创建 3.3消息队列删除 3.4消息队列在任务中发送 3.5消息队列在中断中发送 3.6消息队列在任务中接收 3.7消息队列在中断中接收 1 消…