diffusers中DDPMScheduler/AutoencoderKL/UNet2DConditionModel/CLIPTextModel代码详解

news2024/11/24 19:34:57

扩散模型的训练时比较简单的

上图可见,unet是epsθ是unet。noise和预测出来的noise做个mse loss。

训练的常规过程:

latents = vae.encode(batch["pixel_values"].to(weight_dtype)).latent_dist_sample()
latents = latents*vae.config.scaling_factor
noise = torch.randn_like(latents)
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
            
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
encoder_hidden_states = text_encoder(batch["input_ids"])[0]
 
target = noise
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")

具体分析:

diffusers/models/autoencoder_kl.py

AutoencoderKL.encode->
h = self.encoder(x)
moments = self.quant_conv(h)
posterior = DiagonalGaussianDistribution(moments)
AutoencoderKLOutput(posterior)

diffusers/schedulers/scheduing_ddpm.py

add_noise(original_samples,noise,timesteps)->
noisy_samples = sqrt_alpha_prod*original_samples+sqrt_one_inus_alpha_prod*noise

transformers/models/clip/modeling_clip.py

CLIPTextModel.forward->
self.text_model()->

hidden_states = self.embedding(input_ids,position_ids)->
causal_attention_mask = self._build_causal_attention_mask(bsz,seq_len,hidden_states)
encoder_outputs = self.encoder(hidden_states,attention_mask,causal_attention_mask,output_attention,output_hidden_states)
last_hidden_state = encoder_outputs[0]
last_hidden_state = self.final_layer_norm(last_hidden_state)
pooled_output = last_hidden_state[torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device), input_ids.argmax(dim=-1)]

diffusers/models/unet_2d_condition.py

{
  "_class_name": "UNet2DConditionModel",
  "_diffusers_version": "0.19.3",
  "act_fn": "silu",
  "addition_embed_type": null,
  "addition_embed_type_num_heads": 64,
  "addition_time_embed_dim": null,
  "attention_head_dim": 8,
  "block_out_channels": [
    320,
    640,
    1280,
    1280
  ],
  "center_input_sample": false,
  "class_embed_type": null,
  "class_embeddings_concat": false,
  "conv_in_kernel": 3,
  "conv_out_kernel": 3,
  "cross_attention_dim": 768,
  "cross_attention_norm": null,
  "down_block_types": [
    "CrossAttnDownBlock2D",
    "CrossAttnDownBlock2D",
    "CrossAttnDownBlock2D",
    "DownBlock2D"
  ],
  "downsample_padding": 1,
  "dual_cross_attention": false,
  "encoder_hid_dim": null,
  "encoder_hid_dim_type": null,
  "flip_sin_to_cos": true,
  "freq_shift": 0,
  "in_channels": 4,
  "layers_per_block": 2,
  "mid_block_only_cross_attention": null,
  "mid_block_scale_factor": 1,
  "mid_block_type": "UNetMidBlock2DCrossAttn",
  "norm_eps": 1e-05,
  "norm_num_groups": 32,
  "num_attention_heads": null,
  "num_class_embeds": null,
  "only_cross_attention": false,
  "out_channels": 4,
  "projection_class_embeddings_input_dim": null,
  "resnet_out_scale_factor": 1.0,
  "resnet_skip_time_act": false,
  "resnet_time_scale_shift": "default",
  "sample_size": 64,
  "time_cond_proj_dim": null,
  "time_embedding_act_fn": null,
  "time_embedding_dim": null,
  "time_embedding_type": "positional",
  "timestep_post_act": null,
  "transformer_layers_per_block": 1,
  "up_block_types": [
    "UpBlock2D",
    "CrossAttnUpBlock2D",
    "CrossAttnUpBlock2D",
    "CrossAttnUpBlock2D"
  ],
  "upcast_attention": false,
  "use_linear_projection": false
}
model_pred = unet(noisy_latents,timesteps,encoder_hidden_states).sample

0.center input
sample = 2*sample-1

1.time
t_emb = self.time_proj(timesteps)
emb = self.time_embedding(t_emb,timestep_cond)

2.pre-process
sample = self.conv_in(sample)

3.down
for downsample_block in self.down_blocks:
    sample,res_samples = downsample_block(sample,emb)
    down_block_res_samples += res_samples

4.mid
sample = self.mid_block(sample,emb)

5.up
for i,upsample_block in enumerate(self.up_blocks):
    sample = upsample_block(hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size)

6.post-process
sample = self.conv_out(sample)

扩散模型的推理:

diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py

StableDiffusionPipeline->

0.default height and width to unet

1.check inputs.
self.check_inputs(prompt,height,width,callback_steps,negative_prompt,prompt_embeds,negative_embeds)

2.define call parameters
batch
do_classifier_free_guidance

3.encode input prompt
prompt_embeds = self._encode_prompt(prompt,negative_prompt)

4.prepare timesteps
self.scheduler.set_timesteps(num_inference_steps)
timesteps = self.scheduler.timesteps

