Pytorch线性回归教程

news2024/11/26 15:25:57
import torch
import numpy as np
import torch.nn as nn
import matplotlib.pyplot as plt

生成测试数据

# 长期趋势
def trend(time, slope=0):
    return slope * time

# 季节趋势
def seasonal_pattern(season_time):
    return np.where(season_time < 0.4,
                    np.cos(season_time * 2 * np.pi),
                    1 / np.exp(3 * season_time))
def seasonality(time, period, amplitude=1, phase=0):
    season_time = ((time + phase) % period) / period
    return amplitude * seasonal_pattern(season_time)

# 噪声
def noise(time, noise_level=1):
    return np.random.randn(len(time)) * noise_level
X = torch.arange(1, 1001)
# Y = 0.7 * X + 100 + torch.randn(X.size())
Y = trend(X, 0.3) + seasonality(X, period=365, amplitude=30) + noise(X, 15) + 200
X.shape, Y.shape
(torch.Size([1000]), torch.Size([1000]))
plt.plot(X.numpy(), Y.numpy());

对测试数据进行处理

# 模型的数据的类型需要是32位浮点型
X = X.type(torch.float32)
Y = Y.type(torch.float32)
X.dtype, Y.dtype
(torch.float32, torch.float32)
# 模型的数据需要进行归一化或者标准化,下面是归一化
X = (X - X.min()) / (X.max() - X.min())
Y = (Y - Y.min()) / (Y.max() - Y.min())
plt.plot(X.numpy(), Y.numpy());

定义模型和模型参数

# 线性模型只有两个参数斜率k,和偏置b
# 线性模型的方程为y = k * x + b
k = nn.Parameter(torch.rand(1, dtype=torch.float32))
b = nn.Parameter(torch.rand(1, dtype=torch.float32))
# 下面输出中的requires_grad=True 表示该参数需要计算梯度
# 梯度用于在反向传播中对参数进行优化,优化方法即梯度下降
k, b 
(Parameter containing:
 tensor([0.6231], requires_grad=True),
 Parameter containing:
 tensor([0.0044], requires_grad=True))
def linear_model(x):
    return k * x + b

梯度下降优化参数

# 可以通过改变学习率lr和epoch_num学习各自的用途
# 定义优化器,用于更新模型的参数,即传入的k和b
optimizer = torch.optim.SGD([k, b], lr=0.01)
# 损失函数,模型优化的目的
loss_func = nn.MSELoss()

# 每个epoch表示把全部的数据过一遍
epoch_num = 2000
for epoch in range(epoch_num):
    # 获取模型预测结果
    y_pred = linear_model(X)
    # 计算损失值
    loss = loss_func(y_pred, Y)
    # 将梯度设为0
    optimizer.zero_grad()
    # 反向传播,计算梯度
    loss.backward()
    # 执行梯度下降,优化参数
    optimizer.step()
k, b
(Parameter containing:
 tensor([0.8825], requires_grad=True),
 Parameter containing:
 tensor([0.0419], requires_grad=True))
# detach()函数用于将参数设置为不需要梯度
k2 = k.detach().numpy()[0]
b2 = b.detach().numpy()[0]

plt.plot(X, Y);
plt.plot(X, k2 * X + b2);

优化模型

class LinearModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.k = nn.Parameter(torch.rand(1, dtype=torch.float32))
        self.b = nn.Parameter(torch.rand(1, dtype=torch.float32))

    def forward(self, x):
        return self.k * x + self.b
model = LinearModel()
# 定义优化器,用于更新模型的参数,即传入的k和b
optimizer = torch.optim.SGD([k, b], lr=0.01)
# 损失函数,模型优化的目的
loss_func = nn.MSELoss()

epoch_num = 2000
for epoch in range(epoch_num):
    y_pred = model(X)
    loss = loss_func(y_pred, Y)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
k, b
(Parameter containing:
 tensor([0.8825], requires_grad=True),
 Parameter containing:
 tensor([0.0419], requires_grad=True))
k2 = k.detach().numpy()[0]
b2 = b.detach().numpy()[0]

plt.plot(X, Y);
plt.plot(X, k2 * X + b2);

随机梯度下降

# 前面执行梯度下降时,我们是一次将全部的数据都传入模型
# 但在实际应用中,可能会由于数据太大,没法全部传入模型
# 因此,可以一次传入一部分数据,这便是随机梯度下降
# 随机梯度下降的核心是,梯度是期望。期望可使用小规模的样本近似估计。
model = LinearModel()
# 定义优化器,用于更新模型的参数,即传入的k和b
optimizer = torch.optim.SGD([k, b], lr=0.01)
# 损失函数,模型优化的目的
loss_func = nn.MSELoss()

