扩散模型(2)--1

news2024/12/25 10:01:30
1.简介

        生成模型通过学习并建模输入数据的分布,从而采集生成新的样木,该模型广泛运用于图片视频生成、文本生成和药物分子生成。扩散模型是一类概率生成模型,扩散模型通过向数据中逐步加入噪声来破坏数据的结构,然后学习一个相对应的逆向过程来进行去噪,从而学习原始数据的分布,扩散模型可以生成与真实样本分布高度致的高质量新样本。
        原始图片在加噪过程中逐渐失去了所有信息,最终变成了无法辨识的白噪声。反过程是从噪声开始,模型逐渐对数据进行去噪,可辨识的信息越来越多,直到所有噪声全部被去掉,并生成了新的图片。展示了去噪过程中最重要的概念--分数函数,即当前数据对数似然的梯度,直观上它指向拥有更大似然(更少噪声)的数据分布。逆向过程中去噪的每一步都需要计算当前数据的分数函数,然后根据分数函数对数据进行去噪。一般的生成模型可以分为两类:一类可以直接对数据分布进行建模,比如自回归模型和能量模型;还有是基于潜在变量的模型,它们先假设了潜在变量的分布,然后通过学习一个随机或者非随机的变换将潜在变量进行转换,使转换后的分布逼近真实数据的分布。第二类的生成模型包括变分自编码器(VariationalAuto-Encoder,VAE)、生成对抗网络(Generative Adversarial Network,GAN)、归一化流(Normalizing Flow)。与变分自编码器、生成对抗网络、归一化流等基于潜在变量的生成模型类似,扩散模型也是对潜在变量进行变换,使变换后的数据分布逼近真实数据的分布。但是变分自编码器不仅需要学习从潜在变量到数据的“生成器”q\theta(xIZ),还需要学习用后验分布q\varphi(z|x)来近似真实后验分布q\theta(zIx)以训练生成器。

        而如何选择后验分布是变分自缘码器的难点,如果选得比较简单,那么很可能没办法近似真实后验分布,从而导致型效果不好;而如果选得比较复杂,那么其计算又会很复杂。虽然生成对抗网络和化流都不涉及计算后验分布,但它们也有各自的缺点。生成对抗网络的训练需要外的判别器,这导致其训练难、不稳定;归一化流则要求潜在变量到数据的映射是可映射,这大大限制了其表达能力,并且不能直接使用SOTA(state-ofthe-art)的神经网络框架。

        而扩散模型则综合了上述模型的优点并且避免了上述模型的缺点,只需要训练生成器即可。损失函数的形式简单且容易训练,不需要如判别器等其他的辅助网络表达能力强。当前对扩散模型的研究大多基于3个主要框架:去噪扩散概率模型、基于分数的生成模型、随机微分方程。

