DiffusionDet源码阅读(1)

news2024/9/21 18:08:50

本文仅仅适用于已经通读过全文的小伙伴

本文代码节选自 mmdet 中的 DiffusionDet 代码,目前该代码还处于 Development 阶段,所以我博客里写的代码和之后的稳定版本可能稍有不同,不过不用担心,我们只看最关键的部分

DDPM中扩散部分有个参数 β \beta β:

q ( z t ∣ z t − 1 ) : = N ( z t ; 1 − β t z t − 1 , β t I ) q(z_t | z_{t-1}) := \mathcal{N} (z_{t}; \sqrt{1 - \beta_t} z_{t-1}, \beta_t \bf{I} ) q(ztzt1):=N(zt;1βt zt1,βtI)

这就是每次的加噪过程,也可以视为 z t − 1 z_{t-1} zt1先经过一个缩放,再加一个随机噪声之后,就成了 z t z_{t} zt
每次加噪声通过一个参数 β t \beta_t βt来控制,这个参数是人为给定的,而不是可学习的,由于:

q ( z t ∣ z 0 ) : = N ( z t ; α ˉ t z 0 , ( 1 − α ˉ t ) I ) q(z_t | z_{0}) := \mathcal{N} (z_{t}; \sqrt{ \bar{\alpha}_t } z_{0}, (1-\bar{\alpha}_t) \bf{I} ) q(ztz0):=N(zt;αˉt z0,(1αˉt)I)
即:

z t = α ˉ t z 0 + ϵ 1 − α ˉ t ,    w h e r e    ϵ ∈ N ( 0 , I ) z_t = \sqrt{ \bar{\alpha}_t } z_{0} + \epsilon \sqrt{1 - \bar{\alpha}_t}, \ \ where \ \ \epsilon \in \mathcal{N}(0, \bf{I}) zt=αˉt z0+ϵ1αˉt ,  where  ϵN(0,I)

在给定 z 0 z_{0} z0 的基础上, q ( z t ∣ z 0 ) q(z_t | z_{0}) q(ztz0) 也是一个高斯分布,其中:

α t = 1 − β t α ˉ t = Π s = 0 t α s \alpha_t = 1 - \beta_t \\ \bar{\alpha}_t = \Pi_{s=0}^t \alpha_s αt=1βtαˉt=Πs=0tαs

α ˉ t \bar{\alpha}_t αˉt 取值趋近于0时, z t z_t zt 可以视为一个标准的高斯分布,在DiffusionDet中, β 1 : T \beta_{1:T} β1:T取了一系列零到一,且逐渐变大的值,以下是生成 β \beta β 的代码,这里我们取 T = 1000 T=1000 T=1000,即共采样 1000 1000 1000

def cosine_beta_schedule(timesteps, s=0.008):
    """Cosine schedule as proposed in
    https://openreview.net/forum?id=-NEXDKk8gZ."""
    steps = timesteps + 1
    x = torch.linspace(0, timesteps, steps, dtype=torch.float64)
    alphas_cumprod = torch.cos(
        ((x / timesteps) + s) / (1 + s) * math.pi * 0.5)**2
    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
    return torch.clip(betas, 0, 0.999)

c o s ( x ) cos(x) cos(x) c o s 2 ( x ) cos^2(x) cos2(x) 两个函数的曲线,红线是前者,蓝线是后者,二者有同一个零点 ( π 2 , 0 ) (\frac{\pi}{2}, 0) (2π,0)

请添加图片描述

这是 β \beta β的曲线

请添加图片描述

