机器学习作业6——svm支持向量机

news2024/11/27 22:38:41

目录

一、理论

概念:

线性可分:

支持向量:

间隔:

目标:

软间隔:

梯度下降法:

别的方法:

拉格朗日函数:

SMO算法:

核函数:

二、代码

说明:

三、结果:

优缺点分析:

遇到的问题:


一、理论

svm的目的是找到一个最优的划分超平面或者决策边界,从而实现对数据的有效分割或者拟合。

超平面:

在二维情况下,上图的线就是超平面,而若特征有3维,则超平面就是一个平面,而高维情况很多,就统一叫作超平面。

所以当有了一个数据集后,主要的问题就是如何找出这个最优的超平面

概念:

线性可分:

现在先假设一个数据集是线性可分的。

因为超平面都可以用一个线性方程表示w^T x + b = 0,其中:w是超平面的法向量。x是数据点的特征向量。b是偏置。

有了这个概念,线性可分就可以定义为:

当标签为正类(y=1)时,w \cdot x_i + b \geq 0

当标签为负类(y=-1)时,w \cdot x_i + b < 0

将这两个式子合起来,简写为:y_i (w \cdot x_i + b) \geq 0,使得式子统一

支持向量:

由数学知识得到,假设一个平面为Ax+By+Cz+D=0, 那么将这个平面乘以一个数后,平面还是同一个平面,所以可以通过控制乘的这个数,使得w \cdot x_i + b \geq 1,y = +1,w \cdot x_i + b \leq -1,y = -1,化简一下变为:

y_i (w \cdot x_i + b) \geq 1

通过这样的缩放变换,当一个样本点使得w \cdot x_i + b = \pm 1,这个样本点就是距离这个超平面最近的点,我们把这些点称作支持向量。

虚线上的点就是支持向量

间隔:

在样本空间中,任意点到超平面的距离为:d = \frac{​{|w \cdot x + b|}}{\left \| w \right \|}

例如在三位空间中,点到平面距离公式为:d = \frac{​{|Ax_0 + By_0 + Cz_0 + D|}}{​{\sqrt{​{A^2 + B^2 + C^2}}}}

在支持向量中,{|w \cdot x + b|}这项是为1的,所以两个虚线之间的距离为:2*\frac{1}{\left \| w \right \|},这一项被称之为间隔

目标:

有了以上概念,我们的目标是:

希望最大化间隔\frac{2}{\left \| w \right \|},并且超平面满足约束条件y_i (w \cdot x_i + b) \geq 1,i = 1, 2, \ldots, n

而最大化间隔\frac{2}{\left \| w \right \|},可以等价为最小化\left \| w \right \|,又因为\left \| w \right \|始终为正值,但是带根号,所以简化为找到\frac{1}{2} || w ||^2的最小值(1/2的系数是为了方便求导)。

所以优化目标为:

\min_{w, b} \frac{1}{2} ||w||^2

软间隔:

当然以上条件都是在数据集线性可分的基础之上,才能这么去想的,而实际上,很少有数据集可以完美的符合线性可分的条件,所以要引入软间隔。

引入软间隔后,约束条件从y_i (w \cdot x_i + b) \geq 1,变成了y_i (w \cdot x_i + b) \geq 1- \xi_i,其中\xi_i叫做松弛变量

有了松弛变量后,就允许了一部分点可以被错误的分类。当然,我们希望松弛变量也是越小越好。

具体点说就是:

\xi_i<=0时,代表该样本点是正确分类的。

0<\xi_i<1时,代表该样本点分类虽然时正确的,但是是在自己标签的分离间隔和超平面之间的。

\xi_i=1时,代表该样本点在超平面上,无法正确分类。

\xi_i>1时,代表该样本点被错误分类了。

所以目标函数就变为:

\min_{w, b, \xi} \frac{1}{2} ||w||^2 + C \sum_{i=1}^{N} \xi_i,其中C是认为给出的正则化参数,用于控制\xi_i的大小。
把这个式子写成损失函数,就变成了以下形式,我们最小化损失函数即可。

L(y, f(x)) = \lambda \cdot ||w||^2 + \max(0, 1 - y \cdot f(x)),其中f(x) = w \cdot x + b

