PyTorch到C++再到 CUDA 的调用链(C++ ATen 层) :以torch._amp_update_scale_调用为例

news2025/1/6 2:00:21

今天在看pytorch源码,遇到的问题,记录一下 。

source:/lib/python3.10/site-packages/torch/amp/grad_scaler.py

 torch._amp_update_scale_(
            _scale,
            _growth_tracker,
            found_inf_combined,
            self._growth_factor,
            self._backoff_factor,
            self._growth_interval,
        )

这段python代码是如何调用底层的cpp代码的

Tensor& _amp_update_scale_cuda_(Tensor& current_scale,
                                Tensor& growth_tracker,
                                const Tensor& found_inf,
                                double growth_factor,
                                double backoff_factor,
                                int64_t growth_interval) {
  TORCH_CHECK(growth_tracker.is_cuda(), "growth_tracker must be a CUDA tensor.");
  TORCH_CHECK(current_scale.is_cuda(), "current_scale must be a CUDA tensor.");
  TORCH_CHECK(found_inf.is_cuda(), "found_inf must be a CUDA tensor.");
  
  // 核函数调用
  amp_update_scale_cuda_kernel<<<1, 1, 0, at::cuda::getCurrentCUDAStream()>>>(
    current_scale.mutable_data_ptr<float>(),
    growth_tracker.mutable_data_ptr<int>(),
    found_inf.const_data_ptr<float>(),
    growth_factor,
    backoff_factor,
    growth_interval);
  C10_CUDA_KERNEL_LAUNCH_CHECK();

  return current_scale;
}

source: https://github.com/pytorch/pytorch/blob/0431d47eaaaceadab0c0f2400e186c02ef66d96c/aten/src/ATen/native/cuda/AmpKernels.cu#L222

在 PyTorch 框架中,Python 和底层 C++/CUDA 实现之间通过 PyTorch ATen 底层 API 和 PyBind11来完成。torch._amp_update_scale_ 是一个 Python 函数,它实际上通过 PyTorch 的 C++扩展机制 调用底层 C++ 实现(_amp_update_scale_cuda_)。以下是调用流程的完整解读:


1. Python 调用 C++ 的机制

torch._amp_update_scale_ 是一个绑定在 PyTorch ATen C++核心函数上的 Python 接口。其调用流程如下:

  1. Python 调用 torch._amp_update_scale_ 函数:

    • Python 参数 _scale, _growth_tracker, 和 found_inf_combinedtorch.Tensor 对象。
    • 这些对象在 Python 中使用,但底层通过 PyObject 引用 C++ 的 at::Tensor 对象(ATen 核心张量)。
  2. 通过绑定跳转到 C++ 函数 _amp_update_scale_cuda_

    • torch._amp_update_scale_ 被注册到一个 dispatcher(分发器),它会根据运行设备(CPU 或 GPU)选择合适的后端实现。
    • 在 GPU 上运行时,分发器会调用底层 CUDA 实现 _amp_update_scale_cuda_

2. Python 到 C++ 的具体流程

  1. ATen 和 PyTorch 的 Operator 注册系统
    PyTorch 使用 torch::RegisterOperators 注册 C++ 函数 _amp_update_scale_cuda_,并将其绑定到 Python 的 torch._amp_update_scale_

    注册流程示例

    TORCH_LIBRARY_IMPL(aten, CUDA, m) {
        m.impl("_amp_update_scale_", &_amp_update_scale_cuda_);
    }
    
    • TORCH_LIBRARY_IMPL 用于将 CUDA 实现 _amp_update_scale_cuda_ 注册到 ATen。
    • Python 代码调用 torch._amp_update_scale_ 时,会被自动映射到 C++ 实现 _amp_update_scale_cuda_
  2. Python 的 Tensor 转换为 C++ 的 at::Tensor
    torch._amp_update_scale_ 被调用时,Python 中的 Tensor 对象通过 PyBind11 自动转换为对应的 at::Tensor 对象。例如:

    torch._amp_update_scale_(
        _scale,             # Python Tensor -> at::Tensor
        _growth_tracker,    # Python Tensor -> at::Tensor
        found_inf_combined, # Python Tensor -> at::Tensor
        self._growth_factor, # Python float -> C++ double
        self._backoff_factor, # Python float -> C++ double
        self._growth_interval # Python int -> C++ int64_t
    )
    
  3. 调用 C++ 函数 _amp_update_scale_cuda_

    • 参数从 Python 传递到 _amp_update_scale_cuda_,对应 current_scale, growth_tracker, found_inf 等。
    • 在 C++ 中,_amp_update_scale_cuda_ 函数会调用底层 CUDA 核心函数 amp_update_scale_cuda_kernel,执行缩放更新逻辑。

