FreeU: Free Lunch in Diffusion U-Net 笔记

news2025/1/10 14:35:10

FreeU: Free Lunch in Diffusion U-Net

摘要

作者研究了 U-Net 架构对去噪过程的关键贡献,并发现其主干部分主要在去噪方面发挥作用,而其跳跃连接主要是向解码器模块引入高频特征,这使得网络忽略了主干部分的语义信息。基于这一发现,我们提出了一种简单却有效的方法-- “FreeU”,它无需额外训练或微调就能提升生成质量。我们的核心思路是从策略上对源自 U-Net 跳跃连接和主干特征图的贡献进行重新加权,以充分利用 U-Net 架构中这两个组件的优势。在图像和视频生成任务上取得的良好结果表明, FreeU 方法可以很容易地集成到现有的扩散模型中,例如稳定扩散(Stable Diffusion)、DreamBooth、ModelScope、Rerender 和 ReVersion 等,只需几行代码就能提升生成质量。
在这里插入图片描述
试验表明,如果把decoder阶段的全部backbone都放大,会导致oversmoothed texture。为了缓解这种情况,只在decoder的前两个阶段使用,放大backbone,并且缩小skip features。skip features需要进行FFT和IFFT,详见函数 fourier_filter代码。
完整的stable diffusion1.5的UNet结构可参考UNet2DConditionModel

SDXL效果对比

在这里插入图片描述

参数,来自于FreeU

SD1.4: (will be updated soon)
b1: 1.3, b2: 1.4, s1: 0.9, s2: 0.2

SD1.5: (will be updated soon)
b1: 1.5, b2: 1.6, s1: 0.9, s2: 0.2

SD2.1
b1: 1.1, b2: 1.2, s1: 0.9, s2: 0.2
b1: 1.4, b2: 1.6, s1: 0.9, s2: 0.2

SDXL
b1: 1.3, b2: 1.4, s1: 0.9, s2: 0.2 SDXL results

Range for More Parameters
When trying additional parameters, consider the following ranges:

b1: 1 ≤ b1 ≤ 1.2
b2: 1.2 ≤ b2 ≤ 1.6
s1: s1 ≤ 1
s2: s2 ≤ 1

代码

使用方法

import torch
from diffusers import DiffusionPipeline

pipeline = DiffusionPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16,
).to("cuda")
pipeline.enable_freeu(s1=0.9, s2=0.2, b1=1.3, b2=1.4)   ##add
generator = torch.Generator(device="cpu").manual_seed(13)
prompt = "A squirrel eating a burger"
image = pipeline(prompt, generator=generator).images[0]
image

FreeU函数(来自于diffusers)

def apply_freeu(
    resolution_idx: int, hidden_states: "torch.Tensor", res_hidden_states: "torch.Tensor", **freeu_kwargs
) -> Tuple["torch.Tensor", "torch.Tensor"]:
    """Applies the FreeU mechanism as introduced in https:
    //arxiv.org/abs/2309.11497. Adapted from the official code repository: https://github.com/ChenyangSi/FreeU.

    Args:
        resolution_idx (`int`): Integer denoting the UNet block where FreeU is being applied.
        hidden_states (`torch.Tensor`): Inputs to the underlying block.
        res_hidden_states (`torch.Tensor`): Features from the skip block corresponding to the underlying block.
        s1 (`float`): Scaling factor for stage 1 to attenuate the contributions of the skip features.
        s2 (`float`): Scaling factor for stage 2 to attenuate the contributions of the skip features.
        b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
        b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
    """
    if resolution_idx == 0:
        num_half_channels = hidden_states.shape[1] // 2
        hidden_states[:, :num_half_channels] = hidden_states[:, :num_half_channels] * freeu_kwargs["b1"]
        res_hidden_states = fourier_filter(res_hidden_states, threshold=1, scale=freeu_kwargs["s1"])
    if resolution_idx == 1:
        num_half_channels = hidden_states.shape[1] // 2
        hidden_states[:, :num_half_channels] = hidden_states[:, :num_half_channels] * freeu_kwargs["b2"]
        res_hidden_states = fourier_filter(res_hidden_states, threshold=1, scale=freeu_kwargs["s2"])

    return hidden_states, res_hidden_states
