Python深度学习基础(八)——线性回归

news2024/12/30 3:47:23

线性回归

  • 引言
  • 损失函数
  • 解析解
    • 公式
    • 代码
    • 实例
  • 梯度下降
    • 理论
    • 随机梯度下降的手动实现代码
    • torch中的随机梯度下降


引言

我们生活中可能会遇到形如

y = w 1 x 1 + w 2 x 2 + w 3 x 3 + b y=w_1x_1+w_2x_2+w_3x_3+b y=w1x1+w2x2+w3x3+b

的问题,其中有y为输出,x为输入,w为权值,b为偏置

假设我们有一个房价预测的问题,我们有很多条数据,每一个数据项有很多特征,这些特征就是x,而房价就是y,线性回归要解决的就是得出一批合适的w和b来实现x向y的映射,使得我们得到x时就可以预测出y。

损失函数

为了求得权值和偏置的最优值,我们需要定义损失函数,通过降低损失函数的损失进行权值和偏置的优化,我们常用的有如下三种损失值
均方误差
l ( w , b ) = 1 2 ( x w + b − y ) 2 l(w, b)=\frac{1}{2}(xw+b-y)^2 l(w,b)=21(xw+by)2
这里的1/2并没有什么含义,只是为了求导后计算方便
在多个样本上可以表述为
L ( W , b ) = 1 2 n ∑ i = 1 n ( X ( i ) W + b − Y ( i ) ) 2 L(W, b)= \frac{1}{2n}\sum_{i=1}^n(X^{(i)}W+b-Y^{(i)})^2 L(W,b)=2n1i=1n(X(i)W+bY(i))2
同时为了方便处理,通常我们会在数据后加一列1,这样偏置也会并入到权值当中,即

[ x 1 ( 1 ) x 2 ( 1 ) 1 x 1 ( 2 ) x 2 ( 2 ) 1 x 1 ( 3 ) x 2 ( 3 ) 1 ] ⋅ [ w 1 w 2 b ] = [ y ( 1 ) y ( 2 ) y ( 3 ) ] \begin{bmatrix} x_1^{(1)}&x_2^{(1)}&1\\ x_1^{(2)}&x_2^{(2)}&1\\ x_1^{(3)}&x_2^{(3)}&1\\ \end{bmatrix} \cdot \begin{bmatrix} w_1 \\ w_2 \\ b \end{bmatrix}= \begin{bmatrix} y^{(1)} \\ y^{(2)} \\ y^{(3)} \end{bmatrix} x1(1)x1(2)x1(3)x2(1)x2(2)x2(3)111 w1w2b = y(1)y(2)y(3)
我们令
注:为了方便表示,这里用三个数据,每个数据有两个数据项的数据表示

我们令
W = [ w 1 w 2 b ] W=\begin{bmatrix} w_1 \\ w_2 \\ b \end{bmatrix} W= w1w2b
那么我们的均方误差就变成了
L ( W , b ) = 1 2 n ∑ i = 1 n ( X ( i ) W − Y ( i ) ) 2 L(W, b)= \frac{1}{2n}\sum_{i=1}^n(X^{(i)}W-Y^{(i)})^2 L(W,b)=2n1i=1n(X(i)WY(i))2

解析解

线性回归问题存在解析解

公式

首先我们在L上对W求导
∇ w L = 1 n ∑ i = 1 n ( X ( i ) W − Y ( i ) ) T X ( i ) \nabla _wL= \frac{1}{n} \sum_{i=1}^n(X^{(i)}W-Y^{(i)})^{T}X^{(i)} wL=n1i=1n(X(i)WY(i))TX(i)
最优的解即为L=0的解,即
W ( ∗ ) T X ( i ) T X ( i ) − Y ( i ) T X ( i ) = 0 ⇒ W ( ∗ ) T = Y ( i ) T X ( i ) ( X ( i ) T X ( i ) ) − 1 ⇒ W ( ∗ ) = ( X ( i ) T X ( i ) ) − 1 X ( i ) T Y ( i ) W^{(*)T}X^{(i)T}X^{(i)}-Y^{(i)T}X^{(i)}=0 \\ \Rightarrow W^{(*)T}=Y^{(i)T}X^{(i)}(X^{(i)T}X^{(i)})^{-1} \\ \Rightarrow W^{(*)} = (X^{(i)T}X^{(i)})^{-1} X^{(i)T}Y^{(i)} W()TX(i)TX(i)Y(i)TX(i)=0W()T=Y(i)TX(i)(X(i)TX(i))1W()=(X(i)TX(i))1X(i)TY(i)

代码

如果使用numpy,假设我们有X和y这个过程可以表述为

# 第一步是增加一列1,这样可以使得w和b合并
X_b = np.c_[np.ones((X.shape[0], 1)), X]
theta_best = np.linalg.inv(X_b.T.dot(X_b)).dot(X_b.T).dot(y)  

