使用Python从零实现多分类SVM

news2024/12/27 13:06:16

本文将首先简要概述支持向量机及其训练和推理方程,然后将其转换为代码以开发支持向量机模型。之后然后将其扩展成多分类的场景,并通过使用Sci-kit Learn测试我们的模型来结束。

SVM概述

支持向量机的目标是拟合获得最大边缘的超平面(两个类中最近点的距离)。可以直观地表明,这样的超平面(A)比没有最大化边际的超平面(B)具有更好的泛化特性和对噪声的鲁棒性。

为了实现这一点,SVM通过求解以下优化问题找到超平面的W和b:

它试图找到W,b,使最近点的距离最大化,并正确分类所有内容(如y取±1的约束)。这可以被证明相当于以下优化问题:

可以写出等价的对偶优化问题

这个问题的解决方案产生了一个拉格朗日乘数,我们假设数据集中的每个点的大小为m:(α 1, α 2,…,α _n)。目标函数在α中明显是二次的,约束是线性的,这意味着它可以很容易地用二次规划求解。一旦找到解,由对偶的推导可知:

注意,只有具有α>0的点才定义超平面(对和有贡献)。这些被称为支持向量。因此当给定一个新例子x时,返回其预测y=±1的预测方程为:

这种支持向量机的基本形式被称为硬边界支持向量机(hard margin SVM),因为它解决的优化问题(如上所述)强制要求训练中的所有点必须被正确分类。但在实际场景中,可能存在一些噪声,阻止或限制了完美分离数据的超平面,在这种情况下,优化问题将不返回或返回一个糟糕的解决方案。

软边界支持向量机(soft margin SVM)通过引入C常数(用户给定的超参数)来适应优化问题,该常数控制它应该有多“硬”。特别地,它将原优化问题修改为:

