基于weka平台手工实现(LinearRegression | Ridge Regression,岭回归)

news2024/11/26 16:28:19

一、普通的线性回归

线性回归主要采用最小二乘法来实现,主要思想如下:

X = ( x 11 x 12 ⋯ x 1 d 1 x 21 x 22 ⋯ 5 1 ⋮ ⋮ ⋱ ⋮ ⋮ x m 1 x m 2 ⋯ x m d 1 ) X=\left( \begin{matrix} x_{11} & x_{12} & \cdots & x_{1d} & 1 \\ x_{21} & x_{22} & \cdots & 5 & 1 \\ \vdots & \vdots & \ddots & \vdots & \vdots \\ x_{m1} & x_{m2} & \cdots & x_{md} & 1 \\ \end{matrix} \right) X= x11x21xm1x12x22xm2x1d5xmd111

X为一个m行d+1列的矩阵,其中m行为数据的数量,d+1列表示参数的数量。d表示未知数系数的个数,1用来与截距相乘,相当于b。

ω ^ ∗ = arg ⁡ min ⁡ ω ( y − X ω ^ ) T ( y − X ω ^ ) \hat{\omega}^*=\mathop{\arg\min}\limits_{\omega}(y-X\hat{\omega})^T(y-X\hat{\omega}) ω^=ωargmin(yXω^)T(yXω^)

损失函数如下:

E ω ^ = ( y − X ω ^ ) T ( y − X ω ^ ) E_{\hat\omega}=(y-X\hat{\omega})^T(y-X\hat{\omega}) Eω^=(yXω^)T(yXω^)

对于一个梯度可求最大次方项为正的函数来讲,极小值点为导数为0的点。

∂ E ω ^ ∂ ω ^ = 2 X T ( X ω ^ − y ) = 0 \frac{\partial E_{\hat\omega}}{\partial\hat\omega}=2X^T(X\hat\omega-y)=0 ω^Eω^=2XT(Xω^y)=0

( X ω ^ − y ) = 0 (X\hat\omega-y)=0 (Xω^y)=0

X ω ^ = y X\hat\omega=y Xω^=y

X T X ω ^ = X T y X^TX\hat\omega=X^Ty XTXω^=XTy

ω ^ ∗ = ( X T X ) − 1 X T y \hat\omega^*=(X^TX)^{-1}X^Ty ω^=(XTX)1XTy

但是,矩阵求逆的前提是,当前矩阵是满秩的,也就是说它是非奇异矩阵,这就要求, X T X X^TX XTX 矩阵中的数据都是线性无关的,如果其中的数据存在线性相关性,那么是无法求逆的。

这对于一个高维特征的数据是很难做到的,因为维度之间很可能存在相关性,完全一致之后就会成为奇异矩阵(无法求逆)。

因此,岭回归就产生了。专门用来解决上述问题。

二、岭回归(Ridge Regression)

岭回归主要做法是,对上述公式的对角线上加了一个很小的值(岭系数),这样无论如何, X T X X^TX XTX 就都可以求导了。

ω ^ ∗ = ( X T X + λ I ) − 1 X T y (1) \hat\omega^*=(X^TX+\lambda I)^{-1}X^Ty\tag{1} ω^=(XTX+λI)1XTy(1)

岭回归不是随便加的 λ \lambda λ,因为我们要防止曲线过拟合,就需要控制多元函数的各个系数都不要太大,如果太大,就会造成过拟合,因此在损失函数后加上一个正则项。

E ω ^ = ( y − X ω ^ ) T ( y − X ω ^ ) + λ ω T ω = ( y T y − y T X ω ^ − ω ^ T X T y + ω ^ T X T X ω ^ ) + λ ω T ω \begin{aligned} E_{\hat\omega}&=(y-X\hat{\omega})^T(y-X\hat{\omega})+\lambda \omega^T\omega \\ &=(y^Ty-y^TX\hat{\omega}-\hat\omega^T X^Ty+\hat\omega^T X^TX\hat\omega)+\lambda\omega^T\omega \end{aligned} Eω^=(yXω^)T(yXω^)+λωTω=(yTyyTXω^ω^TXTy+ω^TXTXω^)+λωTω

∂ E ω ^ ∂ ω ^ = − X T Y − X T Y + 2 X T X ω ^ + 2 λ ω ^ = 2 ( − X T Y + X T X ω ^ + λ ω ^ ) = 0 \begin{aligned} \frac{\partial E_{\hat\omega}}{\partial\hat\omega}&=-X^TY-X^TY+2X^TX\hat\omega+2\lambda\hat\omega\\ &=2(-X^TY+X^TX\hat\omega+\lambda\hat\omega)\\ &=0 \end{aligned} ω^Eω^=XTYXTY+2XTXω^+2λω^=2(XTY+XTXω^+λω^)=0

