pytorch构建深度网络的基本概念——随机梯度下降

news2024/11/26 19:41:43

文章目录

  • 随机梯度下降
    • 定义一个简单的模型
    • 定义Loss
    • 什么是梯度
    • 随机梯度下降

随机梯度下降

现在说说深度学习中的权重更新算法:经典算法SGD:stochastic gradient descent,随机梯度下降。

定义一个简单的模型

假设我们的模型就是要拟合一根直线:
那么直线的方程为:
y = ω x + b y = \omega x + b y=ωx+b
而且我们的训练集有 n n n个样本:
( x 1 , y 1 ) , . . . . ( x n , y n ) {(x_1, y_1), .... (x_n, y_n)} (x1,y1),....(xn,yn)
如图所示:

更简单一点,假设这根直线穿过原点,那么后面的常数b也没有了

我们通过图中的这些样本点去训练模型获得靠谱的 w w w

定义Loss

在反向传播的基本逻辑里,首先要定义一个损失函数,损失函数其实就是预测值与真实值(训练数据)的差异,在这个例子里,我们定义的这个差异是均方误差,用公式表示就是:
l = ( y i − y i ^ ) 2 l = ({y_i} - \hat{y_i})^2 l=(yiyi^)2

那么,联合上面两个公式可以得到:

l = w x i 2 − 2 w x i y i ^ − y i ^ 2 l = {wx_i }^2 - 2wx_i\hat{y_i} -\hat{y_i}^2 l=wxi22wxiyi^yi^2

我们就是要通过这个误差去反向计算 w w w的调整量。

我们在训练的时候,一般来说都是一批数据一起调整,每次都调整也太费劲了,作为一个batch,用一个batch的样本损失的平均值来进行权重调整。
一个batch的每个样本都有一个损失:对这些损失求和,比如一个batch有m个样本
e = ∑ i m = 1 m ( ( x 1 2 + . . . + x m 2 ) ∗ w 2 + ( − 2 ∗ x 1 ∗ y 1 + . . . + − 2 ∗ x m ∗ y + m ) ∗ w + ( y 1 2 + . . . + y m 2 ) ) e = \sum_i^m=\frac{1}{m}(({x_1^2 + ... + x_m^2})*w^2 + (-2*x_1*y_1 + ... + -2*x_m*y+m)*w + (y_1^2 + ... +y_m^2)) e=im=m1((x12+...+xm2)w2+(2x1y1+...+2xmy+m)w+(y12+...+ym2))

如果令
a = x 1 2 + . . . + x m 2 a = {x_1^2 + ... + x_m^2} a=x12+...+xm2
b = ( − 2 ∗ x 1 ∗ y 1 + . . . + − 2 ∗ x m ∗ y + m ) b = (-2*x_1*y_1 + ... + -2*x_m*y+m) b=(2x1y1+...+2xmy+m)
c = ( y 1 2 + . . . + y m 2 ) c = (y_1^2 + ... +y_m^2) c=(y12+...+ym2)

那么上面公式实际上就是以 w w w为变量的二次方程,图像是一个朝上的U型曲线,因为 a a a肯定大于0。那么这个二次曲线图可以如图(matplot画的图,不够细致,大概那个意思):

什么是梯度

上面已经把模型的输出损失定义和公式列举出来了。
那么为了尽快收敛,也就是为了尽快的让 w w w能达到我们想要的目标,也就是上面这个曲线的最低处(极值),我们就会让当前的梯度(初始化可能是一个正态分布或者随机数)加上某个偏移量,那么这个偏移量是多少呢,就是对上面的值就行求导,求导也就是求梯度。
梯度的正方向是远离极值的方向,所以需要取个负号,所以也叫梯度下降。用公式表示就是:
θ = − ∂ e ∂ l = − 1 n ( a w + b ) \theta = -\frac{\partial e}{\partial l} = -\frac{1}{n}(aw + b) θ=le=n1(aw+b)
我这里为了省事,把b去掉了,有兴趣的可以带入进去算一下,因为对 w w w求偏导的时候, b b b就是个常量,求完就直接等于0了。

