AIGC笔记--基于Stable Diffusion实现图片的inpainting

news2024/11/15 15:50:04

1--完整代码

SD_Inpainting

2--简单代码

import PIL
import torch
import numpy as np
from PIL import Image
from tqdm import tqdm
import torchvision
from diffusers import AutoencoderKL, UNet2DConditionModel, DDIMScheduler
from transformers import CLIPTextModel, CLIPTokenizer

# 预处理mask
def preprocess_mask(mask):
    mask = mask.convert("L") # 转换为灰度图: L = R * 299/1000 + G * 587/1000+ B * 114/1000。
    w, h = mask.size # 512, 512
    w, h = map(lambda x: x - x % 32, (w, h))  # resize to integer multiple of 32
    mask = mask.resize((w // 8, h // 8), resample = PIL.Image.NEAREST) # 64, 64
    mask = np.array(mask).astype(np.float32) / 255.0 # 归一化 64, 64
    mask = np.tile(mask, (4, 1, 1)) # 4, 64, 64
    mask = mask[None].transpose(0, 1, 2, 3)
    mask = 1 - mask  # repaint white, keep black # mask图中,mask的部分变为0
    mask = torch.from_numpy(mask)
    return mask

# 预处理image
def preprocess(image):
    w, h = image.size
    w, h = map(lambda x: x - x % 32, (w, h))  # resize to integer multiple of 32
    image = image.resize((w, h), resample=PIL.Image.LANCZOS)
    image = np.array(image).astype(np.float32) / 255.0
    image = image[None].transpose(0, 3, 1, 2)
    image = torch.from_numpy(image)
    return 2.0 * image - 1.0

if __name__ == "__main__":
    model_id = "runwayml/stable-diffusion-v1-5" # online download
    # model_id = "/mnt/dolphinfs/hdd_pool/docker/user/hadoop-waimai-aigc/liujinfu/All_test/test0714/huggingface.co/runwayml/stable-diffusion-v1-5" # local path

    # 读取输入图像和输入mask
    input_image = Image.open("./images/overture-creations-5sI6fQgYIuo.png").resize((512, 512))
    input_mask = Image.open("./images/overture-creations-5sI6fQgYIuo_mask.png").resize((512, 512))

    # 1. 加载autoencoder
    vae = AutoencoderKL.from_pretrained(model_id, subfolder = "vae")

    # 2. 加载tokenizer和text encoder 
    tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder = "tokenizer")
    text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder = "text_encoder")

    # 3. 加载扩散模型UNet
    unet = UNet2DConditionModel.from_pretrained(model_id, subfolder = "unet")

    # 4. 定义noise scheduler
    noise_scheduler = DDIMScheduler(
        num_train_timesteps = 1000,
        beta_start = 0.00085,
        beta_end = 0.012,
        beta_schedule = "scaled_linear",
        clip_sample = False, # don't clip sample, the x0 in stable diffusion not in range [-1, 1]
        set_alpha_to_one = False,
    )

    # 将模型复制到GPU上
    device = "cuda"
    vae.to(device, dtype = torch.float16)
    text_encoder.to(device, dtype = torch.float16)
    unet = unet.to(device, dtype = torch.float16)

    # 设置prompt和超参数
    prompt = "a mecha robot sitting on a bench"
    negative_prompt = ""
    strength = 0.75
    guidance_scale = 7.5
    batch_size = 1
    num_inference_steps = 50
    generator = torch.Generator(device).manual_seed(0)

    with torch.no_grad():
        # get prompt text_embeddings
        text_input = tokenizer(prompt, padding = "max_length", 
            max_length = tokenizer.model_max_length, 
            truncation = True, 
            return_tensors = "pt")
        text_embeddings = text_encoder(text_input.input_ids.to(device))[0]

        # get unconditional text embeddings
        max_length = text_input.input_ids.shape[-1]
        uncond_input = tokenizer(
            [negative_prompt] * batch_size, padding = "max_length", max_length = max_length, return_tensors = "pt"
        )
        uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0]
        # concat batch
        text_embeddings = torch.cat([uncond_embeddings, text_embeddings])

        # 设置采样步数
        noise_scheduler.set_timesteps(num_inference_steps, device = device)

        # 根据strength计算timesteps
        init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
        t_start = max(num_inference_steps - init_timestep, 0)
        timesteps = noise_scheduler.timesteps[t_start:]

        # 预处理init_image
        init_input = preprocess(input_image)
        init_latents = vae.encode(init_input.to(device, dtype=torch.float16)).latent_dist.sample(generator)
        init_latents = 0.18215 * init_latents
        init_latents = torch.cat([init_latents] * batch_size, dim=0)
        init_latents_orig = init_latents

        # 处理mask
        mask_image = preprocess_mask(input_mask)
        mask_image = mask_image.to(device=device, dtype=init_latents.dtype)
        mask = torch.cat([mask_image] * batch_size)
        
        # 给init_latents加噪音
        noise = torch.randn(init_latents.shape, generator = generator, device = device, dtype = init_latents.dtype)
        init_latents = noise_scheduler.add_noise(init_latents, noise, timesteps[:1])
        latents = init_latents # 作为初始latents

        # Do denoise steps
        for t in tqdm(timesteps):
            # 这里latens扩展2份,是为了同时计算unconditional prediction
            latent_model_input = torch.cat([latents] * 2)
            latent_model_input = noise_scheduler.scale_model_input(latent_model_input, t) # for DDIM, do nothing

            # 预测噪音
            noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample

            # Classifier Free Guidance
            noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
            noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

            # x_t -> x_t-1
            latents = noise_scheduler.step(noise_pred, t, latents).prev_sample
            
            # 将unmask区域替换原始图像的nosiy latents
            init_latents_proper = noise_scheduler.add_noise(init_latents_orig, noise, torch.tensor([t]))
            # mask的部分数值为0
            # 因此init_latents_proper * mask为保留原始latents(不mask)
            # 而latents * (1 - mask)为用生成的latents替换mask的部分
            latents = (init_latents_proper * mask) + (latents * (1 - mask)) 

        # 注意要对latents进行scale
        latents = 1 / 0.18215 * latents
        image = vae.decode(latents).sample
        
        # 转成pillow
        img = (image / 2 + 0.5).clamp(0, 1).detach().cpu()
        img = torchvision.transforms.ToPILImage()(img.squeeze())
        img.save("./outputs/output.png")
        print("All Done!")

