一次输入多个数据-batchsize大于1的简单的线性回归模型-标量

news2024/12/27 13:14:15

最简单的线性回归模型-标量

接上篇,由于batchsize为1,因此loss有很大的波动,这篇我们讨论batchsize大于1的情况。若batchsize数量为N,则 y = w x + b y=wx+b y=wx+b的损失函数为:
L = ∑ i = 1 N ( w x i ∗ + b − y i ∗ ) 2 = ( w x T + b e T − y T ) ( w x + b e − y ) \begin{aligned} L&=\sum_{i=1}^{N}(wx_i^*+b-y_i^*)^2\\ &=(w\boldsymbol{x}^T+b\boldsymbol{e}^T-\boldsymbol{y}^T)(w\boldsymbol{x}+b\boldsymbol{e}-\boldsymbol{y}) \end{aligned} L=i=1N(wxi+byi)2=(wxT+beTyT)(wx+bey)
为了方便计算在对损失函数乘一个数值,不影响其极值,因此将损失函数变为:
L = 1 2 ∑ i = 1 N ( w x i ∗ + b − y i ∗ ) 2 L=\frac{1}{2}\sum_{i=1}^{N}(wx_i^*+b-y_i^*)^2 L=21i=1N(wxi+byi)2
求出 w w w b b b的梯度:
∂ L ∂ w = ∑ i = 1 N ( w x i ∗ + b − y i ∗ ) x i ∗ = ∑ i = 1 N w x i ∗ 2 + ∑ i = 1 N b x i ∗ − ∑ i = 1 N y i ∗ x i ∗ = w x T x + b e T x − y T x = ( w x T + b e T − y T ) x \begin{aligned} \frac{\partial{L}}{\partial{w}}&=\sum_{i=1}^{N}(wx_i^*+b-y_i^*)x_i^*\\ &=\sum_{i=1}^{N}wx_i^{*2}+\sum_{i=1}^{N}bx_i^*-\sum_{i=1}^{N}y_i^*x_i^*\\ &=w\boldsymbol{x}^T\boldsymbol{x}+b\boldsymbol{e}^T\boldsymbol{x}-\boldsymbol{y}^T\boldsymbol{x}\\ &=(w\boldsymbol{x}^T+b\boldsymbol{e}^T-\boldsymbol{y}^T)\boldsymbol{x} \end{aligned} wL=i=1N(wxi+byi)xi=i=1Nwxi2+i=1Nbxii=1Nyixi=wxTx+beTxyTx=(wxT+beTyT)x
∂ L ∂ b = ∑ i = 1 N ( w x i ∗ + b − y i ∗ ) = ( w x T + b e T − y T ) e \begin{aligned} \frac{\partial{L}}{\partial{b}}&=\sum_{i=1}^{N}(wx_i^*+b-y_i^*)\\ &=(w\boldsymbol{x}^T+b\boldsymbol{e}^T-\boldsymbol{y}^T)\boldsymbol{e} \end{aligned} bL=i=1N(wxi+byi)=(wxT+beTyT)e
其中 x \boldsymbol{x} x为每个batch中所有的 x ∗ x^* x组成的N维列向量, y \boldsymbol{y} y为每个batch中所有的 y ∗ y^* y组成的N维列向量, e \boldsymbol{e} e是长度为N的列向量,**使用向量表示可以让我们轻松使用numpy实现回归过程。**使用python实现结果如下:

import numpy as np
import random
import matplotlib.pyplot as plt

x = np.array([0.1,1.2,2.1,3.8,4.1,5.4,6.2,7.1,8.2,9.3,10.4,11.2,12.3,13.8,14.9,15.5,16.2,17.1,18.5,19.2])
y = np.array([5.7,8.8,10.8,11.4,13.1,16.6,17.3,19.4,21.8,23.1,25.1,29.2,29.9,31.8,32.3,36.5,39.1,38.4,44.2,43.4])
print(x,y)
plt.scatter(x,y)
plt.show()

散点图如下:
在这里插入图片描述
回归过程使用numpy中的矩阵计算完全按照上述损失函数和梯度直接计算即可:

# 设定步长
step=0.001
# 存储每轮损失的loss数组
loss_list=[]
# 定义epoch
epoch=500
# 定义batch_size
batch_size=18
# 定义单位列向量e
e=np.ones(batch_size).reshape(batch_size,1)

# 定义参数w和b并初始化
w=0.0
b=0.0