其中np.linalg.inv是求矩阵的逆

实例

import torch
from torch.utils import data
import numpy as np
import random

首先我们定义一个生成数据的函数

def synthetic_data(w, b, num_examples):
    X = torch.normal(0, 1, (num_examples, w.shape[0]))  # 生成均值为0,方差为1的数据
    y = torch.matmul(X, w) + b  # 生成标签
    y += torch.normal(0, 0.01, y.shape)  # 均值为0,方差为0.01的正态分布
    return X, y.reshape((-1, 1))

我们假设w为[2, -3.4],b为4.2,我们生成线性数据

true_w = torch.tensor([2, -3.4])
true_b = 4.2
features, labels = synthetic_data(true_w, true_b, 1000)

我们将w和b合并

X = np.array(features)
X_b = np.c_[np.ones((X.shape[0], 1)), X]
y = np.array(labels)

获得解析解

theta_best = np.linalg.inv(X_b.T.dot(X_b)).dot(X_b.T).dot(y)
theta_best

在这里插入图片描述
可以看到权值和偏置的解析解基本与真实值相同

梯度下降

理论

但是不是所有的问题都可以得到解析解,所以我们一般使用梯度下降的方式进行优化,优化方式是:各个参数沿着梯度的反方向更新,梯度方向就是方向导数最大的方向,用公式表示
( w , b ) ← ( w , b ) − ∇ ( w , b ) L (w, b) \leftarrow (w, b)-\nabla _{(w, b)}L (w,b)(w,b)(w,b)L

随机梯度下降的手动实现代码

数据迭代器
首先为了读出数据,我们先创建一个函数作为数据迭代器,其本质是生成器,安装小梯度的梯度大小生成数据

def data_iter(batch_size, features, labels):
    num_examples = len(features)
    indices = list(range(num_examples))
    random.shuffle(indices)
    for i in range(0, num_examples, batch_size):
        batch_indices = torch.tensor(indices[i:min(i+batch_size, num_examples)])
        yield features[batch_indices], labels[batch_indices]

权值和偏置初始化

w = torch.normal(0, 0.01, size=(2, 1), requires_grad=True)
b = torch.zeros(1, requires_grad=True)

定义线性模型

def linreg(X, w, b):
    return torch.matmul(X, w) + b

定义损失函数

def squared_loss(y_hat, y):
    return (y_hat - y) ** 2 / 2

定义随机梯度下降优化器

def sgd(params, lr, batch_size):
    with torch.no_grad():
        for param in params:
            param -= lr * param.grad / batch_size
            param.grad.zero_()

参数设置

lr = 0.03
num_epochs = 3
net = linreg
loss = squared_loss

训练模型

for epoch in range(num_epochs):
    for X, y in data_iter(batch_size, features,labels):
        l = loss(net(X, w, b), y)  # 小批量损失
        l.sum().backward()  # 总损失进行反向传递
        sgd([w, b], lr, batch_size)
        
    with torch.no_grad():
        train_1 = loss(net(features, w, b), labels)
        print(f'epoch {epoch+1}, loss:{float(train_1.mean()):f}')

结果如下:

epoch 1, loss:0.033043
epoch 2, loss:0.000118
epoch 3, loss:0.000050

求损失

print(f'w损失:{true_w - w.reshape(true_w.shape)}')
print(f'b损失:{true_b - b}')

结果如下:

w损失:tensor([ 0.0007, -0.0011], grad_fn=<SubBackward0>)
b损失:tensor([0.0002], grad_fn=<RsubBackward1>)

torch中的随机梯度下降

数据加载器
使用torch.utils中的data来构建数据加载器
使用data.TensorDataset从tensor格式的数据中构建数据集,传入的参数应该是数据和标签组成的元组
使用data.DataLoader,传入数据集和批量大小,以及是否打乱顺序,使用这个函数按照批量大小加载数据

def load_array(data_arrays, batch_size, is_train=True):
    dataset = data.TensorDataset(*data_arrays)
    return data.DataLoader(dataset, batch_size, shuffle=is_train)

设置参数

batch_size = 10
data_iter = load_array((features, labels), batch_size)
next(iter(data_iter))

结果如下:
在这里插入图片描述

构建模型
在torch中,Linear为全连接层,通过Sequential构建线性模型

from torch import nn
​
net = nn.Sequential(nn.Linear(2, 1))
help(nn.Linear)

在其中可以看到Attributes里边有weight和bias,通过这个函数我们可以获取权值和偏置
在这里插入图片描述

net[0].weight

结果如下:

Parameter containing:
tensor([[0.6657, 0.1449]], requires_grad=True)

net[0].bias

结果如下:

Parameter containing:
tensor([-0.6534], requires_grad=True)

help(net[0].weight)

