【深度学习】机器学习概述(二)优化算法之梯度下降法(批量BGD、随机SGD、小批量)

news2024/11/17 22:35:07


文章目录

  • 一、基本概念
  • 二、机器学习的三要素
    • 1. 模型
      • a. 线性模型
      • b. 非线性模型
    • 2. 学习准则
      • a. 损失函数
      • b. 风险最小化准则
    • 3. 优化
      • 机器学习问题转化成为一个最优化问题
      • a. 参数与超参数
      • b. 梯度下降法
        • 梯度下降法的迭代公式
        • 具体的参数更新公式
        • 学习率的选择
      • c. 随机梯度下降
        • 批量梯度下降法 (BGD)
        • 随机梯度下降法 (SGD)
        • 小批量梯度下降法 (Mini-batch Gradient Descent)
        • SGD 的优势
        • SGD 的挑战

一、基本概念

  机器学习:通过算法使得机器能从大量数据中学习规律从而对新的样本做决策
  机器学习是从有限的观测数据中学习(或“猜测”)出具有一般性的规律,并可以将总结出来的规律推广应用到未观测样本上。
在这里插入图片描述

二、机器学习的三要素

  机器学习方法可以粗略地分为三个基本要素:模型、学习准则、优化算法

1. 模型

a. 线性模型

f ( x ; θ ) = w T x + b f(\mathbf{x}; \boldsymbol{\theta}) = \mathbf{w}^T \mathbf{x} + b f(x;θ)=wTx+b

b. 非线性模型

  广义的非线性模型可以写为多个非线性基函数 ϕ ( x ) \boldsymbol{\phi}(\mathbf{x}) ϕ(x) 的线性组合: f ( x ; θ ) = w T ϕ ( x ) + b f(\mathbf{x}; \boldsymbol{\theta}) = \mathbf{w}^T \boldsymbol{\phi}(\mathbf{x}) + b f(x;θ)=wTϕ(x)+b其中, ϕ ( x ) = [ ϕ 1 ( x ) , ϕ 2 ( x ) , … , ϕ K ( x ) ] T \boldsymbol{\phi}(\mathbf{x}) = [\phi_1(\mathbf{x}), \phi_2(\mathbf{x}), \ldots, \phi_K(\mathbf{x})]^T ϕ(x)=[ϕ1(x),ϕ2(x),,ϕK(x)]T 是由 K K K 个非线性基函数组成的向量,参数 θ \boldsymbol{\theta} θ 包含了权重向量 w \mathbf{w} w 和偏置 b b b
  如果 ϕ ( x ) \boldsymbol{\phi}(\mathbf{x}) ϕ(x) 本身是可学习的基函数,例如:

