注意力蒸馏技术

news2025/4/1 13:10:40

文章目录

  • 摘要
  • abstract
  • 论文摘要
  • 简介
  • 方法
    • 预备知识
    • 注意力蒸馏损失
    • 注意力引导采样
  • 实验
  • 结论
  • 总结
  • 参考文献

摘要

本周阅读了一篇25年二月份发表于CVPR 的论文《Attention Distillation: A Unified Approach to Visual Characteristics Transfer》,论文开发了Attention Distillation引导采样,这是一种改进的分类器引导方法,将注意力蒸馏损失整合到去噪过程中,大大加快了合成速度,并支持广泛的视觉特征迁移和合成应用。

abstract

This week I read a paper published in CVPR in February, "Attention Distillation: A Unified Approach to Visual Characteristics Transfer, this paper develops the Attention Distillation guided sampling, which is an improved classifier guided method to integrate the attention distillation loss into the denoising process. It greatly speeds up synthesis and supports a wide range of visual feature migration and synthesis applications.

下图中是给定参考图,文生图的示例:
在这里插入图片描述

论文摘要

最近扩散模型方面的进展显示了对图像风格和语义的内在理解。论文提出了一种新颖的注意力蒸馏损失,通过在潜在空间中反向传播来优化合成图像,同时改进了一个分类器引导,它将注意力蒸馏损失集成到去噪采样过程中,进一步加速合成过程。
在这里插入图片描述

简介

论文解决问题: 现有生成扩散模型在图像风格和语义理解方面虽然有进展,但在将参考图像的视觉特征转移到生成图像中时,使用即插即用注意力特征的方法存在局限性。

传统的方法通常将纹理定义为重复的局部模式,并通过从源图像中复制局部补丁来合成新的纹理。通常归结为以下三个原因导致的局限性:

  1. 域差距:当两幅图像存在显著差异时,目标Q(合成图像的查询)与参考图像的K,V之间的相似性较低且不可靠,导致错误的聚合结果(AdaIN和注意力能缓解这个问题)
  2. 误差积累:虽然扩散模型中的迭代采样过程可以改善目标Q和参考图中的K,V之间的巨大差异,但误差也可能积累。来自不同扩散模型层的特征集中于不同的信息,如语义和几何。不正确的匹配将会错误传播到马尔科夫链的后续层,并降低最终图像质量。
  3. 框架限制:在去噪网络的剩余分支内实现自注意力机制,参考图像中的自注意力特征可能对目标图像有潜在的影响,降低了合成的效力。
    为了解决上述局限性,本篇论文中引入一种新的注意力蒸馏损失AD loss,在此基础上,通过反向传播直接更新合成的图像。

提出方案: 首先,提出了一种新颖的注意力蒸馏损失,用于在理想和当前风格化结果之间计算损失,并在隐空间中通过反向传播优化合成图像。其次,开发了一种改进的分类器引导方法,即注意力蒸馏引导采样,将注意力蒸馏损失整合到去噪采样过程中。

方法

预备知识

隐空间扩散模型(LDM),如Stable Diffusion,由于其对复杂数据分布的强大建模能力,在图像生成方面达到了最先进的性能。在LDM中,首先使用预训练的VAE 将图像x压缩到一个学习到的隐空间中。随后,基于UNet的去噪网络被训练用于在扩散过程中预测噪声,通过最小化预测噪声与实际添加噪声之间的均方误差来实现。
L L D M = E z ∼ E ( x ) , y , ϵ ∼ N ( 0 , 1 ) , t [ ∥ ϵ θ ( z t , t , y ) − ϵ ∥ 2 2 ] \mathcal{L}_{\mathrm{LDM}}=\mathbb{E}_{z\sim\mathcal{E}(x),y,\epsilon\sim\mathcal{N}(0,1),t}\left[\|\epsilon_\theta(z_t,t,y)-\epsilon\|_2^2\right] LLDM=EzE(x),y,ϵN(0,1),t[ϵθ(zt,t,y)ϵ22]
其中 y 表示条件, 表示时间步长。去噪 UNet 通常由一系列卷积块和自注意力/交叉注力模块组成,所有这些都集成在残差架构的预测分支中。
KV注入在图像编辑、风格迁移和纹理合成中被广泛使用。它建立在自注意力机制之上,并将扩散模型中的自注意力特征用作即插即用的属性。自注意力机制的公式为:
S e l f − A t t n ( Q , K , V ) = s o f t m a x ( Q K T d ) V \mathrm{Self-Attn}(Q,K,V)=\mathrm{softmax}(\frac{QK^{T}}{\sqrt{d}})V SelfAttn(Q,K,V)=softmax(d QKT)V
在注意力机制的核心,是基于查询Q和键K之间的相似性计算权重矩阵,该矩阵用于对值V进行加权聚合。KV注入通过在不同的合成分支之间复制或共享KV特征来扩展这一机制。其关键假设是KV特征代表图像的视觉外观。在采样过程中,将合成分支中的KV特征替换为示例的相应时间步长的KV特征,可以实现从源图像到合成目标的外观转移。