3. C++ 到 CUDA 核心函数的调用流程

_amp_update_scale_cuda_ 中,C++ 调用 CUDA 核心代码的主要流程是:

  1. 参数检查
    使用 TORCH_CHECK 确保 current_scale, growth_tracker, 和 found_inf 都是 CUDA 张量:

    TORCH_CHECK(growth_tracker.is_cuda(), "growth_tracker must be a CUDA tensor.");
    TORCH_CHECK(current_scale.is_cuda(), "current_scale must be a CUDA tensor.");
    TORCH_CHECK(found_inf.is_cuda(), "found_inf must be a CUDA tensor.");
    
  2. 启动 CUDA 核函数
    使用 CUDA 的核函数调用机制 <<<...>>> 启动 CUDA 内核函数:

    amp_update_scale_cuda_kernel<<<1, 1, 0, at::cuda::getCurrentCUDAStream()>>>(
        current_scale.mutable_data_ptr<float>(), // 当前缩放因子的指针
        growth_tracker.mutable_data_ptr<int>(), // 成功步数计数器的指针
        found_inf.const_data_ptr<float>(),      // 梯度溢出的标志
        growth_factor,                          // 增长因子
        backoff_factor,                         // 回退因子
        growth_interval                         // 增长间隔
    );
    

    调用过程

    • CUDA 核函数 amp_update_scale_cuda_kernel 被调度到当前 GPU 的流(at::cuda::getCurrentCUDAStream())。
    • 各个张量(如 current_scale, growth_tracker)通过 .data_ptr<T>() 方法获取指针,传递给 CUDA 核函数。
  3. CUDA 核函数执行
    核函数 amp_update_scale_cuda_kernel 在 GPU 上执行,完成缩放因子的动态调整。逻辑详见问题中的 CUDA 实现。

  4. 内核启动检查
    启动内核后,通过 C10_CUDA_KERNEL_LAUNCH_CHECK() 检查 CUDA 内核是否成功运行。

    C10_CUDA_KERNEL_LAUNCH_CHECK();
    

4. 总结调用链路

完整调用链如下:

  1. Python 层

    torch._amp_update_scale_(
        _scale, _growth_tracker, found_inf_combined,
        self._growth_factor, self._backoff_factor, self._growth_interval
    )
    
    • Python 张量(torch.Tensor)通过 PyBind11 转换为 C++ 张量(at::Tensor)。
  2. C++ 层

    Tensor& _amp_update_scale_cuda_(
        Tensor& current_scale, Tensor& growth_tracker, const Tensor& found_inf,
        double growth_factor, double backoff_factor, int64_t growth_interval
    ) {
        // 调用 CUDA 核函数
        amp_update_scale_cuda_kernel<<<1, 1, 0, at::cuda::getCurrentCUDAStream()>>>(
            current_scale.mutable_data_ptr<float>(),
            growth_tracker.mutable_data_ptr<int>(),
            found_inf.const_data_ptr<float>(),
            growth_factor, backoff_factor, growth_interval
        );
        C10_CUDA_KERNEL_LAUNCH_CHECK();
        return current_scale;
    }
    
  3. CUDA 层

// amp_update_scale_cuda_kernel is launched with a single thread to compute the new scale.
// The scale factor is maintained and updated on the GPU to avoid synchronization.
__global__ void amp_update_scale_cuda_kernel(float* current_scale,
                                             int* growth_tracker,
                                             const float* found_inf,
                                             double growth_factor,
                                             double backoff_factor,
                                             int growth_interval)
{
 // 核函数逻辑:根据是否溢出动态调整 current_scale 和 growth_tracker
  if (*found_inf) {
    *current_scale = (*current_scale)*backoff_factor;
    *growth_tracker = 0;
  } else {
    // Entering this branch means we just carried out a successful step,
    // so growth_tracker is incremented before comparing to growth_interval.
    auto successful = (*growth_tracker) + 1;
    if (successful == growth_interval) {
      auto new_scale = static_cast<float>((*current_scale)*growth_factor);
      // Do not grow the scale past fp32 bounds to inf.
      if (isfinite_ensure_cuda_math(new_scale)) {
          *current_scale = new_scale;
      }
      *growth_tracker = 0;
    } else {
      *growth_tracker = successful;
    }
  }
}

