论文解读:DiAD之SG网络

news2024/12/24 10:24:42

目录

  • 一、SG网络功能介绍
  • 二、SG网络代码实现

一、SG网络功能介绍

DiAD论文最主要的创新点就是使用SG网络解决多类别异常检测中的语义信息丢失问题,那么它是怎么实现的保留原始图像语义信息的同时重建异常区域?

与稳定扩散去噪网络的连接: SG网络被设计为与稳定扩散(Stable Diffusion, SD)去噪网络相连接。SD去噪网络本身具有强大的图像生成能力,但可能无法在多类异常检测任务中保持图像的语义信息一致性。SG网络通过引入语义引导机制,使得在重构异常区域时能够参考并保留原始图像的语义上下文。整个框架图中,SG网络与去噪网络的连接如下图所示。
在这里插入图片描述在这里插入图片描述
这是论文给出的最终输出,我认为图中圈出来的地方有问题,应该改为SG网络的编码器才对。

语义一致性保持: SG网络在重构过程中,通过在不同尺度下处理噪声,并利用空间感知特征融合(Spatial-aware Feature Fusion, SFF)块融合特征,确保重建过程中保留语义信息。这样,即使在重构异常区域时,也能使修复后的区域与原始图像的语义上下文保持一致。
多尺度特征融合: SFF块将高尺度的语义信息集成到低尺度中,使得在保留原始正常样本信息的同时,能够处理大规模异常区域的重建。这种机制有助于在处理需要广泛重构的区域时,最大化重构的准确性,同时保持图像的语义一致性。从下图中可以看到,特征融合模块还是很好理解的。

在这里插入图片描述

与预训练特征提取器的结合: SG网络还与特征空间中的预训练特征提取器相结合。预训练特征提取器能够处理输入图像和重建图像,并在不同尺度上提取特征。通过比较这些特征,系统能够生成异常图(anomaly maps),这些图显示了图像中可能存在的异常区域,并给出了异常得分或置信度。这一步骤进一步验证了SG网络在保留语义信息方面的有效性。
避免类别错误: 相比于传统的扩散模型(如DDPM),SG网络通过引入类别条件解决了在多类异常检测任务中可能出现的类别错误问题。LDM虽然通过交叉注意力引入了条件约束,但在随机高斯噪声下去噪时仍可能丢失语义信息。SG网络则通过其语义引导机制,有效地避免了这一问题。

二、SG网络代码实现

这部分代码大概有300行

