解锁机器学习-梯度下降:从技术到实战的全面指南

news2024/11/27 0:44:35

目录

  • 一、简介
    • 什么是梯度下降?
    • 为什么梯度下降重要?
  • 二、梯度下降的数学原理
    • 代价函数(Cost Function)
    • 梯度(Gradient)
    • 更新规则
      • 代码示例:基础的梯度下降更新规则
  • 三、批量梯度下降(Batch Gradient Descent)
    • 基础算法
    • 代码示例
  • 四、随机梯度下降(Stochastic Gradient Descent)
    • 基础算法
    • 代码示例
    • 优缺点
  • 五、小批量梯度下降(Mini-batch Gradient Descent)
    • 基础算法
    • 代码示例
    • 优缺点

本文全面深入地探讨了梯度下降及其变体——批量梯度下降、随机梯度下降和小批量梯度下降的原理和应用。通过数学表达式和基于PyTorch的代码示例,本文旨在为读者提供一种直观且实用的视角,以理解这些优化算法的工作原理和应用场景。

关注TechLead,分享AI全维度知识。作者拥有10+年互联网服务架构、AI产品研发经验、团队管理经验,同济本复旦硕,复旦机器人智能实验室成员,阿里云认证的资深架构师,项目管理专业人士,上亿营收AI产品研发负责人。

file

一、简介

梯度下降(Gradient Descent)是一种在机器学习和深度学习中广泛应用的优化算法。该算法的核心思想非常直观:找到一个函数的局部最小值(或最大值)通过不断地沿着该函数的梯度(gradient)方向更新参数。

什么是梯度下降?

简单地说,梯度下降是一个用于找到函数最小值的迭代算法。在机器学习中,这个“函数”通常是损失函数(Loss Function),该函数衡量模型预测与实际标签之间的误差。通过最小化这个损失函数,模型可以“学习”到从输入数据到输出标签之间的映射关系。

为什么梯度下降重要?

  1. 广泛应用:从简单的线性回归到复杂的深度神经网络,梯度下降都发挥着至关重要的作用。

  2. 解决不可解析问题:对于很多复杂的问题,我们往往无法找到解析解(analytical solution),而梯度下降提供了一种有效的数值方法。

  3. 扩展性:梯度下降算法可以很好地适应大规模数据集和高维参数空间。

  4. 灵活性与多样性:梯度下降有多种变体,如批量梯度下降(Batch Gradient Descent)、随机梯度下降(Stochastic Gradient Descent)和小批量梯度下降(Mini-batch Gradient Descent),各自有其优点和适用场景。


二、梯度下降的数学原理

file
在深入研究梯度下降的各种实现之前,了解其数学背景是非常有用的。这有助于更全面地理解算法的工作原理和如何选择合适的算法变体。

代价函数(Cost Function)

在机器学习中,代价函数(也称为损失函数,Loss Function)是一个用于衡量模型预测与实际标签(或目标)之间差异的函数。通常用 ( J(\theta) ) 来表示,其中 ( \theta ) 是模型的参数。

file

梯度(Gradient)

file

更新规则

file

代码示例:基础的梯度下降更新规则

import numpy as np

def gradient_descent_update(theta, grad, alpha):
    """
    Perform a single gradient descent update.
    
    Parameters:
    theta (ndarray): Current parameter values.
    grad (ndarray): Gradient of the cost function at current parameters.
    alpha (float): Learning rate.
    
    Returns:
    ndarray: Updated parameter values.
    """
    return theta - alpha * grad

# Initialize parameters
theta = np.array([1.0, 2.0])
# Hypothetical gradient (for demonstration)
grad = np.array([0.5, 1.0])
# Learning rate
alpha = 0.01

# Perform a single update
theta_new = gradient_descent_update(theta, grad, alpha)
print("Updated theta:", theta_new)

输出:

Updated theta: [0.995 1.99 ]

在接下来的部分,我们将探讨梯度下降的几种不同变体,包括批量梯度下降、随机梯度下降和小批量梯度下降,以及一些高级的优化技巧。通过这些内容,你将能更全面地理解梯度下降的应用和局限性。


三、批量梯度下降(Batch Gradient Descent)

