《PyTorch深度学习实践》第三讲 梯度下降

news2024/11/16 11:31:13

b站刘二大人《PyTorch深度学习实践》课程第三讲梯度下降笔记与代码:https://www.bilibili.com/video/BV1Y7411d7Ys?p=3&vd_source=b17f113d28933824d753a0915d5e3a90


image-20230629112641254

上一讲例子中,初始权重 w w w是随机给的,然后计算每个样本 x x x的预测值 y ^ \hat{y} y^与真实值 y y y的误差平方,再算整个训练集的均方根误差,选择最小的均方根误差对应的权重值

  • 上一讲中采用的是穷举法来确定权重值,即先确定权重值的一个大概范围,然后再里面进行采样,计算每个权重值 w w w的误差,最后选择误差最小的那个权重值

    image-20230629113300139

存在的问题

  • 在实际的学习任务中,损失函数 c o s t ( w ) cost(w) cost(w)并不是图中这种理想曲线
    • 一维的时候还可以使用线性搜索,但如果有两个权重 y ^ ( w 1 , w 2 , x ) \hat{y}(w_1,w_2,x) y^(w1,w2,x),那么就是在一个平面中进行搜索,一维如果是搜索100次,那么在二维的时候就是100的平方,权重再多点那么搜索量将会更大

分治法

  • 先进行稀疏的搜索,认为结果在值较小的区域,然后再在值较小的区域内进行稀疏的搜索,以此往复……

  • 问题:容易陷入局部最优;高维度无法搜索

    image-20230629114338708

将使目标函数最小的问题定义为优化问题

image-20230629114832863

Gradient Descent,梯度下降法

  • 有一个初始猜测值,但需要确定搜索方向?

    image-20230629115216058
  • 采用梯度确定搜索方向

    • 对代价函数 c o s t cost cost求关于权重 w w w的导数,可以得到上升方向

      • Δ x \Delta x Δx大于0,如果导数大于0,意味着 x x x加上 Δ x \Delta x Δx后函数变大了,即往正方向是上升的;如果导数小于0,意味着随着 x x x的增加函数在减小,即往负方向是上升的

        image-20230629120007928
    • 往梯度的负方向搜索(下降最快的方向),可以得到最小值

      • α \alpha α是学习率,即搜索步长
      image-20230629120229271
    • 梯度下降法只能保证找到局部最优点,不一定能找到全局最优

      • 实际深度学习问题中的损失函数并没有很多的局部最优点,不容易陷入局部最优点
      image-20230629120510278
      • 鞍点:梯度等于0(马鞍面)

        image-20230629121049402

接下去就是反复求梯度,往梯度负方向搜索

梯度计算看原视频20:17处,老师讲的很详细

image-20230629121444719 image-20230629121645181

代码实现:

import numpy as np
import matplotlib.pyplot as plt

# 训练集
x_data = [1.0, 2.0, 3.0]
y_data = [2.0, 4.0, 6.0]

# 初始权重的猜测值
w = 1.0

# 存储训练轮数以及对应的loss值用于绘图
epoch_list = []
cost_list = []


def forward(x):
    # 定义模型:y_hat = x * w
    return x * w


def cost(xs, ys):
    # 定义代价函数cost(w)。xs就是x_data, ys就是y_data
    cost = 0
    for x, y in zip(xs, ys):
        y_pred = forward(x)         # y_hat
        cost += (y_pred - y) ** 2   # (y_hat - y)^2,然后再累加
    return cost / len(xs)           # 累加后再除样本数量N,即MSE公式中的1/N


def gradient(xs, ys):
    # 计算梯度,即对cost(w)求关于w的偏导
    grad = 0
    for x, y in zip(xs, ys):
        grad += 2 * x * (x * w - y)  # 累加部分
    return grad / len(xs)            # 除样本数量N


print('Predict (before training)', 4, forward(4))

# 训练过程,迭代100次(100轮训练)
# 每次都是将权重w减去学习率乘以梯度
for epoch in range(100):
    cost_val = cost(x_data, y_data)         # 当前步的损失值
    grad_val = gradient(x_data, y_data)     # 当前的梯度
    w -= 0.01 * grad_val                    # 更新权重w,0.01是学习率
    print('Epoch: ', epoch, 'w = ', w, 'loss = ', cost_val)  # 打印每一轮训练的日志

    epoch_list.append(epoch)
    cost_list.append(cost_val)

print('Predict (after training)', 4, forward(4))


# loss曲线绘制,x轴是epoch,y轴是loss值
plt.plot(epoch_list, cost_list)
plt.ylabel('cost')
plt.xlabel('epoch')
plt.show()
image-20230629124613176 image-20230629124552846

实际中的训练曲线一般存在波动,但总体上是收敛的

  • 绘图的时候一般会采用指数加权均值来对曲线做平滑处理,以便观察训练趋势