class SemanticGuidedNetwork(nn.Module):
    def __init__(
            self,
            image_size,
            in_channels,
            model_channels,
            hint_channels,
            num_res_blocks,
            attention_resolutions,
            dropout=0,
            channel_mult=(1, 2, 4, 8),
            conv_resample=True,
            dims=2,
            use_checkpoint=False,
            use_fp16=False,
            num_heads=-1,
            num_head_channels=-1,
            num_heads_upsample=-1,
            use_scale_shift_norm=False,
            resblock_updown=False,
            use_new_attention_order=False,
            use_spatial_transformer=False,  # custom transformer support
            transformer_depth=1,  # custom transformer support
            context_dim=None,  # custom transformer support
            n_embed=None,  # custom support for prediction of discrete ids into codebook of first stage vq model
            legacy=True,
            disable_self_attentions=None,
            num_attention_blocks=None,
            disable_middle_self_attn=False,
            use_linear_in_transformer=False,
    ):
        super().__init__()
        if use_spatial_transformer:
            assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'

        if context_dim is not None:
            assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
            from omegaconf.listconfig import ListConfig
            if type(context_dim) == ListConfig:
                context_dim = list(context_dim)

        if num_heads_upsample == -1:
            num_heads_upsample = num_heads

        if num_heads == -1:
            assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'

        if num_head_channels == -1:
            assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'

        self.dims = dims
        self.image_size = image_size
        self.in_channels = in_channels
        self.model_channels = model_channels
        if isinstance(num_res_blocks, int):
            self.num_res_blocks = len(channel_mult) * [num_res_blocks]
        else:
            if len(num_res_blocks) != len(channel_mult):
                raise ValueError("provide num_res_blocks either as an int (globally constant) or "
                                 "as a list/tuple (per-level) with the same length as channel_mult")
            self.num_res_blocks = num_res_blocks
        if disable_self_attentions is not None:
            # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
            assert len(disable_self_attentions) == len(channel_mult)
        if num_attention_blocks is not None:
            assert len(num_attention_blocks) == len(self.num_res_blocks)
            assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
            print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
                  f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
                  f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
                  f"attention will still not be set.")

        self.attention_resolutions = attention_resolutions
        self.dropout = dropout
        self.channel_mult = channel_mult
        self.conv_resample = conv_resample
        self.use_checkpoint = use_checkpoint
        self.dtype = th.float16 if use_fp16 else th.float32
        self.num_heads = num_heads
        self.num_head_channels = num_head_channels
        self.num_heads_upsample = num_heads_upsample
        self.predict_codebook_ids = n_embed is not None

        time_embed_dim = model_channels * 4
        self.time_embed = nn.Sequential(
            linear(model_channels, time_embed_dim),
            nn.SiLU(),
            linear(time_embed_dim, time_embed_dim),
        )

        self.input_blocks = nn.ModuleList(
            [
                TimestepEmbedSequential(
                    conv_nd(dims, in_channels, model_channels, 3, padding=1)
                )
            ]
        )
        self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels)])

        self.input_hint_block = TimestepEmbedSequential(
            conv_nd(dims, hint_channels, 16, 3, padding=1),
            nn.SiLU(),
            conv_nd(dims, 16, 16, 3, padding=1),
            nn.SiLU(),
            conv_nd(dims, 16, 32, 3, padding=1, stride=2),
            nn.SiLU(),
            conv_nd(dims, 32, 32, 3, padding=1),
            nn.SiLU(),
            conv_nd(dims, 32, 96, 3, padding=1, stride=2),
            nn.SiLU(),
            conv_nd(dims, 96, 96, 3, padding=1),
            nn.SiLU(),
            conv_nd(dims, 96, 256, 3, padding=1, stride=2),
            nn.SiLU(),
            zero_module(conv_nd(dims, 256, model_channels, 3, padding=1))
        )

        self._feature_size = model_channels
        input_block_chans = [model_channels]
        ch = model_channels
        ds = 1
        for level, mult in enumerate(channel_mult):
            for nr in range(self.num_res_blocks[level]):
                layers = [
                    ResBlock(
                        ch,
                        time_embed_dim,
                        dropout,
                        out_channels=mult * model_channels,
                        dims=dims,
                        use_checkpoint=use_checkpoint,
                        use_scale_shift_norm=use_scale_shift_norm,
                    )
                ]
                ch = mult * model_channels
                if ds in attention_resolutions:
                    if num_head_channels == -1:
                        dim_head = ch // num_heads
                    else:
                        num_heads = ch // num_head_channels
                        dim_head = num_head_channels
                    if legacy:
                        # num_heads = 1
                        dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
                    if exists(disable_self_attentions):
                        disabled_sa = disable_self_attentions[level]
                    else:
                        disabled_sa = False

                    if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
                        layers.append(
                            AttentionBlock(
                                ch,
                                use_checkpoint=use_checkpoint,
                                num_heads=num_heads,
                                num_head_channels=dim_head,
                                use_new_attention_order=use_new_attention_order,
                            ) if not use_spatial_transformer else SpatialTransformer(
                                ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
                                disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
                                use_checkpoint=use_checkpoint
                            )
                        )
                self.input_blocks.append(TimestepEmbedSequential(*layers))
                self.zero_convs.append(self.make_zero_conv(ch))
                self._feature_size += ch
                input_block_chans.append(ch)
            if level != len(channel_mult) - 1:
                out_ch = ch
                self.input_blocks.append(
                    TimestepEmbedSequential(
                        ResBlock(
                            ch,
                            time_embed_dim,
                            dropout,
                            out_channels=out_ch,
                            dims=dims,
                            use_checkpoint=use_checkpoint,
                            use_scale_shift_norm=use_scale_shift_norm,
                            down=True,
                        )
                        if resblock_updown
                        else Downsample(
                            ch, conv_resample, dims=dims, out_channels=out_ch
                        )
                    )
                )
                ch = out_ch
                input_block_chans.append(ch)
                self.zero_convs.append(self.make_zero_conv(ch))
                ds *= 2
                self._feature_size += ch

        if num_head_channels == -1:
            dim_head = ch // num_heads
        else:
            num_heads = ch // num_head_channels
            dim_head = num_head_channels
        if legacy:
            # num_heads = 1
            dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
        self.middle_block = TimestepEmbedSequential(
            ResBlock(
                ch,
                time_embed_dim,
                dropout,
                dims=dims,
                use_checkpoint=use_checkpoint,
                use_scale_shift_norm=use_scale_shift_norm,
            ),
            AttentionBlock(
                ch,
                use_checkpoint=use_checkpoint,
                num_heads=num_heads,
                num_head_channels=dim_head,
                use_new_attention_order=use_new_attention_order,
            ) if not use_spatial_transformer else SpatialTransformer(  # always uses a self-attn
                ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
                disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
                use_checkpoint=use_checkpoint
            ),
            ResBlock(
                ch,
                time_embed_dim,
                dropout,
                dims=dims,
                use_checkpoint=use_checkpoint,
                use_scale_shift_norm=use_scale_shift_norm,
            ),
        )
        self.middle_block_out = self.make_zero_conv(ch)
        self._feature_size += ch

        #SFF Block
        self.down11 = nn.Sequential(
            zero_module(nn.Conv2d(640, 1280, kernel_size=3, stride=2, padding=1, bias=False)),
            nn.InstanceNorm2d(1280),
            nn.SiLU(),
        )
        self.down12 = nn.Sequential(
            zero_module(nn.Conv2d(640, 1280, kernel_size=3, stride=2, padding=1, bias=False)),
            nn.InstanceNorm2d(1280),
            nn.SiLU(),
        )
        self.down13 = nn.Sequential(
            zero_module(nn.Conv2d(640, 1280, kernel_size=3, stride=2, padding=1, bias=False)),
            nn.InstanceNorm2d(1280),
            nn.SiLU(),
        )
        self.down21 = nn.Sequential(
            zero_module(nn.Conv2d(1280, 1280, kernel_size=3, stride=2, padding=1, bias=False)),
            nn.InstanceNorm2d(1280),
            nn.SiLU(),
        )
        self.down22 = nn.Sequential(
            zero_module(nn.Conv2d(1280, 1280, kernel_size=3, stride=2, padding=1, bias=False)),
            nn.InstanceNorm2d(1280),
            nn.SiLU(),
        )
        self.down23 = nn.Sequential(
            zero_module(nn.Conv2d(1280, 1280, kernel_size=3, stride=2, padding=1, bias=False)),
            nn.InstanceNorm2d(1280),
            nn.SiLU(),
        )
        self.down31 = nn.Sequential(
            zero_module(nn.Conv2d(1280, 1280, kernel_size=3, stride=2, padding=1, bias=False)),
            nn.InstanceNorm2d(1280),
            nn.SiLU(),
        )
        self.down32 = nn.Sequential(
            zero_module(nn.Conv2d(1280, 1280, kernel_size=3, stride=2, padding=1, bias=False)),
            nn.InstanceNorm2d(1280),
            nn.SiLU(),
        )
        self.down33 = nn.Sequential(
            zero_module(nn.Conv2d(1280, 1280, kernel_size=3, stride=2, padding=1, bias=False)),
            nn.InstanceNorm2d(1280),
            nn.SiLU(),
        )
        self.silu = nn.SiLU()

    def make_zero_conv(self, channels):
        return TimestepEmbedSequential(zero_module(conv_nd(self.dims, channels, channels, 1, padding=0)))

    def forward(self, x, hint, timesteps, context, **kwargs):
        t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
        emb = self.time_embed(t_emb)

        guided_hint = self.input_hint_block(hint, emb, context)

        outs = []

        h = x.type(self.dtype)
        for module, zero_conv in zip(self.input_blocks, self.zero_convs):
            if guided_hint is not None:
                h = module(h, emb, context)
                h += guided_hint
                guided_hint = None
            else:
                h = module(h, emb, context)
            outs.append(zero_conv(h, emb, context))

        #SFF Block Implementation
        outs[9] = self.silu(outs[9]+self.down11(outs[6])+self.down21(outs[7])+self.down31(outs[8]))
        outs[10] = self.silu(outs[10]+self.down12(outs[6])+self.down22(outs[7])+self.down32(outs[8]))
        outs[11] = self.silu(outs[11]+self.down13(outs[6])+self.down23(outs[7])+self.down33(outs[8]))

        h = self.middle_block(h, emb, context)
        outs.append(self.middle_block_out(h, emb, context))

        return outs

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1947778.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