#梯度下降回归
for i in range(epoch) :
    #计算当前输入x和标签y的索引,由于x和y数组长度一致,因此通过i整除x的长度即可获得当前索引
    index = i % int(len(x)/batch_size)
    # 当前轮次的x列向量值为:
    cx=x[index*batch_size:(index+1)*batch_size]
    cx=cx.reshape(len(cx),1)
    # 当前轮次的y列向量值为:
    cy=y[index*batch_size:(index+1)*batch_size]
    cy=cy.reshape(len(cy),1)

    # 计算当前loss
    curloss = (w*cx.T+b*e.T-cy.T).dot((w*cx+b*e-cy))
    loss_list.append(float(curloss))

    # 计算参数w和b的梯度
    grad_w = (w*cx.T+b*e.T-cy.T).dot(cx)
    grad_b = (w*cx.T+b*e.T-cy.T).dot(e)
    # 更新w和b的值
    w -= step*grad_w
    b -= step*grad_b

损失函数和最终拟合结果如下:

print(loss_list)
plt.plot(loss_list)
plt.show()

在这里插入图片描述

pred_y = w*x+b
plt.scatter(x,y)
plt.plot(x,pred_y.reshape(len(x)),c='r')
plt.show()

在这里插入图片描述
可以看到增大batsize后损失函数比较稳定。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/392729.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

html网页源码加密

html加密、网页加密、网页源码加密html网页源码能加密吗?能加密到何种程度? 某些时候,我们可能需要对html网页源码加密,使网页源码不那么容易被他人获得。出于这个目标,本文测试一种html加密方式。 提前透露&#xf…

Linux系统PWM应用编程

目录应用层如何操控PWM编写应用程序在开发板上测试本章我们将学习如何对开发板上的PWM 设备进行应用编程。 应用层如何操控PWM 与LED 设备一样,PWM 同样也是通过sysfs 方式进行操控,进入到/sys/class/pwm 目录下,如下所示: 这里…

Java多态性

用一句话概括就是:事物在运行过程中存在不同的状态。 多态的存在有三个前提: 1.要有继承关系 2.子类要重写父类的方法 3.父类引用指向子类对 但是其中又有很多细节需要注意。首先我们定义两个类,一个父类Animal,一个子类Cat。 父类Animal cl…

Hive学习——企业级调优

目录 一、计算资源调优 (一)Yarn资源配置——集群 1.Yarn配置说明 (1)yarn.nodemanager.resource.memory-mb (2)yarn.nodemanager.resource.cpu-vcores (3)yarn.scheduler.maximum-allocation-mb (4)yarn.scheduler.minimum-allocation-mb (二)MapReduce资源配置 二、…

裁员降本,扭转颓势!通用汽车吹响智能电动「中国集结号」

2023年,将是合资品牌能否搭上中国智能电动市场红利的关键一年。 全新一代VCS智能座舱(高通8155,30英寸6K曲面OLED显示屏,12.6英寸WHUD以及5G版本的别克eConnect车联系统)、全新一代Super Cruise超级辅助驾驶系统&#…

7.SpringSecurity中的权限管理

SpringSecurity中的权限管理 SpringSecurity是一个权限管理框架,核心是认证和授权,前面已经系统的给大家介绍过了认证的实现和源码分析,本文重点来介绍下权限管理这块的原理。 一、权限管理的实现 服务端的各种资源要被SpringSecurity的权限…

ccc-pytorch-卷积神经网络介绍(5)

文章目录一、卷积二、池化三、Batch Norm四、经典卷积网络简单介绍一、卷积 卷积连续函数形式: F(x)∫f(t)g(x−t)dtF(x)\int f(t)g(x-t)dtF(x)∫f(t)g(x−t)dt 物理意义是一个函数在另一个函数上的加权叠加。在2D卷积中指卷积核在数据矩阵中分割出的矩阵和卷积核相…

PMP和软考高项集成,更应该考哪个呢?

要看你自己的偏向,要说考的话,我是觉得都值得考的,一个证一份技术嘛。 我给你稍微介绍一下,PMP都是美国PMI发起的考试,软考高项是国内的考试。PMP是项目管理证书,学习的内容是项目管理,包含大约…

研报精选230306

目录 【行业230306东亚前海证券】食品饮料行业2023年年度投资策略:复苏在途,蓄势待发【行业230306国金证券】基础化工行业研究:MDI价格上行,新一轮国企改革在即【行业230306中银证券】华为汽车产业链深度报告:三种合作…

