【深度学习】pytorch——线性回归

news2025/1/6 19:12:22

笔记为自我总结整理的学习笔记,若有错误欢迎指出哟~

深度学习专栏链接:
http://t.csdnimg.cn/dscW7

pytorch——线性回归

  • 线性回归简介
  • 公式说明
  • 完整代码
  • 代码解释

线性回归简介

线性回归是一种用于建立特征和目标变量之间线性关系的统计学习方法。它假设特征和目标变量之间存在一个线性的关系,并试图通过拟合最佳的线性函数来预测目标变量。

线性回归模型的一般形式可以表示为:

y = w 0 + w 1 x 1 + w 2 x 2 + … + w n x n y = w_0 + w_1x_1 + w_2x_2 + \ldots + w_nx_n y=w0+w1x1+w2x2++wnxn

其中, y y y 是目标变量(或因变量), x 1 , x 2 , … , x n x_1, x_2, \ldots, x_n x1,x2,,xn 是特征变量(或自变量), w 0 , w 1 , w 2 , … , w n w_0, w_1, w_2, \ldots, w_n w0,w1,w2,,wn 是模型的参数,分别对应截距和各个特征的权重。

线性回归模型的训练过程就是寻找最优的参数 w 0 , w 1 , w 2 , … , w n w_0, w_1, w_2, \ldots, w_n w0,w1,w2,,wn 来使得模型的预测值与实际值之间的差异最小化。

公式说明

以下是代码涉及到的数学公式

  1. 线性回归模型

线性回归模型用于建立特征 x x x 和目标变量 y y y 之间的线性关系。在本代码中,线性回归模型被表示为:

y = w x + b y = wx + b y=wx+b

其中, w w w 是权重(即斜率), b b b 是偏置(即截距), x x x 是输入特征, y y y 是预测值。

  1. 损失函数

损失函数用于衡量模型预测值与实际标签之间的差异。在本代码中,使用的损失函数是均方误差(Mean Squared Error,MSE):

l o s s = 1 2 n ∑ i = 1 n ( y p r e d ( i ) − y ( i ) ) 2 loss = \frac{1}{2n} \sum_{i=1}^{n} (y_{pred}^{(i)} - y^{(i)})^2 loss=2n1i=1n(ypred(i)y(i))2

其中, y p r e d ( i ) y_{pred}^{(i)} ypred(i) 是模型的第 i i i 个样本的预测值, y ( i ) y^{(i)} y(i) 是实际标签, n n n 是样本数量。

  1. 其他运算

代码中还涉及到了矩阵乘法、矩阵转置、元素级别的操作等。例如, x . m m ( w ) x.mm(w) x.mm(w) 表示将输入特征 x x x 与权重 w w w 进行矩阵乘法; x T . m m ( d y _ p r e d ) x^T.mm(dy\_pred) xT.mm(dy_pred) 表示将输入特征 x x x 的转置与梯度 d y _ p r e d dy\_pred dy_pred 进行矩阵乘法。

完整代码

import torch as t
%matplotlib inline
from matplotlib import pyplot as plt
from IPython import display

device = t.device('cpu') #如果你想用gpu,改成t.device('cuda:0')

# 设置随机数种子,保证在不同电脑上运行时下面的输出一致
t.manual_seed(1000) 

def get_fake_data(batch_size=8):
    ''' 产生随机数据:y=x*2+3,加上了一些噪声'''
    x = t.rand(batch_size, 1, device=device) * 5
    y = x * 2 + 3 +  t.randn(batch_size, 1, device=device)
    return x, y

'''
# 产生的x-y分布
x, y = get_fake_data(batch_size=100)
plt.scatter(x.squeeze().cpu().numpy(), y.squeeze().cpu().numpy())
'''


# 随机初始化参数
w = t.rand(1, 1).to(device)
b = t.zeros(1, 1).to(device)

lr =0.02 # 学习率

