机器学习复习(2)——线性回归SGD优化算法

news2024/9/28 5:24:52

目录

线性回归代码

线性回归理论

SGD算法

手撕线性回归算法

模型初始化

定义模型主体部分

定义线性回归模型训练过程

数据demo准备

模型训练与权重参数

定义线性回归预测函数

定义R2系数计算

可视化展示 

预测结果

训练过程 

sklearn进行机器学习

线性回归代码

class My_Model(nn.Module):
    def __init__(self, input_dim):
        super(My_Model, self).__init__()
        # 矩阵的维度(dimensions) 
        self.layers = nn.Sequential(
            nn.Linear(input_dim, 16),
            nn.ReLU(),
            nn.Linear(16, 8),
            nn.ReLU(),
            nn.Linear(8, 1)
        )

    def forward(self, x):
        x = self.layers(x)
        x = x.squeeze(1) # (B, 1) -> (B)
        return x

线性回归理论

回归算法是相对分类算法而言的,与我们想要预测的目标变量y的值类型有关。

如果目标变量y是分类型变量,如预测用户的性别(男、女),预测月季花的颜色(红、白、黄……),那我们就需要用分类算法去拟合训练数据并做出预测;

如果y是连续型变量,如预测用户的收入(4千,2万,10万……),预测患肺癌的概率(1%,50%,99%……),我们则需要用回归模型。

有时分类问题也可以转化为回归问题。可以用回归模型先预测出患肺癌的概率,然后再给定一个阈值,例如50%,概率值在50%以下为A类,50%以上为B类。

一元线性回归公式:

 具象化含义:

SGD算法

手撕线性回归算法

模型初始化

### 初始化模型参数
def initialize_params(dims):
    '''
    输入:
    dims:训练数据变量维度
    输出:
    w:初始化权重参数值
    b:初始化偏差参数值
    '''
    # 初始化权重参数为零矩阵
    w = np.zeros((dims, 1))
    # 初始化偏差参数为零
    b = 0
    return w, b
w,b=initialize_params(3)#用于测试
print("w初始化是",w)
print("b初始化是",b)

运行结果:

定义模型主体部分

包括线性回归公式、均方损失和参数偏导三部分
def linear_loss(X, y, w, b):
    '''
    输入:
    X:输入变量矩阵
    y:输出标签向量
    w:变量参数权重矩阵
    b:偏差项
    输出:
    y_hat:线性模型预测输出
    loss:均方损失值
    dw:权重参数一阶偏导
    db:偏差项一阶偏导
    '''
    # 训练样本数量
    num_train = X.shape[0]
    # 训练特征数量
    num_feature = X.shape[1]
    # 线性回归预测输出
    y_hat = np.dot(X, w) + b
    # 计算预测输出与实际标签之间的均方损失
    loss = np.sum((y_hat-y)**2)/num_train
    # 基于均方损失对权重参数的一阶偏导数
    dw = np.dot(X.T, (y_hat-y)) /num_train
    # 基于均方损失对偏差项的一阶偏导数
    db = np.sum((y_hat-y)) /num_train
    return y_hat, loss, dw, db

定义线性回归模型训练过程

### 定义线性回归模型训练过程
def linear_train(X, y, learning_rate=0.01, epochs=10000):
    '''
    输入:
    X:输入变量矩阵
    y:输出标签向量
    learning_rate:学习率
    epochs:训练迭代次数
    输出:
    loss_his:每次迭代的均方损失
    params:优化后的参数字典
    grads:优化后的参数梯度字典
    '''
    # 记录训练损失的空列表
    loss_his = []
    # 初始化模型参数
    w, b = initialize_params(X.shape[1])
    # 迭代训练
    for i in range(1, epochs):
        # 计算当前迭代的预测值、损失和梯度
        y_hat, loss, dw, db = linear_loss(X, y, w, b)
#y_hat是预测值,loss是损失,dw是权重参数一阶偏导,db是偏差项一阶偏导
        # 基于梯度下降的参数更新
        w += -learning_rate * dw
        b += -learning_rate * db
        # 记录当前迭代的损失
        loss_his.append(loss)
        # 每1000次迭代打印当前损失信息
        if i % 10000 == 0:
            print('epoch %d loss %f' % (i, loss))
        # 将当前迭代步优化后的参数保存到字典
        params = {
            'w': w,
            'b': b
        }
        # 将当前迭代步的梯度保存到字典
        grads = {
            'dw': dw,
            'db': db
        }     
    return loss_his, params, grads

