用你的手机/电脑运行文生图方案

news2025/1/15 18:19:52

f8b0fc16a231ae0cd7d8fcf298dd8fde.gif

随着ChatGPT和Stable Diffusion的发布,最近一两年,生成式AI已经火爆全球,已然成为移动互联网后一个重要的“风口”。就图片/视频生成领域来说,Stable Diffusion模型发挥着极其重要的作用。由于Stable Diffusion模型参数量是10亿参数的大模型,通常业界都是运行部署在显卡上。

但是随着量化、剪枝等模型压缩技术的进步,以及手机等终端设备的算力、带宽、内存持续增大。使得大模型在终端设备部署也成为的可能。大模型在终端部署可以有效保护用户隐私,而且终端设备日常广泛使用、用户可以随时随地生成想要的内容。

7ec7e6256a7667e08014fbdff9b54cea.png

MNN-Diffusion使用

本文是深度学习推理引擎MNN团队,做的Stable Diffusion端侧部署应用,代码开源,用户可以自行DIY各种好玩的Stable Diffusion应用。

MNN开源地址:

https://github.com/alibaba/MNN/tree/master

欢迎大家试用,使用教程如下:

https://mnn-docs.readthedocs.io/en/latest/transformers/diffusion.html



下面是在个人手机/电脑上生成的图片:

bebe439c12b0e97687df89db324f0bad.png

技术要点

业界加速Stable Diffusion部署通常有两个方向,一是算法层面的优化,包括优化网络结构、减少计算量或者降低推理迭代步数;二是工程部署优化,通过量化/算子高效实现等方式提高硬件计算效率、提高访存效率。MNN作为推理引擎,主要聚焦在工程部署优化上,下面分享下MNN Diffusion GPU在性能/内存方面做了优化工作。

  Self-Attention优化

Transformer结构中Self-Attention是一个基础结构,也是性能耗时的关键。如下结构是一个典型的Attention结构:

30df654d025272d5fc22fc5fd4eb287e.png

一个共有节点,分别经过三个Linear层,得到Query/Key/Value,Query/Key经过形状变换进行BatchMatMul操作,再进行Scale,取Softmax操作;该结果和Value经过形状变换做BatchMatMul;之后把结果进行形状变换,得到最终的输出。可以看到上述总共有19个算子,包括12个形状变化算子,7个计算型算子。



大量的形状变化会带来很多的访存耗时,对于GPU高算力的硬件来说,访存耗时往往容易成为热点。因此,将上述结构,融合成2个算子,第一个是将三个Linear层权重融合在一起,只做一个Linear,这样形成更大的矩阵乘尺寸,更容易打满GPU算力,带来性能收益;第二个算子是将Attention算子融合成一个算子Fused-MultiHead-Attention,融合之后在该新算子内部仅需5个Kernel就可以实现整个Attention功能。消除了大量额外的形状变换算子,降低了访存压力,同时可以更容易基于Attention算子特性做进一步优化工作。

500e2baca3a6db86eb4d5d2fd9550b6f.png

  GroupNorm/SplitGeLU融合

在Stable Diffusion中,有一个通用的结构ResnetBlock,其中包含了BroadCast Binary + GroupNorm + SiLU结构,在onnx模型图结构中包含了如下13个算子:

333736b66abbeff889e57a0fad37192b.png

可以看到GroupNorm采用InstanceNorm+形变算子实现,gamma/beta被单独拆解为mul/add算子,细碎的算子会增加全局内存的访存次数、以及Kernel launch的压力。因此将上述通用结构合并成一个GroupNorm算子,该算子把前面的BroadCast Binary和后续的SiLU激活函数,融合在一起。高效的只需一个Kernel就可以实现上述计算需求。



同样的图融合原理,在Transformer激活函数中,Stable Diffusion Feed-Forward模块中采用GEGLU结构,对应onnx图结构如下。将该8个onnx图算子,融合为通用的SplitGeLU算子。

c5251af434249b000dafca534c294940.png

  conv-winograd算法实现

在Stable Diffusion中有大量3x3卷积,在深度学习中,Winograd算法已经大量应用在加速3x3卷积实现。

Winograd F(m, r)算法,其中m代表一个计算tile的大小,r对应filter的尺寸,d=m+r-1 代表对应input tile大小。

4f6d41a9cea75a5c83330159ce22669b.png

下表是3x3 Winograd不同tile对应计算量的节省比例和中间内存占用的增大比例。

m

r

d

计算量前后比例

input中间内存

weight中间内存

2

3

4

9 : 4 = 2.25x

4x

1.78x

4

3

6

4 : 1 = 4x

2.25x

4x

6

3

8

81 : 16 = 5.06x

1.78x

7.11x

目前,我们使用的是F(2, 3) Winograd,控制内存增大量,同时带来一倍的性能提升效果。

  高性能Gemm/BatchGemm

