神经网络的工程基础(一)——利用PyTorch实现梯度下降法

news2025/1/10 11:38:29

相关说明

这篇文章的大部分内容参考自我的新书《解构大语言模型:从线性回归到通用人工智能》,欢迎有兴趣的读者多多支持。
本文涉及到的代码链接如下:regression2chatgpt/ch06_optimizer/gradient_descent.ipynb

本文将讨论利用PyTorch实现梯度下降法的细节。这是神经网络模型的共同工程基础。

关于大语言模型的内容,推荐参考这个专栏。

内容大纲

  • 相关说明
  • 一、为什么需要了解实现细节?
  • 二、梯度下降法的理论基础
  • 三、代码实现

一、为什么需要了解实现细节?

在我们使用经典机器学习模型对数据建模时,首先会从实际应用场景出发,初步分析数据的特征,获取灵感和直觉;然后,通过数学的抽象和变换,为问题选择合适的模型架构;最后,使用Python开源的算法库实现最终的模型,其中模型的参数已经被估计出来。

从软件设计的角度来讲,Python开源算法库在抽象(Abstraction)方面做得非常出色。它有效地隐藏了模型构建和训练的底层实现细节,使我们只需关注高层的概念和操作,即提供的一系列函数接口(API)。通过这些接口,通常只需几十行代码就能完成模型的构建和训练。在这个过程中,无须过多考虑模型背后复杂的数学计算,计机估计模型参数的算法实现也不再成为障碍。在理想情况下,所有底层的复杂性都被完美抽象,数据科学家的工作更加轻松和便捷(当然,作为硬币的另一面,这也可能导致数据科学家的门槛降低,进而影响相关职位的数量和薪水)。然而,不幸的是(或者幸运的是),由于模型涉及复杂的数学抽象和计算,即使软件设计和抽象再完美,也无法完全掩盖其复杂性,某些细节仍然可能“泄漏”出来,影响用户对系统的理解和操作,这就是抽象泄漏(Leaky Abstraction)。

举个例子,在训练逻辑回归模型时,某些数据集可能导致开源算法库出现错误,无法估计模型参数。对于相对经典或简单的模型,抽象泄漏的情况较少出现。然而,对于更复杂的模型,例如神经网络领域的深度学习和语言大模型,可能出现大量的抽象泄漏问题。如果不理解底层实现的细节,在这些领域将寸步难行:从理论角度来看,无法理解模型的精髓,就难以有效地优化模型,无法达到预期的模型效果;从实际应用角度来看,遇到程序问题难以修复,训练时间过长,除了参考示例实现,很难灵活运用算法库,也无法根据需求调整模型架构。

因此,这个系列的文章将深入研究开源算法库的核心细节,探讨如何基于模型的数学公式计算出相应的参数估计值。更具学术性的表述是——探讨解决最优化问题的算法。最优化问题有多种求解方法,不同算法适用于不同的模型,并在解决不同类型的问题上各有优势。鉴于篇幅限制,本文将重点关注最基础的算法:梯度下降法。后续的文章将继续讨论如何实现随机梯度下降法及其各种变种。

二、梯度下降法的理论基础

对于任何一个模型,它都对应着一个损失函数 L L L,假设选取的初始点为 a 0 , b 0 a_0,b_0 a0,b0;现在将这两个点稍稍移动一点,得到 a 1 , b 1 a_1,b_1 a1,b1。根据泰勒级数(Taylor Series)1,暂时只考虑一阶导数2,可以得到公式(1),其中 ∆ a = a 1 − a 0 , ∆ b = b 1 − b 0 ∆a = a_1 - a_0,∆b = b_1 - b_0 a=a1a0,b=b1b0
∆ L = L ( a 1 , b 1 ) − L ( a 0 , b 0 ) ≈ ∂ L ∂ a ∆ a + ∂ L ∂ b ∆ b (1) ∆L = L(a_1,b_1) - L(a_0,b_0) ≈\frac{∂L}{∂a} ∆a + \frac{∂L}{∂b} ∆b \tag{1} L=L(a1,b1)L(a0,b0)aLa+bLb(1)
如果令
( ∆ a , ∆ b ) = − η ( ∂ L / ∂ a , ∂ L / ∂ b ) (2) (∆a,∆b)= -η(∂L/∂a,∂L/∂b) \tag{2} (a,b)=η(L/a,L/b)(2)