其中的shape操作说明:

import numpy as np
# 创建一个示例的训练数据集 X
X = np.array([[1, 2, 3],
              [4, 5, 6],
              [7, 8, 9],
              [10, 11, 12],
              [13, 14, 15]])
# 计算训练样本数量
shape0 = X.shape[0]
shape1 = X.shape[1]
print("shape0是",shape0)
print("shape1是",shape1)

运行结果:

数据demo准备

from sklearn.datasets import load_diabetes
diabetes = load_diabetes()
data = diabetes.data
target = diabetes.target 
print(data.shape)
print(target.shape)
print(data[:5])
print(target[:5])
###########################################
# 导入sklearn diabetes数据接口
from sklearn.datasets import load_diabetes
# 导入sklearn打乱数据函数
from sklearn.utils import shuffle
# 获取diabetes数据集
diabetes = load_diabetes()
# 获取输入和标签
data, target = diabetes.data, diabetes.target 
# 打乱数据集
X, y = shuffle(data, target, random_state=13)
# 按照8/2划分训练集和测试集
offset = int(X.shape[0] * 0.8)
# 训练集
X_train, y_train = X[:offset], y[:offset]
# 测试集
X_test, y_test = X[offset:], y[offset:]
# 将训练集改为列向量的形式
y_train = y_train.reshape((-1,1))
# 将验证集改为列向量的形式
y_test = y_test.reshape((-1,1))
# 打印训练集和测试集维度
print("X_train's shape: ", X_train.shape)
print("X_test's shape: ", X_test.shape)
print("y_train's shape: ", y_train.shape)
print("y_test's shape: ", y_test.shape)

模型训练与权重参数

# 线性回归模型训练
loss_his, params, grads = linear_train(X_train, y_train, 0.01, 200000)
# 打印训练后得到模型参数
print(params)

定义线性回归预测函数

### 定义线性回归预测函数
def predict(X, params):
    '''
    输入:
    X:测试数据集
    params:模型训练参数
    输出:
    y_pred:模型预测结果
    '''
    # 获取模型参数
    w = params['w']
    b = params['b']
    # 预测
    y_pred = np.dot(X, w) + b
    return y_pred
# 基于测试集的预测
y_pred = predict(X_test, params)
# 打印前五个预测值
y_pred[:5]

定义R2系数计算

R2系数,也称为决定系数(Coefficient of Determination),是一种用于评估回归模型拟合优度的统计指标。它表示模型对观测数据的方差解释比例,通常用于衡量回归模型的拟合程度。

R2系数的取值范围在0到1之间,具体含义如下:

  • 如果R2等于0,表示模型未能解释目标变量的任何方差,即模型无法拟合数据。
  • 如果R2等于1,表示模型完美拟合了数据,能够解释目标变量的所有方差。
  • 如果R2在0和1之间,表示模型能够解释一部分目标变量的方差,数值越接近1,说明模型的拟合程度越好。

计算公式如下:

其中:

  • SSR(Sum of Squares of Residuals)表示模型的残差平方和,即实际观测值与模型预测值之间的差异的平方和。
  • SST(Total Sum of Squares)表示总平方和,即实际观测值与观测值的均值之间的差异的平方和。

R2系数越接近1,说明模型对数据的拟合越好,而越接近0则表示模型的拟合效果较差。这个指标对于评估回归模型的性能非常有用,帮助我们了解模型解释数据方差的程度。

### 定义R2系数函数
def r2_score(y_test, y_pred):
    '''
    输入:
    y_test:测试集标签值
    y_pred:测试集预测值
    输出:
    r2:R2系数
    '''
    # 测试标签均值
    y_avg = np.mean(y_test)
    # 总离差平方和
    ss_tot = np.sum((y_test - y_avg)**2)
    # 残差平方和
    ss_res = np.sum((y_test - y_pred)**2)
    # R2计算
    r2 = 1 - (ss_res/ss_tot)
    return r2

可视化展示 

预测结果

import matplotlib.pyplot as plt
f = X_test.dot(params['w']) + params['b']

plt.scatter(range(X_test.shape[0]), y_test)
plt.plot(f, color = 'darkorange')
plt.xlabel('X_test')
plt.ylabel('y_test')
plt.show();