2.去噪扩散概率模型

        去噪扩散概率模型(DDPM),定义了一个马尔可夫链(MarkovChain)(马尔科夫链是一种随机过程,它描述了一个系统在不同状态之间转换的概率模型。在马尔科夫链中,系统的未来状态只依赖于当前状态,而与过去的状态无关。这种性质称为无记忆性马尔科夫性质并缓慢地向数据添加随机噪声,然后学习逆向扩散过程,从噪声中构建所需的数据样本。一个DDPM由两个马尔可夫链组成,一个正向马尔可夫链(以下简称“正向链”)将数据转化为噪声;一个逆向马尔可夫链(以下简称“逆向链”)将噪声转化为数据。正向链通常是预先设计的,其目标是逐步将数据分布转化为简单的先验分布如标准高斯分布。而逆向链的每一步的转移核(转移核通常指的是转移概率分布,它描述了马尔科夫链中从一种状态转移到另一种状态的概率)是由深度神经网络学习得到的,其目标是逆向链转正向链从而生成数据。新数据的生成需要先从先验分布中抽取随机向量,然后将此随机向量输入逆向链并使用祖先采样法(祖先采样法它通过构建一个马尔科夫链来近似目标分布,然后通过这个链进行采样)生成新数据。

        超参数是指在学习过程开始之前设置的参数,而不是通过训练数据直接估计的参数。这些参数通常用于控制学习过程中的行为,比如算法的复杂度、学习率、迭代次数等。

3.代码

生成图片分类的架构

        __init__.py是一个特殊的文件,它允许一个目录被识别为Python的包。这个文件可以是空的,也可以包含初始化包的Python代码,写上后就可以导入包了。

init代码部分

from .DiffusionCondition import *
from .ModelCondition import *
from .TrainCondition import *

DiffusionCondition代码部分

import torch
import torch.nn as nn
#torch.nn模块包含了构建神经网络所需的类和函数
import torch.nn.functional as F
#这行代码导入了PyTorch的torch.nn.functional模块,并给它起了一个别名F
import numpy as np
def extract(v,t,x_shape):
#用于从给定的张量v中提取特定索引t对应的值,并且将结果重新塑形为特定的形状 x_shape
    device=t.device()
#设备确定,获取索引张量t所在的设备,以确保后续操作在同一设备上执行
    out=torch.gather(v,index=t,dim=0).float().to(device)
#这个函数沿着维度0,通常是第一个维度从张量v中根据索引t提取元素,index=t指定了要提取的元素的索引
# .to(device): 确保结果张量在与索引张量 t 相同的设备上
    return out.view([t.shape[0]]+[1]*(len(x_shape)-1))

# t.shape[0]: 获取索引张量t的第一个维度的大小,这通常是批量大小
# [1] * (len(x_shape) - 1): 创建一个由1组成的列表,长度为 x_shape的长度减1,这通常用于保持输出张量与输入张量t在其他维度上的一致性
# .view(): 将张量out重新塑形为指定的形状,这个形状的第一个维度是t的批量大小,其余维度都是1


TrainCondition代码部分

import os
#import os
from typing import Dict
# 用于类型注解,Dict 是 typing 模块的一部分,用于指定字典的键和值的类型
import numpy as np
import torch
import torch.optim as optim
#导入PyTorch的优化器模块,用于在训练过程中更新模型的权重
from tqdm import tqdm
# 从tqdm库导入tqdm类,这是一个进度条库
from torch.utils.data import DataLoader
#用于创建数据加载器,可以批量加载数据
from torchvision import transforms
#供了一系列的图像预处理功能,如归一化、裁剪、旋转等
from torchvision.datasets import CIFAR10
#一个常用的小型图像数据集
from torchvision.utils import save_image
from DiffusionFreeGuidence.DiffusionCondition import GaussianDiffusionSampler,GaussianDiffusionTrainer
#自定义的类,用于实现去噪扩散模型的采样和训练
from DiffusionFreeGuidence.ModelCondition import UNet
#导入UNet类,UNet是一种常见的网络架构,常用于图像分割和生成任务
from Scheduler import GradualWarmupScheduler
#一个自定义的学习率调度器,用于在训练初期逐渐增加学习率,有助于模型更稳定地开始训练
def train(modelConfig:Dict):
    #接受一个类型为字典的参数modelConfig
    device=torch.device(modelConfig["device"])
    #创建了一个CIFAR-10数据集实例
    dataset=CIFAR10(
        # 指定了数据集的存储路径,train表示加载训练集,download=True 表示如果数据集不在本地,则下载它
        root='./CIFAR10',train=True,download=True,
        transform=transforms.Compose([
            transforms.Compose(),
            transforms.ToTensor(),#将PIL图像或NumPy数组转换为torch.Tensor,Tensor是多维数组的数学表示,类似于NumPy的数组
            #但Tensor可以在GPU上使用,从而加速计算。PyTorch模型的输入和输出都是Tensor,但Tensor可以在GPU上使用,从而加速计算。PyTorch模型的输入和输出都是Tensor
            transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5)),#将图像的每个通道归一化到[-1, 1]范围内
        ]))
    #创建了一个DataLoader实例,用于批量加载数据集,并且可以并行加载数据以提高效率
    dataloader=DataLoader(#设置了每个批次的样本数,shuffle=true表示在每个epoch开始时打乱数据num_workers=4用于数据加载的子进程数
        #drop_last=True 表示如果数据集大小不能被批次大小整除,则丢弃最后一个不完整的批次
        #pin_memory=True 表示如果使用CUDA,则将数据加载到CUDA固定内存中,这样转移到GPU会更快
        dataset,batch_size=modelConfig["batch_size"],shuffle=True,num_workers=4,drop_last=True,pin_memory=True)