k8s控制器

目录 一、控制器简介 二、控制器类型 1、RC和RS 2、Deployment 3、DaemonSet 4、Job 5、CronJob 6、StateFulSet 7、HPA 一、控制器简介 在kubernetes中,按照Pod的创建方式可以将其分为两类: 自主式:kubernetes直接创建出来的Pod,…

【Alamofire】【Swift】属性包装器注解@propertyWrapper

Alamofire 中的源码例子 import Foundationprivate protocol Lock {func lock()func unlock() }extension Lock {/// Executes a closure returning a value while acquiring the lock.////// - Parameter closure: The closure to run.////// - Returns: The value…

9.SpringSecurity核心过滤器-SecurityContextPersistenceFilter

SpringSecurity核心过滤器-SecurityContextPersistenceFilter 一、SpringSecurity中的核心组件 在SpringSecurity中的jar分为4个,作用分别为 jar作用spring-security-coreSpringSecurity的核心jar包,认证和授权的核心代码都在这里面spring-security-co…

Promise入门

Promise入门 Promise的基本概念 男女相爱了&#xff0c;女方向男方许下一个承诺怀孕new Promise&#xff0c;这是会产生两种结果怀上(resolve)和没怀上(reject)&#xff0c;resolve对应then&#xff0c;reject对应catch&#xff0c;无论是否怀上都会执行finally。 <script&…

【论文速递】CASE 2022 - EventGraph: 将事件抽取当作语义图解析任务

【论文速递】CASE 2022 - EventGraph: 将事件抽取当作语义图解析任务 【论文原文】&#xff1a;https://aclanthology.org/2022.case-1.2.pdf 【作者信息】&#xff1a;Huiling You, David Samuel, Samia Touileb, and Lilja vrelid 论文&#xff1a;https://aclanthology.o…

sql server 对比两个查询性能 ,理解Elapsed Time、CPU Time、Wait Time

分析 SET STATISTICS TIME ONyour sqlSET STATISTICS TIME OFF由上图分析: cpu time 是查询执行时占用的 cpu 时间。如果了解系统的多任务机制&#xff0c;就会知道系统会将整个 cpu 时间分为一个一个时间片&#xff0c;平均分配给运行的线程——一个线程在 cpu 上运行一段时间…

《PyTorch深度学习实践9》——卷积神经网络-高级篇(Advanced-Convolution Neural Network)

一、1∗11*11∗1卷积 由下面两张图&#xff0c;可以看出1∗11*11∗1卷积可以显著降低计算量。 通常1∗11*11∗1卷积还有以下功能&#xff1a; 一是用于信息聚合&#xff0c;同时增加非线性&#xff0c;1∗11*11∗1卷积可以看作是对所有通道的信息进行线性加权&…

Air101|Air103|Air105|Air780E|ESP32C3|ESP32S3|Air32F103开发板:概述及PinOut

1、合宙Air101&#xff08;芯片及开发板&#xff09; 合宙Air101是一款QFN32 封装&#xff0c;4mm x 4mm 大小的mcu。通用串口波特率&#xff0c;设置波特率为921600。 ​ 管脚映射表 GPIO编号 命名 默认功能及扩展功能 0 PA0 BOOT 1 PA1 I2C_SCL/ADC0 4 PA4 I2C_S…

前端必备技术之——AJAX

简介 AJAX 全称为 Asynchronous JavaScript And XML&#xff0c;就是异步的 JS 和 XML(现在已经基本被json取代)。通过 AJAX 可以在浏览器中向服务器发送异步请求&#xff0c;最大的优势&#xff1a;无刷新获取数据。AJAX 不是新的编程语言&#xff0c;而是一种将现有的标准组…

揭秘关键一环!数据安全服务大盘点

数据安全服务&#xff0c;数据安全体系建设的关键一环。通过数据安全服务解决数据安全建设难题&#xff0c;得到越来越多的重视。不久前&#xff0c;《工业和信息化部等十六部门关于促进数据安全产业发展的指导意见》发布&#xff0c;明确“壮大数据安全服务”&#xff0c;推进…

VScode 插件【配置】

写这篇博客的原因&#xff1a; vscode 很久以前的插件&#xff0c;忘记是干什么的了记录 vscode 好用的插件 插件介绍&#xff08;正文开始&#xff09; Auto Rename tag 开始/关闭标签内容 同步 Chinese (Simplified) VScode 中文化 CSS Peek 通过 html 代码查找到引用的样式…