pytorch深度学习实践(二):梯度下降算法详解和代码实现(梯度下降、随机梯度下降、小批量梯度下降的对比)

news2024/11/18 11:17:29

目录

  • 一、梯度下降
    • 1.1 公式与原理
      • 1.1.1 cost(w)
      • 1.1.2 梯度
      • 1.1.3 w的更新
    • 1.2 训练过程可视化
    • 1.3 代码实现
  • 二、随机梯度下降(stochastic gradient descent,SDG)
    • 2.1 公式与原理
      • 2.1.1 w的更新
    • 2.2 代码实现
    • 2.3 梯度下降和随机梯度下降的优缺点对比
      • 2.3.1 梯度下降算法(Batch Gradient Descent)
      • 2.3.2 随机梯度下降算法(Stochastic Gradient Descent)
  • 三、小批量梯度下降(Mini-batch Gradient Descent)
    • 3.1 优势
    • 3.2缺点
    • 3.3 代码实现
  • 总结

一、梯度下降

1.1 公式与原理

1.1.1 cost(w)

cost为数据集中所有样本的误差值平方再求均值。

在这里插入图片描述

1.1.2 梯度

计算梯度时为所有样本的梯度。一个样本的梯度为: g r a d i = 2 ∗ x i ∗ ( x i ∗ w i − y i ) grad_i = 2*x_i*(x_i*w_i-y_i) gradi=2xi(xiwiyi),所有样本的梯度为所有样本的 g r a d i grad_i gradi的和求平均。
在这里插入图片描述

1.1.3 w的更新

一个epoch中:w会等到中所有的x和y都计算完平均值之后再更新。

1.2 训练过程可视化

一般正常的训练过程中cost function都是一直在波动中下降的,如果出现了cost先下降到最小然后又上升的情况(抛物线),则说明训练失败,一般的原因是因为学习率设置过大。

在这里插入图片描述

1.3 代码实现

import numpy as np
import matplotlib.pyplot as plt

x_data = [1.0,2.0,3.0]
y_data = [2.0,4.0,6.0]
w=1.0

def forward(x):
    return x*w

def cost(xs,ys):
    cost = 0
    for x,y in zip(xs,ys):
        y_pred = forward(x)
        cost += (y_pred - y) ** 2
    return cost/len(xs)

def gradient(xs,ys):
    grad = 0
    for x,y in zip(xs,ys):
        grad+= 2*x*(x*w-y)
    return grad/len(xs)

w_list = []
cost_list = []
w_list.append(0.1)
for epoch in range(101):
    cost_val = cost(x_data,y_data)
    grad_val = gradient(x_data,y_data)
    w-=0.01*grad_val
    w_list.append(w)
    cost_list.append(cost_val)
    print('Epoch:',epoch,'w=',w,'loss',cost_val)

plt.plot(range(101),cost_list)
plt.xlabel('epoch')
plt.ylabel('cost')
plt.show()

在这里插入图片描述
在这里插入图片描述

二、随机梯度下降(stochastic gradient descent,SDG)

2.1 公式与原理

随机梯度下降:从样本中随机抽出一组x和y,训练后按梯度更新一次,然后再抽取一组,再更新一次。

在这里插入图片描述

2.1.1 w的更新

计算一次 x i x_i xi y i y_i yi的梯度就进行一次参数更新。

一个epoch中:要进行样本个数次的参数更新

2.2 代码实现

import numpy as np
import matplotlib.pyplot as plt

x_data = [1.0,2.0,3.0]
y_data = [2.0,4.0,6.0]
w=1.0

def forward(x):
    return x*w

def loss(x,y):
    return (forward(x)-y)**2

def grad(x,y):
    return 2 * x * (x * w - y)


loss_list = []
for epoch in range(101):
    for x,y in zip(x_data,y_data):
        w -= 0.01*grad(x,y)
        l = loss(x,y)
    loss_list.append(l)
    print("epoch=",epoch,"w=",w,"loss=",loss)

