机器学习:多元线性回归闭式解(Python)

news2024/11/17 15:55:48
import numpy as np
import matplotlib.pyplot as plt


class LRClosedFormSol:
    def __init__(self, fit_intercept=True, normalize=True):
        """
        :param fit_intercept: 是否训练bias
        :param normalize: 是否标准化数据
        """
        self.theta = None  # 训练权重系数
        self.fit_intercept = fit_intercept  # 线性模型的常数项。也即偏置bias,模型中的theta0
        self.normalize = normalize  # 是否标准化数据
        if normalize:
            self.feature_mean, self.feature_std = None, None  # 特征的均值,标准方差
        self.mse = np.infty  # 训练样本的均方误差
        self.r2, self.r2_adj = 0.0, 0.0  # 判定系数和修正判定系数
        self.n_samples, self.n_features = 0, 0  # 样本量和特征数

    def fit(self, x_train, y_train):
        """
        模型训练,根据是否标准化与是否拟合偏置项分类讨论
        :param x_train: 训练样本集
        :param y_train: 训练目标集
        :return:
        """
        if self.normalize:
            self.feature_mean = np.mean(x_train, axis=0)  # 按样本属性计算样本均值
            self.feature_std = np.std(x_train, axis=0) + 1e-8  # 样本方差,为避免零除,添加噪声
            x_train = (x_train - self.feature_mean) / self.feature_std  # 标准化
        if self.fit_intercept:
            x_train = np.c_[x_train, np.ones_like(y_train)]  # 添加一列1,即偏置项样本
        # 训练模型
        self._fit_closed_form_solution(x_train, y_train)  # 求闭式解

    def _fit_closed_form_solution(self, x_train, y_train):
        """
        线性回归的闭式解,单独函数,以便后期扩充维护
        :param x_train: 训练样本集
        :param y_train: 训练目标集
        :return:
        """
        # pinv伪逆,即(A^T * A)^(-1) * A^T
        self.theta = np.linalg.pinv(x_train).dot(y_train)  # 非正则化
        # xtx = np.dot(x_train.T, x_train) + 0.01 * np.eye(x_train.shape[1])  # 按公式书写
        # self.theta = np.dot(np.linalg.inv(xtx), x_train.T).dot(y_train)

    def get_params(self):
        """
        返回线性模型训练的系数
        :return:
        """
        if self.fit_intercept:  # 存在偏置项
            weight, bias = self.theta[:-1], self.theta[-1]
        else:
            weight, bias = self.theta, np.array([0])
        if self.normalize:  # 标准化后的系数
            weight = weight / self.feature_std.reshape(-1)  # 还原模型系数
            bias = bias - weight.T.dot(self.feature_mean.reshape(-1))
        return np.r_[weight.reshape(-1), bias.reshape(-1)]

    def predict(self, x_test):
        """
        测试数据预测,x_test:待预测样本集,不包括偏置项1
        :param x_test:
        :return:
        """
        try:
            self.n_samples, self.n_features = x_test.shape[0], x_test.shape[1]
        except IndexError:
            self.n_samples, self.n_features = x_test.shape[0], 1  # 测试样本数和特征数
        if self.normalize:
            x_test = (x_test - self.feature_mean) / self.feature_std  # 测试数据标准化
        if self.fit_intercept:
            x_test = np.c_[x_test, np.ones(shape=x_test.shape[0])]  # 存在偏置项,添加一列1
        return x_test.dot(self.theta)

    def cal_mse_r2(self, y_pred, y_test):
        """
        计算均方误差,计算拟合优度的判定系数R方和修正判定系数
        :param y_pred: 模型预测目标真值
        :param y_test: 测试目标真值
        :return:
        """
        self.mse = ((y_test - y_pred) ** 2).mean()  # 均方误差
        # 计算测试样本的判定系数和修正判定系数
        self.r2 = 1 - ((y_test - y_pred) ** 2).sum() / ((y_test - y_test.mean()) ** 2).sum()
        self.r2_adj = 1 - (1 - self.r2) * (self.n_samples - 1) / (self.n_samples - self.n_features - 1)
        return self.mse, self.r2, self.r2_adj

    def plt_predict(self, y_pred, y_test, is_show=True, is_sort=True):
        """
        绘制预测值与真实值对比图
        :return:
        """
        if self.mse is np.infty:
            self.cal_mse_r2(y_pred, y_test)
        if is_show:
            plt.figure(figsize=(7, 5))
        if is_sort:
            idx = np.argsort(y_test)
            plt.plot(y_pred[idx], "r:", lw=1.5, label="Predictive Val")
            plt.plot(y_test[idx], "k--", lw=1.5, label="Test True Val")
        else:
            plt.plot(y_pred, "r:", lw=1.5, label="Predictive Val")
            plt.plot(y_test, "k--", lw=1.5, label="Test True Val")
        plt.xlabel("Test sample observation serial number", fontdict={"fontsize": 12})
        plt.ylabel("Predicted sample value", fontdict={"fontsize": 12})
        plt.title("The predictive values of test samples \n MSE = %.5e, R2 = %.5f, R2_adj = %.5f"
                  % (self.mse, self.r2, self.r2_adj), fontdict={"fontsize": 14})
        plt.legend(frameon=False)
        plt.grid(ls=":")
        if is_show:
            plt.show()