接下来就是上边计算 α \alpha α α ˉ \bar{\alpha} αˉ之类的代码:

    def _build_diffusion(self):
        betas = cosine_beta_schedule(self.timesteps)
        alphas = 1. - betas
        alphas_cumprod = torch.cumprod(alphas, dim=0)
        alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.)

        self.register_buffer('betas', betas)
        self.register_buffer('alphas_cumprod', alphas_cumprod)
        self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)

        # calculations for diffusion q(x_t | x_{t-1}) and others
        self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
        self.register_buffer('sqrt_one_minus_alphas_cumprod',
                             torch.sqrt(1. - alphas_cumprod))
        self.register_buffer('log_one_minus_alphas_cumprod',
                             torch.log(1. - alphas_cumprod))
        self.register_buffer('sqrt_recip_alphas_cumprod',
                             torch.sqrt(1. / alphas_cumprod))
        self.register_buffer('sqrt_recipm1_alphas_cumprod',
                             torch.sqrt(1. / alphas_cumprod - 1))

        # calculations for posterior q(x_{t-1} | x_t, x_0)
        # equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
        posterior_variance = betas * (1. - alphas_cumprod_prev) / (
            1. - alphas_cumprod)
        self.register_buffer('posterior_variance', posterior_variance)

        # log calculation clipped because the posterior variance is 0 at
        # the beginning of the diffusion chain
        self.register_buffer('posterior_log_variance_clipped',
                             torch.log(posterior_variance.clamp(min=1e-20)))
        self.register_buffer(
            'posterior_mean_coef1',
            betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
        self.register_buffer('posterior_mean_coef2',
                             (1. - alphas_cumprod_prev) * torch.sqrt(alphas) /
                             (1. - alphas_cumprod))

这三行计算了 β t \beta_t βt, α ˉ t \bar{\alpha}_t αˉt α ˉ t − 1 \bar{\alpha}_{t-1} αˉt1,其长度都是 T T T

        alphas = 1. - betas
        alphas_cumprod = torch.cumprod(alphas, dim=0)
        alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.)

        self.register_buffer('betas', betas)
        self.register_buffer('alphas_cumprod', alphas_cumprod)
        self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)

q ( z t ∣ z t − 1 ) : = N ( z t ; 1 − β t z t − 1 , β t I ) q(z_t | z_{t-1}) := \mathcal{N} (z_{t}; \sqrt{1 - \beta_t} z_{t-1}, \beta_t \bf{I} ) q(ztzt1):=N(zt;1βt zt1,βtI)

接下来计算 α ˉ t \sqrt{\bar{\alpha}_{t}} αˉt 1 − α ˉ t \sqrt{1 - \bar{\alpha}_{t}} 1αˉt log ⁡ ( 1 − α ˉ t ) \log{(1-\bar{\alpha}_{t})} log(1αˉt) 1 α ˉ t \frac{1}{\sqrt{\bar{\alpha}_{t}}} αˉt 1 1 α ˉ t − 1 \sqrt{\frac{1}{\bar{\alpha}_t} - 1} αˉt11

        # calculations for diffusion q(x_t | x_{t-1}) and others
        self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
        self.register_buffer('sqrt_one_minus_alphas_cumprod',
                             torch.sqrt(1. - alphas_cumprod))
        self.register_buffer('log_one_minus_alphas_cumprod',
                             torch.log(1. - alphas_cumprod))
        self.register_buffer('sqrt_recip_alphas_cumprod',
                             torch.sqrt(1. / alphas_cumprod))
        self.register_buffer('sqrt_recipm1_alphas_cumprod',
                             torch.sqrt(1. / alphas_cumprod - 1))

DDPM文中假设,后验分布 q ( z t − 1 ∣ z t , z 0 ) q(z_{t-1} | z_t, z_0) q(zt1zt,z0)也是高斯分布,有:

q ( z t − 1 ∣ z t , z 0 ) = N ( z t − 1 ; μ ~ ( z t , z 0 ) , β t ~ I ) q(z_{t-1} | z_t, z_0) = \mathcal{N} (z_{t-1} ; \tilde{\mu}(z_t, z_0), \tilde{\beta_t} \bm{I}) q(zt1zt,z0)=N(zt1;μ~(zt,z0),βt~I)

算式整理后有:

μ ~ t ( z t , z 0 ) = α ˉ t − 1 β t 1 − α ˉ t z 0 + α t ( 1 − α ˉ t − 1 ) 1 − α ˉ t z t \tilde{\mu}_t(z_t, z_0) = \frac{ \sqrt{\bar{\alpha}_{t-1}} \beta_t }{ 1 - \bar{\alpha}_t } z_{0} + \frac { \sqrt{\alpha_t} (1 - \bar{\alpha}_{t-1}) } { 1 - \bar{\alpha}_t } z_{t} μ~t(zt,z0)=1αˉtαˉt1 βtz0+1αˉtαt (1αˉt1)zt

β ~ t = 1 − α ˉ t − 1 1 − α ˉ t β t \tilde{\beta}_{t} = \frac { 1 - \bar{\alpha}_{t-1} } { 1 - \bar{\alpha}_t } \beta_{t} β~t=1αˉt1αˉt1βt

