使用线性回归模型逼近目标模型 | PyTorch 深度学习实战

news2025/2/5 16:01:31

前一篇文章,计算图 Compute Graph 和自动求导 Autograd | PyTorch 深度学习实战

本系列文章 GitHub Repo: https://github.com/hailiang-wang/pytorch-get-started

使用线性回归模型逼近目标模型

  • 什么是回归
  • 什么是线性回归
  • 使用 PyTorch 实现线性回归模型
    • 代码
    • 执行结果

什么是回归

在统计学中,回归分析(regression analysis)指的是确定两种或两种以上变量间相互依赖的定量关系的一种统计分析方法。

简单说,就是使用统计学手段,分析变量之间的规律。发现规律后,可以根据给定的数据猜测特征空间的因变量1的数据。

在这里插入图片描述
参考文章:https://zhuanlan.zhihu.com/p/669597409

什么是线性回归

用一条直线去逼近数据的分布,参考定义:

A linear regression is a straight line that describes how the values of a response variable y y y change as the predictor variable x x x changes.

线性回归在实际中,可以包含多元的情况,比如:

z = w 1 x + w 2 y z = w_1 x + w_2 y z=w1x+w2y

更多线性回归介绍,参考文章。

使用 PyTorch 实现线性回归模型

实现量化投资:现在假如我们观测到了某支股票的数据 v v v,并且这支股票和石油的价格 x x x、黄金的价格 y y y和原煤的价格 z z z 有关联。因此,我们取得了不同时刻的 x x x y y y z z z 和对应的股票价格 v v v,现在,依据这些数据,建立一个方程式:

v = a x + b y + c z + d v = ax + by + cz + d v=ax+by+cz+d

此时,依赖历史采集的数据,我们来求 a,b,c,d 的值。
使用 PyTorch,这个程序实现如下。

代码

import torch
import matplotlib.pyplot as plt
import numpy as np

# X and Y data,观测数据,包含多条
# 每条包含 3 个数据,分别代表石油、黄金、原煤的价格
x_data = [[65., 80., 75.],
          [89., 88., 93.],
          [80., 91., 90.],
          [30., 98., 100.],
          [50., 66., 70.]]
# 对应的这支股票的价格
y_data = [[152.],
          [185.],
          [189.],
          [196.],
          [142.]]

# 定义输入 tensor 和输出 tensor 的变量
x=torch.autograd.Variable(torch.Tensor(x_data)) 
y=torch.autograd.Variable(torch.Tensor(y_data))

# Our hypothesis XW+b,定义模型及参数
model=torch.nn.Linear(3,1,bias=True)

# cost criterion,定义损失函数
criterion=torch.nn.MSELoss()

# Minimize,优化器
optimizer=torch.optim.SGD(model.parameters(),lr=1e-7)

# 训练轮数
epochs=200
cost_h=np.zeros(epochs)

# Train the model,对于这个简单的问题,没有使用 SGD,每次都是将数据录入
for step in range(epochs):
    optimizer.zero_grad()
    hypothesis=model(x) # Our hypothesis
    cost=criterion(hypothesis,y)
    cost.backward()
    optimizer.step()
    cost_h[step]=cost.data.numpy()
    print(step,'Loss:',cost.data.numpy(),'\nPredict:\n',hypothesis.data.numpy())

for name, param in model.named_parameters():
    if param.requires_grad:
        print(name, param.data)

plt.plot(cost_h)
plt.show()

执行结果

使用 Python 运行上述程序,结果如下:

weight tensor([[-0.0980,  0.5064,  0.4115]])
bias tensor([-0.1257])

在这里插入图片描述

因为模型的定义是:

model=torch.nn.Linear(3,1,bias=True)

也就是包含了三个参数和一个偏置,最终机器学习得到的公式就是:
v = − 0.0980 x + 0.5064 y + 0.4115 z − 0.1257 v = -0.0980x + 0.5064y + 0.4115z -0.1257 v=0.0980x+0.5064y+0.4115z0.1257