X T X ω ^ + λ ω ^ = X T Y X^TX\hat\omega+\lambda\hat\omega=X^TY XTXω^+λω^=XTY

( X T X + λ ) ω ^ = X T Y (X^TX+\lambda)\hat\omega=X^TY (XTX+λ)ω^=XTY

ω ^ = ( X T X + λ I ) − 1 X T Y (2) \hat\omega=(X^TX+\lambda I)^{-1}X^TY\tag{2} ω^=(XTX+λI)1XTY(2)
将 (2) 式和 (1) 式进行对比,可以发现,二者是相同的。

三、代码实现

代码实现注释已经较为清晰,在此不再赘述具体实现过程。

package weka.classifiers.myf;

import weka.classifiers.Classifier;
import weka.core.*;
import weka.core.matrix.Matrix;

/**
 * @author YFMan
 * @Description 自定义的 线性回归 分类器
 * @Date 2023/5/9 15:45
 */
public class myLinearRegression extends Classifier {

    // 用于存储 线性回归 系数 的数组
    private double[] m_Coefficients;

    // 类别索引
    private int m_ClassIndex;

    // 存储训练数据
    private Instances m_Instances;

    /*
     * @Author YFMan
     * @Description 根据训练数据 建立 线性回归模型
     * @Date 2023/5/9 22:08
     * @Param [data] 训练数据
     * @return void
     **/
    public void buildClassifier(Instances data) throws Exception {
        // 存储训练数据
        m_Instances = data;

        // 初始化类别索引
        m_ClassIndex = data.classIndex();

        // 用来存储 线性回归 系数 的数组
        m_Coefficients = null;

        // 初始化数据矩阵 X
        // 高度是样例数量,宽度是属性数量+1(1作为 截距参数b 的输入)
        Matrix X = new Matrix(data.numInstances(), data.numAttributes());

        // 初始化数据矩阵 Y
        // 高度是样例数量,宽度是1
        Matrix Y = new Matrix(data.numInstances(), 1);

        // 初始化矩阵值
        for (int i = 0; i < data.numInstances(); i++) {
            int column = 0;
            for (int j = 0; j < data.numAttributes(); j++) {
                if (j != data.classIndex()) {
                    X.set(i, column, data.instance(i).value(j));
                    column++;
                } else {
                    Y.set(i, 0, data.instance(i).value(j));
                }
            }
        }

        // 设置 X 的最后一列为 1,用于计算 截距参数b
        for (int i = 0; i < data.numInstances(); i++) {
            X.set(i, data.numAttributes() - 1, 1);
        }

        // 计算XTX
        Matrix XTX = X.transpose().times(X);
        // 计算XTY
        Matrix XTY = X.transpose().times(Y);

        // 由于XTX可能是奇异矩阵,所以需要加一个岭回归系数
        for (int i = 0; i < XTX.getRowDimension(); i++) {
            XTX.set(i, i, XTX.get(i, i) + 0.0001);
        }

        // 计算系数矩阵
        Matrix solution = XTX.inverse().times(XTY);

        // 将系数矩阵转换为数组
        m_Coefficients = new double[solution.getRowDimension()];
        for (int i = 0; i < solution.getRowDimension(); i++) {
            m_Coefficients[i] = solution.get(i, 0);
        }
    }

    /*
     * @Author YFMan
     * @Description 利用 建立的线性模型 对样例进行分类
     * @Date 2023/5/9 22:10
     * @Param [instance] 待分类的样例
     * @return double
     **/
    public double classifyInstance(Instance instance) throws Exception {
        // 计算回归模型的预测值
        double result = 0;
        int column = 0;
        for (int i = 0; i < instance.numAttributes(); i++) {
            if (m_ClassIndex != i) {
                result += instance.value(i) * m_Coefficients[column];
                column++;
            }
        }
        // 加上截距参数
        result += m_Coefficients[column];

        // 返回预测值
        return result;
    }