file
批量梯度下降(Batch Gradient Descent)是梯度下降算法的一种基础形式。在这种方法中,我们使用整个数据集来计算梯度,并更新模型参数。

基础算法

批量梯度下降的基础算法可以概括为以下几个步骤:

file

代码示例

下面的Python代码使用PyTorch库演示了批量梯度下降的基础实现。

import torch

# Hypothetical data (features and labels)
X = torch.tensor([[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]], requires_grad=True)
y = torch.tensor([[1.0], [2.0], [3.0]])

# Initialize parameters
theta = torch.tensor([[0.0], [0.0]], requires_grad=True)

# Learning rate
alpha = 0.01

# Number of iterations
n_iter = 1000

# Cost function: Mean Squared Error
def cost_function(X, y, theta):
    m = len(y)
    predictions = X @ theta
    return (1 / (2 * m)) * torch.sum((predictions - y) ** 2)

# Gradient Descent
for i in range(n_iter):
    J = cost_function(X, y, theta)
    J.backward()
    with torch.no_grad():
        theta -= alpha * theta.grad
    theta.grad.zero_()

print("Optimized theta:", theta)

输出:

Optimized theta: tensor([[0.5780],
        [0.7721]], requires_grad=True)

批量梯度下降的主要优点是它的稳定性和准确性,但缺点是当数据集非常大时,计算整体梯度可能非常耗时。接下来的章节中,我们将探索一些用于解决这一问题的变体和优化方法。


四、随机梯度下降(Stochastic Gradient Descent)

file
随机梯度下降(Stochastic Gradient Descent,简称SGD)是梯度下降的一种变体,主要用于解决批量梯度下降在大数据集上的计算瓶颈问题。与批量梯度下降使用整个数据集计算梯度不同,SGD每次只使用一个随机选择的样本来进行梯度计算和参数更新。

基础算法

随机梯度下降的基本步骤如下:

file

代码示例

下面的Python代码使用PyTorch库演示了SGD的基础实现。

import torch
import random

# Hypothetical data (features and labels)
X = torch.tensor([[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]], requires_grad=True)
y = torch.tensor([[1.0], [2.0], [3.0]])

# Initialize parameters
theta = torch.tensor([[0.0], [0.0]], requires_grad=True)

# Learning rate
alpha = 0.01

# Number of iterations
n_iter = 1000

# Stochastic Gradient Descent
for i in range(n_iter):
    # Randomly sample a data point
    idx = random.randint(0, len(y) - 1)
    x_i = X[idx]
    y_i = y[idx]

    # Compute cost for the sampled point
    J = (1 / 2) * torch.sum((x_i @ theta - y_i) ** 2)
    
    # Compute gradient
    J.backward()

    # Update parameters
    with torch.no_grad():
        theta -= alpha * theta.grad

    # Reset gradients
    theta.grad.zero_()

print("Optimized theta:", theta)

输出:

Optimized theta: tensor([[0.5931],
        [0.7819]], requires_grad=True)

优缺点

SGD虽然解决了批量梯度下降在大数据集上的计算问题,但因为每次只使用一个样本来更新模型,所以其路径通常比较“嘈杂”或“不稳定”。这既是优点也是缺点:不稳定性可能帮助算法跳出局部最优解,但也可能使得收敛速度减慢。

在接下来的部分,我们将介绍一种折衷方案——小批量梯度下降,它试图结合批量梯度下降和随机梯度下降的优点。


五、小批量梯度下降(Mini-batch Gradient Descent)

file
小批量梯度下降(Mini-batch Gradient Descent)是批量梯度下降和随机梯度下降(SGD)之间的一种折衷方法。在这种方法中,我们不是使用整个数据集,也不是使用单个样本,而是使用一个小批量(mini-batch)的样本来进行梯度的计算和参数更新。

基础算法

小批量梯度下降的基本算法步骤如下:

file

代码示例

下面的Python代码使用PyTorch库演示了小批量梯度下降的基础实现。

import torch
from torch.utils.data import DataLoader, TensorDataset

# Hypothetical data (features and labels)
X = torch.tensor([[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]], requires_grad=True)
y = torch.tensor([[1.0], [2.0], [3.0], [4.0]])

# Initialize parameters
theta = torch.tensor([[0.0], [0.0]], requires_grad=True)