运行结果:

3--基于Diffuser进行调用

import torch
import torchvision
from PIL import Image
from diffusers import StableDiffusionInpaintPipelineLegacy

if __name__ == "__main__":
    # load inpainting pipeline
    model_id = "runwayml/stable-diffusion-v1-5"
    # model_id = "/mnt/dolphinfs/hdd_pool/docker/user/hadoop-waimai-aigc/liujinfu/All_test/test0714/huggingface.co/runwayml/stable-diffusion-v1-5" # local path
    pipe = StableDiffusionInpaintPipelineLegacy.from_pretrained(model_id, torch_dtype = torch.float16).to("cuda")

    # load input image and input mask
    input_image = Image.open("./images/overture-creations-5sI6fQgYIuo.png").resize((512, 512))
    input_mask = Image.open("./images/overture-creations-5sI6fQgYIuo_mask.png").resize((512, 512))

    # run inference
    prompt = ["a mecha robot sitting on a bench", "a cat sitting on a bench"]
    generator = torch.Generator("cuda").manual_seed(0)
    with torch.autocast("cuda"):
        images = pipe(
            prompt = prompt,
            image = input_image,
            mask_image = input_mask,
            num_inference_steps = 50,
            strength = 0.75,
            guidance_scale = 7.5,
            num_images_per_prompt = 1,
            generator = generator
        ).images

    # 转成pillow
    for idx, image in enumerate(images):
        image.save("./outputs/output_{:d}.png".format(idx))
    print("All Done!")

运行结果:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1928725.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【全面介绍Pip换源】

🎥博主:程序员不想YY啊 💫CSDN优质创作者,CSDN实力新星,CSDN博客专家 🤗点赞🎈收藏⭐再看💫养成习惯 ✨希望本文对您有所裨益,如有不足之处,欢迎在评论区提出…

产品经理-产品经理会在项目中遇到的几个问题(16)

项目中遇到了需求变更怎么办? 首先要弄清楚需求变更的原因是什么。如果是因为在迭代的过程中更好地理解了用户需求 进而产生了更好的需求则完全是正常的。如果是因为老板的需求 那就需要和老板沟通清楚,并且确保自己能理解老板的需求,而且这个…

【数据结构】高效解决连通性问题的并查集详解及Python实现

文章目录 1. 并查集:一种高效的数据结构2. 并查集的基本操作与优化2.1 初始化2.2 查找操作与路径压缩2.3 合并操作与按秩合并 3. 并查集的应用3.1 判断连通性3.2 计算连通分量 4. 并查集的实际案例4.1 图的连通性问题4.2 网络连接问题 5. 并查集的优缺点5.1 优点5.2…

哪些网站是获取独立站外链的最佳选择?

想要为独立站获取外链,有几个地方可以考虑,首先自然是最有效的博客和文章投稿网站,找那些与你的行业相关的博客和内容平台,撰写高质量的文章,里面自然地嵌入你的链接。这是最有价值的外链 然后不分其他,效…

ESP32-S3多模态交互方案在线AI语音设备应用,启明云端乐鑫代理商

随着物联网(IoT)和人工智能(AI)技术的飞速发展,嵌入式设备正逐渐变得智能化,让我们的家庭生活变得更加智能化和个性化。 随着大型语言模型的不断进步和优化,AI语音机器人设备能够实现更加智能、…