ϕ k ( x ) = h ( w k T ϕ ′ ( x ) + b k ) \phi_k(\mathbf{x}) = h(\mathbf{w}_k^T \boldsymbol{\phi}'(\mathbf{x}) + b_k) ϕk(x)=h(wkTϕ(x)+bk)其中, h ( ⋅ ) h(\cdot) h() 是非线性函数, ϕ ′ ( x ) \boldsymbol{\phi}'(\mathbf{x}) ϕ(x) 是另一组基函数, w k \mathbf{w}_k wk b k b_k bk 是可学习的参数,那么模型 f ( x ; θ ) f(\mathbf{x}; \boldsymbol{\theta}) f(x;θ) 就等价于神经网络模型。

2. 学习准则

a. 损失函数

b. 风险最小化准则

【深度学习】机器学习概述(一)机器学习三要素——模型、学习准则、优化算法

3. 优化

机器学习问题转化成为一个最优化问题

  一旦确定了训练集 D \mathcal{D} D、假设空间 F \mathcal{F} F 以及学习准则,接下来的任务就是通过优化算法找到最优的模型 f ( x , θ ∗ ) f(\mathbf{x}, \boldsymbol{\theta}^*) f(x,θ)。机器学习的训练过程本质上是最优化问题的求解过程。

a. 参数与超参数

  优化可以分为参数优化和超参数优化两个方面:

  1. 参数优化: ( x ; θ ) (\mathbf{x}; \boldsymbol{\theta}) (x;θ) 中的 θ \boldsymbol{\theta} θ 称为模型的参数,这些参数通过优化算法进行学习。这些参数可以通过梯度下降等算法迭代地更新,以使损失函数最小化。

  2. 超参数优化: 除了可学习的参数 θ \boldsymbol{\theta} θ 外,还有一类参数用于定义模型结构或优化策略,这些参数被称为超参数。例如,聚类算法中的类别个数、梯度下降法中的学习率、正则化项的系数、神经网络的层数、支持向量机中的核函数等都是超参数。与可学习的参数不同,超参数的选取通常是一个组合优化问题,很难通过优化算法自动学习。通常,超参数的设定是基于经验或者通过搜索的方法对一组超参数组合进行不断试错调整。

b. 梯度下降法

  在机器学习中,最简单而常用的优化算法之一是梯度下降法。梯度下降法用于最小化一个函数,通常是损失函数或者风险函数。这个函数关于模型参数(权重)的梯度指向了函数值增加最快的方向,梯度下降法利用这一信息来更新参数,使得函数值逐渐减小。

梯度下降法的迭代公式

θ t + 1 = θ t − α ∂ R D ( θ ) ∂ θ \boldsymbol{\theta}_{t+1} = \boldsymbol{\theta}_t - \alpha \frac{\partial \mathcal{R}_{\mathcal{D}}(\boldsymbol{\theta})}{\partial \boldsymbol{\theta}} θt+1=θtαθRD(θ)

其中:

  • θ t \boldsymbol{\theta}_t θt 是第 (t) 次迭代时的参数值。
  • α \alpha α 是学习率,控制参数更新的步长。
  • R D ( θ ) \mathcal{R}_{\mathcal{D}}(\boldsymbol{\theta}) RD(θ) 是风险函数,也可以是损失函数,表示在训练集 (\mathcal{D}) 上的性能。

梯度下降法的目标是通过迭代调整参数,使得风险函数最小化。

具体的参数更新公式

参数更新公式可以具体化为:

θ t + 1 = θ t − α 1 N ∑ n = 1 N ∂ L ( y ( n ) , f ( x ( n ) ; θ ) ) ∂ θ \boldsymbol{\theta}_{t+1} = \boldsymbol{\theta}_t - \alpha \frac{1}{N} \sum_{n=1}^{N} \frac{\partial \mathcal{L}(y^{(n)}, f(\mathbf{x}^{(n)}; \boldsymbol{\theta}))}{\partial \boldsymbol{\theta}} θt+1=θtαN1n=1NθL(y(n),f(x(n);θ))

其中:

  • N N N 是训练集中样本的数量。
  • L ( y ( n ) , f ( x ( n ) ; θ ) ) \mathcal{L}(y^{(n)}, f(\mathbf{x}^{(n)}; \boldsymbol{\theta})) L(y(n),f(x(n);θ)) 是损失函数,表示模型对样本 n n n 的预测误差。
学习率的选择

  学习率 α \alpha α 是一个关键的超参数,影响着参数更新的步长。选择合适的学习率很重要,过小的学习率可能导致收敛速度过慢,而过大的学习率可能导致参数在优化过程中发散。

  梯度下降法的一种改进是使用自适应学习率的变体,如 Adagrad、RMSprop 和 Adam 等。这些算法能够根据参数的历史梯度自动调整学习率,从而更灵活地适应不同参数的更新需求。

c. 随机梯度下降

在这里插入图片描述

批量梯度下降法 (BGD)

  在批量梯度下降法中,每一次迭代都要计算整个训练集上的梯度,然后更新模型参数,这导致了在大规模数据集上的高计算成本和内存要求。其迭代更新规则如下:

θ t + 1 = θ t − α ∇ R D ( θ t ) \theta_{t+1} = \theta_t - \alpha \nabla \mathcal{R}_{\mathcal{D}}(\theta_t) θt+1=θtαRD(θt)

其中, α \alpha α 是学习率, ∇ R D ( θ t ) \nabla \mathcal{R}_{\mathcal{D}}(\theta_t) RD(θt) 是整个训练集上损失函数关于参数 θ t \theta_t θt 的梯度。

随机梯度下降法 (SGD)

  随机梯度下降法通过在每次迭代中仅使用一个样本来估计梯度,从而减小了计算成本。其迭代更新规则如下:

θ t + 1 = θ t − α ∇ L ( θ t , x i , y i ) \theta_{t+1} = \theta_t - \alpha \nabla \mathcal{L}(\theta_t, \mathbf{x}_i, y_i) θt+1=θtαL(θt,xi,yi)

其中, ∇ L ( θ t , x i , y i ) \nabla \mathcal{L}(\theta_t, \mathbf{x}_i, y_i) L(θt,xi,yi) 是单个样本 ( x i , y i ) (\mathbf{x}_i, y_i) (xi,yi) 上的损失函数关于参数 θ t \theta_t θt 的梯度。

小批量梯度下降法 (Mini-batch Gradient Descent)

  为了权衡计算成本和梯度估计的准确性,通常使用小批量梯度下降法。该方法在每次迭代中使用一个小批量(mini-batch)样本来估计梯度,从而兼具计算效率和梯度准确性。

θ t + 1 = θ t − α ∇ L batch ( θ t , Batch ) \theta_{t+1} = \theta_t - \alpha \nabla \mathcal{L}_{\text{batch}}(\theta_t, \text{Batch}) θt+1=θtαLbatch(θt,Batch)

其中, ∇ L batch ( θ t , Batch ) \nabla \mathcal{L}_{\text{batch}}(\theta_t, \text{Batch}) Lbatch(θt,Batch) 是小批量样本集 Batch \text{Batch} Batch 上的损失函数关于参数 θ t \theta_t θt 的梯度。

SGD 的优势
  1. 计算效率: 相对于批量梯度下降法,SGD的计算成本更低,尤其在大规模数据集上更为实用。

  2. 在线学习: SGD具有在线学习的性质,每次迭代只需一个样本,使得模型可以逐步适应新数据。

  3. 跳出局部极小值: 由于每次迭代使用的样本不同,SGD有助于跳出局部极小值,从而更有可能找到全局最优解。

SGD 的挑战
  1. 不稳定性: SGD中每次迭代的更新可能受到单个样本的影响,导致更新方向波动较大。

  2. 学习率调整: 选择合适的学习率对于SGD的性能至关重要。学习率过大可能导致不稳定性,而学习率过小可能使模型收敛缓慢。

  3. 需调参: SGD的性能依赖于学习率、小批量大小等超参数的选择,需要进行调参。

在实践中,通常会使用学习率衰减、动量法等技术来改进SGD的性能。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1313929.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

[Linux] Tomcat部署和优化

一、Tomcat相关知识 1.1 Tomcat的简介 Tomcat 是 Java 语言开发的,Tomcat 服务器是一个免费的开放源代码的 Web 应用服务器,是 Apache 软件基金会的 Jakarta 项目中的一个核心项目,由 Apache、Sun 和其他一些公司及个人共同开发而成。 …

环境搭建及源码运行_java环境搭建_maven

1、介绍 1)管理项目依赖和版本 统一的项目依赖和版本管理 2)Maven支持多模块项目管理 通过定义父子模块的关系来管理多个子模块的构建和依赖关系。使用Maven可以实现多模块项目的统一管理和构建,从而提高项目的可维护性和可重用性。 3&#x…