plt.plot(range(101),loss_list)
plt.xlabel('epoch')
plt.ylabel('loss')
plt.show()

在这里插入图片描述
在这里插入图片描述

2.3 梯度下降和随机梯度下降的优缺点对比

2.3.1 梯度下降算法(Batch Gradient Descent)

优点:

收敛性较好: 梯度下降在每次迭代中使用整个训练集计算梯度,通常能够更快地收敛到较好的解;
稳定性高: 由于使用整个训练集计算梯度,梯度下降的更新方向相对稳定,能够更稳定地接近最优解;
可并行化: 由于每次迭代使用整个训练集,梯度下降可以更容易地进行并行化计算,加快训练速度。

缺点:

内存消耗大: 梯度下降需要在内存中保存整个训练集,对于大规模数据集来说,内存消耗较大;
计算代价高: 每次迭代都需要计算整个训练集的梯度,对于大规模数据集和复杂模型,计算代价较高;
容易陷入局部最优解:梯度下降可能会陷入局部最优解,特别是在非凸优化问题中。

2.3.2 随机梯度下降算法(Stochastic Gradient Descent)

优点:

计算代价低: 随机梯度下降每次迭代只使用一个样本计算梯度,因此计算代价较低;
内存消耗小:由于只需要一个样本,随机梯度下降的内存消耗相对较小;
可适用于在线学习:随机梯度下降适用于在线学习,可以动态地更新模型。

缺点:

收敛性相对较差: 由于梯度的随机性,随机梯度下降的收敛性较梯度下降差,可能会陷入波动或震荡;
不稳定:由于每次迭代只使用一个样本,随机梯度下降的更新方向相对不稳定,可能无法稳定地接近最优解;
学习率选择困难: 由于样本的随机性,随机梯度下降的学习率选择较为困难,需要进行合适的学习率调度。

三、小批量梯度下降(Mini-batch Gradient Descent)

结合BGD和SGD的优点,每一个epoch中取batchsize个样本进行梯度的更新。在每次迭代中随机均匀采样多个样本来组成一个小批量来计算梯度,一个epoch周期内会进行(样本数目/批量大小)次的参数更新。

3.1 优势

小批量梯度下降(Mini-batch Gradient Descent)是梯度下降和随机梯度下降的一种折衷方案,它同时具有一些梯度下降和随机梯度下降的优势,主要包括以下几点优势:

  1. 较低的方差:相比于随机梯度下降,小批量梯度下降使用一小批样本来计算梯度,因此梯度估计的方差较低。 这使得小批量梯度下降相对更稳定,收敛性更好,并且可以更快地接近最优解。

  2. 较高的计算效率:相比于梯度下降,小批量梯度下降每次迭代只使用一小批样本计算梯度,因此计算代价较低。这使得小批量梯度下降在处理大规模数据集时更具优势,能够更快地完成一轮迭代。

  3. 更好的泛化性能:由于小批量梯度下降使用了一小批样本的信息,在每次迭代中能够更好地反映训练集的整体特点。这使得小批量梯度下降相对于随机梯度下降在一定程度上具有更好的泛化性能,可以得到更好的模型。

  4. 并行化能力:小批量梯度下降的计算可以进行一定程度的并行化处理。由于每次迭代使用了一小批样本,可以将这些样本分配给不同的计算单元进行计算,从而提高训练速度。

3.2缺点

与梯度下降相比,由于每次迭代只使用了一小批样本,可能会引入一些噪声,导致更新方向相对不稳定。

小批量梯度下降需要选择合适的批大小,过小的批大小可能导致收敛速度变慢,而过大的批大小可能会增加计算代价和内存消耗。

3.3 代码实现

import numpy as np
import matplotlib.pyplot as plt

x_data = [1.0,2.0,3.0]
y_data = [2.0,4.0,6.0]
w=1.0
n=2
x_data_n = x_data[0:2]
y_data_n = y_data[0:2]

def forward(x):
    return x*w

def loss(x1,y1,n):
    loss = 0
    for x,y in zip(x1,y1):
        loss += (forward(x)-y)**2
    return loss/n