对w求偏导,

1 - y \cdot f(x) \leq 0时,\max(0, 1 - y \cdot f(x))=0,所以梯度为\lambda \text{w}

1 - y \cdot f(x) > 0时,\max(0, 1 - y \cdot f(x))=1- y \cdot f(x),所以梯度为-y \cdot x+\lambda \text{w}

梯度下降法:

若使用梯度下降法的SVM,权重更新式子为:

1 - y \cdot f(x) \leq 0时,\text{weights} \mathrel{-}= learningrate \times \lambda \text{w}

1 - y \cdot f(x) > 0时,weights \mathrel{-}= learningrate \times (\lambda \text{w}-y \cdot x)

别的方法:

拉格朗日函数:

线性不可分的支持向量机的拉格朗日函数可以写为:

L(w, b, \xi, \alpha, \beta) = \frac{1}{2} ||w||^2 + C \sum_{i=1}^{N} \xi_i - \sum_{i=1}^{N} \alpha_i (y_i (w \cdot x_i + b) - 1 + \xi_i) - \sum_{i=1}^{N} \beta_i \xi_i

原始问题:

\min_{w, b,\xi } \max_{\alpha}L(w, b, \xi, \alpha, \beta)

因为满足KKT条件(不去深究),所以可以将这个原始问题转化为对偶问题

\max_{\alpha}\min_{w, b,\xi } L(w, b, \xi, \alpha, \beta),意思是先对w, b,\xi求极小值,在对\alpha求极大值。

为了让L得到极小值,接下来分别求偏导,并且令偏导数=0。

w求偏导得到:

\frac{\partial L}{\partial w} = w - \sum_{i=1}^{N} \alpha_i y_i x_i = 0w = \sum_{i=1}^{N} \alpha_i y_i x_i

b求偏导得到:

\frac{\partial L}{\partial b} = -\sum_{i=1}^{N} \alpha_i y_i = 0\sum_{i=1}^{N} \alpha_i y_i = 0

\xi求偏导得到:

\frac{\partial L}{\partial \xi_i} = C - \alpha_i - \beta_i = 0\alpha_i = C - \beta_i

将上述3个结果代入原式,得到这个式子:

\max_{\alpha} \sum_{i=1}^{N} \alpha_i - \frac{1}{2} \sum_{i=1}^{N} \sum_{j=1}^{N} \alpha_i \alpha_j y_i y_j x_i \cdot x_j,并且满足:0 \leq \alpha_i \leq C, \quad i=1,2,...,N\sum_{i=1}^{N} \alpha_i y_i = 0

在上述条件下,解出\alpha,将其代入w和b中,就可以解出w和b了

SMO算法:

解出上面的\alpha就是SMO算法优化的地方。

SMO 算法通过不断选择两个变量进行优化,固定其他变量,然后在选定的两个变量上优化目标函数,从而实现对目标函数的最大化。这个过程中,SMO 算法会不断地更新拉格朗日乘子 α,直到达到收敛条件,最终求出α。

理论过程对本人来说太难了,写不出来,望老师见谅。

核函数:

核函数的作用是将输入空间中的数据映射到一个高维特征空间,从而产生了新的特征矩阵,使得原始数据在新的特征空间中变得线性可分或更容易进行线性划分。这样做的目的是为了解决原始特征空间中线性不可分的问题。

有:线性核函数(Linear Kernel),多项式核函数(Polynomial Kernel),高斯核函数(Gaussian Kernel 或 RBF Kernel),其中高斯核函数是最常用的。

二、代码

梯度下降法:

import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

# 加载鸢尾花数据集
iris = load_iris()
X = iris.data[:, :2]  # 只使用两个特征
y = iris.target

# 将标签转换为二元分类问题(假设类别 0 作为正例,其他类别作为负例)
y = np.where(y == 0, 1, -1)

# 将数据集分割为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

# 特征缩放
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)

# 初始化模型参数
np.random.seed(42)
w = np.random.randn(X_train.shape[1])  # 权重
b = 0                                   # 偏置项
lr = 0.01                               # 学习率
epochs = 100                          # 迭代次数
lmd = 0.1