【️Zookeeper是CP还是AP的?】

😊引言 🎖️本篇博文约3000字,阅读大约10分钟,亲爱的读者,如果本博文对您有帮助,欢迎点赞关注!😊😊😊 🖥️Zookeeper是CP还是AP的? ✅…

Docker及其使用思维导图

Docker的架构 构建分发运行镜像 Client(客户端):是Docker的用户端,可以接受用户命令和配置标识,并与Docker daemon通信。Images(镜像):是一个只读模板,含创建Docker容器…

Sui第八轮资助:七个项目获得资助

今天,Sui基金会宣布本月的资助获得者,他们因构建项目以推动Sui的采用和发展而获得资助。要获得资助,项目必须提交提案,详细说明他们正在构建的内容、预算明细、关键里程碑、团队经验以及对Sui社区的预期贡献。 以下七个项目致力于…

ACT、NAT、NATPT和EASY-IP

目录 一、ACL 1.ACL 2.ACL的两种应用匹配机制 3.ACL的基本类型 4.ACL命令操作 5.ACL实验: 4.ACL的应用原则: 5.匹配原则: 二、NAT 1.NAT的原理及作用: 2.NAT分类 3.NAT配置 三、EASY-ip实验 四、NATPT 五、通配符 …

解决el-table组件中,分页后数据的勾选、回显问题?

问题描述: 1、记录一个弹窗点击确定按钮后,table列表所有勾选的数据信息2、再次打开弹窗,回显勾选所有保存的数据信息3、遇到的bug:切换分页,其他页面勾选的数据丢失;点击确认只保存当前页的数据&#xff1…

spring boot集成mybatis和springsecurity实现登录认证功能

参考了很多网上优秀的教程,结合自己的理解,实现了登录认证功能,不打算把理论搬过来,直接上代码可能入门更快,文中说明都是基于我自己的理解写的,可能存在表述或者解释不对的情况,如果需要理论支…