5. 补充说明

这种从 Python 到 C++ 再到 CUDA 的调用链是 PyTorch 的通用设计模式:

  • Python API 层:提供高层易用接口。
  • C++ ATen 层:实现设备无关的核心逻辑。
  • CUDA 内核层:实现高性能的设备特定操作。

后记

2025年1月2日15点22分于上海, 在GPT4o大模型辅助下完成。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2270660.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

实践:事件循环

实践&#xff1a;事件循环 代码示例 console.log(1); setTimeout(() > console.log(2), 0); Promise.resolve(3).then(res > console.log(res)); console.log(4);上述的代码的输出结果是什么 1和4肯定优先输出&#xff0c;因为他们会立即方式堆栈的执行上下文中执行&am…

从零开始开发纯血鸿蒙应用之逻辑封装

从零开始开发纯血鸿蒙应用 一、前言二、逻辑封装的原则三、实现 FileUtil1、统一的存放位置2、文件的增删改查2.1、文件创建与文件保存2.2、文件读取2.2.1、读取内部文件2.2.2、读取外部文件 3、文件删除 四、总结 一、前言 应用的动态&#xff0c;借助 UI 响应完成&#xff0…

ESP32 I2S音频总线学习笔记(一):初识I2S通信与配置基础

文章目录 简介为什么需要I2S&#xff1f;关于音频信号采样率分辨率音频声道 怎样使用I2S传输音频&#xff1f;位时钟BCLK字时钟WS串行数据SD I2S传输模型I2S通信格式I2S格式左对齐格式右对齐格式 i2s基本配置i2s 底层API加载I2S驱动设置I2S使用的引脚I2S读取数据I2S发送数据卸载…

CSS 中 content换行符实现打点 loading 正在加载中的效果

我们动态加载页面内容的时候&#xff0c;经常会使用“正在加载中…”这几个字&#xff0c;基本上&#xff0c;后面的 3 个点都是静态的。静态的问题在于&#xff0c;如果网络不流畅&#xff0c;加载时间比较长&#xff0c;就会给人有假死的 感觉&#xff0c;但是&#xff0c;如…

25考研王道数据结构课后习题笔记

声明&#xff1a;以下内容来自于B栈知名up主–白话拆解数据结构 回答&#xff1a;为什么要做这个&#xff0c;因为我这个学期学完了数据结构&#xff0c;而且这个数据结构是408的重头&#xff0c;为什么选择25的&#xff0c;因为这个25考研刚刚结束&#xff0c;25相对成熟&…

小程序发版后,强制更新为最新版本

为什么要强制更新为最新版本&#xff1f; 在小程序的开发和运营过程中&#xff0c;强制用户更新到最新版本是一项重要的策略&#xff0c;能够有效提升用户体验并保障系统的稳定性与安全性。以下是一些主要原因&#xff1a; 1. 功能兼容 新功能或服务通常需要最新版本的支持&…

GRAPE——RLAIF微调VLA模型:通过偏好对齐提升机器人策略的泛化能力(含24年具身模型汇总)

前言 24年具身前沿模型大汇总 过去的这两年&#xff0c;工作之余&#xff0c;我狂写大模型与具身的文章&#xff0c;加之具身大火&#xff0c;每周都有各种朋友通过CSDN私我及我司「七月在线」寻求帮助/指导(当然&#xff0c;也欢迎各大开发团队与我司合作共同交付&#xff09…

0xc0000020错误代码怎么处理,Windows11、10坏图像错误0xc0000020的修复办法

“0xc0000020”是一种 Windows 应用程序错误代码&#xff0c;通常表明某些文件缺失或损坏。这可能是由于系统文件损坏、应用程序安装或卸载问题、恶意软件感染、有问题的 Windows 更新等原因导致的。 比如&#xff0c;当运行软件时&#xff0c;可能会出现类似“C:\xx\xxx.dll …

pycharm+anaconda创建项目

pycharmanaconda创建项目 安装&#xff1a; Windows下PythonPyCharm的安装步骤及PyCharm的使用-CSDN博客 详细Anaconda安装配置环境创建教程-CSDN博客 创建项目&#xff1a; 开始尝试新建一个项目吧&#xff01; 选择好项目建设的文件夹 我的项目命名为&#xff1a;pyth…

基于Pytorch和yolov8n手搓安全帽目标检测的全过程

