(含代码)利用TensorRT的8位PTQ将Stable Diffusion速度提高 2 倍

news2024/11/23 7:31:16

利用TensorRT的8位PTQ将Stable Diffusion速度提高 2 倍

在这里插入图片描述

在生成人工智能的动态领域中,扩散模型脱颖而出,成为生成带有文本提示的高质量图像的最强大的架构。 像稳定扩散这样的模型已经彻底改变了创意应用。

然而,由于需要迭代去噪步骤,扩散模型的推理过程可能需要大量计算。 这对于努力实现最佳端到端推理速度的公司和开发人员提出了重大挑战。

从 NVIDIA TensorRT 9.2.0 开始,我们开发了一流的量化工具包,具有改进的 8 位(FP8 或 INT8)训练后量化 (PTQ: Post-Training Quantization),可显着加快 NVIDIA 硬件上的扩散部署,同时保持图像质量 。 TensorRT 的 8 位量化功能已成为许多生成型 AI 公司的首选解决方案,特别是创意视频编辑应用程序的领先提供商。

在这篇文章中,我们讨论 TensorRT 与 Stable Diffusion XL 的性能。 我们介绍了使 TensorRT 成为低延迟稳定扩散推理的首选的技术优势。 最后,我们演示如何使用 TensorRT 通过几行更改来加速模型。

性能指标

与在 FP16 中运行的本机 PyTorch 的 torch.compile 相比,用于扩散模型的 NVIDIA TensorRT INT8 和 FP8 量化方案在 NVIDIA RTX 6000 Ada GPU 上实现了 1.72 倍和 1.95 倍的加速。 FP8 相对于 INT8 的额外加速主要归因于多头注意力 (MHA) 层的量化。 使用 TensorRT 8 位量化可以增强生成式 AI 应用程序的响应能力并降低推理成本。

在这里插入图片描述

除了加速推理之外,TensorRT 8 位量化还擅长保持图像质量。 通过专有的量化技术,它生成与原始 FP16 图像非常相似的图像。 我们将在本文后面介绍这些技术。

在这里插入图片描述

TensorRT 解决方案:克服推理速度挑战

尽管 PTQ 被认为是减少内存占用并加速许多 AI 任务推理的首选压缩方法,但它在扩散模型上并不能开箱即用。 扩散模型具有独特的多时间步去噪过程,并且噪声估计网络在每个时间步的输出分布可能会有很大变化。 这使得简单的 PTQ 校准方法不适用。

在现有技术中,SmoothQuant 作为一种流行的 PTQ 方法脱颖而出,可为 LLM 实现 8 位权重、8 位激活 (W8A8) 量化。 其主要创新在于解决激活异常值的方法,通过数学上等效的变换将量化挑战从激活转移到权重。

尽管它很有效,但用户在 SmoothQuant 中手动定义参数时经常遇到困难。 实证研究还表明,SmoothQuant 难以适应不同的图像特征,限制了其在现实场景中的灵活性和性能。 此外,其他现有的扩散模型量化技术仅针对单个版本的扩散模型量身定制,而用户正在寻找一种可以加速各种版本模型的通用方法。

为了应对这些挑战,NVIDIA TensorRT 开发了复杂的细粒度调整管道,以确定 SmoothQuant 模型每一层的最佳参数设置。 您可以根据特征图的具体特征开发自己的调整管道。 与基于客户需求的现有方法相比,此功能使 TensorRT 量化能够获得卓越的图像质量,保留原始图像的丰富细节。

根据 Q-Diffusion 的研究结果,激活分布在不同的时间步长内可能会有很大差异,并且图像的形状和整体风格主要在去噪过程的初始阶段确定。 因此,使用传统的最大校准会导致初始步骤中出现较大的量化误差。

在这里插入图片描述

相反,我们有选择地使用所选步骤范围中的最小量化缩放因子,因为我们发现激活中的异常值对最终图像质量并不那么重要。 这种量身定制的方法,我们将其命名为“百分比定量”,重点关注步长范围的重要百分位。 它使 TensorRT 能够生成与原始 FP16 精度生成的图像几乎相同的图像。

在这里插入图片描述

使用 TensorRT 8 位量化加速扩散模型

/NVIDIA/TensorRT GitHub 存储库现在托管端到端、SDXL、8 位推理管道,提供即用型解决方案以在 NVIDIA GPU 上实现优化的推理速度。

运行单个命令即可使用 Percentile Quant 生成图像,并使用 demoDiffusion 测量延迟。 在本节中,我们使用 INT8 作为示例,但 FP8 的工作流程基本相同。

python demo_txt2img_xl.py "enchanted winter forest with soft diffuse light on a snow-filled day" --version xl-1.0 --onnx-dir onnx-sdxl --engine-dir engine-sdxl --int8 --quantization-level 3

以下是该命令所涉及的主要步骤的概述:

  • 校准
  • 导出 ONNX
  • 构建 TensorRT 引擎

校准

校准是量化过程中计算目标精度范围的步骤。 目前,TensorRT 中的量化功能封装在 nvidia-ammo 中,该依赖项已包含在 TensorRT 8 位量化示例中。