其中 η > 0 η > 0 η>0,可以得到: ∆ L ≈ − η [ ( ∂ L / ∂ a ) 2 + ( ∂ L / ∂ b ) 2 ] ≤ 0 ∆L ≈ -η[(∂L/∂a)^2 + (∂L/∂b)^2] \le 0 Lη[(L/a)2+(L/b)2]0。这说明如果按公式(2)移动参数,损失函数的函数值始终是下降的,这正是我们想要达到的效果。如果一直重复这种移动,数学上可以证明,损失函数能最终得到它的最小值,整个过程就像鸡蛋在圆底锅里滚动一样,于是可以得到参数的迭代公式,见公式(3)。
a k + 1 = a k − η ∂ L ∂ a b k + 1 = b k − η ∂ L ∂ b (3) a_{k + 1} = a_k - η \frac{∂L}{∂a} \\ b_{k + 1} = b_k - η \frac{∂L}{∂b} \tag{3} ak+1=akηaLbk+1=bkηbL(3)

也可以换一个类比角度来理解梯度下降法的核心思想。想象你站在一个山坡上,目标是要找到最低的山谷。公式(3)就如同导航,在山坡上指引着你下山的方向。如果地势是向下的(损失函数的偏导数 ∂ L ⁄ ∂ a < 0 ∂L⁄∂a < 0 La<0),那么你会朝着这个方向迈出一步;相反,如果地势是向上的( ∂ L ⁄ ∂ a > 0 ∂L⁄∂a > 0 La>0),那么你会退回一步,避免走向更高的地方。

在数学上,向量 ∇ L = ( ∂ L / ∂ a , ∂ L / ∂ b ) ∇L = (∂L/∂a,∂L/∂b) L=(L/a,L/b)被称为损失函数L的梯度。这也是公式(3)表示的算法被称为梯度下降法的原因。同时可以证明,函数的梯度正好是函数值下降得最快的方向,因此梯度下降法也是最高效的“下降”方式。

综上,可以将梯度下降法的主要算法归纳为三步:根据当前参数和训练数据计算模型损失;计算当前的损失函数梯度;利用梯度,迭代更新模型参数,如图1所示。

图1

图1

需要强调的是,从严谨的数学角度来看,多元可微函数 L L L在点 P P P上的梯度,实际上是由 L L L在点 P P P上各个变量的偏导数构成的向量。然而在人工智能领域,尤其是神经网络领域,为了简化表达,我们通常会用“变量的梯度” 3这一术语来指代该变量在特定情况下的偏导数或者对偏导数的估计值。

三、代码实现

下面将探索如何利用PyTorch提供的封装函数来实现梯度下降法。实现梯度下降法涉及3个关键步骤。

  1. 根据当前参数和训练数据,计算模型损失。
  2. 计算当前的损失函数梯度:利用模型定义的损失函数及训练数据,计算得到当前损失函数的梯度。需要注意的是,损失函数梯度的计算依赖于损失函数的数学表达式、用于梯度计算的训练数据,以及当前的参数估计值。这一步可以由PyTorch封装好的反向传播算法4(Back Propagation,BP)来完成。
  3. 利用梯度,更新模型参数:在计算得到损失函数的当前梯度后,利用这个梯度来迭代更新模型参数的估计值。这一步可以由PyTorch提供的优化算法函数(例如torch.optim.SGD)来实现。

首先进行一些准备工作,包括生成训练所需的数据和定义模型的结构。尽管这部分代码相对简单,但仍需注意以下两点。

  1. 在程序清单1(完整代码)的第2—4行,对变量x进行归一化处理。这一步的目的在于保证梯度下降法的稳定性。实际上,读者可以很容易地修改代码,不对x进行归一化处理,但会影响梯度下降法的稳定性,进而可能导致无法收敛的情况。在实际建模过程中,几乎会对每个变量进行归一化处理,以确保模型的稳健性和可靠性。
  2. 在程序清单1的第9—28行,通过继承torch.nn.Module的方式来定义线性回归模型。在具体的实现中,需要重写两个核心函数:__init__和forward。__init__函数定义了模型所需的参数及相应的初始值,forward函数中描述了如何利用这些参数获得模型的预测结果5
程序清单1 定义模型和产生训练数据
 1 |  # 产生训练用的数据
 2 |  x_origin = torch.linspace(100, 300, 200)
 3 |  # 将变量x归一化,否则梯度下降法很容易不稳定
 4 |  x = (x_origin - torch.mean(x_origin)) / torch.std(x_origin)
 5 |  epsilon = torch.randn(x.shape)
 6 |  y = 10 * x + 5 + epsilon
 7 |  
 8 |  # 为了使用PyTorch的高层封装函数,通过继承Module类来定义函数
 9 |  class Linear(torch.nn.Module):