def fourier_filter(x_in: "torch.Tensor", threshold: int, scale: int) -> "torch.Tensor":
    """Fourier filter as introduced in FreeU (https://arxiv.org/abs/2309.11497).

    This version of the method comes from here:
    https://github.com/huggingface/diffusers/pull/5164#issuecomment-1732638706
    """
    x = x_in
    B, C, H, W = x.shape

    # Non-power of 2 images must be float32
    if (W & (W - 1)) != 0 or (H & (H - 1)) != 0:
        x = x.to(dtype=torch.float32)
    # fftn does not support bfloat16
    elif x.dtype == torch.bfloat16:
        x = x.to(dtype=torch.float32)

    # FFT
    x_freq = fftn(x, dim=(-2, -1))
    x_freq = fftshift(x_freq, dim=(-2, -1))

    B, C, H, W = x_freq.shape
    mask = torch.ones((B, C, H, W), device=x.device)

    crow, ccol = H // 2, W // 2
    mask[..., crow - threshold : crow + threshold, ccol - threshold : ccol + threshold] = scale
    x_freq = x_freq * mask

    # IFFT
    x_freq = ifftshift(x_freq, dim=(-2, -1))
    x_filtered = ifftn(x_freq, dim=(-2, -1)).real

    return x_filtered.to(dtype=x_in.dtype)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2274386.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

JAVA 使用apache poi实现EXCEL文件的输出;apache poi实现标题行的第一个字符为红色;EXCEL设置某几个字符为别的颜色

