优化 | 用CVXPY手搓一个SVM

news2024/11/18 19:29:46

在这里插入图片描述

支持向量机(Support Vector Machine,SVM)是一种常用的机器学习算法,用于分类和回归问题。SVM的基本思想是寻找一个最优的超平面,将不同类别的数据样本分隔开来。设样本点为 ( x i , y i ) ( i = 1 , ⋯   , n ) (x_i,y_i)(i = 1,\cdots,n) (xi,yi)(i=1,,n) ,其中 x i ∈ R m x_i\in \mathbb{R}^m xiRm ,标签 y i ∈ { 1 , − 1 } y_i\in\{1, -1\} yi{1,1} ,线性分类面方程为 w T x + b = 0 , w ∈ R m , b ∈ R w^Tx + b = 0, w\in\mathbb{R}^m, b\in\mathbb{R} wTx+b=0,wRm,bR 。SVM希望找到超平面参数 w , b w, b w,b,满足如下的优化问题
$$\min \frac{1}{2} w^Tw \ \text{s.t.} y_i(w^Tx_i + b) \geq 1$$

这是一个凸的二次规划问题,因此一定有最优解。引入对偶变量 α i ≥ 0 ( i = 1 , ⋯   , n ) \alpha_i\geq 0(i = 1,\cdots,n) αi0(i=1,,n),拉格朗日函数为
L ( w , b , α ) = 1 2 w T w − ∑ i = 1 n α i [ y i ( w T x i + b ) − 1 ] L(w, b, \alpha) = \frac{1}{2} w^Tw - \sum_{i = 1}^n \alpha_i[y_i(w^Tx_i + b) - 1] L(w,b,α)=21wTwi=1nαi[yi(wTxi+b)1]
KKT条件为
$$\nabla_w L(w, b, \alpha) = w - \sum_{i = 1}^n \alpha_iy_ix_i = 0 \ \nabla_b L(w, b, \alpha) = -\sum_{i = 1}^n \alpha_iy_i = 0 \
\alpha_i \geq 0 \ \alpha_i[y_i(w^Tx_i + b) - 1] = 0 \ 将第一个 K K T 条件和第二个 K K T 条件带入带入拉格朗日函数,我们就得到了对偶问题 将第一个KKT条件和第二个 KKT条件带入带入拉格朗日函数,我们就得到了对偶问题 将第一个KKT条件和第二个KKT条件带入带入拉格朗日函数,我们就得到了对偶问题\max Q(\alpha) = \sum_{i = 1}^n \alpha_i - \frac{1}{2}\sum_{i = 1}^n\sum_{j = 1}^n \alpha_i\alpha_jy_iy_jx_i^Tx_j \ \text{s.t.} \alpha_i \geq 0, \sum_{i = 1}^n \alpha_iy_i = 0$$
这同样是一个凸的二次规划问题。理论上,我们可以通过求解对偶问题得到对偶变量 α \alpha α,将其带入KKT条件就可得到超平面的参数。同时根据第四个KKT条件,我们得到 α i ≠ 0 \alpha_i \neq 0 αi=0 y i ( w T x i + b ) = 1 y_i(w^Tx_i + b) = 1 yi(wTxi+b)=1。这些点恰好位于边界上,因此被称为支持向量。这也是该算法被称为支持向量机的原因。

CVXPY是一个用于凸优化问题建模和求解的Python库。它提供了一种简洁而直观的方式来描述凸优化问题,并使用底层求解器来求解这些问题。我们将用CVXPY实现SVM算法。

CVXPY的安装非常简单。首先确保电脑已配置Python环境。在终端中输入

pip install cvxpy

即可。进入Python Console,输入以下命令

import cvxpy as cp
print(cp.installed_solvers())

如果出现以下输出说明CVXPY安装成功。输出显示了已安装的求解器。