for ii in range(500):
    x, y = get_fake_data(batch_size=4)
    
    # forward:计算loss
    y_pred = x.mm(w) + b.expand_as(y) # x@W等价于x.mm(w);for python3 only
    loss = 0.5 * (y_pred - y) ** 2 # 均方误差
    loss = loss.mean()
    
    # backward:手动计算梯度
    dloss = 1
    dy_pred = dloss * (y_pred - y)
    
    dw = x.t().mm(dy_pred)
    db = dy_pred.sum()
    
    # 更新参数
    w.sub_(lr * dw)
    b.sub_(lr * db)
    
    if ii%50 ==0:
        # 画图
        display.clear_output(wait=True)
        x = t.arange(0, 6).view(-1, 1)
        y = x.float().mm(w) + b.expand_as(x)
        plt.plot(x.cpu().numpy(), y.cpu().numpy(),color='b') # predicted
        
        x2, y2 = get_fake_data(batch_size=100) 
        plt.scatter(x2.numpy(), y2.numpy(),color='r') # true data
        
        plt.xlim(0, 5)
        plt.ylim(0, 15)
        plt.show()
        plt.pause(0.5)
        
print('w: ', w.item(), 'b: ', b.item())

输出结果为:
在这里插入图片描述
w: 1.9709817171096802 b: 3.1699466705322266

代码解释

  1. 导入需要的库:
import torch as t
%matplotlib inline
from matplotlib import pyplot as plt
from IPython import display

导入PyTorch库以及绘图相关的库,%matplotlib inline是Jupyter Notebook中的魔法命令,用于在Notebook中显示绘图。

  1. 设置随机数种子:
t.manual_seed(1000)

这行代码设置随机数种子,保证每次运行结果的随机数生成过程一致。

  1. 定义生成随机数据的函数:
def get_fake_data(batch_size=8):
    ''' 产生随机数据:y=x*2+3,加上了一些噪声'''
    x = t.rand(batch_size, 1, device=device) * 5
    y = x * 2 + 3 +  t.randn(batch_size, 1, device=device)
    return x, y

该函数用于产生随机的输入特征x和对应的标签y,其中y满足线性关系y = x * 2 + 3,并添加了一些随机噪声。

  1. 初始化模型参数:
w = t.rand(1, 1).to(device)
b = t.zeros(1, 1).to(device)

这里使用随机数初始化模型参数wb,并指定在CPU上进行计算。

  1. 设置学习率:
lr = 0.02

学习率lr控制每次参数更新的步长。

  1. 进行模型训练:
for ii in range(500):
    # 生成随机数据
    x, y = get_fake_data(batch_size=4)
    
    # forward:计算损失
    y_pred = x.mm(w) + b.expand_as(y)
    loss = 0.5 * (y_pred - y) ** 2
    loss = loss.mean()
    
    # backward:手动计算梯度
    dloss = 1
    dy_pred = dloss * (y_pred - y)
    
    dw = x.t().mm(dy_pred)
    db = dy_pred.sum()
    
    # 更新参数
    w.sub_(lr * dw)
    b.sub_(lr * db)

这里使用一个循环进行模型的训练,每次迭代都包含以下步骤:

  • 生成随机数据;
  • 前向传播:计算预测值y_pred和损失函数loss
  • 反向传播:手动计算梯度dwdb
  • 更新参数:根据梯度和学习率更新参数wb
  1. 可视化模型训练过程:
if ii % 50 == 0:
    display.clear_output(wait=True)
    x = t.arange(0, 6).view(-1, 1)
    y = x.float().mm(w) + b.expand_as(x)
    plt.plot(x.cpu().numpy(), y.cpu().numpy(), color='b') # predicted line
    
    x2, y2 = get_fake_data(batch_size=100)
    plt.scatter(x2.numpy(), y2.numpy(), color='r') # true data
    
    plt.xlim(0, 5)
    plt.ylim(0, 15)
    plt.show()
    plt.pause(0.5)

