【前沿模型解析】潜在扩散模型 2-1 | 手撕感知图像压缩 基础块ResNet块

news2024/11/25 8:22:59

文章目录

  • 1 残差结构回顾
  • 2 LDM结构中的残差结构设计
    • 2.1 组归一化GroupNorm层
    • 2.2 激活函数层
    • 2.3 卷积层
    • 2.4 dropout层
  • 3 代码实现

1 残差结构回顾

残差结构应该是非常重要的基础块之一了,你肯定会在各种各样的网络模型结构里看到残差结构,他是非常强大的~

残差结构往往可以抽象为如下的结构

在这里插入图片描述

为什么要设置残差模块?

因为在深度学习中,网络层级很多复杂,非常重要的一部分是前后的梯度信息流动,否则很容易发生梯度爆炸或梯度消失

想象一下我们开车去一个遥远的目的地。我们可以选择直接开到目的地,也可以选择在途中设置几个“中转站”,但是中转站多了有可能会丢失最终目的地的信息,中途有好玩的就被吸引了,所以我们有时候要直接开车前往目的地,避免一些信息的丢失或遗忘。

2 LDM结构中的残差结构设计

第一阶段的残差块是主要的部分

负责不断将送进来的图像进行特征提取~,是多个ResnetBlock前后堆叠起来形成的

  • 其中中间部分的ResnetBlock是不改变特征图的通道和尺寸大小
  • 最后一个ResnetBlock将通道数加倍

在这里插入图片描述

2.1 组归一化GroupNorm层

不是用的传统的批归一化BN,而是用的组归一化GN

是因为BN在Batch_size比较小的时候,表现很差,而我们在图像生成的实际任务中,由于分辨率比较大,所以Batch_size往往比较小

这时候GN的效果会更好

GN实现方式就是按照通道去分组,每组各自归一化,比如图例中的通道数是128,分为32组,那么每组就是4个通道进行归一化

像了解更多归一化知识可以看文章

全面解读Group Normalization

2.2 激活函数层

没有采用传统的ReLU什么的,而是采用了一个组合形式

这是作者尝试不同实验效果比较好的

其实启发我们在自己的平时网络设计过程中,也可以进行损失函数的调整

2.3 卷积层

卷积层就很熟悉了

如果是中间层的ResNetBlock我们要实现两点

  1. 维持前后通道数不变
  2. 维持前后尺寸大小不变

所以用了这样的设计

  • 小卷积核3*3的大小,步长1,填充1

经过以后,尺寸大小可以维持不变

公式

特征图长或宽 − 卷积核尺寸 + 2 ∗ 填充尺寸 步长 + 1 \frac{特征图长或宽-卷积核尺寸+2*填充尺寸}{步长}+1 步长特征图长或宽卷积核尺寸+2填充尺寸+1

如果是最后的ResNetBlock

则进行通道数的改变

2.4 dropout层

防止过拟合的经典操作,让部分神经元失活(变为0)组织信息传递,避免模型能力过强

最后再经过一个残差输出即可

3 代码实现

from torch import nn
import torch

class ResnetBlock(nn.Module):
    def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
                 dropout):
        super().__init__()
        self.in_channels = in_channels
        out_channels = in_channels if out_channels is None else out_channels
        self.out_channels = out_channels
        self.use_conv_shortcut = conv_shortcut

        self.norm1 = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
        self.conv1 = torch.nn.Conv2d(in_channels,
                                     out_channels,
                                     kernel_size=3,
                                     stride=1,
                                     padding=1)

        self.norm2 = torch.nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
        self.dropout = torch.nn.Dropout(dropout)
        self.conv2 = torch.nn.Conv2d(out_channels,
                                     out_channels,
                                     kernel_size=3,
                                     stride=1,
                                     padding=1)
        if self.in_channels != self.out_channels:#如果分辨率改变则进行更小的卷积核
            if self.use_conv_shortcut:
                self.conv_shortcut = torch.nn.Conv2d(in_channels,
                                                     out_channels,
                                                     kernel_size=3,
                                                     stride=1,
                                                     padding=1)
            else:
                self.nin_shortcut = torch.nn.Conv2d(in_channels,
                                                    out_channels,
                                                    kernel_size=1,
                                                    stride=1,
                                                    padding=0)
    def forward(self, x):
        h = x
        h = self.norm1(h) #靠后面的时候有问题 shape [80, 256, 8, 8]
        h = h*torch.sigmoid(h)
        h = self.conv1(h)
        h = self.norm2(h)
        h = h*torch.sigmoid(h)
        h = self.dropout(h)
        h = self.conv2(h)

        if self.in_channels != self.out_channels:
            if self.use_conv_shortcut:
                x = self.conv_shortcut(x)
            else:
                x = self.nin_shortcut(x)

        return x+h

