代码解读:使用Stable Diffusion完成相似图像生成任务

news2025/2/23 11:13:04

Diffusion models代码解读:入门与实战

前言:作为内容生产重要的一部分,生成相似图像是一项有意义的工作,例如很多内容创作分享平台单纯依赖用户贡献的图片已经不够了,最省力的方法就是利用已有的图片生成相似的图片作为补充。这篇博客详细解读基于Stable Diffusion生成相似图片的原理和代码。

目录

原理详解

代码实战

环境安装

代码运行

参数解读

快速体验

效果展示

代码解读

模型加载

图像加噪

分类器引导


原理详解

首先读者需要熟悉Stable Diffusion的原理,这部分的可以参考我之前的博客:

Diffusion models代码实战:从零搭建自己的扩散模型_diffusion model 自己-CSDN博客

在Stable Diffusion 中,latent域注入了text condition;在相似图像生成的任务中,把这个text condition替换成Image conditon:

推理的时候,会先在图片中加入噪声,这个添加噪声的程度会用noise_level参数控制。然后把这个加噪过的图片embedding输入到模型中。

原理就这么简单……

代码实战

环境安装

pip install git+https://github.com/lllcho/image_variation.git
pip install modelscope

代码运行

from modelscope.pipelines import pipeline
from modelscope.outputs import OutputKeys
from PIL import Image
from image_variation import modelscope_warpper

model = 'damo/cv_image_variation_sd'
pipe = pipeline('image_variation_task', model=model, device='gpu',auto_collate=False,model_revision='v1.1.0')
out=pipe('https://vision-poster.oss-cn-shanghai.aliyuncs.com/lllcho.lc/data/test_data/sunset-landscape-sky-colorful-preview.jpg')
imgs=out[OutputKeys.OUTPUT_IMGS]
imgs[0].save(f'result.jpg')

参数解读

pipeline调用时的可调参数:

num_inference_steps: int, 默认为20
guidance_scale:float, 默认5.0
num_images_per_prompt:默认为1,每次调用返回几张图,可根据显存大小调整
seed:默认为None,int类型,取值范围[0, 2^32-1]
height::默认值512
width:默认值512
noise_level: int,默认值为0, 取值范围[0,999],表示像输入图像中加入噪声,值越大噪声越多,生成结果与输入图像的相似度越低

快速体验

https://modelscope.cn/studios/iic/image_variation/summary

效果展示

输入的图像:

输出的图像

代码解读

代码地址:https://modelscope.cn/studios/iic/image_variation/summary