这部分代码用于可视化模型训练的过程,每50次迭代将当前参数下的预测结果以蓝色线条的形式绘制出来,并将随机生成的100个样本以红色散点图显示出来。

  1. 输出最终训练得到的参数:
print('w: ', w.item(), 'b: ', b.item())

输出训练得到的参数wb的值。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1163160.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

GoLong的学习之路(十七)基础工具之GORM(操作数据库)(更新)

书接上回,上回写道,GORM的查询和创建(插入数据),这回继续些增删改查的改和删的操作。 文章目录 更新update修改单个列修改多个列修改选定字段批量更新新阻止全局更新 使用 SQL 表达式更新注意 根据子查询进行更新不使用…

Modbus转Profinet网关与流量变送器兼容转ModbusTCP协议博图配置案例

首先,我们需要明确电磁流量计的通信协议是Modbus,而西门子1200PLC的通信协议是Profinet。这两种协议在功能和特性上存在一定的差异,因此需要使用兴达易控Modbus转Profinet网关设备进行转换。兴达易控的XD-MDPN100是Profinet转ModbusTCP的网关…

功率放大器的种类和作用是什么

功率放大器是一种电子设备,用于将输入信号的功率增加到更高的水平,以驱动负载或输出设备。功率放大器广泛应用于各种领域,包括通信、音频、无线电频谱分析、激光器和雷达等。 根据应用需求和工作原理不同,功率放大器可分为几种不同…

笔记:IDEA如何修改代码后,不重启服务器局部更新资源

前言 平常用IDEA开发网页写调样式和测功能最讨厌改一丁点东西就要重启整个服务器,所以本文主要就是解决此问题从而提高开发效率,避免浪费过多时间。 具体步骤 1、打开设置框 2、先新增exploded结尾的,并apply应用,把没有结尾的…

【Kubernetes部署】二进制部署单Master Kurbernetes集群 超详细

二进制部署K8s 一、基本架构和系统初始化操作1.1 基本架构1.2 系统初始化操作 二、部署etcd集群2.1 证书签发Step1 下载证书制作工具Step2 创建k8s工作目录Step3 编写脚本并添加执行权限Step4 生成CA证书、etcd 服务器证书以及私钥 2.2 启动etcd服务Step1 上传并解压代码包Step…

云尘-Node1 js代码

继续做题 拿到就是基本扫一下 nmap -sP 172.25.0.0/24 nmap -sV -sS -p- -v 172.25.0.13 然后顺便fscan扫一下咯 nmap: fscan: 还以为直接getshell了 老演员了 其实只是302跳转 所以我们无视 只有一个站 直接看就行了 扫出来了两个目录 但是没办法 都是要跳转 说明还是需要…

轻松搭建Nextcloud私有云盘并实现远程访问【内网穿透】

文章目录 摘要1. 环境搭建2. 测试局域网访问3. 内网穿透3.1 ubuntu本地安装cpolar3.2 创建隧道3.3 测试公网访问 4 配置固定http公网地址4.1 保留一个二级子域名4.1 配置固定二级子域名4.3 测试访问公网固定二级子域名 摘要 Nextcloud,它是ownCloud的一个分支,是一个文件共享服…

opencv第一个例子

目的 这是用用QTopencv实现的一个完整的展示图片的例子,包括了项目的配置文件,完整的代码,以用做初次学习opencv用。 代码 工程文件: QT core guigreaterThan(QT_MAJOR_VERSION, 4): QT widgetsTARGET openCv1 TEMPL…

双路比例阀放大器

双路比例阀放大器是一种常见的电子设备,它能够将输入信号放大到所需的水平,并输出两个相等或不同的放大信号。这种放大器通常由一个放大器和一个驱动电路组成,可以用于各种应用中,如液压控制、气动控制等。 在液压控制方面&#…

物联网系统的基本构件