# Learning rate and batch size
alpha = 0.01
batch_size = 2

# Prepare DataLoader
dataset = TensorDataset(X, y)
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Mini-batch Gradient Descent
for epoch in range(100):
    for X_batch, y_batch in data_loader:
        J = (1 / (2 * batch_size)) * torch.sum((X_batch @ theta - y_batch) ** 2)
        J.backward()
        with torch.no_grad():
            theta -= alpha * theta.grad
        theta.grad.zero_()

print("Optimized theta:", theta)

输出:

Optimized theta: tensor([[0.6101],
        [0.7929]], requires_grad=True)

优缺点

小批量梯度下降结合了批量梯度下降和SGD的优点:它比SGD更稳定,同时比批量梯度下降更快。这种方法广泛应用于深度学习和其他机器学习算法中。

小批量梯度下降不是没有缺点的。选择合适的批量大小可能是一个挑战,而且有时需要通过实验来确定。

关注TechLead,分享AI全维度知识。作者拥有10+年互联网服务架构、AI产品研发经验、团队管理经验,同济本复旦硕,复旦机器人智能实验室成员,阿里云认证的资深架构师,项目管理专业人士,上亿营收AI产品研发负责人。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1089270.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

java io读取数据

1.字节流读取数据 2.字节流读取数据: read() package wwx;import jdk.swing.interop.SwingInterOpUtils;import java.io.*; import java.nio.charset.StandardCharsets;public class Test {public static void main(String[] args) {FileInp…

1.1 向量与线性组合

一、向量的基础知识 两个独立的数字 v 1 v_1 v1​ 和 v 2 v_2 v2​,将它们配对可以产生一个二维向量 v \boldsymbol{v} v: 列向量 v v [ v 1 v 2 ] v 1 v 的第一个分量 v 2 v 的第二个分量 \textbf{列向量}\,\boldsymbol v\kern 10pt\boldsymbol …

机器人制作开源方案 | 杠杆式6轮爬楼机器人

1. 功能描述 本文示例将实现R281b样机杠杆式6轮爬楼机器人爬楼梯的功能(注意:演示视频中为了增加轮胎的抓地力,在轮胎上贴了双面胶,请大家留意)。 2. 结构说明 杠杆式6轮爬楼机器人是一种专门用于爬升楼梯或不平坦地面…

thinkphp6 - 超详细使用阿里云短信服务发送验证码功能,TP框架调用对接阿里云短信发验证码(详细示例代码,一键复制开箱即用)

效果图 在thinkphp 5/6 框架(只要不是太低的版本就能用)中,实现接入调用阿里云短信服务详细教程,整个配置过程及示例代码保证小白也能轻松完成! 直接复制就行,改个阿里云参数就能用了。

MongoDB 简介和安装

一、MongoDB 相关概念 1.1 业务应用场景 1.1.1 三高需求 传统的关系型数据库(如 MySQL) ,在数据操作的 “三高” 需求以及应对 Web2.0 的网站需求面前,显得力不从心。”三高“ 需求如下所示,而 MongoDB 可应对 “三高…

C++入门指南:类和对象总结友元类笔记(下)

C入门指南:类和对象总结友元类笔记(下) 一、深度剖析构造函数1.1 构造函数体赋值1.2 初始化列表1.3 explicit关键字 二、static成员2.1 概念2.2 特性 三、友元3.1 友元函数3.2 友元类 四、 内部类4.1 概念4.2 特征 五、拷贝对象时的一些编译器优化六、深…

【C++】C++11 ——— 类的新功能

​ ​📝个人主页:Sherry的成长之路 🏠学习社区:Sherry的成长之路(个人社区) 📖专栏链接:C学习 🎯长路漫漫浩浩,万事皆有期待 上一篇博客:【C】STL…

JAVA中的垃圾回收

JVM规范说了并不需要必须回收方法区,不具有普遍性,永久代使用的是JVM之外的内存 引用计数:效率要比可达性分析要强,随时发现,随时回收,实现简单,但是可能存在内存泄漏 局部变量表,静态引用变量 …

Radius OTP实现VPN登录认证 安当加密

实现Radius OTP认证来完成VPN登录,需要使用支持Radius协议的VPN设备和客户端,以及一个Radius服务器来处理用户认证。 安当ASP身份认证平台作为一个企业通用的身份认证系统,集成了Radius认证功能,可满足所有支持radius认证的设备登…

【string题解 C++】字符串相乘 | 翻转字符串III:翻转单词

字符串相乘 题面 力扣(LeetCode)官网 - 全球极客挚爱的技术成长平台 给定两个以字符串形式表示的非负整数 num1 和 num2,返回 num1 和 num2 的乘积,它们的乘积也表示为字符串形式。 注意:不能使用任何内置的 BigIn…

基本微信小程序的电影票务系统-电影票预订系统

项目介绍 在传统的模式下,电影购票需要到当地的影院进行线下购票,既浪费时间,又消耗人力。线上购票可以满足消费者查看电影信息及购买电影票的需求,在一定程度上降低经济和时间成本[9]。目前已有一些手机app可以线上购票&#xf…

优化|优化处理可再生希尔伯特核空间的非参数回归中的协变量偏移

原文:Optimally tackling covariate shift in RKHS-based nonparametric regression. The Annals of Statistics, 51(2), pp.738-761, 2023.​ 原文作者:Cong Ma, Reese Pathak, Martin J. Wainwright​ 论文解读者:赵进 编者按: …

无名管道和有名管道

进程间通信的几种方式 无名管道(pipe) 无名管道(Unnamed Pipe)是一种在进程间进行单向通信的机制。它可以用于父进程与子进程之间的通信,或者同一父进程中不同子进程之间的通信。无名管道是一种特殊的文件&#xff0…

【(数据结构) —— 顺序表的应用-通讯录的实现】

(数据结构)—— 顺序表的应用-通讯录的实现 一.通讯录的功能介绍1.基于动态顺序表实现通讯录(1). 功能要求(2).重要思考 二. 通讯录的代码实现1.通讯录的底层结构(顺序表)(1)思路展示(2)底层代码实现(顺序表) 2.通讯录上层代码实现(通讯录结构…

IEDA 自动生成类注释和方法注释

1. 新建类,自动生成类注释的模板配置 File->Settings->Editor->File and Code Templates->Class /*** Description: TODO* Author: LT* Date: ${YEAR}-${MONTH}-${DAY}* Version:V3.5.6*/2. 通过快捷键,添加类注释和方法注释的模板设置 类…

六轴传感器 SH3001

SH3001简介 SH3001是Senodia公司生产的一款六轴姿态传感器,可测量芯片自身X、Y、Z轴的加速度以及角速度参数,通过姿态融合,进而得到姿态角。 三轴加速度计(Accelerometer):测量X、Y、Z轴的加速度。 三轴陀…

C++基础入门详解(一)

文章目录 前言命名空间展开命名空间使用C官方命名空间中的输入输出IO流输入输出的使用方法 缺省参数半缺省 函数重载 “你总以为机会无限,所以不珍惜眼前人” 前言 提到C,大部分人都想到的是复杂的语法和大量的规则,相对于java和python等语言…

基于YOLOv8模型的绵羊目标检测系统(PyTorch+Pyside6+YOLOv8模型)

摘要:基于YOLOv8模型的绵羊目标检测系统可用于日常生活中检测与定位车辆目标,利用深度学习算法可实现图片、视频、摄像头等方式的目标检测,另外本系统还支持图片、视频等格式的结果可视化与结果导出。本系统采用YOLOv8目标检测算法训练数据集…

基于simulink的单相光伏系统并网储能控制仿真

本仿真涉及到:基于电导增量法的最佳功率点跟踪算法、蓄电池恒流_恒压充电算法、光伏逆变器并网算法、双向(同步)DCDC电路设计等知识。 辐照度变化曲线: 模拟仿真,低辐照度情况,蓄电池与光伏逆变器共同向…

软件项目管理实践指南:有效规划、执行和控制

软件项目管理是使软件产品、应用程序和系统成功交付的重要规程。它有助于确保软件在预算内按时开发,同时满足客户的质量和功能需求。 软件项目管理是管理软件项目生命周期的一种有组织的方法,包括计划、开发、发布、维护和支持。它是在满足客户需求的同时…