    /*
     * @Author YFMan
     * @Description 输出建立的线性模型
     * @Date 2023/5/9 22:29
     * @Param []
     * @return java.lang.String
     **/
    public String toString() {

        try {
            StringBuilder text = new StringBuilder();
            int column = 0;
            boolean first = true;

            text.append("\nLinear Regression Model\n\n");

            text.append(m_Instances.classAttribute().name()).append(" =\n\n");
            for (int i = 0; i < m_Instances.numAttributes(); i++) {
                if (i != m_ClassIndex) {
                    if (!first)
                        text.append(" +\n");
                    else
                        first = false;
                    text.append(Utils.doubleToString(m_Coefficients[column], 12, 4)).append(" * ");
                    text.append(m_Instances.attribute(i).name());
                    column++;
                }
            }
            text.append(" +\n").append(Utils.doubleToString(m_Coefficients[column], 12, 4));
            return text.toString();
        } catch (Exception e) {
            return "Can't print Linear Regression!";
        }
    }

    /*
     * @Author YFMan
     * @Description 主函数 生成一个线性回归函数预测器
     * @Date 2023/5/9 22:35
     * @Param [argv]
     * @return void
     **/
    public static void main(String[] argv) {
        runClassifier(new myLinearRegression(), argv);
    }
}

四、结果分析

在weka平台中的cpu.arff数据集上进行实验。

我们自己写的模型结果:

在这里插入图片描述

weka中的线性回归模型结果:

在这里插入图片描述

可以看到,weka平台中的算法和我们自己实现的结果,训练得到的参数是一致的,不同点在于weka平台中的算法有对属性进行选择,并没有使用所有特征进行训练(weka使用了5个特征),而我自己实现的没有进行属性选择,所以一共有6个特征。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/701123.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Vinted店铺为什么被封?如何应对?

Vinted是一家在线二手交易平台&#xff0c;专门用于买卖衣物和时尚配件。自从2022年以来&#xff0c;Vinted也越来越向综合性跨境电商平台转变。细心的伙伴都会发现&#xff0c;近来Vinted这阵子封号确实很严重&#xff0c;感觉是风控变严格了&#xff0c;但是万变不离其宗&…

xhtmlrenderer 将html转换成pdf,设置多字体, 以及中文不显示的问题

接上一篇 https://blog.csdn.net/qq_21480147/article/details/131187202 多字体 字体文件自行搜索或者window中自带的搜索(C:\Windows\Fonts) 中文不显示 在要渲染的中文的地方中设置stylefont-family:[字体] 该字体需要对应指定的属性, 属性参考:

java程序改变io临时存储路径

System.setProperty(“java.io.temdir”,“your path”)

【UE5 Cesium】08-Cesium for Unreal 子关卡应用实例(上)

UE版本&#xff1a;5.1 效果 &#xff08;运行游戏可以看到进入关卡体积内楼房模型才会显现&#xff0c;以此来减少电脑性能消耗&#xff09; 步骤 一、新建两个子关卡&#xff08;以北京和上海为例&#xff09; 点击窗口-》关卡-》新建 命名第一个子关卡为“SubLevel_Bei…

计算机专业学生暑假要去看这些经典书籍!

好书在精不在多&#xff0c;每一本经典书籍都值得反复咀嚼&#xff0c;温故而知新&#xff01; 分享几本经典书籍。 重构 改善既有代码的设计 就像豆瓣评论所说的&#xff0c;看后有种醍醐灌顶、欲罢不能的感觉。无论你是初学者&#xff0c;还是深耕多年的老手&#xff0c;这…

Bytebase VS Yearning

下文对 Bytebase 和 Yearning 两个数据库管理工具进行了多维度比较&#x1f50d;。 产品功能定位 Yearning&#xff1a;功能较为单一的独立数据库审核工具&#xff0c;适合小团队进行简单的 SQL 审核&#xff0c;若要应对复杂需求必须进行大量二次开发&#xff0c;用户群更偏…

从功能测试进阶自动化测试,熬夜7天整理出这一份3000字超全学习指南!

因为我一直在分享自动化测试技术&#xff0c;所以&#xff0c;经常被问到&#xff1a; 功能测试想转自动化&#xff0c;请问应该怎么入手&#xff1f;或者有哪些书推荐&#xff1f; 那么&#xff0c;接下来我就结合自己的经历聊一聊我是如何在工作中做自动化测试的。 测试新…

【技术操作】EasyCVR如何在分享页增加控制台跳转?

EasyCVR可拓展性强、视频能力灵活、部署轻快&#xff0c;可支持的主流标准协议有GB28181、RTSP/Onvif、RTMP等&#xff0c;以及厂家私有协议与SDK接入&#xff0c;包括海康Ehome、海大宇等设备的SDK等&#xff0c;能对外分发RTSP、RTMP、FLV、HLS、WebRTC等格式的视频流。 在Ea…

6.4.2 文件隐藏属性