一.背景 还是之前的主题&#xff0c;使用开源软件为公司搭建安全管理平台&#xff0c;从视觉模型识别安全帽开始。主要参考学习了开源项目 https://github.com/jomarkow/Safety-Helmet-Detection&#xff0c;我是从运行、训练、标注倒过来学习的。由于工作原因&#xff0c;抽空…

【PDF物流单据提取明细】批量PDF提取多个区域内容导出表格或用区域内容对文件改名,批量提取PDF物流单据单号及明细导出表格并改名的技术难点及小节

相关阅读及下载&#xff1a; PDF电子物流单据&#xff1a; 批量PDF提取多个区域局部内容重命名PDF或者将PDF多个局部内容导出表格&#xff0c;具体使用步骤教程和实际应用场景的说明演示https://mp.weixin.qq.com/s/uCvqHAzKglfr40YPO_SyNg?token720634989&langzh_CN扫描…

JavaWeb开发(五)Servlet-ServletContext

1. ServletContext 1.1. ServletContext简介 1.1.1. ServletContext定义 ServletContext即Servlet上下文对象&#xff0c;该对象表示当前的web应用环境信息。 1.1.2. 获取ServletContext对象: &#xff08;1&#xff09;通过ServletConfig的getServletContext()方法可以得到…

长时间序列预测算法---Informer

目录 一、传统的 Transformer 模型二、Informer原理2.1 Attention计算2.2 “积极”的Q筛选2.2.1 KL散度2.2.2 “懒惰”的q处理 2.3 Encoder结构2.4 Decoder结构2.4.1 Transformer的Decoder操作2.4.2 Informer的Decoder操作 2.5 Informer模型的改进 三、模型应用 时间序列相关参…

点击取消按钮,console出来数据更改了,页面视图没有更新

点击取消按钮&#xff0c;console出来数据更改了&#xff0c;页面视图没有更新 前言 实现效果&#xff1a;点击取消按钮&#xff0c;页面视图全部为空&#xff0c; 遇到的问题&#xff1a; 点击取消按钮&#xff0c;console出来数据更改了&#xff0c;SchemaJson 都是默认值啦…

RFID手持机与RFID工业平板在仓储物流管理系统中的选型

概述 随着物联网技术在仓储物流管理系统中的普及&#xff0c;RFID手持机与RFID工业平板作为基于RFID技术手持式读写器的两种重要终端设备形态&#xff0c;得到了广泛应用。尽管RFID手持机与RFID工业平板都具备读写 RFID标签的基本功能&#xff0c;使用场景较为类似&#xff0c…

UML之泛化、特化和继承

在UML&#xff08;统一建模语言&#xff09;中&#xff0c;泛化&#xff08;Generalization&#xff09;和特化&#xff08;Specialization&#xff09;是面向对象思想中继承&#xff08;Inheritance&#xff09;关系的重要概念&#xff0c;它们描述类与类&#xff08;或用例与…

vue 修改vant样式NoticeBar中的图标,不用插槽可以直接用图片

使用文档中是可以直接使用图片链接的 :left-icon"require(../../assets/newImages/noticeImg.png)" <html> .... <NoticeBarmode""color"#C6C6C6"background""v-if"global_info.site_bulletin":left-icon"r…

【漫话机器学习系列】028.CP

Mallows’ Cp&#xff1a;标准化公式解析与应用 Mallows’ Cp 是一种常用的模型选择工具&#xff0c;用于在一系列候选模型中权衡拟合度和复杂性&#xff0c;帮助我们选择性能最优的模型。本文将基于其标准化公式展开详细解析&#xff0c;并探讨其应用场景、实现方法、优点与局…

vs 2022 中xml 粘贴为Class 中,序列化出来的xml 的使用

上图是visual studio 2022 中使用的粘贴功能的菜单位置 在生成的xml 中&#xff0c;有些是类似如下类型的 [System.Serializable] [System.Xml.Serialization.XmlType] public class Item {private bool isVisibleField;private bool isVisibleFieldSpecified;[System.Xml.Se…

数据库自增 id 过大导致前端时数据丢失

可以看到&#xff0c;前端响应参数是没有丢失精度的 但是在接受 axios 请求参数时出现了精度丢失 解决方案一&#xff1a;改变 axios 字符编码 axios.defaults.headers[Content-Type] application/json;charsetUTF-8; 未解决 解决方案二&#xff1a;手动使用 json.parse() …