image-20230629124918957

梯度下降法在实际中应用较少,用的比较多的是它的衍生版本,随机梯度下降

Stochastic Gradient Descent

  • 梯度下降法中是对数据集的损失cost进行梯度更新
  • 提供N个数据,随机梯度下降是从这N个数据中随机选一个,将其损失loss来更新,即单个样本的损失对权重求导
image-20230629125646911

使用随机梯度下降的原因:

  • cost function曲线存在鞍点

  • 每次随机取1个样本,会引入了随机噪声,那么即便陷入鞍点,随机噪声可能会对其进行推动,那么有可能离开鞍点

    image-20230629130117872

代码实现:

import matplotlib.pyplot as plt

# 训练集
x_data = [1.0, 2.0, 3.0]
y_data = [2.0, 4.0, 6.0]

# 初始权重的猜测值
w = 1.0

# 存储训练轮数以及对应的los值用于绘图
epoch_list = []
cost_list = []


def forward(x):
    # 定义模型:y_hat = x * w
    return x * w


def loss(x, y):
    # 计算loss function
    y_pred = forward(x)         # y_hat
    return (y_pred - y) ** 2    # (y_hat - y)^2


def gradient(x, y):
    # 计算梯度
    return 2 * x * (x * w - y)


print('Predict (before training)', 4, forward(4))

# 训练过程
for epoch in range(100):
    for x, y in zip(x_data, y_data):
        grad = gradient(x, y)           # 对每一个样本求梯度
        w = w - 0.01 * grad             # 用一个样本的梯度来更新权重,而不是所有的
        print("\tgrad: ", x, y, grad)
        l = loss(x, y)

    print("progress: ", epoch, "w = ", w, "loss = ", l)

print('Predict (after training)', 4, forward(4))
image-20230629131155745

在实际问题中,梯度下降中对 f ( x i ) f(x_i) f(xi)求导和对 f ( x i + 1 ) f(x_{i+1}) f(xi+1)求导是没有关联的,相互独立,因此是可以进行并行计算的

但是在随机梯度下降中,对 w w w求导,但是 w w w是要更新的,因此下一步的权重更新与上一步的更新结果之间存在关系,即不可进行并行计算

image-20230629131943529
  • 梯度下降法可以并行计算,时间复杂度低,但学习器的性能差
  • 随机梯度下降无法并行计算,时间复杂度高,但学习器的性能好

采用折中的方法:Batch,批量

批量随机梯度下降

  • 若干个样本一组,每次用这一组样本的梯度进行权重更新
image-20230629132420726

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/700787.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

JavaWeb 笔记-1

JavaWeb 笔记-1 初始JavaWeb什么是JavaWeb 一、JDBC1.1、JDBC简介1.2、API详解-DriverManager1.3、API详解-Connection1.4、API详解-Statement1.5、API详解-ResultSet1.6、API详解-PreparedStatement1.6.1、API详解-PreparedStatement-SQL注入演示1.6.2、API详解-PreparedState…

RK3568平台开发系列讲解(外设篇)RFID 模块调试

🚀返回专栏总目录 文章目录 一、RFID 工作原理二、硬件连接三、驱动程序四、设备树五、测试程序沉淀、分享、成长,让自己和他人都能有所收获!😄 📢本篇我们将讲解 RFID 模块调试。 一、RFID 工作原理 射频识别技术也就是 RFID,英文名为 Radio Frequency Identificati…

玻璃活动隔断安装需要注意什么

随着社会的发展和人们对空间利用的要求不断提高,玻璃活动隔断逐渐成为办公室和商业空间中常见的装修选择。玻璃活动隔断不仅可以有效分割空间,提供私密性,还能保持充足的采光和视觉效果。然而,为了确保玻璃活动隔断的安装质量和使…

PSP - MetaPredict 预测蛋白质序列的内源性无序区域 (Intrinsically Disordered Regions)

欢迎关注我的CSDN:https://spike.blog.csdn.net/ 本文地址:https://blog.csdn.net/caroline_wendy/article/details/131461900 MetaPredict 算法简介: 内源性无序区域(IDRs)在所有生命领域中都普遍存在,并…

【GPT】如何拥有离线版本的GPT以及部署过程中的问题

【背景】 目前很多公司由于数据安全的问题,不能使用OpenAI的GPT,同时也没有必要非得使用如此泛用化的GPT。很多公司因此有训练自己的离线GPT的需求,这样的GPT只需要具备专业知识即可。 要使这个成为可能,首先就需要能够让GPT的Mo…

InstructGPT学习

GPT发展历程 在回答这个问题之前,首先要搞清楚ChatGPT的发展历程。 GPT-1用的是无监督预训练有监督微调。GPT-2用的是纯无监督预训练。GPT-3沿用了GPT-2的纯无监督预训练,但是数据大了好几个量级。InstructGPT在GPT-3上用强化学习做微调,内…