linux内核使用ppm图片开机

什么是ppm图片 PPM(Portable Pixmap)是一种用于存储图像的文件格式。PPM图像文件以二进制或ASCII文本形式存储,并且是一种简单的、可移植的图像格式。PPM格式最初由Jef Poskanzer于1986年创建,并经过了多次扩展和修改。 PPM图像…

金蝶云星空协同开发环境应用内执行单据类型脚本

文章目录 金蝶云星空协同开发环境应用内执行单据类型脚本业务界面查询单据类型表数据导出数据执行数据库脚本单据类型xml检验是否执行成功检查数据库检查业务数据 金蝶云星空协同开发环境应用内执行单据类型脚本 业务界面 查询单据类型表数据 先使用类型中文在单据类型多语言…

Windows10之wsl-Linux子系统安装JDK、Maven环境

Windows10之wsl-Linux子系统安装JDK、Maven环境 文章目录 1.环境2.安装2.1安装JDK2.1安装maven 3.配置setting.xml4.下载编译项目插件5.总结 1.环境 首先需要在windwos10上安装wsl的Linux子系统,我选择的是CentOs的操作系统的镜像(之前的文章中采用的是docker拉取一…

嵌入式开发人员需要具备哪些能力?

大家好,今天给大家介绍嵌入式开发人员需要具备哪些能力,文章末尾附有分享大家一个资料包,差不多150多G。里面学习内容、面经、项目都比较新也比较全!可进群免费领取。 嵌入式开发人员需要具备以下能力: 熟练掌握C/C语…

腾讯云Linux云服务器禁Ping设置

腾讯云Linux服务器默认是允许ping包的,但是在一些情况下为了安全考虑起见,我们都会把服务器设置为禁ping的模式。 1、首先检查Linux服务器当前是否禁ping 执行命令: cat /proc/sys/net/ipv4/icmp_echo_ignore_all 备注: 0----代…

Android画布Canvas绘图scale translate,Kotlin

Android画布Canvas绘图scale & translate&#xff0c;Kotlin <?xml version"1.0" encoding"utf-8"?> <androidx.appcompat.widget.LinearLayoutCompat xmlns:android"http://schemas.android.com/apk/res/android"xmlns:app"…

maven jar sort

1&#xff09;往常项目结构lib包排序 2&#xff09;maven的默认是没有排序的

2019年第八届数学建模国际赛小美赛A题放射性产生的热量解题全过程文档及程序

2019年第八届数学建模国际赛小美赛 A题 放射性产生的热量 原题再现&#xff1a; 假设我们把一块半衰期很长的放射性物质做成一个特定的形状。在这种材料中&#xff0c;原子核在衰变时会以随机的方向释放质子。我们假设携带质子的能量是一个常数。质子在穿过致密物质时&#x…

【STM32】STM32学习笔记-OLED调试工具(09)

00. 目录 文章目录 00. 目录01. STM32调试方式02. OLED简介03. 0.96寸OLED模块04. 0.96寸OLED驱动IC05. 0.96寸OLED原理图06. 硬件电路07. OLED驱动函数08. 附录 01. STM32调试方式 串口调试&#xff1a;通过串口通信&#xff0c;将调试信息发送到电脑端&#xff0c;电脑使用串…

Tekton 克隆 git 仓库

Tekton 克隆 git仓库 介绍如何使用 Tektonhub 官方 git-clone task 克隆 github 上的源码到本地。 git-clone task yaml文件下载地址&#xff1a;https://hub.tekton.dev/tekton/task/git-clone 查看git-clone task yaml内容&#xff1a; 点击Install&#xff0c;选择一种…

微服务组件Sentinel的学习(3)

Sentinel 隔离和降级Feign整合Sentinel线程隔离熔断降级熔断策略 授权规则&#xff1a;自定义异常 隔离和降级 虽然限流可以尽量避免因高并发而引起的服务故障&#xff0c;但服务还会因为其它原因而故障。而要将这些故障控制在一定范用避免雪崩&#xff0c;就要靠线程隔离(舱壁…

jmeter,读取CSV文件数据的循环控制

1、构造csv数据 保存文件时需要注意文件的编码格式 id,name,limit,status,address,start_time 100,小米100,1000,1,某某会展中心101,2023/8/20 14:20 101,小米101,1001,1,某某会展中心102,2023/8/21 14:20 2、在线程组下添加【CSV数据文件设置】元件 3、CSV文件数据的循环控…