机器学习(二十):偏差和方差问题

一、判断偏差和方差 以多项式回归为例&#xff0c;红点为训练集数据&#xff0c;绿点为交叉验证数据。 下图的模型&#xff0c;训练集误差大&#xff0c;交叉验证集误差大&#xff0c;这代表偏差很大 下图的模型&#xff0c;训练集误差小&#xff0c;交叉验证集误差小&#x…

Linux网络:传输层协议TCP(二)三次挥手四次握手详解

目录 一、TCP的连接管理机制 1.1三次握手 1.2四次挥手 二、理解 TIME_WAIT 状态 2.1解决TIME_WAIT 状态引起的 bind 失败的方法 三、理解CLOSE_WAIT状态 一、TCP的连接管理机制 在正常情况下, TCP 要经过三次握手建立连接, 四次挥手断开连接 1.1三次握手 三次握手顾名思…

vue import from

vue import from 导入文件&#xff0c;从XXXX路径&#xff1b;引入文件 import xxxx from “./minins/resize” import xxxx from “./minins/resize.js” vue.config.js 定义 : resolve(src)&#xff1b;就是指src 目录 import xxxx from “/utils/auth” im…

vue3知识

目录 基础vue开发前的准备vue项目目录结构模板语法属性绑定条件渲染列表渲染通过key管理状态事件处理事件传参事件修饰符数组变化侦测计算属性Class绑定style绑定侦听器表单输入绑定模板引用组件组成组件嵌套关系组件注册方式组件传递数据Props(父传子)组件传递多种数据类型组件…

