梯度下降法

news2025/2/25 5:28:42

梯度下降法

对于一个二元一次函数 y = ax + b,我们只需要知道两个 (x,y) 点即可获取到 ab 的值,我们称其为精确解,如下图:在这里插入图片描述
但是如果该函数中存在已知分布的噪声,那么又该如何求解:
在这里插入图片描述

我们可以假设 ab为任意值,则根据输入 x 有预测输出 y',实际的 y 值与预测的 y' 的差距我们称为损失,对于已知的若干 (x,y) 点的损失则为损失函数:
l o s s = 1 n ∑ i = 1 n ( y i − y ′ i ) 2 loss = \frac{1}{n} \sum_{i=1}^n (y_i - {y'}_i)^2 loss=n1i=1n(yiyi)2

我们最终的目的是求得ab使损失值最小,所以可以从当前的参数取值,一步步的按照损失函数下坡的方向下降,直到走到最低点。第一要保证 loss 是下降的,第二要使得下降的趋势尽可能的快。微积分的基础知识告诉我们:沿着梯度的反方向,是函数值下降最快的方向,所以只需要对损失函数求导,并且沿着导数的反方向逐步移动,则会找到最佳的 ab。我们称这种求解方法为梯度下降法,称该问题为线性回归问题。
在这里插入图片描述

代码实现

ab 分别求偏导数:
l o s s = 1 n ∑ i = 1 n ( a x i + b − y i ) 2 loss = \frac{1}{n} \sum_{i=1}^n (ax_i + b - y_i)^2 loss=n1i=1n(axi+byi)2
∂ l o s s ∂ a = 2 n ∑ i = 1 n [ ( a x i + b − y i ) x i ] {\partial loss \over \partial a} = {2\over n}\sum_{i=1}^n [(ax_i+b-y_i)x_i] aloss=n2i=1n[(axi+byi)xi]
∂ l o s s ∂ b = 2 n ∑ i = 1 n ( a x i + b − y i ) {\partial loss \over \partial b} = {2\over n}\sum_{i=1}^n (ax_i+b-y_i) bloss=n2i=1n(axi+byi)

# 线性回归模型

import numpy as np
import matplotlib.pyplot as plt


# 创建样本 y = ax+b,其中x = 1.72,b = 3.69
def create_sample():
    # 生成 0-100 之间 200 个随机数
    x = np.random.rand(200) * 10
    y = x * 1.72 + 3.69 + (np.random.normal(size=200))
    # 转换为 [[x1,y1],[x2,y2]...[xn,yn]] 的矩阵
    return np.array(list(zip(x, y)))


# 计算一次梯度,并更新 a,b 值
def gradient(a_cur, b_cur, points, learning_rate):
    a_gradient = 0
    b_gradient = 0
    points_length = len(points)
    # 计算梯度
    for i in range(0, points_length):
        x = points[i, 0]
        y = points[i, 1]
        a_gradient += (2 / points_length) * x * (a_cur * x + b_cur - y)
        b_gradient += (2 / points_length) * (a_cur * x + b_cur - y)

    # 更新 a、b 值
    new_a = a_cur - learning_rate * a_gradient
    new_b = b_cur - learning_rate * b_gradient
    return [new_a, new_b]


# 计算梯度
def computer_loss(a, b, points):
    points_length = len(points)
    loss = 0
    # 计算梯度
    for i in range(0, points_length):
        x = points[i, 0]
        y = points[i, 1]
        loss += (a * x + b - y) ** 2
    return loss


points = create_sample()
print("points", points)
a = 0
b = 0
loss_list = list()
for i in range(0, 100000):
    [a, b] = gradient(a, b, points, 0.001)
    loss_list.append(computer_loss(a, b, points))

#求得 a = 1.7219045715547612 b = 3.6145089870651086
print("a = {} b = {}".format(a, b))

# 绘制损失函数
plt.plot(loss_list)
plt.show()

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1037433.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

如何修复wmvcore.dll缺失问题,wmvcore.dll下载修复方法分享

近年来,电脑使用的普及率越来越高,人们在日常生活中离不开电脑。然而,有时候我们可能会遇到一些问题,其中之一就是wmvcore.dll缺失的问题。wmvcore.dll是Windows平台上用于支持Windows Media Player的动态链接库文件,如…

蓝桥杯每日一题2023.9.24

九进制转十进制 - 蓝桥云课 (lanqiao.cn) 题目描述 分析 #include<bits/stdc.h> using namespace std; int main() {cout << 2 * 9 * 9 * 9 0 * 9 * 9 2 * 9 2;return 0; } 顺子日期 - 蓝桥云课 (lanqiao.cn) 题目描述 分析 全部枚举 #include<bits/s…

Vector Art - 矢量艺术

什么是矢量艺术&#xff1f; 矢量图形允许创意人员构建高质量的艺术作品&#xff0c;具有干净的线条和形状&#xff0c;可以缩放到任何大小。探索这种文件格式如何为各种规模的项目提供创造性的机会。 什么是矢量艺术作品? 矢量艺术是由矢量图形组成的艺术。这些图形是基于…

LeetCode 494.目标和 (动态规划 + 性能优化)二维数组 压缩成 一维数组

494. 目标和 - 力扣&#xff08;LeetCode&#xff09; 给你一个非负整数数组 nums 和一个整数 target 。 向数组中的每个整数前添加 或 - &#xff0c;然后串联起所有整数&#xff0c;可以构造一个 表达式 &#xff1a; 例如&#xff0c;nums [2, 1] &#xff0c;可以在 2…

vue指令(代码部分二)

<template><view><view v-on:click"onClick">{{title}}</view><button click"clickNum">数值&#xff1a;{{num}}</button><view class"box" :style"{background:bgcolor}" click"clickB…

ROS 2官方文档(基于humble版本)学习笔记(三)

ROS 2官方文档&#xff08;基于humble版本&#xff09;学习笔记&#xff08;三&#xff09; 理解参数&#xff08;parameter&#xff09;ros2 param listros2 param getros2 param setros2 param dumpros2 param load在节点启动时加载参数文件 理解动作&#xff08;action&…

【python零基础入门学习】python进阶篇之时间表示方法和异常处理以及linux系统的os模块执行shell命令以及记账程序编写教学(一)

本站以分享各种运维经验和运维所需要的技能为主 《python零基础入门》&#xff1a;python零基础入门学习 《python运维脚本》&#xff1a; python运维脚本实践 《shell》&#xff1a;shell学习 《terraform》持续更新中&#xff1a;terraform_Aws学习零基础入门到最佳实战 《k8…

第1篇 目标检测概述 —(1)目标检测基础知识

前言&#xff1a;Hello大家好&#xff0c;我是小哥谈。目标检测是计算机视觉领域中的一项任务&#xff0c;旨在自动识别和定位图像或视频中的特定目标&#xff0c;目标可以是人、车辆、动物、物体等。目标检测的目标是从输入图像中确定目标的位置&#xff0c;并使用边界框将其标…

Docker 自动化部署(保姆级教程)

Docker 自动化部署 1. jenkins 介绍1.1 参考链接&#xff1a;1.2 jenkins 概述1.3 jenkins部署项目的流程 2. jenkins 安装2.1 基于docker 镜像2.2 启动 jenkins 后端服务2.3 登录 jenkins 服务后端 3. jenkins自动化部署开始3.1 下载需要的插件3.2 创建任务3.2.1 描述3.2.2 配…

Vue3+element-plus切换标签页时数据保留问题

记录一次切换标签页缓存失效问题&#xff0c;注册路由时name不一致可能会导致缓存失效

Visio——绘制倾斜线段

一、形状 -> 图表和数学图形 -> 多行 二、放置多行线&#xff0c;可以发现存在两个折点 三、选择多行线&#xff0c;右键选择删除点&#xff0c;即可得到倾斜线段

【python爬虫】爬虫所需要的爬虫代理ip是什么?

目录 前言 一、什么是爬虫代理 IP 二、代理 IP 的分类 1.透明代理 2.匿名代理 3.高匿代理 三、如何获取代理 IP 1.免费代理网站 2.付费代理服务 四、如何使用代理 IP 1.使用 requests 库 2.使用 scrapy 库 五、代理 IP 的注意事项 1.代理 IP 可能存在不稳定性 2…

Linux指令(ls、pwd、cd、touch、mkdir、rm)

whoami who pwd ls ls -l clearls指令 ls ls -l ls -a :显示当前目录下的隐藏文件&#xff08;隐藏文件以.开头&#xff09;ls -a -l 和 ls -l -a 和 ls -la 和 ls -al &#xff08;等价于ll&#xff09; pwd命令 显示用户当前所在的目录 cd指令 mkdir code &#xff08;创建…

《Python趣味工具》——ppt的操作(1)

前面我们学习了如何利用turtle模块制作emoji&#xff0c;今天来看看PPT的相关操作&#xff1a; 文章目录 一、PPT的基础结构&#xff1a;二、PPT的相关操作&#xff1a;1. 导入pptx模块2. ppt的基本操作&#xff1a; 三、总结&#xff1a;四、 完整源码&#xff1a; 小L想要把 …

Blender 学习笔记(二)之坐标

文章目录 归零世界坐标系与局部坐标系物体的编辑模式万向坐标系视图坐标轴游标坐标轴原点变换轴心点 归零 alt G 键 世界坐标系与局部坐标系 在blender 中的物体&#xff0c;默认情况下是世界坐标系&#xff0c;也就是全局坐标系 当你按G 键&#xff0c;再按一次x 键时&…

周赛364(模拟+贪心,枚举,单调栈+前后缀分解,枚举+DFS)

文章目录 周赛364[8048. 最大二进制奇数](https://leetcode.cn/problems/maximum-odd-binary-number/)贪心 模拟 [100049. 美丽塔 I](https://leetcode.cn/problems/beautiful-towers-i/)枚举 [100048. 美丽塔 II](https://leetcode.cn/problems/beautiful-towers-ii/)单调栈 …

数据结构与算法——16.二叉树

这篇文章我们来讲一下二叉树 目录 1.概述 2.代码实现 1.概述 树&#xff1a;&#xff08;Tree&#xff09;是计算机数据存储的一种结构&#xff0c;因为存储类型和现实生活中的树类似所以被称为树。 树的源头被称为根&#xff0c;树其余分叉点被称为节点&#xff0c;而树这…

未知非参数需求和有限价格变动的动态定价

英文题目&#xff1a;Dynamic Pricing with Unknown Non-Parametric Demand and Limited Price Changes 中文题目&#xff1a;未知非参数需求和有限价格变动的动态定价 单位&#xff1a;麻省理工学院&#xff0c;剑桥 时间&#xff1a;2019 论文链接&#xff1a;https://do…

制作频谱灯

最近研究了下傅里叶变换&#xff0c;用它可以通过采集声音信号由时域转换到频域内&#xff0c;从而得到声音的频谱信息&#xff0c;可以做个频谱灯。 主要使用ESP32来实现了他&#xff0c;实现效果如下&#xff1a; 频谱灯 为了可以带出去露营&#xff0c;我把它做的很大&…

ubuntu20.04下源码编译colmap

由于稠密重建需要CUDA&#xff0c;因此先安装CUDA&#xff0c;我使用的是3050GPU&#xff0c;nvidia-smi显示最高支持CUDA11.4。 不要用sudo apt安装&#xff0c;版本较低&#xff0c;30系显卡建议安装CUDA11.0以上&#xff0c;这里安装了11.1版本。 下载&#xff1a; cuda_1…