上述分析可以看出,Attention/卷积3x3,核心计算量在BatchGemm上,Linear层实际上就是Gemm运算。实际上,Stable Diffusion中,核心的计算量或者说耗时的热点,归根溯源,都集中在Gemm/BatchGemm上。如何高效实现矩阵乘法 成为最核心的关键。

矩阵乘在各个维度上的分块策略,可以有效提升数据的复用度和数据cache命中率;合理的分块可以为矩阵乘法带来大幅度的性能提升。

a6c30dcdb2a2f826e6023ee302c4c2ab.png

上图展示了,矩阵乘在各个维度上面的分块变量,包括在并发M/N维度,单次数据访存向量化位宽、每个线程存取矩阵的尺寸、每个工作组存取矩阵的尺寸,以及如果使用local memory缓存的话每个线程/工作组的缓存量。

这些参量都决定了数据访存的效率、并发量的大小、计算访存比的大小。不同的设备有不同的寄存器资源、共享内存资源、访存带宽、计算核心数,这些参量都决定着矩阵乘法的性能效率。



对于特定的矩阵乘的尺寸M/N/K,针对特定设备采取Auto-Tuning的获取最佳的运行参数(OPWM/OPWN/OPTM/OPTN/VEC_M/VEC_N等),Tuning候选集数量是M的N次方(N是参数的个数、M是每个参数候选集个数)。如果暴力循环每个参数候选集,由于候选集数量巨大、并且大尺寸矩阵乘本身单次运行耗时较大,必然会导致要花费大量时间去Tuning完所有候选集。因此,根据经验和实际试跑,选出部分高频参数候选集进行Tuning,在控制好Tuning时间的同时,也可以带来极大的性能收益。

  Gemm Strassen探索

由于矩阵乘法是Stable Diffusion耗时的核心,因此进行了矩阵乘快速算法的研究探索。Strassen算法是利用矩阵拆解,通过引入矩阵加减法,来减少矩阵乘法次数的方式。最简单的方法,将M/N/K维度各对拆1/2的方法,朴素的矩阵拆解如下:

0b94ad48464fc4675377f555b66e0b0e.png

Strassen算法,通过15次子矩阵加减法,来减少一次子矩阵乘法。矩阵拆解如下:

46a06996401968f0e405de7c014f6a98.png

当N足够大时,矩阵加减法耗时会远低于矩阵乘法耗时,带来12.5%的计算量降低。当N较小时,受限于15次 子矩阵加减的 耗时,以及拆解子矩阵乘法算力打不满等损耗原因,将引起负优化。具体某个形状的矩阵乘法适不适合使用Strassen算法?



对于矩阵A形状为[M, K], 矩阵B形状为[N, K],输出矩阵C形状为[M, N]。15次子矩阵加减,数据访存量为:(3*M*K + 3*N*K + 3.5*M*N) * sizeof(DataType) Bytes。1次子矩阵乘法,数据计算量为:1/8 * M*N*K * 2 = 1/4 * M*N*K FLOPS。我们默认矩阵加减是带宽瓶颈,矩阵乘法是算力瓶颈。假设设备的内存带宽为X GB/s,算力是Y GFLOPS。

子矩阵加减耗时:(6*M*K + 6*N*K + 3.5*M*N)*sizeof(DataType) / X (ns)

子矩阵乘节省耗时:(1/4 * M * N * K) / Y (ns)



当节省的耗时大于损耗耗时,即可有性能收益。根据上述公式,计算访存比越低的设备,Strassen算法越容易有收益。对于手机设备来说,1024x1024x1024的子矩阵,通常可以获得约10%的性能收益。

  内存占用优化

在Attention优化中,Q/K做BatchMatMul得到中间数据QK时,张量维度为[Batch, HeadNum, SeqLen, SeqLen]。对于Stable Diffusion来说,会遇到Batch=2,HeadNum=16,SeqLen=4096。对于float16的数据类型,单个张量的存储就需要1GB的内存大小,这对于内存资源紧缺的端侧设备是不可接受的。

876e1c1eb1ae66fb68659378e93eced9.png

因此,将Attention操作进行分块处理,类似Paged Attention的思路,将整个Attention分成SeqNum次执行,这样每次仅需原先1/SeqNum中间内存大小,可以非常有效的控制内存的大小。

性能测评

MNN Stable Diffusion应用,生成512x512图片,在骁龙8Gen3上使用GPU float16精度达到2s/iter (20次迭代,手机上40s可以生成完一幅图),在Apple Mac M3上GPU float32精度达到1.1s/iter (20次迭代,Mac上22s可以生成完一幅图)。MNN CPU/GPU性能均较大幅度快于如下Stable Diffusion开源框架,例如:

  • stable-diffusion.cpp

    https://github.com/leejet/stable-diffusion.cpp/issues/15

  • Android OnnxRuntime Stable Diffusion应用

    https://github.com/ZTMIDGO/Android-Stable-diffusion-ONNX

