improved-diffusion-main代码理解

news2025/1/22 15:46:03

目录

  • 一、 TimestepEmbedSequential
  • 二、PyTorch之Checkpoint机制
  • 三、AttentionBlock
  • 四、use_scale_shift_norm

和nanoDiffusion-main相比,improved-diffusion-main代码是相似的,但有几个不是很好理解的地方记录一下。

一、 TimestepEmbedSequential

代码中class ResBlock继承自TimestepBlock,需要执行时间步嵌入操作,其他不需要。

class TimestepBlock(nn.Module):
    """
    Any module where forward() takes timestep embeddings as a second argument.
    """

    @abstractmethod
    def forward(self, x, emb):
        """
        Apply the module to `x` given `emb` timestep embeddings.
        """


class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
    """
    A sequential module that passes timestep embeddings to the children that
    support it as an extra input.
    """

    def forward(self, x, emb):
        for layer in self:
            if isinstance(layer, TimestepBlock):
                x = layer(x, emb)
            else:
                x = layer(x)
        return x
class ResBlock(TimestepBlock):

二、PyTorch之Checkpoint机制

def checkpoint(func, inputs, params, flag):
    """
    Evaluate a function without caching intermediate activations, allowing for
    reduced memory at the expense of extra compute in the backward pass.

    :param func: the function to evaluate.
    :param inputs: the argument sequence to pass to `func`.
    :param params: a sequence of parameters `func` depends on but does not
                   explicitly take as arguments.
    :param flag: if False, disable gradient checkpointing.
    """
    if flag:
        args = tuple(inputs) + tuple(params)
        return CheckpointFunction.apply(func, len(inputs), *args)
    else:
        return func(*inputs)

checkpoint 是在 torch.no_grad() 模式下计算的目标操作的前向函数,这并不会修改原本的叶子结点的状态,有梯度的还会保持。只是关联这些叶子结点的临时生成的中间变量会被设置为不需要梯度,因此梯度链式关系会被断开。

三、AttentionBlock


class AttentionBlock(nn.Module):
    def __init__(self, channels, num_heads=1, use_checkpoint=False):
        super().__init__()
        self.channels = channels
        self.num_heads = num_heads
        self.use_checkpoint = use_checkpoint

        self.norm = normalization(channels)
        self.qkv = conv_nd(1, channels, channels * 3, 1)
        self.attention = QKVAttention()
        self.proj_out = zero_module(conv_nd(1, channels, channels, 1))

    def forward(self, x):
        return checkpoint(self._forward, (x,), self.parameters(), self.use_checkpoint)

    def _forward(self, x):
        b, c, *spatial = x.shape
        x = x.reshape(b, c, -1)
        qkv = self.qkv(self.norm(x))
        qkv = qkv.reshape(b * self.num_heads, -1, qkv.shape[2])
        h = self.attention(qkv)
        h = h.reshape(b, -1, h.shape[-1])
        h = self.proj_out(h)
        return (x + h).reshape(b, c, *spatial)


class QKVAttention(nn.Module):
    """
    A module which performs QKV attention.
    """

    def forward(self, qkv):
        """
        Apply QKV attention.

        :param qkv: an [N x (C * 3) x T] tensor of Qs, Ks, and Vs.
        :return: an [N x C x T] tensor after attention.
        """
        ch = qkv.shape[1] // 3
        q, k, v = th.split(qkv, ch, dim=1)
        scale = 1 / math.sqrt(math.sqrt(ch))
        weight = th.einsum(
            "bct,bcs->bts", q * scale, k * scale
        )  # More stable with f16 than dividing afterwards
        weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
        return th.einsum("bts,bcs->bct", weight, v)

    @staticmethod
    def count_flops(model, _x, y):
    
        b, c, *spatial = y[0].shape
        num_spatial = int(np.prod(spatial))
        # We perform two matmuls with the same number of ops.
        # The first computes the weight matrix, the second computes
        # the combination of the value vectors.
        matmul_ops = 2 * b * (num_spatial ** 2) * c
        model.total_ops += th.DoubleTensor([matmul_ops])

下面这个函数是准备适合的qkv矩阵

    def _forward(self, x):
        b, c, *spatial = x.shape
        x = x.reshape(b, c, -1)-》输入转换为(b,c,N)
        qkv = self.qkv(self.norm(x))-》通过卷积转换为(b,3*c,N)
        qkv = qkv.reshape(b * self.num_heads, -1, qkv.shape[2])
        h = self.attention(qkv)
        h = h.reshape(b, -1, h.shape[-1])
        h = self.proj_out(h)
        return (x + h).reshape(b, c, *spatial)

class QKVAttention的forward就是下面的公式:
在这里插入图片描述

四、use_scale_shift_norm

    def _forward(self, x, emb):
        h = self.in_layers(x)
        emb_out = self.emb_layers(emb).type(h.dtype)
        while len(emb_out.shape) < len(h.shape):
            emb_out = emb_out[..., None]
        if self.use_scale_shift_norm:
            out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
            scale, shift = th.chunk(emb_out, 2, dim=1)
            h = out_norm(h) * (1 + scale) + shift
            h = out_rest(h)
        else:
            h = h + emb_out
            h = self.out_layers(h)
        return self.skip_connection(x) + h

