NVIDIA TensorRT Model Optimizer

news2024/11/28 17:59:04

NVIDIA TensorRT Model Optimizer

NVIDIA TensorRT 模型优化器(ModelOpt)是一个用于优化 AI 模型的库,它通过量化和稀疏性技术减小模型大小并加速推理,同时保持模型性能。ModelOpt 支持多种量化格式和算法,包括 FP8、INT8、INT4,并提供 Python API 以实现轻松优化。它还支持后训练量化和量化感知训练。此外,ModelOpt 提供了稀疏性 API,以减少模型的内存占用,支持 NVIDIA 的稀疏模式和稀疏化方法,并推荐使用微调来最小化精度损失。

1 Install

1-1 System requirements

在这里插入图片描述

1-2 pip install

pip install "nvidia-modelopt[all]" --no-cache-dir --extra-index-url https://pypi.nvidia.com

1-3 Check installation

python -c "import modelopt.torch.quantization.extensions as ext; print(ext.cuda_ext); print(ext.cuda_ext_fp8)"

2 Quantization

2-1 选择量化方法时需要考虑的权衡因素

在这里插入图片描述

2-2 PyTorch Quantization

2-2-1 Post Training Quantization (PTQ)

可以通过在少量训练或评估数据(通常是128-512个样本)上进行简单的校准来实现后训练量化(PTQ)一个PyTorch模型。

import modelopt.torch.quantization as mtq

# Setup the model
model = get_model()

# Select quantization config
config = mtq.INT8_SMOOTHQUANT_CFG

# Quantization need calibration data. Setup calibration data loader
# An example of creating a calibration data loader looks like the following:
data_loader = get_dataloader(num_samples=calib_size)

# Define forward_loop. Please wrap the data loader in the forward_loop
def forward_loop(model):
    for batch in data_loader:
        model(batch)

# Quantize the model and perform calibration (PTQ)
model = mtq.quantize(model, config, forward_loop)

# Print quantization summary after successfully quantizing the model with mtq.quantize
# This will show the quantizers inserted in the model and their configurations
mtq.print_quantization_summary(model)

torch.onnx.export(model, sample_input, onnx_file)
2-2-2 Quantization-aware Training (QAT)

量化感知训练(QAT)是一种微调量化模型的技术,用于恢复由于量化导致的模型质量下降。尽管QAT比后训练量化(PTQ)需要更多的计算资源,但它在恢复模型质量方面非常有效。

import modelopt.torch.quantization as mtq

# Select quantization config
config = mtq.INT8_DEFAULT_CFG

# Define forward loop for calibration
def forward_loop(model):
    for data in calib_set:
        model(data)

# QAT after replacement of regular modules to quantized modules
model = mtq.quantize(model, config, forward_loop)

# Fine-tune with original training pipeline
# Adjust learning rate and training duration
train(model, train_loader, optimizer, scheduler, ...)

建议对原始训练epoch的10%进行量化感知训练(QAT)。对于大型语言模型(LLMs),即使对原始预训练持续时间的不到1%进行QAT微调也通常足以恢复模型质量。

2-2-3 Storing and loading quantized model

mto.modelopt_state() 提供了模型的量化器状态。这些量化器状态可以使用 torch.save 来保存。

import modelopt.torch.opt as mto

# Save quantizer states
torch.save(mto.modelopt_state(model), "modelopt_state.pt")

# Save model weights using torch.save or custom check-pointing function
# trainer.save_model("model.pt")
torch.save(model.state_dict(), "model.pt")

要恢复一个量化模型,首先使用 mto.restore_from_modelopt_state 来恢复量化器状态。在量化器状态恢复后,加载模型的权重。

import modelopt.torch.opt as mto

# Initialize the un-quantized model
model = ...

# Load quantizer states
model = mto.restore_from_modelopt_state(model, torch.load("modelopt_state.pt"))

# Load model weights using torch.load or custom check-pointing function
# model.from_pretrained("model.pt")
model.load_state_dict(torch.load("model.pt"))
2-2-4 Advanced Topics
TensorQuantizer

ModelOpt 的 mtq.quantize() 方法会在模型层(如线性层、卷积层等)中插入 TensorQuantizer(量化模块),并修改它们的前向传播方法来执行量化。

要创建 TensorQuantizer 实例,您需要指定 QuantDescriptor,该描述符描述了量化参数,如quantization bits, axis等。

from modelopt.torch.quantization.tensor_quant import QuantDescriptor
from modelopt.torch.quantization.nn import TensorQuantizer