chattr指令只能在Ext2/Ext3/Ext4的 Linux 传统文件系统上面完整生效&#xff0c; 其他的文件系统可能就无法完整的支持这个指令了&#xff0c;例如 xfs 仅支持部份参数而已。 chattr &#xff08;设置文件隐藏属性&#xff09; 这个指令很重要&#xff0c;在系统的数据安全上面…

互联网医院平台|互联网医院搭建|线上医疗系统开发必要功能

医疗服务行业一直以来都有着较好的发展市场&#xff0c;为了进一步拓展医疗行业的发展空间&#xff0c;开始选择布局线上渠道&#xff0c;互联网医院平台的出现解决了线下就医的一些困境&#xff0c;比如改善人流如潮的情况&#xff0c;提升医护人员的工作效率&#xff0c;那么…

AutoSAR系列讲解(入门篇)3.1-RTE概述

一、什么是RTE RTE的作用有点像一个快递中转站或者说是电话接线员&#xff08;就是上个世界那种要先打电话到接线员那里&#xff0c;然后通过接线员转接电话线到目的地&#xff09;&#xff0c;其作 用就是将一个SWC的信息通过RTE连接到其他SWC或者BSW上。且RTE具有管理这些信…

前端新增校验关键属性是否重复

需求&#xff1a;前端新增某个属性时&#xff0c;该属性下可新增列表&#xff0c;列表编码禁止重复&#xff08;未提交该属性时前端校验列表编码是否重复&#xff09; js&#xff1a;新增后校验 let arrayCode if (this.collectionPointList.length 0) {this.collectionPoint…

自制聊天机器人实现与chatgpt或微信好友对话【附代码】

闲来无事&#xff0c;想实现一个可与chatgpt或者微信好友对话的聊天机器人。该聊天机器人还可应用于QQ好友或者其他地方的语音输入。功能还是比较简单的&#xff0c;后期会慢慢更新&#xff0c;让人机交互体验感不断提升。 项目描述&#xff1a; 语音输入"开启语音助手&…

Linux常用命令——fmt命令

在线Linux命令查询工具 fmt 读取文件后优化处理并输出 补充说明 fmt命令读取文件的内容&#xff0c;根据选项的设置对文件格式进行简单的优化处理&#xff0c;并将结果送到标准输出设备。 语法 fmt(选项)(参数)选项 -c或--crown-margin&#xff1a;每段前两列缩排&#…

Django期末复习总结【内含思维导图帮助梳理】

Django-最下面有笔记的下载链接 初始Django框架 MTV设计模式 Model&#xff08;模型&#xff09; Template&#xff08;模板&#xff09; View&#xff08;视图&#xff09; Django项目框架搭建 创建项目骨架 django-admin startproject my_project1 启动服务 python mana…

2 线程基础知识复习

1、并发相关Java包 涉及到的包内容 java.util.concurrent java.util.concurrent.atomic java.util.concurrent.locks 2、并发始祖 3、start线程解读 初始程序 public static void main(String[] args) {Thread t1 new Thread(() ->{},"t1");t1.start();}//…

从功能测试到自动化测试,待遇翻倍,我整理的超全学习指南!

在这个吃技术的IT行业来说&#xff0c;我刚入行的时候每天做的也是最基础的工作&#xff0c;但是随着时间的消磨&#xff0c;我产生了对自我和岗位价值和意义的困惑。一是感觉自己在浪费时间&#xff0c;另一个就是做了快2年的测试&#xff0c;感觉每天过得浑浑噩噩&#xff0c…

一个JVM参数,服务超时率降了四分之三

先说结论&#xff1a;通过优化Xms&#xff0c;改为和Xmx一致&#xff0c;使系统的超时率降了四分之三 1. 背景 一个同事说他负责的服务在一次上线之后超时率增加了一倍 2. 分析 2.1 机器的监控 首先找了一台机器&#xff0c;看了监控 上线后最明显的变化就是CPU使用率变高了…

Redis6之主从复制

主从复制 是指将一台Redis服务器的数据&#xff0c;复制到其他Redis服务器。前者称为主节点&#xff0c;后者称为从节点&#xff1b;数据复制是单向的&#xff0c;只能由主节点复制到从节点&#xff1b;主节点以写为主&#xff0c;从节点以读为主。 特点 1.使用异步复制&#…

VS2019 QT5 第一个项目

(1条消息) VS2017PyQt5环境配置以及第一个HellowPyQt5_vs pyqt_2011老王的博客-CSDN博客 利用工具里的PyUIC5&#xff0c;将ui转为py 选中刚加入的ui文件&#xff0c;工具》PyUIC5 利用工具里的PyUIC5&#xff0c;将ui转为py 选中刚加入的ui文件&#xff0c;工具》PyUIC5 利用…