# 定义损失函数(hinge loss)
def hinge_loss(X, y, w, b):
    loss = 1 - y * (np.dot(X, w) + b)
    return np.maximum(0, loss)

# 训练 SVM 模型
cnt = 0
for epoch in range(epochs):
    for i, x_i in enumerate(X_train):
        if y_train[i] * (np.dot(x_i, w) + b) >= 1:  # 判断是否分类正确
            dw = 2 * lmd * w  
        else:
            dw = 2 * lmd * w - np.dot(y_train[i], x_i)  # 对于错误分类的样本,更新权重和偏置项
            db = -y_train[i]
            w -= lr * dw
            b -= lr * db
        cnt+=1
        if cnt%100 == 0:
            print(repr('更新了第') + repr(cnt) + repr('次') + repr('W:') + repr(w) + repr('    b:') + repr(b))

# 绘制决策边界
plt.scatter(X_train[:, 0], X_train[:, 1], c=y_train, cmap=plt.cm.Paired)
ax = plt.gca()
xlim = ax.get_xlim()
ylim = ax.get_ylim()

# 创建网格以绘制决策边界
xx, yy = np.meshgrid(np.linspace(xlim[0], xlim[1], 50),
                     np.linspace(ylim[0], ylim[1], 50))
Z = np.dot(np.c_[xx.ravel(), yy.ravel()], w) + b
Z = np.sign(Z)
Z = Z.reshape(xx.shape)

# 绘制决策边界
plt.contour(xx, yy, Z, colors='k', levels=[-1, 0, 1], alpha=0.5, linestyles=['--', '-', '--'], interpolation='nearest')

plt.xlabel('x1')
plt.ylabel('x2')
plt.title('result')   
plt.show()

y_pred_train = np.sign(np.dot(X_train, w) + b)
y_pred_test = np.sign(np.dot(X_test, w) + b)

accuracy_train = np.mean(y_pred_train == y_train)
accuracy_test = np.mean(y_pred_test == y_test)
print("训练集准确率:", accuracy_train)
print("测试集准确率:", accuracy_test)

说明:

if y_train[i] * (np.dot(x_i, w) + b) >= 1:  
            dw = 2 * lmd * w  
else:
            dw = 2 * lmd * w - np.dot(y_train[i], x_i)  
            db = -y_train[i]
            w -= lr * dw
            b -= lr * db

最关键的部分就是这里了,但是这里在上面理论部分的梯度下降法里头说明了,dw是L对w求偏导,db同理,lr是学习率,这个条件的意义是:当在当前超平面下,分割出来的当前这个样本点如果是正确的,并且处于间隔外,在惩罚中就不需要加入松弛参数变出的那一项。

三、结果:

可以看到,在更新次数为9000左右的时候,参数就稳定下来了。

训练结果如下图:

可以看到,有一个点虽然被错误分类了,但关系到总体,情况还是很好的。

优缺点分析:

梯度下降SVM:

优点:

  1. 全局最优解:梯度下降算法可以收敛到全局最优解(如果学习率合适,并且损失函数是凸函数),从而得到最佳的分类超平面。
  2. 易于实现:梯度下降算法的实现相对简单,只需计算损失函数关于模型参数的梯度,并根据梯度方向更新参数即可。
  3. 扩展性强:梯度下降算法可以轻松地扩展到大规模数据集和高维特征空间。

缺点:

  1. 学习率选择:梯度下降算法的性能高度依赖于学习率的选择。学习率太小会导致收敛速度慢,学习率太大可能会导致震荡或无法收敛。
  2. 局部最优解:在非凸损失函数的情况下,梯度下降算法可能会陷入局部最优解,而无法找到全局最优解。
  3. 对初始值敏感:梯度下降算法的性能受初始参数值的影响,不同的初始值可能会导致不同的收敛结果。

遇到的问题:

一开始把梯度下降法和SMO算法混起来了,主要是对梯度下降的损失函数和W的更新式子不知道怎么得出的,然后先去学了一遍拉格朗日函数,在看SMO理论的时候,感觉很难,不太像是梯度下降,回头多看了看最开始得出的目标函数\min_{w, b, \xi} \frac{1}{2} ||w||^2 + C \sum_{i=1}^{N} \xi_i,发现将松弛参数用超平面代入,再对W求偏导就可以得出W的更新式子了,梯度下降的问题就解决了。