5.prepare latent variables
latents = self.prepare_latents(batch_size * num_images_per_prompt,num_channels_latents,height,width,prompt_embeds.dtype,device,generator,latents)

6.prepare extra step kwargs
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)

7.denosing loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
for i,t in enumerate(timesteps):
    latent_model_input = torch.cat([latents]*2)
    latent_model_input = self.scheduler.scale_model_input(latent_model_input,t)
    
    # predict the noise residual
    noise_pred = self.unet(latent_model_input,t...)[0]
    
    if do_classifer_free_guidance:
        noise_pred_uncond,noise_pred_text = noise_pred.chunk(2)
        noise_pred = noise_pred_uncond + guidance_scale*(noise_pred_text-noise_pred_uncond)
    
    # compute the previous noisy sample x_t->x_t-1
    latents = self.scheduler.step(noise_pred,t,latents,..)[0] # xt

image = self.image_processor.postprocess()

diffusers/schedulers/scheduling_ddpm.py

step->
t = timesteps
prev_t = self.previous_timestep(t)
- prev_t = timestep-self.config.num_train_timesteps//num_inference_steps

1.compute alpha,betas 
# 认为设置超参数beta,满足beta随着t的增大而增大,根据beta计算alpha
alpha_prod_t = self.alphas_cumprod[t]
alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one
beta_prod_t = 1 - alpha_prod_t
beta_prod_t_prev = 1 - alpha_prod_t_prev
current_alpha_t = alpha_prod_t / alpha_prod_t_prev
current_beta_t = 1 - current_alpha_t

2.compute predicted original sample from predicted noise also called predicted_x0
pred_original_sample = (sample-beta_prod_t**(0.5)*model_output)/alpha_prod_t**(0.5)

3.clip or threshold predicted x0
pred_original_sample = pred_original_sample.clamp(-self.config.clip_sample_range,self.config.clip_sample_range)

4.compute coefficients for pred_original_sample x0 and current sample xt
pred_original_sample_coeff = (alpha_prod_t_prev**0.5*current_beta_t)/beta_prod_t
current_sample_coeff = current_alpha_t**0.5*beta_prod_t_prev/beta_prod_t

5.compute predicted previous sample 
pred_prev_sample = pred_original_sample_coeff*pred_original_sample+current_sample_coeff*sample

6.add noise
variance_noise = randn_tensor()
variance = self._get_variance(t,predicted_variance)*variance_noise
pred_prev_sample = pred_prev_sample+variance

return pred_prev_sample,pred_original_sample

xt = pred_prev_sample,x0 = pred_original_sample,xt这个式子化简一下就是下面预测结果

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1024340.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

QT连接Sqlite

使用QTCreator; 根据资料,Qt自带SQLite数据库,不需要再单独安装,默认情况下,使用SQLite版本3,驱动程序为***QSQLITE***; 首先创建项目;在 Build system 中应选中qmake,…

前端自定义导出PPT

1、背景 前端导出PPT,刚接触这个需求,还是比较懵逼,然后就在网上查找资料,最终确认是可行的;这个需求也是合理的,我们做了一个可视化数据报表,报表导出成PPT,将在线报表转成文档类型…

【数据库系统概论】关系数据库中的关系数据结构

前言关系关系模式关系数据库关系模型的存储结构感谢 💖 前言 上一篇文章【数据库系统概论】数据模型介绍了数据库系统中的数据模型的基本概念。其中提到了关系模型是最重要的一种数据模型。下面将介绍支持关系模型的数据库系统——关系数据库。 按照数据模型的三大…

蓝牙核心规范(V5.4)10.5-BLE 入门笔记之HCI

HCI全称:HOST Constroller Interface 主机控制器接口(HCI)定义了一个标准化的接口,通过该接口,主机可以向控制器发出命令,并且控制器可以与主机进行通信。规范被分成几个部分,第一部分仅从功能的角度定义接口,不考虑具体的实现机制,而其他部分定义了在使用四种可能的…

Mac 上如何安装Mysql? 如何配置 Mysql?以及如何开启并使用MySQL

前言: 有许多开发的小伙伴,使用的是mac,那么在mac上如何安装,配置Mysql,以及使用Mysql了,今天来一个系统的教程。 安装Mysql 使用mysql前,我们需要先下载mysql,并按照以下几个步骤…

【Oracle】Oracle系列之四--用户管理

文章目录 往期回顾前言1. 创建/删除用户(1)创建用户(2)修改口令(3)删除用户 2. 用户授权管理(1)对用户直接授权(2)通过角色对用户授权 往期回顾 【Oracle】O…

Nodejs 相关知识

Nodejs是一个js运行环境,可以让js开发后端程序,实现几乎其他后端语言实现的所有功能,能够让js与其他后端语言平起平坐。 nodejs是基于v8引擎,v8是Google发布的开源js引擎,本身就是用于chrome浏览器的js解释部分&#…

day43 数据库

