目录
前言
一、完整代码
二、输出结果
三、实现步骤解析
1.读取数据
2.创建模型并训练
3.可视化SVM结果
总结
前言
支持向量机(SVM,Support Vector Machine)是一种用于分类和回归的监督学习算法。它的核心思想是通过在特征空间中找到一个最佳的分隔超平面来将数据分成不同的类别。
一、完整代码
import pandas as pd
# 读取数据
data = pd.read_csv('iris.csv', header=None)
"""
使用SVM进行训练
"""
from sklearn.svm import SVC # SVC做分类 SVR做回归
# 获取特征和标签
x = data.iloc[:, [1, 3]]
y = data.iloc[:, -1]
svm = SVC(kernel='linear', C=10, random_state=0) # C=float('inf')将软间隔的惩罚设置为无穷大
svm.fit(x, y)
"""
可视化SVM结果
"""
# 参数w[原始数据为二维数组]
w = svm.coef_[0]
# 偏置项b[原始数据为一维数组]
b = svm.intercept_[0]
# 超平面方程:w1x1+w2x2+b=0
# ->>x2 = -(w1x1+b)/w2
import numpy as np
x1 = np.linspace(0, 7, 300) # 在0-7内生成300个数据
# 超平面方程
x2 = -(w[0] * x1 + b) / w[1]
# 上超平面方程
x3 = (1 - (w[0] * x1 + b)) / w[1]
# 下超平面方程
x4 = (-1 - (w[0] * x1 + b)) / w[1]
import matplotlib.pyplot as plt
data1 = data.iloc[:50, :]
data2 = data.iloc[50:, :]
# 原数据为四维 无法展示 这里选择两个特征进行二维展示
plt.scatter(data1[1], data1[3], marker='+')
plt.scatter(data2[1], data2[3], marker='o')
# plt.show()
# 可视化超平面
plt.plot(x1, x2, linewidth=2, color='r')
plt.plot(x1, x3, linewidth=1, color='r', linestyle='--')
plt.plot(x1, x4, linewidth=1, color='r', linestyle='--')
# 进行坐标轴限制
plt.xlim(4, 7)
plt.ylim(0, 6)
# 可视化支持向量
vets = svm.support_vectors_
plt.scatter(vets[:, 0], vets[:, 1], c='b', marker='x')
plt.show()
二、输出结果
三、实现步骤解析
1.读取数据
- 这里使用的是鸢尾花的数据
import pandas as pd
# 读取数据
data = pd.read_csv('iris.csv', header=None)
2.创建模型并训练
- svm模型里的C参数可以用来控制惩罚力度进而控制软间隔的程度
- C越大,惩罚越严格,软间隔程度越小,越准确,但也越容易过拟合
- C越小,惩罚越不严格,软间隔程度越大,越不准确,但也越不容易过拟合
"""
使用SVM进行训练
"""
from sklearn.svm import SVC # SVC做分类 SVR做回归
# 获取特征和标签
x = data.iloc[:, [1, 3]]
y = data.iloc[:, -1]
svm = SVC(kernel='linear', C=10, random_state=0) # C=float('inf')将软间隔的惩罚设置为无穷大
svm.fit(x, y)
3.可视化SVM结果
- 获取svm模型里返回的系数和截距
- 再通过系数和截距求出各直线方程
- 最后进行二维的展示
"""
可视化SVM结果
"""
# 参数w[原始数据为二维数组]
w = svm.coef_[0]
# 偏置项b[原始数据为一维数组]
b = svm.intercept_[0]
import numpy as np
x1 = np.linspace(0, 7, 300) # 在0-7内生成300个数据
# 超平面方程
x2 = -(w[0] * x1 + b) / w[1]
# 上超平面方程
x3 = (1 - (w[0] * x1 + b)) / w[1]
# 下超平面方程
x4 = (-1 - (w[0] * x1 + b)) / w[1]
import matplotlib.pyplot as plt
data1 = data.iloc[:50, :]
data2 = data.iloc[50:, :]
# 原数据为四维 无法展示 这里选择两个特征进行二维展示
plt.scatter(data1[1], data1[3], marker='+')
plt.scatter(data2[1], data2[3], marker='o')
# plt.show()
# 可视化超平面
plt.plot(x1, x2, linewidth=2, color='r')
plt.plot(x1, x3, linewidth=1, color='r', linestyle='--')
plt.plot(x1, x4, linewidth=1, color='r', linestyle='--')
# 进行坐标轴限制
plt.xlim(4, 7)
plt.ylim(0, 6)
# 可视化支持向量
vets = svm.support_vectors_
plt.scatter(vets[:, 0], vets[:, 1], c='b', marker='x')
plt.show()
总结
总的来说,SVM可以使用核函数处理非线性问题,通过将数据映射到更高维的空间。正则化参数C控制分类准确性与模型复杂度之间的平衡。SVM广泛应用于文本分类、图像识别、生物信息学和金融预测等领域。