#model
#创建了一个UNet模型实例,UNet是一种常用于图像分割任务的卷积神经网络架构,T=modelConfig["T"] 表示网络中的时间步长或层数,T可能从其他地方读取的所以不是数字
#num_labels=10 表示数据集中类别的数量,ch=modelConfig["channel"] 表示输入通道的数量,ch_mult表示通道数的乘数,用于在网络的不同层中扩展或减少通道数
    net_model=UNet(T=modelConfig["T"],num_labels=10,ch=modelConfig["channel"],ch_mult=modelConfig["channel_mult"],
                   #num_res_blocks表示残差块的数量,表示Dropout率,用于正则化以减少过拟合
                   num_res_blocks=modelConfig["num_res_blocks"],dropout=modelConfig["dropout"]).to(device)
    #这行代码检查modelConfig字典中的"training_load_weight"键是否不为None。如果不为None,则表示需要加载预训练的模型权重
    if modelConfig["training_load_weight"] is not None:
        net_model.load_state_dict(torch.load(os.path.join( #strict=False: 允许模型结构和权重文件中的结构不完全匹配,即权重文件中的参数可以少于模型参数
            modelConfig["save_dir"],modelConfig["training_load_weight"]),map_location=device),strict=False)
        print("Model weight load down.")
    optimizer=torch.optim.AdamW(
        # 使用AdamW优化器,它是Adam优化器的一个变种,增加了权重衰减
        #获取模型的所有参数,lr设置学习率,weight_decay=1e-4: 设置权重衰减率,用于正则化以减少过拟合
        net_model.parameters(),lr=modelConfig["lr"],weight_deacy=1e-4)
    #optim.lr_scheduler.CosineAnnealingLR: 使用余弦退火学习率调度器,它在训练过程中逐渐减小学习率
    cosineScheduler=optim.lr_scheduler.CosineAnnealingLR(#optimizer=optimizer: 指定优化器,并设置调度器的最大周期eta_0,设置学习率的最小值,表示从0开始计数epoch
        optimizer=optimizer,T_max=modelConfig["epoch"],eta_min=0,last_epoch=-1)
    #GradualWarmupScheduler: 使用预热调度器,在训练的初始阶段逐渐增加学习率
    #optimizer=optimizer, multiplier=modelConfig["multiplier"]: 指定优化器,设置预热阶段学习率的乘数
    # warm_epoch=modelConfig["epoch"]//10: 设置预热阶段的周期数,after_scheduler=cosineScheduler: 设置预热调度器之后的调度器
    warmUpScheduler=GradualWarmupScheduler(optimizer=optimizer,multiplier=modelConfig["multiplier"],
                                           warm_epoch=modelConfig["epoch"]//10,after_scheduler=cosineScheduler)
    #创建一个高斯扩散训练器,这是一个用于训练扩散模型的训练器
    trainer=GaussianDiffusionTrainer(
        #modelConfig["beta_1"], modelConfig["beat_T"], modelConfig["T"]: 分别设置扩散过程中的超参数
        #超参数:是机器学习和深度学习中的一个重要概念。它们是在学习过程开始之前设置的参数,不同于模型训练过程中学习的参数
        net_model,modelConfig["beta_1"],modelConfig["beat_T"],modelConfig["T"]).to(device)
# start training
    for e in range(modelConfig["epoch"]):
        #tqdm 是一个快速、可扩展的Python进度条库,使用 tqdm 库来包装数据加载器 dataloader,以便在训练过程中显示进度条
        # dynamic_ncols=True 表示进度条的宽度会根据终端的列数动态调整
        with tqdm(dataloader,dynamic_ncols=True) as tqdmDataLoader:
            for images,labels in tqdmDataLoader:
                b=images.shape[0] #获取当前批次的大小,即图像的数量
                optimizer.zero_grad()#清除(重置)模型参数的梯度,这是在每次权重更新前必须执行的步骤
                x_0=images.to(device)
                labels=labels.to(device)+1#将标签数据移动到指定的设备上,并加1,这可能是为了将标签从0开始的索引转换为1开始的索引
                if np.random.rand() <0.1: #随机决定是否进行某种操作,概率为10%
                    labels=torch.zeros_like(labels).to(device) #如果上述条件满足,将标签设置为0,并移动到指定设备上
                loss=trainer(x_0,labels).sum()/b**2#调用trainer函数计算损失,通常是模型的前向传播和计算损失,损失值被求和后再除以批次大小的平方
                loss.backward()#计算损失函数关于模型参数的梯度
                torch.nn.utils.clip_grad_norm_(
                    net_model.parameters(),modelConfig["grad_clip"]) #使用梯度裁剪来防止梯度爆炸,将梯度的大小限制在 modelConfig["grad_clip"] 指定的阈值内
                optimizer.step()#根据计算出的梯度更新模型的参数
                tqdmDataLoader.set_postfix(order_dict={ #设置进度条的后缀,显示当前的训练轮数、损失、图像形状和学习率
                    "epoch":e,
                    "loss":loss.item(),
                    "img shape:":x_0.shape,
                    "LR":optimizer.state_dict()['param_groups'][0]["lr"]
                })
                #更新预热调度器,这通常在训练的初期阶段逐步增加学习率
        warmUpScheduler.step()
        torch.save(net_model.state_dict(),os.path.join(
            modelConfig["save_dir"],'ckpt_'+str(e)+"_.pt"))#保存模型的参数到指定目录,文件名包含当前的训练轮数 e
def eval(modelConfig:Dict):
    device=torch.device(modelConfig["device"])
    with torch.no_grad():#使用 torch.no_grad() 上下文管理器来停止计算梯度,这通常在模型评估阶段使用,以减少内存消耗和计算时间
        step=int(modelConfig["batch_size"]//10)#计算步长,将批次大小除以10并取整
        labelList=[]#初始化一个空列表,用于存储标签
        k=0#初始化一个变量 k,用于跟踪当前的标签值
        for i in range(1,modelConfig["batch_size"]+1):
            labelList.append(torch.ones(size=[1]).long()*k)#将一个大小为1的全1张量乘以 k,然后添加到 labelList 中
            if i %step==0:
                if k<10-1:#如果 k 小于9,增加 k 的值,用于下一个标签
                    k=k+1
        labels=torch.cat(labelList,dim=0).long().to(device)+1#参数 dim 指定了要沿着哪个维度进行连接,dim=0 表示沿着第一个维度通常是批次维度进行连接
        print("labels:",labels)
        model=UNet(T=modelConfig["T"],num_labels=10,ch=modelConfig["channel"],ch_mult=modelConfig["channel_mult"],
                   num_res_blocks=modelConfig["num_res_blocks"],dropout=modelConfig["dropout"]).to(device)
        ckpt=torch.load(os.path.join(
            modelConfig["save_dir"],modelConfig["test_load_weight"],map_location=device))
        model.load_state_dict(ckpt)#加载保存的模型权重
        print("model load weight done.")
        model.eval()#将模型设置为评估模式
        sampler=GaussianDiffusionSampler(#创建一个高斯扩散采样器,并将其移动到指定设备上
            model,modelConfig["beta_1"],modelConfig["beta_T"],w=modelConfig["w"]).to(device)
        noisyImage=torch.randn(
            size=[modelConfig["batch_size"],3,modelConfig["img_size"],modelConfig["img_size"]],device=device)
        saveNoisy=torch.clamp(noisyImage*0.5+0.5,0,1)#将噪声图像的值限制在一定范围
        sampledImgs=sampler(noisyImage,labels)#使用采样器和噪声图像以及标签生成采样图像
        sampleImages=sampleImages*0.5+0.5
        print(sampleImages)
        save_image(sampleImages,os.path.join(
            modelConfig["sampled_dir"],modelConfig["sampledImgName"]),nrow=modelConfig["nrow"])
#这行代码的作用是将sampleImages张量中的图像保存到由modelConfig字典指定的路径和文件名中,并且每行显示 modelConfig["nrow"]个图像

        
    

ModelCondition代码

import math
#telnetlib用于与Telnet服务器进行交互,PRAGMA_HEARTBEAT通常用于设置心跳机制,以保持连接的活跃状态,这在网络编程中比较常见,但在深度学习代码中不常见,可能是误加入的
from telnetlib import PRAGMA_HEARTBEAT
import torch
#nn模块包含了构建神经网络所需的类和函数,比如层、激活函数、损失函数等
from torch import nn
#init模块包含了用于神经网络参数初始化的函数,比如正态分布初始化、均匀分布初始化
from torch.nn import init
from torch.nn import functional as F
#drop_ratio 是要丢弃的连接的比例
def drop_connect(x,drop_ratio):
    #计算保留的比率,即1减去丢弃比率
    keep_ratio=1.0-drop_ratio
    #创建一个与输入张量x相同数据类型的空张量mask,其形状为[s.shape[0], 1, 1, 1],其中s.shape[0]应该是输入张量x的第一个维度的大小(通常是批量大小)
    mask=torch.empty([s.shape[0],1,1,1],dtype=x.dtype,device=x.device)
    #使用伯努利分布填充mask张量,每个元素有keep_ratio的概率为1,keep是保留的也就是成功的,不成功为0
    mask.bernoulli_(p=keep_ratio)
    #将输入张量x的每个元素除以keep_ratio,这是为了调整未被丢弃的权重,以保持输出的期望值不变
    x.div_(keep_ratio)
    #将调整后的输入张量x与mask相乘,实现权重的丢弃
    x.mul_(mask)
    #返回经过DropConnect处理后的张量x
    return x
class Swish(nn.Module):
    def forward(self,x):
        return x*torch.sigmoid(x)
class TimeEmbedding(nn.Module):
    # T:时间序列的长度
    # d_model:嵌入的维度,必须是偶数
    # dim:最终输出的维度
    def __init__(self,T,d_model,dim):
        assert d_model % 2==0#确保 d_model 是偶数,因为后面的代码中会用到 d_model // 2
        super().__init__()#调用父类的初始化方法
        emb=torch.arange(0,d_model,ste=2)/d_model *math.log(10000)#创建一个从0开始到 d_model 结束,步长为2的序列
        emb=torch.exp(-emb)#计算序列的指数衰减
        pos=torch.arrange(T).float()
        emb=pos[:,None]*emb[None,:]#将位置序列和衰减序列进行外积
        assert list(emb.shape)==[T,d_model//2]#assert list(emb.shape)==[T,d_model//2] 
        emb=torch.stack([torch.sin(emb),torch.cos(emb)],dim=-1)
        assert list(emb.shape)==[T,d_model//2,2]#确保张量的形状正确
        emb=emb.view(T,d_model)#将张量重新塑形为 T x d_model
        self.timembedding=nn.Sequential(
            nn.Embedding.form_pretrained(emb,freeze=False),#使用预训练的嵌入张量 emb 初始化嵌入层,freeze=False 表示参数可以训练
            nn.Linear(d_model,dim),#一个线性层,将输入从 d_model 维映射到 dim 维
            Swish(),#前面定义的 Swish 激活函数
            nn.Linear(dim,dim),#另一个线性层,将输入从 dim 维映射回 dim 维
            )
    def forward(self,t):
        emb=self.timembedding(t)#将时间索引 t 通过时间嵌入层
        return emb
class ConditionalEmbedding(nn.Module):
    def __init__(self,num_labels,d_model,dim):#num_labels:类别标签的数量,d_model:嵌入的维度
        assert d_model % 2 == 0 
        super().__init__()
        self.condEmbedding=nn.Sequential(#创建一个序列模块
            #创建一个嵌入层,num_embeddings 是嵌入的总数,embedding_dim 是每个嵌入的维度,padding_idx=0 表示索引0是填充索引,通常不用于训练
            nn.Embedding(num_embeddings=num_labels+1,embedding_dim=d_model,padding_idx=0),
            nn.Linear(d_model,dim),
            Swish(),
            nn.Linear(dim,dim),#另一个线性层,将输入从 dim 维映射回 dim 维
        )
    def forward(self,t):#定义了类的前向传播方法
        emb=self.condEmbedding(t)# 将类别标签索引 t 通过条件嵌入层
        return emb#返回嵌入的输出
class DownSample(nn.Module):
    def __init__(self,in_ch):#in_ch:输入通道的数量
        super().__init__()
        self.c1=nn.Conv2d(in_ch,in_ch,3,stride=2,padding=1)#3:卷积核的大小为3x3
        self.c2=nn.Conv2d(in_ch,in_ch,5,stride=2,padding=2)
    def forward(self,x,temp,cemb):
        _,_,H,W=x.shape#获取输入特征图x的形状,其中_ 表示忽略批次大小和通道数,只获取高度 H 和宽度 W
        x=self.t(x)
        x=self.c(x)
        return x#返回经过两次卷积操作后的特征图
class AttnBlock(nn.Module):
    def __init__(self,in_ch):
        super().__init__()
        self.group_norm=nn.GroupNorm(32,in_ch)#创建一个组归一化层,组大小为32,通道数为 in_ch
        self.proj_q=nn.Conv2d(in_ch,in_ch,1,stride=1,padding=0)#padding=0:无边缘填充
        self.proj_q=nn.Conv2d(in_ch,in_ch,1,stride=1,padding=0)#用于生成键向量 q
        self.proj_k=nn.Conv2d(in_ch,in_ch,1,stride=1,padding=0)#用于生成值向量 v
        self.proj_v=nn.Conv2d(in_ch,in_ch,1,stride=1,padding=0)
        self.proj=nn.Conv2d(in_ch,in_ch,1,stride=1,padding=0)
    def forward(self,x):
        B,C,H,W,=x.shape#获取输入特征图 x 的形状,其中 B 是批次大小,C 是通道数,H 是高度,W 是宽度
        h=self.group_norm(x)#将输入特征图 x 通过组归一化层
        q=self.proj_q(h)#将归一化后的特征图 h 通过卷积层生成查询向量 q
        k=self.proj_k(h)
        v=self.proj_v(h)
        q=q.permute(0,2,3,1).view(B,H*W,C)#调整 q 的维度并展平,以便于进行批矩阵乘法
        k=k.view(B,C,H*W)#调整 k 的维度
        w=torch.bmm(q,k)*(int(c)**(-0.5))
        assert list(w.shape)==[B,H*W,H*W]#确保权重矩阵 w 的形状正确
        W=F.softmax(w,dim=-1)#对权重矩阵 w 应用 softmax 函数,以获取注意力权重
        v=v.permute(0,2,3,1).view(B,H*W,C)
        h=torch.bmm(w,v)#计算加权的值向量
        assert list(h.shape)==[B,H*W,C]#确保输出 h 的形状正确
        h=h.view(B,H,W,C).permute(0,3,1,2)#将输出 h 调整回原始的批次和通道维度
        h=self.proj(h)#将输出 h 通过最后的卷积层进行投影
        return x+h#返回输入特征图 x 和输出特征图 h 的和,实现残差连接
class ResBlock(nn.Module):#这个类实现了一个残差块
    def __init__(self,in_ch,out_ch,tdim,dropout,attn=True):#out_ch:输出通道的数量,tdim:温度嵌入的维度,是否包含注意力模块,默认为 True
        super().__init__()
        self.block1=nn.Sequential(
            nn.GroupNorm(32,in_ch),#一个组归一化层,组大小为32,通道数为 in_ch
            Swish(),
            nn.Conv2d(in_ch,out_ch,3,stride=1,padding=1),
        )
        self.temp_proj=nn.Sequential(
            Swish(),
            nn.Linear(tdim,out_ch),
        )
        self.cond_proj=nn.Sequential(
            nn.GroupNorm(32,out_ch),
            Swish(),
            nn.Dropout(dropout),
            nn.Conv2d(out_ch,out_ch,3,stride=1,padding=1),
        )
        if in_ch!=out_ch:#如果输入通道和输出通道不同
            self.shortcut=nn.Conv2d(in_ch,out_ch,1,stride=1,padding=0)
        else:
            self.shortcut=nn.Identity()#使用 nn.Identity 作为快捷连接,相当于恒等映射
        if attn:
            self.attn=AttnBlock(out_ch)#创建一个 AttnBlock 注意力模块
        else:
            self.attn=nn.Identity()#使用 nn.Identity 作为注意力模块,相当于恒等映射
    def forward(self,x,temb,labels):
        h=self.block1(x)#将输入特征图 x 通过第一个卷积块
        h+=self.temp_proj(temb)[:,:,None,None]#将温度嵌入temb 通过线性层,并加到h上,[:,:,None,None] 是为了增加两个维度以匹配 h 的维度
        h+=self.cond_proj(labels)[:,:,None,None]#将条件嵌入 labels 通过条件投影块,并加到 h 上
        h=self.attn(h)# h 通过注意力模块
        return h
class UNet(nn.Module):#UNet 是一个典型的卷积神经网络架构,常用于图像分割任务。它由一个编码器(逐步降采样)和一个解码器(逐步上采样)组成,中间通过一个瓶颈层连接
    def __init__(self,T,num_labels,ch,ch_mult,num_res_blocks,dropout):
        super().__init__()
        tdim=ch*4#计算温度嵌入和条件嵌入的维度
        self.time_embedding=TimeEmbedding(T,ch,tdim)#创建一个时间嵌入层
        self.cond_embedding=ConditionalEmbedding(num_labels,ch,tdim)#创建一个条件嵌入层
        self.head=nn.Conv2d(3,ch,kernel_size=3,stride=1,padding=1)#创建一个卷积层作为网络的头部,将输入的3通道图像转换为 ch 通道的特征图
        self.downblocks=nn.ModuleList()#创建一个模块列表,用于存储下采样阶段的残差块
        chs=[ch]#初始化一个列表,用于存储当前的通道数
        for i,mult in enumerate(ch_mult):#开始一个循环,遍历 ch_mult 列表,i 是索引,mult 是当前的倍增因子
            out_ch=ch*mult#开始一个循环,根据 num_res_blocks 参数,添加指定数量的残差块
            for _ in range(num_res_blocks):#开始一个循环,根据 num_res_blocks 参数,添加指定数量的残差块
                self.downblocks.append(ResBlock(in_ch=now_ch,out_ch=out_ch,tdim=tdim,dropout=dropout))#创建一个 ResBlock 实例并添加到下采样模块列表 self.downblocks 中
                now_ch=out_ch#更新当前通道数 now_ch 为新的输出通道数 out_ch
                chs.append(now_ch)#将更新后的通道数 now_ch 添加到列表 chs 中
            if i!=len(ch_mult)-1:#如果当前不是 ch_mult 列表中的最后一个元素,意味着还需要进一步下采样
                self.downblocks.append(DownSample(now_ch))#创建一个 DownSample 实例并添加到下采样模块列表 self.downblocks 中
                chs.append(now_ch)#再次将当前通道数 now_ch 添加到列表 chs 中
        self.middleblocks=nn.ModuleList([#创建一个模块列表 self.middleblocks,包含两个 ResBlock 实例作为瓶颈层,一个带有注意力机制,另一个不带
            ResBlock(now_ch,now_ch,tdim,dropout,attn=True),
            ResBlock(now_ch,now_ch,tdim,dropout=False),
        ])
        self.middleblocks=nn.ModuleList([
            ResBlock(now_ch,now_ch,tdim,dropout,attn=True),
            ResBlock(now_ch,now_ch,tdim,dropout,attn=False),
        ])
        self.upblocks=nn.ModuleList()#创建一个空的模块列表 self.upblocks,用于存储上采样阶段的残差块
        for i ,mult in reversed(list(enumerate(ch_mult))):#开始一个循环,反向遍历 ch_mult 列表,进行上采样
            out_ch=ch*mult#计算当前上采样阶段的输出通道数 out_ch
            for _ in range(num_res_blocks+1):#开始一个循环,添加 num_res_blocks + 1 个残差块,因为每个上采样阶段开始时会多一个残差块
                self.upblocks.append(ResBlock(in_ch=chs.pop()+now_ch,out_ch=out_ch,tdim=tdim,dropout=dropout,attn=False))#chs.pop() 用于获取并移除之前下采样阶段的通道数
                now_ch=out_ch#更新当前通道数 now_ch 为新的输出通道数 out_ch
            if i!=0:#如果当前不是 ch_mult 列表中的第一个元素,意味着还需要进一步上采样
                self.upblocks.append(UpSample(now_ch))
        assert len(chs)==0#断言 chs 列表为空,确保所有的通道数都被正确地使用
        self.tail=nn.Sequential(
            nn.GroupNorm(32,now_ch),
            SWish(),
            #创建一个序列模块 self.tail 作为网络的尾部,包含一个组归一化层、一个 Swish 激活函数和一个卷积层,将最终的特征图转换回3通道的输出
            nn.Conv2d(now_ch,3,3,stride=1,padding=1)
        )
def forward(self,x,t,labels):
    temb=self.time_embedding(t)#将时间信息 t 传递给时间嵌入层 self.time_embedding,生成时间嵌入 temb
    cemb=self.cond_embedding(labels)#将条件信息 labels 传递给条件嵌入层 self.cond_embedding,生成条件嵌入 cemb
    h=self.head(x)#将输入特征图 x 传递给网络头部(self.head),通常是卷积层,得到初步处理的特征图 h
    hs=[h]#初始化一个列表 hs,用于存储下采样过程中的特征图 h
    for layer in self.downblocks:#遍历下采样阶段的所有层
        h=layer(h,temb,cemb)#对于每一层,将当前特征图 h、时间嵌入 temb 和条件嵌入 cemb 传递进去,更新特征图 h
        hs.append(h)#将每一层的输出特征图 h 添加到列表 hs 中,这些特征图将在上采样阶段用于跳跃连接
    for layer in self.middleblocks:#遍历瓶颈层的所有残差块
        h=layer(h,temb,cemb)
    for layer in self.upblocks:
        if isinstance(layer,ResBlock):#如果当前层是 ResBlock 类型,进行跳跃连接
            h=torch.cat([h,hs.pop()],dim=1)#当前特征图 h 和 hs 列表中弹出的特征图(即对应下采样阶段的特征图)在通道维度(dim=1)上进行拼接
        h=layer(h,temb,cemb)
    h=self.tail(h)
    assert len(hs)==0
    return h
if __name__==' __ main __':
    batch_size=8#定义一个变量 batch_size,设置批处理大小为8
    model=UNet(
        T=1000,num_labels=10,ch=128,ch_mult=[1,2,2,2],
        
    # T=1000:时间序列长度。
    # num_labels=10:类别标签的数量。
    # ch=128:基础通道数。
    # ch_mult=[1,2,2,2]:每个下采样阶段的通道数倍增因子列表。
    # num_res_blocks=2:每个下采样阶段的残差块数量。
    # dropout=0.1:Dropout层的丢弃率。

        num_res_blocks=2,dropout=0.1)
    x=torch.randn(batch_size,3,32,32)#生成一个随机的输入张量 x,其形状为 [batch_size, 3, 32, 32],表示有8个样本,每个样本有3个通道,每个通道的大小是32x32
    t=torch.randint(1000,size=[batch_size])#生成一个随机的时间信息张量 t,其形状为 [batch_size],每个元素是从0到999的整数
    labels=torch.randint(10,size=[batch_size])#生成一个随机的条件信息张量 labels,其形状为 [batch_size],每个元素是从0到9的整数,代表类别标签。
    y=model(x,t,labels)#将输入张量 x、时间信息张量 t 和条件信息张量 labels 传递给 UNet 模型进行前向传播,得到输出 y
    print(y.shape)#打印输出张量 y 的形状

    

        

scheduler代码部分

from torch.optim.lr_scheduler import _LRScheduler#它定义了学习率调度器的接口。学习率调度器用于在训练过程中根据一定的策略调整学习率
class GradualWarmupScheduler(_LRScheduler):
    #optimizer:用于模型训练的优化器,multiplier:学习率放大的倍数,用于预热结束后的学习率调整
    #warm_epoch:预热期的周期数(epoch),在这期间学习率将逐渐增加,after_scheduler:可选参数,预热期结束后要使用的另一个学习率调度器
    def __init__(self,optimizer,multiplier,warm_epoch,after_scheduler=None):
        self.multiplier=multiplier#初始化成员变量 self.multiplier 以存储学习率放大倍数
        self.total_epoch=warm_epoch
        self.after_scheduler=after_scheduler
        self.finished=False#初始化成员变量 self.finished,用于标记预热是否完成
        self.last_epoch=None#初始化成员变量 self.last_epoch 以存储当前的 epoch 数
        self.base_lrs=None
        super().__init__(optimizer)
        def get_lr(self):
            if self.last_epoch>self.total_epoch:#如果当前 epoch 超过预热期
                if self.after_schduler:#如果设置了预热期后的学习率调度器
                    if not self.finished:
                        self.after_scheduler.base_lrs=[base_lr*self.multiplier for base_lr in self.base_lrs]#更新预热后学习率调度器的基础学习率
                        self.finished=True
                    return self.after_scheduler.get_lr()
                return [base_lr*self.multiplier for base_lr in self.base_lrs]
            return [base_lr*((self.multiolier-1.)*self.last_epoch/self.total_epoch+1.) for base_lr in self.base_lrs]
        def step(self,epoch=None,metrics=None):
            if self.finished and self.after_schduler:
                if epoch is None:
                    self.after_scheduller.step(None)
                else:
                    self.after_scheduler.step(epoch-self.total_epoch)
            else:
                return super(GradualWarmupScheduler,self).step(epoch)

Maincondition代码部分

from DiffusionFreeGuidence.TrainCondition import train, eval
def main(model_config=None):
    modelConfig={
        "state":"train",
        "epoch":70,
        "batch_size":80,
        "T":500,
        "channel":128,
        "channel_mult":[1,2,2,2],
        "num_res_block":2,
        "dropout":0.15,
        "lr":1e-4,
        "multipllier":2.5,
        "beta_1":1e-4,
        "beta_T":0.028,
        "img_size":32,
        "grad_clip":1.,
        "device":"cpu:0",
        "w":1.8,
        "save_dir": "./CheckpointsCondition/",
        "training_load_weight":None,
        "test_load_weight": "ckpt_63_.pt",
        "sampled_dir":"./Sampled_images/",
        "sampledNoisyImgName":"SampledImgs/",
        "sampledImagName":"SampledGuidenceImgs.png",
        "nrow":8
    }
    if model_config is not None:
        modelConfig=model_config
    if modelConfig["state"]=="train":
        train(modelConfig)
    else:
        eval(modelConfig)
if __name__ == "__main__":
    main()


       

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2178148.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

在Windows系统上安装的 Boost C++ 库

步骤一 https://www.boost.org/users/history/version_1_86_0.html 下载Boost库文件: 步骤二 安装: https://www.boost.org/doc/libs/1_52_0/doc/html/bbv2/installation.html 点击运行.\bootstrap.bat脚本在当前目录的powershell中执行:./b2 install --prefixPREFIX 然后…

优选拼团平台架构解析与关键代码逻辑概述

一、系统架构设计 唐古拉优选拼团平台采用多层架构设计&#xff0c;主要包括前端展示层、业务逻辑层、数据访问层及数据存储层。 前端展示层&#xff1a;负责用户界面的展示和交互&#xff0c;包括商品列表、拼团详情、订单管理等页面。前端采用现代前端框架&#xff08;如Vue…

第十四周学习周报

目录 摘要Abstract1. LSTM的代码实现2. 序列到序列模型3. 梯度与方向导数总结 摘要 在上周的学习基础之上&#xff0c;本周学习的内容有LSTM的代码实现&#xff0c;通过对代码的学习进一步加深了对LSTM的理解。为了切入到transformer的学习&#xff0c;本文通过对一些应用例子…

JUC高并发编程4:集合的线程安全

1 内容概要 2 ArrayList集合线程不安全 2.1 ArrayList集合操作Demo 代码演示 /*** list集合线程不安全*/ public class ThreadDemo4 {public static void main(String[] args) {// 创建ArrayList集合List<String> list new ArrayList<>();for (int i 0; i <…

铺铜修改后自动重铺

很多初学者对于敷铜操作感到比较麻烦&#xff1a;为什么每次打过孔&#xff0c;修改走线后都需要手动右击-重新修改敷铜。如何提升layout的效率&#xff1f; 版本&#xff1a;Altium Designer 21.9.2 首先&#xff0c;点击面板右边的小齿轮&#xff0c;进入设置 接下来&#…

9.29学习

1.线上问题rebalance 因集群架构变动导致的消费组内重平衡&#xff0c;如果kafka集内节点较多&#xff0c;比如数百个&#xff0c;那重平衡可能会耗时导致数分钟到数小时&#xff0c;此时kafka基本处于不可用状态&#xff0c;对kafka的TPS影响极大 产生的原因 ①组成员数量发…

【C++并发入门】摄像头帧率计算和多线程相机读取(上):并发基础概念和代码实现

前言 高帧率摄像头往往应用在很多opencv项目中&#xff0c;今天就来通过简单计算摄像头帧率&#xff0c;抛出一个单线程读取摄像头会遇到的问题&#xff0c;同时提出一种解决方案&#xff0c;使用多线程对摄像头进行读取。同时本文介绍了线程入门的基础知识&#xff0c;讲解了…

2-107 基于matlab的hsv空间双边滤波去雾图像增强算法

基于matlab的hsv空间双边滤波去雾图像增强算法&#xff0c;原始图像经过光照增强后&#xff0c;将RGB转成hsv&#xff0c;进行图像增强处理&#xff0c;使图像更加清晰。程序已调通&#xff0c;可直接运行。 下载源程序请点链接&#xff1a; 2-107 基于matlab的hsv空间双边滤…

“找不到emp.dll,无法继续执行代码”需要怎么解决呢?分享6个解决方法

在日常使用电脑玩游戏的过程中&#xff0c;我们可能会遇到一些错误提示&#xff0c;其中最常见的就是“emp.dll丢失”。那么&#xff0c;emp.dll到底是什么&#xff1f;它为什么会丢失&#xff1f;丢失后会对我们的电脑产生什么影响&#xff1f;本文将为您详细解析emp.dll的概念…

超详细的华为ICT大赛报名流程

1、访问华为人才在线官网&#xff0c;点击右上角“登录/注册“&#xff0c;登录华为账号。 报名链接&#xff1a; https://e.huawei.com/cn/talent/cert/#/careerCert?navTypeauthNavKey ▲如已有华为Uniportal账号&#xff0c;完成实名认证后方可报名大赛。 ▲如没有华为…

【有啥问啥】具身智能(Embodied AI):人工智能的新前沿

具身智能&#xff08;Embodied AI&#xff09;&#xff1a;人工智能的新前沿 引言 在人工智能&#xff08;AI&#xff09;的进程中&#xff0c;具身智能&#xff08;Embodied AI&#xff09;正逐渐成为研究与应用的焦点。具身智能不仅关注于机器的计算能力&#xff0c;更强调…

需求5:增加一个按钮

在之前的几个需求中&#xff0c;我们逐步从修改字段到新增字段&#xff0c;按部就班地完成了相关工作。通过最近的文章&#xff0c;不难看出我目前正在处理前端的“未完成”和“已完成”按钮。借此机会&#xff0c;我决定趁热打铁&#xff0c;重新梳理一下之前关于按钮实现的需…

4、MapReduce编程实践

目录 1、创建文件2、启动HDFS3、启动eclipse 创建项目并导入jar包file->new->java project导入jar包finish 4、编写Java应用程序5、编译打包应用程序&#xff08;1&#xff09;查看直接运行结果&#xff08;2&#xff09;打包程序&#xff08;3&#xff09;查看 JAR 包是…

软硬协同方案破解IT瓶颈,龙蜥衍生版KOS助力内蒙古大学成功迁移10+业务软件 | 龙蜥案例

2024 云栖大会上&#xff0c;龙蜥社区发布了《龙蜥操作系统生态用户实践精选 V2》&#xff0c;为面临 CentOS 迁移的广大用户提供成熟实践样板。截至目前&#xff0c;阿里云、浪潮信息、中兴通讯 | 新支点、移动、联通、龙芯、统信软件等超 12 家厂商基于龙蜥操作系统发布商业衍…

【在Linux世界中追寻伟大的One Piece】命名管道

目录 1 -> 命名管道 1.1 -> 创建一个命名管道 1.2 -> 匿名管道与命名管道的区别 1.3 -> 命名管道的打开规则 1.4 -> 例子 1 -> 命名管道 管道应用的一个限制就是只能在具有共同祖先(具有亲缘关系)的进程间通信。如果我们想在不相关的进程之间交换数据&…

串行化执行、并行化执行

文章目录 1、串行化执行2、并行化测试&#xff08;多线程环境&#xff09;3、任务的执行是异步的&#xff0c;但主程序的继续执行是同步的 可以将多个任务编排为并行和串行化执行。 也可以处理编排的多个任务的异常&#xff0c;也可以返回兜底数据。 1、串行化执行 顺序执行、…

C++类和对象(下) 初始化列表 、static成员、友元、内部类等等

1.再探构造函数 之前使用构造函数时都是在函数体内初始化成员变量&#xff0c;还有一种构造函数的用法&#xff0c;叫做初始化列表&#xff1b;那么怎么使用呢&#xff1f; 使用方法用冒号开始(" : ")要写多个就用逗号(" , ")隔开数据成队列每个成员变量后…

DC00023基于jsp+MySQL新生报到管理系统

1、项目功能演示 DC00023基于jsp新生报到管理系统java webMySQL新生管理系统 2、项目功能描述 基于jspMySQL新生报到管理系统项目分为学生、辅导员、财务处和系统管理员四个角色。 2.1 学生功能 1、系统登录 2、校园新闻、报到流程、学校简介、在线留言、校园风光、入校须知…

解决Qt每次修改代码后首次运行崩溃,后几次不崩溃问题

在使用unique_ptr声明成员变量后&#xff0c;我习惯性地在初始化构造列表中进行如下构造&#xff1a; 注意看&#xff0c;我将m_menuBtnGroup的父类指定为ui->center_menu_widget&#xff0c;这便是导致崩溃的根本原因&#xff0c;解决办法便是先用this初始化&#xff0c;后…

pdf页面尺寸裁减

1、编辑pdf 2、点击裁减页面&#xff0c;并在空白区域双击裁减 3、输入裁减数据&#xff1a;