在深度学习中,特别是在处理如扩散模型(Diffusion Models)或任何需要精细控制输出特征的神经网络时,use_scale_shift_norm引入一种灵活的变换,这种变换通过缩放(scale)和平移(shift)来调整网络层的输出。
use_scale_shift_norm是一个布尔值(True或False),用于决定是否应用这种缩放和平移的归一化方法。如果use_scale_shift_norm为True,则执行以下步骤:

分割输出层:首先,代码将self.out_layers(一个包含网络层的列表)分割为两部分。out_norm是列表中的第一个层,负责进行某种形式的归一化或变换(尽管这里的名字是out_norm,但它可能不仅仅执行归一化,而是任何形式的变换层)。out_rest是列表中剩余的所有层,这些层将在缩放和平移之后应用。
提取缩放和平移参数:接下来,从emb_out(可能是嵌入层的输出或其他某种特征表示)中提取缩放(scale)和平移(shift)参数。这里假设emb_out的维度被设计为包含这两组参数,通过th.chunk(emb_out, 2, dim=1)沿着第二维(dim=1)将其分割成两部分,分别代表缩放和平移参数。
应用缩放和平移:然后,将h(可能是之前某个层的输出)通过out_norm层进行变换,之后使用从emb_out中提取的缩放和平移参数对结果进行调整。调整的方式是将out_norm(h)的输出乘以(1 + scale)并加上shift。这个步骤实质上是在对out_norm(h)的输出进行线性变换,以引入额外的灵活性和控制。
通过剩余层:最后,将调整后的输出h通过out_rest中剩余的层进行进一步的处理。
这种技术的一个关键优势是它能够以一种灵活且数据驱动的方式调整网络层的输出,而不需要在模型架构中硬编码特定的归一化或变换策略。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1899071.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

栈复用(覆盖栈上的有用数据)

栈复用&#xff08;覆盖栈上的有用数据&#xff09; 程序给的输入长度&#xff0c;不够溢出 到返回值&#xff0c;甚至都 不到bp位置 &#xff0c;这是要考虑覆盖之前函数(或当前函数)的栈上的有用数据&#xff0c;任何利用 程序后续函数调用 时要利用该位置上的数据&#xff…

机器人具身智能Embodied AI

强调智能体&#xff08;如机器人&#xff09;通过物理身体在物理世界中的实时感知、交互和学习来执行任务。 通过物理交互来完成任务的智能系统。它由“本体”&#xff08;即物理身体&#xff09;和“智能体”&#xff08;即智能核心&#xff09;耦合而成&#xff0c;能够在复…

《HIRI-ViT: Scaling Vision Transformer with High Resolution Inputs》解读

期刊&#xff1a;TPAMI 年份&#xff1a;2024 摘要 视觉Transformer(ViT)和卷积神经网络(CNN)的混合深度模型已经成为一类强大的视觉任务骨干。扩大这种混合主干网的输入分辨率自然会增强模型的能力&#xff0c;但不可避免地要承受二次扩展的沉重计算成本。相反&#xff0c;…

SQL索引事务

SQL索引事务 索引 创建主键约束(primary key),唯一约束(unique),外键约束(foreign key)时,会自动创建对应列的索引 1.1 查看索引 show index from 表名 现在这个表中没有索引,那么我们现在将这几个表删除之后创建新表 我们现在建立一个班级表一个学生表,并且学生表与班级表存…

高速PCB设计Tips

在进行原理图输入过程中&#xff0c;需要注意将设计分解为功能块&#xff0c;将所有相关组件放在同一页。例如&#xff0c;以太网相关的组件&#xff0c;通常运行在50MHz或更高频率&#xff0c;在原理图设计中应集中在同一页。清晰标记高速连接和电源连接。差分信号和单端阻抗控…

免费分享:中国三级及以上河流(附下载方法)

河流分级法的分级方法是从源头最小河流开始,称为一级河流;两条一级河流汇合成二级河流;以此类推,三级河流等等;最后是干流。本文将介绍中国三级及以上河流数据。 数据简介 1:100万中国三级及以上河流矢量数据是涵盖了全国范围内三级及以上级别河流的详细地理信息和空间分布。这…

5百多本分章节古籍内容大全ACCESS\EXCEL数据库

很多明清小说现在越来越不容易查看其内容&#xff0c;虽然之前搞到过一份《3万8千多古代文学大全ACCESS数据库》&#xff0c;但简体中文总让我感觉有删减、非原版的印象&#xff0c;今天正好遇到一个好的古籍网站&#xff0c;繁体字繁体文&#xff0c;感觉非常不错&#xff0c;…

期权学习必看圣书:《3小时快学期权》要在哪里看?

今天带你了解期权学习必看圣书&#xff1a;《3小时快学期权》要在哪里看&#xff1f;《3小时快学期权》是一本关于股票期权基础知识的书籍。 它旨在通过简明、易懂的语言和实用的案例&#xff0c;让读者在短时间内掌握股票期权的基本概念、操作方法和投资策略。通过这本书&…

