支持向量机 (support vector machine,SVM)
flyfish
支持向量机是一种用于分类和回归的机器学习模型。在分类任务中,SVM试图找到一个最佳的分隔超平面,使得不同类别的数据点在空间中被尽可能宽的间隔分开。
超平面方程和直线方程
超平面(hyperplane)是一个在高维空间中将空间分成两个部分的几何对象。它的方程可以在不同维度的空间中有不同的形式。
一维空间中的“超平面”
在一维空间中,超平面就是一个点。假设我们在一维空间中有一个超平面,它可以表示为:
x
=
a
x = a
x=a
其中,
a
a
a 是某个常数。这表示一维空间中的一个特定点,将空间分成两个部分:
x
<
a
x < a
x<a 和
x
>
a
x > a
x>a。
二维空间中的超平面(直线)
在二维空间中,超平面就是一条直线。直线的方程可以表示为:
y
=
k
x
+
b
y = kx + b
y=kx+b
其中,
k
k
k 是斜率,
b
b
b 是截距。或者,可以表示为标准形式:
a
x
+
b
y
+
c
=
0
ax + by + c = 0
ax+by+c=0
其中,
a
a
a、
b
b
b、
c
c
c 是常数。
这条直线将二维空间分成两个半平面。
三维空间中的超平面(平面)
在三维空间中,超平面是一个平面。平面的方程可以表示为:
a
x
+
b
y
+
c
z
+
d
=
0
ax + by + cz + d = 0
ax+by+cz+d=0
其中,
a
a
a、
b
b
b、
c
c
c 和
d
d
d 是常数。
这个平面将三维空间分成两个半空间。
一般形式的超平面方程
在更高维度的空间中,超平面的方程一般可以表示为:
w
⋅
x
+
b
=
0
\mathbf{w} \cdot \mathbf{x} + b = 0
w⋅x+b=0
其中:
-
w = ( w 1 , w 2 , … , w n ) \mathbf{w} = (w_1, w_2, \ldots, w_n) w=(w1,w2,…,wn) 是一个权重向量,定义了超平面的方向。
-
x = ( x 1 , x 2 , … , x n ) \mathbf{x} = (x_1, x_2, \ldots, x_n) x=(x1,x2,…,xn) 是一个点的坐标向量。
-
b b b 是偏置。
这个超平面将 n n n 维空间分成两个半空间。
直线方程是超平面方程在二维空间中的一种特例。一般来说,超平面是 n n n 维空间中的一个 ( n − 1 ) (n-1) (n−1) 维的对象:
-
在一维空间中,超平面是一个点。
-
在二维空间中,超平面是一个直线。
-
在三维空间中,超平面是一个平面。
-
在四维及更高维空间中,超平面是一个 ( n − 1 ) (n-1) (n−1) 维的对象。
示例和理解
一维空间中的超平面
x
=
2
x = 2
x=2
这是在一维空间中的一个点,将空间分为
x
<
2
x < 2
x<2 和
x
>
2
x > 2
x>2 两部分。
二维空间中的超平面
标准形式:
2
x
+
3
y
−
6
=
0
2x + 3y - 6 = 0
2x+3y−6=0
或者:
y
=
−
2
3
x
+
2
y = -\frac{2}{3}x + 2
y=−32x+2
这是在二维空间中的一条直线。
三维空间中的超平面
2
x
+
3
y
+
4
z
−
5
=
0
2x + 3y + 4z - 5 = 0
2x+3y+4z−5=0
这是在三维空间中的一个平面。
import numpy as np
import matplotlib.pyplot as plt
from sklearn import svm
# 生成一些数据
np.random.seed(0)
X = np.r_[np.random.randn(20, 2) - [2, 2], np.random.randn(20, 2) + [2, 2]]
Y = [0] * 20 + [1] * 20
# 拟合模型
clf = svm.SVC(kernel='linear')
clf.fit(X, Y)
# 绘制数据点和分类超平面
plt.scatter(X[:, 0], X[:, 1], c=Y, cmap=plt.cm.Paired)
ax = plt.gca()
xlim = ax.get_xlim()
ylim = ax.get_ylim()
# 创建网格以评估模型
xx = np.linspace(xlim[0], xlim[1], 30)
yy = np.linspace(ylim[0], ylim[1], 30)
YY, XX = np.meshgrid(yy, xx)
xy = np.vstack([XX.ravel(), YY.ravel()]).T
Z = clf.decision_function(xy).reshape(XX.shape)
# 绘制分类超平面
ax.contour(XX, YY, Z, colors='k', levels=[-1, 0, 1], alpha=0.5, linestyles=['--', '-', '--'])
ax.scatter(clf.support_vectors_[:, 0], clf.support_vectors_[:, 1], s=100, linewidth=1, facecolors='none', edgecolors='k')
plt.show()
import numpy as np
import matplotlib.pyplot as plt
from sklearn import svm
from mpl_toolkits.mplot3d import Axes3D
# 生成三维数据
np.random.seed(0)
X = np.r_[np.random.randn(20, 3) - [2, 2, 2], np.random.randn(20, 3) + [2, 2, 2]]
Y = [0] * 20 + [1] * 20
# 拟合模型
clf = svm.SVC(kernel='linear')
clf.fit(X, Y)
# 创建一个网格来绘制分类平面
xx, yy = np.meshgrid(np.linspace(-5, 5, 50), np.linspace(-5, 5, 50))
zz = (-clf.intercept_[0] - clf.coef_[0][0] * xx - clf.coef_[0][1] * yy) / clf.coef_[0][2]
# 绘制数据点和分类平面
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(X[:20, 0], X[:20, 1], X[:20, 2], color='b', marker='o', label='Class 0')
ax.scatter(X[20:, 0], X[20:, 1], X[20:, 2], color='r', marker='^', label='Class 1')
ax.plot_surface(xx, yy, zz, color='g', alpha=0.5, rstride=100, cstride=100)
ax.set_xlabel('X1')
ax.set_ylabel('X2')
ax.set_zlabel('X3')
plt.legend()
plt.show()
最大间隔解释
import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets
from sklearn.svm import SVC
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
# 生成一个简单的二维分类数据集
X, y = datasets.make_blobs(n_samples=50, centers=2, random_state=6)
# 训练一个线性支持向量机
clf = SVC(kernel='linear', C=1000)
clf.fit(X, y)
# 获取分隔超平面
w = clf.coef_[0]
b = clf.intercept_[0]
# 计算分隔超平面的两个端点
x = np.linspace(-10, 10, 100)
y_hyperplane = -w[0]/w[1] * x - b/w[1]
# 计算间隔边界
margin = 1 / np.sqrt(np.sum(w ** 2))
y_margin_up = y_hyperplane + margin
y_margin_down = y_hyperplane - margin
# 绘制数据点、分隔超平面及其间隔边界
plt.scatter(X[:, 0], X[:, 1], c=y, cmap='coolwarm')
plt.plot(x, y_hyperplane, 'k-', label='分隔超平面')
plt.plot(x, y_margin_up, 'k--', label='上间隔边界')
plt.plot(x, y_margin_down, 'k--', label='下间隔边界')
# 绘制支持向量
plt.scatter(clf.support_vectors_[:, 0], clf.support_vectors_[:, 1],
s=100, facecolors='none', edgecolors='k', label='支持向量')
plt.legend()
plt.xlabel('Feature 1')
plt.ylabel('Feature 2')
plt.title('最大化间隔的 SVM')
plt.show()
拉格朗日乘子法