注意力蒸馏损失

尽管KV注入取得了显著的效果,但由于残差机制的影响,它在保留参考的风格或纹理细节方面表现不足;例如,下图(a)中。KV注入仅作用于残差,这意味着信息流(红色箭头)随后受到恒等连接的影响,导致信息传递不完整。因此,采样输出无法完全再现所需的视觉细节。
在这里插入图片描述
本论文通过在自注意力机制中重新聚合特征来提取视觉元素。利用预训练的T2I扩散模型SD的UNet,从自注意力模块中提取图像特征。
在这里插入图片描述
上图中,首先根据目标分支的Q,从参考分支重新聚合KV特征(Ks和Vs)的视觉信息,这与KV注入相同。
将此注意力输出视为理想的风格化。然后,我们计算目标分支的注意力输出,并计算相对于理想注意力输出的L1损失,这定义了AD损失:
L A D = ∥ S e l f − A t t n ( Q , K , V ) − S e l f − A t t n ( Q , K s , V s ) ∥ 1 \mathcal{L}_{\mathrm{AD}}=\|\mathrm{Self-Attn}(Q,K,V)-\mathrm{Self-Attn}(Q,K_{s},V_{s})\|_{1} LAD=SelfAttn(Q,K,V)SelfAttn(Q,Ks,Vs)1
可以使用提出的AD损失通过梯度下降来优化随机隐空间噪声,从而在输出中实现生动的纹理或风格再现;例如,参见上图(b)。这归因于优化中的反向传播,它不仅允许信息在(残差)自注意力模块中流动,还通过恒等连接流动。通过持续优化,Q和Ks之间的差距逐渐缩小,使得注意力越来越准确,最终特征被正确聚合以产生所需的视觉细节。

注意力引导采样

将注意力蒸馏损失以改进的分类器引导方式纳入扩散模型的采样过程中。
分类器引导在去噪过程中改变去噪方向,从而生成来自p(zt|c)的样本,其公式可以表示为:
ϵ ^ θ = ϵ θ ( z t , t , y ) − α σ t ∇ z t log ⁡ p ( c ∣ z t ) \hat{\epsilon}_\theta=\epsilon_\theta(z_t,t,y)-\alpha\sigma_t\nabla_{z_t}\log p(c|z_t) ϵ^θ=ϵθ(zt,t,y)ασtztlogp(czt)
其中,t是时间步长,y表示提示, ϵ θ \epsilon_\theta ϵθ   z t \ {z_t}  zt分别指去噪网络和LDM中的隐空间变量。 α \alpha α控制引导强度。使用基于注意力蒸馏损失的能量函数来引导扩散采样过程。

实验

由于补丁来源有限,使用传统方法合成超高分辨率纹理非常困难。在此,将注意力蒸馏引导的采样应用于MultiDiffusion模型,使纹理扩展到任意分辨率。尽管SD-1.5是在尺寸为512×512的图像上训练的,但令人惊讶的是,当结合注意力蒸馏时,它在大尺寸纹理合成中表现出了强大的能力。下图展示了将纹理扩展到512×1536的尺寸与GCD和GPDM的比较。
在这里插入图片描述
损失函数代码