模型加载

    scheduler = UniPCMultistepScheduler(beta_start=0.00085,beta_end=0.012,beta_schedule='scaled_linear')
    vae = AutoencoderKL.from_pretrained(ckpt_dir, subfolder='vae')
    vae.eval()
    unet = UNet2DConditionModel.from_pretrained(ckpt_dir, subfolder='unet')
    cond_model, preprocess_train, preprocess_val = open_clip.create_model_and_transforms('ViT-H-14',
        pretrained=osp.join(ckpt_dir,'CLIP-ViT-H-14-laion2B-s32B-b79K/open_clip_pytorch_model.bin')     

图像加噪

对输入的图片embedding加噪后输入到模型中:

    def noise_image_embeddings(self,image_embeds,noise_level,generator=None):
        noise = randn_tensor(image_embeds.shape, generator=generator, device=image_embeds.device, dtype=image_embeds.dtype)
        noise_level = torch.tensor([noise_level] * image_embeds.shape[0], device=image_embeds.device)
        
        meanstd=torch.from_numpy(np.load(self.norm_file)[None]).to(device=image_embeds.device, dtype=image_embeds.dtype)
        mean,std=torch.chunk(meanstd,2,dim=1)
        #scale
        image_embeds-=mean
        image_embeds/=std
        
        image_embeds=self.scheduler.add_noise(image_embeds,noise,noise_level)
        
        #unscale
        image_embeds*=std
        image_embeds+=mean
    
        return image_embeds

分类器引导

和text condition 一样,这里也有“negative prompt”的分类器引导。

图像的“negative prompt”是用masked image替代的,maked image用了一个图片值全0的图片表示:

        mask=torch.ones(2,1,height//8,width//8,device=self.device,dtype=self.dtype)
        masked_img=torch.zeros(2,3,height,width,device=self.device,dtype=self.dtype)
        masked_image_latents = self.vae.encode(masked_img).latent_dist.sample()
        masked_image_latents = masked_image_latents * self.vae.config.scaling_factor
        
        mask=mask.repeat(num_images_per_prompt,1,1,1)
        masked_image_latents=masked_image_latents.repeat(num_images_per_prompt,1,1,1)

完整的推理代码如下:

   @torch.no_grad()
    def __call__(self,
                 image,
                 num_inference_steps=20,
                 guidance_scale=5.0,
                 num_images_per_prompt=1,
                 seed=None,
                 height=512,
                 width=512,
                 noise_level: int=0,
                 ):
        if seed is None:
            seed=random.randint(0,2**32-1)
        set_seed(seed)
        self.scheduler.set_timesteps(num_inference_steps, device=self.device)
        timesteps = self.scheduler.timesteps
        
        mask=torch.ones(2,1,height//8,width//8,device=self.device,dtype=self.dtype)
        masked_img=torch.zeros(2,3,height,width,device=self.device,dtype=self.dtype)
        masked_image_latents = self.vae.encode(masked_img).latent_dist.sample()
        masked_image_latents = masked_image_latents * self.vae.config.scaling_factor
        
        mask=mask.repeat(num_images_per_prompt,1,1,1)
        masked_image_latents=masked_image_latents.repeat(num_images_per_prompt,1,1,1)
         
        latents = randn_tensor((num_images_per_prompt,4,height//8,width//8),device=self.device, generator=None,dtype=self.dtype)
        latents = latents * self.scheduler.init_noise_sigma

        cond_image=image
        clip_img=self.preprocess(read_img(cond_image).convert('RGB')).unsqueeze(0)
        cond_embedding=self.cond_model.encode_image(clip_img.to(self.device,self.dtype)).to(self.dtype)
        cond_embedding=cond_embedding.repeat(num_images_per_prompt,1,1)
        if noise_level>0:
            cond_embedding=self.noise_image_embeddings(cond_embedding,noise_level)
        cond_embedding=torch.cat([cond_embedding*0,cond_embedding])
        
        for i, t in enumerate(timesteps):
            latent_model_input = torch.cat([latents] * 2)
            latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
            latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)
            
            noise_pred = self.unet(latent_model_input,t,encoder_hidden_states=cond_embedding,cross_attention_kwargs={}).sample
            
            noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
            noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
            latents = self.scheduler.step(noise_pred, t, latents, **{}).prev_sample
        latents = 1 / self.vae.config.scaling_factor * latents
        image = self.vae.decode(latents.to(self.dtype)).sample
        image = (image / 2 + 0.5).clamp(0, 1)
        image = image.cpu().permute(0, 2, 3, 1).float().numpy()
        image = (image * 255).round().astype("uint8")
        imgs=[Image.fromarray(img) for img in image]
        return imgs

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1586146.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

vscode 之 win11前端环境安装(javascrip、html、nodejs以及插件推荐)

javascript 也可以用来编写一些小工具,包括但不限于浏览器的插件,浏览器的书签以及进行一些其他操作的小工具等。 这时候就需要进行安装前端相关的测试环境 1. nodejs (1)确保电脑存在 nodejs 的环境 在 cmd 输入 npm -v node -…

DataEase-V1.18版本源码通过Docker镜像部署与静态资源通过阿里云OSS存储实现看这一篇就够了

修改DataEase实现静态资源阿里云OSS存储 后端源码文件读取配置类配置 1.阿里云OSS配置类 /*** ClassName AliyunConfig.java* author shuyixiao* version 1.0.0* Description 阿里云OSS配置* createTime 2024年04月03日 10:03:00*/ Data Configuration public class AliyunC…

Docker端口一直占用问题,docker重置(端口无法释放)(彻底重置docker环境)

文章目录 背景解决方法:彻底重置docker环境1. 停止所有Docker容器2. 删除所有容器3. 删除所有Docker镜像4. 删除所有Docker网络5. 删除所有Docker卷6. 清理Dangling资源7. 停止Docker服务8. 删除Docker数据和配置文件9. 重启Docker服务10. 验证 在这里插入图片描述验…

PostgreSQL入门到实战-第十四弹

PostgreSQL入门到实战 PostgreSQL数据过滤(七)官网地址PostgreSQL概述PostgreSQL中BETWEEN 命令理论PostgreSQL中BETWEEN 命令实战更新计划 PostgreSQL数据过滤(七) BETWEEN运算符允许您检查值是否在值的范围内。 官网地址 声明: 由于操作系统, 版本更新等原因, 文章所列内容…

“桃花庵主”是我国哪位古代名人的称号?2024年4月12日蚂蚁庄园今日答案

原文来源:蚂蚁庄园今日答案 - 词令 蚂蚁庄园是一款爱心公益游戏,用户可以通过喂养小鸡,产生鸡蛋,并通过捐赠鸡蛋参与公益项目。用户每日完成答题就可以领取鸡饲料,使用鸡饲料喂鸡之后,会可以获得鸡蛋&…

2024年腾讯云新用户云服务器价格表

腾讯云作为国内领先的云服务提供商,以其稳定可靠、灵活高效的服务赢得了广大用户的信赖。对于新用户而言,腾讯云提供了丰富的云服务器产品,并且制定了具有竞争力的价格策略,以吸引更多的新用户加入。 首先,我们来看一下…

【C++】STL--stackquene

这一节主要学习stack、quene和priority_quene的使用以及模拟实现,最后介绍了容器适配器。 目录 stack的介绍和使用 stack的介绍 stack的使用 stack的模拟实现 queue的介绍和使用 queue的介绍 queue的使用 queue的模拟实现 priority_queue的介绍和使用 pri…

Spring Boot与Vue联手打造智能化学生选课平台

末尾获取源码作者介绍:大厂全栈码农|毕设实战开发,专注于大学生项目实战开发、讲解和毕业答疑辅导。 更多项目:CSDN主页YAML墨韵 学如逆水行舟,不进则退。学习如赶路,不能慢一步。 目录 一、项目简介 二、开发技术与…

嵌入式工程师需要掌握哪些技术?

嵌入式系统是当今科技领域中的重要组成部分,它们存在于我们生活的方方面面,从智能手机到汽车控制系统,从家电到医疗设备。因此,对于那些想要进入嵌入式行业的人来说,掌握一些必要的技术能力是至关重要的。在本篇中&…

springboot 反射调用ServiceImpl时报错:java.lang.NullPointerExceptio、,mapper为null【解决方法】

springboot 反射调用ServiceImpl时报错:java.lang.NullPointerException、mapper为null【解决方法】 问题描述问题分析解决方案创建SpringBootBeanUtil编写调用方法 executeMethod调用 总结 问题描述 在使用Spring Boot时,我们希望能够通过反射动态调用…

Win11又来「重大」更新!

ChatGPT狂飙160天,世界已经不是之前的样子。 新建了免费的人工智能中文站https://ai.weoknow.com 新建了收费的人工智能中文站ai人工智能工具 更多资源欢迎关注 Windows 11预览通道的22635.3420版本迎来了几个比较大的改进,主要有三个方面: …

Springboot 大事务问题的常用优化方案

🏷️个人主页:牵着猫散步的鼠鼠 🏷️系列专栏:Java全栈-专栏 🏷️个人学习笔记,若有缺误,欢迎评论区指正 目录 1.前言 2.什么是大事务 3.解决办法 3.1.少用Transactional注解 3.2..将查询…

医疗图像分割 | 基于Pyramid-Vision-Transformer算法实现医疗息肉分割

项目应用场景 面向医疗图像息肉分割场景,项目采用 Pytorch Pyramid-Vision-Transformer 深度学习算法来实现。 项目效果 项目细节 > 具体参见项目 README.md (1) 模型架构 (2) 项目依赖,包括 python 3.8、pytorch 1.7.1、torchvision 0.8.2(3) 下载…

【实战】ZLMediaKit问题解决

项目中遇到的问题 1.不带音频的rtsp转rtmp后,出现了音频 1.1判断元素rtsp是否有音频的方法 使用vlc进行访问rtsp流,看如图位置: 音频 -> 音轨 ,是否为灰色,为灰色就是不带音频 1.2 解决方法 在zlmediakit的web页面进行全局配置修改如图, 1.将3和4处修改为 否,再保存, …

网络协议——RSTP(快速生成树)与MSTP(多实例生成树)

一. RSTP 1. STP的不足 1、依靠计时器超时的方式进行收敛导致它的收敛时间需要30到50秒 2、端口状态和端口角色没有细致区分,指导数据转发依靠的不是端口状态而是端口所扮演角色。 3、如果拓扑频繁变化导致用户通信质量差,甚至通信中断&#xf…

MyBatis中的动态SQL的用法

前言:我们要想在Spring Boot环境下使用动态SQL,必须先在application.yml中添加配置 mybatis:mapper-locations: classpath:mapper/**Mapper.xml 并且新建一个xml文件,路径及写法按照配置好的形式写 在新建好的xml文件中复制进去以下代码&a…

Golang——方法

一. 方法定义 Golang方法总是绑定对象的实例,并隐式将实例作为第一实参。 只能为当前包内命名类型定义方法参数receiver可以任意命名。如方法中未曾使用,可省略参数名参数receiver类型可以是T或*T。基类型T不能是接口或指针类型(即多级指针)不支持方法重…

【JAVASE】抽象类和接口及其抽象类和接口的区别

✅作者简介:大家好,我是橘橙黄又青,一个想要与大家共同进步的男人😉😉 🍎个人主页:再无B~U~G-CSDN博客 目标: 1. 抽象类 2. 接口 3. Object 类 1. &am…

性能测试--数据库慢 SQL 语句分析

一 慢 SQL 语句的几种常见诱因 1. 无索引或索引失效 ​ 当查询基于一个没有索引的列进行过滤、排序或连接时,数据库可能被迫进行全表扫描,即逐行检查所有数据,导致性能显著下降。 ​ 虽然我们很多时候建立了索引,但在一些特定的…

第3章 存储系统(2)

3.3 主存储器与CPU连接 3.3.1 连接原理 现代计算机的MAR和MDR都在CPU内部。 (1)主存储器通过数据总线,地址总线,控制总线与CPU连接。 (2)数据传输率数据总线宽度*总线频率。 (4)控制总线(读写线)控制读写操作。 3.3.2 主存的扩展 数据总线宽度等于存储字长 1.位扩展法【增加…