9e9496859a04ea0ac02b2af87f69bd83.png

后续研究

后续在性能优化和内存优化上面仍然有空间可以挖掘。

性能优化方面:

  • Conv Winograd采用更大的分块,获取更高的计算量降低收益。

  • 矩阵乘尝试Image存储内存访问模式,提高访存效率。

  • Attention进一步采用Flash Attention等思路优化。

内存占用优化方面:

  • 采用低比特权重(int8/int4量化)。

  • 在线转换动态内存可复用,Conv Winograd权重尝试采用在线转换。

  • Attention 采用Flash Attention优化节省中间内存使用。

8c56100ad51dd1b6565cc7b815d90b01.png

参考资料

  • https://blog.csdn.net/xian0710830114/article/details/129194419

  • https://github.com/NVIDIA/TensorRT/tree/release/8.6/demo/Diffusion

  • https://arxiv.org/abs/0707.2347

  • https://courses.cs.cornell.edu/cs6810/2023fa/Matrix.pdf

  • https://github.com/CNugteren/CLBlast/tree/master

  • https://arxiv.org/pdf/1703.06503

  • https://github.com/leejet/stable-diffusion.cpp/

  • https://github.com/ZTMIDGO/Android-Stable-diffusion-ONNX

e2e86c44f7745eb2a3c1103c68e973db.png

团队介绍

我们是大淘宝技术Meta Team,负责面向消费场景的3D/XR基础技术建设和创新应用探索,通过技术和应用创新找到以手机及XR 新设备为载体的消费购物3D/XR新体验。团队在端智能、商品三维重建、3D引擎、XR引擎等方面有深厚的技术积累。团队在OSDI、MLSys、CVPR、ICCV、NeurIPS、TPAMI等顶级学术会议和期刊上发表多篇论文。

¤ 拓展阅读 ¤

3DXR技术 | 终端技术 | 音视频技术

服务端技术 | 技术质量 | 数据算法

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2219270.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

PHP爬虫:获取商品销量数据的利器

在电子商务的激烈竞争中,掌握商品销量数据是商家洞察市场动态、制定销售策略的关键。通过PHP爬虫技术,我们可以高效地获取这些数据,为商业决策提供支持。 PHP爬虫的优势 PHP作为一种流行的服务器端脚本语言,拥有跨平台运行、丰富…

【C++篇】类与对象的秘密(上)

目录 引言 一、类的定义 1.1类定义的基本格式 1.2 成员命名规范 1.3 class与struct的区别 1.4 访问限定符 1.5 类的作用域 二、实例化 2.1 类的实例化 2.2 对象的大小与内存对齐 三、this 指针 3.1 this指针的基本用法 3.2 为什么需要this指针? 3.3 t…

数据结构——链表,哈希表

文章目录 链表python实现双向链表复杂度分析 哈希表(散列表)python实现哈希表哈希表的应用 链表 python实现 class Node:def __init__(self, item):self.item itemself.next Nonedef head_create_linklist(li):head Node(li[0])for element in li[1…

SQL Server 2019数据库“正常,已自动关闭”

现象: SQL Server 2019中,某个数据库在SQL Server Management Studio(SSMS)中的状态显示为“正常,已自动关闭”。 解释: 如此显示,是由于该数据库的AUTO_ CLOSE选项被设为True。 在微软的官…

JavaSE——IO流1:FileOutputStream(字节输出流)、FileInputStream(字节输入流)

目录 一、IO流概述 二、IO流的分类 三、字节输出流与字节输入流 (一)字节输出流——FileOutputStream 1.FileOutputStream书写步骤 2.FileOutputStream书写细节 3.FileOutputStream写数据的3种方式 4.FileOutputStream的换行和续写 (二)字节输入流——FileInputStream…

如何给手机换ip地址

在当今数字化时代,IP地址作为设备在网络中的唯一标识,扮演着举足轻重的角色。然而,有时出于隐私保护、网络访问需求或其他特定原因,我们可能需要更改手机的IP地址。本文将详细介绍几种实用的方法,帮助您轻松实现手机IP…

若依框架中spring security的完整认证流程,及其如何使用自定义用户表进行登录认证,学会轻松实现二开,嘎嘎赚块乾

1)熟悉之前的SysUser登录流程 过滤器链验证配置 这里security过滤器链增加了前置过滤器链jwtFilter 该过滤器为我们自定义的,每次请求都会经过jwt验证 ok我们按ctrl alt B跳转过去来看下 首先会获取登录用户LoginUser 内部通过header键,获…

Deep Learning