# Create quantizer descriptor
quant_desc = QuantDescriptor(num_bits=8, axis=(-1,), unsigned=True)

# Create quantizer module
quantizer = TensorQuantizer(quant_desc)

quant_x = quantizer(x)  # Quantize input x
Customize quantizer config

ModelOpt 在常见的层中插入了输入量化器、权重量化器和输出量化器,但默认情况下禁用了输出量化器。用户如果想要自定义默认的量化器配置,可以使用通配符或过滤器函数匹配来更新提供给 mtq.quantize 的配置字典。

# Select quantization config
config = mtq.INT8_DEFAULT_CFG.copy()
config["quant_cfg"]["*.bmm.output_quantizer"] = {
    "enable": True
}  # Enable output quantizer for bmm layer

# Perform PTQ/QAT;
model = mtq.quantize(model, config, forward_loop)
Custom quantized module and quantizer placement

modelopt.torch.quantization 提供了一组默认的量化模块(详见 modelopt.torch.quantization.nn.modules 以获取详细列表)和量化器放置规则(输入、输出和权重量化器)。但是也允许自定义的量化模块和/或自定义量化器的放置位置。

from modelopt.torch.quantization.nn import TensorQuantizer

class QuantLayerNorm(nn.LayerNorm):
    def __init__(self, normalized_shape):
        super().__init__(normalized_shape)
        self._setup()

    def _setup(self):
        # Method to setup the quantizers
        self.input_quantizer = TensorQuantizer()
        self.weight_quantizer = TensorQuantizer()

    def forward(self, input):
        # You can customize the quantizer placement anywhere in the forward method
        input = self.input_quantizer(input)
        weight = self.weight_quantizer(self.weight)
        return F.layer_norm(input, self.normalized_shape, weight, self.bias, self.eps)


import modelopt.torch.quantization as mtq

# Register the custom quantized module
mtq.register(original_cls=nn.LayerNorm, quantized_cls=QuantLayerNorm)

# Perform PTQ
# nn.LayerNorm modules in the model will be replaced with the QuantLayerNorm module
model = mtq.quantize(model, config, forward_loop)
Fast evaluation

Weight folding 避免了在每次推理前向传递过程中权重的重复量化,并加速了eval 过程。

# Fold quantizer together with weight tensor
mtq.fold_weight(quantized_model)

# Run model evaluation
user_evaluate_func(quantized_model)

3 Sparsity

3-1 Post-Training Sparsification

后训练稀疏化技术允许您将已经训练好的密集模型转换为更高效的稀疏模型,而无需经过重新训练的过程。这一过程可以通过调用 mts.sparsify API 轻松实现,它通过接受稀疏配置和稀疏格式的参数来输出一个优化后的稀疏模型。

提供的稀疏配置是一个详细的字典,它定义了模型中哪些层需要进行稀疏化处理,并且可以选择性地包含一个数据加载器,用于在数据驱动的稀疏化方法(如 SparseGPT)中进行校准。

mts.sparsify() 提供了两种稀疏化方法的支持:NVIDIA ASP 用于基于权重幅度的稀疏化,而 SparseGPT 用于更高级的数据驱动稀疏化。

import torch
from transformers import AutoModelForCausalLM
import modelopt.torch.sparsity as mts

# User-defined model
model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-j-6b")

# Configure and convert for sparsity
sparsity_config = {
    # data_loader is required for sparsity calibration
    "data_loader": calib_dataloader,
    "collect_func": lambda x: x,
}
sparse_model = mts.sparsify(
    model,
    "sparsegpt",  # or "sparse_magnitude"
    config=sparsity_config,
)
Save and restore the sparse model

mto.save() 将保存模型的 state_dict,以及稀疏掩码和元数据,以便稍后正确重新创建稀疏模型。

mto.save(sparse_model, "modelopt_sparse_model.pth")

mto.restore() 将恢复模型的 state_dict,以及每个稀疏模块的稀疏掩码和元数据。普通的 PyTorch 模块将被转换为稀疏模块。当访问模型权重时,稀疏掩码将自动被应用。

import modelopt.torch.opt as mto

# Re-initialize the original, unmodified model
model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-j-6b")

# Restore the sparse model and metadata.
sparse_model = mto.restore(model, "modelopt_sparse_model.pth")

mts.export() 将稀疏模型导出为普通的 PyTorch 模型。稀疏掩码将被应用到模型权重上,并且所有稀疏相关的元数据都将被移除。导出后,在后续的微调过程中将不再强制执行稀疏性。如果你想继续微调,请不要导出模型。

3-2 Sparsity Concepts