超越 Transformer开启高效开放语言模型的新篇章

在人工智能快速发展的今天,对于高效且性能卓越的语言模型的追求,促使谷歌DeepMind团队开发出了RecurrentGemma这一突破性模型。这款新型模型在论文《RecurrentGemma:超越Transformers的高效开放语言模型》中得到了详细介绍,它通过…

软件工程课设——成绩管理系统

软件工程课设——成绩管理系统 该文档是软件工程课程设计,成绩管理子系统的开发模块仓库。 功能分析 从面向的用户分,成绩管理子系统主要面向三类用户,即至少需要满足这三类用户的需求: 学生:学生是成绩管理系统的…

实现keepalive+Haproxyde 的高可用

需要准备五台实验机 一台客户机:test1 两台:一主一备的实验机:test2 test3 两台真实服务器:nginx1 nginx2 实验 首先在两台实验机上安装Haproxy 安装依赖环境,并将Haproxy的包进行解压处理 yum install -y pcre…

什么ISP?什么是IAP?

做单片机开发的工程师经常会听到两个词:ISP和IAP,但新手往往对这两个概念不是很清楚,今天就来和大家聊聊什么是ISP,什么是IAP? 一、ISP ISP的全称是:In System Programming,即在系统编程&…

vscode常用组件

1.vue-helper 启用后点击右下角注册,可以通过vue组件点击到源码里面 2.【Auto Close Tag】和【Auto Rename Tag】 3.setting---Auto Reveal Exclude vscode跳转node_modules下文件,没有切换定位到左侧菜单目录> 打开VSCode的setting配置&#xff…

Redis的使用(四)常见使用场景-缓存使用技巧

1.绪论 redis本质上就是一个缓存框架,所以我们需要研究如何使用redis来缓存数据,并且如何解决缓存中的常见问题,缓存穿透,缓存击穿,缓存雪崩,以及如何来解决缓存一致性问题。 2.缓存的优缺点 2.1 缓存的…

Transformer模型解析:走进自然语言处理的新时代

UPDATED:2023 年 1 月 27 日,本文登上 ATA 头条。(注:ATA 全称 Alibaba Technology Associate,是阿里集团最大的技术社区)UPDATED:2023 年 2 月 2 日,本文在 ATA 获得鲁肃点赞。&…

华为OD算法题汇总

60、计算网络信号 题目 网络信号经过传递会逐层衰减,且遇到阻隔物无法直接穿透,在此情况下需要计算某个位置的网络信号值。注意:网络信号可以绕过阻隔物 array[m][n],二维数组代表网格地图 array[i][j]0,代表i行j列是空旷位置 a…

数据结构(4.0)——串的定义和基本操作

串的定义(逻辑结构) 串,即字符串(String)是由零个或多个字符组成的有序数列。 一般记为Sa1a2....an(n>0) 其中,S是串名,单引号括起来的字符序列是串的值;ai可以是字母、数字或其他字符;串中字符的个数n称为串的长度。n0时的…

分布式对象存储minio

本教程minio 版本:RELEASE.2021-07-*及以上 1. 分布式文件系统应用场景 互联网海量非结构化数据的存储需求 电商网站:海量商品图片视频网站:海量视频文件网盘 : 海量文件社交网站:海量图片 1.1 Minio介绍 MinIO 是一个基于Ap…

Spring解决循环依赖:三级缓存

1.什么是循环依赖 通俗来讲,循环依赖指的是一个实例或多个实例存在相互依赖的关系(类之间循环嵌套引用)。 2.Spring如何解决循环依赖 首先,先介绍Spring是如何创建Bean的。 (1)createBeanInstance&…

【LoadRunner】博客笔记项目 性能测试报告

文章目录 前言一、博客笔记项目性能测试介绍二、编写性能测试脚本(VUG) 2.1 测试脚本编写步骤 2.2 脚本总代码和结果分析三、创建测试场景(Controller) 3.1 测试场景创建实现步骤四、生成测试报告(Anal…

集合相关知识

string final,不能追加,需要重新new一个 stringbuild,内容 可变,可以重新赋能,能够追加,空间不足创造一个更大的,然后复制过去 stringbufferbuild 线程安全 javac编译,字符串加号…

SpringBoot介绍以及第一个SpringBoot程序

T04BF 👋专栏: 算法|JAVA|MySQL|C语言 🫵 今天你敲代码了吗 文章目录 2.第一个SpringBoot程序2.1Spring Boot介绍2.2使用idea创建Spring Boot程序2.2.1 社区版idea2.2.2专业版idea2.2.3创建SpringBoot项目2.2.4项目代码和目录介绍目录介绍pom文件 2.3Web…

Linux 上 TTY 的起源

注:机翻,未校对。 What is a TTY on Linux? (and How to Use the tty Command) What does the tty command do? It prints the name of the terminal you’re using. TTY stands for “teletypewriter.” What’s the story behind the name of the co…