FasterTransformer 005 初始化:如何将参数传给模型?

news2025/1/10 19:10:33

cpp的例子

device_malloc

  • cpp没有用具体数值初始化 float *d_from_tensor = NULL;device_malloc(&d_from_tensor, batch_size * seq_len * hidden_dim);
  • https://github1s.com/NVIDIA/FasterTransformer/blob/v1.0/sample/cpp/transformer_fp32.cc#L35-L38 直接用的cudaMalloc
void device_malloc(float** ptr, int size) // cudaMalloc函数为什么是二级指针的解释https://blog.csdn.net/CaiYuxingzzz/article/details/121112273
{
  cudaMalloc((void**)ptr, sizeof(float) * size);
}

allocator

  • allocator用于分配attr_out_buf_
https://github1s.com/NVIDIA/FasterTransformer/blob/v1.0/fastertransformer/bert_encoder_transformer.h#L131-L135
buf_ = reinterpret_cast<DataType_*>(allocator_.malloc(sizeof(DataType_) * buf_size * 6));
  • 然后将这些参数和encoder_param打包成multi_head_init_param
    在初始化(encoder_transformer_->initialize)时传给attention_->initialize(multi_head_init_param);
    attention_->initialize则只需将传入的参数初始化给attention对象的参数,等forward时调用自己的参数
接口包含两个方法malloc,free
class IAllocator{
 public:
  virtual void* malloc(size_t size) const = 0;
  virtual void free(void* ptr) const = 0;
};
//AllocatorTypeyouenum class AllocatorType{CUDA, TF}; 用的应该是CUDA的
template<>
class Allocator<AllocatorType::CUDA> : public IAllocator{
  const int device_id_;
 public:
  Allocator(int device_id): device_id_(device_id){}

  void* malloc(size_t size) const {
    void* ptr = nullptr;
    int o_device = 0;
    check_cuda_error(get_set_device(device_id_, &o_device));
    check_cuda_error(cudaMalloc(&ptr, size));
    check_cuda_error(get_set_device(o_device));
    return ptr;
  }
  
  void free(void* ptr) const {
    int o_device = 0;
    check_cuda_error(get_set_device(device_id_, &o_device));
    check_cuda_error(cudaFree(ptr));
    check_cuda_error(get_set_device(o_device));
    return;
  }
};
fastertransformer::Allocator<AllocatorType::CUDA> allocator(0); // 0是device_id_

encoder_param

  • EncoderInitParam encoder_param; //init param here 包含参数的结构体,成员记录了GPU数据的地址

initialize

  BertEncoderTransformer<EncoderTraits_> *encoder_transformer_ = new 
    BertEncoderTransformer<EncoderTraits_>(allocator, batch_size, from_seq_len, to_seq_len, head_num, size_per_head);
  encoder_transformer_->initialize(encoder_param);

trt_plugin的例子

将数值放入vector

  • https://github1s.com/NVIDIA/FasterTransformer/blob/v1.0/sample/tensorRT/transformer_trt.cc#L108-L136
  • 先分配地址
    host_malloc(&h_attr_kernel_Q, hidden_dim * hidden_dim);
  • 然后进行赋值
    h_attr_kernel_Q[i] = 0.001f;
   std::vector<T* > layer_param;
    layer_param.push_back(h_attr_kernel_Q);
 将值打包
    params.push_back(layer_param);
  }

  cudaStream_t stream;
  cudaStreamCreate(&stream);

  TRT_Transformer<T>* trt_transformer = new TRT_Transformer<T>(batch_size, seq_len, head_num, hidden_dim, layers);
  trt_transformer->build_engine(params);

  trt_transformer->do_inference(batch_size, h_from_tensor, h_attr_mask, h_transformer_out, stream);

  delete trt_transformer;
  • 构建TRT_Transformer时会调用算子插件,权重在void build_engine(std::vector<std::vector<T* > > &weights)时传入
    https://github1s.com/NVIDIA/FasterTransformer/blob/v1.0/fastertransformer/trt_plugin/trt_model.h#L75-L77