怎么批量加密U盘?U盘批量加密的方法有哪些?

加密U盘是保护U盘数据安全的重要方法。而当需要加密的U盘数量较多时&#xff0c;我们需要批量加密U盘。那么&#xff0c;U盘怎么批量加密呢&#xff1f;下面我们就来了解一下。 U盘内存卡批量只读加密专家 U盘内存卡批量只读加密专家是一款专业的U盘加密软件&#xff0c;适用于…

什么牌子的充电宝又好又耐用?认准这几个充电宝品牌!错过就吃亏

在 2024 年&#xff0c;充电宝已然成为我们生活中不可或缺的电子配件。但面对市场上琳琅满目的充电宝产品&#xff0c;如何挑选出一款适合自己的&#xff0c;却让许多人感到困惑。充电宝要怎么挑&#xff1f;这可不是一个简单的问题。不同的使用场景、不同的设备需求&#xff0…

02 MySQL数据库管理

目录 1.数据库的结构 sql语言主要由以下几部分组成 2. 数据库与表的创建和管理 1&#xff0c;创建数据库 2&#xff0c;创建表并添加数据 3&#xff0c;添加一条数据 4&#xff0c;查询数据 5&#xff0c;更新数据 6&#xff0c;删除数据 3.用户权限管理 1.创建用户 …

3万多有分类的成语词典ACCESS\EXCEL数据库

今天最后发一个成语词典的数据库了&#xff0c;因为成语词典的数据库太多了导致我自己都有些糊涂了&#xff0c;今天这份数据库应该说是最好的成语词典了&#xff0c;不但包含了3级分类&#xff0c;而且还有级别&#xff08;不要太较真&#xff09;字段。 数据库包含多个表&…

利用 Databend 生态构建现代数据湖工作流