首先我们生成一组线性可分的数据。从均值为 [ − 3 − 3 ] \begin{bmatrix}-3 \\-3 \\\end{bmatrix} [33],协方差矩阵为 [ 2 − 1 − 1 2 ] \begin{bmatrix}2 & -1 \\-1 & 2 \\\end{bmatrix} [2112]的二元正态分布中抽取100个样本作为正例,从均值为 [ 3 3 ] \begin{bmatrix}3 \\3 \\\end{bmatrix} [33],协方差矩阵为 [ 2 − 1 − 1 2 ] \begin{bmatrix}2 & -1 \\-1 & 2 \\\end{bmatrix} [2112]的二元正态分布中抽取100个样本作为负例。核心代码如下。


def load_data():
    mean1 = [-3, -3]
    sigma1 = [[2, -1], [-1, 2]]
    mean2 = [3, 3]
    sigma2 = [[2, -1], [-1, 2]]
    X1 = np.random.multivariate_normal(mean1, sigma1, 100)
    X2 = np.random.multivariate_normal(mean2, sigma2, 100)
    X = np.vstack((X1, X2))
    y = np.hstack((np.ones(100), -np.ones(100)))
    return X, y

画出散点图如下。可见样本点线性可分。

接下来我们进行数据的拟合。为了统一运算,对负例样本乘 -1。设样本矩阵 X = [ x 1 , ⋯   , x n ] T X = [x_1, \cdots, x_n]^T X=[x1,,xn]T,全一向量 1 n = [ 1 , ⋯   , 1 ] T 1_n = [1, \cdots, 1]^T 1n=[1,,1]T ,则对偶问题 Q ( α ) Q(\alpha) Q(α)的优化目标可写作
1 n T α − 1 2 ( X T α ) T ( X T α ) 1_n^T\alpha - \frac{1}{2}(X^T\alpha)^T(X^T\alpha) 1nTα21(XTα)T(XTα)
然后定义优化目标和约束,调用CVXOPT求解器进行求解。核心代码如下。

def fit(X, y):
    x = X.copy()
    x[y == -1, :] = -x[y == -1, :]
    n = x.shape[0]
    alpha = cp.Variable(n)
    objective = cp.Minimize(0.5 * cp.sum_squares(x.T @ alpha) - np.ones(n) @ alpha)
    constraint = [alpha >= 0, y @ alpha == 0]
    prob = cp.Problem(objective, constraint)
    prob.solve(solver='CVXOPT')
    print(f"dual variable = {alpha.value}")

得到对偶变量的最优解后,代入第一个KKT条件可以计算出 w w w α i ≠ 0 \alpha_i \neq 0 αi=0的对偶变量对应了支持向量。代入支持向量满足的边界方程可以计算出 b b b。核心代码如下。

w = x.T @ alpha.value
    index = np.where(abs(alpha.value) > 1e-3)[0]
    print(f"support vector index = {index}")
    b = np.mean(y[index] - X[index, :] @ w)
    return w, b, index

程序运行结果如下。可见有3个支持向量。其余的对偶变量非常接近0。

dual variable = [ 1.43531769e-09  3.28358675e-11 -4.70081346e-11 ... -1.23193041e-12
  1.46991949e-11 -4.72866927e-11]
support vector index = [ 11  42 111]
w = [-0.30023937 -0.26737101], b = 0.007281446816055766

将分界面和支持向量可视化,可见SVM确实找到了最优分类面。

附程序完整代码。

import numpy as np
from matplotlib import pyplot as plt
import cvxpy as cp

def load_data():
    mean1 = [-3, -3]
    sigma1 = [[2, -1], [-1, 2]]
    mean2 = [3, 3]
    sigma2 = [[2, -1], [-1, 2]]
    X1 = np.random.multivariate_normal(mean1, sigma1, 100)
    X2 = np.random.multivariate_normal(mean2, sigma2, 100)
    X = np.vstack((X1, X2))
    y = np.hstack((np.ones(100), -np.ones(100)))
    return X, y

def fit(X, y):
    x = X.copy()
    x[y == -1, :] = -x[y == -1, :]
    n = x.shape[0]
    alpha = cp.Variable(n)
    objective = cp.Minimize(0.5 * cp.sum_squares(x.T @ alpha) - np.ones(n) @ alpha)
    constraint = [alpha >= 0, y @ alpha == 0]
    prob = cp.Problem(objective, constraint)
    prob.solve(solver='CVXOPT')
    print(f"dual variable = {alpha.value}")
    w = x.T @ alpha.value
    index = np.where(abs(alpha.value) > 1e-3)[0]
    print(f"support vector index = {index}")
    b = np.mean(y[index] - X[index, :] @ w)
    return w, b, index
    