上面公式计算出来的就是曲线当前点的梯度,或者说是切线的斜率。那么沿着斜率走多少了?也就是最终还是要求出一个 Δ w \Delta w Δw来,这就是训练中的参数学习率了,在学习率 η \eta η的情况下(可以理解成斜率上的x轴距离),权重的调整量就是:
w ^ = w + Δ w = w + η ∗ θ \hat{w}=w+\Delta{w} = w + \eta * \theta w^=w+Δw=w+ηθ

随机梯度下降

随机是指上面的这个batch是在整个训练集随机挑选样本到batch中,这个就可以减小样本之间造成的参数更新抵消问题。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/750081.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

基于深度学习的高精度课堂人脸检测系统(PyTorch+Pyside6+YOLOv5模型)

摘要:基于深度学习的高精度课堂人脸检测系统可用于日常生活中或野外来检测与定位课堂人脸目标,利用深度学习算法可实现图片、视频、摄像头等方式的课堂人脸目标检测识别,另外支持结果可视化与图片或视频检测结果的导出。本系统采用YOLOv5目标…

力扣876. 链表的中间结点

题目 给你单链表的头结点head,请你找出并返回链表的中间结点。如果有两个中间结点,则返回第二个中间结点。 题解 设置快慢指针slow和fast,slow每次走一步,fast每次走两步,当fast走完时,slow刚好指到链表中间…

Vue从小白到入门(保姆级教学)

文章目录 🍋Vue是什么?🍋MVVM思想 🍋vue2快速入门🍋注意事项 🍋数据单向渲染🍋数据双向渲染🍋作业布置 🍋事件绑定🍋事件处理机制🍋注意事项和细节&#x1f…

西门子S7300以太网模块labview软件介绍

借助捷米特ETH-S7300-JM01以太网模块,通过NetS7 OPC和NI OPC Servers,西门子S7-300与测控软件NI LABVIEW实现以太网通讯和监控。 功能简介 LabVIEW是一种程序开发环境,由美国国家仪器(NzI)公司研制开发,类…

Redis 宕机了,如何避免数据丢失?

前言 如果有人问你:"你会把 Redis 用在什么业务场景下?" 我想你大概率会说:"我会把它当作缓存使用,因为它把后端数据库中的数据存储在内存中,然后直接从内存中读取数据,响应速度会非常快。…

英华特在创业板上市:总市值约50亿元,国产品牌持续向上

7月13日,苏州英华特涡旋技术股份有限公司(下称“英华特”,SZ:301272)在深圳证券交易所创业板上市。本次上市,英华特的发行价为51.39元/股,发行数量为1463万股,募资总额约为7.52亿元,…

直播 | SDS 容灾方案,让制品数据更安全

近日,腾讯 CODING WePack 制品管理系统 V1 以及腾讯 CODING DevOps 研发效能管理平台 V7 与 XSKY 星辰天合的统一数据平台 XEDP 及天合翔宇分布式存储系统完成互相兼容认证,在数据层面满足了共同客户敏捷开发的高可用建设合规要求。 联合解决方案可以帮…

Linux stress命令---压力测试

一、使用场景 CPU压力测试 内存压力测试 磁盘IO测试 Swap可用性测试 二、语法及常用参数 stress [选项] [进程数] -?, --help:显示帮助信息 --version:显示版本信息 -v, --verbose:详细输出 -q, --quiet:静默输出 -t, --timeout&…

基于python 和anaconda搭建环境

目录 1.先了解以下几点。 2 方案:pycharmanaconda 3.基本步骤 4 熟悉anaconda。 4.1 虚拟环境的创建方法 4.2 anaconda prompt中,常用指令 4.3 在Anaconda Navigate中的一些操作 4.3.1给已有虚拟环境安装包 4.3.2 新建虚拟环境 4.4 在pycharm中…

JavaScript 深度剖析-函数式编程(一)

文章介绍 为什么要学习函数编程以及什么是函数式编程函数式编程的特性(纯函数、柯里化、函数组合等)函数式编程的应用场景函数式编程库 Lodash 为什么要学习函数式编程 函数式编程是非常古老的一个概念,早于第一台计算机的诞生,函数式编程的历史。 那…

