从 0 手撸一个 pytorch

news2025/1/15 18:20:24

背景介绍

最近抽空看了下 Andrej Karpathy 的视频教程 building micrograd,教程的质量很高。教程不需要任何前置机器学习基础,只需要有高中水平的数学基础即可。整个教程从 0 到 1 手撸了一个类 pytorch 的机器学习库 micrograd,核心代码不到 100 行。虽然为了简化没有实现复杂的矩阵运算,但是对于理解 pytorch 的设计思想有很大帮助。

动手实践

为了验证 micrograd 的可用性,先基于 micrograd 实现了简单的线性回归算法。

首先构造出数据集,我使用随机数作为 x,通过线性回归确定结果后增加必要的噪声,对应的构造方法如下所示:

import numpy as np

def get_train_dataset(num_samples, noise):
    x = np.random.rand(num_samples)

    y = 4 * x + 3 + np.random.normal(0, noise, num_samples)

    return x.tolist(), y.tolist()

可以看到最终期望的结果为 y = 4 * x + 3

接下来实现训练流程,线性回归的模型的初始值都使用随机值,持续跟踪训练过程中损失值与对应参数的变化,实现如下所示:

import numpy as np
from micrograd.engine import Value

def zero_grad(w, b):
    w.grad = 0
    b.grad = 0

def step(w, b, learning_rate):
    w.data -= learning_rate * w.grad
    b.data -=  learning_rate * b.grad

def train_loop():
    dataset_x, dataset_y = get_train_dataset(10, 0.01)
    learning_rate = 0.1
    w = Value(np.random.rand())
    b = Value(np.random.rand())
    epoch = 40

    print(f"Init w {w.data}, b {b.data}")
    for idx in range(epoch):
        loss = 0

        for x, y in zip(dataset_x, dataset_y):
            x_value, y_value = Value(x), Value(y)
            y_pred = x_value * w + b
            current_loss = (y_value - y_pred) ** 2
            loss += current_loss.data

            zero_grad(w, b)
            current_loss.backward()
            step(w, b, learning_rate)

        print(f"Epoch {idx} got loss: {loss}, w {w.data}, b {b.data}")

上面的实现中使用 zero_grad() 方法重置参数的梯度,使用 step() 方法实际更新模型参数,训练流程就实现在 train_loop() 中。最终结果如下所示:
在这里插入图片描述
可以看到经过 40 轮训练后,损失值从最初的 55.69 下降至 0.0016,而参数 w, b 也接近期望的目标。从实践结果来看,micrograd 确实能实现简单模型的训练。

通过上面的实践来看,micrograd 最核心的就是 Value,按照 Andrej Karpathy 的说法,不到 100 行实现的 Value 就已经完成的 pytorch 中的 Tensor 90% 的功能了,除了这部分核心功能之外,pytorch 更多的是做了效率上的优化。

流程梳理

在机器学习中,模型训练都是基于 梯度下降 来更新模型的。模型训练的过程一般分为前向传播和反向传播:

  • 前向传播会根据训练数据确定对应的损失值,对应于上面的实现如下:
x_value, y_value = Value(x), Value(y)
y_pred = x_value * w + b
current_loss = (y_value - y_pred) ** 2

前向传播就是根据模型确定预测值 y_pred, 基于 MSE 确定损失值 (y - y_pred)^2。前向传播相对容易理解。

  • 反向传播就是根据确定的损失值进行模型参数的调整,从而降低损失值,对应的实现就是:
zero_grad(w, b)
current_loss.backward()
step(w, b, learning_rate)

上面最核心的功能就是调用 current_loss.backward() 确定各个参数对应的梯度,然后在 step() 方法中对参数的值进行更新。

参数更新的方案是相对明确,就是减去梯度与学习率之积实现。因此主要关注如何确定参数的梯度。梯度的计算存在如下所示的关注点:

  1. 数学运算各个元素对应的梯度如何计算,这部分就是微积分中导数的计算;
  2. 链式法则;
  3. 复杂模型中包含上亿参数,如何确定参数各自的梯度;

