Python 中的机器学习简介:多项式回归

news2024/12/23 8:18:27

一、说明

        多项式回归可以识别自变量和因变量之间的非线性关系。本文是关于回归、梯度下降和 MSE 系列文章的第三篇。前面的文章介绍了简单线性回归、回归的正态方程和多元线性回归。

二、多项式回归

        多项式回归用于最适合曲线拟合的复杂数据。它可以被视为多元线性回归的子集。

        请注意,X₀ 是偏差的一列;这允许在第一篇文章中讨论的广义公式。使用上述等式,每个“自变量”都可以被视为 X₁ 的指数版本。

        这允许从多元线性回归使用相同的模型,因为只需要识别每个变量的系数。可以创建一个简单的三阶多项式模型作为示例。其等式如下:

        模型、梯度下降和 MSE 的广义函数可用于前面的文章:

# line of best fit
def model(w, X):
  """
    Inputs:
      w: array of weights | (num features, 1)
      X: array of inputs  | (n samples, num features)

    Output:
      returns the output of X@w | (n samples, 1)
  """

  return torch.matmul(X, w)
# mean squared error (MSE)
def MSE(Yhat, Y):
  """
    Inputs:
      Yhat: array of predictions | (n samples, 1)
      Y: array of expected outputs | (n samples, 1)
    Output:
      returns the loss of the model, which is a scalar
  """
  return torch.mean((Yhat-Y)**2) # mean((error)^2)
# optimizer
def gradient_descent(w):
  """
    Inputs:
      w: array of weights | (num features, 1)

    Global Variables / Constants:
      X: array of inputs  | (n samples, num features)
      Y: array of expected outputs | (n samples, 1)
      lr: learning rate to scale the gradient

    Output:
      returns the updated weights
  """ 

  n = X.shape[0]

  return w - (lr * 2/n) * (torch.matmul(-Y.T, X) + torch.matmul(torch.matmul(w.T, X.T), X)).reshape(w.shape)

三、创建数据

        现在,所需要的只是一些用于训练模型的数据。可以使用“蓝图”功能,并且可以添加随机性。这遵循与前面文章相同的方法。蓝图如下所示:

        可以创建大小为 (800, 4) 的训练集和大小为 (200, 4) 的测试集。请注意,除偏差外,每个特征都是第一个特征的指数版本。

import torch

torch.manual_seed(5)
torch.set_printoptions(precision=2)

# features
X0 = torch.ones((1000,1))
X1 = (100*(torch.rand(1000) - 0.5)).reshape(-1,1) # generates 1000 random numbers from -50 to 50
X2, X3 = X1**2, X1**3
X = torch.hstack((X0,X1,X2,X3))

# normal distribution with a mean of 0 and std of 8
normal = torch.distributions.Normal(loc=0, scale=8)

# targets
Y = (3*X[:,3] + 2*X[:,2] + 1*X[:,1] + 5 + normal.sample(torch.ones(1000).shape)).reshape(-1,1)

# train, test
Xtrain, Xtest = X[:800], X[800:]
Ytrain, Ytest = Y[:800], Y[800:]

        定义初始权重后,可以使用最佳拟合线绘制数据。

torch.manual_seed(5)
w = torch.rand(size=(4, 1))
w
tensor([[0.83],
        [0.13],
        [0.91],
        [0.82]])
import matplotlib.pyplot as plt

def plot_lbf():
  """
    Output:
      prints the line of best fit in comparison to the train and test data
  """

  # plot the train and test sets
  plt.scatter(Xtrain[:,1],Ytrain,label="train")
  plt.scatter(Xtest[:,1],Ytest,label="test")

  # plot the line of best fit
  X1_plot = torch.arange(-50, 50.1,.1).reshape(-1,1) 
  X2_plot, X3_plot = X1_plot**2, X1_plot**3
  X0_plot = torch.ones(X1_plot.shape)
  X_plot = torch.hstack((X0_plot,X1_plot,X2_plot,X3_plot))

  plt.plot(X1_plot.flatten(), model(w, X_plot).flatten(), color="red", zorder=4)

  plt.xlim(-50, 50)
  plt.xlabel("$X$")
  plt.ylabel("$Y$")
  plt.legend()
  plt.show()

plot_lbf()
图片来源:作者

四、训练模型

        为了部分最小化成本函数,可以使用 5e-11 和 500,000 epoch 的学习率与梯度下降一起使用。

lr = 5e-11
epochs = 500000

# update the weights 1000 times
for i in range(0, epochs):
  # update the weights
  w = gradient_descent(w)

  # print the new values every 10 iterations
  if (i+1) % 100000 == 0:
    print("epoch:", i+1)
    print("weights:", w)
    print("Train MSE:", MSE(model(w,Xtrain), Ytrain))
    print("Test MSE:", MSE(model(w,Xtest), Ytest))
    print("="*10)

plot_lbf()
epoch: 100000
weights: tensor([[0.83],
        [0.13],
        [2.00],
        [3.00]])
