深入浅出 diffusion(3):pytorch 实现 diffusion 中的 U-Net

news2025/1/19 3:07:21

导入python包

import math

import torch
import torch.nn as nn
import torch.nn.functional as F

 silu激活函数

class SiLU(nn.Module):  
    # SiLU激活函数
    @staticmethod
    def forward(x):
        return x * torch.sigmoid(x)

归一化设置

def get_norm(norm, num_channels, num_groups):
    if norm == "in":
        return nn.InstanceNorm2d(num_channels, affine=True)
    elif norm == "bn":
        return nn.BatchNorm2d(num_channels)
    elif norm == "gn":
        return nn.GroupNorm(num_groups, num_channels)
    elif norm is None:
        return nn.Identity()
    else:
        raise ValueError("unknown normalization type")

 计算时间步长的位置嵌入,一半为sin,一半为cos

class PositionalEmbedding(nn.Module):
    def __init__(self, dim, scale=1.0):
        super().__init__()
        assert dim % 2 == 0
        self.dim = dim
        self.scale = scale

    def forward(self, x):
        device      = x.device
        half_dim    = self.dim // 2
        emb = math.log(10000) / half_dim
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        # x * self.scale和emb外积
        emb = torch.outer(x * self.scale, emb)
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb

 上下采样层设置

class Downsample(nn.Module):
    def __init__(self, in_channels):
        super().__init__()

        self.downsample = nn.Conv2d(in_channels, in_channels, 3, stride=2, padding=1)
    
    def forward(self, x, time_emb, y):
        if x.shape[2] % 2 == 1:
            raise ValueError("downsampling tensor height should be even")
        if x.shape[3] % 2 == 1:
            raise ValueError("downsampling tensor width should be even")

        return self.downsample(x)

