IMAGDressing是一个全新的虚拟试衣框架,它由南京理工大学、武汉理工大学、腾讯AI实验室和南京大学共同开发。
该项目旨在通过先进的技术提升消费者的在线购物体验,特别是通过虚拟试穿技术(VTON)来实现逼真的服装效果。
IMAGDressing定义了一个新的虚拟穿衣任务,专注于生成具有固定服装和可选条件的自由可编辑人物图像,同时设计了一种全面的亲和度指标来评估生成图像与参考服装之间的一致性。
此外,IMAGDressing-v1还结合了一个服装UNet,该UNet从CLIP捕获语义特征,从VAE捕获纹理特征,并引入了一个混合注意力模块,包括冻结的自注意力和可训练的交叉注意力,以将服装特征整合到冻结的去噪UNet中,确保用户可以通过文本控制不同的场景。
github项目地址:https://github.com/muzishen/IMAGDressing。
一、环境安装
1、python环境
建议安装python版本在3.10以上。
2、pip库安装
pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 torchaudio==2.0.2 --index-url https://download.pytorch.org/whl/cu118
pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple
3、IMAGDressing模型下载
git lfs install
git clone https://huggingface.co/feishen29/IMAGDressing
4、sd-vae-ft-mse模型下载
git lfs install
git clone https://huggingface.co/stabilityai/sd-vae-ft-mse
5、Realistic_Vision_V4.0_noVAE模型下载
git lfs install
git clone https://huggingface.co/SG161222/Realistic_Vision_V4.0_noVAE
6、IP-Adapter-FaceID模型下载
git lfs install
git clone https://huggingface.co/h94/IP-Adapter-FaceID
7、control_v11p_sd15_openpose模型下载
git lfs install
git clone https://huggingface.co/lllyasviel/control_v11p_sd15_openpose
8、IP-Adapter模型下载
git lfs install
git clone https://huggingface.co/h94/IP-Adapter
9、IDM-VTON模型下载
git lfs install
git clone https://huggingface.co/spaces/yisol/IDM-VTON
二、功能测试
1、命令行运行测试:
(1)指定服装的python代码测试
import os
import torch
from PIL import Image
from diffusers import UNet2DConditionModel, AutoencoderKL, DDIMScheduler
from torchvision import transforms
from transformers import CLIPImageProcessor
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
from adapter.attention_processor import CacheAttnProcessor2_0, RefSAttnProcessor2_0, CAttnProcessor2_0
import argparse
from adapter.resampler import Resampler
from dressing_sd.pipelines.IMAGDressing_v1_pipeline import IMAGDressing_v1
def resize_img(input_image, max_side=640, min_side=512, size=None,
pad_to_max_side=False, mode=Image.BILINEAR, base_pixel_number=64):
w, h = input_image.size
ratio = min_side / min(h, w)
w, h = round(ratio * w), round(ratio * h)
ratio = max_side / max(h, w)
input_image = input_image.resize([round(ratio * w), round(ratio * h)], mode)
w_resize_new = (round(ratio * w) // base_pixel_number) * base_pixel_number
h_resize_new = (round(ratio * h) // base_pixel_number) * base_pixel_number
input_image = input_image.resize([w_resize_new, h_resize_new], mode)
return input_image
def image_grid(imgs, rows, cols):
assert len(imgs) == rows * cols
w, h = imgs[0].size
grid = Image.new("RGB", size=(cols * w, rows * h))
grid_w, grid_h = grid.size
for i, img in enumerate(imgs):
grid.paste(img, box=(i % cols * w, i // cols * h))
return grid
def prepare(args):
generator = torch.Generator(device=args.device).manual_seed(42)
vae = AutoencoderKL.from_pretrained("path/to/sd-vae-ft-mse").to(dtype=torch.float16, device=args.device)
tokenizer = CLIPTokenizer.from_pretrained("path/to/Realistic_Vision_V4.0_noVAE", subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained("path/to/Realistic_Vision_V4.0_noVAE", subfolder="text_encoder").to(dtype=torch.float16, device=args.device)
image_encoder = CLIPVisionModelWithProjection.from_pretrained("path/to/IP-Adapter", subfolder="models/image_encoder").to(dtype=torch.float16, device=args.device)
unet = UNet2DConditionModel.from_pretrained("path/to/Realistic_Vision_V4.0_noVAE", subfolder="unet").to(dtype=torch.float16, device=args.device)
image_proj = Resampler(dim=unet.config.cross_attention_dim, depth=4, dim_head=64, heads=12, num_queries=16, embedding_dim=image_encoder.config.hidden_size, output_dim=unet.config.cross_attention_dim, ff_mult=4).to(dtype=torch.float16, device=args.device)
attn_procs = {}
for name in unet.attn_processors.keys():
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
if name.startswith("mid_block"):
hidden_size = unet.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = unet.config.block_out_channels[block_id]
if cross_attention_dim is None:
attn_procs[name] = RefSAttnProcessor2_0(name, hidden_size)
else:
attn_procs[name] = CAttnProcessor2_0(name, hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
unet.set_attn_processor(attn_procs)
adapter_modules = torch.nn.ModuleList(unet.attn_processors.values()).to(dtype=torch.float16, device=args.device)
ref_unet = UNet2DConditionModel.from_pretrained("path/to/Realistic_Vision_V4.0_noVAE", subfolder="unet").to(dtype=torch.float16, device=args.device)
ref_unet.set_attn_processor({name: CacheAttnProcessor2_0() for name in ref_unet.attn_processors.keys()})
model_sd = torch.load(args.model_ckpt, map_location="cpu")["module"]
ref_unet_dict = {}
unet_dict = {}
image_proj_dict = {}
adapter_modules_dict = {}
for k, v in model_sd.items():
if k.startswith("ref_unet"):
ref_unet_dict[k.replace("ref_unet.", "")] = v
elif k.startswith("unet"):
unet_dict[k.replace("unet.", "")] = v
elif k.startswith("proj"):
image_proj_dict[k.replace("proj.", "")] = v
elif k.startswith("adapter_modules"):
adapter_modules_dict[k.replace("adapter_modules.", "")] = v
ref_unet.load_state_dict(ref_unet_dict)
image_proj.load_state_dict(image_proj_dict)
adapter_modules.load_state_dict(adapter_modules_dict)
noise_scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", steps_offset=1)
pipe = IMAGDressing_v1(unet=unet, reference_unet=ref_unet, vae=vae, tokenizer=tokenizer,
text_encoder=text_encoder, image_encoder=image_encoder,
ImgProj=image_proj, scheduler=noise_scheduler,
safety_checker=StableDiffusionSafetyChecker,
feature_extractor=CLIPImageProcessor())
return pipe, generator
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='IMAGDressing_v1')
parser.add_argument('--model_ckpt', default="path/to/IMAGDressing-v1_512.pt", type=str)
parser.add_argument('--cloth_path', type=str, required=True)
parser.add_argument('--output_path', type=str, default="./output_sd_base")
parser.add_argument('--device', type=str, default="cuda:0")
args = parser.parse_args()
output_path = args.output_path
os.makedirs(output_path, exist_ok=True)
pipe, generator = prepare(args)
print('====================== Pipe loaded successfully ===================')
num_samples = 1
clip_image_processor = CLIPImageProcessor()
img_transform = transforms.Compose([
transforms.Resize([640, 512], interpolation=transforms.InterpolationMode.BILINEAR),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
])
prompt = 'A beautiful woman, best quality, high quality'
null_prompt = ''
negative_prompt = 'bare, naked, nude, undressed, monochrome, lowres, bad anatomy, worst quality, low quality'
clothes_img = Image.open(args.cloth_path).convert("RGB")
clothes_img = resize_img(clothes_img)
vae_clothes = img_transform(clothes_img).unsqueeze(0).to(args.device)
ref_clip_image = clip_image_processor(images=clothes_img, return_tensors="pt").pixel_values.to(args.device)
output = pipe(ref_image=vae_clothes, prompt=prompt, ref_clip_image=ref_clip_image, null_prompt=null_prompt, negative_prompt=negative_prompt, width=512, height=640, num_images_per_prompt=num_samples, guidance_scale=7.5, image_scale=1.0, generator=generator, num_inference_steps=50).images
save_output = [clothes_img.resize((512, 640), Image.BICUBIC)]
save_output.append(output[0])
grid = image_grid(save_output, 1, 2)
grid.save(os.path.join(output_path, os.path.basename(args.cloth_path)))
未完......
更多详细的内容欢迎关注:杰哥新技术