def ad_loss(
    q_list, ks_list, vs_list, self_out_list, scale=1, source_mask=None, target_mask=None
):
    loss = 0
    attn_mask = None
    for q, ks, vs, self_out in zip(q_list, ks_list, vs_list, self_out_list):
        if source_mask is not None and target_mask is not None:
            w = h = int(np.sqrt(q.shape[2]))
            mask_1 = torch.flatten(F.interpolate(source_mask, size=(h, w)))
            mask_2 = torch.flatten(F.interpolate(target_mask, size=(h, w)))
            attn_mask = mask_1.unsqueeze(0) == mask_2.unsqueeze(1)
            attn_mask=attn_mask.to(q.device)

        target_out = F.scaled_dot_product_attention(
            q * scale,
            torch.cat(torch.chunk(ks, ks.shape[0]), 2).repeat(q.shape[0], 1, 1, 1),
            torch.cat(torch.chunk(vs, vs.shape[0]), 2).repeat(q.shape[0], 1, 1, 1),
            attn_mask=attn_mask
        )
        loss += loss_fn(self_out, target_out.detach())
    return loss



def q_loss(q_list, qc_list):
    loss = 0
    for q, qc in zip(q_list, qc_list):
        loss += loss_fn(q, qc.detach())
    return loss

# weight = 200
def qk_loss(q_list, k_list, qc_list, kc_list):
    loss = 0
    for q, k, qc, kc in zip(q_list, k_list, qc_list, kc_list):
        scale_factor = 1 / math.sqrt(q.size(-1))
        self_map = torch.softmax(q @ k.transpose(-2, -1) * scale_factor, dim=-1)
        target_map = torch.softmax(qc @ kc.transpose(-2, -1) * scale_factor, dim=-1)
        loss += loss_fn(self_map, target_map.detach())
    return loss

# weight = 1
def qkv_loss(q_list, k_list, vc_list, c_out_list):
    loss = 0
    for q, k, vc, target_out in zip(q_list, k_list, vc_list, c_out_list):
        self_out = F.scaled_dot_product_attention(q, k, vc)
        loss += loss_fn(self_out, target_out.detach())
    return loss

下面这段代码主要通过自适应特征提取和优化,将内容图像的潜变量 (latents) 调整为具有风格图像特征的潜变量,实现风格迁移(Style Transfer)或风格控制 (Style-Adaptive Denoising, AD)。
1.使用了一种基于 AdaIN (Adaptive Instance Normalization) 的方法对 latents 进行风格调整:

if self.adain:
    noise = torch.randn_like(self.style_latent)
    style_latent = self.scheduler.add_noise(self.style_latent, noise, t)
    latents = utils.adain(latents, style_latent)

2.提取风格和内容特征:

qs_list, ks_list, vs_list, s_out_list = self.extract_feature(
    self.style_latent,
    t,
    self.null_embeds_for_style,
    add_noise=True,
)
if self.content_latent is not None:
    qc_list, kc_list, vc_list, c_out_list = self.extract_feature(
        self.content_latent,
        t,
        self.null_embeds,
        add_noise=True,
    )

3.优化 latents 使其匹配风格和内容特征:

optimizer = torch.optim.Adam([latents.requires_grad_()], lr=lr)
optimizer = self.accelerator.prepare(optimizer)

在 iters 轮优化中,计算损失 (style_loss 和 content_loss),并进行反向传播:

for j in range(iters):
    style_loss = ad_loss(q_list, ks_list, vs_list, self_out_list, scale=self.attn_scale)
    if self.content_latent is not None:
        content_loss = q_loss(q_list, qc_list)
    loss = style_loss + content_loss * weight
    self.accelerator.backward(loss)
    optimizer.step()

结论

这篇论文提出了一种统一的方法来处理各种视觉特征转移任务,包括风格/外观转移、特定风格的图像生成和纹理合成。该方法的关键是一种新颖的注意力蒸馏损失,它计算理想风格化与当前风格化之间的差异,并逐步修改合成。