设置输出文件的列宽,防止文件过于丑陋 Sheet sheet workbook.createSheet(FileConstants.ERROR_FILE_SHEET_NAME); sheet.setColumnWidth(0, 40 * 256); sheet.setColumnWidth(1, 20 * 256); sheet.setColumnWidth(2, 20 * 256); sheet.setColumnWidth(3, 20 * 25…

【STM32】无源蜂鸣器播放音乐《千与千寻》,HAL库

目录 一、工程链接 二、简单介绍 主要特点: 应用: 驱动电路: 三、原理图 四、cubeMX配置 时钟配置 五、keil配置 六、驱动编写 演奏函数 主函数编写 七、效果展示 八、驱动附录 music.h music.c 一、工程链接 STM32无源蜂鸣…

在 Vue 3 集成 e签宝电子合同签署功能

实现 Vue 3 e签宝电子合同签署功能,需要使用 e签宝提供的实际 SDK 或 API。 e签宝通常提供针对不同平台(如 Web、Android、iOS)的 SDK,而 Web 端一般通过 WebView 或直接使用嵌入式 iframe 来加载合同签署页面。 下面举个 &…

04、Redis深入数据结构

一、简单动态字符串SDS 无论是Redis中的key还是value,其基础数据类型都是字符串。如,Hash型value的field与value的类型,List型,Set型,ZSet型value的元素的类型等都是字符串。redis没有使用传统C中的字符串而是自定义了…

如何用Python编程实现自动整理XML发票文件

传统手工整理发票耗时费力且易出错,而 XML 格式发票因其结构化、标准化的特点,为实现发票的自动化整理与保存提供了可能。本文将详细探讨用python来编程实现对 XML 格式的发票进行自动整理。 一、XML 格式发票的特点 结构化数据:XML 格式发票…

Linux——修改USB网卡设备节点名称

修改驱动: 测试: 参考资料: https://blog.csdn.net/ablexu2018/article/details/144868950

(STM32笔记)十二、DMA的基础知识与用法 第三部分

我用的是正点的STM32F103来进行学习,板子和教程是野火的指南者。 之后的这个系列笔记开头未标明的话,用的也是这个板子和教程。 DMA的基础知识与用法 三、DMA程序验证1、DMA 存储器到存储器模式实验(1)DMA结构体解释(2…

论文笔记(六十一)Implicit Behavioral Cloning

Implicit Behavioral Cloning 文章概括摘要1 引言2 背景:隐式模型的训练与推理3 隐式模型与显式模型的有趣属性4 policy学习成果5 理论见解:隐式模型的通用逼近性6 相关工作7 结论 文章概括 引用: inproceedings{florence2022implicit,titl…

高斯函数Gaussian绘制matlab

高斯 约翰卡尔弗里德里希高斯,(德语:Johann Carl Friedrich Gau,英语:Gauss,拉丁语:Carolus Fridericus Gauss)1777年4月30日–1855年2月23日,德国著名数学家、物理学家…

vue的路由守卫逻辑处理不当导致部署在nginx上无法捕捉后端异步响应消息等问题

近期对前端的路由卫士有了更多的认识。 何为路由守卫?这可能是一种约定俗成的名称。就是VUE中的自定义函数,用来处理路由跳转。 import { createRouter, createWebHashHistory } from "vue-router";const router createRouter({history: cr…

如何在 Ubuntu 22.04 上使用 LEMP 安装 WordPress 教程

简介: 本教程旨在指导你如何在 Ubuntu 22.04 上使用 LEMP 栈安装 WordPress。 WordPress 是一个用 PHP 编写的开源内容管理系统。LEMP 栈是 Linux,NGINX,MySQL 和 PHP 的缩写。WordPress 非常用户友好,并提供了多种选项&#xff…

PySide6基于QSlider实现QDoubleSlider

我在写小工具的时候,需要一个支持小数的滑动条。 我QSpinBox都找到了QDoubleSpinBox,QSlider愣是没找到对应的东西。 网上有好多对QSlider封装实现QDoubleSlider的文章。 似乎Qt真的没有这个东西,需要我们自行实现。 于是我也封装了一个&…

升级 Spring Boot 3 配置讲解 —— 支持断点传输的文件上传和下载功能

学会这款 🔥全新设计的 Java 脚手架 ,从此面试不再怕! 在现代 Web 应用中,文件上传和下载是非常常见的需求。然而,当文件较大时,传统的上传下载方式可能会遇到网络不稳定或传输中断的问题。为了解决这些问题…

Backend - C# EF Core 执行迁移 Migrate

目录 一、创建Postgre数据库 二、安装包 (一)查看是否存在该安装包 (二)安装所需包 三、执行迁移命令 1. 作用 2. 操作位置 3. 执行(针对visual studio) 查看迁移功能的常用命令: 查看…

KG-CoT:基于知识图谱的大语言模型问答的思维链提示

一些符号定义 知识图谱实体数量: n n n 知识图谱中关系类型数量: m m m 三元组矩阵: M ∈ { 0 , 1 } n n m \textbf{M} \in \{0, 1\}^{n \times n \times m} M∈{0,1}nnm, M i j k 1 M_{ij}^k 1 Mijk​1则说明实体 i i i和实…

HTML+CSS+JS制作中国传统节日主题网站(内附源码,含5个页面)

一、作品介绍 HTMLCSSJS制作一个中国传统节日主题网站,包含首页、节日介绍页、民俗文化页、节日活动页、联系我们页等5个静态页面。其中每个页面都包含一个导航栏、一个主要区域和一个底部区域。 二、页面结构 1. 顶部横幅区 包含传统中国风格的网站标题中国传统…

大模型WebUI:Gradio全解11——Chatbot:融合大模型的多模态聊天机器人(1)

大模型WebUI:Gradio全解11——Chatbots:融合大模型的聊天机器人(1) 前言本篇摘要11. Chatbot:融合大模型的多模态聊天机器人11.1 gr.ChatInterface()快速创建Chatbot11.1.1 定义聊天函数1. 随机回答“是”或“否”的聊…

springboot + vue+elementUI图片上传流程

1.实现背景 前端上传一张图片&#xff0c;存到后端数据库&#xff0c;并将图片回显到页面上。上传组件使用现成的elementUI的el-upload。、 2.前端页面 <el-uploadclass"upload-demo"action"http://xxxx.xxx.xxx:9090/file/upload" :show-file-list&q…

开源生成式物理引擎Genesis,可模拟世界万物

这是生成大模型时代 —— 它们能生成文本、图像、音频、视频、3D 对象…… 而如果将所有这些组合到一起&#xff0c;我们可能会得到一个世界&#xff01; 现在&#xff0c;不管是 LeCun 正在探索的世界模型&#xff0c;还是李飞飞想要攻克的空间智能&#xff0c;又或是其他研究…

【PPTist】批注、选择窗格

前言&#xff1a;本篇文章研究批注和选择窗格两个小功能 一、批注 批注功能就是介个小图标 点击可以为当前页的幻灯片添加批注&#xff0c;还能删除之前的批注 如果我们增加了登录功能&#xff0c;还可以在批注上显示当前的用户名和头像&#xff0c;不过现在是写死的。 左侧…