它允许每个点产生一些错误λ(例如,在超平面的错误一侧),并且通过将它们在目标函数中的总和加权C来减少它们。当C趋于无穷时(一般情况下肯定不会),它就等于硬边界。与此同时,较小的C将允许更多的“违规行为”(以换取更大的支持;例如,更小的w (w)。

可以证明,等价对偶问题只有在约束每个点的α≤C时才会发生变化。

由于允许违例,支持向量(带有α>0的点)不再都在边界的边缘。任何错误的支持向量都具有α=C,而非支持向量(α=0)不能发生错误。我们称潜在错误(α=C)的支持向量为“非错误编剧支持向量”和其他纯粹的支持向量(没有违规;“边界支持向量”(0<α<C)。

这样推理方程不变:

现在(xₛ,yₛ)必须是一个没有违规的支持向量,因为方程假设它在边界的边缘。

软边界支持向量机扩展了硬边界支持向量机来处理噪声,但通常由于噪声以外的因素,例如自然非线性,数据不能被超平面分离。软边界支持向量机可以用于这样的情况,但是最优解决方案的超平面,它允许的误差远远超过现实中可以容忍的误差。

例如,在左边的例子中,无论C的设置如何,软边界支持向量机都找不到线性超平面。但是可以通过某种转换函数z=Φ(x)将数据集中的每个点x映射到更高的维度,从而使数据在新的高维空间中更加线性(或完全线性)。这相当于用z替换x得到:

在现实中,特别是当Φ转换为非常高维的空间时,计算z可能需要很长时间。所以就出现了核函数。它用一个数学函数(称为核函数)的等效计算来取代z,并且更快(例如,对z进行代数简化)。例如,这里有一些流行的核函数(每个都对应于一些转换Φ到更高维度空间):

这样,对偶优化问题就变成:

直观地,推理方程(经过代数处理后)为:

上面所有方程的完整推导,有很多相关的文章了,我们就不详细介绍了。

Python实现

对于实现,我们将使用下面这些库:

 import numpy as np                  # for basic operations over arrays
 from scipy.spatial import distance  # to compute the Gaussian kernel
 import cvxopt                       # to solve the dual opt. problem
 import copy                         # to copy numpy arrays

定义核和SVM超参数,我们将实现常见的三个核函数:

 class SVM:
     linear = lambda x, xࠤ , c=0: x @ xࠤ.T
     polynomial = lambda x, xࠤ , Q=5: (1 + x @ xࠤ.T)**Q
     rbf = lambda x, xࠤ, γ=10: np.exp(-γ*distance.cdist(x, xࠤ,'sqeuclidean'))
     kernel_funs = {'linear': linear, 'polynomial': polynomial, 'rbf': rbf}

为了与其他核保持一致,线性核采用了一个额外的无用的超参数。kernel_funs接受核函数名称的字符串,并返回相应的内核函数。

继续定义构造函数:

 class SVM:
     linear = lambda x, xࠤ , c=0: x @ xࠤ.T
     polynomial = lambda x, xࠤ , Q=5: (1 + x @ xࠤ.T)**Q
     rbf = lambda x, xࠤ, γ=10: np.exp(-γ*distance.cdist(x, xࠤ,'sqeuclidean'))
     kernel_funs = {'linear': linear, 'polynomial': polynomial, 'rbf': rbf}
     
     def __init__(self, kernel='rbf', C=1, k=2):
         # set the hyperparameters
         self.kernel_str = kernel
         self.kernel = SVM.kernel_funs[kernel]
         self.C = C                  # regularization parameter
         self.k = k                  # kernel parameter
         
         # training data and support vectors (set later)
         self.X, y = None, None
         self.αs = None
         
         # for multi-class classification (set later)
         self.multiclass = False
         self.clfs = []              

SVM有三个主要的超参数,核(我们存储给定的字符串和相应的核函数),正则化参数C和核超参数(传递给核函数);它表示多项式核的Q和RBF核的γ。

为了兼容sklearn的形式,我们需要使用fit和predict函数来扩展这个类,定义以下函数,并在稍后将其用作装饰器:

 SVMClass = lambda func: setattr(SVM, func.__name__, func) or func

拟合SVM对应于通过求解对偶优化问题找到每个点的支持向量α:

设α为可变列向量(α₁α₂…α _n);y为标签(y₁α₂…y_N)常数列向量;K为常数矩阵,其中K[n,m]计算核在(x, x)处的值。点积、外积和二次型分别基于索引的等价表达式:

可以将对偶优化问题写成矩阵形式如下:

这是一个二次规划,CVXOPT的文档中解释如下:

可以只使用(P,q)或(P,q,G,h)或(P,q,G,h, A, b)等等来调用它(任何未给出的都将由默认值设置,例如1)。

对于(P, q, G, h, A, b)的值,我们的例子可以做以下比较:

为了便于比较,将第一个重写如下:

现在很明显(0≤α等价于-α≤0):

我们就可以写出如下的fit函数:

 @SVMClass
 def fit(self, X, y, eval_train=False):
     # if more than two unique labels, call the multiclass version
     if len(np.unique(y)) > 2:
         self.multiclass = True
         return self.multi_fit(X, y, eval_train)
     
     # if labels given in {0,1} change it to {-1,1}
     if set(np.unique(y)) == {0, 1}: y[y == 0] = -1
 
     # ensure y is a Nx1 column vector (needed by CVXOPT)
     self.y = y.reshape(-1, 1).astype(np.double) # Has to be a column vector
     self.X = X
     N = X.shape[0]  # Number of points
     
     # compute the kernel over all possible pairs of (x, x') in the data
     # by Numpy's vectorization this yields the matrix K
     self.K = self.kernel(X, X, self.k)
     
     ### Set up optimization parameters
     # For 1/2 x^T P x + q^T x
     P = cvxopt.matrix(self.y @ self.y.T * self.K)
     q = cvxopt.matrix(-np.ones((N, 1)))
     
     # For Ax = b
     A = cvxopt.matrix(self.y.T)
     b = cvxopt.matrix(np.zeros(1))
 
     # For Gx <= h
     G = cvxopt.matrix(np.vstack((-np.identity(N),
                                  np.identity(N))))
     h = cvxopt.matrix(np.vstack((np.zeros((N,1)),
                                  np.ones((N,1)) * self.C)))
 
     # Solve    
     cvxopt.solvers.options['show_progress'] = False
     sol = cvxopt.solvers.qp(P, q, G, h, A, b)
     self.αs = np.array(sol["x"])            # our solution
         
     # a Boolean array that flags points which are support vectors
     self.is_sv = ((self.αs-1e-3 > 0)&(self.αs <= self.C)).squeeze()
     # an index of some margin support vector
     self.margin_sv = np.argmax((0 < self.αs-1e-3)&(self.αs < self.C-1e-3))
     
     if eval_train:  
       print(f"Finished training with accuracy{self.evaluate(X, y)}")

我们确保这是一个二进制问题,并且二进制标签按照支持向量机(±1)的假设设置,并且y是一个维数为(N,1)的列向量。然后求解求解(α₁α₂…α _n) 的优化问题。

使用(α₁α₂…α _n) _来获得在与支持向量对应的任何索引处为1的标志数组,然后可以通过仅对支持向量和(xₛ,yₛ)的边界支持向量的索引求和来应用预测方程。我们确实假设非支持向量可能不完全具有α=0,如果它的α≤1e-3,那么这是近似为零(CVXOPT结果可能不是最终精确的)。同样假设非边际支持向量可能不完全具有α=C。

下面就是预测的方法,预测方程为:

 @SVMClass
 def predict(self, X_t):
     if self.multiclass: return self.multi_predict(X_t)
     # compute (xₛ, yₛ)
     xₛ, yₛ = self.X[self.margin_sv, np.newaxis], self.y[self.margin_sv]
     # find support vectors
     αs, y, X= self.αs[self.is_sv], self.y[self.is_sv], self.X[self.is_sv]
     # compute the second term
     b = yₛ - np.sum(αs * y * self.kernel(X, xₛ, self.k), axis=0)
     # compute the score
     score = np.sum(αs * y * self.kernel(X, X_t, self.k), axis=0) + b
     return np.sign(score).astype(int), score

我们还可以实现一个评估方法来计算精度(在上面的fit中使用)。

 @SVMClass
 def evaluate(self, X,y):  
     outputs, _ = self.predict(X)
     accuracy = np.sum(outputs == y) / len(y)
     return round(accuracy, 2)

最后测试我们的完整代码:

 from sklearn.datasets import make_classification
 import numpy as np
 
 # Load the dataset
 np.random.seed(1)
 X, y = make_classification(n_samples=2500, n_features=5, 
                            n_redundant=0, n_informative=5, 
                            n_classes=2,  class_sep=0.3)
 
 # Test Implemented SVM
 svm = SVM(kernel='rbf', k=1)
 svm.fit(X, y, eval_train=True)
 
 y_pred, _ = svm.predict(X)
 print(f"Accuracy: {np.sum(y==y_pred)/y.shape[0]}")  #0.9108
 
 # Test with Scikit
 from sklearn.svm import SVC
 clf = SVC(kernel='rbf', C=1, gamma=1)
 clf.fit(X, y)
 y_pred = clf.predict(X)
 print(f"Accuracy: {sum(y==y_pred)/y.shape[0]}")    #0.9108

多分类SVM

我们都知道SVM的目标是二元分类,如果要将模型推广到多类则需要为每个类训练一个二元SVM分类器,然后对每个类进行循环,并将属于它的点重新标记为+1,并将所有其他类的点重新标记为-1。

当给定k个类时,训练的结果是k个分类器,其中第i个分类器在数据上进行训练,第i个分类器被标记为+1,所有其他分类器被标记为-1。

 @SVMClass
 def multi_fit(self, X, y, eval_train=False):
     self.k = len(np.unique(y))      # number of classes
     # for each pair of classes
     for i in range(self.k):
         # get the data for the pair
         Xs, Ys = X, copy.copy(y)
         # change the labels to -1 and 1
         Ys[Ys!=i], Ys[Ys==i] = -1, +1
         # fit the classifier
         clf = SVM(kernel=self.kernel_str, C=self.C, k=self.k)
         clf.fit(Xs, Ys)
         # save the classifier
         self.clfs.append(clf)
     if eval_train:  
         print(f"Finished training with accuracy {self.evaluate(X, y)}")

然后,为了对新示例执行预测,我们选择相应分类器最自信(得分最高)的类。

 @SVMClass
 def multi_predict(self, X):
     # get the predictions from all classifiers
     N = X.shape[0]
     preds = np.zeros((N, self.k))
     for i, clf in enumerate(self.clfs):
         _, preds[:, i] = clf.predict(X)
     
     # get the argmax and the corresponding score
     return np.argmax(preds, axis=1), np.max(preds, axis=1)

完整测试代码:

 from sklearn.datasets import make_classification
 import numpy as np
 
 # Load the dataset
 np.random.seed(1)
 X, y = make_classification(n_samples=500, n_features=2, 
                            n_redundant=0, n_informative=2, 
                            n_classes=4, n_clusters_per_class=1,  
                            class_sep=0.3)
 
 # Test SVM
 svm = SVM(kernel='rbf', k=4)
 svm.fit(X, y, eval_train=True)
 
 y_pred = svm.predict(X)
 print(f"Accuracy: {np.sum(y==y_pred)/y.shape[0]}") # 0.65
 
 # Test with Scikit
 from sklearn.multiclass import OneVsRestClassifier
 from sklearn.svm import SVC
 
 clf = OneVsRestClassifier(SVC(kernel='rbf', C=1, gamma=4)).fit(X, y)
 y_pred = clf.predict(X)
 print(f"Accuracy: {sum(y==y_pred)/y.shape[0]}")    # 0.65

绘制每个决策区域的图示,得到以下图:

可以看到,我们的实现与Sci-kit Learn结果相当,说明在算法实现上没有问题。注意:SVM默认支持OVR(没有如上所示的显式调用),它是特定于SVM的进一步优化。

总结

我们使用Python实现了支持向量机(SVM)学习算法,并且包括了软边界和常用的三个核函数。我们还将SVM扩展到多分类的场景,并使用Sci-kit Learn验证了我们的实现。希望通过本文你可以更好的了解SVM。

https://avoid.overfit.cn/post/0b2410e6737a4911be507ca29cb3136c

作者:Essam Wisam

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1179932.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

我哭了,终于找到了合适的程序员接单平台!

说起我接单这条道路可真是艰难又漫长。 为什么说它艰难呢&#xff1f; 因为我总是被骗。 第1次接单的时候&#xff0c;由于是熟人&#xff0c;所以没好意思狠下心要价&#xff0c;结果辛辛苦苦搞了半个月到口袋的钱还没有我请别人帮我介绍单子的钱多还各种各样的挑剔。第2次我…

开联通支付牌照“易主”

据西米支付网报道&#xff0c;最新消息显示&#xff0c;持牌支付机构开联通支付服务有限公司&#xff08;以下简称“开联通支付”&#xff09;发生了股权出质。该公司已经出质的股权总额达到9000万元&#xff0c;占有公司股权总数的90%。 根据登记编号为91110108565839081K_000…

scrapy案例教程

文章目录 1 scrapy简介2 创建项目3 自定义初始化请求url4 定义item5 定义管道 1 scrapy简介 scrapy常用命令 |命令 | 格式 |说明| |–|–|–| |startproject |scrapy startproject <项目名> |创建一个新项目| |genspider| scrapy genspider <爬虫文件名> <域名…

【Linux】磁盘阵列,了解不同raid的特点

一、raid和阵列卡介绍 1、什么是磁盘阵列&#xff1a; 磁盘阵列是利用虚拟化存储技术把很多块独立的磁盘组合成一个容量巨大的磁盘组&#xff0c;利用个别磁盘提供数据所产生加成效果提升整个磁盘系统效能。利用这项技术&#xff0c;将数据切割成许多区段&#xff0c;分别存放…

CRM系统如何帮助无损检测设备企业发展?

得益于新兴行业的高速发展&#xff0c;近些年无损检测设备在国内市场得到了规模增长。通过搭建完整的CRM客户管理系统&#xff0c;打通营销、销售及服务各环节&#xff0c;进一步提高企业市场竞争力。CRM系统如何帮助无损检测设备企业发展&#xff1f; 无损检测设备企业无论在…

CSS 边框、轮廓线

一、CSS边框&#xff1a; CSS边框属性允许指定一个元素边框的样式和颜色。 1&#xff09;、边框样式&#xff1a;border-style属性用来定义边框的样式&#xff0c;border-style值&#xff1a; 2&#xff09;、边框宽度&#xff1a;border-width属性用于指定边框宽度。指定变宽…

TCP编程及基础知识

一、端口号 为了区分一台主机接收到的数据包应该转交给哪个进程来进行处理&#xff0c;使用端口号来区分TCP端口号与UDP端口号独立端口用两个字节来表示 2byte&#xff08;65535个&#xff09; 众所周知端口&#xff1a;1~1023&#xff08;1~255之间为众所周知端口&#xff…

软件测试/测试开发丨Python安装指南(macOS)

点此获取更多相关资料 下载 Python 解释器 下载地址: https://www.Python.org/downloads/macos 通过下载页面&#xff0c;可以在该页面上看到下载链接。 下载完成后会得到 Python-3.10.11-macos11.pkg安装文件 。 安装 Python 解释器 双击Python-3.10.11-macos11.pkg文件&a…

Vue3指令

Vue 指令&#xff08;Directives&#xff09;是 Vue.js 的一项核心功能&#xff0c;它们可以在 HTML 模板中以 v- 开头的特殊属性形式使用&#xff0c;用于将响应式数据绑定到 DOM 元素上或在 DOM 元素上进行一些操作。 Vue 指令是带有前缀 v- 的特殊 HTML 属性&#xff0c;它赋…

Linux操作系统中软件安装:用RPM包管理器安装软件步骤

安装软件的一般步骤如下&#xff1a; 1.打开终端&#xff0c;作为root用户或使用sudo命令获取管理员权限。 2.使用RPM命令进行软件包的安装。例如&#xff0c;使用“rpm -ivh 软件包名称.rpm”命令来安装软件包&#xff0c;其中“-i”表示安装&#xff0c;“-v”表示显示详细安…

【入门Flink】- 07Flink DataStream API【万字篇】

DataStream API 是 Flink 的核心层 API。一个 Flink 程序&#xff0c;其实就是对DataStream的各种转换。 代码基本上都由以下几部分构成&#xff1a; 执行环境&#xff08;Execution Environment&#xff09; 1&#xff09;创建执行环境StreamExecutionEnvironment StreamExe…

【启扬方案】基于RK3568核心板的激光打标机应用解决方案

激光打标机是一种利用激光技术进行标记和刻字的设备&#xff0c;作为激光技术应用的一个细分领域&#xff0c;是最早引入工业市场的一类激光装备&#xff0c;它采用激光束在工件表面进行刻印、打标&#xff0c;常用于工业生产中的物料标识、产品追溯、防伪标记等应用&#xff0…

centos7安装mysql-阿里云服务器

1.背景 2.安装 2.1.下载安装包 wget https://dev.mysql.com/get/mysql57-community-release-el7-8.noarch.rpm2.2.安装mysql rpm -ivh mysql57-community-release-el7-8.noarch.rpm 3.安装mysql服务 3.1.进入目录 首先进入cd /etc/yum.repos.d/目录 cd /etc/yum.repos.d/ 3.…

Netty 高性能原因之一 采用了高性能的NIO 模式

java IO简介 I/O 全称Input/Output&#xff0c;即输入/输出&#xff0c;通常指数据在内部存储器和外部存储器或其他周边设备之间的输入/输出。 涉及 I/O 的操作&#xff0c;不仅仅局限于硬件设备的读写&#xff0c;还要网络数据的传输。无论是从磁盘中读写文件&#xff0c;还…

【广州华锐互动】VR综合布线虚拟实验教学系统

随着科技的不断发展&#xff0c;虚拟现实&#xff08;VR&#xff09;技术已经逐渐渗透到各个领域&#xff0c;为人们的生活和工作带来了前所未有的便利。在建筑行业中&#xff0c;VR技术的应用也日益广泛&#xff0c;尤其是在综合布线方面。 广州华锐互动开发的VR综合布线虚拟实…

百度上线“文心一言”付费版本,AI聊天机器人市场竞争加剧

原创 | 文 BFT机器人 百度不愧是我国AI技术领域的先行者&#xff0c;每年致力于人工智能领域取得技术产品的突破和创新。据爆料称&#xff0c;百度的文心一言有突破了新境界&#xff0c;开创了文心大模型4.0会员版本。从线上的to C产品到试水商业化&#xff0c;百度都是争先走…

Python的requests库爬取商城优惠券

首先&#xff0c;我们需要了解要抓取的网页的结构和数据格式。在这个例子中&#xff0c;我们使用Python的requests库来发送HTTP请求&#xff0c;并使用BeautifulSoup库来解析HTML内容。 import requests from bs4 import BeautifulSoup然后&#xff0c;我们需要使用requests库的…

LeetCode | 160. 相交链表

LeetCode | 160. 相交链表 O链接 我们这里有两个问题&#xff0c;一是判断是否相交&#xff0c;二是找交点 思路一&#xff1a; 暴力求解 A链表所有节点依次取B链表找一遍&#xff08;时间复杂度是O(N^2)&#xff09; struct ListNode *getIntersectionNode(struct ListNod…

QT not in executable format:file truncated

今天在调研串口打印机的时候出现的&#xff0c;串口打印机有sdk&#xff0c;自己qt的编辑器用的 MinGW 64&#xff0c;编译出现次错误 出现这个错误&#xff0c;主要是sdk和编译器的版本位数不一致。 修改方法&#xff1a;把MinGW64 改为MinGW32&#xff0c;不过这个根据使用的…

为什么说制造企业需要部署MES管理系统

在数字化浪潮席卷的今天&#xff0c;每个企业都期望通过新技术、新模式来优化自身的运营。这其中&#xff0c;MES管理系统成为了不少企业的首选。那么&#xff0c;为何企业需要部署MES管理系统&#xff1f;又该如何搭建MES管理系统呢&#xff1f; 一、企业缘何钟情于MES系统&am…