3-2-1 Structured and Unstructured Sparsity

权重稀疏性通过将模型中的部分权重设为零来优化模型。它分为两类:

  • 非结构化稀疏性:零权重在权重矩阵中随机分布,灵活但可能导致在GPU等硬件上利用率低。
  • 结构化稀疏性:零权重在权重矩阵中有规律地分布,内存访问更高效,支持更高的数学吞吐量。可以通过强制特定的稀疏模式实现。
    N:M 稀疏性
3-2-2 N:M Sparsity

N:M 稀疏性是一种特殊的结构化稀疏模式,每个由M个连续元素组成的块中最多有N个非零元素。它在GPU架构上有效,提供以下优势:

  • 降低内存带宽需求
  • 提高数学吞吐量(如2:4稀疏模式在稀疏张量核心上允许2倍的数学吞吐量)
3-2-3 Sparsification algorithm

实现权重稀疏性的方法有多种,如基于幅度的稀疏性(保留M个元素块中的N个最大元素)和数据驱动的稀疏性(如Optimal Brain Surgeon)。NVIDIA的模型优化器支持基于幅度的(NVIDIA ASP)和数据驱动的稀疏性(SparseGPT)。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1670679.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

深度学习知识点全面总结

ChatGPT 深度学习是一种使用神经网络来模拟人脑处理数据和创建模式的机器学习方法。下面是深度学习的一些主要知识点的总结: 1. 神经网络基础: - 神经元:基本的计算单元,模拟人脑神经元。 - 激活函数:用于增加神…

力扣HOT100 - 763. 划分字母区间

解题思路&#xff1a; class Solution {public List<Integer> partitionLabels(String s) {int[] last new int[26];int len s.length();for (int i 0; i < len; i) {last[s.charAt(i) - a] i;//记录字母最远的下标}List<Integer> partition new ArrayList…

大数据在IT行业的应用与发展趋势及IT行业的现状与未来

大数据在IT行业中的应用、发展趋势及IT行业的现状与未来 一、引言 随着科技的飞速发展&#xff0c;大数据已经成为IT行业的重要驱动力。从数据收集、存储、处理到分析&#xff0c;大数据技术为各行各业带来了深远的影响。本文将详细探讨大数据在IT行业中的应用、发展趋势&#…

做抖店如何提高与达人合作的几率?有效筛选+有效推品

我是王路飞。 总是有很多新手商家&#xff0c;找我吐槽&#xff0c;抖音上的达人特别不好找&#xff0c;好不容易加上了&#xff0c;要么是发消息不回复&#xff0c;要么是寄样后就没下文了。 虽然一直都说找达人带货玩法比较简单&#xff0c;但也离不开电商的基本逻辑&#…

【k8s多集群管理平台开发实践】九、client-go实现nginx-ingress读取列表、创建ingress、读取更新yaml配置

文章目录 简介 一.k8s的ingress列表1.1.controllers控制器代码1.2.models模型代码 二.创建ingress2.1.controllers控制器代码2.2.models模分代码 三.读取和更新ingress的yaml配置3.1.controllers控制器代码3.2.models模型代码 四.路由设置4.1.路由设置 五.前端代码5.1.列表部分…

低血压怎么办?低血压患者应该如何调理?

点击文末领取揿针的视频教程跟直播讲解 低血压在生活中也是一种常见的问题&#xff0c;低血压的朋友常有头晕眼黑、冒冷汗等症状&#xff0c;对工作学习产生了一定的影响。 什么是低血压呢&#xff1f; 低血压是指体循环动脉压力低于正常的状态。即血压低于正常水平。 ​一般…

LearnOpenGL(十四)之模型加载

Model类的结构&#xff1a; class Model {public:/* 函数 */Model(char *path){loadModel(path);}void Draw(Shader shader); private:/* 模型数据 */vector<Mesh> meshes;string directory;/* 函数 */void loadModel(string path);void processNode(aiNode …

初识指针(5)<C语言>

前言 在前几篇文章中&#xff0c;已经介绍了指针一些基本概念、用途和一些不同类型的指针&#xff0c;下文将介绍某些指针类型的运用。本文主要介绍函数指针数组、转移表&#xff08;函数指针的用途&#xff09;、回调函数、qsort使用举例等。 函数指针数组 函数指针数组即每个…

京东h5st4.7逆向分析

声明 本文章中所有内容仅供学习交流使用&#xff0c;不用于其他任何目的&#xff0c;不提供完整代码&#xff0c;抓包内容、敏感网址、数据接口等均已做脱敏处理&#xff0c;严禁用于商业用途和非法用途&#xff0c;否则由此产生的一切后果均与作者无关&#xff01; 本文章未…

