支持向量机SVM代码详解——多分类/降维可视化/参数优化【python】

news2024/9/23 11:48:12

篇1:SVM原理及多分类python代码实例讲解(鸢尾花数据)

SVM原理

支持向量机(Support Vector Machine,SVM),主要用于小样本下的二分类、多分类以及回归分析,是一种有监督学习的算法。基本思想是寻找一个超平面来对样本进行分割,把样本中的正例和反例用超平面分开,其原则是使正例和反例之间的间隔最大。

SVM学习的基本想法是求解能够正确划分训练数据集并且几何间隔最大的分离超平面。如下图所示,wx+b=0即为分离超平面,对于线性可分的数据集来说,这样的超平面有无穷多个(即感知机),但是几何间隔最大的分离超平面却是唯一的。

SVM实现分类代码 

1.数据集介绍——鸢尾花数据集

下载方式:通过UCI Machine Learning Repository下载或者直接使用代码

from sklearn.datasets import load_iris

数据展示与介绍(iris.data)

Iris.data中有5个属性,包括4个预测属性(萼片长度、萼片宽度、花瓣长度、花瓣宽度)和1个类别属性(Iris-setosa、Iris-versicolor、Iris-virginica三种类别)。首先,需要将第五列类别信息转换为数字,再选择输入数据和标签。

 2.多分类python代码(二分类可看做只有两类的多分类)

from sklearn import svm #引入svm包
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
from sklearn.model_selection import train_test_split
#定义字典,将字符与数字对应起来
def Iris_label(s):
    it={b'Iris-setosa':0, b'Iris-versicolor':1, b'Iris-virginica':2}
    return it[s]
#读取数据,利用np.loadtxt()读取text中的数据
path='iris.data'              #将下载的原始数据放到项目文件夹,即可不用写路径
data= np.loadtxt(path, dtype=float, delimiter=',', converters={4:Iris_label}) #分隔符为‘,'
#确定输入和输出
x,y=np.split(data,(4,),axis=1)   #将data按前4列返回给x作为输入,最后1列给y作为标签值
x=x[:,0:2]              #取x的前2列作为svm的输入,为了便于可视化展示
#划分数据集和标签:利用sklearn中的train_test_split对原始数据集进行划分,本实验中训练集和测试集的比例为7:3。
train_data,test_data,train_label,test_label=train_test_split(x,y,random_state=1,train_size=0.7,test_size=0.3)
#创建svm分类器并进行训练:首先,利用sklearn中的SVC()创建分类器对象,其中常用的参数有C(惩罚力度)、kernel(核函数)、gamma(核函数的参数设置)、decision_function_shape(因变量的形式),再利用fit()用训练数据拟合分类器模型。
'''C越大代表惩罚程度越大,越不能容忍有点集交错的问题,但有可能会过拟合(defaul C=1);
kernel常规的有‘linear’, ‘poly’, ‘rbf’, ‘sigmoid’, ‘precomputed’ ,默认的是rbf;
gamma是核函数为‘rbf’, ‘poly’ 和 ‘sigmoid’时的参数设置,其值越小,分类界面越连续,其值越大,分类界面越“散”,分类效果越好,但有可能会过拟合,默认的是特征个数的倒数;
decision_function_shape='ovr'时,为one v rest(一对多),即一个类别与其他类别进行划分,等于'ovo'时,为one v one(一对一),即将类别两两之间进行划分,用二分类的方法模拟多分类的结果。
'''
model=svm.SVC(C=2,kernel='rbf',gamma=10,decision_function_shape='ovo')
model.fit(train_data,train_label.ravel())    #ravel函数在降维时默认是行序优先
#利用classifier.score()分别计算训练集和测试集的准确率。
train_score = model.score(train_data,train_label)
print("训练集:",train_score)
test_score = model.score(test_data,test_label)
print("测试集:",test_score)
#决策函数的查看(可省略)
#print('train_decision_function:\n',model.decision_function(train_data))#(90,3)

输出结果:

分类结果的可视化(由于只选取了2个维度的输入数据,所以可视化可直接进行二维平面作图)