运行结果:

训练过程 

plt.plot(loss_his, color='blue')
plt.xlabel('epochs')
plt.ylabel('loss')
plt.show()

运行结果:

sklearn进行机器学习

 和torch.nn类似:封装好了linear函数,直接掉包

### sklearn版本为1.0.2
# 导入线性回归模块
from sklearn import linear_model
from sklearn.metrics import mean_squared_error, r2_score
# 创建模型实例
regr = linear_model.LinearRegression()
# 模型拟合
regr.fit(X_train, y_train)
# 模型预测
y_pred = regr.predict(X_test)
# 打印模型均方误差
print("Mean squared error: %.2f" % mean_squared_error(y_test, y_pred))
# 打印R2
print('R2 score: %.2f' % r2_score(y_test, y_pred))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1429481.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

浙政钉(专有钉钉)

专有钉钉是浙政钉的测试版本,可在正式发布之前进行业务开发。 专有钉钉 原名政务钉钉 是高安全、强管控、灵活开放的面向大型组织专有独享的协同办公平台。支持专有云、混合云等多种方式灵活部署,以满足客户特定场景所需为目标,最大化以“平…

elk之简介

写在前面 本文看下es的简介。 1:简介 背后公司,elastic,08年纽交所上市,与腾讯,阿里等云厂商有合作,推出云产品,类似功能的产品由solr,splunk,但使用量es当前遥遥领先…

Vue-52、Vue技术插槽使用

1、默认插槽 App.vue <template><div class"container"><category title"美食"><img src"https://tse2-mm.cn.bing.net/th/id/OIP-C.F0xLT-eKcCSOI5tdjBv9FAHaE8?w303&h202&c7&r0&o5&pid1.7">&l…

网络安全之SSL证书加密

简介 SSL证书是一种数字证书&#xff0c;遵守SSL协议&#xff0c;由受信任的数字证书颁发机构&#xff08;CA&#xff09;验证服务器身份后颁发。它具有服务器身份验证和数据传输加密的功能&#xff0c;能够确保数据在传输过程中的安全性和完整性。 具体来说&#xff0c;SSL证…

番外篇 vue与django 交互流程

学习了一段时间的django和vue&#xff0c;对于前后端开发有了一个初步的了解&#xff0c;这里记录一下编写的流程和思路&#xff0c;主要是为了后面如果遗忘从哪里开始操作做一个起步引导作用 一、Django后端 参考下前面django的文档https://moziang.blog.csdn.net/article/det…

Web实战丨基于django+hitcount的网页计数器

文章目录 写在前面Django简介主要程序运行结果系列文章写在后面 写在前面 本期内容 基于djangohitcount的网页计数器 所需环境 pythonpycharm或vscodedjango 下载地址 https://download.csdn.net/download/m0_68111267/88795611 Django简介 Django 是一个开源的、基于 …

六、Nacos源码系列:Nacos健康检查

目录 一、简介 二、健康检查流程 2.1、健康检查 2.2、客户端释放连接事件 2.3、客户端断开连接事件 2.4、小结 2.5、总结图 三、服务剔除 一、简介 Nacos作为注册中心不止提供了服务注册和服务发现的功能&#xff0c;还提供了服务可用性检测的功能&#xff0c;在Nacos…

【QT+QGIS跨平台编译】之二十二:【FontConfig+Qt跨平台编译】(一套代码、一套框架,跨平台编译)

文章目录 一、FontConfig介绍二、文件下载三、文件分析四、pro文件五、编译实践 一、FontConfig介绍 FontConfig 是一个用于配置和定制字体的库&#xff0c;广泛应用于基于X Window系统的操作系统中&#xff0c;尤其是在Linux和Unix-like系统中。它为应用程序提供了一种统一的…

在VM虚拟机搭建NFS服务器

NFS共享要求如下&#xff1a; &#xff08;1&#xff09;共享“/mnt/自已姓名的完整汉语拼音”目录&#xff0c;允许XXX网段的计算机访问该共享目录&#xff0c;可进行读写操作。&#xff08;说明&#xff1a;XXX网段&#xff0c;请根据你的规划&#xff0c;再具体指定&#xf…

2024美赛数学建模B题思路分析 - 搜索潜水器