if __name__=="__main__":
    X, y = load_data()
    w, b, index = fit(X, y)
    print(f"w = {w}, b = {b}")
    plt.scatter(X[y == 1, 0], X[y == 1, 1])
    plt.scatter(X[y == -1, 0], X[y == -1, 1])
    plt.scatter(X[index, 0], X[index, 1], marker='*', s=100)
    plt.plot((-4, 4), ((-b + 4 * w[0]) / w[1], (-b - 4 * w[0]) / w[1]))
    plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1298534.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

排序算法---冒泡排序

1. 原理 对数组进行遍历,每次对相邻的两个元素进行比较,如果大的在前面,则交换两个元素的位置,完成一趟遍历后,数组中最大的数值到了数组的末尾。再对前面n-1个数值进行相同的遍历。一共完成n-1趟遍历就实现了排序。 1…

分享下我发现的16个AI辅助编程的网站

这些工具和服务覆盖了多个方面,包括编程辅助、代码生成、问题解决、Git指令辅助、代码安全扫描等,为开发者提供了丰富的选择。 Codegeex (codegeex.cn/zh-CN): 类型:AI编程助手支持语言:Python, C/C, Java, Go, JavaScript等特点…

docke网络之bridge、host、none

一、bridge网络 1.创建一个测试容器 [rootlocalhost ~]# docker run -d -it --name busybox_1 busybox /bin/sh -c "while true;do sleep 3600;done" 03b308c847edd23f21ba69afb825d92f7aaeb05b1ff4431dd47ccee439a0361a 2.查看当前机器docker有哪些网络 [rootlocal…

获取类class对象的方式

一、什么是class对象 Class类位于java核心包lang包中,它是反射的源头。Class对象用于记录每个类的运行时数据结构,或者说是在内存中访问类的静态数据的接口,每个类都有一个唯一的Class对象。Class对象不能直接通过new来获取,因为…

空中消防员:无人机森林防火应用全面升级

森林是生态系统的重要组成部分,也是人类得以生存的关键。我国森林面积广大,存在火灾频发的困境。提升森林火灾防控能力是维护生态平衡、保护资源和保障人民生命安全的必要步骤。随着无人机技术的发展,其在无人机森林防火中的应用为传统巡查工…

temu发货单在哪里打印

在Temu平台上,打印发货单是进行订单发货的重要步骤之一。通过打印发货单,您可以方便地记录订单信息并与物流公司进行配合。以下是在Temu平台上打印发货单的详细步骤和注意事项。 先给大家推荐一款拼多多/temu运营工具——多多情报通 多多情报通是拼多多…

Chart 8 内核优化

文章目录 前言8.1 内核融合和拆分8.2 编译选项8.3 Conformant(规范) vs. fast vs. native math functions8.4 Loop unrolling8.5 避免分支发散8.6 Handle image boundaries8.7 Avoid the use of size_t8.8 通用 vs. 具名内存地址空间8.9 Subgroup8.10 Us…

C++ 哈希表实现

目录 前言 一、什么是哈希表 二、直接定值法 三、开放定值法(闭散列) 1.开放定制法定义 2.开放定制法实现 2.1类的参数 2.2类的构造 2.3查找实现 2.4插入实现 2.5删除实现 2.6string做key 四、哈希桶(开散列) 1.开散…

让老板成为数据分析师,我用 ChatGpt 链接本地数据源实战测试

本文探究 ChatGpt 等AI机器人能否帮助老板快速的做数据分析?用自然语言同老板进行沟通,满足老板的所有数据分析的诉求? 一、背景 设想这样一个场景:你是某贸易公司的老板,公司所有的日常运转数据都在私域的进销存系统…

tqdm详细教程,实现tqdm进度条完美设计;解决进度条多行一直刷新的问题;如何使得滚动条不上下滚动(保持一行内滚动)