关于SMO,理不清楚原理,还是不写了。。(上次实验课腾讯会议里头的代码应该是梯度下降法)

参考的视频:

视频1

视频2

视频3


 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1802795.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

数据结构之ArrayList与顺序表(下)

找往期文章包括但不限于本期文章中不懂的知识点&#xff1a; 个人主页&#xff1a;我要学编程(ಥ_ಥ)-CSDN博客 所属专栏&#xff1a;数据结构&#xff08;Java版&#xff09; 目录 ArrayList的具体使用 118. 杨辉三角 扑克洗牌算法 接上篇&#xff1a;数据结构之ArrayLis…

三端植物大战僵尸杂交版来了

Hi&#xff0c;好久不见&#xff0c;最近植物大战僵尸杂交版蛮火的 那今天苏音整理给大家三端的植物大战僵尸杂交版包括【苹果端、电脑端、安卓端】 想要下载的直接划到最下方即可下载。 植物大战僵尸&#xff0c;作为一款古老的单机游戏&#xff0c;近期随着B站一位UP主潜艇…

英伟达黄仁勋最新主题演讲:“机器人时代“已经到来

6月2日&#xff0c;英伟达联合创始人兼首席执行官黄仁勋在Computex 2024&#xff08;2024台北国际电脑展&#xff09;上发表主题演讲&#xff0c;分享了人工智能时代如何助推全球新产业革命。 黄仁勋表示&#xff0c;机器人时代已经到来&#xff0c;将来所有移动的物体都将实现…

开源与新质生产力

在这个信息技术迅猛发展的时代&#xff0c;全球范围内的产业都在经历着深刻的变革。在这样的背景下&#xff0c;“新质生产力”的概念引起了广泛的讨论。无论是已经成为或正努力转型成为新质生产力的企业&#xff0c;都在寻求新的增长动力和竞争优势。作为一名长期从事开源领域…

什么是2+1退休模式?什么是链动2+1模式?

21退休模式又称链动21模式&#xff0c;主要是建立团队模式&#xff0c;同时快速提升销量。是目前成熟模式中裂变速度最快的模式。21退休模式合理合规&#xff0c;同时激励用户公司的利润分享机制&#xff0c;让您在享受购物折扣的同时&#xff0c;也能促进并获得客观收益。 模…

kettle从入门到精通 第六十六课 ETL之kettle kettle阻塞教程,轻松获取最后一行数据,so easy

场景:ETL沟通交流群内有小伙伴反馈,如何在同步一批数据完成之后记录下同步结果呢?或者是调用后续步骤、存储过程、三方接口等。 解决:使用步骤Blocking step进行阻塞处理即可。 1、下面的demo演示从表t1同步数据至表t2(t1表中有三条数据,t2为空表,两个表表结构相同),…

Plotly : 超好用的Python可视化工具

文章目录 安装&#xff1a;开始你的 Plotly 之旅基本折线图&#xff1a;简单却强大的起点带颜色的散点图&#xff1a;数据的多彩世界三维曲面图&#xff1a;探索数据的深度气泡图&#xff1a;让世界看到你的数据小提琴图&#xff1a;数据分布的优雅展现旭日图&#xff1a;分层数…

Vue学习day05笔记

day05 一、学习目标 1.自定义指令 基本语法&#xff08;全局、局部注册&#xff09;指令的值v-loading的指令封装 2.插槽 默认插槽具名插槽作用域插槽 3.综合案例&#xff1a;商品列表 MyTag组件封装MyTable组件封装 4.路由入门 单页应用程序路由VueRouter的基本使用 …

认识Java中的String类

前言 大家好呀&#xff0c;本期将要带大家认识一下Java中的String类&#xff0c;本期注意带大家认识一些String类常用方法&#xff0c;和区分StringBuffer和StringBuilder感谢大家收看 一&#xff0c;String对象构造方法与原理 String类为我们提供了非常多的重载的构造方法让…

kubesz(一键安装k8s)

