机器学习:梯度下降法(Python)

news2024/9/20 18:49:26

LinearRegression_GD.py

import numpy as np
import matplotlib.pyplot as plt


class LinearRegression_GradDesc:
    """
    线性回归,梯度下降法求解模型系数
    1、数据的预处理:是否训练偏置项fit_intercept(默认True),是否标准化normalized(默认True)
    2、模型的训练:闭式解公式,fit(self, x_train, y_train)
    3、模型的预测,predict(self, x_test)
    4、均方误差,判决系数
    5、模型预测可视化
    """

    def __init__(self, fit_intercept=True, normalize=True, alpha=0.05, max_epochs=300, batch_size=20):
        """
        :param fit_intercept: 是否训练偏置项
        :param normalize: 是否标准化
        :param alpha: 学习率
        :param max_epochs: 最大迭代次数
        :param batch_size: 批量大小,若为1,则为随机梯度,若为训练集样本量,则为批量梯度,否则为小批量梯度
        """
        self.fit_intercept = fit_intercept  # 线性模型的常数项。也即偏置bias,模型中的theta0
        self.normalize = normalize  # 是否标准化数据
        self.alpha = alpha  # 学习率
        self.max_epochs = max_epochs
        self.batch_size = batch_size
        self.theta = None  # 训练权重系数
        if normalize:
            self.feature_mean, self.feature_std = None, None  # 特征的均值,标准方差
        self.mse = np.infty  # 训练样本的均方误差
        self.r2, self.r2_adj = 0.0, 0.0  # 判定系数和修正判定系数
        self.n_samples, self.n_features = 0, 0  # 样本量和特征数
        self.train_loss, self.test_loss = [], []  # 存储训练过程中的训练损失和测试损失

    def init_params(self, n_features):
        """
        初始化参数
        如果训练偏置项,也包含了bias的初始化
        :return:
        """
        self.theta = np.random.randn(n_features, 1) * 0.1

    def fit(self, x_train, y_train, x_test=None, y_test=None):
        """
        模型训练,根据是否标准化与是否拟合偏置项分类讨论
        :param x_train: 训练样本集
        :param y_train: 训练目标集
        :param x_test: 测试样本集
        :param y_test: 测试目标集
        :return:
        """
        if self.normalize:
            self.feature_mean = np.mean(x_train, axis=0)  # 样本均值
            self.feature_std = np.std(x_train, axis=0) + 1e-8  # 样本方差
            x_train = (x_train - self.feature_mean) / self.feature_std  # 标准化
            if x_test is not None:
                x_test = (x_test - self.feature_mean) / self.feature_std  # 标准化
        if self.fit_intercept:
            x_train = np.c_[x_train, np.ones_like(y_train)]  # 添加一列1,即偏置项样本
            if x_test is not None and y_test is not None:
                x_test = np.c_[x_test, np.ones_like(y_test)]  # 添加一列1,即偏置项样本
        self.init_params(x_train.shape[1])  # 初始化参数
        self._fit_gradient_desc(x_train, y_train, x_test, y_test)  # 梯度下降法训练模型

    def _fit_gradient_desc(self, x_train, y_train, x_test=None, y_test=None):
        """
        三种梯度下降求解:
        (1)如果batch_size为1,则为随机梯度下降法
        (2)如果batch_size为样本量,则为批量梯度下降法
        (3)如果batch_size小于样本量,则为小批量梯度下降法
        :return:
        """
        train_sample = np.c_[x_train, y_train]  # 组合训练集和目标集,以便随机打乱样本
        # np.c_水平方向连接数组,np.r_竖直方向连接数组
        # 按batch_size更新theta,三种梯度下降法取决于batch_size的大小
        best_theta, best_mse = None, np.infty  # 最佳训练权重与验证均方误差
        for i in range(self.max_epochs):
            self.alpha *= 0.95
            np.random.shuffle(train_sample)  # 打乱样本顺序,模拟随机化
            batch_nums = train_sample.shape[0] // self.batch_size  # 批次
            for idx in range(batch_nums):
                # 取小批量样本,可以是随机梯度(1),批量梯度(n)或者是小批量梯度(<n)
                batch_xy = train_sample[self.batch_size * idx: self.batch_size * (idx + 1)]
                # 分取训练样本和目标样本,并保持维度
                batch_x, batch_y = batch_xy[:, :-1], batch_xy[:, -1:]
                # 计算权重更新增量
                delta = batch_x.T.dot(batch_x.dot(self.theta) - batch_y) / self.batch_size
                self.theta = self.theta - self.alpha * delta
            train_mse = ((x_train.dot(self.theta) - y_train.reshape(-1, 1)) ** 2).mean()
            self.train_loss.append(train_mse)
            if x_test is not None and y_test is not None:
                test_mse = ((x_test.dot(self.theta) - y_test.reshape(-1, 1)) ** 2).mean()
                self.test_loss.append(test_mse)

    def get_params(self):
        """
        返回线性模型训练的系数
        :return:
        """
        if self.fit_intercept:  # 存在偏置项
            weight, bias = self.theta[:-1], self.theta[-1]
        else:
            weight, bias = self.theta, np.array([0])
        if self.normalize:  # 标准化后的系数
            weight = weight / self.feature_std.reshape(-1, 1)  # 还原模型系数
            bias = bias - weight.T.dot(self.feature_mean)
        return weight.reshape(-1), bias

    def predict(self, x_test):
        """
        测试数据预测
        :param x_test: 待预测样本集,不包括偏置项
        :return:
        """
        try:
            self.n_samples, self.n_features = x_test.shape[0], x_test.shape[1]
        except IndexError:
            self.n_samples, self.n_features = x_test.shape[0], 1  # 测试样本数和特征数
        if self.normalize:
            x_test = (x_test - self.feature_mean) / self.feature_std  # 测试数据标准化
        if self.fit_intercept:
            # 存在偏置项,加一列1
            x_test = np.c_[x_test, np.ones(shape=x_test.shape[0])]
        y_pred = x_test.dot(self.theta).reshape(-1, 1)
        return y_pred

    def cal_mse_r2(self, y_test, y_pred):
        """
        计算均方误差,计算拟合优度的判定系数R方和修正判定系数
        :param y_pred: 模型预测目标真值
        :param y_test: 测试目标真值
        :return:
        """
        self.mse = ((y_test.reshape(-1, 1) - y_pred.reshape(-1, 1)) ** 2).mean()  # 均方误差
        # 计算测试样本的判定系数和修正判定系数
        self.r2 = 1 - ((y_test.reshape(-1, 1) - y_pred.reshape(-1, 1)) ** 2).sum() / \
                  ((y_test.reshape(-1, 1) - y_test.mean()) ** 2).sum()
        self.r2_adj = 1 - (1 - self.r2) * (self.n_samples - 1) / \
                      (self.n_samples - self.n_features - 1)
        return self.mse, self.r2, self.r2_adj

    def plt_predict(self, y_test, y_pred, is_show=True, is_sort=True):
        """
        绘制预测值与真实值对比图
        :return:
        """
        if self.mse is np.infty:
            self.cal_mse_r2(y_pred, y_test)
        if is_show:
            plt.figure(figsize=(8, 6))
        if is_sort:
            idx = np.argsort(y_test)  # 升序排列,获得排序后的索引
            plt.plot(y_test[idx], "k--", lw=1.5, label="Test True Val")
            plt.plot(y_pred[idx], "r:", lw=1.8, label="Predictive Val")
        else:
            plt.plot(y_test, "ko-", lw=1.5, label="Test True Val")
            plt.plot(y_pred, "r*-", lw=1.8, label="Predictive Val")
        plt.xlabel("Test sample observation serial number", fontdict={"fontsize": 12})
        plt.ylabel("Predicted sample value", fontdict={"fontsize": 12})
        plt.title("The predictive values of test samples \n MSE = %.5e, R2 = %.5f, R2_adj = %.5f"
                  % (self.mse, self.r2, self.r2_adj), fontdict={"fontsize": 14})
        plt.legend(frameon=False)
        plt.grid(ls=":")
        if is_show:
            plt.show()

    def plt_loss_curve(self, is_show=True):
        """
        可视化均方损失下降曲线
        :param is_show: 是否可视化
        :return:
        """
        if is_show:
            plt.figure(figsize=(8, 6))
        plt.plot(self.train_loss, "k-", lw=1, label="Train Loss")
        if self.test_loss:
            plt.plot(self.test_loss, "r--", lw=1.2, label="Test Loss")
        plt.xlabel("Epochs", fontdict={"fontsize": 12})
        plt.ylabel("Loss values", fontdict={"fontsize": 12})
        plt.title("Gradient Descent Method and Test Loss MSE = %.5f"
                  % (self.test_loss[-1]), fontdict={"fontsize": 14})
        plt.legend(frameon=False)
        plt.grid(ls=":")
        # plt.axis([0, 300, 20, 30])
        if is_show:
            plt.show()