def grad(x1,y1,n):
    grad = 0
    for x, y in zip(x1, y1):
        grad += 2*x*(x*w-y)
    return grad/n


loss_list = []

for epoch in range(101):
    w-=0.01*grad(x_data_n,y_data_n,n)
    loss_list.append(loss(x_data_n,y_data_n,n))
    print("epoch=",epoch,"w=",w,"loss=",loss)

plt.plot(range(101),loss_list)
plt.xlabel('epoch')
plt.ylabel('loss')
plt.show()

在这里插入图片描述
在这里插入图片描述

总结

现在多使用小批量随机梯度下降算法来进行梯度的更新。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1139723.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

漏洞复现-jquery-picture-cut 任意文件上传_(CVE-2018-9208)

jquery-picture-cut 任意文件上传_(CVE-2018-9208) 漏洞信息 jQuery Picture Cut v1.1以下版本中存在安全漏洞CVE-2018-9208文件上传漏洞 描述 ​ picture cut是一个jquery插件,以友好和简单的方式处理图像,具有基于bootstrap…

Vue3-小兔鲜项目

1.初始化项目 npm init vuelatest src目录调整 Git项目管理 基于create-vue创建出来的项目默认没有初始化git仓库,需要我们手动初始化 执行命令并完成首次提交 1.git init 2.git add 3.git commit -m "init" 别名路径联想提示 什么是别名路径联想…

通过requests库使用HTTP编写的爬虫程序

使用Python的requests库可以方便地编写HTTP爬虫程序。以下是一个使用requests库的示例: import requests# 发送HTTP GET请求 response requests.get("http://example.com")# 检查响应状态码 if response.status_code 200:# 获取响应内容html response.…

推荐5款助你高效工作的小软件

现在,有很多实用的工具和软件可以帮助我们更高效地完成各种任务。以下是5款值得推荐的工具软件,能够极大地提高我们的工作效率。 1.电子书阅读器——Koodo Reader ​ Koodo Reader 是一款开源免费的电子书阅读器,支持多达15种主流电子书格式…

laravel+vue2 element 一套项目级医院手术麻醉信息系统源码

手术麻醉临床信息系统源码,PHPmysqllaravelvue2 手术麻醉临床信息系统,采用计算机和通信技术,实现监护仪、麻醉机、输液泵等设备输出数据的自动采集,采集的数据能够如实准确地反映患者生命体征参数的变化,并实现信息高…

搜维尔科技:【应用】配备MTi-3的轻便型ROV,在水下进行地理标记视觉检测

部署潜水员进行水下摄像,不仅难度高而且费用昂贵,需要受过潜水和摄像两方面培训的专业人员来进行。但有些水下作业任务例如拍摄海底管道内部的照片,由于人员无法进入或危险度高的原因,无法由潜水员完成。 如今,俄罗…

看谷歌浏览器源码,为什么p标签和div标签为块元素

看谷歌浏览器源码 谷歌源码路径:third_party/blink/renderer/core/html/resources/html.css 为什么块级元素独占一行? 是谷歌浏览器设置div的默认样式 display:block 它才独占一行 p标签和div标签为块元素 strong,b,i,em等等标签为行内元素

如何在Excel中实现三联类模板?

本文由葡萄城技术团队原创并首发。转载请注明出处:葡萄城官网,葡萄城为开发者提供专业的开发工具、解决方案和服务,赋能开发者。 前言 在一些报表打印应用场景中,会有类似于如下图所示的排版格式: 一般情况下将这种类…

k8s statefulSet 学习笔记

缩写: sts 通过 kubectl api-resources 可以查到: NAMESHORTNAMESAPIVERSIONNAMESPACEDKINDstatefulsetsstsapps/v1trueStatefulSet web-sts.yaml apiVersion: v1 kind: Service metadata:name: nginxlabels:app: nginx spec:ports:- port: 80name: web-sts-svc…

22年上半年下午题