from sklearn.datasets import fetch_california_housing
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression
from sklearn.metrics import r2_score, mean_squared_error


housing = fetch_california_housing()
X, y = housing.data, housing.target  # 获取样本数据和目标数据
X_train, X_test, y_train, y_test = \
    train_test_split(X, y, test_size=0.3, random_state=1, shuffle=True)

lgcfs_obj = LRClosedFormSol(normalize=True, fit_intercept=True)
lgcfs_obj.fit(X_train, y_train)
theta = lgcfs_obj.get_params()  # 获得模型系数
print("线性回归模型拟合系数如下:")
for i, fn in enumerate(housing.feature_names):
    print(fn + ":", theta[i])
print("Const:", theta[-1])

# 模型预测,即针对测试样本进行预测
y_pred = lgcfs_obj.predict(X_test)
lgcfs_obj.plt_predict(y_pred, y_test, is_sort=True)

# 采用sklearn库函数进行线性回归和预测
lr = LinearRegression().fit(X_train, y_train)
print("sklearn截距:", lr.intercept_)  # 打印截距
print("sklearn系数:", lr.coef_)  # 打印模型系数
y_test_predict = lr.predict(X_test)
mse = mean_squared_error(y_test, y_test_predict)
r2 = r2_score(y_test, y_test_predict)
print("sklearn均方误差与判定系数为:", mse, r2)




 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1410936.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【Chrome】浏览器怎么清除缓存并强制刷新

文章目录 1、正常刷新:正常刷新网页,网页有缓存则采用缓存。 F5 或 刷新键2、强制刷新:忽略缓存刷新,重新下载资源不用缓存。 CtrlF5 或 ShiftF5 或 CtrlShiftR3、在浏览器的设置里面清除所有数据

哪些 3D 建模软件值得推荐?

云端地球是一款免费的在线实景三维建模软件,不需要复杂的技巧,只要需要手机,多拍几张照片,就可以得到完整的三维模型! 无论是大场景倾斜摄影测量还是小场景、小物体建模,都可以通过云端地球将二维数据向三…

MTP与管理壳(AAS)有异曲同工之妙

在过去的几年中,流程工业中的不同部门(例如制药、精细化学品以及食品和饮料部门)遇到了一系列共同且可比较的新兴挑战。这些挑战包括: 新产品的需求迅速接连不断,更快交货和更低价格的压力,更多定制产品&a…

【wvp】关于码率等的相关流程设计

目录 流程设计 前端UI大致设计 终端上的相关修改界面参考 流程设计 前端UI大致设计 终端上的相关修改界面参考

【WPF.NET开发】WPF中的双向功能

本文内容 FlowDirectionFlowDocumentSpan 元素非文本元素的 FlowDirection数字替换 与其他任何开发平台不同,WPF 具有许多支持双向内容快速开发的功能,例如,同一文档中混合了从左到右和从右到左的数据。 同时,WPF 也为需要双向功…

文件IO讲解

💕"跑起来就有意义"💕 作者:Mylvzi 文章主要内容:文件IO讲解 一.与文件相关的基本概念 1.什么是文件 文件从广义上来说就是操作系统对其所持有的硬件设备和软件资源的抽象化表示,但是在日常生活中我们所提到的文件就…

《Visual Tree Convolutional Neural Network in Image Classification》阅读笔记

论文标题 《Visual Tree Convolutional Neural Network in Image Classification》 图像分类中的视觉树卷积神经网络 作者 Yuntao Liu、Yong Dou、Ruochun Jin 和 Peng Qiao 来自国防科技大学并行和分布式处理国家实验室 初读 摘要 问题: 在图像分类领域&…

1.25号c++

1.引用 引用就是给变量起别名 格式: 数据类型 &引用名 同类型的变量名 (& 引用符号) eg: int a 10; int &b a; //b引用a,或者给a变量取个别名叫b int *p; //指针可以先定义 后指向 p &a; //int &a…