test_linear_regression_gd.py

import numpy as np
from LinearRegression_GD import LinearRegression_GradDesc
from sklearn.model_selection import train_test_split


np.random.seed(42)
X = np.random.rand(1000, 6)  # 随机样本值,6个特征
coeff = np.array([4.2, -2.5, 7.8, 3.7, -2.9, 1.87])  # 模型参数
y = coeff.dot(X.T) + 0.5 * np.random.randn(1000)  # 目标函数值

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=0, shuffle=True)

lr_gd = LinearRegression_GradDesc(alpha=0.1, batch_size=1)
lr_gd.fit(X_train, y_train, X_test, y_test)
theta = lr_gd.get_params()
print(theta)
y_test_pred = lr_gd.predict(X_test)
lr_gd.plt_predict(y_test, y_test_pred)
lr_gd.plt_loss_curve()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1424479.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【Tomcat与网络8】从源码看Tomcat的层次结构

在前面我们介绍了如何通过源码来启动Tomcat&#xff0c;本文我们就来看一下Tomcat是如何一步步启动的&#xff0c;以及在启动过程中&#xff0c;不同的组件是如何加载的。 一般&#xff0c;我们可以通过 Tomcat 的 /bin 目录下的脚本 startup.sh 来启动 Tomcat&#xff0c;如果…