实现细节

micrograd 最核心的实现位于 engine.py,主要关注 Value 类的实现。

初始化过程

关注初始化过程可以看到 Value 中包含的元素,实现如下:

def __init__(self, data, _children=(), _op=''):
    self.data = data
    self.grad = 0
    self._backward = lambda: None
    self._prev = set(_children)
    self._op = _op # the op that produced this node, for graphviz / debugging / etc

初始化阶段可以看到 Value 中最重要的两个参数,data 保存的是元素中的原始数据,grad 保存的是当前元素对应的梯度。

_backward() 方法保存的是反向传播的方法,用于计算反向传播的梯度

_prev 保存的是当前节点前置的节点,比如 y = w * x 中节点 y 对应的 _prev 保存的是 wx。通过不断的获取 _prev 节点,即可还原完整的运算链路。

数学运算支持

Value 中支持了不同的数学运算,首先以加法为例,实现如下所示:

def __add__(self, other):
    other = other if isinstance(other, Value) else Value(other)

    # 加法运算得到结果,同样是 Value 元素

    out = Value(self.data + other.data, (self, other), '+')

    # 加法反向传播函数

    def _backward():
        self.grad += out.grad
        other.grad += out.grad
    out._backward = _backward

    return out

前向传播计算的实现比较简单,直接基于 data 进行计算,通过加法运算生成了结果 out。同时将参与运算的元素 selfother 保存至 self._prev 中,方便还原运算链路。

out 对应的反向传播的方法 _backward() 是基于链式法则实现。举例如下:

c = a + b

那么 ∂l/∂a = ∂l/∂c * ∂c/∂a,而 ∂c/∂a = 1,因此 ∂l/∂a = ∂l/∂c,因此加法中元素的梯度就等于其结果的梯度。

那么为什么实现是 self.grad += out.grad 而不是 self.grad = out.grad 呢,因为单个元素涉及多个运算链路时,梯度是不同链路确定的梯度之和。

这个也带来一个隐患,每次重新计算梯度之前,需要将原有的梯度重置为 0。对应于上面的 zero_grad() 的实现。了解 pytorch 应该也会注意到 pytorch 训练过程中也存在类似情况。

同样来查看乘法运算,对应的实现如下:


def __mul__(self, other):
    other = other if isinstance(other, Value) else Value(other)
    out = Value(self.data * other.data, (self, other), '*')

    def _backward():
        self.grad += other.data * out.grad
        other.grad += self.data * out.grad
    out._backward = _backward

    return out

主要关注反向传播的实现,可以看到同样是链路法则的推演,举例如下:

c = a * b

那么 ∂l/∂a = ∂l/∂c * ∂c/∂a,而 ∂c/∂a = b, 因此 ∂l/∂a = ∂l/∂c * b, 因此就可以理解上面的实现了。

反向传播

通过上面的运算过程可以看到,通过不断保存其前置元素至 self._prev 中,可以构建出完整的运算链路图。而在运算过程中,元素反向传播计算的梯度的方法 _backward() 也被确定。因此反向传播就是从后往前调用 _backward() 来实现的:


def backward(self):

    topo = []
    visited = set()
    # 根据前置元素的关系构建拓扑排序的元素列表,保证最终调用时是从后往前的

    def build_topo(v):
        if v not in visited:
            visited.add(v)
            for child in v._prev:
                build_topo(child)
            topo.append(v)
    build_topo(self)

    # 最后元素的梯度为 1, 依次计算前置元素的梯度

    self.grad = 1
    for v in reversed(topo):
        v._backward()

最终反向传播就是调用 _backward() 即可确定各个元素的梯度。

总结