我们就可以由某一天的黄金、石油、原煤的价格,来预测这支股票的价格。


  1. 因变量(dependent variable)函数中的专业名词,也叫函数值。函数关系式中,某些特定的数会随另一个(或另几个)会变动的数的变动而变动,就称为因变量。如:Y=f(X)。此式表示为:Y随X的变化而变化。Y是因变量,X是自变量。 ↩︎

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2293381.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

深入浅出:频谱掩码 Spectral Masking —— 噪音消除利器

在语音处理领域,噪声是一个常见的敌人。无论是语音通话、语音识别,还是语音合成,噪声都会大大降低语音的质量和可理解性。为了解决这个问题,Spectral Masking(频谱掩码) 模型应运而生。它通过从带噪信号的频…

C++ Primer 多维数组

欢迎阅读我的 【CPrimer】专栏 专栏简介:本专栏主要面向C初学者,解释C的一些基本概念和基础语言特性,涉及C标准库的用法,面向对象特性,泛型特性高级用法。通过使用标准库中定义的抽象设施,使你更加适应高级…

Mac M1 ComfyUI 中 AnyText插件安装问题汇总?

Q1:NameError: name ‘PreTrainedTokenizer’ is not defined ? 该项目最近更新日期为2024年12月,该时间段的transformers 版本由PyPI 上的 transformers 页面 可知为4.47.1. A1: transformers 版本不满足要求,必须降级transformors &#…

C++基础(2)

目录 1. 引用 1.1 引用的概念和定义 1.2 引用的特性 1.3 引用的使用 2. 常引用 3. 指针和引用的关系 4. 内联函数inline 5. nullptr 1. 引用 1.1 引用的概念和定义 引用不是新定义一个变量,而是给已存在变量取了一个别名,编译器不会为引用变量开…

electron typescript运行并设置eslint检测

目录 一、初始化package.json 二、安装依赖 三、项目结构 四、配置启动项 五、补充:ts转js别名问题 已整理好的开源代码:Type-Electron: 用typescript开发的electron项目脚手架,轻量级、支持一键配置网页转PC - Gitee.com 一、初始化pac…

modbus协议处理

//------------------------0x01-------------------------------- //MDA_usart_send: aa 55 01 00 06 00 02 00 05 //转modbusTCP——Master——send:地址00002,寄存器数量:00005 00 00 00 00 00 06 01 01 00 02 00 05 //ModbusTCP——Slave…

java-(Oracle)-Oracle,plsqldev,Sql语法,Oracle函数

卸载好注册表,然后安装11g 每次在执行orderby的时候相当于是做了全排序,思考全排序的效率 会比较耗费系统的资源,因此选择在业务不太繁忙的时候进行 --给表添加注释 comment on table emp is 雇员表 --给列添加注释; comment on column emp.empno is 雇员工号;select empno,en…

c++可变参数详解

目录 引言 库的基本功能 va_start 宏: va_arg 宏 va_end 宏 va_copy 宏 使用 处理可变参数代码 C11可变参数模板 基本概念 sizeof... 运算符 包扩展 引言 在C编程中,处理不确定数量的参数是一个常见的需求。为了支持这种需求,C标准库提供了 &…

linux 函数 sem_init () 信号量、sem_destroy()

&#xff08;1&#xff09; &#xff08;2&#xff09; 代码举例&#xff1a; #include <stdio.h> #include <stdlib.h> #include <pthread.h> #include <semaphore.h> #include <unistd.h>sem_t semaphore;void* thread_function(void* arg) …

基于python的体育新闻数据可视化及分析

项目 &#xff1a;北京冬奥会体育新闻数据可视化及分析 摘 要 随着社会的不断进步与发展&#xff0c;新时代下的网络媒体获取的信息也更加庞大和繁杂&#xff0c;相比于传统信息来源更加难以分析和辨别&#xff0c;造成了新时代媒体从业者撰写新闻的难度。在此背景下&#xff…

代码随想录算法【Day36】