接下来的几行代码用来计算这几个系数:

        # calculations for posterior q(x_{t-1} | x_t, x_0)
        # equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
        posterior_variance = betas * (1. - alphas_cumprod_prev) / (
            1. - alphas_cumprod)
        self.register_buffer('posterior_variance', posterior_variance)

        # log calculation clipped because the posterior variance is 0 at
        # the beginning of the diffusion chain
        self.register_buffer('posterior_log_variance_clipped',
                             torch.log(posterior_variance.clamp(min=1e-20)))
        self.register_buffer(
            'posterior_mean_coef1',
            betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
        self.register_buffer('posterior_mean_coef2',
                             (1. - alphas_cumprod_prev) * torch.sqrt(alphas) /
                             (1. - alphas_cumprod))

以上就是函数 _build_diffusion 的全部内容,集中几个log项可能是之后计算loss用的

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/456615.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

mybatis中大数据量foreach插入效率对比

1.controller代码 RequestMapping("/testInsert")public String testInsert(Integer sum){testService.testInsert(sum);return "发送成功";}2.service代码 Overridepublic void testInsert(Integer sum) {long start System.currentTimeMillis();List<…

LightGBM面试题

1.偏差 vs 方差? 偏差是指由有所采样得到的大小为m的训练数据集&#xff0c;训练出的所有模型的输出的平均值和真实模型输出之间的偏差。 通常是由对学习算法做了错误的假设导致的描述模型输出结果的期望与样本真实结果的差距。分类器表达能力有限导致的系统性错误&#xff0c…

基于AT89C52单片机的温度检测报警设计

点击链接获取Keil源码与Project Backups仿真图&#xff1a; https://download.csdn.net/download/qq_64505944/87708680?spm1001.2014.3001.5503 源码获取 主要内容&#xff1a; 本系统的设计主要是了解了单片机微型计算机&#xff0c;根据现实生活的需要以及已掌握的理论知识…

Linux应用程序开发:静态库与动态库的制作及使用

目录 一、库的简介二、静态库与动态库的简介三、静态库制作与调用案例四、动态库制作与调用案例 一、库的简介 库是一种可执行的二进制文件&#xff0c;是编译好的代码。使用库可以提高开发效率。而Linux库的种类可分为动态库和静态库。 二、静态库与动态库的简介 1、静态库&a…

第十七届中国CFO大会圆满举办 用友蝉联中国CFO首选智能财务厂商!

4月21日&#xff0c;由财政部指导、《新理财》杂志社主办、用友等单位协办的「数字财务 智能引领」第十七届中国CFO大会在北京圆满举办&#xff01;业内专家、权威学者以及众多来自央国企等知名大型企业的财务领路人荟聚一堂&#xff0c;共襄中国CFO领域的顶尖盛会&#xff0c;…

数字化转型中的石头和沙子问题

作者介绍 朱金衡&#xff0c;西门子Mendix 高级技术咨询顾问及架构师&#xff0c;Mendix Certified 中级培训讲师以及TOGAF Certified 企业架构师。作为专家服务架构师提供咨询服务&#xff0c;如方案设计、开发辅导、故障排除、应用程序审查等&#xff0c;同时创造了许多专门…

【算法】从x的n次方看递归时间复杂度计算

从x的n次方看递归时间复杂度计算 1.循环 这个问题&#xff0c;最简单的办法是用循环 int pow1(int x,int n) {int result 1;for(int i0;i<n;i){result*x;}return result; }如上算法的时间复杂度为O(N)&#xff0c;但还是不够理想。这时尝试使用递归算法 2.递归1 int po…

交换机的电口和光口到底是个啥东东,做网络的这个常识得懂!

在计算机网络中&#xff0c;交换机是一个非常重要的设备&#xff0c;它可以将来自不同设备的数据包进行转发和交换。交换机通常具有多个接口&#xff0c;其中包括光口和电口。在本文中&#xff0c;我们将详细讨论交换机的光口和电口的概念以及它们的不同之处。 电口 电口是交换…

应届生的天坑,悔不该进那外包啊.....

关于计算机专业应届生毕业之后会遇到的就业问题&#xff0c;网上已经有许多的套路&#xff0c;实际上许多人在选择专业的时候并没有考虑到之后的就业方向&#xff0c;甚至于自己所学的专业面向的工作岗位都不是特别清楚。计算机专业毕业大概率是要做程序员的&#xff0c;而目前…

RichTextBox控件详解

RichTextBox和TextBox的区别 从外观来看 multiline设置为true or false区别 textbox RichTextBox 先看截图 属性 AcceptsTab AutoWordSelection BulletIndent DetectUrls Dock EnableAutoDragDrop HideSelection Lines ScrollBars WordWrap SelectionIndent and SelectionC…

FE_TA不知道的CSS 换行系列【1】white-space

在W3C官方描述中&#xff0c;white-space主要有以下两个作用&#xff1a; 是否进行空格合并&#xff0c;以及控制空格合并的方式&#xff1b;是否在soft wrap opportunities&#xff08;文本中可进行换行的断点位置&#xff09;处进行文本换行。 从字面意思来看white-space即…

从github下载项目并进行环境配置

文章目录 1 设置虚拟环境2 git clone 链接地址3 环境配置 1 设置虚拟环境 利用pycharm打开项目&#xff1a;File->Open配置对应的虚拟环境&#xff1a;File->Setting->Project->Python解释器&#xff0c;然后选择对应的虚拟环境如果没有提前设置虚拟环境&#xf…

私人工具集6——使用C# 创建一个简单的restful风格的WebAPI

创建一个简单的WebApi 工具&#xff1a;VS2022 创建新项目 打开VS2022,创建新项目&#xff0c;可以搜索API作为关键字。 为项目取个名字 创建的应用程序&#xff0c;选择WebAPI&#xff0c;注意&#xff0c;右侧的信息默认即可&#xff0c;不要随意选择。 点击创建&#xff…

高可用消息服务消息一致、可靠性、链路稳定性核心关注点

面临的问题 初期业务主要的场景是直播间的群聊消息以及一小部分的单聊消息。由于是教育场景&#xff0c;所以业务在划分聊天室的时候是以班级为单位进行划分的&#xff0c;假设每个聊天室的人数为500人。 问题一&#xff1a;用户的维护 直播场景的群聊与微信等常见的群聊在用…

可视化大屏模板|不玩虚的,套用立得报表

写在前面&#xff1a;这是报表&#xff0c;是可视化大屏报表&#xff0c;是可以直接套用来分析我们自己数据源的可视化大屏报表模板。不是单纯的图片&#xff01; 在一些社交平台上经常看到有人误将可视化大屏图片当做报表求分享。可以理解大家都想要将报表做得好看&#xff0…

信息收集(三)端口和目录信息收集

信息收集&#xff08;一&#xff09;域名信息收集 信息收集&#xff08;二&#xff09;IP信息收集 端口是什么 "端口"是英文port的意译&#xff0c;可以认为是设备与外界通讯交流的出口。端口可分为虚拟端口和物理端口&#xff0c;其中虚拟端口指计算机内部或交换机…

Matlab绘图中的一些技能

目录 1、matlab坐标轴设置多种字体(复合字体) 2、matlab图片中title生成的标题转移至图像下端 3、指定对应格式和期望dpi的图像进行保存、以及不留白保存 4、设置字体字号&#xff08;x、y轴&#xff0c;标题。全局字体等&#xff09; 5、设置刻度值信息&#xff0c;只有左…

ConcurrentHashMap是如何保证线程安全的

ConcurrentHashMap是如何保证线程安全的 定义和问题解决JDK 1.7实现原理JDK 1.8性能优化总结 定义和问题解决 ConcurrentHashMap相当于HashMap的多线程版本。 它的功能本质上和HashMap没有什么区别&#xff0c;因为HashMap在并发操作的时候会出现各种问题&#xff0c;比如&am…

Android混淆和反混淆

本篇来介绍下Android的混淆和反混淆&#xff0c;说起混淆&#xff0c;大家都会很自然地想到ProGuard&#xff0c;此外还有R8。事实上&#xff0c;AGP3.3之后&#xff0c;官方默认使用R8做代码优化、混淆和压缩。ProGuard和R8常常用于混淆最终的Android项目&#xff0c;增加项目…

Vue-router【VUE】

6. vue-router 6.1 相关理解 6.1.1 vue-router 的理解 路由就是一组key-value的对应关系。多个路由&#xff0c;需要经过路由器的管理vue是一个插件库&#xff0c;专门用来实现SPA应用。 6.1.2 对SPA应用的理解 单页 Web 应用&#xff08;single page web application, SP…