这里我们可以看到使用data可以查看weight的数据
在这里插入图片描述

net[0].weight.data

tensor([[0.6657, 0.1449]])

初始化权值和偏置
我们使用均值为0,方差为0.01的正态分布的数据初始化weight,将偏置设置为0

net[0].weight.data.normal_(0, 0.01)

tensor([[-0.0092, 0.0053]])

net[0].weight.data

tensor([[-0.0092, 0.0053]])

net[0].bias.data.fill_(0)

tensor([0.])

net[0].bias.data

tensor([0.])

损失函数和优化器

选用均方误差作为损失函数,随机梯度下降优化参数

loss = nn.MSELoss()
trainer = torch.optim.SGD(net.parameters(), lr=0.03)

训练

num_epochs= 3
for epoch in range(num_epochs):
    for X, y in data_iter:
        l = loss(net(X), y)
        trainer.zero_grad()
        l.backward()
        trainer.step()
        
    l = loss(net(features), labels)
    print(f'epoch:{epoch+1}, loss:{l:f}')

epoch:1, loss:0.000201
epoch:2, loss:0.000100
epoch:3, loss:0.000100

net[0].weight.data

tensor([[ 1.9992, -3.3989]])

net[0].bias.data

tensor([4.2002])

print(f'w损失:{net[0].weight.data - true_w}')
print(f'b损失:{net[0].bias.data - true_b}')

w损失:tensor([[-0.0008, 0.0011]])
b损失:tensor([0.0002])

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/146715.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Java设计模式中工厂模式是啥?静态工厂、简单工厂与抽象工厂,工厂方法模式又是啥,怎么用,

继续整理记录这段时间来的收获&#xff0c;详细代码可在我的Gitee仓库SpringBoot克隆下载学习使用&#xff01; 4.3 工厂模式 4.3.1 背景 若创建对象时直接new对象&#xff0c;则会使对象耦合严重&#xff0c;更换对象则很复杂 4.3.2 简单工厂 4.3.3 特点 不是一种设计模…

c语言 文件处理2 程序环境和预处理

对比函数 sprintf&#xff08;把一个格式化数据转化为字符串&#xff09; sscanf &#xff08;从一个字符串中读一个格式化数据&#xff09; struct S {char arr[10];int age;float f; };int main() {struct S s { "hello", 20, 5.5f };//把这个转化为一个字符串s…

idea调试unity里面的lua代码

前言 本人一名java后端开发&#xff0c;看到前端同事调试lua代码无脑print&#xff0c;甚为鄙视&#xff0c;百度加实操写一份调试unity的lua脚本文档 操作 1.安装lua lua官网下载页面 最终下载页面 2.idea安装插件 emmylua 3.idea打开unity的lua脚本 idea->file->op…

【面试题】面试如何正确的介绍项目经验

大厂面试题分享 面试题库前端面试题库 &#xff08;面试必备&#xff09; 推荐&#xff1a;★★★★★地址&#xff1a;前端面试题库1、在面试前准备项目描述&#xff0c;别害怕&#xff0c;因为面试官什么都不知道面试官是人&#xff0c;不是神&#xff0c;拿到你的简历的时候…

ospf双向重发布,LSA优化综合

目录实验分析ip地址划分写公网缺省路由区域0公网MGRE搭建各个区域ospf的宣告改变ospf接口工作方式和更改接口优先级ospf多进程及双向重发布减少LSA的更新量1&#xff0c;减少特殊区域的LSA更新量2&#xff0c;骨干区域的优化域间汇总域外汇总防环nat的设置实验分析 如图实际的…

VS Code 为 Clang for MSVC 配置 cmake cmake tools

介绍 在windows平台上&#xff0c;由于平台API差异过大&#xff0c;一般为linux设计的项目&#xff08;POSIX兼容&#xff09;无法通过MSVC的编译&#xff0c;而是会报非常多的头文件错误。如果要修改&#xff0c;工程量将巨大。Windows平台上&#xff0c;主要有两个类POSIX兼容…

【JavaScript】事件--总结

千锋 1.Event 对象 代表事件的状态&#xff0c;比如事件在其中发生的元素、键盘按键的状态、鼠标的位置、鼠标按钮的状态。 div{width: 200px;height: 200px;background-color: yellow;} </style> <body><input type"text" id"username"&…

JavaScript 事件

文章目录JavaScript 事件HTML 事件常见的HTML事件JavaScript 可以做什么?JavaScript 事件 HTML 事件是发生在 HTML 元素上的事情。 当在 HTML 页面中使用 JavaScript 时&#xff0c; JavaScript 可以触发这些事件。 HTML 事件 HTML 事件可以是浏览器行为&#xff0c;也可以是…

babel做兼容处理 到底怎么使用?