总结

这篇论文提出了一种基于注意力蒸馏(Attention Distillation, AD)的新方法,用于改进扩散模型在视觉特征迁移任务中的表现。作者引入注意力蒸馏损失(AD Loss),通过反向传播优化合成图像,使其更好地匹配目标风格。此外,论文提出注意力蒸馏引导采样,将AD Loss整合到去噪过程中,加快图像合成速度,并提升细节保真度。实验表明,该方法在风格迁移、特定风格图像生成和纹理合成等任务中均优于现有技术,特别是在高分辨率纹理生成方面表现突出。该方法通过改进查询-键-值(Q-K-V)特征聚合,有效缓解域差距、误差积累和框架限制问题。

参考文献

[1] Attention Distillation: A Unified Approach to Visual Characteristics Transfer

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2324400.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

PERL开发环境搭建>>Windows,Linux,Mac OS

特点 简单 快速 perl解释器直接对源代码程序解释执行,是一个解释性的语言, 不需要编译器和链接器来运行代码>>速度快 灵活 借鉴了C/C, Basic, Pascal, awk, sed等多种语言, 定位于实用性语言,既具备了脚本语言的所有功能,也添加了高级语言功能 开源.免费 没有&qu…

鸿蒙项目源码-记账本app个人财物管理-原创!原创!原创!

鸿蒙记账项目源码个人财务管理含文档包运行成功ArkTS语言。 我一个月写的原创作品,请尊重原创。 原创作品,盗版必究!!! api12 SDK5.0.0仅适用于最新的2024版本DevEco studio 共9个页面:广告倒计时页、登录、…

Ovito的python脚本

在 OVITO 里,Python 对象是构建脚本化操作的基础。下面为你详细介绍 OVITO 中 Python 对象的基本概念: 1. 数据管道(Pipeline) 数据管道是 OVITO 里最核心的对象之一。它就像一个流水线,把数据输入进来,经过一系列处理步骤,最后输出处理好的数据。 创建管道:借助 imp…

【免费】2007-2019年各省地方财政文化体育与传媒支出数据

2007-2019年各省地方财政文化体育与传媒支出数据 1、时间:2007-2019年 2、来源:国家统计局、统计年鉴 3、指标:行政区划代码、地区、年份、地方财政文化体育与传媒支出 4、范围:31省 5、指标说明:地方财政在文化、…

NOIP2007提高组.矩阵取数游戏

题目 492. 矩阵取数游戏 思路 不难发现, 每一行之间是独立的, 因此可以求出每一行的最大值, 然后行与行之间最大值相加, 就是总的最大值 对于行内来说, 每次可以选取左边或者右边, 可以使用区间 d p dp dp求解, 时间复杂度 O ( n 3 ) O(n ^ 3) O(n3), 因为列的最大值是 80 …

项目实战--权限列表