# Load the SDXL-1.0 base model from HuggingFace
import torch
from diffusers import DiffusionPipeline
base = DiffusionPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    torch_dtype=torch.float16,
    variant="fp16",
    use_safetensors=True
)
base.to("cuda")
 
# Load calibration prompts:
from utils import load_calib_prompts
cali_prompts = load_calib_prompts(batch_size=2,prompts="./calib_prompts.txt")
 
# Create the int8 quantization recipe
from utils import get_percentilequant_config
quant_config = get_percentilequant_config(base.unet, quant_level=3.0, percentile=1.0, alpha=0.8)
 
# Apply the quantization recipe and run calibration  
import ammo.torch.quantization as atq 
quantized_model = atq.quantize(base.unet, quant_config, forward_loop)
 
# Save the quantized model
import ammo.torch.opt as ato
ato.save(quantized_model, 'base.unet.int8.pt')

导出 ONNX

获得量化模型检查点后,您可以导出 ONNX 模型。

# Prepare the onnx export  
from utils import filter_func, quantize_lvl
base.unet = ato.restore(base.unet, 'base.unet.int8.pt')
quantize_lvl(base.unet, quant_level=3.0)
atq.disable_quantizer(base.unet, filter_func) # `filter_func` is used to exclude layers you don't quantize
  
# Export the ONNX model
from onnx_utils import ammo_export_sd
base.unet.to(torch.float32).to("cpu")
ammo_export_sd(base, 'onnx_dir', 'stabilityai/stable-diffusion-xl-base-1.0')

构建 TensorRT 引擎

使用 INT8 UNet ONNX 模型,您可以构建 TensorRT 引擎。

trtexec --onnx=./onnx_dir/unet.onnx --shapes=sample:2x4x128x128,timestep:1,encoder_hidden_states:2x77x2048,text_embeds:2x1280,time_ids:2x6 --fp16 --int8 --builderOptimizationLevel=4 --saveEngine=unetxl.trt.plan

总结

在生成式人工智能时代,拥有优先考虑易用性的推理解决方案至关重要。 借助 NVIDIA TensorRT,您可以通过其专有的 8 位量化技术无缝实现高达 2 倍的推理速度加速,同时确保图像质量不受影响,从而实现卓越的用户体验。

TensorRT 对平衡速度和质量的承诺凸显了其作为加速 AI 应用程序的领先选择的地位,使您能够轻松交付尖端解决方案。

我将在 NVIDIA GTC 大会期间为大家带来免费中文在线解读:
NVIDIA CUDA 最新特性以及生成式 AI 相关内容,包括 Stable Diffusion 模型部署实践,以及介绍用于视觉内容生成的 Edify 模型,点击链接了解详情并注册参会:

https://www.nvidia.cn/gtc-global/session-catalog/?search=WP62435%20WP62832%20WP62400&ncid=ref-dev-945313#/

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1513161.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

vulhub靶场-Jangow

下载&部署 下载 下载链接: https://download.vulnhub.com/jangow/jangow-01-1.0.1.ova 点击下载即可 部署 因为源虚拟机是从virtualbox导出的,为了避免繁琐的操作步骤,用virtualbox来导入 virtualbox下载地址: https:…

2021年1月1日起四川启动食品经营许可与备案电子证书办理

12月25日,四川省市场监督管理局、四川省大数据中心发布《关于启用食品经营许可登记电子证书的公告》(以下简称《公告》)。 《公告》显示,为贯彻落实《食品经营许可证管理办法》、《四川省小食品作坊、小经营店、商贩管理规定》和国…

常见面试题之计算机网络

1. OSI 五层模型(或七层模型)是什么,每一层的作用是什么 应用层:又可细分为应用层、表示层、会话层。其中应用层主要做的工作就是为应用程序提供服务,常见的协议为 HTTP、HTTPS、DNS等;表示层主要做的工作…

PXE自动化安装操作系统

一、PXE基本介绍 PXE,全名Pre-boot Execution Environment,预启动执行环境; 通过网络接口启动计算机,不依赖本地存储设备(如硬盘)或本地已安装的操作系统; 由Intel和Systemsoft公司于1999年9月…

通过一篇文章让你了解什么是函数栈帧

函数栈帧的创建和销毁 前言一、什么是函数栈帧二、 理解函数栈帧能解决什么问题三、 函数栈帧的创建和销毁解析3.1 什么是栈3.2 认识相关寄存器和汇编指令相关寄存器eaxebxebpespeip 相关汇编命令 3.3 解析函数栈帧的创建和销毁3.3.1 预备知识3.3.2 函数的调用堆栈3.3.4 准备环…

*Javaweb -- MyBatis*

一:介绍: 1.MyBatis是一个优秀的 ①持久层 ②框架,用于简化JDBC的开发! ①:JAVAEE有三层的结构:表现层, 业务层, 持久层. 表现层代表的是页面的展示,业务层则指的是对于相关逻辑的处理, 而持久层, 指的则是对于数据进行持久化,保存在数据库当中. 持久层具体的来说就是负责…

宏碁又遭网络袭击,菲律宾分公司大量数据被盗