1.背景 日常项目开发中总是避免不了对低版本浏览器做一些兼容处理&#xff0c;最常见的手段就是结合编译工具使用babel来处理一些语法的兼容&#xff0c;但是每次使用的时候其实Babel的配置和使用到的相关库总是云里雾里&#xff0c;网上的各种推荐方案眼花缭乱不知道到底应该…

自定义DotNetCore 项目模板

在进行代码开发时候&#xff0c;各自的团队或者公司都会有服务自己要求的项目代码模板&#xff0c;再创建新的项目时&#xff0c;都需要按照模板规范进行定义&#xff0c;NET支持自定义项目模板&#xff0c;这样在进行项目创建时就会高效很多。 官方参考文档&#xff1a;dotne…

软测复习01:软件测试概述

文章目录软件测试的目的软件测试的定义软件测试与软件开发软件测试发展软件测试的目的 基于不同的立场&#xff0c;存在着两种完全不同的测试目的 从用户的角度出发&#xff0c;希望通过软件测试暴露软件中隐藏的错误和缺陷&#xff0c;以考虑是否可接受该产品。从软件开发者的…

Java当中的定时器

目录 一、什么是定时器 二、Java当中的定时器 ①schedule()方法&#xff1a; ②TimerTask ​编辑 ③delay 三、实现一个定时器 前提条件: 代码实现: ①确定一个“任务”&#xff08;MyTask)的描述&#xff1a; ②schedule方法&#xff1a; ③需要一个计时器 属性…

MAT-内存泄漏工具使用

目录 一、MAT简介 1.1 MAT介绍 1.2 MAT工具的下载安装 二、使用MAT基本流程 2.1 获取HPROF文件 2.2 MAT主界面介绍 2.3 MAT中的概念介绍 2.3.1 Shallow heap 2.3.2 Retained Heap 2.3.3 GC Root 2.4 MAT中的一些常用的视图 2.4.1 Thread OvewView 2.4.2 Group 2.…

复杂工况下少样本轴承故障诊断的元学习

摘要&#xff1a;近年来&#xff0c;基于深度学习的轴承故障诊断得到了较为系统的研究。但是&#xff0c;这些方法中的大多数的成功在很大程度上依赖于大量的标记数据&#xff0c;而这些标记数据在实际生产环境中并不总是可用的。如何在有限的数据条件下训练出鲁棒的轴承故障诊…

线程状态到底是5种还是六种?傻傻分不清楚

目录 从操作系统层面上描述线程状态 从javaAPI层面上理解线程的6种状态 线程的状态转换. NEW --> RUNNABLE 1.RUNNABLE <--> WAITING 2.RUNNABLE <--> WAITING 3.RUNNABLE <--> WAITING 1.RUNNABLE <--> TIMED_WAITING 2.RUNNABLE <--&…

开源天气时钟项目删减和更新

开源天气时钟项目删减和更新&#x1f4cc;原项目开源地址&#xff1a;https://gitee.com/liuzewen/ESP8266-SSD1306-Watch-mini ✨本文只针对Arduino IDE平台代码进行删减和更新。 &#x1f4fa;按键菜单功能 &#x1f33c;天气时钟功能整体架构描述 代码中所使用的库&…

【MySQL】十,SQL执行流程

MySQL中的SQL执行流程 MySQL的查询流程 查询缓存 Server 如果在查询缓存中发现了这条 SQL 语句&#xff0c;就会直接将结果返回给客户端 如果没有&#xff0c;就进入到解析阶段&#xff08;MySQL 8.0 已经废弃了查询缓存功能&#xff09;。 解析器 在解析器中对 SQL 语句进行…

36、基于STM32的电子闹钟(DS1302)

编号&#xff1a;36 基于STM32的电子闹钟&#xff08;DS1302&#xff09; 功能描述&#xff1a; 本设计由STM32单片机液晶1602按键DS1302时钟模块声光报警模组成。 1、采用STM32F103最小系统。 2、利用DS1302芯片提供时钟信号 3、液晶1602实时显示年月日、时分秒、星期等信息…

java线程池原理

背景&#xff1a;为什么需要线程池java中的线程池是运用场景最多的并发框架&#xff0c;几乎所有需要异步或并发执行任务的程序都可以使用线程池。在开发过程中&#xff0c;合理的使用线程池能够带来3个好处。降低资源消耗。通过重复利用已创建的线程降低线程创建和销毁造成的消…

(1分钟了解)SLAM的七大问题:地图表示、信息感知、数据关联、定位与构图、回环检测、深度、绑架

编辑切换为居中添加图片注释&#xff0c;不超过 140 字&#xff08;可选&#xff09;SLAM问题也被称为是CML问题。编辑切换为居中添加图片注释&#xff0c;不超过 140 字&#xff08;可选&#xff09;编辑切换为居中添加图片注释&#xff0c;不超过 140 字&#xff08;可选&…