后端数据: 用表格实现权限列表 const dataSource [{key: 1,name: 胡彦斌,age: 32,address: 西湖区湖底公园1号,},{key: 2,name: 胡彦祖,age: 42,address: 西湖区湖底公园1号,}, ];const columns [{title: 姓名,dataIndex: name,key: name,},{title: 年龄,dataInd…

若依赖前端处理后端返回的错误状态码

【背景】 后端新增加了一个过滤器,用来处理前端请求中的session 若依赖存放过滤器的目录:RuoYi-Vue\ruoyi-framework\src\main\java\com\ruoyi\framework\security\filter\ 【问题】 后端返回了一个状态码为403的错误,现在前端需要处理这…

【计网】数据包

期末复习自用的,处理得比较草率,复习的同学或者想看基础的同学可以看看,大佬的话可以不用浪费时间在我的水文上了 1.数据包的定义: 数据包是网络通信中的基本单元,它包含了通过网络传输的所有必要信息。数据包的结构…

web权限划分提权和移权

前言:权限的基本认知 渗透权限划分:假如我们通过弱口令进入到web的后台 这样我们就拿到了web的管理员权限 管理员权限是web中最高的权限(一般我们进入web的时候数据库会进行用户权限的划分:假设 0-10为最高的权限 11-10000为普通…

LocalDateTime序列化总结

版权说明: 本文由CSDN博主keep丶原创,转载请保留此块内容在文首。 原文地址: https://blog.csdn.net/qq_38688267/article/details/146703276 文章目录 1.背景2.序列化介绍常见场景关键问题 3.总体方案4.各场景实现方式WEB接口EasyExcelMybat…

[ 春秋云境 ] Initial 仿真场景

文章目录 靶标介绍:外网内网信呼oa永恒之蓝hash传递 靶标介绍: Initial是一套难度为简单的靶场环境,完成该挑战可以帮助玩家初步认识内网渗透的简单流程。该靶场只有一个flag,各部分位于不同的机器上。 外网 打开给的网址, 有一…

unity 截图并且展现在UI中

using UnityEngine; using UnityEngine.UI; using System.IO; using System.Collections.Generic; using System; using System.Collections;public class ScreenshotManager : MonoBehaviour {[Header("UI 设置")]public RawImage latestScreenshotDisplay; // 显示…

中断管理常用API(四)

一、request_irq(...) request_irq 函数主要用于硬中断相关操作,它的核心作用是把一个中断处理函数和特定的中断号进行绑定。当硬件设备触发该中断号对应的中断时,内核就会调用绑定的中断处理函数,像 irqhandler_func 这类。 此函数在多种硬件…

pyspark学习rdd处理数据方法——学习记录

python黑马程序员 """ 文件,按JSON字符串存储 1. 城市按销售额排名 2. 全部城市有哪些商品类别在售卖 3. 上海市有哪些商品类别在售卖 """ from pyspark import SparkConf, SparkContext import os import jsonos.environ[PYSPARK_P…

【HTML 基础教程】HTML <head>

HTML <head> 查看在线实例 - 定义了HTML文档的标题"><title> - 定义了HTML文档的标题 使用 <title> 标签定义HTML文档的标题 - 定义了所有链接的URL"><base> - 定义了所有链接的URL 使用 <base> 定义页面中所有链接默认的链接目…

混合知识表示系统框架python示例

前文我们已经深入学习了框架表示法、产生式规则和一阶谓词逻辑,并对它们进行了深度对比,发现它们在不同的应用场景下各有优缺点。 一阶谓词逻辑适合复杂逻辑推理场景,具有数学定理证明、形式化系统规范的优点;产生式规则适合动态决策系统,支持实时决策(如风控、诊断),规…

MATLAB 控制系统设计与仿真 - 30

用极点配置设计伺服系统 方法2-反馈修正 如果我们想只用前馈校正输入&#xff0c;从而达到伺服控制的效果&#xff0c;我们需要很精确的知道系统的参数模型&#xff0c;否则系统输出仍然具有较大的静态误差。 但是如果我们在误差比较器和系统的前馈通道之间插入一个积分器&a…

Baklib知识中台驱动智能架构升级

构建四库体系驱动架构升级 在数字化转型过程中&#xff0c;企业普遍面临知识资源分散、隐性经验难以沉淀的痛点。Baklib通过构建知识库、案例库、流程库及资源库四层核心体系&#xff0c;为知识中台搭建起结构化基础框架。知识库以AI分类引擎实现文档标签化存储&#xff0c;案…

IP第一次笔记

一、TCP协议 第0步&#xff1a;如果浏览器和host文件存在域名对应的P地址记录关系 则直接封装HTTP数据报文&#xff0c;如果没有记录则触发DNS解析获 取目标域名对应的P地址 第一步&#xff1a;终端主机想服务器发起TCP三次握手 1.TCP的三次握手 2.传输网页数据 HTTP --应用层…

vue3实现router路由

说明&#xff1a; vue3实现router路由 效果图&#xff1a; step1:项目结构 src/ ├── views/ │ ├── Home.vue │ └── User.vue ├── router/ │ └── index.js ├── App.vue └── main.jsstep2:左边路由列表C:\Users\wangrusheng\PycharmProjects\un…