LeetCode刷题记录:(15)三角形最小路径和

知识点&#xff1a;倒叙的动态规划 题目传送 解法一&#xff1a;二维动态规划【容易理解】 class Solution {public int minimumTotal(List<List<Integer>> triangle) {int n triangle.size();if (n 1) {return triangle.get(0).get(0);}// dp[i][j]:走到第i层第…

Gemini for China 大更新,现已上架 Android APP!

官网&#xff1a;https://gemini.fostmar.online/ Android APP&#xff1a;https://gemini.fostmar.online/gemini_1.0.apk 一、Android APP 如果是 Android 设备&#xff0c;则会直接识别到并给下载链接。PC 直接对话即可。 二、聊天记录 现在 Gemini for China&#xff…

QWidget成员函数功能和使用详细说明(四)(文字+用例+代码+效果图)

文章目录 1.测试工程配置2.成员函数2.1 void setParent(QWidget *parent)2.2 void setMouseTracking(bool enable)2.3 bool hasMouseTracking() const2.4 void setPalette(const QPalette &)2.5 const QPalette &palette() const2.6 int QWidget::grabShortcut(const Q…

gradle构建工具

setting.gradle // settings.gradle rootProject.name my-project // 指定根项目名称include subproject1, subproject2 // 指定子项目名称&#xff0c;可选jar包名称 方式一 jar {archiveBaseName my-application // 设置 JAR 文件的基本名称archiveVersion 1.0 // 设置…

实验四 图像增强—灰度变换之直方图变换

一&#xff0e;实验目的 1&#xff0e;掌握灰度直方图的概念及其计算方法&#xff1b; 2&#xff0e;熟练掌握直方图均衡化计算过程&#xff1b;了解直方图规定化的计算过程&#xff1b; 3&#xff0e;了解色彩直方图的概念和计算方法 二&#xff0e;实验内容&#xff1a; …

泰迪智能科技企业项目试岗实训——数据分析类型

当今社会&#xff0c;就业技能的瞩目度正与日俱增。在竞争激烈的就业市场中&#xff0c;重视技能的培养和提升&#xff0c;通过学习与实践&#xff0c;增强自己的竞争力&#xff0c;为未来的职业发展打下坚实的基础&#xff0c;不仅有助于解决各类工作技能问题&#xff0c;更能…

跨境干货|最新注册Google账号方法分享

谷歌账号对做跨境外贸业务的人来说是刚需&#xff0c;目前来说大部分的海外社媒平台、工具都可以用谷歌账号来注册。但是仍然有很多朋友并不知道如何注册这个谷歌账号&#xff0c;今天就来给大家分享2个注册谷歌账号的方法&#xff0c;一个是手机号注册&#xff0c;一个是如何跳…

PEFT - 安装及简单使用

LLM、AIGC、RAG 开发交流裙&#xff1a;377891973 文章目录 一、关于 PEFT二、安装1、使用 PyPI 安装2、使用源码安装 三、快速开始1、训练2、保存模型3、推理4、后续步骤 本文翻译整理自&#xff1a;https://huggingface.co/docs/peft/index 一、关于 PEFT &#x1f917;PEFT…

关于如何做好淘汰 IT 资产数据安全销毁工作的思考 文件销毁 硬盘销毁 数据销毁 物料销毁 文件粉碎

在当今数字化时代&#xff0c;企业的 IT 资产不断更新换代&#xff0c;淘汰的 IT 资产中往往存储着大量的敏感数据。如何确保这些数据在资产淘汰过程中被安全销毁&#xff0c;成为了企业面临的重要挑战。以下是对如何做好淘汰 IT 资产数据安全销毁工作的一些思考。 一、明确数…

51单片机STC89C52RC——14.1 直流电机调速

目录 目的/效果 1&#xff1a;电机转速同步LED呼吸灯 2 通过独立按键 控制直流电机转速。 一&#xff0c;STC单片机模块 二&#xff0c;直流电机 2.1 简介 2.2 驱动电路 2.2.1 大功率器件直接驱动 2.2.2 H桥驱动 正转 反转 2.2.3 ULN2003D 引脚、电路 2.3 PWM&…

【C++】 解决 C++ 语言报错:Segmentation Fault

文章目录 引言 段错误&#xff08;Segmentation Fault&#xff09;是 C 编程中常见且令人头疼的错误之一。段错误通常发生在程序试图访问未被允许的内存区域时&#xff0c;导致程序崩溃。本文将深入探讨段错误的产生原因、检测方法及其预防和解决方案&#xff0c;帮助开发者在…

智能插座搭配BIOS唤醒功能实现远程定时开关机

智能插座 智能插座凭借其强大的联网能力&#xff0c;不仅能够实现远程操控开关电源&#xff0c;部分高端型号更是集成了电量统计与自动化操作功能&#xff0c;为用户带来了前所未有的便捷体验。以下是我对几款体验过的智能插座的简要评价&#xff0c;因版本差异可能有所不同。…