# 每个epoch表示把全部的数据过一遍
epoch_num = 2000 
# iter_step表示在一个epoch内抽取几个小规模样本
iter_step = 10
# batch_size表示小规模样本的大小
batch_size = 100
for epoch in range(epoch_num):
    for i in range(iter_step):
        random_samples = torch.randint(X.size()[0], (batch_size, ))
        X_i, Y_i = X[random_samples], Y[random_samples]
        y_pred = model(X_i)
        loss = loss_func(y_pred, Y_i)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
k, b
(Parameter containing:
 tensor([0.8825], requires_grad=True),
 Parameter containing:
 tensor([0.0419], requires_grad=True))
k2 = k.detach().numpy()[0]
b2 = b.detach().numpy()[0]

plt.plot(X, Y);
plt.plot(X, k2 * X + b2);

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1291945.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

el-tooltip (element-plus)修改长度

初始状态&#xff1a; 修改后&#xff1a; 就是添加 :teleported"false"&#xff0c;问题解决&#xff01;&#xff01;&#xff01; <el-tooltipeffect"dark"content"要求密码长度为9-30位&#xff0c;需包含大小写字母、数字两种或以上与特殊字…

你的轻量化设计能有效提高模型的推理速度吗?

写在前面&#xff1a;本博客仅作记录学习之用&#xff0c;部分图片来自网络&#xff0c;如需引用请注明出处&#xff0c;同时如有侵犯您的权益&#xff0c;请联系删除&#xff01; 文章目录 前言预备知识模型指标MACs计算卷积MACs全连接MACs激活函数MACsBN MACs 存储访问存储构…

EasyExcel如何读取全部Sheet页数据方法

一、需求描述 Excel表格里面大约有20个sheet页&#xff0c;每个sheet页65535条数据&#xff0c;需要读取全部数据&#xff0c;并导入至数据库。 找了好多种方式&#xff0c;EasyExcel比较符合&#xff0c;下面看代码。 二、实现方式 采用EasyExcel框架的doReadAll()方法 1、…

千梦网创:为什么要做一个专业的倒爷?

培训是什么&#xff1f;教书育人&#xff0c;传递价值。 割韭菜是什么&#xff1f;贩卖焦虑&#xff0c;制造需求。 做培训难还是割韭菜难&#xff1f;不言而喻。 有人边割边培&#xff0c;有人边培边割&#xff0c;有人只割不培&#xff0c;只培不割的有&#xff0c;但我接…

【移动端vant 地址选择滑动不了】

分析&#xff1a; H5页面直接在浏览器打开是没有任何问题的&#xff0c;但是内嵌到小程序中就会出现&#xff0c;目前已出现在抖音&#xff0c;快手&#xff0c;小程序中&#xff0c;其他的没有试 大致看了一下&#xff0c;滑动不了的原因&#xff0c;可能是页面禁止滑动或滚动…

Linux下查看端口占用

第一种&#xff1a;通过命令查看 1.netstat -ntulp&#xff1a;查看所有的被占用的端口 在列表中最后一列就列出了&#xff0c;某个端口被占用的进程 其中&#xff1a; -t : 指明显示TCP端口 -u : 指明显示UDP端口 -l : 仅显示监听套接字(所谓套接字就是使应用程序能够读写与收…

AI算力研究报告:智算供给格局分化国产化进程有望加速

今天分享的AI系列深度研究报告&#xff1a;《AI算力研究报告&#xff1a;智算供给格局分化国产化进程有望加速》。 &#xff08;报告出品方&#xff1a;华龙证券&#xff09; 报告共计&#xff1a;24页 1 大模型浪潮推动作用下,其力需求缺口将持续扩大 1.1 大模型发展对算力…

docker安装配置prometheus+node_export+grafana

简介 Prometheus是一套开源的监控预警时间序列数据库的组合&#xff0c;Prometheus本身不具备收集监控数据功能&#xff0c;通过获取不同的export收集的数据&#xff0c;存储到时序数据库中。Grafana是一个跨平台的开源的分析和可视化工具&#xff0c;将采集过来的数据实现可视…

【EI会议征稿】第二届材料科学与智能制造国际学术会议(MSIM 2024)