通过上面的流程可以很容易理解机器学习模型训练框架的设计方案,这一套流程也完全适用于 pytorch,可以帮助更好地理解 pytorch 的训练流程。整体总结下实现思路:

  1. 前向传播过程中会逐层计算运行结果,并确定结果与运算元素梯度之前的关系,在结果元素梯度确定后就可以确定运算元素的梯度;
  2. 反向传播就是按照从后往前依次确认各个元素的梯度,方便后续根据梯度更新元素对应的值;

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1708094.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

关于VFX Graph的学习

关于VFX Graph的学习 转载自我的有道云笔记,目前内容不多,后续如果继续使用会更新。 前言 出于实习工作需要和毕设需要,我开始使用VFXGraph。 以前准备第一批作品集的时候,就简单地使用过,但是只是跟着教程一顿乱连…

PENDLE会是打响LSDFI赛道的第一枪吗?以bitget钱包为例

Pendle Finance是什么? PENDLE是Pendle Finance的原生通证,因此,在介绍Pendle币之前,我们需要对Pendle Finance有一个简单的了解。、 Pendle是一个建立在以太坊区块链上的无需许可的去中心化金融(DeFi)协议&#xff…

长三角智能科技高端盛会—南京人工智能展览会(南京智博会)

南京,作为一座历史悠久的文化名城,早已不仅仅以其深厚的文化底蕴和独特的自然风貌著称于世。而今,这座古老而又年轻的城市,正以其卓越的科技实力和创新精神,成为中国乃至全球科研领域的一颗璀璨明珠。南京不仅是中国三…

打造高质感的电子画册,这篇文章告诉你

​在数字化时代,电子画册作为一种全新的视觉传达方式,正逐渐成为各行各业展示形象、传播信息的重要工具。相较于传统的纸质画册,电子画册具有更高的质感、更好的互动性以及更低的制作成本,使得它愈发受到众多企业的青睐。那样怎么…

【umi-max】初识 antd pro

修改端口号 根目录下的 .env 文件: PORT8888目录结构 (umijs.org) 新增页面 在 umirc.ts 中进行配置。 新增页面 - Ant Design Pro 这里有一个配置 icon:string,可以在菜单加 icon 图标,默认使用 antd 的 icon 名,默认不适用二…

pands使用openpyxl引擎实现EXCEL条件格式

通过python的openpyxl库,实现公式条件格式。 实现内容:D列单元格不等于E列同行单元格时标红。 #重点是formula后面的公式不需要“”号。 from openpyxl.styles import Color, PatternFill, Font, Border from openpyxl.styles.differential import Dif…

记录深度学习GPU配置,下载CUDA与cuDnn

目标下载: cuda 11.0.1_451.22 win10.exe cudnn-11.0-windows-x64-v8.0.2.39.zip cuda历史版本网址 CUDA Toolkit Archive | NVIDIA Developer 自己下载过11.0.1版本 点击下载local版本,本地安装,有2个多GB,很大,我不喜欢network版本,容易掉线 cuDnn https://developer.nvi…

selenium源码学习

这里写自定义目录标题 欢迎使用Markdown编辑器新的改变功能快捷键合理的创建标题,有助于目录的生成如何改变文本的样式插入链接与图片如何插入一段漂亮的代码片生成一个适合你的列表创建一个表格设定内容居中、居左、居右SmartyPants 创建一个自定义列表如何创建一个…

Zookeeper的watch 机制

Watch机制介绍 我们可以把Watch理解成是注册在特定Znode上的触发器。当这个Znode发生改变,也就是调用了create,delete,setData方法的时候,将会触发Znode上注册的对应事件,请求Watch的客户端会收到异步通知 ZooKeeper…

基于LLM的优化器评测-非凸函数