深度学习 文章目录 前言面向开发人员的 NVIDIA AI 平台每个 AI 框架 - 加速统一平台从开发到部署前言 深度学习是 AI 和机器学习的一个子集,它使用多层人工神经网络在对象检测、语音识别、语言翻译等任务中提供最先进的准确性。 深度学习与传统机器学习技术的不同之处在于,深…

python爬虫加解密分析及实现

第一种: 1、找到加密的接口地址,通过加密的接口地址全局搜索 2、通过打断点的方式,操作页面,跑到断点处时,即可找到加密串,如图二; 3、找到用的是哪种加密方式,如: cr…

PCL 点云配准 基于目标对称的ICP算法(精配准)

目录 一、概述 1.1原理 1.2实现步骤 1.3应用场景 二、代码实现 2.1关键函数 2.1.1计算点云的法线 2.1.2基于对称误差估计的ICP配准 2.1.3可视化 2.2完整代码 三、实现效果 PCL点云算法汇总及实战案例汇总的目录地址链接: PCL点云算法与项目实战案例汇总…

OpenCV高级图形用户界面(20)更改窗口的标题函数setWindowTitle()的使用

操作系统:ubuntu22.04 OpenCV版本:OpenCV4.9 IDE:Visual Studio Code 编程语言:C11 算法描述 在OpenCV中,cv::setWindowTitle函数用于更改窗口的标题。这使得您可以在程序运行时动态地更改窗口的标题文本。 函数原型 void cv::…

SpringBoot日常:封装redission starter组件

文章目录 逻辑实现POM.xmlRedissionConfigRedissionPropertiesRedissionUtilsspring.factories 功能测试application.yml配置POM.xmlTestController运行测试 本章内容主要介绍如何通过封装相关的redission连接配置和工具类,最终完成一个通用的redission starter。并…

论文速读:通过目标感知双分支蒸馏进行跨域目标检测(CVPR2022)

原文标题:Cross Domain Object Detection by Target-Perceived Dual Branch Distillation 中文标题:通过目标感知双分支蒸馏进行跨域目标检测 论文地址: https://arxiv.org/abs/2205.01291 代码地址: GitHub - Feobi1999/TDD 这篇…

使用多块AMD GPU通过Megatron-DeepSpeed进行大型语言模型的预训练

Pre-training a large language model with Megatron-DeepSpeed on multiple AMD GPUs 2024年1月24日,作者:Douglas Jia 在这篇博客中,我们将向你展示如何使用Megatron-DeepSpeed框架在多块AMD GPU上预训练GPT-3模型。我们还将展示如何使用你…

5、JavaScript(二) 对象+DOM

17.对象 1、对象:⽤来存储多个数据的 是由多个键值对/key value对组成的 ⽤来描述⼀个事物的 相当于多个变量的集合 2、格式 :{key:value,key:value} 键/值对 属性名:属性值 3、对象的属性值是不限制数据类型的,甚至还可以是对…

常用的字符集(ASCII、GBK)

目录 1.ASCII字符集 2.各版本的字符集 3. GBK字符集在计算机中的存储规则 4. 总结 1.ASCII字符集 计算机中最小的存储单元是一个字节,一个字节8bit 0-127 一共是128个 2.各版本的字符集 只需要掌握GBK和Unicode两个字符集。GBK是简体中文window操作系统默认使…

85.【C语言】数据结构之顺序表的中间插入和删除及遍历查找

目录 3.操作顺序表 1.分析中间插入函数 函数的参数 代码示例 图片分析 main.c部分改为 在SeqList.h添加SLInsert函数的声明 运行结果 2.分析中间删除函数 函数的参数 代码示例 图片分析 main.c部分改为 在SeqList.h添加SLErase函数的声明 运行结果 承接84.【C语…

Atlas800昇腾服务器(型号:3000)—YOLO全系列NPU推理【检测】(五)

服务器配置如下: CPU/NPU:鲲鹏 CPU(ARM64)A300I pro推理卡 系统:Kylin V10 SP1【下载链接】【安装链接】 驱动与固件版本版本: Ascend-hdk-310p-npu-driver_23.0.1_linux-aarch64.run【下载链接】 Ascend-…

spring boot实现不停机更新

主要实现思路:发布新的应用程序(与原端口不同),启动成功后,将原端口进行优雅关闭,同时将应用程序端口动态切换至原端口 application.yml server:port: 8000shutdown: graceful DatapickCliApplication package com.zy.datapickcli;import org.springframework.boot.SpringAp…

保研考研机试攻略:python笔记(1)

🐨🐨🐨宝子们好呀 ~ 我来更新欠大家的python笔记了,从这一篇开始我们来学下python,当然,如果只是想应对机试并且应试语言以C和C为主,那么大家对python了解一点就好,重点可以看高分篇…