Train MSE: tensor(163.87)
Test MSE: tensor(162.55)
==========
epoch: 200000
weights: tensor([[0.83],
        [0.13],
        [2.00],
        [3.00]])
Train MSE: tensor(163.52)
Test MSE: tensor(162.22)
==========
epoch: 300000
weights: tensor([[0.83],
        [0.13],
        [2.00],
        [3.00]])
Train MSE: tensor(163.19)
Test MSE: tensor(161.89)
==========
epoch: 400000
weights: tensor([[0.83],
        [0.13],
        [2.00],
        [3.00]])
Train MSE: tensor(162.85)
Test MSE: tensor(161.57)
==========
epoch: 500000
weights: tensor([[0.83],
        [0.13],
        [2.00],
        [3.00]])
Train MSE: tensor(162.51)
Test MSE: tensor(161.24)
==========
图片来源:作者

        即使有 500,000 个 epoch 和极小的学习率,该模型也无法识别前两个权重。虽然当前的解决方案非常准确,MSE为161.24,但可能需要数百万个epoch才能完全最小化它。这是多项式回归梯度下降的局限性之一。

五、正态方程

        作为替代方案,可以使用第二篇文章中的正态方程直接计算优化权重:

def NormalEquation(X, Y):
  """
    Inputs:
      X: array of input values | (n samples, num features)
      Y: array of expected outputs | (n samples, 1)
      
    Output:
      returns the optimized weights | (num features, 1)
  """
  
  return torch.inverse(X.T @ X) @ X.T @ Y

w = NormalEquation(Xtrain, Ytrain)
w
tensor([[4.57],
        [0.98],
        [2.00],
        [3.00]])

        正态方程能够立即识别每个权重的正确值,并且每组的MSE比梯度下降时低约100点:

MSE(model(w,Xtrain), Ytrain), MSE(model(w,Xtest), Ytest)
(tensor(60.64), tensor(63.84))

六、结论

        通过实现简单线性、多重线性和多项式回归,接下来的两篇文章将介绍套索和岭回归。这些类型的回归在机器学习中引入了两个重要概念:过拟合和正则化。

 参考文章:

亨特·菲利普斯

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/846167.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

BenchmarkSQL 支持 TiDB 驱动以及 tidb-loadbalance

作者: GangShen 原文来源: https://tidb.net/blog/3c274180 使用 BenchmarkSQL 对 TiDB 进行 TPC-C 测试 众所周知 TiDB 是一个兼容 MySQL 协议的分布式关系型数据库,用户可以使用 MySQL 的驱动以及连接方式连接 TiDB 进行使用&#xff0…

Butterfly安装文档(三)主题配置-1