企业邮箱如何将一个用户设置到多个部门/群组

1、使用管理员账号postmaster登录企业邮局,点击“邮局管理”。 2、点击“组织与成员”。 3、勾选需要设置的用户,点击“设置所属部门/群”。(例如:我们需要将所属销售分公司的高阳,加入到以下四个分销部中,…

Hive on Zeppelin

** Hive on Zeppelin ** 官网:zeppelin.apache.org 做大数据的人应该对Hive不陌生,Hive应该是大数据SQL引擎的鼻祖。历经多个版本的改进,现在的Hive3已经具备比较完善的ACID功能,能够同时满足交互式查询和ETL 两种场景。 那怎…

Linux内核的编译、安装、调试

这里写目录标题 编译安装内核下载内核安装依赖更改.config编译内核安装首先安装模块安装内核更改引导更改grub重启 其他操作清理内核源目录卸载安装的内核修改内核配置菜单实现对新加入内核源码的控制 常见问题1. Module.symvers is missing2. No rule to make target ‘debian…

Revit三维视图:第一人称的视角看模型,生成局部三维视图

​  一、Revit中怎么以第一人称的视角看空间效果 我们创建一栋完整的楼模型后,会不会想说假设在里面看看是什么效果呢,就是说想看看第一视角的空间效果,那么如何可以看第一人称的空间效果图呢?以下看步骤: 1、 打开楼层平面图 …

系统架构设计师 6:数据库设计

一、数据库系统 数据库系统(DataBase System, DBS)是一个采用了数据库技术,有组织地、动态地存储大量相关联数据,从而方便多用户访问的计算机系统。广义上讲,DBS包括了数据库管理系统(DBMS)。 …

详细认识二叉树【图片+代码】

目录 一、树的概念及结构 1.1树的概念 1.2树的相关概念 1.3树的表示 1.4树在实际中的应用(目录树) 二、二叉树概念及结构 2.1概念 2.2特殊的二叉树 2.3二叉树的性质 2.4二叉树存储结构 三、二叉树的顺序结构及实现 3.1二叉树的顺序结构 3…

Redis6之集群

集群,就是通过增加服务器的数量,提供相同的服务,从而让服务器达到一个稳定、高效的状态 必要性 单个redis存在不稳定性。当redis服务宕机了,就没有可用的服务了。而且单个redis的读写能力是有限的。使用redis集群可以强化redis的…

PIL.Image 调色板模式处理标签数据

文章目录 1 使用PIL.Image库进行调色板模式2 转回原来的色彩3 效果参考 1 使用PIL.Image库进行调色板模式 基本步骤: 自定义调色板,数据格式是一个Nx3的二维数组,一维数组的位置为分类的下标数据类型为np.uint8转化为调色板模式后img.conve…

想知道音频怎么转文字吗?

随着数字化技术的不断发展,我们生活中产生的各种音频越来越多,例如会议录音、采访录音等等。虽然音频记录信息方便,但它们在信息处理、存储和分享方面也存在问题。比如当我们需要对音频中的内容进行编辑或整理时,手动打字出现漏字…

Eclipse中项目的配置

1、修改本地运行时Tomcat对应的JRE版本 老项目升级JDK,在eclipse修改了项目的jdk、编译等级,但还是启动失败,报“java.lang.UnsupportedClassVersionError”。 观察发现,启动日志,tomcat还是使用的jdk1.5,…

编程题分享:有⼀堆糖果,其数量为n,现将糖果分成不同数量的堆数

背景 近期面试遇到一家公司的编程题,觉得挺有参考价值 此处使用 PHP语言,进行编码测试, 编码之前要进行思路分析,避免无头苍蝇,走一步看一步 最后,希望后期面试顺利!欢迎指摘 . 题目&#xff1…

形态学操作之膨胀

note // 膨胀原理:操作过程中,若膨胀因子某点是1,且原图该点为1,则锚点位置为1 code // 膨胀 // 膨胀原理:操作过程中,若膨胀因子某点是1,且原图该点为1,则锚点位置为1 typedef e…

gma 2 教程(一)概述:1.GMA 简介

地理与气象分析库(Geographic and Meteorological Analysis. gma),是一个基于 Python 的地理、气象数据快速处理分析和地理制图函数包。构建过程参考了ArcGIS和QGIS的操作逻辑和特点,并添加诸多独创性、独有的功能,具有…

QT Creator上位机学习(三)QString及其相关控件介绍

系列文章目录 文章目录 系列文章目录字符串QStringQLableQLineEditQString的常用功能 字符串QString QSting类,用于处理字符串,进行字符串和数字之间的转化 转换函数: //字符串转数字 QString str......; int numstr.toInt(); float num2s…