class Upsample(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.upsample = nn.Sequential(
            nn.Upsample(scale_factor=2, mode="nearest"),
            nn.Conv2d(in_channels, in_channels, 3, padding=1),
        )
        
    def forward(self, x, time_emb, y):
        return self.upsample(x)

 使用Self-Attention注意力机制,做一个全局的Self-Attention

class AttentionBlock(nn.Module):
    def __init__(self, in_channels, norm="gn", num_groups=32):
        super().__init__()
        
        self.in_channels = in_channels
        self.norm = get_norm(norm, in_channels, num_groups)
        self.to_qkv = nn.Conv2d(in_channels, in_channels * 3, 1)
        self.to_out = nn.Conv2d(in_channels, in_channels, 1)

    def forward(self, x):
        b, c, h, w  = x.shape
        q, k, v     = torch.split(self.to_qkv(self.norm(x)), self.in_channels, dim=1)

        q = q.permute(0, 2, 3, 1).view(b, h * w, c)
        k = k.view(b, c, h * w)
        v = v.permute(0, 2, 3, 1).view(b, h * w, c)

        dot_products = torch.bmm(q, k) * (c ** (-0.5))
        assert dot_products.shape == (b, h * w, h * w)

        attention   = torch.softmax(dot_products, dim=-1)
        out         = torch.bmm(attention, v)
        assert out.shape == (b, h * w, c)
        out         = out.view(b, h, w, c).permute(0, 3, 1, 2)

        return self.to_out(out) + x

 用于特征提取的残差结构

class ResidualBlock(nn.Module):
    def __init__(
        self, in_channels, out_channels, dropout, time_emb_dim=None, num_classes=None, activation=F.relu,
        norm="gn", num_groups=32, use_attention=False,
    ):
        super().__init__()

        self.activation = activation

        self.norm_1 = get_norm(norm, in_channels, num_groups)
        self.conv_1 = nn.Conv2d(in_channels, out_channels, 3, padding=1)

        self.norm_2 = get_norm(norm, out_channels, num_groups)
        self.conv_2 = nn.Sequential(
            nn.Dropout(p=dropout), 
            nn.Conv2d(out_channels, out_channels, 3, padding=1),
        )

        self.time_bias  = nn.Linear(time_emb_dim, out_channels) if time_emb_dim is not None else None
        self.class_bias = nn.Embedding(num_classes, out_channels) if num_classes is not None else None

        self.residual_connection    = nn.Conv2d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity()
        self.attention              = nn.Identity() if not use_attention else AttentionBlock(out_channels, norm, num_groups)
    
    def forward(self, x, time_emb=None, y=None):
        out = self.activation(self.norm_1(x))
        # 第一个卷积
        out = self.conv_1(out)
        
        # 对时间time_emb做一个全连接,施加在通道上
        if self.time_bias is not None:
            if time_emb is None:
                raise ValueError("time conditioning was specified but time_emb is not passed")
            out += self.time_bias(self.activation(time_emb))[:, :, None, None]

        # 对种类y_emb做一个全连接,施加在通道上
        if self.class_bias is not None:
            if y is None:
                raise ValueError("class conditioning was specified but y is not passed")

            out += self.class_bias(y)[:, :, None, None]

        out = self.activation(self.norm_2(out))
        # 第二个卷积+残差边
        out = self.conv_2(out) + self.residual_connection(x)
        # 最后做个Attention
        out = self.attention(out)
        return out

 U-Net模型设计

class UNet(nn.Module):
    def __init__(
        self, img_channels, base_channels=128, channel_mults=(1, 2, 2, 2),
        num_res_blocks=2, time_emb_dim=128 * 4, time_emb_scale=1.0, num_classes=None, activation=F.silu,
        dropout=0.1, attention_resolutions=(1,), norm="gn", num_groups=32, initial_pad=0,
    ):
        super().__init__()
        # 使用到的激活函数,一般为SILU
        self.activation = activation
        # 是否对输入进行padding
        self.initial_pad = initial_pad
        # 需要去区分的类别数
        self.num_classes = num_classes
        
        # 对时间轴输入的全连接层
        self.time_mlp = nn.Sequential(
            PositionalEmbedding(base_channels, time_emb_scale),
            nn.Linear(base_channels, time_emb_dim),
            nn.SiLU(),
            nn.Linear(time_emb_dim, time_emb_dim),
        ) if time_emb_dim is not None else None
    
        # 对输入图片的第一个卷积
        self.init_conv  = nn.Conv2d(img_channels, base_channels, 3, padding=1)

        # self.downs用于存储下采样用到的层,首先利用ResidualBlock提取特征
        # 然后利用Downsample降低特征图的高宽
        self.downs      = nn.ModuleList()
        self.ups        = nn.ModuleList()
        
        # channels指的是每一个模块处理后的通道数
        # now_channels是一个中间变量,代表中间的通道数
        channels        = [base_channels]
        now_channels    = base_channels
        for i, mult in enumerate(channel_mults):
            out_channels = base_channels * mult
            for _ in range(num_res_blocks):
                self.downs.append(
                    ResidualBlock(
                        now_channels, out_channels, dropout,
                        time_emb_dim=time_emb_dim, num_classes=num_classes, activation=activation,
                        norm=norm, num_groups=num_groups, use_attention=i in attention_resolutions,
                    )
                )
                now_channels = out_channels
                channels.append(now_channels)
            
            if i != len(channel_mults) - 1:
                self.downs.append(Downsample(now_channels))
                channels.append(now_channels)

        # 可以看作是特征整合,中间的一个特征提取模块
        self.mid = nn.ModuleList(
            [
                ResidualBlock(
                    now_channels, now_channels, dropout,
                    time_emb_dim=time_emb_dim, num_classes=num_classes, activation=activation,
                    norm=norm, num_groups=num_groups, use_attention=True,
                ),
                ResidualBlock(
                    now_channels, now_channels, dropout,
                    time_emb_dim=time_emb_dim, num_classes=num_classes, activation=activation, 
                    norm=norm, num_groups=num_groups, use_attention=False,
                ),
            ]
        )

        # 进行上采样,进行特征融合
        for i, mult in reversed(list(enumerate(channel_mults))):
            out_channels = base_channels * mult

            for _ in range(num_res_blocks + 1):
                self.ups.append(ResidualBlock(
                    channels.pop() + now_channels, out_channels, dropout, 
                    time_emb_dim=time_emb_dim, num_classes=num_classes, activation=activation, 
                    norm=norm, num_groups=num_groups, use_attention=i in attention_resolutions,
                ))
                now_channels = out_channels
            
            if i != 0:
                self.ups.append(Upsample(now_channels))
        
        assert len(channels) == 0
        
        self.out_norm = get_norm(norm, base_channels, num_groups)
        self.out_conv = nn.Conv2d(base_channels, img_channels, 3, padding=1)
    
    def forward(self, x, time=None, y=None):
        # 是否对输入进行padding
        ip = self.initial_pad
        if ip != 0:
            x = F.pad(x, (ip,) * 4)

        # 对时间轴输入的全连接层
        if self.time_mlp is not None:
            if time is None:
                raise ValueError("time conditioning was specified but tim is not passed")
            time_emb = self.time_mlp(time)
        else:
            time_emb = None
        
        if self.num_classes is not None and y is None:
            raise ValueError("class conditioning was specified but y is not passed")
        
        # 对输入图片的第一个卷积
        x = self.init_conv(x)

        # skips用于存放下采样的中间层
        skips = [x]
        for layer in self.downs:
            x = layer(x, time_emb, y)
            skips.append(x)
        
        # 特征整合与提取
        for layer in self.mid:
            x = layer(x, time_emb, y)
        
        # 上采样并进行特征融合
        for layer in self.ups:
            if isinstance(layer, ResidualBlock):
                x = torch.cat([x, skips.pop()], dim=1)
            x = layer(x, time_emb, y)

        # 上采样并进行特征融合
        x = self.activation(self.out_norm(x))
        x = self.out_conv(x)
        
        if self.initial_pad != 0:
            return x[:, :, ip:-ip, ip:-ip]
        else:
            return x

参考链接:GitCode - 开发者的代码家园icon-default.png?t=N7T8https://gitcode.com/bubbliiiing/ddpm-pytorch/tree/master?utm_source=csdn_github_accelerator&isLogin=1

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1413324.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

php项目下微信小程序对接实战问题与解决方案

一.实战问题与方案总结 1.SQL查询条件是一组数,传参却是一个字符串导致报错,如下 SQLSTATE[HY093]: Invalid parameter number (SQL: select count(*) as aggregate from car_video where province_id in (1492) and city_id in (1493) and county_id …

Nginx编译安装以及负载均衡配置(Ubuntu 22.04)

目录 Nginx编译安装以及负载均衡配置 Ubuntu 22.04.1 LTS 编译安装 nginx-1.22.1 1.安装依赖包 2. 下载nginx 3. 编译安装 报错解决 解决问题2 4.安装 5启动Nginx: 负载均衡 负载均衡算法 轮询 加权负载均衡 ip_hash算法 算法进行配置演示 加权负载均衡 轮询 IP 哈希…

ES文档索引、查询、分片、文档评分和分析器技术原理

技术原理 索引文档 索引文档分为单个文档和多个文档。 单个文档 新建单个文档所需要的步骤顺序: 客户端向 Node 1 发送新建、索引或者删除请求。节点使用文档的 _id 确定文档属于分片 0 。请求会被转发到 Node 3,因为分片 0 的主分片目前被分配在 …

[MQ]常用的mq产品图形管理web界面或客户端

一、MQ介绍 1.1 定义 MQ全称为Message Queue,消息队列是应用程序和应用程序之间的通信方法。 如果非要用一个定义来概括只能是抽象出来一些概念,概括为跨服务之间传递信息的软件。 1.2 MQ产品 较为成熟的MQ产品:IBMMQ(IBM We…

《动手学深度学习(PyTorch版)》笔记4.1

注:书中对代码的讲解并不详细,本文对很多细节做了详细注释。另外,书上的源代码是在Jupyter Notebook上运行的,较为分散,本文将代码集中起来,并加以完善,全部用vscode在python 3.9.18下测试通过。…

万界星空科技可视化数据大屏的作用

随着科技的不断发展和进步,当前各种数据化的设备也是如同雨后春笋般冒了出来,并且其可以说是给我们带来了极大的便利的。在这其中,数据大屏就是非常具有代表性的一个例子。 数据大屏的主要作用包括: 数据分析:数据大屏…

k8s 进阶实战笔记 | Pod 创建过程详解

Pod 创建过程详解 ​ 初始状态0 controller-manager、scheduler、kubelet组件通过 list-watch 机制与 api-server 通信并检查资源变化 第一步 用户通过 CLI 或者 WEB 端等方式向 api-server 发送创建资源的请求(比如:我要创建一个replicaset资源&…

【七、centos要停止维护了,我选择Almalinux】

搜索镜像 https://developer.aliyun.com/mirror/?serviceTypemirror&tag%E7%B3%BB%E7%BB%9F&keywordalmalinux dvd是有界面操作的,minimal是最小化只有命里行 镜像下载地址 安装和centos基本一样的,操作命令也是一样的,有需要我…

Redis创建集群

主要内容 搭建redis集群 能力目标 搭建redis集群 一 应用场景 为什么需要redis集群? 当主备复制场景,无法满足主机的单点故障时,需要引入集群配置。 一般数据库要处理的读请求远大于写请求 ,针对这种情况,我们优…

【C/C++】C/C++编程——C++ 关键字和数据类型简介

C 关键字和数据类型简介 大家好,我是 shopeeai,也可以叫我虾皮,中科大菜鸟研究生。昨天已经成功运行了第一个C程序,今天来学习一下C 关键字和数据类型。C 中的关键字是由 C 标准预先定义的。它们被保留作为语言的一部分&#xff…

python222网站实战(SpringBoot+SpringSecurity+MybatisPlus+thymeleaf+layui)-帖子管理实现

锋哥原创的SpringbootLayui python222网站实战: python222网站实战课程视频教程(SpringBootPython爬虫实战) ( 火爆连载更新中... )_哔哩哔哩_bilibilipython222网站实战课程视频教程(SpringBootPython爬虫实战) ( 火…

Golang中make与new有何区别

📕作者简介: 过去日记,致力于Java、GoLang,Rust等多种编程语言,热爱技术,喜欢游戏的博主。 📗本文收录于go进阶系列,大家有兴趣的可以看一看 📘相关专栏Rust初阶教程、go语言基础系…

【新课上架】安装部署系列Ⅲ—Oracle 19c Data Guard部署之两节点RAC部署实战

01 课程介绍 Oracle Real Application Clusters (RAC) 是一种跨多个节点分布数据库的企业级解决方案。它使组织能够通过实现容错和负载平衡来提高可用性和可扩展性,同时提高性能。本课程基于当前主流版本Oracle 19cOEL7.9解析如何搭建2节点RAC对1节点单机的DATA GU…

滴滴基于 Ray 的 XGBoost 大规模分布式训练实践

背景介绍 作为机器学习模型的核心代表,XGBoost 在滴滴众多策略算法业务场景中发挥着至关重要的作用。因此,保障并持续提升 XGBoost 模型的离线训练及在线推理稳定性一直是机器学习平台的重点工作。同时,面对多样化的业务场景定制需求和数据规…

学习gin框架知识的注意点

这几天重新学习了一遍gin框架:收获颇多 Gin框架的初始化 有些项目中 初始化gin框架写的是: r : gin.New() r.Use(logger.GinLogger(), logger.GinRecovery(true)) 而不是r : gin.Default() 为什么呢? 点击进入Default源码发现其实他也是…

大数据就业方向-(工作)ETL开发

上一篇文章: 大数据 - 大数据入门第一篇 | 关于大数据你了解多少?-CSDN博客 目录 🐶1.ETL概念 🐶2. ETL的用处 🐶3.ETL实现方式 🐶4. ETL体系结构 🐶5. 什么是ETL技术? &…

Linux——搭建FTP服务器

1、FTP简介 FTP(File Transfer Protocol) :是一种处于应用层的用于文件传输的协议。FTP客户端和FTP服务器之间的通信使用TCP/IP协议族。它规定了客户端和服务器之间的通信格式和命令集,包括用户认证、文件传输、文件名和目录信息等,允许用户…

掌握可视化大屏:提升数据分析和决策能力的关键(下)

🤍 前端开发工程师、技术日更博主、已过CET6 🍨 阿珊和她的猫_CSDN博客专家、23年度博客之星前端领域TOP1 🕠 牛客高级专题作者、打造专栏《前端面试必备》 、《2024面试高频手撕题》 🍚 蓝桥云课签约作者、上架课程《Vue.js 和 E…

Java - OpenSSL与国密OpenSSL

文章目录 一、定义 OpenSSL:OpenSSL是一个开放源代码的SSL/TLS协议实现,也是一个功能丰富的加密库,提供了各种主要的加密算法、常用的密钥和证书封装管理功能以及SSL协议。它被广泛应用于Web服务器、电子邮件服务器、VPN等网络应用中&#x…

dvwa靶场文件上传high

dvwa upload high 第一次尝试(查看是否是前端验证)第二次尝试我的上传思路最后发现是图片码上传修改配置文件尝试蚁🗡连接菜刀连接 第一次尝试(查看是否是前端验证) 因为我是初学者,所以无法从代码审计角度…