#训练集和测试集的预测结果
trainPredict = (model.predict(train_data).reshape(-1, 1))
testPredict = model.predict(test_data).reshape(-1, 1)
#将预测结果进行展示,首先画出预测点,再画出分类界面
#预测点的画法,可参考https://zhuanlan.zhihu.com/p/81006952
#画图例和点集
x1_min,x1_max=x[:,0].min(),x[:,0].max()   #x轴范围
x2_min,x2_max=x[:,1].min(),x[:,1].max()   #y轴范围
matplotlib.rcParams['font.sans-serif']=['SimHei']   #指定默认字体
cm_dark=matplotlib.colors.ListedColormap(['g','r','b'])  #设置点集颜色格式
cm_light=matplotlib.colors.ListedColormap(['#A0FFA0','#FFA0A0','#A0A0FF'])  #设置边界颜色
plt.xlabel('length',fontsize=13)        #x轴标注
plt.ylabel('width',fontsize=13)        #y轴标注
plt.xlim(x1_min,x1_max)                   #x轴范围
plt.ylim(x2_min,x2_max)                   #y轴范围
plt.title('SVM result')          #标题
plt.scatter(x[:,0],x[:,1],c=y[:,0],s=30,cmap=cm_dark)  #画出测试点
plt.scatter(test_data[:,0],test_data[:,1],c=test_label[:,0],s=30,edgecolors='k',zorder=2,cmap=cm_dark) #画出预测点,并将预测点圈出 
#画分类界面
x1,x2=np.mgrid[x1_min:x1_max:200j,x2_min:x2_max:200j]#生成网络采样点
grid_test=np.stack((x1.flat,x2.flat),axis=1)#测试点
grid_hat=model.predict(grid_test)# 预测分类值
grid_hat=grid_hat.reshape(x1.shape)# 使之与输入的形状相同
plt.pcolormesh(x1,x2,grid_hat,cmap=cm_light)# 预测值的显示
plt.show()

可视化绘图结果:

 根据结果,发现分类的效果并不好。接着,将输入鸢尾花的所有特征进行训练,并将分类结果降维后可视化展示。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/732255.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

腾讯云对象存储的创建和S3 Browser的使用

简述 想想第一次接触对象存储的时候还是很兴奋的,同时也是一脸懵逼;然后开始网上疯狂的找资料,但因为客户当时给的文档写的是关于Amazon S3之类的,所以自以为的就只有Amazon S3这一家,接着开始查资料,经过一…

Spark学习---2、SparkCore(RDD概述、RDD编程(创建、分区规则、转换算子、Action算子))

1、RDD概述 1.1 什么是RDD RDD(Resilient Distributed Dataset)叫弹性分布式数据集,是Spark中对于分布式数据集的抽象。代码中是一个抽象类,它代表一个弹性的、不可变、可分区、里面的元素可并行计算的集合。 1.2 RDD五大特性 1、一组分区&#xff0…

Pyecharts 绘制各种统计图的案例

Pyecharts 绘制各种统计图的案例 基础使用 from pyecharts import options as opts from pyecharts.charts import Bar, Line, Pie, Scatter from pyecharts.faker import Faker# 柱状图示例 def bar_chart():x_data Faker.choose()y_data Faker.values()bar (Bar().add_xa…

simulink实战 建模 简单车辆动力学模型

Gmg Discrete-TimeIntegrator 离散时间积分器

CentOS 7 搭建 Impala 4.1.2 + Kudu 1.15.0 测试环境

安装依赖 这部分不过于详细介绍,如果有现成环境也可以直接拿来使用。 Java 下载 java 安装包,需要登录 oracle,请自行下载。 cd /mnt tar zxvf jdk-8u202-linux-x64.tar.gz配置环境变量到 /etc/bashrc,并执行 source /etc/bas…

关于深度学习图像数据增广

数据增广方法在广义上可以按照产生新数据的方式分为数据变形和数据过采样。由于操作简单,同时数据量上的需求远比现在要低得多,早期对数据增广的应用多是数据变形类方法。对于图像数据,基本的图像变换操作都属于数据变形类增广方法&#xff0…

Jvm参数设置-JVM(八)

上篇文章说了逃逸分析和标量,代码实例解析了内存分配先从eden区域开始,当内存不足的时候,才会进入s0和s1,发生yangGC,之后大内存会放入old,因为我们昨天程序运行了一个45M的对象,于是小对象在ed…

详解------>数组笔试题(必备知识)

目录 本章将通过列题进一步了解sizeof 与strlen的区别,加强对数组的理解。 1:一维数组列题 2:字符数组列题 3:二维数组列题 首先在进行这些习题讲解之前我们需要知道的知识点 sizeof:是一个关键字,可以…

KMP--高效字符串匹配算法(Java)

KMP算法 KMP算法算法介绍代码演示: KMP算法 KMP算法是为了解决这一类问题,给定一个字符串str1,和一个字符串str2,如果str2属于str1d的字串,则返回字串第一个出现位置的下标,不存在返回-1. 注意: 子串是连续的. 举个例子 str1 “abc123abs” str1 长度假设m str2 “123”; str2…