灵活利用ChatAI,提升你的码力—程序员篇

前言 ChatGPT目前还完全无法替代程序员,尤其是在一些强上下文的编程场景下,比如一些重业务的编程场景,但是可以利用它来完成一些编程相关的事,把它当做一个工具来大幅度提升我们的工作效率 ​开发:微信小程序 用户交互…

pg手动清理pg_wal文件

1、由于我是docker安装的,要先进入docker容器 docker exec -it a470585a9cdc /bin/bash2、查看哪个检查点之前的日志可以清除 pg_controldata $PGDATA表示00000001000000E7000000CE之前的pg_wal文件可以删除 3、手动清理pg_wal pg_archivecleanup -d $PGDATA/pg…

当我掉入计算机的大坑中时,遇到简单的题也很吃力,这可如何是好呢?

一支笔,一双手,一道力扣(Leetcode)做一宿!!! 一、分享自己相关的经历 我们可能经常听到这句话,人永远赚不到认知以外的钱,如果把它放到程序员行业来说,同样适…

微信加粉计数器后台开发

后台包括管理后台与代理后台两部分 管理后台 管理后台自带网络验证卡密系统,一个后台可以完成对Pc端的全部对接,可以自定义修改分组名称 分享等等代理后台 分享页 调用示例 <?php$request new HttpRequest(); $request->setUrl(http://xxxxxxx/api); $request->…

ROS:URDF、Gazebo与Rviz结合使用

目录 一、机器人运动控制以及里程计信息显示1.1ros_control 简介1.2运动控制实现流程(Gazebo)1.2.1为 joint 添加传动装置以及控制器1.2.2xacro文件集成1.2.3启动 gazebo并控制机器人运动 1.3Rviz查看里程计信息1.3.1启动 Rviz1.3.2添加组件 二、雷达信息仿真以及显示2.1流程分…

路径规划算法:基于人工兔优化的路径规划算法- 附代码

路径规划算法&#xff1a;基于人工兔优化的路径规划算法- 附代码 文章目录 路径规划算法&#xff1a;基于人工兔优化的路径规划算法- 附代码1.算法原理1.1 环境设定1.2 约束条件1.3 适应度函数 2.算法结果3.MATLAB代码4.参考文献 摘要&#xff1a;本文主要介绍利用智能优化算法…

浮点数的近似保存与计算

这里写目录标题 负数的补码存储十进制浮点数与二进制的转换有限循环的二进制无限循环的二进制 计算机对浮点数的保存无限循环二进制数的保存浮点数的近似 参考文献 负数的补码存储 首先我们回忆一下负数的补码表示。我们都知道&#xff0c;有符号数的负数使用补码的方式进行存…

WVP+ZLMediaKit实现网络摄像头接入

​ 记录下本地调试监控摄像头相关信息。 参考来源&#xff1a;部署 WVPZLMediaKit 实现大华摄像头接入_wvp zlm_鬼畜的稀饭的博客-CSDN博客 ZLMediaKit 代码地址 WVP 代码地址 ⚠️ 摄像头需要连接PoE设备来供电&#xff08;插网线就能供电&#xff09; 资源清单&#xff1a…

如何通过设备管理系统实现设备全生命周期管理

设备是生产力的核心&#xff0c;对企业的运营和效益起着至关重要的作用。然而&#xff0c;随着设备数量和复杂性的增加&#xff0c;如何有效管理设备的全生命周期成为了一个挑战。 在这个时代&#xff0c;设备管理系统成为了一种重要的工具&#xff0c;帮助企业实现设备全生命周…

LJUBOMORA - 思维+二分

分析&#xff1a; 二分最小的嫉妒值&#xff0c;每次check需要判断每一种颜色需要分给几个小朋友&#xff0c;如果可以所有都分完那么返回true。 代码&#xff1a; #include <bits/stdc.h>using namespace std;typedef long long ll; typedef pair<int,int> pii;…