其中有几个特别的参数需要说明

  • in_channels 就是输入ResNet的网络通道数,out_channels 就是输出ResNet块的网络通道数,注意后者默认是None,这样的话就不会改变通道数,适用于中间ResNet块部分,最后一层的ResNet块要给出out_channels,方便做通道更改
  • use_conv_shortcut决定着最后一层的ResNet改变通道数的时候的方式,为True的话则是size=3的卷积核,而如果是false的话则是size为1的小卷积核
  • dropout是失活比率

注:与源码相比,简化了time_emb相关的内容,因为在源码中这部分也没有排上用场

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1575544.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【绘图案例-模拟ImageView Objective-C语言】

一、接下来,我们来模拟一下imageView, 1.我们先来写一段儿简单的代码,然后我们再来说怎么着去做啊, 首先呢,imageView啊,它初始化,有两个方法: 1)init: 2)initWithImage: 这两个方法,有什么样儿的区别呢,知道吗,除了默认带不带图片以外,还有没有其他的区别…

C语言进阶课程学习记录-第25课 - # 和 ## 操作符使用分析

C语言进阶课程学习记录-第25课 - # 和 ## 操作符使用分析 #运算符实验-#转化字符串预处理后代码 实验-#输出函数名预处理后的代码 ##运算符实验-##定义变量预处理后代码 实验-##定义结构体预处理后的代码 小结 本文学习自狄泰软件学院 唐佐林老师的 C语言进阶课程,图…

linux基础篇:linux挂载本地yum源——centos7.9为例

linux挂载本地yum源——centos7.9为例 一、Linux本地yum源介绍 Linux本地YUM源是一个本地存储的软件包仓库,它允许用户在不连接互联网的情况下安装、更新和管理软件包。本地YUM源可以提高软件包安装速度,降低网络带宽消耗,并提高软件包管理…

C语言程序与设计——指针地址与main函数

指针变量 在C语言中,最重要的就是对于指针和地址的理解,因为C语言是更接近底层的编程语言,所以它可以允许开发者对内存操作,这也是区别于其它编程语言的一个重要特性。 如何对内存进行操作呢。我们知道在编程过程中,在…

软件设计师:11-结构化开发与UML

结构化开发(3-4分) 一、模块化 二、耦合(背) 三、内聚(背) 四、设计原则(背) 五、系统文档 六、数据流图 数据流的起点或终点必须有一个是加工 判断依据: 1、…

CTF之GET和POST

学过php都知道就一个简单传参,构造payload:?whatflag得到 flag{3121064b1e9e27280f9f709144222429} 下面是POST那题 使用firefox浏览器的插件Hackbar勾选POST传入whatflag flag{828a91acc006990d74b0cb0c2f62b8d8}

【TI毫米波雷达】官方工业雷达包的生命体征检测环境配置及避坑(Vital_Signs、IWR6843AOPEVM)

【TI毫米波雷达】官方工业雷达包的生命体征检测环境配置及避坑(Vital_Signs、IWR6843AOPEVM) 文章目录 生命体征基本介绍IWR6843AOPEVM的配置上位机配置文件避坑上位机start测试距离检测心跳检测呼吸频率检测空环境测试 附录:结构框架雷达基…

js,uniapp,vue,小写数字转化为大写