信息量、熵、KL散度、交叉熵概念理解

信息量、熵、KL散度、交叉熵概念理解 (1) 信息量 信息量是对事件的不确定性的度量。 假设我们听到了两件事&#xff0c;分别如下&#xff1a;事件A&#xff1a;巴西队进入了世界杯决赛圈。 事件B&#xff1a;中国队进入了世界杯决赛圈。仅凭直觉来说&#xff0c;显而易见事件…

SpringAMQP-消息转换器

这边发送消息接收消息默认是jdk的序列化方式&#xff0c;发送到服务器是以字节码的形式&#xff0c;我们看不懂也很占内存&#xff0c;所以我们要手动设置一下 我这边设置成json的序列化方式&#xff0c;注意发送方和接收方的序列化方式要保持一致 不然回报错。 引入依赖&#…

STM32_HAL_TIM_1介绍

1.F1的定时器类型&#xff08;高的拥有低级的全部功能&#xff09; 高级定时器&#xff08;TIM1和TIM8&#xff09;&#xff1a; 16位自动重装载计数器。支持多种工作模式&#xff0c;包括中心对齐模式、边沿对齐模式等。可以产生7个独立的通道&#xff0c;用于PWM、输出比较、…

Cosmo Bunny Girl

可爱的宇宙兔女郎的3D模型。用额外的骨骼装配到Humanoid上,Apple混合了形状。完全模块化,包括不带衣服的身体。 技术细节 内置,包括URP和HDRP PDF。还包括关于如何启用URP和HDRP的说明。 LOD 0:面:40076,tris 76694,verts 44783 装配了Humanoid。添加到Humanoid中的其他…

测试用例编写规范

1.1目的 统一测试用例编写的规范&#xff0c;为测试设计人员提供测试用例编写的指导&#xff0c;提高编写的测试用例的可读性&#xff0c;可执行性、合理性。为测试执行人员更好执行测试&#xff0c;提高测试效率&#xff0c;最终提高公司整个产品的质量。 1.2使用范围 适用…

数字人实训室助推元宇宙人才培养

如今&#xff0c;全身动作捕捉设备已经大量应用在影视、动画、游戏领域&#xff0c;在热门的元宇宙内容领域中&#xff0c;全身动作捕捉设备逐步发挥着重要的作用&#xff0c;在包括体育训练、数字娱乐虚拟偶像、虚拟主持人、非物质文化遗产保护等等场景&#xff0c;数字人实训…

第5章 处理GET请求参数

1 什么是GET请求参数 表单GET请求参数是指在HTML表单中通过GET方法提交表单数据时所附带的参数信息。在HTML表单中&#xff0c;可以通过表单元素的name属性来指定表单字段的名称&#xff0c;通过表单元素的value属性来指定表单字段的值。当用户提交表单时&#xff0c;浏览器会将…

【数据结构】有关栈和队列相互转换问题

文章目录 用队列实现栈思路实现 用栈实现队列思路实现 用队列实现栈 Leetcode-225 用队列实现栈 思路 建立队列的基本结构并实现队列的基本操作 这部分这里就不多说了&#xff0c;需要的可以看笔者的另一篇博客 【数据结构】队列详解(Queue) 就简单带过一下需要实现的功能 …

金融业开源软件应用 评估规范

金融业开源软件应用 评估规范 1 范围 本文件规定了金融机构在应用开源软件时的评估要求&#xff0c;对开源软件的引入、维护和退出提出了实现 要求、评估方法和判定准则。 本文件适用于金融机构对应用的开源软件进行评估。 2 规范性引用文件 下列文件中的内容通过文中的规范…

数据科学:使用Optuna进行特征选择

大家好&#xff0c;特征选择是机器学习流程中的关键步骤&#xff0c;在实践中通常有大量的变量可用作模型的预测变量&#xff0c;但其中只有少数与目标相关。特征选择包括找到这些特征的子集&#xff0c;主要用于改善泛化能力、助力推断预测、提高训练效率。有许多技术可用于执…

【kettle012】kettle访问FTP服务器文件并处理数据至PostgreSQL(已更新)

1.一直以来想写下基于kettle的系列文章,作为较火的数据ETL工具,也是日常项目开发中常用的一款工具,最近刚好挤时间梳理、总结下这块儿的知识体系。 2.熟悉、梳理、总结下FTP服务器相关知识体系 3.欢迎批评指正,跪谢一键三连! kettle访问FTP服务器文件并处理数据至PostgreS…