谷歌推出 AutoRT 机器人代理大规模编排的具体基础模型,远程操作和收集 77,000 个机器人事件

演示 AutoRT 向多个建筑物中的20多个机器人提出指令,并通过远程操作和自主机器人策略收集77,000个真实的机器人事件。实验表明,AutoRT 收集的此类“野外”数据明显更加多样化,并且 AutoRT 使用 LLMs 允许遵循能够符合人类偏好的数据收集机器人…

Jenkins全局工具配置

目录 Jenkins全局工具全局工具配置Settings 文件配置Maven配置JDK配置Git配置 Jenkins全局工具 我们在安装了Jenkins之后,就可以开始使用Jenkins来进行一些自动化构建发布工作,但是开始之前我们还需要进行全局工具的配置,Jenkins全局工具配置…

QT入门篇---无门槛学习

1.1 什么是 Qt Qt 是⼀个 跨平台的 C 图形⽤⼾界⾯应⽤程序框架 。它为应⽤程序开发者提供了建⽴艺术级图形界⾯所需的所有功能。它是完全⾯向对象的,很容易扩展。Qt 为开发者提供了⼀种基于组件的开发模式,开发者可以通过简单的拖拽和组合来实现复杂的…

【深度学习】【注意力机制】【自然语言处理】【图像识别】深度学习中的注意力机制详解、self-attention

1、深度学习的输入 无论是我们的语言处理、还是图像处理等,我们的输入都可以看作是一个向量。通过Model最终输出结果。这里,我们的vector大小是不会改变的。 然而,我们有可能会遇到这样的情况: 输入的sequence的长度是不定的怎…

Idea上操作Git回退本地版本,怎么样保留已修改的文件,回退本地版本的四种方式代表什么?

Git的基本概念:Git是一个版本控制系统,用于管理代码的变更历史记录。核心概念包括仓库、分支、提交和合并。 1、可以帮助开发者合并开发的代码 2、如果出现冲突代码的合并,会提示后提交合并代码的开发者,让其解决冲突 3、代码文件版本管理 问题描述 当我们使用git提交代码…

Hive3.1.3基础学习

文章目录 一、Hive入门与安装1、Hive入门1.1 简介1.2 Hive架构原理 2、Hive安装2.1 安装地址2.2 Hive最小化安装(测试用)2.3 MySQL安装2.4 配置Hive元数据存储到MySQL2.5 Hive服务部署2.6 Hive服务启动脚本(了解) 3、Hive使用技巧3.1 Hive常用交互命令3.2 Hive参数配置方式3.3 …

蓝桥杯(C++ 左移右移 买二增一 松散子序列 填充 有奖问答 更小的数 )

目录 左移右移 思路: 代码: 买二增一 思路: 代码: 松散子序列 思路: 代码: 填充 思路: 代码 : 有奖问答 思路: 代码: 更小的数 思路&#…

【快影】怎么制作卡拉OK字幕

您好,您添加了字幕之后可以添加动画,选择卡拉OK,其中 卡拉OK1是支持修改颜色的,卡拉OK2只支持修改文字的底色。

OpenCV使用基础、技巧

OpenCV概述与安装 视觉概述 人类的视觉能够很轻易地从图像中识别出内容。但是,计算机视觉不会像人类视觉那样能够对图像进行感知和识别,更不会自动控制焦距和光圈,而是把图像解析为按照栅格状排列的数字。 这些按照栅格状排列的数字包含大量…

Java中Integer(127)==Integer(127)为True,Integer(128)==Integer(128)却为False,这是为什么?

文章目录 1.前言2. 源码解析3.总结 1.前言 相信大家职业生涯中或多或少的碰到过Java比较变态的笔试题,下面这道题目大家应该不陌生: Integer i 127; Integer j 127;Integer m 128; Integer n 128;System.out.println(i j); // 输出为 true System.o…

华为数通方向HCIP-DataCom H12-831题库(判断题:121-140)

第121题 BGP/MPLS IP VPN内层采用MP-BGP分配的标签区分不同的VPN实例,外层可采用多种隧道类型,例如GRE隧道。 正确 错误 答案: 错误 解析: VPN业务的转发需要隧道来承载,隧道类型包括GRE隧道、LSP隧道、TE隧道(即CR-LSP)。 如果网络边缘的PE设备具备MPLS功能,但骨干网核…

Deepin基本环境查看(四)【硬盘/分区、文件系统、硬连接/软连接】

Linux操作系统(Deepin、Ubuntu)操作系统中,硬盘分区的管理与Windows操作系统不同; 在Linux系统中维护着一个统一的文件目录体系,而硬盘和分区是以资源的形式由操作系统挂接和调度;此外Linux系统中连接(硬连…