引言 Kubernetes&#xff08;K8s&#xff09;是一个开源的容器编排系统&#xff0c;用于自动化部署、扩展和管理容器化应用程序。kubeasz 是一个用于快速搭建 Kubernetes 高可用集群的项目&#xff0c;它基于 Ansible&#xff0c;通过提供一套简单、易用的配置&#xff0c;使得…

java异常处理知识点总结

一.前提知识 首先当运行出错的时候&#xff0c;有两种情况&#xff0c;一种叫做“错误”&#xff0c;另一种叫做“异常”。错误指的是运行过程中遇到了硬件或操作系统出错&#xff0c;这种情况程序员是没办法处理的&#xff0c;因为这是硬件和系统的问题&#xff0c;不能靠代码…

Linux: ubi rootfs 故障案例 (1)

文章目录 1. 前言2. ubi rootfs 故障现场3. 故障分析与解决4. 参考资料 1. 前言 限于作者能力水平&#xff0c;本文可能存在谬误&#xff0c;因此而给读者带来的损失&#xff0c;作者不做任何承诺。 2. ubi rootfs 故障现场 问题故障内核日志如下&#xff1a; Starting ker…

【数据结构与算法 | 二叉树篇】力扣101, 104, 111,LCR144

1. 力扣101 : 对称二叉树 (1). 题 给你一个二叉树的根节点 root &#xff0c; 检查它是否轴对称。 示例 1&#xff1a; 输入&#xff1a;root [1,2,2,3,4,4,3] 输出&#xff1a;true示例 2&#xff1a; 输入&#xff1a;root [1,2,2,null,3,null,3] 输出&#xff1a;false…

Go语言 几种常见的IO模型用法 和 netpoll与原生GoNet对比

【go基础】16.I/O模型与网络轮询器netpoller_go中的多路io复用模型-CSDN博客 字节开源的netPoll多路复用器源码解析-CSDN博客 一、几种常见的IO模型 1. 阻塞I/O (1) 解释&#xff1a; 用户调用如accept、read等系统调用&#xff0c;向内核发起I/O请求后&#xff0c;应用程序…

多样本上下文学习:开拓大模型的新领域

大模型&#xff08;LLMs&#xff09;在少量样本上下文学习&#xff08;ICL&#xff09;中展现出了卓越的能力&#xff0c;即通过在推理过程中提供少量输入输出示例来学习&#xff0c;而无需更新权重。随着上下文窗口的扩展&#xff0c;我们现在可以探索包含数百甚至数千个示例的…

基于JSP技术的文物管理系统

你好呀&#xff0c;我是计算机学长猫哥&#xff01;如果有相关需求&#xff0c;文末可以找到我的联系方式。 开发语言&#xff1a;Java 数据库&#xff1a;MySQL 技术&#xff1a;JSP技术 工具&#xff1a;IDEA/Eclipse、Navicat、Maven 系统展示 首页 管理员界面 用户前台…

步态控制之ZMP

零力矩点&#xff08;Zero Moment Point&#xff0c;ZMP&#xff09;概述 ZMP步态控制是人形机器人步态控制中的一个关键概念&#xff0c;旨在确保机器人在行走或站立过程中保持平衡。ZMP是指机器人接触面上力矩为零的点&#xff0c;确保在该点上机器人不会倾倒。这个示例展示…

Python的登录注册界面跳转汽车主页面

1.登录注册界面的代码&#xff1a; import tkinter as tk from tkinter import messagebox,ttk from tkinter import simpledialog from ui.car_ui import start_car_ui# 设置主题风格 style ttk.Style() style.theme_use("default") # 可以根据需要选择不同的主题…

竞拍商城系统源码后端PHP+前端UNIAPP

下载地址&#xff1a;竞拍商城系统源码后端PHP前端UNIAPP

Live800:深度解析,客户服务如何塑造品牌形象

在当今竞争激烈的市场环境中&#xff0c;品牌形象对于企业的成功至关重要。而客户服务作为品牌与消费者之间最直接的互动方式&#xff0c;不仅影响着消费者的购买决策&#xff0c;更在塑造品牌形象方面发挥着不可替代的作用。本文将深度解析客户服务如何塑造品牌形象&#xff0…