node.js(nest.js控制器)学习笔记

nest.js控制器&#xff1a; 控制器负责处理传入请求并向客户端返回响应。 为了创建基本控制器&#xff0c;我们使用类和装饰器。装饰器将类与所需的元数据相关联&#xff0c;并使 Nest 能够创建路由映射&#xff08;将请求绑定到相应的控制器&#xff09;。 1.获取get请求传参…

使用 Enigma Protector 无需管理员权限即可注册 ActiveX/COM 组件

我们的客户对如何保护 .NET 应用程序免遭破解和转储提出了许多问题。在本文中&#xff0c;我们将尝试描述保护此类特定文件的所有薄弱环节和细节。 The Enigma Protector 是一款专门设计用来为应用程序添加高强度保护的强大工具。它旨在防止非法复制、反编译和修改代码等操作&…

数据结构与算法面试系列-03

1. 一球从100米高度自由落下,每次落地后反跳回原高度的一半;再落下,求它在 第10次落地时,共经过多少米?第10次反弹多高? 程序代码 package com.jingxuan.system;public class Sphere {public static void main(String[] args) {double s = 0;double t = 100;for (int i…

搜维尔科技:第九届元宇宙数字人大赛,参赛小组报名确认公告!

各位参赛选手大家好&#xff0c;近期已收到新增报名信息如下表&#xff0c;请各位参赛选手确认&#xff0c;如果信息有误或信息不完整请电话联系赛务组工作人员进行更正 随着元宇宙时代的来临&#xff0c;数字人设计成为了创新前沿领域之一。为了提高大学生元宇宙虚拟人角色策划…

多线程代码案例之线程池

作者简介&#xff1a; zoro-1&#xff0c;目前大二&#xff0c;正在学习Java&#xff0c;数据结构&#xff0c;javaee等 作者主页&#xff1a; zoro-1的主页 欢迎大家点赞 &#x1f44d; 收藏 ⭐ 加关注哦&#xff01;&#x1f496;&#x1f496; 创建线程池 public class Poo…

Java面试题宝典(万字长文)

Java 基础 1. JDK 和 JRE 有什么区别&#xff1f; JRE是Java运行环境&#xff0c;即&#xff08;Java Runtime Environment&#xff09;&#xff0c;也就是Java平台。所有的Java程序都要在JRE下才能运行。 JDK是开发工具包&#xff0c;即&#xff08;Java Development Kit&am…

经典mysql实操和行专列操作

1.删除除了学号字段以外&#xff0c;其它字段都相同的冗余记录&#xff0c;只保留一条&#xff01;&#xff08;也就是要删除王五和赵六中一条重复数据只留一条&#xff09; 要求的预期效果: 原始数据创建表结构&#xff1a; CREATE TABLE tb_student (id int(16) NOT NULL,na…

科技云报道:云原生PaaS,如何让金融业数字化开出“繁花”?

科技云报道原创。 在中国金融业数字化转型的历史长卷中&#xff0c;过去十年无疑是一部磅礴的史诗。 2017年&#xff0c;南京银行第一次将传统线下金融业务搬到了线上。那一年&#xff0c;它的互联网金融信贷业务实现了过去10年的业务总额。 2021年&#xff0c;富滇银行通过…

【大数据安全】数据管理安全安全分析隐私保护

目录 一、数据管理安全 &#xff08;一&#xff09;数据溯源 &#xff08;二&#xff09;数字水印 &#xff08;三&#xff09;策略管理 &#xff08;四&#xff09;完整性保护 &#xff08;五&#xff09;数据脱敏 二、安全分析 &#xff08;一&#xff09;大数据安全…

数据库技术栈 —— B树与B+树

数据库技术栈 —— B树与B树 一、复习二、MySQL中的B树应用 一、复习 B树是多路平衡查找树的意思 参考文章或视频链接[1] 【王道计算机考研 数据结构】 二、MySQL中的B树应用 这篇文章里的计算题还是讲的不错的。 参考文章或视频链接[1] 《探究MySQL的索引结构选型》

Wireshark网络协议分析 - Wireshark速览

在我的博客阅读本文 文章目录 1. 版本与平台2. 快速上手2.1. 选择网络接口进行捕获&#xff08;Capture&#xff09;2.2. 以Ping命令为例进行抓包分析2.3. 设置合适的过滤表达式2.4. 数据包详情2.5. TCP/IP 四层模型 3. 参考资料 1. 版本与平台 Wireshark是一个开源的网络数据…

IDEA的properties默认编码是UTF-8但是不显示中文

问题描述 今天打开IDEA项目&#xff0c;发现messages_zh_CN.properties不显示中文了 但奇怪的是target下的文件就是展示的中文 而且我IDEA已经配置了编码格式是UTF-8了 使用nodepad打开源文件&#xff0c;也是展示编码格式是UTF-8 &#xff08;打开target下的文件&#xff0c;…

QWT开源库使用

源代码地址&#xff1a;Qwt Users Guide: Qwt - Qt Widgets for Technical Applications Qwt库包含GUI组件和实用程序类&#xff0c;它们主要用于具有技术背景的程序。除了2D图的框架外&#xff0c;它还提供刻度&#xff0c;滑块&#xff0c;刻度盘&#xff0c;指南针&#xf…

EDR、SIEM、SOAR 和 XDR 的区别

在一个名为网络安全谷的神秘小镇&#xff0c;居住着四位守护者&#xff0c;他们分别是EDR&#xff08;艾迪&#xff09;、SIEM&#xff08;西姆&#xff09;、SOAR&#xff08;索亚&#xff09;和XDR&#xff08;艾克斯&#xff09;。他们各自拥有独特的能力&#xff0c;共同守…

Android组件化中的Arouter学习

假设现在有两个业务组件登录和问答模块之间需要进行通信&#xff0c;可能会想到用反射的方式&#xff0c;是可以但是会影响性能&#xff0c;而写的代码比较多类名这些要记清楚。 路由可以看做表&#xff0c;每个map对应一张表 我们可以试着这么写&#xff0c;完成MainActivity跳…

03. 【Linux教程】安装虚拟机

前面小节介绍了 Linux 和 GUN 项目&#xff0c;本小节开始学习如何在 Windows 上安装虚拟机&#xff0c;虚拟机安装之后可以在虚拟机中安装 Linux 相关的操作系统&#xff0c;常见的虚拟机软件有 VirtualBox、VMware 等等&#xff0c;本教程使用 VMware 虚拟机软件来演示如何安…

【Linux取经路】进程控制——进程等待

文章目录 一、进程创建1.1 初识 fork 函数1.2 fork 函数返回值1.3 写时拷贝1.4 fork 的常规用法1.5 fork 调用失败的原因1.6 创建一批进程 二、进程终止2.1 进程退出场景2.2 strerror函数2.3 errno全局变量2.4 程序异常2.5 进程常见退出方法2.6 exit 函数2.7 _exit 函数和 exit…

2024年股市走向!温州哪家证券公司股票开户佣金最低呢,最低可以多少?

​ 对于2024年股市的预测很难做出准确的判断&#xff0c;因为股市受到多种因素的影响&#xff0c;包括经济状况、政策变化、国际形势等。然而&#xff0c;我们可以根据当前的一些趋势和因素来对未来的股市做出一些预测。 首先&#xff0c;随着全球经济的逐步恢复&#xff0c…

【从零开始的rust web开发之路 三】orm框架sea-orm入门使用教程

【从零开始的rust web开发之路 三】orm框架sea-orm入门使用教程 文章目录 前言一、引入依赖二、创建数据库连接简单链接连接选项开启日志调试 三、生成实体安装sea-orm-cli创建数据库表使用sea-orm-cli命令生成实体文件代码 四、增删改查实现新增数据主键查找条件查找查找用户名…