关于svm机器学习模型,我主要学习的是周志华老师的西瓜书(《机器学习》);
但是西瓜书中对于参数优化(即:Sequential Minimal Optimization,smo算法)部分讲解的十分简略,看起来不太好懂。因此这一部分参考的是John C. Platt 1998年发表的论文:Sequential Minimal Optimization: A Fast Algorithm for Training Support Vector Machines
值得注意的是,S.S. Keerthi在2001年又发表了一篇名为Imrovements to Platt’s SMO Algorithm for SVM Classifier Design的文章,在这篇文章中,它改进了原版本smo的收敛条件,并融入了许多缓存机制,好处是求解速度更快了,但理解起来较为晦涩。
因为smo数学原理较强,处于学习考虑,我这里的实现参考的的是John C. Platt 1998年发表的论文。
一、支持向量机(SVM)模型
支持向量机就是想找到一个间隔最大的超平面,将正负两种样本分割开来,进而实现分类的一个模型。
支持向量机寻找到的超平面可以用如下公式来表示:
w T x + b = 0 w^Tx+b=0 wTx+b=0
如果输入x,结果大于0,就为正例;
如果输入x,结果小于0,就为负例,进而实现分类任务。
根据周志华西瓜书(《机器学习》p121-123)中的公式推导,我们要想寻找到参数的最优解,最终的优化目标如下:
min w , b 1 2 ∣ ∣ w ∣ ∣ 2 s . t . y i ( w T x i + b ) ≥ 1 , i = 1 , 2 , . . . , m \min_{w,b}\frac{1}{2}||w||^2\\ s.t. \quad y_i(w^Tx_i+b) \ge 1,\quad i=1,2,...,m w,bmin21∣∣w∣∣2s.t.yi(wTxi+b)≥1,i=1,2,...,m
上面这个最优化目标对应的 w w w 和 b b b 就是我们最终想要的结果。1
二、序列最小优化算法(SMO)
而优化上述模型所采用的优化算法一种就是二次规划,采用线程的优化包进行求解,但是当样本量非常大的情况下,约束目标的数量也会非常大,会出现维度爆炸的问题,而相比之下,SMO算法就可以很好地解决这个问题。
在学习SMO算法的时候,我首先阅读的是西瓜书上的相关内容,但是十分晦涩,读完后一头雾水。
然后我又找同学借了本李航老师的《统计机器学习》进行阅读,读了几遍之后感觉虽然有了一个大体的思路,但是具体如何编码实现呢?比如如何判定一个样例是否满足KKT条件?还是不太会。
直到最后,被逼无奈之下去看了John C. Platt 1998年发表的原版论文(Sequential Minimal Optimization: A Fast Algorithm for Training Support Vector Machines),看完后真的有种柳暗花明又一村的感觉,感觉比上述两位老师写的教材还要好懂一些,所以十分建议阅读。最重要的是,在文章的末尾,John C. Platt前辈还提供了一段C语言的伪代码,对照着伪代码以及文章中的公式,再回过头来写svm的模型就很容易了。2
三、基于weka实现的SMO
因为svm需要经过一个sigmoid函数类似的指数类型的变换以及核函数的处理(高斯核、拉普拉斯核都会涉及到指数),因此,其值大小是非常重要的。如果svm输出的值为几百或者几千,那么经过指数后,直接就会变为无穷大inf或者Nan。
我在编码的完成后,自己写的svm的性能总是不好,找了很久的原因,最终定位在了这个bug上。
package weka.classifiers.myf;
import weka.classifiers.Classifier;
import weka.core.Instance;
import weka.core.Instances;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.NominalToBinary;
import weka.filters.unsupervised.attribute.Standardize;
import java.io.Serializable;
import java.util.HashSet;
import java.util.Set;
enum KernelType {
KERNEL_LINEAR, KERNEL_POLYNOMIAL, KERNEL_RBF, KERNEL_SIGMOID
}
/**
* @author YFMan
* @Description 自定义的 SMO 分类器
* @Date 2023/6/12 15:45
*/
public class mySMO extends Classifier {
// 二元支持向量机
public static class BinarySMO implements Serializable {
// alpha
protected double[] m_alpha;
// bias
protected double m_b;
// 训练集
protected Instances m_train;
// 权重向量
protected double[] m_weights;
// 训练数据的类别标签
protected double[] m_class;
// 支持向量集合 {i: 0 < m_alpha[i] < C}
protected Set<Integer> m_supportVectors;
// 惩罚因子C,超参数
protected double m_C = 10.0;
// 容忍参数
protected double m_toleranceParameter = 1.0e-3;
// 四舍五入的容忍参数
protected double m_epsilon = 1.0e-12;
// 最大迭代次数
protected int m_maxIterations = 100000;
// 当前已经执行的迭代次数
protected int m_numIterations = 0;
// 定义核函数枚举类型
protected KernelType m_kernelType = KernelType.KERNEL_LINEAR;
// 定义多项式核函数的参数
protected double m_exponent = 2.0;
// 定义 高斯核 和 拉普拉斯 核函数的参数
protected double m_gamma = 1.0;
// 定义 SIGMOID 核函数的参数 beta
protected double m_sigmoidBeta = 1.0;
// 定义 sigmoid 核函数的参数 theta
protected double m_sigmoidTheta = -1.0;
// 程序精度误差
protected double m_Del = 1000 * Double.MIN_VALUE;
/*
* @Author YFMan
* @Description // 构建分类器
* @Date 2023/6/12 22:03
* @Param [instances 训练数据集, cl1 正类, cl2 负类]
* @return void
**/
protected void buildClassifier(Instances instances, int cl1, int cl2) throws Exception {
// 初始化 alpha
m_alpha = new double[instances.numInstances()];
// 初始化 bias
m_b = 0;
// 初始化训练集
m_train = instances;
// 初始化权重向量
m_weights = new double[instances.numAttributes() - 1];
// 初始化支持向量集合
m_supportVectors = new HashSet<Integer>();
// 初始化 m_class
m_class = new double[instances.numInstances()];
// 将标签转换为 -1 和 1
for (int i = 0; i < m_class.length; i++) {
// 如果实例的类别标签为负类,则将其转换为 -1
if (instances.instance(i).classValue() == cl1) {
m_class[i] = -1;
} else {
m_class[i] = 1;
}
}
int numChanged = 0; // 记录改变的拉格朗日乘子的个数
boolean examineAll = true; // 是否检查所有的实例
while ((numChanged > 0 || examineAll) && (m_numIterations < m_maxIterations)) {
numChanged = 0;
if (examineAll) {
// loop over all training examples
for (int i = 0; i < m_train.numInstances(); i++) {
numChanged += examineExample(i);
}
} else {
// loop over examples where alpha is not 0 & not C
for (int i = 0; i < m_train.numInstances(); i++) {
if ((m_alpha[i] != 0) && (m_alpha[i] != m_C)) {
numChanged += examineExample(i);
}
}
}
if (examineAll) {
examineAll = false;
} else if (numChanged == 0) {
examineAll = true;
}
m_numIterations++;
}
}
/*
* @Author YFMan
* @Description // 计算 SVM 的输出
* @Date 2023/6/14 19:26
* @Param [index, inst]
* @return double
**/
public double SVMOutput(Instance instance) throws Exception {
double result = 0;
if (m_kernelType == KernelType.KERNEL_LINEAR) {
for (int i = 0; i < m_weights.length; i++) {
result += m_weights[i] * instance.value(i);
}
} else {
// 非线性核函数 计算 SVM 的输出
for (int i = 0; i < m_train.numInstances(); i++) {
// 只有支持向量的拉格朗日乘子才会大于 0 且两个向量不重合
if (m_alpha[i] > 0) {
result += m_alpha[i] * m_class[i] * kernelFunction(m_train.instance(i), instance);
}
}
}
result -= m_b;
return result;
}
/*
* @Author YFMan
* @Description // 根据 i2 选择第二个变量,并且更新拉格朗日乘子
* @Date 2023/6/14 19:58
* @Param [i2]
* @return int
**/
protected int examineExample(int i2) throws Exception {
double y2 = m_class[i2];
double alph2 = m_alpha[i2];
double E2 = SVMOutput(m_train.instance(i2)) - y2;
double r2 = E2 * y2;
if (r2 < -m_toleranceParameter && alph2 < m_C || r2 > m_toleranceParameter && alph2 > 0) {
// 第一种情况:违反KKT条件
// 选择第二个变量
if (m_supportVectors.size() > 1) {
// 选择第二个变量
int i1 = -1;
double max = 0;
for (Integer index : m_supportVectors) {
double E1 = SVMOutput(m_train.instance(index)) - m_class[index];
double temp = Math.abs(E1 - E2);
if (temp > max) {
max = temp;
i1 = index;
}
}
// 如果找到了第二个变量
if (i1 >= 0) {
if (takeStep(i1, i2) == 1) {
return 1;
}
}
}
// 第二种情况:没有选择第二个变量
for (int index : m_supportVectors) {
if (takeStep(index, i2) == 1) {
return 1;
}
}
// 第三种情况:没有选择支持向量
for (int index = 0; index < m_train.numInstances(); index++) {
if (takeStep(index, i2) == 1) {
return 1;
}
}
}
return 0;
}
/*
* @Author YFMan
* @Description // 根据 i1 和 i2 更新拉格朗日乘子
* @Date 2023/6/14 19:59
* @Param [i1, i2]
* @return int
**/
protected int takeStep(int i1, int i2) throws Exception {
if (i1 == i2) {
return 0;
}
double alph1 = m_alpha[i1];
double alph2 = m_alpha[i2];
double y1 = m_class[i1];
double y2 = m_class[i2];
double E1 = SVMOutput(m_train.instance(i1)) - y1;
double E2 = SVMOutput(m_train.instance(i2)) - y2;
double s = y1 * y2;
double L = 0;
double H = 0;
if (y1 != y2) {
L = Math.max(0, alph2 - alph1);
H = Math.min(m_C, m_C + alph2 - alph1);
} else {
L = Math.max(0, alph2 + alph1 - m_C);
H = Math.min(m_C, alph2 + alph1);
}
if (L == H) {
return 0;
}
double k11 = kernelFunction(m_train.instance(i1), m_train.instance(i1));
double k12 = kernelFunction(m_train.instance(i1), m_train.instance(i2));
double k22 = kernelFunction(m_train.instance(i2), m_train.instance(i2));
double eta = k11 + k22 - 2 * k12;
double a1 = 0;
double a2 = 0;
if (eta > 0) {
a2 = alph2 + y2 * (E1 - E2) / eta;
if (a2 < L) {
a2 = L;
} else if (a2 > H) {
a2 = H;
}
} else {
double f1 = y1 * (E1 + m_b) - alph1 * k11 - s * alph2 * k12;
double f2 = y2 * (E2 + m_b) - s * alph1 * k12 - alph2 * k22;
double L1 = alph1 + s * (alph2 - L);
double H1 = alph1 + s * (alph2 - H);
// objective function at a2=L
double Lobj = L1 * f1 + L * f2 + 0.5 * L1 * L1 * k11 + 0.5 * L * L * k22 + s * L * L1 * k12;
// objective function at a2=H
double Hobj = H1 * f1 + H * f2 + 0.5 * H1 * H1 * k11 + 0.5 * H * H * k22 + s * H * H1 * k12;
if (Lobj > Hobj + m_epsilon) {
a2 = L;
} else if (Lobj < Hobj - m_epsilon) {
a2 = H;
} else {
a2 = alph2;
}
}
if (Math.abs(a2 - alph2) < m_epsilon * (a2 + alph2 + m_epsilon)) {
return 0;
}
if (a2 > m_C - m_Del * m_C) // m_Del = 1000 *
// Double.MIN_VALUE,在精度误差上做了一点处理
a2 = m_C;
else if (a2 <= m_Del * m_C)
a2 = 0;
a1 = alph1 + s * (alph2 - a2);
// Update threshold to reflect change in Lagrange multipliers
double b1 = E1 + y1 * (a1 - alph1) * k11 + y2 * (a2 - alph2) * k12 + m_b;
double b2 = E2 + y1 * (a1 - alph1) * k12 + y2 * (a2 - alph2) * k22 + m_b;
if ((0 < a1 && a1 < m_C) && (0 < a2 && a2 < m_C)) {
m_b = (b1 + b2) / 2;
} else if (0 < a1 && a1 < m_C) {
m_b = b1;
} else if (0 < a2 && a2 < m_C) {
m_b = b2;
}
// Update weight vector to reflect change in a1 & a2, if linear SVM
if (m_kernelType == KernelType.KERNEL_LINEAR) {
int column = 0;
for (int i = 0; i < m_train.numAttributes(); i++) {
if (i != m_train.classIndex()) {
m_weights[column] += y1 * (a1 - alph1) * m_train.instance(i1).value(i) + y2 * (a2 - alph2) * m_train.instance(i2).value(i);
column++;
}
}
}
m_alpha[i1] = a1;
m_alpha[i2] = a2;
return 1;
}
/*
* @Author YFMan
* @Description // 核函数
* @Date 2023/6/14 19:29
* @Param [i1, i2]
* @return double
**/
protected double kernelFunction(Instance instance1, Instance instance2) throws Exception {
switch (m_kernelType) {
case KERNEL_LINEAR:
return linearKernel(instance1, instance2);
case KERNEL_POLYNOMIAL:
return polynomialKernel(instance1, instance2);
case KERNEL_RBF:
return rbfKernel(instance1, instance2);
case KERNEL_SIGMOID:
return sigmoidKernel(instance1, instance2);
default:
throw new Exception("Invalid kernel type.");
}
}
/*
* @Author YFMan
* @Description // 线性核函数
* @Date 2023/6/14 20:33
* @Param [instance1, instance2]
* @return double
**/
protected double linearKernel(Instance instance1, Instance instance2) {
double result = 0;
for (int i = 0; i < m_train.numAttributes() - 1; i++) {
result += instance1.value(i) * instance2.value(i);
}
return result;
}
protected double polynomialKernel(Instance instance1, Instance instance2) {
double result = 0;
for (int i = 0; i < m_train.numAttributes() - 1; i++) {
result += instance1.value(i) * instance2.value(i);
}
return Math.pow(result + m_gamma, m_exponent);
}
/*
* @Author YFMan
* @Description // 高斯核函数
* @Date 2023/6/15 10:46
* @Param [instance1, instance2]
* @return double
**/
protected double rbfKernel(Instance instance1, Instance instance2) {
double result = 0;
for (int i = 0; i < m_train.numAttributes() - 1; i++) {
result += Math.pow(instance1.value(i) - instance2.value(i), 2);
}
return Math.exp(-result / (2 * m_gamma * m_gamma));
}
/*
* @Author YFMan
* @Description // sigmoid 核函数
* @Date 2023/6/15 10:47
* @Param [instance1, instance2]
* @return double
**/
protected double sigmoidKernel(Instance instance1, Instance instance2) {
double result = 0;
for (int i = 0; i < m_train.numAttributes() - 1; i++) {
result += instance1.value(i) * instance2.value(i);
}
return Math.tanh(m_sigmoidBeta * result + m_sigmoidTheta);
}
}
// 归一化数据的过滤器
public static final int FILTER_NORMALIZE = 0;
// 标准化数据的过滤器
public static final int FILTER_STANDARDIZE = 1;
// 不使用过滤器
public static final int FILTER_NONE = 2;
// 二元分类器
protected BinarySMO m_classifier = null;
// 是否使用过滤器
protected int m_filterType = FILTER_NORMALIZE;
// 用于标准化/归一化数据的过滤器
protected Filter m_Filter = null;
// 用于标准化数据的过滤器
protected Filter m_StandardizeFilter = null;
// 用于二值化数据的过滤器
protected Filter m_NominalToBinary = null;
/*
* @Author YFMan
* @Description // 构建分类器
* @Date 2023/6/14 20:29
* @Param [insts]
* @return void
**/
public void buildClassifier(Instances insts) throws Exception {
// 标准化数据
m_StandardizeFilter = new Standardize();
m_StandardizeFilter.setInputFormat(insts);
insts = Filter.useFilter(insts, m_StandardizeFilter);
// 二值化数据
m_NominalToBinary = new NominalToBinary();
m_NominalToBinary.setInputFormat(insts);
insts = Filter.useFilter(insts, m_NominalToBinary);
m_classifier = new BinarySMO();
m_classifier.buildClassifier(insts, 0, 1);
}
/*
* @Author YFMan
* @Description // 分类实例
* @Date 2023/6/14 20:43
* @Param [inst]
* @return double[]
**/
public double[] distributionForInstance(Instance inst) throws Exception {
// 过滤实例
m_StandardizeFilter.input(inst);
inst = m_StandardizeFilter.output();
m_NominalToBinary.input(inst);
inst = m_NominalToBinary.output();
double[] result = new double[2];
double output = m_classifier.SVMOutput(inst);
result[1] = 1.0 / (1.0 + Math.exp(-output));
result[0] = 1.0 - result[1];
return result;
}
/*
* @Author YFMan
* @Description // 主函数
* @Date 2023/6/14 20:42
* @Param [argv]
* @return void
**/
public static void main(String[] argv) {
runClassifier(new mySMO(), argv);
}
}
四、感悟
支持向量机的优化部分smo数学原理很强,论文中的推导非常清晰,因此文中并没有对其过多解读,因为我解读的再细致,以我对smo的理解,也不可能有原作者好。
同时,虽然自己能将smo侥幸实现,但只能说按照文中公式及伪代码来理解一二,并不敢说对其理解有多深刻。直到现在,也依然有很多不明白的点。
对于计算机科学这门应用学科而言,数学永远是天花板,也许我们能侥幸的把它用起来,但如果真正的想要有所建树和理论创新,可能还要回归到数学吧。
《机器学习》周志华 ↩︎
Platt J. Sequential minimal optimization: A fast algorithm for training support vector machines[J]. 1998. ↩︎