1.基本组件 云服务器 数据库消息服务器应用服务器管理平台 云APP 云服务器的维护终端微信客户端网页管理平台 页面式的更全面的管理。组态软件和PLC软件 编程软件终端设备 PLC 主要指标,模拟数字接口数量 DO有 继电器和1.5,2.5.5V数字输出一般支持扩展IO模块模拟量…

利用win32的GetLastInputInfo函数实现锁屏(C#)

前两天看到群里面讨论这个问题,刚好我们上一家公司的系统也有这个功能,就研究了一下,我们这边实现这个功能的目的如下:当用户长时间不操作系统时,自动退出系统并退回到登录界面,想要使用系统,就…

软文投放、发稿:如何写一篇优质的软文

在当今的营销世界中,软文是一种强大的工具,可以用来宣传产品、建立品牌形象,以及与受众建立更深层次的联系。然而,要写一篇优质的软文并不容易。本文将介绍如何撰写一篇引人入胜的软文,以吸引读者的兴趣和赢得他们的信…

用 Java 实现 Syslog 功能

1、业务场景 用一个 Spring Boot 的项目去实现对管控设备的监控、日志收集等。同时需要将接收到的日志进行入库,每天存一张表,如device_log_20231026… 2、Syslog客户端(接收日志的服务器,即运行Java程序的服务器) 2…

JavaScript 基础 - 第4天

理解封装的意义,能够通过函数的声明实现逻辑的封装,知道对象数据类型的特征,结合数学对象实现简单计算功能。 理解函数的封装的特征掌握函数声明的语法理解什么是函数的返回值知道并能使用常见的内置函数 函数 理解函数的封装特性&#xff0c…

软件测试/测试开发丨UbuntuServer环境准备

点此获取更多相关资料 前提 现有设备是一套 i54090 的组合,安装了 Ubuntu 22.04.3 LTS Server 版本,后文的安装步骤都是基于这套系统和配置进行操作。 系统准备 查看是否安装了 gcc 命令行中执行 gcc -v 正常输入如图效果的,说明已经成功…

kubeadm部署kubernetes1.28

k8s在1.24版本以后删除了内置dockershim插件,原生不再支持docker运行时,需要使用第三方cri接口cri-docker https://github.com/Mirantis/cri-dockerd.git 安装前,需要先升级systemd和主机内核,本操作文档安装的是最新的版本kube…

微信小程序渲染的富文本里面除了img标签外什么都没有,该如何设置img的大小

微信小程序富文本渲染&#xff1a; <rich-text nodes"{{content}}"style"{{style}}" ></rich-text> content是接口得到的值 let cont object.contentlet a cont.replace(/<img/gi,<img style"max-width:94%;height:auto;margi…

大厂面试题-什么是IO的多路复用机制?

IO多路复用机制&#xff0c;核心思想是让单个线程去监视多个连接&#xff0c;一旦某个连接就绪&#xff0c;也就是触发了读/写事件。 就通知应用程序&#xff0c;去获取这个就绪的连接进行读写操作。 也就是在应用程序里面可以使用单个线程同时处理多个客户端连接&#xff0c…

四川竹哲电子商务有限公司服务怎么样?

随着抖音电商的日益崛起&#xff0c;越来越多的商家开始关注这个充满无限商机的平台。四川竹哲电子商务有限公司作为一家专业的抖音电商服务公司&#xff0c;凭借其丰富的经验和优秀的服务&#xff0c;成为了众多商家在抖音电商领域中的重要合作伙伴。 一、专业实力 作为一家专…

同为科技(TOWE)国标10A电瓶车专用智能定时桌面PDU插线板

电动车作为现在国内保有量较高的一种交通工具&#xff0c;因其价格亲民、环保便捷、使用成本低等原因受到广大人民群众的欢迎。然而&#xff0c;电瓶车充电问题长期以来是广受关注的社会性话题&#xff0c;过充短路爆充引起火灾事故的新闻时有发生&#xff0c;80%的电动车火灾都…