
news2025/2/20 18:05:56


但是西瓜书中对于参数优化(即:Sequential Minimal Optimization,smo算法)部分讲解的十分简略,看起来不太好懂。因此这一部分参考的是John C. Platt 1998年发表的论文:Sequential Minimal Optimization: A Fast Algorithm for Training Support Vector Machines

值得注意的是,S.S. Keerthi在2001年又发表了一篇名为Imrovements to Platt’s SMO Algorithm for SVM Classifier Design的文章,在这篇文章中,它改进了原版本smo的收敛条件,并融入了许多缓存机制,好处是求解速度更快了,但理解起来较为晦涩。

因为smo数学原理较强,处于学习考虑,我这里的实现参考的的是John C. Platt 1998年发表的论文。





w T x + b = 0 w^Tx+b=0 wTx+b=0



min ⁡ w , b 1 2 ∣ ∣ w ∣ ∣ 2 s . t . y i ( w T x i + b ) ≥ 1 , i = 1 , 2 , . . . , m \min_{w,b}\frac{1}{2}||w||^2\\ s.t. \quad y_i(w^Tx_i+b) \ge 1,\quad i=1,2,...,m w,bmin21∣∣w2s.t.yi(wTxi+b)1,i=1,2,...,m

上面这个最优化目标对应的 w w w b b b 就是我们最终想要的结果。1





直到最后,被逼无奈之下去看了John C. Platt 1998年发表的原版论文(Sequential Minimal Optimization: A Fast Algorithm for Training Support Vector Machines),看完后真的有种柳暗花明又一村的感觉,感觉比上述两位老师写的教材还要好懂一些,所以十分建议阅读。最重要的是,在文章的末尾,John C. Platt前辈还提供了一段C语言的伪代码,对照着伪代码以及文章中的公式,再回过头来写svm的模型就很容易了。2




package weka.classifiers.myf;

import weka.classifiers.Classifier;
import weka.core.Instance;
import weka.core.Instances;

import weka.filters.Filter;
import weka.filters.unsupervised.attribute.NominalToBinary;
import weka.filters.unsupervised.attribute.Standardize;

import java.io.Serializable;
import java.util.HashSet;
import java.util.Set;

enum KernelType {

 * @author YFMan
 * @Description 自定义的 SMO 分类器
 * @Date 2023/6/12 15:45
public class mySMO extends Classifier {

    // 二元支持向量机
    public static class BinarySMO implements Serializable {

        // alpha
        protected double[] m_alpha;

        // bias
        protected double m_b;

        // 训练集
        protected Instances m_train;

        // 权重向量
        protected double[] m_weights;

        // 训练数据的类别标签
        protected double[] m_class;

        // 支持向量集合 {i: 0 < m_alpha[i] < C}
        protected Set<Integer> m_supportVectors;

        // 惩罚因子C,超参数
        protected double m_C = 10.0;

        // 容忍参数
        protected double m_toleranceParameter = 1.0e-3;

        // 四舍五入的容忍参数
        protected double m_epsilon = 1.0e-12;

        // 最大迭代次数
        protected int m_maxIterations = 100000;

        // 当前已经执行的迭代次数
        protected int m_numIterations = 0;

        // 定义核函数枚举类型
        protected KernelType m_kernelType = KernelType.KERNEL_LINEAR;

        // 定义多项式核函数的参数
        protected double m_exponent = 2.0;

        // 定义 高斯核 和 拉普拉斯 核函数的参数
        protected double m_gamma = 1.0;

        // 定义 SIGMOID 核函数的参数 beta
        protected double m_sigmoidBeta = 1.0;

        // 定义 sigmoid 核函数的参数 theta
        protected double m_sigmoidTheta = -1.0;

        // 程序精度误差
        protected double m_Del = 1000 * Double.MIN_VALUE;