第一大题题目 第一大题解答 第一小问 看加工交互和说明来得出实体的名字。如果不太确定,可以多去看几条数据流来确认答案。仔细一点,这分稳啦。 第二小问 需要对应加工结合说明得出数据存储的名称。 一般可以在后面加上表字或者加上信息表。自拟&…

2023年Q3企业邮箱安全性报告:境内钓鱼邮件超过境外攻击

10月25日,Coremail邮件安全联合北京中睿天下信息技术有限公司发布《2023年第三季度企业邮箱安全性研究报告》。2023年第三季度企业邮箱安全呈现出何种态势?作为邮箱管理员,我们又该如何做好防护? 以下为精华版阅读,如需…

u盘资料不小心删掉怎么找回来?一文教会你恢复方法

案例描述:“平时我都是使用U盘来存储和传输公司重要的资料。昨天,不小心将一个文件夹整个删除,里面包含了我准备好几个月的工作成果和重要的项目资料。怎么办!!!救救我的宝贝资料吧!” 在日常生…

简述低功耗语音芯片的含义与特点

低功耗语音芯片是一种功耗较低的集成电路,其集成了语音处理、控制逻辑等多个功能。相比传统的语音芯片,低功耗语音芯片能够在功耗较低的情况下完成更多的功能,因此非常适合移动设备和可穿戴设备等对功耗要求较高的场景。 低功耗语音芯片的主要…

字符串中的assert和stract

assert:函数原型是:void assert( int expression );其作用是现计算表达式 expression ,如果其值为假(即为0),那么它先 stderr 打印一条出信息,然后通过调用 abort 来终止程序运行。使用assert 的缺点是,频繁的调用会影…

AI口语APP的实现方法

开发AI口语应用程序涉及多个技术领域,包括语音识别、自然语言处理、机器学习和应用程序开发。下面是开发AI口语应用程序的一般步骤和实现方法,希望对大家有所帮助。北京木奇移动技术有限公司,专业的软件外包开发公司,欢迎交流合作…

揭秘!新手主播如何快速出圈,看拓世法宝分分钟打造百万直播间

互联网技术的不断进步催生了信息传播方式的革新,直播作为其重要产物之一,已成为人们获取信息和娱乐消遣的重要途径。尤其在当前信息爆炸的时代背景下,直播以其即时性和互动性满足了人们追求实时资讯和娱乐的需求。这种蓬勃发展的直播行业也在…

openEuler 22.03 LTS 环境使用 Docker Compose 一键部署 JumpServer (all-in-one 模式)

环境回顾 上一篇文章中,我们讲解了 openEuler 22.03 LTS 安装 Docker CE 和 Dcoker Compose,部署的软件环境版本分别如下: OS 系统:openEuler 22.03 LTS(openEuler-22.03-LTS-x86_64-dvd.iso)Docker Engine:Docker C…

采购供应链可见性的详细介绍(数智化采购供应链系统)

信息来源:专业的数智化采购供应链系统整体解决方案提供商郑州信源分享! 有这样一句话:“让看得见全局的人做决策。” 那么如何才能“看见”,并且可以看到“全局”呢? 答案就是采购供应链的可见性。 采购供应链可见…

【Ansible自动化运维工具 1】Ansible常用模块详解(附各模块应用实例和Ansible环境安装部署)

Ansible常用模块 一、Ansible1.1 简介1.2 工作原理1.3 Ansible的特性1.3.1 特性一:Agentless,即无Agent的存在1.3.2 特性二:幂等性 1.4 Ansible的基本组件 二、Ansible环境安装部署2.1 安装ansible2.2 查看基本信息2.3 配置远程主机清单 三、…

GIT在window是 配置SSHKEY

1、打开你得命令行工具,输入: cd ~/.ssh2、生成密钥 #设置自己的邮箱,随意设置 $ ssh-keygen -t rsa -C "wqzbxh163.com"#输入保存密钥的文件名字 Enter file in which to save the key (/c/Users/dahai/.ssh/id_rsa): wqzbxh剩下…