近日,宏碁(Acer)菲律宾公司方面证实,管理该公司员工考勤数据的第三方供应商遭遇网络攻击,部分员工数据被盗。 宏碁,中国台湾计算机硬件和电子产品制造商,以其电脑产品极具性价比而闻名。 据悉&a…

生成式AI来袭,FOSS全闪对象存储应时而生

AI大模型正飞速跃进,从引领文本生成革命的ChatGPT到开创文生视频新纪元的Sora,多模态交互技术连续迭代,促进了智算中心的快速落地。在这一过程中,算力的迅猛增长对存储系统提出了更高的要求和挑战。为满足这些日益增长的需求&…

3_springboot_shiro_jwt_多端认证鉴权_Redis缓存管理器.md

1. 什么是Shiro缓存管理器 上一章节分析完了Realm是怎么运作的,自定义的Realm该如何写,需要注意什么。本章来关注Realm中的一个话题,缓存。再看看 AuthorizingRealm 类继承关系 其中抽象类 CachingRealm ,表示这个Realm是带缓存…

【Linux】Centos7上安装MySQL5.7

目录 1.下载安装包2. 上传安装包3.将 mysql 解压到/usr/local/4.重命名5.创建mysql用户及用户组6. 进入 mysql 目录修改权限7. 安装依赖库8. 执行安装脚本9. 复制启动脚本到资源目录10. 拷贝 my.cnf,并赋予权限11. 配置环境变量12. 启动 mysqld13. 登录 MySQL&#…

深度解析:如何运用山海鲸可视化软件制作高效销售数据看板

在数字化时代,数据可视化已经成为企业决策和运营的重要工具。作为一名长期使用山海鲸可视化软件的资深用户,我深知其在制作销售数据可视化看板方面的优势。今天,我想分享一些我在使用山海鲸可视化软件制作销售数据可视化看板过程中的经验和感…

探索:C++继承中虚表与虚基表的内存存储

探讨:菱形虚拟继承的虚基表和虚表 在继承和多态里,总是能听到虚表、虚基表这样的词汇,没有洞悉其根本的人很容易将它们混淆,因此,我们对这两个“虚”“表”进行实践,来更好地理解它们。 通俗些说&#xf…

哪些行业实操会用到PMP的知识?

首先说项目管理适合那些行业。 项目管理覆盖的行业可以说非常广了,就我知道的,医疗啊,互联网啊,机械啊,建筑啊,金融啊,汽车啊,零售啊、广告啊等各行各业都是需要项目管理人员的。 …

H5简约星空旋转引导页源码

源码名称:H5简约星空旋转引导页 源码介绍:一款带有星空旋转背景特效的源码,带有四个按钮 需求环境:H5 下载地址: https://www.changyouzuhao.cn/11655.html

32x4点阵式LCD驱动芯片/抗干扰段码屏驱动/仪器仪表液晶驱动IC- VK1C21A/B SSOP48/LQFP48 COG

产品型号:VK1C21A/B 产品品牌:永嘉微电/VINKA 封装形式:SSOP48/LQFP48 可定制裸片:DICE(COB邦定片);COG(邦定玻璃用) 工程服务,技术支持! 概述: VK1C21A/B是一个点阵式存储映射…

一个简单的微信小程序表单提交样式模板

没什么东西&#xff0c;只是方便自己直接复制使用 .wxml <view class"box"><form bindsubmit"formSubmit"><view class"form-item"><text class"head">姓名&#xff1a;</text><input class"…

目标跟踪SORT算法原理浅析

SORT算法 Simple Online and Realtime Tracking(SORT)是一个非常简单、有效、实用的多目标跟踪算法。在SORT中&#xff0c;仅仅通过IOU来进行匹配虽然速度非常快&#xff0c;但是ID switch依然非常严重。 SORT最大特点是基于Faster RCNN的目标检测方法&#xff0c;并利用卡尔…

阿里又又发布了一个“AI神器”

阿里给“打工”朋友送上“节日礼物” 六一儿童节当天&#xff0c;阿里就给所有“打工”的大朋友送上了一份“节日礼物” 6月1日上午&#xff0c;阿里云发布了面向音视频内容的AI新品“通义听悟”&#xff0c;并正式公测 通义千问、通义听悟 这哥俩现在所处环境不同&#xff0…

Midjourney封禁Stability AI:恶意爬取数据,致服务器瘫痪24小时

这两家 AI 图像生成公司之间发生什么事了。虽然 AI 生图领域&#xff0c;看似百花齐放&#xff0c;但论资排辈&#xff0c;Midjourney、Stability AI 还是很受用户欢迎的。 Midjourney 把 Stability AI 拉入黑名单了&#xff0c;禁止后者所有员工使用其软件&#xff0c;直至另…

本地mysql5.7以上版本配置及my.ini

&#x1f339;作者主页&#xff1a;青花锁 &#x1f339;简介&#xff1a;Java领域优质创作者&#x1f3c6;、Java微服务架构公号作者&#x1f604; &#x1f339;简历模板、学习资料、面试题库、技术互助 &#x1f339;文末获取联系方式 &#x1f4dd; 往期热门专栏回顾 专栏…