         * @Author YFMan
         * @Description // 构建分类器
         * @Date 2023/6/12 22:03
         * @Param [instances 训练数据集, cl1 正类, cl2 负类]
         * @return void
        protected void buildClassifier(Instances instances, int cl1, int cl2) throws Exception {
            // 初始化 alpha
            m_alpha = new double[instances.numInstances()];

            // 初始化 bias
            m_b = 0;

            // 初始化训练集
            m_train = instances;

            // 初始化权重向量
            m_weights = new double[instances.numAttributes() - 1];

            // 初始化支持向量集合
            m_supportVectors = new HashSet<Integer>();

            // 初始化 m_class
            m_class = new double[instances.numInstances()];

            // 将标签转换为 -1 和 1
            for (int i = 0; i < m_class.length; i++) {
                // 如果实例的类别标签为负类,则将其转换为 -1
                if (instances.instance(i).classValue() == cl1) {
                    m_class[i] = -1;
                } else {
                    m_class[i] = 1;

            int numChanged = 0;         // 记录改变的拉格朗日乘子的个数
            boolean examineAll = true;  // 是否检查所有的实例

            while ((numChanged > 0 || examineAll) && (m_numIterations < m_maxIterations)) {
                numChanged = 0;
                if (examineAll) {
                    // loop over all training examples
                    for (int i = 0; i < m_train.numInstances(); i++) {
                        numChanged += examineExample(i);
                } else {
                    // loop over examples where alpha is not 0 & not C
                    for (int i = 0; i < m_train.numInstances(); i++) {
                        if ((m_alpha[i] != 0) && (m_alpha[i] != m_C)) {
                            numChanged += examineExample(i);

                if (examineAll) {
                    examineAll = false;
                } else if (numChanged == 0) {
                    examineAll = true;


         * @Author YFMan
         * @Description // 计算 SVM 的输出
         * @Date 2023/6/14 19:26
         * @Param [index, inst]
         * @return double
        public double SVMOutput(Instance instance) throws Exception {
            double result = 0;

            if (m_kernelType == KernelType.KERNEL_LINEAR) {
                for (int i = 0; i < m_weights.length; i++) {
                    result += m_weights[i] * instance.value(i);
            } else {
                // 非线性核函数 计算 SVM 的输出
                for (int i = 0; i < m_train.numInstances(); i++) {
                    // 只有支持向量的拉格朗日乘子才会大于 0 且两个向量不重合
                    if (m_alpha[i] > 0) {
                        result += m_alpha[i] * m_class[i] * kernelFunction(m_train.instance(i), instance);

            result -= m_b;
            return result;

         * @Author YFMan
         * @Description // 根据 i2 选择第二个变量,并且更新拉格朗日乘子
         * @Date 2023/6/14 19:58
         * @Param [i2]
         * @return int
        protected int examineExample(int i2) throws Exception {
            double y2 = m_class[i2];
            double alph2 = m_alpha[i2];
            double E2 = SVMOutput(m_train.instance(i2)) - y2;
            double r2 = E2 * y2;

            if (r2 < -m_toleranceParameter && alph2 < m_C || r2 > m_toleranceParameter && alph2 > 0) {
                // 第一种情况:违反KKT条件
                // 选择第二个变量
                if (m_supportVectors.size() > 1) {
                    // 选择第二个变量
                    int i1 = -1;
                    double max = 0;
                    for (Integer index : m_supportVectors) {
                        double E1 = SVMOutput(m_train.instance(index)) - m_class[index];
                        double temp = Math.abs(E1 - E2);
                        if (temp > max) {
                            max = temp;
                            i1 = index;
                    // 如果找到了第二个变量
                    if (i1 >= 0) {
                        if (takeStep(i1, i2) == 1) {
                            return 1;
                // 第二种情况:没有选择第二个变量
                for (int index : m_supportVectors) {
                    if (takeStep(index, i2) == 1) {
                        return 1;
                // 第三种情况:没有选择支持向量
                for (int index = 0; index < m_train.numInstances(); index++) {
                    if (takeStep(index, i2) == 1) {
                        return 1;
            return 0;

         * @Author YFMan
         * @Description // 根据 i1 和 i2 更新拉格朗日乘子
         * @Date 2023/6/14 19:59
         * @Param [i1, i2]
         * @return int
        protected int takeStep(int i1, int i2) throws Exception {
            if (i1 == i2) {
                return 0;

            double alph1 = m_alpha[i1];
            double alph2 = m_alpha[i2];
            double y1 = m_class[i1];
            double y2 = m_class[i2];
            double E1 = SVMOutput(m_train.instance(i1)) - y1;
            double E2 = SVMOutput(m_train.instance(i2)) - y2;
            double s = y1 * y2;

            double L = 0;
            double H = 0;
            if (y1 != y2) {
                L = Math.max(0, alph2 - alph1);
                H = Math.min(m_C, m_C + alph2 - alph1);
            } else {
                L = Math.max(0, alph2 + alph1 - m_C);
                H = Math.min(m_C, alph2 + alph1);

            if (L == H) {
                return 0;

            double k11 = kernelFunction(m_train.instance(i1), m_train.instance(i1));
            double k12 = kernelFunction(m_train.instance(i1), m_train.instance(i2));
            double k22 = kernelFunction(m_train.instance(i2), m_train.instance(i2));
            double eta = k11 + k22 - 2 * k12;

            double a1 = 0;
            double a2 = 0;

            if (eta > 0) {
                a2 = alph2 + y2 * (E1 - E2) / eta;
                if (a2 < L) {
                    a2 = L;
                } else if (a2 > H) {
                    a2 = H;
            } else {
                double f1 = y1 * (E1 + m_b) - alph1 * k11 - s * alph2 * k12;
                double f2 = y2 * (E2 + m_b) - s * alph1 * k12 - alph2 * k22;
                double L1 = alph1 + s * (alph2 - L);
                double H1 = alph1 + s * (alph2 - H);

                // objective function at a2=L
                double Lobj = L1 * f1 + L * f2 + 0.5 * L1 * L1 * k11 + 0.5 * L * L * k22 + s * L * L1 * k12;
                // objective function at a2=H
                double Hobj = H1 * f1 + H * f2 + 0.5 * H1 * H1 * k11 + 0.5 * H * H * k22 + s * H * H1 * k12;

                if (Lobj > Hobj + m_epsilon) {
                    a2 = L;
                } else if (Lobj < Hobj - m_epsilon) {
                    a2 = H;
                } else {
                    a2 = alph2;

            if (Math.abs(a2 - alph2) < m_epsilon * (a2 + alph2 + m_epsilon)) {
                return 0;

            if (a2 > m_C - m_Del * m_C) // m_Del = 1000 *
                // Double.MIN_VALUE,在精度误差上做了一点处理
                a2 = m_C;
            else if (a2 <= m_Del * m_C)
                a2 = 0;

            a1 = alph1 + s * (alph2 - a2);

            // Update threshold to reflect change in Lagrange multipliers
            double b1 = E1 + y1 * (a1 - alph1) * k11 + y2 * (a2 - alph2) * k12 + m_b;
            double b2 = E2 + y1 * (a1 - alph1) * k12 + y2 * (a2 - alph2) * k22 + m_b;

            if ((0 < a1 && a1 < m_C) && (0 < a2 && a2 < m_C)) {
                m_b = (b1 + b2) / 2;
            } else if (0 < a1 && a1 < m_C) {
                m_b = b1;
            } else if (0 < a2 && a2 < m_C) {
                m_b = b2;

            // Update weight vector to reflect change in a1 & a2, if linear SVM
            if (m_kernelType == KernelType.KERNEL_LINEAR) {
                int column = 0;
                for (int i = 0; i < m_train.numAttributes(); i++) {
                    if (i != m_train.classIndex()) {
                        m_weights[column] += y1 * (a1 - alph1) * m_train.instance(i1).value(i) + y2 * (a2 - alph2) * m_train.instance(i2).value(i);

            m_alpha[i1] = a1;
            m_alpha[i2] = a2;

            return 1;

         * @Author YFMan
         * @Description // 核函数
         * @Date 2023/6/14 19:29
         * @Param [i1, i2]
         * @return double
        protected double kernelFunction(Instance instance1, Instance instance2) throws Exception {
            switch (m_kernelType) {
                case KERNEL_LINEAR:
                    return linearKernel(instance1, instance2);
                case KERNEL_POLYNOMIAL:
                    return polynomialKernel(instance1, instance2);
                case KERNEL_RBF:
                    return rbfKernel(instance1, instance2);
                case KERNEL_SIGMOID:
                    return sigmoidKernel(instance1, instance2);
                    throw new Exception("Invalid kernel type.");

         * @Author YFMan
         * @Description // 线性核函数
         * @Date 2023/6/14 20:33
         * @Param [instance1, instance2]
         * @return double
        protected double linearKernel(Instance instance1, Instance instance2) {
            double result = 0;
            for (int i = 0; i < m_train.numAttributes() - 1; i++) {
                result += instance1.value(i) * instance2.value(i);
            return result;

        protected double polynomialKernel(Instance instance1, Instance instance2) {
            double result = 0;
            for (int i = 0; i < m_train.numAttributes() - 1; i++) {
                result += instance1.value(i) * instance2.value(i);
            return Math.pow(result + m_gamma, m_exponent);

         * @Author YFMan
         * @Description // 高斯核函数
         * @Date 2023/6/15 10:46
         * @Param [instance1, instance2]
         * @return double
        protected double rbfKernel(Instance instance1, Instance instance2) {
            double result = 0;
            for (int i = 0; i < m_train.numAttributes() - 1; i++) {
                result += Math.pow(instance1.value(i) - instance2.value(i), 2);
            return Math.exp(-result / (2 * m_gamma * m_gamma));

         * @Author YFMan
         * @Description // sigmoid 核函数
         * @Date 2023/6/15 10:47
         * @Param [instance1, instance2]
         * @return double
        protected double sigmoidKernel(Instance instance1, Instance instance2) {
            double result = 0;
            for (int i = 0; i < m_train.numAttributes() - 1; i++) {
                result += instance1.value(i) * instance2.value(i);
            return Math.tanh(m_sigmoidBeta * result + m_sigmoidTheta);


    // 归一化数据的过滤器
    public static final int FILTER_NORMALIZE = 0;

    // 标准化数据的过滤器
    public static final int FILTER_STANDARDIZE = 1;

    // 不使用过滤器
    public static final int FILTER_NONE = 2;

    // 二元分类器
    protected BinarySMO m_classifier = null;

    // 是否使用过滤器
    protected int m_filterType = FILTER_NORMALIZE;

    // 用于标准化/归一化数据的过滤器
    protected Filter m_Filter = null;

    // 用于标准化数据的过滤器
    protected Filter m_StandardizeFilter = null;

    // 用于二值化数据的过滤器
    protected Filter m_NominalToBinary = null;

     * @Author YFMan
     * @Description // 构建分类器
     * @Date 2023/6/14 20:29
     * @Param [insts]
     * @return void
    public void buildClassifier(Instances insts) throws Exception {
        // 标准化数据
        m_StandardizeFilter = new Standardize();
        insts = Filter.useFilter(insts, m_StandardizeFilter);

        // 二值化数据
        m_NominalToBinary = new NominalToBinary();
        insts = Filter.useFilter(insts, m_NominalToBinary);

        m_classifier = new BinarySMO();

        m_classifier.buildClassifier(insts, 0, 1);

     * @Author YFMan
     * @Description // 分类实例
     * @Date 2023/6/14 20:43
     * @Param [inst]
     * @return double[]
    public double[] distributionForInstance(Instance inst) throws Exception {
        // 过滤实例
        inst = m_StandardizeFilter.output();

        inst = m_NominalToBinary.output();

        double[] result = new double[2];

        double output = m_classifier.SVMOutput(inst);
        result[1] = 1.0 / (1.0 + Math.exp(-output));
        result[0] = 1.0 - result[1];

        return result;

     * @Author YFMan
     * @Description // 主函数
     * @Date 2023/6/14 20:42
     * @Param [argv]
     * @return void
    public static void main(String[] argv) {
        runClassifier(new mySMO(), argv);





  1. 《机器学习》周志华 ↩︎

  2. Platt J. Sequential minimal optimization: A fast algorithm for training support vector machines[J]. 1998. ↩︎