10 |      def __init__(self):
11 |          """
12 |          定义线性回归模型的参数:a, b
13 |          """
14 |          super().__init__()
15 |          self.a = torch.nn.Parameter(torch.zeros(()))
16 |          self.b = torch.nn.Parameter(torch.zeros(()))
17 |  
18 |      def forward(self, x):
19 |          """
20 |          根据当前的参数估计值,得到模型的预测结果
21 |          参数
22 |          ----
23 |          x :torch.tensor,变量x
24 |          返回
25 |          ----
26 |          y_pred :torch.tensor,模型预测值
27 |          """
28 |          return self.a * x + self.b
29 |  
30 |      def string(self):
31 |          """
32 |          输出当前模型的结果
33 |          """
34 |          return f'y = {self.a.item():.2f} * x + {self.b.item():.2f}'

接下来,进入核心的算法实现阶段,如程序清单2所示,其中包括定义模型的损失函数、计算损失函数的梯度,以及计算迭代更新参数估计值。这些步骤相对固定,几乎适用于所有模型。或许第13行中的“将上一次的梯度清零”操作可能会引发一些读者的困惑。实际上,这行代码与反向传播算法的工作机制息息相关,后续的文章[TODO]将对其进行详细的解释和讨论。

程序清单2 梯度下降法
 1 |  # 定义模型
 2 |  model = Linear()
 3 |  # 确定最优化算法
 4 |  learning_rate = 0.1
 5 |  optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
 6 |  
 7 |  for t in range(20):
 8 |      # 根据当前的参数估计值,得到模型的预测结果
 9 |      # 也就是调用forward函数
10 |      y_pred = model(x)
11 |      # 计算损失函数
12 |      loss = (y - y_pred).pow(2).mean()
13 |      # 将上一次的梯度清零
14 |      optimizer.zero_grad()
15 |      # 触发反向传播算法,计算损失函数的梯度
16 |      loss.backward()
17 |      # 迭代更新模型参数的估计值
18 |      optimizer.step()

本章运用PyTorch提供的高级封装函数来实现梯度下降法。尽管如此,整个算法的核心难点仍然被这些函数隐藏了,其中有两个关键函数起到了重要作用。首先是optimizer.step(),负责实现参数的迭代更新,其细节相对简单,可以轻松地实现,如图2所示;其次是负责反向传播算法的loss.backward()函数,其实现相当复杂,将在后续的文章[TODO]中详细讨论。

图2

图2


  1. 回顾一下泰勒一阶展开式,假设 f ( x 1 , x 2 , ⋯ , x n ) f(x_1,x_2,⋯,x_n) f(x1,x2,,xn)是一个一阶可导的函数,即 ∂ 2 f ∂ x i ∂ x j \frac{∂^2 f}{∂x_i ∂x_j } xixj2f都存在,则 f ( x 1 , x 2 , ⋯ , x n ) = f ( a 1 , a 2 , ⋯ , a n ) + ∑ i = 1 n ∂ f ( a 1 , a 2 ⋯ , a n ) ∂ x i ( x i − a i ) + o ( ∑ i ∣ x i − a i ∣ ) f(x_1,x_2,⋯,x_n)=f(a_1,a_2,⋯,a_n)+\sum_{i = 1}^n\frac{∂f(a_1,a_2⋯,a_n)}{∂x_i}(x_i-a_i) +o(\sum_i|x_i-a_i |) f(x1,x2,,xn)=f(a1,a2,,an)+i=1nxif(a1,a2,an)(xiai)+o(ixiai)其中, o ( ∑ i ∣ x i − a i ∣ ) o(\sum_i|x_i-a_i |) o(ixiai)表示相对于 ∑ i ∣ x i − a i ∣ \sum_i|x_i-a_i | ixiai的极小值。因此在x很靠近a时,有 f ( x ) ≈ f ( a ) + ∑ i ∂ f ( a ) ∂ x i ( x i − a i ) f(x) ≈ f(a) + \sum_i\frac{∂f(a)}{∂x_i}(x_i - a_i) f(x)f(a)+ixif(a)(xiai)。但是当x离a较远时,上述近似关系的误差就很大了。 ↩︎

  2. 如果考虑多阶导数,可以得到其他的最优化问题求解算法,比如使用二阶导数的共轭梯度法(Conjugate Gradient Method)等。这些算法对于特定问题可以更快地得到收敛解,但它们对损失函数的要求更多,计算复杂度也更高,并不适合神经网络和分布式机器学习,所以这里不做深入探讨。 ↩︎

  3. 这一概念在实际应用中非常重要,因为在优化算法中,需要计算或者估计损失函数关于某个参数的偏导数,以指导这个参数的更新。然而,若要准确地计算梯度,就需要对多元函数的每个偏导数进行计算,这让准确的数学表述变得非常烦琐。因此,通过使用“变量的梯度”这一术语,能够使表达更简洁,并在实际操作中更加便利地进行参数更新和优化。 ↩︎

  4. 在PyTorch中,算法的正式名字是自动微分(Autograd或Automatic Differentiation)算法。这两者指的其实是同一个算法。 ↩︎

  5. 或许有些读者会对“为什么将模型的预测函数称为forward”感到好奇。这是因为在神经网络领域,常常将计算模型的预测结果并评估损失的步骤称为向前传播,而将更新模型参数的步骤称为向后传播。这种命名习惯在PyTorch这个主要应用于神经网络的开源工具中得到了延续。 ↩︎

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1696667.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