Day36 1049. 最后一块石头的重量 II 思路 把石头尽可能分成两堆&#xff0c;这两堆重量如果相似&#xff0c;相撞后所剩的值就是最小值 若石头的总质量为sum&#xff0c;可以将问题转化为0-1背包问题&#xff0c;即给一个容量为sum/2的容器&#xff0c;如何尽量去凑满这个容…

如可安装部署haproxy+keeyalived高可用集群

第一步&#xff0c;环境准备 服务 IP 描述 Keepalived vip Haproxy 负载均衡 主服务器 Rip&#xff1a;192..168.244.101 Vip&#xff1a;192.168.244.100 Keepalive主节点 Keepalive作为高可用 Haproxy作为4 或7层负载均衡 Keepalived vip Haproxy 负载均衡 备用服务…

如何运行Composer安装PHP包 安装JWT库

1. 使用Composer Composer是PHP的依赖管理工具&#xff0c;它允许你轻松地安装和管理PHP包。对于JWT&#xff0c;你可以使用firebase/php-jwt这个库&#xff0c;这是由Firebase提供的官方库。 安装Composer&#xff08;如果你还没有安装的话&#xff09;&#xff1a; 访问Co…

安全策略配置

1.拓扑信息 2. 实验需求 3.需求分析 1.需要在交换机LSW1配置分配vlan并且为配置通道 2/3/4/5 在web界面或者命令行制定相应的安全策略 由于存在默认的拒绝需求4中生产区在任何时刻访问不了web不允许单独配置&#xff0c;只配置动作为运行的策略 4.配置信息 先配置服务器 …

使用Chainlit快速构建一个对话式人工智能应用体验DeepSeek-R1

Chainlit是一个开源的 Python 包&#xff0c;用于构建可用于生产的对话式人工智能。 DeepSeek-R1 是一款强化学习&#xff08;RL&#xff09;驱动的推理模型&#xff0c;解决了模型中的重复性和可读性问题。在 RL 之前&#xff0c;DeepSeek-R1 引入了冷启动数据&#xff0c;进…

生成式AI安全最佳实践 - 抵御OWASP Top 10攻击 (下)

今天小李哥将开启全新的技术分享系列&#xff0c;为大家介绍生成式AI的安全解决方案设计方法和最佳实践。近年来生成式 AI 安全市场正迅速发展。据IDC预测&#xff0c;到2025年全球 AI 安全解决方案市场规模将突破200亿美元&#xff0c;年复合增长率超过30%&#xff0c;而Gartn…

家政预约小程序12服务详情

目录 1 修改数据源2 创建页面3 搭建轮播图4 搭建基本信息5 显示服务规格6 搭建服务描述7 设置过滤条件总结 我们已经在首页、分类页面显示了服务的列表信息&#xff0c;当点击服务的内容时候需要显示服务的详情信息&#xff0c;本篇介绍一下详情页功能的搭建。 1 修改数据源 在…

知识蒸馏教程 Knowledge Distillation Tutorial

来自于&#xff1a;Knowledge Distillation Tutorial 将大模型蒸馏为小模型&#xff0c;可以节省计算资源&#xff0c;加快推理过程&#xff0c;更高效的运行。 使用CIFAR-10数据集 import torch import torch.nn as nn import torch.optim as optim import torchvision.tran…

【Numpy核心编程攻略:Python数据处理、分析详解与科学计算】2.29 NumPy+Scikit-learn(sklearn):机器学习基石揭秘

2.29 NumPyScikit-learn&#xff1a;机器学习基石揭秘 目录 #mermaid-svg-46l4lBcsNWrqVkRd {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-46l4lBcsNWrqVkRd .error-icon{fill:#552222;}#mermaid-svg-46l4lBcsNWr…

【C语言】指针详解:概念、类型与解引用

博客主页&#xff1a; [小ᶻ☡꙳ᵃⁱᵍᶜ꙳] 本文专栏: C语言 文章目录 &#x1f4af;前言&#x1f4af;指针的基本概念1. 什么是指针2. 指针的基本操作 &#x1f4af;指针的类型1. 指针的大小2. 指针类型与所指向的数据类型3. 指针类型与数据访问的关系4. 指针类型的实际意…