SQL分类 DDL:Date definition Language 数据定义语言 主要针对的是数据库对象进行创建修改删除的操作 包括:create, alter, drop, show, desc truncate DML:Data Manipulation Language 数据操作语言 对数据库中数据进行增加,修…

3D成像技术概述

工业4.0时代,三维机器视觉备受关注,目前,三维机器视觉成像方法主要分为光学成像法和非光学成像法,这之中,光学成像法是市场主流。 飞行时间3D成像 飞行时间成像(Time of Flight),简称TOF,是通过给目标连续发送光脉冲,然后用传感器接收从物体返回的光,通过探测光脉…

国庆中秋特辑(二)浪漫祝福方式 使用生成对抗网络(GAN)生成具有节日氛围的画作

要用人工智能技术来庆祝国庆中秋,我们可以使用生成对抗网络(GAN)生成具有节日氛围的画作。这里将使用深度学习框架 TensorFlow 和 Keras 来实现。 一、生成对抗网络(GAN) 生成对抗网络(GANs,…

基于Yolov8的野外烟雾检测(4):通道优先卷积注意力(CPCA),效果秒杀CBAM和SE等 | 中科院2023最新发表

目录 1.Yolov8介绍 2.野外火灾烟雾数据集介绍 3.CPCA介绍 3.1 CPCA加入到yolov8 4.训练结果分析 5.系列篇 1.Yolov8介绍 Ultralytics YOLOv8是Ultralytics公司开发的YOLO目标检测和图像分割模型的最新版本。YOLOv8是一种尖端的、最先进的(SOTA)模型&a…

Golang反射相关知识总结

1. Golang反射概述 Go语言的反射(reflection)是指在运行时动态地获取类型信息和操作对象的能力。在Go语言中,每个值都是一个接口类型,这个接口类型包含了这个值的类型信息和值的数据,因此,通过反射&#x…

Freertos学习笔记

文章目录 Freertos移植TCB控制块中断管理 (内部异常和外部中断)同步互斥与通信消息队列:邮箱:信号量:互斥量:事件组:任务通知:Freertos移植 其核心文件为,tasks.c、timers.c、queue.c、event_groups.c、croutine.c、list.c。源码兼顾了很多平台,但是我们可以删除一些…

亚马逊攀岩安全带ASTM F1772测试办理

本政策适用于主要用于攀岩或登山活动的安全带。攀岩安全带是一种装备,可穿戴在攀岩者或登山者的腰部和大腿处。攀岩安全带为绳子提供了一个连接点,并提供一种手段,以便在攀登、休息、绕绳下降或跌落的过程中为攀登者身体提供支撑。本政策涵盖…

整理mongodb文档:副本集一

个人博客 整理mongodb文档:副本集一 本文讲解较为粗糙,对于没有后台开发经验的人员,建议配合官网了解下相关概念。个人博客,日常求关注 文章概叙 文章会先花费几分钟讲解下关于垂直缩放以及水平缩放的概念,以方便大家对副本集…

Qt5开发及实例V2.0-第五章Qt主窗体

Qt5开发及实例V2.0-第五章Qt主窗体 第5章 Qt 5主窗体5.1.1 基本元素5.1.2 【综合实例】:文本编辑器5.1.3 菜单与工具栏的实现 5.2 Qt 5文件操作功能5.2.1 新建文件5.2.2 打开文件5.2.3 打印文件 5.3 Qt 5图像坐标变换5.3.1 缩放功能5.3.2 旋转功能5.3.3 镜像功能 5.…

适用于 Linux 的 Windows 子系统获得新的“镜像”网络模式

Microsoft 发布了 Windows Subsystem for Linux (WSL) 2.0.0,其中包含一组新的可选实验功能,包括新的网络模式以及自动内存和磁盘大小清理。 首先,新添加的“自动内存回收”功能通过回收缓存内存来动态减少 WSL 虚拟机 (VM) 的内存占用。 此…

Git学习笔记8

Gitlab: Gitlab是利用Ruby on Rails 一个开源的版本管理系统,实现一个自托管的git项目仓库,可通过web界面进行访问公开或私有的项目。 Gitlab安装: 安装之前,将虚拟机的内存改成了4个G。内存如果太小,会有…

Zipping

Zipping 信息收集端口扫描目录扫描webbanner信息收集 漏洞利用空字节绕过---->失败sqlI-preg_match bypass反弹shell 稳定维持 提权-共享库漏洞 参考:https://rouvin.gitbook.io/ibreakstuff/writeups/htb-season-2/zipping#sudo-privileges-greater-than-stock-…

基于Python开发的图片批量处理器(源码+可执行程序+程序配置说明书+程序使用说明书)

一、项目简介 本项目是一套基于Python开发的图片批量处理器,主要针对计算机相关专业的正在做毕设的学生与需要项目实战练习的Python学习者。 包含:项目源码、项目文档等,该项目附带全部源码可作为毕设使用。 项目都经过严格调试,…