高斯过程学习笔记

目录 基础知识 例子 推荐 A Visual Exploration of Gaussian Processes (distill.pub) AB - Introduction to Gaussian Processes - Part I (bridg.land) 基础知识 高斯过程回归&#xff08;Gaussian Process Regression&#xff09; - 知乎 (zhihu.com) 高斯过程&#x…

VS2022上通过C++绘图库ROOT库绘制一个3D曲面图

ROOT库提供了强大的交互式图形功能。通过使用ROOT库的TCanvas和TApplication类&#xff0c;可以创建一个交互式的图形窗口&#xff0c;可以对图形进行缩放、平移、旋转等操作&#xff0c;并且可以通过鼠标和键盘与图形进行交互&#xff0c;这点实在是太厉害了&#xff0c;也就是…

贪心算法4(c++)

过河的最短时间 题目描述 输入 在漆黑的夜里&#xff0c;N位旅行者来到了一座狭窄而且没有护栏的桥边。如果不借助手电筒的话&#xff0c;大家是无论如何也不敢过桥去的。不幸的是&#xff0c;N个人一共只带了一只手电筒&#xff0c;而桥窄得只够让两个人同时过&#xff0c;如果…

Java进阶学习笔记21——泛型概念、泛型类、泛型接口

泛型&#xff1a; 定义类、接口、方法的时候&#xff0c;同时声明了一个或者多个类型变量&#xff08;如: <E>&#xff09;,称之为泛型类、泛型接口、泛型方法&#xff0c;我们统称之为泛型。 说明这是一个泛型类。 如果不使用泛型&#xff0c;我们可以往ArrayList中传…

【多线程开发 2】从代码到实战TransmittableThreadLocal

【多线程开发 2】从代码到实战TransmittableThreadLocal 本文将从以下几个点讲解TransmittableThreadLocal(为了方便写以下简称ttl)&#xff1a; 前身 是什么&#xff1f; 可以用来做什么&#xff1f; 源码原理 实战 前身 ThreadLocal 要了解ttl就要先了解Java自带的类…

【C语言】指针作为参数(传值调用vs传址调用)

前言 在前面讲了那些指针相关的内容后&#xff0c;是时候探讨一下指针有什么作用了。 在C语言中&#xff0c;指针有多种各不相同的应用&#xff0c;在本篇文章中&#xff0c;我们探讨一下指针作为函数参数的作用&#xff08;对比传值与传址两种不同函数调用方式&#xff09;。…

解决git克隆项目出现fatal无法访问git clone https://github.com/lvgl/lvgl.git

Windows 11系统 报错 $ git clone https://github.com/lvgl/lvgl.git Cloning into lvgl... fatal: unable to access https://github.com/lvgl/lvgl.git/: Failed to connect to github.com port 443 after 21141 ms: Couldnt connect to server 解决方法 git运行这两段代码…

008-Linux后台进程管理(作业控制:、jobs、fg、bg、ctrl + z、nohup)

文章目录 前言 1、& 2、ctrl z 3、jobs 4、fg&#xff1a;将后台进程调到前台执行 5、bg&#xff1a;将一个暂停的后台进程变为执行 6、&和nohup 总结 前言 有时候我们需要将一个进程放到后台去运行&#xff0c;或者将后台程序切换回前台&#xff0c;这时候就…

LabVIEW如何实现多张图拼接

在LabVIEW中实现相机多次拍摄进行拼接的过程&#xff0c;可以分为以下几个步骤&#xff1a;设置相机参数、控制相机拍摄、图像处理与拼接、显示和保存结果。以下是一个详细的实现方案&#xff1a; 1. 设置相机参数 首先需要配置相机的参数&#xff0c;例如分辨率、曝光时间、…

如何用ai打一场酣畅淋漓的数学建模比赛? 给考研加加分!