pycharm汉化

安装pycharm 不多说了,直接下载安装即可 汉化 file -setting plugins 输入chinese进行搜索 点击 进行安装,等待安装完成 安装完成需要重启,点击重启,等待重启完成即可 出现上图,说明汉化成功了

【计算机视觉】YOLOv8的测试以及训练过程(含源代码)

文章目录 一、导读二、部署环境三、预测结果3.1 使用检测模型3.2 使用分割模型3.3 使用分类模型3.4 使用pose检测模型 四、COCO val 数据集4.1 在 COCO128 val 上验证 YOLOv8n4.2 在COCO128上训练YOLOv8n 五、自己训练5.1 训练检测模型5.2 训练分割模型5.3 训练分类模型5.4 训练…

Mybatis-xml和动态sql

xml映射方式 除了之前那种 select(语句) public void ...();通过注解定义sql语句&#xff0c;还可以通过xml的方式来定义sql语句 注意 在resource创建的是目录&#xff0c;要用斜线分隔 创建出文件后 先写约束 <?xml version"1.0" encoding"UTF-8"…

第4集丨JavaScript 使用原型(prototype)实现继承——最佳实战2

目录 一、临时构造器方式1.1 代码实现1.2 代码分析 二. 增加uber属性&#xff0c;用于子对象访问父对象2.1 实现分析2.2 代码实现 三. 将继承封装成extend()函数3.1 代码实现3.1.1 临时构造器实现extend()3.1.2 原型复制实现extend2() 3.2 代码测试3.2.1 测试extend()函数3.2.1…

uniapp打包嵌入app,物理返回键的问题

问题描述&#xff1a;将uniapp开发的应用打包成wgt包放入app后&#xff0c;发现手机自带的返回键的点击有问题&#xff0c;比如我从app原生提供的入口进入了uniapp的列表页&#xff0c;然后我又进入了详情页&#xff0c;这时候在详情页点击物理返回键的话&#xff0c;它直接就返…

C语言—最大公约数和最小公倍数

作者主页&#xff1a;paper jie的博客_CSDN博客-C语言,算法详解领域博主 本文作者&#xff1a;大家好&#xff0c;我是paper jie&#xff0c;感谢你阅读本文&#xff0c;欢迎一建三连哦。 本文录入于《算法详解》专栏&#xff0c;本专栏是针对于大学生&#xff0c;编程小白精心…

过河卒

题目描述 棋盘上 A 点有一个过河卒&#xff0c;需要走到目标 B 点。卒行走的规则&#xff1a;可以向下、或者向右。同时在棋盘上 C 点有一个对方的马&#xff0c;该马所在的点和所有跳跃一步可达的点称为对方马的控制点。因此称之为“马拦过河卒”。 棋盘用坐标表示&#xff…

云同步盘 vs 普通网盘:选择哪种更适合你?区别解析与选购指南!

云同步盘是一种基于云存储的在线服务&#xff0c;主要用于将本地文件存储到云端&#xff0c;并通过客户端软件实现文件的自动同步&#xff0c;从而保持本地和云端文件的同步更新。用户可以在任何设备上访问和共享这些文件。 云同步盘和普通云盘都是云存储服务&#xff0c;可以让…

Kubernetes CoreDNS

Kubernetes CoreDNS 1、DNS服务概述 coredns github 地址&#xff1a; https://github.com/kubernetes/kubernetes/blob/master/cluster/addons/dns/coredns/coredns.yaml.base service 发现是 k8s 中的一个重要机制&#xff0c;其基本功能为&#xff1a;在集群内通过服务名…

TL-ER2260T获取SSH密码并登录后台

TL-ER2260T获取SSH密码并登录后台 首先需要打开诊断模式 打开Ubuntu&#xff0c;通过如下指令计算SSH密码&#xff0c;XX-XX-XX-XX-XX-XX是MAC地址echo -n "XX-XX-XX-XX-XX-XX" | tr -d - | tr [a-z] [A-Z] | md5sum | cut -b 1-16SSH登录ssh -oKexAlgorithmsdiffie…

硬件打样和小批量生产

PCB 打样和小批量生产过程 包括PCB 定型、生产文件制作、元器件准备、装配图制作、贴片、全流程测试。 打样一般是 几块PCB 手工进行焊接。 其中生产文件根据加工厂 一般提供PCB或者Gerber。 元器件准备设计公司的物料管理&#xff0c;这里假设已经拿到了所需的物料。 装…