语言 修改站点配置文件 _config.yml 默认语言是 en 主题支持三种语言 default(en)zh-CN (简体中文)zh-TW (繁体中文) 网站资料 修改网站各种资料,例如标题、副标题和邮箱等个人资料,请修改博客根目录的_config.yml 导航栏设置 (Navigation bar set…

Data analysis|Tableau基本介绍及可实现功能

一、基础知识介绍 (一)什么是tableau tableau 成立于 2003 年,是斯坦福大学一个计算机科学项目的成果,该项目旨在改善分析流程并让人们能够通过可视化更轻松地使用数据。Tableau可以帮助用户更好地理解和发现数据中的价值&#x…

工具推荐之不出网环境下上线CS

前言 在实战攻防演练中,我们经常会遇到目标不出网的情况,即便获取了目标权限也不方便在目标网络进行下一步横向移动。本期我们将会推荐两个常用的代理工具,使我们能在不出网的环境下让目标上线到CS,方便后渗透的工作。 工具1&…

vue如何对node_modules源码进行修改,对第三方依赖包源码修改

方法 用patch-package给node_module中的包打补丁,解决修改源码的问题 使用 1、下载 patch-package 包:npm install patch-package -D 2、package.json文件中增加命令:"postinstall": "patch-package" {"scripts&quo…

【go-zero】docker镜像直接部署go-zero的API与RPC服务 如何实现注册发现?docker network 实现 go-zero 注册发现

一、场景&问题 使用docker直接部署go-zero微服务会发现API无法找到RPC服务 1、API无法发现RPC服务 用docker直接部署 我们会发现API无法注册发现RPC服务 原因是我们缺少了docker的network网桥 2、系统内查看 RPC服务运行正常API服务启动,通过docker logs 查看日志还是未…

MyCat垂直分库案例以及全局表概念讲解

这里的分片指的就是分库分表 1.垂直拆分 1.1场景介绍 1.2 数据库准备 1.3MyCat配置 schema.xml: <schema name"shopping" checkSQLschema"true" sqlMaxLimit"100"><table name"tb_goods_base" dataNode"dn1" pr…

⛳ Java注解

目录 ⛳ Java注解&#x1f3ed; 一&#xff0c;常见的注解&#x1f3a8; 二&#xff0c;JDK元注解&#x1f69c; 三&#xff0c;通过反射获取注解&#x1f43e; 3.1、JDK常用注解&#x1f463; 3.2、简单注解&#x1f4e2; 3.3、复杂注解 ⛳ Java注解 从 JDK 5.0 开始, Java 增…

python --windows获取启动文件夹路径/获取当前用户名/添加自启动文件

如何使用Python获取计算机用户名 一、Python自带的getpass模块可以用于获取用户输入的密码&#xff0c;但是它同样可以用来获取计算机用户名。 import getpassuser getpass.getuser() print("计算机用户名为&#xff1a;", user)二、使用os模块获取用户名 Python的…

深度学习部署:FastDeploy部署教程(CSharp版本)

FastDeploy部署教程(CSharp版本) 1. FastDeploy介绍 FastDeploy是一款全场景、易用灵活、极致高效的AI推理部署工具&#xff0c; 支持云边端部署。提供超过 &#x1f525;160 Text&#xff0c;Vision&#xff0c; Speech和跨模态模型&#x1f4e6;开箱即用的部署体验&#xf…

[机器学习]线性回归模型

线性回归 线性回归&#xff1a;根据数据&#xff0c;确定两种或两种以上变量间相互依赖的定量关系 函数表达式&#xff1a; y f ( x 1 , x 2 . . . x n ) y f(x_1,x_2...x_n) yf(x1​,x2​...xn​) ​ 回归根据变量数分为一元回归[ y f ( x ) yf(x) yf(x)]和多元回归[ y …

CSS 属性计算过程

CSS 属性计算过程 你是否了解 CSS 的属性计算过程呢&#xff1f; 有的同学可能会讲&#xff0c;CSS属性我倒是知道&#xff0c;例如&#xff1a; p{color : red; }上面的 CSS 代码中&#xff0c;p 是元素选择器&#xff0c;color 就是其中的一个 CSS 属性。 但是要说 CSS 属…

国内大模型在局部能力上已超ChatGPT

中文大模型正在后来居上&#xff0c;也必须后来居上。 数科星球原创 作者丨苑晶 编辑丨大兔 从GPT3.5彻底出圈后&#xff0c;大模型的影响力开始蜚声国际。一段时间内&#xff0c;国内科技公司可谓被ChatGPT按在地上打&#xff0c;毫无还手之力。 彼时&#xff0c;很多企业…

echarts实现中国地图下钻进入下一级行政区(地图钻取)

获取geo数据&#xff1a; 可以使用node爬虫获取数据 最好多爬几遍&#xff0c;因为有时候会获取错误 echarts实现 html <div ref"echarts-dom" class"echarts-content"></div>js: export default {data() {return {mapChart: null,addressC…

太心累!企业IT维修呼唤更专业的维修平台

大数据产业创新服务媒体 ——聚焦数据 改变商业 设想这样一个场景&#xff1a;在繁忙的工作日早晨&#xff0c;企业的运营部门突然发现一批重要的办公设备&#xff0c;台式电脑、笔记本电脑和打印机&#xff0c;出现了各种技术问题。无法连接网络、电脑启动异常、软件冲突等问…

【SQL应知应会】索引(一)• MySQL版

欢迎来到爱书不爱输的程序猿的博客, 本博客致力于知识分享&#xff0c;与更多的人进行学习交流 本文收录于SQL应知应会专栏,本专栏主要用于记录对于数据库的一些学习&#xff0c;有基础也有进阶&#xff0c;有MySQL也有Oracle 索引 • MySQL版 前言一、索引1.简介1.1 索引的优点…

业界首个云管理产品与服务图谱发布,九州未来入选!

近日&#xff0c;由中国信息通信研究院和中国通信标准化协会联合主办的第十届可信云大会在北京成功召开&#xff0c;会上发布业界首个云管理全景图《云管理产品与服务图谱&#xff08;2023&#xff09;》。 九州未来凭借在云管理领域的多年深耕&#xff0c;成功入选“智慧应用…

【数学建模学习(9):模拟退火算法】

模拟退火算法(Simulated Annealing, SA)的思想借 鉴于固体的退火原理&#xff0c;当固体的温度很高的时候&#xff0c;内能比 较大&#xff0c;固体的内部粒子处于快速无序运动&#xff0c;当温度慢慢降 低的过程中&#xff0c;固体的内能减小&#xff0c;粒子的慢慢趋于有序&a…

无涯教程-Perl - defined函数

描述 如果 EXPR 的值不是undef值,则此函数返回true&#xff1b;如果未指定 EXPR ,则检查$_的值。它可以与许多功能一起使用以检测操作失败,因为如果出现问题,它们将返回undef。简单的布尔测试不会区分false,零,空字符串或字符串.0。 如果 EXPR 是函数或函数引用,则在定义函数…

ORCA优化器浅析——CFunctionProp function properties

CFunctionProp CFunctionProp代表了function properties函数属性&#xff0c;主要由function stability函数易变性&#xff08; enum EFuncStbl { EfsImmutable, /* never changes for given input */ EfsStable, /* does not change within a scan */ EfsVolatile, /* can ch…