第二届材料科学与智能制造国际学术会议&#xff08;MSIM 2024&#xff09; 2024 2nd International Conference on Materials Science and Intelligent Manufacturing 2024年第二届材料科学与智能制造国际学术会议 &#xff08;MSIM2024&#xff09;将于2024年1月19日至21日在…

第60天:django学习(十)

choices参数的使用 choices参数应用场景&#xff1a; 学历&#xff1a;小学 初中 高中 本科 硕士 博士 客户来源:微信渠道 广告 介绍 QQ 性别&#xff1a;男 女 未知 对于以上可能被我们列举完的字段我们一般都是选择使用choices参来做 建表 class UserInfo(models.Model):us…

DBeaver 如何在没有外网的情况下连接数据库(下载驱动)

1.选择自己要连接的数据库 2.编辑驱动 3.选择你自己通过maven或者别的渠道下载的对应数据库的jar

深入探索Python delattr函数的威力与灵活性

引言&#xff1a; 在Python中&#xff0c;delattr函数是一个非常强大且灵活的工具&#xff0c;它允许我们删除对象的属性。使用delattr函数&#xff0c;我们可以动态地删除对象的属性&#xff0c;从而在编程中实现更灵活的操作。本文将详细介绍delattr函数的用法&#xff0c;帮…

数据结构与算法(五)回溯算法(Java)

目录 一、简介1.1 定义1.2 特性1.3 结点知识补充1.4 剪枝函数1.5 使用场景1.6 解空间1.7 实现模板 二、经典示例2.1 0-1 背包问题2.2 N皇后问题 一、简介 1.1 定义 回溯法&#xff08;back tracking&#xff09;是一种选优搜索法&#xff0c;又称为试探法&#xff0c;按选优条…

NVRAM相关

1. Modem NVRAM四个分区 nvdata&#xff1a;手机运行过程中&#xff0c;使用(读写)的NVRAM(除了存在protect_f和protect_s中的NVRAM)都是该分区的nvram文件。存储着普通NVRAM数据、 IMEI、barcode、Calibration数据等。对应的modem path是Z:\NVRAM。NVRAM目录下有CALIBRAT、NVD…

用序列化思想为自动化测试「提供动力」

Python 对象序列化技术 对象序列化是指将对象从内存转换为字节流的过程&#xff0c;以实现对象的持久化存储和网络传输。它在许多场景中都非常重要&#xff0c;比如远程调用、长期数据存储等。 在Python中&#xff0c;我们主要使用pickle和marshal这两个模块来实现对象的序列…

【银行测试】第三方支付平台业务流,功能/性能/安全测试方法...

目录&#xff1a;导读 前言一、Python编程入门到精通二、接口自动化项目实战三、Web自动化项目实战四、App自动化项目实战五、一线大厂简历六、测试开发DevOps体系七、常用自动化测试工具八、JMeter性能测试九、总结&#xff08;尾部小惊喜&#xff09; 前言 1、第三方支付平台…

vue+echarts实现桑吉图的效果

前言&#xff1a; 在我们项目使用图形的情况下&#xff0c;桑吉图算是冷门的图形了&#xff0c;但是它可以实现我们对多级数据之间数据流向更好的展示的需求&#xff0c;比如&#xff0c;我们实际数据流向中&#xff0c;具有1对多&#xff0c;多对多的情况下&#xff0c;如果用…

2023年淘宝天猫年终惊喜红包玩法

2023年淘宝天猫年终惊喜红包玩法&#xff0c;2023年淘宝年终好价节红包活动 随着2023年的尾声渐近&#xff0c;淘宝再次为广大用户带来了年终的惊喜——一场特别的红包活动。从12月8日零时开始&#xff0c;直至12月12日的午夜&#xff0c;淘宝app将开启一个为期五天的年终好价节…

Web前端工程的装机必备软件

前言 最近作者的电脑 C 盘变红了&#xff0c;这让我很难受(有点小强迫症)&#xff0c;所以准备重新安装下系统&#xff0c;顺便把 C 盘扩大点。 注意&#xff1a; 操作系统是 windows 11 23H2。 所有的命令行都是使用 Windows Terminal 中进行的。 安装 Windows Terminal 由于…

L1-026:I Love GPLT

题目描述 这道超级简单的题目没有任何输入。 你只需要把这句很重要的话 —— “I Love GPLT”——竖着输出就可以了。 所谓“竖着输出”&#xff0c;是指每个字符占一行&#xff08;包括空格&#xff09;&#xff0c;即每行只能有1个字符和回车。 输入样例&#xff1a; 无输出样…