一、tqdm简介 tqdm是一个python进度条库,可以在 Python长循环中添加一个进度提示信息。 二、3种使用方法 1.tqdm(range)-自动更新 import time from tqdm import range# 自动更新 for i in tqdm(range(10)): # 共可以更新10次进度条time. Sleep(0.5) # 每次更新间…

nginx多端口部署

1.配置nginx.conf文件 有几个端口需要部署就写几个server,我这里只部署了两个端口分别为80和81端口,所以有两个server文件。80端口项目入口在根目录的test文件中,81端口项目入口在根目录的test1文件夹中。 2.准备项目文件html文件 在/test1…

2023年终总结-轻舟已过万重山

自我介绍 高考大省的读书人 白,陇西布衣,流落楚、汉。-与韩荆州书 我来自孔孟故里山东济宁,也许是小学时的某一天,我第一次接触到了电脑,从此对它产生了强烈的兴趣,高中我有一个愿望:成为一名计…

【漏洞复现】华脉智联指挥调度平台/xml_edit/fileread.php文件读取漏洞

Nx01 产品简介 深圳市华脉智联科技有限公司,融合通信系统将公网集群系统、专网宽带集群系统、不同制式、不同频段的短波/超短波对讲、模拟/数字集群系统、办公电话系统、广播系统、集群单兵视频、视频监控系统、视频会议系统等融为一体,集成了专业的有线…

【力扣】移除链表元素203

目录 1.前言2. 题目描述3. 题目分析3.1 不带哨兵位3.2 带哨兵位 4. 附代码4.1 不带哨兵位4.2 带哨兵位 1.前言 这里开始介绍从网上一些刷题网站上的题目,在这里做一些分享,和学习记录。 先来介绍一些力扣的OJ题目。 这里的OJ就是我们不需要写主函数&…

SpringBoot的监控(Actuator) 功能

目录 0、官方文档 一、引入依赖 二、application.yml文件中开启监控 三、具体使用 四、具体细节使用 五、端点开启与禁用 六、定制Endpoint 1. 定制 /actuator/health 2. 定制 /actuator/info (1)直接在配置文件中写死 (2&#xff…

【2023传智杯-新增场次】第六届传智杯程序设计挑战赛AB组-ABC题复盘解题分析详解【JavaPythonC++解题笔记】

本文仅为【2023传智杯-第二场】第六届传智杯程序设计挑战赛-题目解题分析详解的解题个人笔记,个人解题分析记录。 本文包含:第六届传智杯程序设计挑战赛题目、解题思路分析、解题代码、解题代码详解 文章目录 一.前言二.赛题题目A题题目-B题题目-C题题目-二.赛题题解A题题解-…

内存学习——堆(heap)

目录 一、概念二、自定义malloc函数三、Debug运行四、heap_4简单分析4.1 heap管理链表结构体4.2 堆初始化4.3 malloc使用4.4 free使用 一、概念 内存分为堆和栈两部分: 栈(Stack)是一种后进先出(LIFO)的数据结构&…

STM32-GPIO编程

一、GPIO 1.1 基本概念 GPIO(General-purpose input/output)通用输入输出接口 --GP 通用 --I input输入 --o output输出 通用输入输出接口GPIO是嵌入式系统、单片机开发过程中最常用的接口,用户可以通过编程灵活的对接口进行控制,…

MATLAB离线附加功能包下载与安装教程

MATLAB离线附加功能包下载与安装教程 本文介绍如何下载与安装MATLAB离线附加功能包,便于大家更加高效的使用MATLAB。 目录 MATLAB离线附加功能包下载与安装教程一、下载1. 获取MATLAB试用版账号2. 使用MATLAB Online搜索所需要的资源包3. 下载所需要的资源包二、安装由于不是…

【QED】井字棋

目录 题目背景题目描述输入格式输出格式测试样例 思路核心代码 题目背景 井字棋,英文名叫Tic-Tac-Toe,是一种在 3 3 3 \times 3 33格子上进行的连珠游戏,和五子棋类似。游戏时,由分别代表O和X的两名玩家轮流在棋盘格子里留下棋子…