auto plugin = new TransformerPlugin<T>(
          hidden_dim_, head_num_, seq_len_, batch_size_, 
          point2weight(weights[i][0], hidden_dim_ * hidden_dim_),
  • 创建TransformerPlugin实例时会传入权重
TransformerPlugin(
    int hidden_dim, int head_num, int seq_len, int max_batch_size,
    const nvinfer1::Weights &w_attr_kernel_Q,...
  • 这里就是和cpp例子的不同了,其使用权重w_attr_kernel_Q
  • https://github1s.com/NVIDIA/FasterTransformer/blob/v1.0/fastertransformer/trt_plugin/bert_transformer_plugin.h#L103
cudaMallocAndCopy(d_attr_kernel_Q_, w_attr_kernel_Q, hidden_dim * hidden_dim);
  • cudaMallocAndCopy定义在https://github1s.com/NVIDIA/FasterTransformer/blob/v1.0/fastertransformer/trt_plugin/bert_transformer_plugin.h#L338-L352
    static void cudaMallocAndCopy(T *&dpWeight, const nvinfer1::Weights &w, int nValue) 
    {
      assert(w.count == nValue);
      check_cuda_error(cudaMalloc(&dpWeight, nValue * sizeof(T)));
      check_cuda_error(cudaMemcpy(dpWeight, w.values, nValue * sizeof(T), cudaMemcpyHostToDevice));

      T* data = (T*)malloc(sizeof(T) * nValue);
      cudaMemcpy(data, dpWeight, sizeof(T) * nValue, cudaMemcpyDeviceToHost);

    }
    static void cudaMallocAndCopy(T*&dpWeight, const T *&dpWeightOld, int nValue) 
    {
      check_cuda_error(cudaMalloc(&dpWeight, nValue * sizeof(T)));
      check_cuda_error(cudaMemcpy(dpWeight, dpWeightOld, nValue * sizeof(T), cudaMemcpyDeviceToDevice));
    }

cg

  • https://github.com/NVIDIA/TensorRT/blob/release/8.5/demo/Diffusion/models.py
    在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/644405.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【电子学会】2023年03月图形化四级 -- 绘制直尺

绘制直尺 编写一段程序&#xff0c;绘制一段7厘米的直尺。 1. 准备工作 &#xff08;1&#xff09;保留小猫角色&#xff0c;隐藏&#xff1b; &#xff08;2&#xff09;白色背景。 2. 功能实现 &#xff08;1&#xff09;点击绿旗&#xff0c;设置笔的颜色为红色&#…

事务和事务的隔离级别

一、事务 &#xff08;一&#xff09;为什么需要事务 事务是数据库管理系统&#xff08;DBMS&#xff09;执行过程中的一个逻辑单位&#xff08;不可再进行分割&#xff09;&#xff0c;由一个有限的数据库操作序列构成&#xff08;多个DML语句&#xff0c;select语句不包含事…

数字图像处理期末复习习题 SCUEC part1

1.在利用LoG算子做边缘检测的时候&#xff0c;作为一种经验法则&#xff0c;当滤波器空间参数为a7时&#xff0c;LoG滤波器空域模板大小应为 答&#xff1a;4343 理由是&#xff1a;n大于等于6a1 2.空间域方法主要分为灰度变换和空间滤波两类&#xff0c;灰度变换在图像的单…

【前端 - CSS】第 15 课 - 复合选择器

欢迎来到博主 Apeiron 的博客&#xff0c;祝您旅程愉快 &#xff01; 时止则止&#xff0c;时行则行。动静不失其时&#xff0c;其道光明。 目录 1、缘起 2、复合选择器 2.1、后代选择器 2.2、子代选择器 2.3、并集选择器 2.4、交集选择器&#xff08;了解&#xff09…

SpringBatch从入门到实战(三):父子Job和多步骤控制

一&#xff1a;Job嵌套 Job之前也可以嵌套&#xff0c;比如一个父Job封装多个已经存在的子Job。 Configuration public class ChildrenJobConfig {Autowiredprivate JobBuilderFactory jobBuilderFactory;Autowiredprivate StepBuilderFactory stepBuilderFactory;Beanpublic…

基础知识学习---牛客网C++面试宝典(八)操作系统--第三节

1、本栏用来记录社招找工作过程中的内容&#xff0c;包括基础知识学习以及面试问题的记录等&#xff0c;以便于后续个人回顾学习&#xff1b; 暂时只有2023年3月份&#xff0c;第一次社招找工作的过程&#xff1b; 2、个人经历&#xff1a; 研究生期间课题是SLAM在无人机上的应…

Golang每日一练(leetDay0096) 添加运算符、移动零

目录 282. 给表达式添加运算符 Expression Add Operators &#x1f31f;&#x1f31f;&#x1f31f; 283. 移动零 Move Zeroes &#x1f31f; &#x1f31f; 每日一练刷题专栏 &#x1f31f; Rust每日一练 专栏 Golang每日一练 专栏 Python每日一练 专栏 C/C每日一练 …

Cenos7 --- Redis下载和安装(Linux版本)

1.下载和安装 Download | Redis进入官网Download | Redis&#xff0c; 上边点击下载7.0.11,右键复制下载衔接 https://download.redis.io/releases/redis-7.0.2.tar.gz 1.weget获取 我这个安装包放在 /tools/installbags下 cd /tools/installbags wget https://download.red…

Java进阶 —— Java多线程编程笔记

❤ 作者主页&#xff1a;欢迎来到我的技术博客&#x1f60e; ❀ 个人介绍&#xff1a;大家好&#xff0c;本人热衷于Java后端开发&#xff0c;欢迎来交流学习哦&#xff01;(&#xffe3;▽&#xffe3;)~* &#x1f34a; 如果文章对您有帮助&#xff0c;记得关注、点赞、收藏、…

【头歌-Python】9.3 中英文词云绘制(project) 第1~3关

第1关&#xff1a;词云练习1 任务描述 本关任务&#xff1a;编写一个能制作词云的小程序。 相关知识 词云 词云&#xff0c;也叫文字云&#xff0c;是一种应用广泛的数据可视化方法。是过滤掉文本中大量的低频信息&#xff0c;形成“关键词云层”或“关键词渲染”&#xf…

基于VMWare组件安装Centos7.9

1.前提条件 使用VMware进行安装&#xff0c;VMware可以自行下载&#xff0c;需要介质(VMware和CentOS7.9)的同仁&#xff0c;请留言&#xff0c;我给你下载链接。 2.CentOS7.9安装 1.打开VMware&#xff0c;点击“新建虚拟机(N)...” 2.选择“典型” &#xff0c;点击“下一步…

基础知识学习---牛客网C++面试宝典(六)操作系统--第一节

1、本栏用来记录社招找工作过程中的内容&#xff0c;包括基础知识学习以及面试问题的记录等&#xff0c;以便于后续个人回顾学习&#xff1b; 暂时只有2023年3月份&#xff0c;第一次社招找工作的过程&#xff1b; 2、个人经历&#xff1a; 研究生期间课题是SLAM在无人机上的应…

A100 GPU服务器安装GPU驱动教程

大家好,我是爱编程的喵喵。双985硕士毕业,现担任全栈工程师一职,热衷于将数据思维应用到工作与生活中。从事机器学习以及相关的前后端开发工作。曾在阿里云、科大讯飞、CCF等比赛获得多次Top名次。现为CSDN博客专家、人工智能领域优质创作者。喜欢通过博客创作的方式对所学的…

【OpenCV DNN】Flask 视频监控目标检测教程 06

欢迎关注『OpenCV DNN Youcans』系列&#xff0c;持续更新中 【OpenCV DNN】Flask 视频监控目标检测教程 06 3.6 OpenCVFlask实时监控和视频播放cvFlask06 项目的文件树cvFlask06 项目的程序文件cvFlask06 项目的网页模板cvFlask06 项目的运行 本系列从零开始&#xff0c;详细…

chatgpt赋能python:Python排序算法实现及其应用

Python排序算法实现及其应用 排序是计算机科学中最基础也是最常用的算法之一。在数据分析、数据挖掘和机器学习等领域&#xff0c;排序算法有着广泛的应用。Python作为一种流行的编程语言&#xff0c;在排序方面具有一定的优势。本文将介绍一些常见的Python排序算法实现以及应…

有趣的图(三)(57)

小朋友们好&#xff0c;大朋友们好&#xff01; 我是猫妹&#xff0c;一名爱上Python编程的小学生。 和猫妹学Python&#xff0c;一起趣味学编程。 今日主题 咱们之前分别学习了图的基本概念&#xff0c;和图的深度优先遍历算法dfs。 你学会了吗&#xff1f; 咱们今天要学…

Linux系统的tty架构及UART驱动详解

​一、模块硬件学习 1.1. Uart介绍 通用异步收发传输器&#xff08;Universal Asynchronous Receiver/Transmitter)&#xff0c;通常称为UART&#xff0c;是一种异步收发传输器&#xff0c;是电脑硬件的一部分。它将要传输的资料在串行通信与并行通信之间加以转换。 作为把并…

面试问题总结----C/C++部分

1、本栏用来记录社招找工作过程中的内容,包括基础知识学习以及面试问题的记录等,以便于后续个人回顾学习; 暂时只有2023年3月份,第一次社招找工作的过程; 2、个人经历: 研究生期间课题是SLAM在无人机上的应用,有接触SLAM、Linux、ROS、C/C++、DJI OSDK等; 3、参加工作后…

C++程序流程结构

目录 程序流程结构 一、选择结构 1.1 If语句 1.2 三目运算符 1.3 switch语句 二、循环结构 2.1 while 循环语句 2.2 do…while循环 2.3 for循环 2.4 嵌套循环 三、跳转语句 3.1 break语句 3.2 continue 语句 3.3 goto语句 程序流程结构 C/C支持最基本的三种程…

20230623在WIN10安装PROTEL DXP2004(STEP-BY-STEP)

20230623在WIN10安装PROTEL DXP2004&#xff08;STEP-BY-STEP&#xff09; https://xiazai.zol.com.cn/detail/43/428470.shtml Protel DXP 2004 https://www.onlinedown.net/soft/580490.htm Protel DXP 2004 DXP2004 安装步骤 Failed To load Parallel Port Driver Welcom…