应用场景: 把1、2、3,转为一、二、三 方法: retBigSrt(num) {const changeNum [零, 一, 二, 三, 四, 五, 六, 七, 八, 九]const unit [, 十, 百]num parseInt(num)const getWan (temp) > {const strArr temp.toString().split().re…

Github项目推荐-ChatGPT-Admin-Web

项目地址 https://github.com/AprilNEA/ChatGPT-Admin-Web 项目简介 通过api接入大模型,并基于此封装了一层用户管理的功能,适合团队内使用。 项目截图

瑞吉外卖实战学习-17、用户地址簿相关功能

用户地址簿相关功能 效果图1、根据规则创建相关文件2、新增收货地址接口3、列表查询页面以及设置默认地址 效果图 1、根据规则创建相关文件 2、新增收货地址接口 获取到传入的数据然后将id添加进去,然后存储到数据库 3、列表查询页面以及设置默认地址 list接口&am…

C顺序表:通讯录

目录 前言 通讯录数据结构 通讯录初始化 查找名字 增加联系人 删除联系人 展示所有联系人 查找联系人 修改信息 销毁通讯录 完整通讯录代码 前言 数据结构中的顺序表如果已经学会了,那么我们就可以基于顺序表来完成一个通讯录了 通讯录其实我们使用前…

1.0-spring入门

文章目录 前言一、版本要求二、第一个spring程序1.引入pom2.代码部分2.1 spring bean2.2 springContext.xml2.3 测试2.4 执行结果 总结 前言 最近想要系统的学习下spring相关的框架,于是乎,来到了B站(真是个好地方),spring会专门开一个专栏出来,记录学习心得,与大家共勉。 Spri…

[每天一道面试题] HTTP,FTP,TFTP的底层实现协议是什么

HTTP、FTP和TFTP等这些协议都是属于互联网协议网络层模型中的应用层协议。它们的底层实现主要依赖于传输层的两种协议—— TCP(传输控制协议) 和 UDP(用户数据报协议)。 HTTP: 超文本传输协议(HTTP)通常在TCP协议的基础上操作。HTTP用于在网络上传输超文本,是万维网…

【Redis 知识储备】读写分离/主从分离架构 -- 分布系统的演进(4)

读写分离/主从分离架构 简介出现原因架构工作原理技术案例架构优缺点 简介 将数据库读写操作分散到不同的节点上, 数据库服务器搭建主从集群, 一主一从, 一主多从都可以, 数据库主机负责写操作, 从机只负责读操作 出现原因 数据库成为瓶颈, 而互联网应用一般读多写少, 数据库…

云仓酒庄旗下雷盛红酒入驻香港星怡SingLa餐厅共绘美食美酒新篇章

近日,云仓酒庄旗下品牌雷盛红酒正式入驻香港餐厅星怡SingLa,这一跨界合作不仅为香港市民和游客带来了全新的味蕾享受,也标志着美食与美酒文化的很好结合,共同绘就了一幅精彩绝伦的美食美酒新篇章。 云仓酒庄一直以来都致力于为消费…

MySQL基础练习题:创建数据库

这部分主要是为了帮助大家回忆回忆MySQL的基本语法,数据库来自于MySQL的官方简化版,题目也是网上非常流行的35题。这些基础习题基本可以涵盖面试中需要现场写SQL的问题。 创建数据库 在开始练习之前,我默认你的电脑上是没有本系列练习题需要…

项目管理与经济决策(项目投资经济决策)

1.现金流量 流出系统的现金称为现金流出(CO) 流入系统的现金称为现金流入(CI) 现金流入-现金流出=净现金流量(NCF) 构成现金流量的基本因素:投资、(付现)成本、 (现金)收入、税金、利润等&a…

ICLR 2024 | 联邦学习后门攻击的模型关键层

ChatGPT狂飙160天,世界已经不是之前的样子。 新建了免费的人工智能中文站https://ai.weoknow.com 新建了收费的人工智能中文站https://ai.hzytsoft.cn/ 更多资源欢迎关注 联邦学习使多个参与方可以在数据隐私得到保护的情况下训练机器学习模型。但是由于服务器无法…

Java初始——IDEA-web的启动

Tomcat 文件夹作用 bin 启动 关闭的脚本文件 conf 配置 lib 依赖的jar包 logs 日志 temp 临时文件 webapps 存放的网站 Maven 1.在javaweb中,需要使用大量的jar包,手动导入 2.Maven 架构管理工具 核心:约定大于配置 必须按照规则 web idea-we…

【ELK】搭建elk日志平台(使用docker-compose),并接入springboot项目

1、环境搭建 前提条件:请自行安装docker以及docker-compose环境 version: 3 services:elasticsearch:image: elasticsearch:7.14.0container_name: elasticsearchports:- "9200:9200"- "9300:9300"environment:# 以单一节点模式启动discovery…