文章目录 数学建模比赛1. 数学建模是什么&#xff1f;2. 数学建模分工合作2.1 第一&#xff1a;组队和分工合作2.2 第二&#xff1a;充分的准备2.3 第三&#xff1a;比赛中写论文过程 3. 数学建模基本过程4. 2023全年数学建模竞赛时间轴5. 数学建模-资料大全6. 数学建模实战 数…

精品PPT | MES设计与实践,业务+架构+实施(免费下载))

【1】关注本公众号&#xff0c;转发当前文章到微信朋友圈 【2】私信发送 MES设计与实践 【3】获取本方案PDF下载链接&#xff0c;直接下载即可。 如需下载本方案PPT/WORD原格式&#xff0c;请加入微信扫描以下方案驿站知识星球&#xff0c;获取上万份PPT/WORD解决方案&#x…

Ant Design pro 6.0.0 搭建使用以及相关配置

一、背景 在选择一款比较合适的中台的情况下&#xff0c;挑选了有arco design、ant design pro、soybean、vue-pure-admin等中台系统&#xff0c;经过筛选就选择了ant design pro。之前使用过arco design 搭建通过组件库拼装过后台管理界面&#xff0c;官方文档也比较全&#…

软件安全复习

文章目录 第一章 软件安全概述1.1 信息定义1.2 信息的属性1.3 信息安全1.4 软件安全1.5 软件安全威胁及其来源1.5.1 软件缺陷与漏洞1.5.1.1 软件缺陷1.5.1.2 漏洞1.5.1.3 软件漏洞1.5.1.4 软件缺陷和漏洞的威胁 1.5.2 恶意软件1.5.2.1 恶意软件的定义1.5.2.2 恶意软件的威胁 1.…

《安富莱嵌入式周报》第337期:超高性能信号量测量,协议分析的开源工具且核心算法开源,工业安全应用的双通道数字I/O模组,低成本脑机接口,开源音频合成器

周报汇总地址&#xff1a;http://www.armbbs.cn/forum.php?modforumdisplay&fid12&filtertypeid&typeid104 视频版&#xff1a; https://link.zhihu.com/?targethttps%3A//www.bilibili.com/video/BV1PT421S7TR/ 《安富莱嵌入式周报》第337期&#xff1a;超高性…

镜子摆放忌讳多

镜子是我们日常生活中不可或缺的物品。在风水中&#xff0c;镜子的作用非常多&#xff0c;能够起到一定的作用。镜子的摆放位置也是非常有讲究的&#xff0c;摆放不好会直接影响到家人的事业、财运、婚姻乃至健康等诸多方面。 第一个风水忌讳&#xff0c;镜子对大门。大门的正前…

开发自定义菜单之创建菜单

文章目录 申请测试账号换取Token接口测试提交自定义菜单查看效果校验菜单配置清空菜单配置结束语 申请测试账号 https://mp.weixin.qq.com/debug/cgi-bin/sandboxinfo?actionshowinfo&tsandbox/index 或 得到appid和secret 换取Token 使用appid和secret换取token令牌…

企业内网开源OA服务器(办公自动化系统),搭建O2OA基于Linux(openEuler、CentOS8)

本实验环境为openEuler系统(以server方式安装)&#xff08;CentOS8基本一致&#xff0c;可参考本文) 目录 知识点实验下载安装O2OA安装mysql配置O2OA 知识点 “O2OA” 是一个开源的、基于Java的办公自动化&#xff08;Office Automation&#xff09;系统。其名称中的“O2OA”…

Linux操作指令大全

目录 &#x1f349;引言 &#x1f349; 基础命令 &#x1f348;pwd &#x1f348;cd &#x1f348;ls &#x1f348;mkdir &#x1f348;rmdir &#x1f348;cp &#x1f348;mv &#x1f348;rm &#x1f349; 文件操作命令 &#x1f348;cat &#x1f348;tac …

Web课外练习9

<!DOCTYPE html> <html> <head><meta charset"utf-8"><title>邮购商品业务</title><!-- 引入vue.js --><script src"./js/vue.global.js" type"text/javascript"></script><link rel&…

关于微信小程序低功耗蓝牙ECharts实时刷新

最近搞了这方面的东西&#xff0c;是刚刚开始接触微信小程序&#xff0c;因为是刚刚开始接触蓝牙设备&#xff0c;所以这篇文章适合既不熟悉小程序&#xff0c;又不熟悉蓝牙的新手看。 项目要求是获取到蓝牙传输过来的数据&#xff0c;并显示成图表实时显示&#xff1b; 我看了…