# 1 赛题 问题B&#xff1a;搜索潜水器 总部位于希腊的小型海上巡航潜艇&#xff08;MCMS&#xff09;公司&#xff0c;制造能够将人类运送到海洋最深处的潜水器。潜水器被移动到该位置&#xff0c;并不受主船的束缚。MCMS现在希望用他们的潜水器带游客在爱奥尼亚海底探险&…

STM32F1 - 存储器映射

Memory mapping 1> 外设内存地址映射2> GPIO寄存器映射3> 存储器访问 1> 外设内存地址映射 1> STM32F103ZET6的地址线位宽为32位&#xff0c;所以寻址空间为4GB &#xff08;2 ^ 32 4GB&#xff09;&#xff1b; 2> STM32将&#xff0c;Flash&#xff0c;SR…

数仓建模维度建模理论知识

0. 思维导图 第 1 章 数据仓库概述 1.1 数据仓库概述 数据仓库是一个为数据分析而设计的企业级数据管理系统。数据仓库可集中、整合多个信息源的大量数据&#xff0c;借助数据仓库的分析能力&#xff0c;企业可从数据中获得宝贵的信息进而改进决策。同时&#xff0c;随着时间的…

【经典项目】Java小游戏 —— 弹力球

一、功能需求 设计一个Java弹球小游戏的思路如下&#xff1a; 创建游戏窗口&#xff1a;使用Java图形库&#xff08;如Swing或JavaFX&#xff09;创建一个窗口&#xff0c;作为游戏的可视化界面。 绘制游戏界面&#xff1a;在游戏窗口中绘制游戏所需的各个元素&#xff0c;包括…

同城外卖跑腿app开发:重新定义城市生活

随着科技的发展和人们生活节奏的加快&#xff0c;同城外卖跑腿app应运而生&#xff0c;成为现代城市生活中的重要组成部分。本文将探讨同城外卖跑腿app开发的意义、市场需求、功能特点以及未来的发展趋势。 一、同城外卖跑腿app开发的意义 同城外卖跑腿app作为一种便捷的生活…

Visual Studio 2022 查看类关系图

这里写自定义目录标题 右键要查看的项目 -“查看”-“查看类图”效果展示&#xff1a; 原文地址 www.cnblogs.com 步骤1&#xff1a;勾选扩展开发 步骤2: 勾选类设计器 右键要查看的项目 -“查看”-“查看类图” 效果展示&#xff1a;

使用 WebSocket 发送二进制数据:最佳实践

WebSocket 技术提供了一种在客户端和服务器间建立持久连接的方法&#xff0c;使得双方可以在打开连接后随时发送数据&#xff0c;而不必担心建立复杂的持久连接机制。同时&#xff0c;使用二进制数据&#xff0c;如ArrayBuffer&#xff0c;可以更有效率地传送图像、声音等信息。…

MySQL知识点总结(四)——MVCC

MySQL知识点总结&#xff08;四&#xff09;——MVCC 三个隐式字段row_idtrx_idroll_pointer undo logread viewMVCC与隔离级别的关系快照读和当前读 MVCC全称是Multi Version Concurrency Control&#xff0c;也就是多版本并发控制。它的作用是提高事务的并发度&#xff0c;通…

西瓜书学习笔记——主成分分析(公式推导+举例应用)

文章目录 算法介绍实验分析 算法介绍 主成分分析&#xff08;Principal Component Analysis&#xff0c;PCA&#xff09;是一种常用的降维技术&#xff0c;用于在高维数据中发现最重要的特征或主成分。PCA的目标是通过线性变换将原始数据转换成一组新的特征&#xff0c;这些新…

web前端--------渐变和过渡

线性渐变&#xff0c;是指颜色沿一条直线进行渐变&#xff0c;例如从上到下、从左到右。 当然&#xff0c;CSS中也支持使用角度来设置渐变的方向&#xff0c;角度单位为deg。 0deg&#xff0c;为12点钟方向&#xff0c;表示从下到上渐变。 90deg&#xff0c;为3点钟方向&…

C++ | 部分和函数partial_sum的使用技巧

如果你需要处理一个数组的前缀和&#xff0c;或者数组中某一段元素的前缀和&#xff0c;你会怎么做呢&#xff1f; partial_sum函数是STL中的函数&#xff0c;用于计算范围的部分和&#xff0c;并从结果开始分配范围中的每个元素&#xff0c;range[first,last)中相应元素的部分…