基于LLM的优化器评测-非凸函数 目标函数测试结果测试代码测试日志 背景: ​ 很多时候我们需要为系统寻找最优的超参.比如模型训练,推理的量化等.本文尝试将LLM当成优化器,帮忙我们寻找最优的超参. 验证方法: 1.设计一个已知最优解的多项式,该多项式有3个变量(因为3个变量可以…

深度解析Java 11核心新特性

码到三十五 &#xff1a; 个人主页 < 免责声明 > 避免对文章进行过度解读&#xff0c;因为每个人的知识结构和认知背景都不同&#xff0c;没有一种通用的解决方案。对于文章观点&#xff0c;不必急于评判。融入其中&#xff0c;审视自我&#xff0c;尝试从旁观者角度认清…

基于python flask +pyecharts实现的气象数据可视化分析大屏

背景 气象数据可视化分析大屏基于Python Flask和Pyecharts技术&#xff0c;旨在通过图表展示气象数据的分析结果&#xff0c;提供直观的数据展示和分析功能。在当今信息化时代&#xff0c;气象数据的准确性和实时性对各行业具有重要意义。通过搭建气象数据可视化分析大屏&…

【Linux】Linux基本指令1

1.软件&#xff0c;OS&#xff0c;驱动 我们看看计算机的结构层次 1.1.操作系统 操作系统是一款做 软硬件管理 的软件 操作系统&#xff08;计算机管理控制程序&#xff09;_百度百科 (baidu.com) 操作系统&#xff08;英语&#xff1a;Operating System&#xff0c;缩写&a…

60. UE5 RPG 使用场景查询系统(EQS,Environment Query System)实现远程敌人寻找攻击位置

UE的Environment Query System&#xff08;EQS&#xff09;是环境查询系统&#xff0c;它是UE4和UE5中用于AI决策制定过程中的数据采集和处理的一个强大工具。EQS可以收集场景中相关的数据&#xff0c;利用生成器&#xff08;Generator&#xff09;针对用户的测试&#xff08;T…

身份认证页面该怎么设计更加合理?

一、认证页面的作用 认证页面在应用程序中具有以下几个重要的作用&#xff1a; 验证用户身份&#xff1a;认证页面的主要作用是验证用户的身份。通过要求用户提供正确的凭据&#xff08;如用户名和密码、生物特征、验证码等&#xff09;&#xff0c;认证页面可以确认用户是合法…

安卓开机启动阶段

目录 概述一、boot_progress_start二、boot_progress_preload_start三、boot_progress_preload_end四、boot_progress_system_run五、boot_progress_pms_start六、boot_progress_pms_system_scan_start七、boot_progress_pms_data_scan_start八、boot_progress_pms_scan_end九、…

Docker(三) 容器管理

1 容器管理概述 Docker 的容器管理可以通过 Docker CLI 命令行工具来完成。Docker 提供了丰富的命令&#xff0c;用于管理容器的创建、启动、停止、删除、暂停、恢复等操作。 以下是一些常用的 Docker 容器命令&#xff1a; 1、docker run&#xff1a;用于创建并启动一个容器。…

ubuntu22.04安装调节显示器亮度工具

1 介绍 软件名叫 DDC/CI control&#xff0c;官网 2 安装方法 sudo apt install intltool i2c-tools libxml2-dev libpci-dev libgtk2.0-dev liblzma-dev3 效果 进入软件&#xff0c;忽略告警信息

selenium 学习笔记(一)

pip的安装 新建一个txt curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py 把上面的代码复制进去后&#xff0c;把后缀名改为.bat然后双击运行 当前目录会出现一个这个文件 然后在命令行pyhon get-pip.py等它下好就可以了selenium安装 需要安装到工程目…

【进程空间】通过页表寻址的过程

文章目录 前言介绍页表、页框、页目录的概念页框页表页目录页表和页目录的分配 一级页表和二级页表一级页表寻址过程 二级页表寻址过程 一级页表和二级页表的对比 前言 我们知道每个进程都有属于自己的虚拟地址空间&#xff0c;且每个进程的虚拟地址都是统一的。要想通过虚拟地…