数据是洞察力的基石&#xff0c;越来越多的企业开始建设以数据资产为中心的存储和分析一体化方案&#xff0c;这要求 Data Infra 架构能够提供可扩展、灵活且统一的数据工作流。现代数据湖架构同时兼顾数据湖的可扩展性和数据仓库的性能&#xff0c;满足对大规模数据处理的需求…

视频文件怎么压缩到最小 视频文件怎么压缩到最小内存 4个简单的方法工具分享简单步骤

如何压缩大视频文件以减小其大小&#xff1f;在分享或存储大视频文件时&#xff0c;有效压缩是关键&#xff0c;以降低文件大小且不显著牺牲视觉和听觉质量。视频文件的大小直接影响传输、分享和存储的成本与便捷性。掌握压缩视频的技能对于数字内容处理至关重要&#xff0c;能…

【Android】linux

android系统就是跑在linux上的系统。Linux层里面包含系统和硬件驱动等一些本地代码的环境。 linux的目录 mount: 用于查看哪个模块输入只读&#xff0c;一般显示为&#xff1a; [rootlocalhost ~]# mount /dev/cciss/c0d0p2 on / type ext3 (rw) proc on /proc type proc (…

SpringBoot 实现图形验证码

一、最终结果展示 二、前端代码 2.1 index.html <!DOCTYPE html> <html lang"en"><head><meta charset"utf-8"><title>验证码</title><style>#inputCaptcha {height: 30px;vertical-align: middle;}#verifica…

(leetcode学习)236. 二叉树的最近公共祖先

给定一个二叉树, 找到该树中两个指定节点的最近公共祖先。 百度百科中最近公共祖先的定义为&#xff1a;“对于有根树 T 的两个节点 p、q&#xff0c;最近公共祖先表示为一个节点 x&#xff0c;满足 x 是 p、q 的祖先且 x 的深度尽可能大&#xff08;一个节点也可以是它自己的祖…

Q238. 除自身以外数组的乘积

思路 一开始想到的是按位乘 看了题解&#xff0c;思路是存i左边的乘积和 与 i右边的乘积和 代码一&#xff1a; 需要三次循环,需要额外空间 left和right数组 代码&#xff1a; public int[] productExceptSelf(int[] nums) {int[] left new int[nums.length];int[] right …

python题解

空间三角形 输入在三维空间的三角形三个顶点A&#xff0c;B&#xff0c;C的坐标&#xff08;x,y,z&#xff09;&#xff0c;计算并输出三角形面积。不考虑不能构成三角形的特殊情况。 格式 输入格式&#xff1a; 依次输入三个顶点A&#xff0c;B&#xff0c;C的坐标&#xff…

CISSP,信息安全圈公认的高含金量证书

在数字化和信息化迅速发展的时代&#xff0c;信息安全的重要性愈发突出。 网络攻击、数据泄露和隐私问题频发&#xff0c;使得企业和组织对信息安全专业人士的需求不断增加。 CISSP&#xff08;Certified Information Systems Security Professional&#xff09;作为信息安全领…

文字描边效果

文字描边效果可以通过text-shadow来实现&#xff0c;也可以通过-webkit-text-stroke来实现 <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, i…

MySQL数据库练习(5)

1.建库建表 # 使用数据库 use mydb16_trigger;# 表格goods create table goods( gid char(8) primary key, name varchar(10), price decimal(8,2), num int);# 表格orders create table orders( oid int primary key auto_increment, gid char(10) not null, name varchar(10…

MYSQL第五次作业

1、触发器 建立两个表:goods(商品表)、orders(订单表) mysql> use mydb16_trigger; Database changed mysql> create table goods-> (-> gid char(8) primary key,-> name varchar(10),-> price decimal(8,2),-> num int-> ); Query O…

MySQL零散拾遗(四)--- 使用聚合函数时需要注意的点点滴滴

聚合函数 聚合函数作用于一组数据&#xff0c;并对一组数据返回一个值。 常见的聚合函数&#xff1a;SUM()、MAX()、MIN()、AVG()、COUNT() 对COUNT()聚合函数的更深一层理解 COUNT函数的作用&#xff1a;计算指定字段在查询结果中出现的个数&#xff08;不包含NULL值&#…