参考论文
实践
有数据增强的代码
import math
import collections
import CLIP_.clip as clip
import torch
import torch.nn as nn
from torchvision import models, transforms
import numpy as np
import webp
from PIL import Image
import skimage
import torchvision
import pydiffvg
import os
import torch.nn.functional as F
class GeometrymatchLoss(torch.nn.Module):
def __init__(self, device, reference_images_path):
super(GeometrymatchLoss, self).__init__()
self.device = device
self.model, clip_preprocess = clip.load(
'ViT-B/32', self.device, jit=False)
self.model.eval()
self.preprocess = transforms.Compose(
[clip_preprocess.transforms[0], clip_preprocess.transforms[-1]]) # clip normalisation
self.reference_images_feature = self.reference_images_feature(reference_images_path)
self.reference_images_feature =self.reference_images_feature/ self.reference_images_feature.norm(dim=-1, keepdim=True)
self.text = clip.tokenize([ "A picture of triangle"]).to(device)
self.text_features = self.model.encode_text(self.text)
# self.text_features = self.text_features / self.text_features.norm(dim=-1, keepdim=True)
print("text_features.requires_grad:",self.text_features.requires_grad)
self.text_features=self.text_features.detach()
self.shape_groups=[pydiffvg.ShapeGroup(shape_ids=torch.tensor([0]), fill_color=torch.tensor([0.0, 0.0, 0.0, 1.0]),
stroke_color=torch.tensor([0.0, 0.0, 0.0, 1.0]))]
# Image Augmentation Transformation
self.augment_trans = transforms.Compose([
transforms.RandomPerspective(fill=1, p=1, distortion_scale=0.5),
transforms.RandomResizedCrop(224, scale=(0.7, 0.9)),
])
def forward(self, t,canvas_width, canvas_height,shapes):
scene_args = pydiffvg.RenderFunction.serialize_scene(canvas_width, canvas_height, shapes, self.shape_groups)
# 渲染图像
render = pydiffvg.RenderFunction.apply
target = render(canvas_width, canvas_height, 2, 2, 0, None, *scene_args)
if target.shape[-1] == 4:
target = self.compose_image_with_white_background(target)
if t%100==0:
pydiffvg.imwrite(target.cpu(), f'learn/log_augs/output_{t}.png', gamma=2.2)
# targets_ = self.preprocess(target.permute(2, 0, 1).unsqueeze(0)).to(self.device)
img = target.unsqueeze(0)
img = img.permute(0, 3, 1, 2)
loss = 0
NUM_AUGS = 4
img_augs = []
for n in range(NUM_AUGS):
img_augs.append(self.augment_trans(img))
im_batch = torch.cat(img_augs)
image_features = self.model.encode_image(im_batch)
# logit_scale = self.model.logit_scale.exp()
for n in range(NUM_AUGS):
loss -= torch.cosine_similarity(self.text_features, image_features[n:n + 1], dim=1)
return loss
def compose_image_with_white_background(self, img: torch.tensor) -> torch.tensor:
if img.shape[-1] == 3: # return img if it is already rgb
return img
# Compose img with white background
alpha = img[:, :, 3:4]
img = alpha * img[:, :, :3] + (1 - alpha) * torch.ones(
img.shape[0], img.shape[1], 3, device=self.device)
return img
def read_png_image_from_path(self, path_to_png_image: str) -> torch.tensor:
numpy_image = skimage.io.imread(path_to_png_image)
normalized_tensor_image = torch.from_numpy(numpy_image).to(
torch.float32) / 255.0
resizer = torchvision.transforms.Resize((224, 224))
resized_image = resizer(normalized_tensor_image.permute(2, 0, 1)
).permute(1, 2, 0)
return resized_image
def reference_images_feature(self, reference_images_path):
reference_images_num = len(os.listdir(reference_images_path))
reference_images_feature = []
for i in range(reference_images_num):
i_reference_image = self.read_png_image_from_path(os.path.join(reference_images_path, str(i) + ".png"))
if i_reference_image.shape[-1] == 4:
i_reference_image = self.compose_image_with_white_background(i_reference_image)
# targets_ = self.preprocess(i_reference_image.permute(2, 0, 1).unsqueeze(0)).to(self.device)
i_reference_image_features = self.model.encode_image(i_reference_image.permute(2, 0, 1).unsqueeze(0).to(self.device)).detach()
reference_images_feature.append(i_reference_image_features)
return torch.cat(reference_images_feature)
def read_png_image_from_path(path_to_png_image: str) -> torch.tensor:
if path_to_png_image.endswith('.webp'):
numpy_image = np.array(webp.load_image(path_to_png_image))
else:
numpy_image = skimage.io.imread(path_to_png_image)
normalized_tensor_image = torch.from_numpy(numpy_image).to(
torch.float32) / 255.0
resizer = torchvision.transforms.Resize((224, 224))
resized_image = resizer(normalized_tensor_image.permute(2, 0, 1)
).permute(1, 2, 0)
return resized_image
if __name__ == '__main__':
torch.autograd.set_detect_anomaly(True)
from tqdm import tqdm
def get_bezier_circle(radius: float = 80,
segments: int = 4,
bias: np.array = np.asarray([100., 100.])):
deg = torch.arange(0, segments * 3 + 1) * 2 * np.pi / (segments * 3 + 1)
points = torch.stack((torch.cos(deg), torch.sin(deg))).T
points = points * radius + torch.tensor(bias).unsqueeze(dim=0)
points = points.type(torch.FloatTensor).contiguous()
return points
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
matchLoss = GeometrymatchLoss(device, "reference_images/")
# print(matchLoss.reference_images_feature.shape)
# img1 = read_png_image_from_path('learn/output.png')
canvas_width, canvas_height = 224, 224
num_segments=4
points1 = get_bezier_circle()
path = pydiffvg.Path(num_control_points=torch.tensor(num_segments * [2] + [0],dtype=torch.int32), points=points1, stroke_width=torch.tensor(2.0),
is_closed=True)
shapes=[path]
path.points.requires_grad = True
print(id(path.points))
print(id(points1))
points_vars = []
points_vars.append(path.points)
points_optim = torch.optim.Adam(points_vars, lr=1)
pbar = tqdm(range(100000))
print(points1)
for t in pbar:
# print(t)
points_optim.zero_grad()
# print("match_loss:", match_loss)
match_loss = matchLoss(t,224, 224, shapes)
match_loss.backward()
# print(path.points.grad)
points_optim.step()
pbar.set_postfix({"match_loss": f"{match_loss.item()}"})
# print(points_vars[0])
pass
迭代1000轮次后生成的结果
没有图像增强
import math
import collections
import CLIP_.clip as clip
import torch
import torch.nn as nn
from torchvision import models, transforms
import numpy as np
import webp
from PIL import Image
import skimage
import torchvision
import pydiffvg
import os
import torch.nn.functional as F
class GeometrymatchLoss(torch.nn.Module):
def __init__(self, device, reference_images_path):
super(GeometrymatchLoss, self).__init__()
self.device = device
self.model, clip_preprocess = clip.load(
'ViT-B/32', self.device, jit=False)
self.model.eval()
self.preprocess = transforms.Compose(
[clip_preprocess.transforms[0], clip_preprocess.transforms[-1]]) # clip normalisation
# self.preprocess = transforms.Compose([clip_preprocess.transforms[-1]]) # clip normalisation
self.reference_images_feature = self.reference_images_feature(reference_images_path)
self.reference_images_feature =self.reference_images_feature/ self.reference_images_feature.norm(dim=-1, keepdim=True)
self.text = clip.tokenize([ "A picture of triangle"]).to(device)
# self.text = clip.tokenize(["A picture of rectangle", "A picture of triangle", "A picture of circle", "A picture of pentagon","A picture of five-pointed star"]).to(device)
self.text_features = self.model.encode_text(self.text)
self.text_features = self.text_features / self.text_features.norm(dim=-1, keepdim=True)
print("text_features.requires_grad:",self.text_features.requires_grad)
self.text_features=self.text_features.detach()
self.shape_groups=[pydiffvg.ShapeGroup(shape_ids=torch.tensor([0]), fill_color=torch.tensor([0.0, 0.0, 0.0, 1.0]),
stroke_color=torch.tensor([0.0, 0.0, 0.0, 1.0]))]
# Image Augmentation Transformation
self.augment_trans = transforms.Compose([
transforms.RandomPerspective(fill=1, p=1, distortion_scale=0.5),
transforms.RandomResizedCrop(224, scale=(0.7, 0.9)),
])
def forward(self, t,canvas_width, canvas_height,shapes):
scene_args = pydiffvg.RenderFunction.serialize_scene(canvas_width, canvas_height, shapes, self.shape_groups)
# 渲染图像
render = pydiffvg.RenderFunction.apply
target = render(canvas_width, canvas_height, 2, 2, 0, None, *scene_args)
if target.shape[-1] == 4:
target = self.compose_image_with_white_background(target)
if t%100==0:
pydiffvg.imwrite(target.cpu(), f'learn/log/output_{t}.png', gamma=2.2)
# targets_ = self.preprocess(target.permute(2, 0, 1).unsqueeze(0)).to(self.device)
img = target.unsqueeze(0)
img = img.permute(0, 3, 1, 2)
loss = 0
NUM_AUGS = 4
img_augs = []
for n in range(NUM_AUGS):
img_augs.append(self.augment_trans(img))
im_batch = torch.cat(img_augs)
image_features = self.model.encode_image(img)
self.targets_features: torch.tensor=image_features[0]
self.targets_features = self.targets_features / self.targets_features.norm(dim=-1, keepdim=True)
loss -= torch.cosine_similarity(self.text_features, self.targets_features, dim=1)
return loss
def compose_image_with_white_background(self, img: torch.tensor) -> torch.tensor:
if img.shape[-1] == 3: # return img if it is already rgb
return img
# Compose img with white background
alpha = img[:, :, 3:4]
img = alpha * img[:, :, :3] + (1 - alpha) * torch.ones(
img.shape[0], img.shape[1], 3, device=self.device)
return img
def read_png_image_from_path(self, path_to_png_image: str) -> torch.tensor:
numpy_image = skimage.io.imread(path_to_png_image)
normalized_tensor_image = torch.from_numpy(numpy_image).to(
torch.float32) / 255.0
resizer = torchvision.transforms.Resize((224, 224))
resized_image = resizer(normalized_tensor_image.permute(2, 0, 1)
).permute(1, 2, 0)
return resized_image
def reference_images_feature(self, reference_images_path):
reference_images_num = len(os.listdir(reference_images_path))
reference_images_feature = []
for i in range(reference_images_num):
i_reference_image = self.read_png_image_from_path(os.path.join(reference_images_path, str(i) + ".png"))
if i_reference_image.shape[-1] == 4:
i_reference_image = self.compose_image_with_white_background(i_reference_image)
# targets_ = self.preprocess(i_reference_image.permute(2, 0, 1).unsqueeze(0)).to(self.device)
i_reference_image_features = self.model.encode_image(i_reference_image.permute(2, 0, 1).unsqueeze(0).to(self.device)).detach()
reference_images_feature.append(i_reference_image_features)
return torch.cat(reference_images_feature)
def read_png_image_from_path(path_to_png_image: str) -> torch.tensor:
if path_to_png_image.endswith('.webp'):
numpy_image = np.array(webp.load_image(path_to_png_image))
else:
numpy_image = skimage.io.imread(path_to_png_image)
normalized_tensor_image = torch.from_numpy(numpy_image).to(
torch.float32) / 255.0
resizer = torchvision.transforms.Resize((224, 224))
resized_image = resizer(normalized_tensor_image.permute(2, 0, 1)
).permute(1, 2, 0)
return resized_image
if __name__ == '__main__':
torch.autograd.set_detect_anomaly(True)
from tqdm import tqdm
def get_bezier_circle(radius: float = 80,
segments: int = 4,
bias: np.array = np.asarray([100., 100.])):
deg = torch.arange(0, segments * 3 + 1) * 2 * np.pi / (segments * 3 + 1)
points = torch.stack((torch.cos(deg), torch.sin(deg))).T
points = points * radius + torch.tensor(bias).unsqueeze(dim=0)
points = points.type(torch.FloatTensor).contiguous()
return points
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
matchLoss = GeometrymatchLoss(device, "reference_images/")
# print(matchLoss.reference_images_feature.shape)
# img1 = read_png_image_from_path('learn/output.png')
canvas_width, canvas_height = 224, 224
num_segments=4
points1 = get_bezier_circle()
path = pydiffvg.Path(num_control_points=torch.tensor(num_segments * [2] + [0],dtype=torch.int32), points=points1, stroke_width=torch.tensor(2.0),
is_closed=True)
shapes=[path]
path.points.requires_grad = True
print(id(path.points))
print(id(points1))
points_vars = []
points_vars.append(path.points)
points_optim = torch.optim.Adam(points_vars, lr=1)
pbar = tqdm(range(100000))
print(points1)
for t in pbar:
# print(t)
points_optim.zero_grad()
# print("match_loss:", match_loss)
match_loss = matchLoss(t,224, 224, shapes)
match_loss.backward()
# print(path.points.grad)
points_optim.step()
pbar.set_postfix({"match_loss": f"{match_loss.item()}"})
# print(points_vars[0])
pass
迭代1000轮次后生成的结果
迭代2000轮次后生成的结果
迭代4000轮次后生成的结果
迭代8000轮次后生成的结果
无图像增强效果不好的原因分析
论文CLIPDraw: Exploring Text-to-Drawing Synthesisthrough Language-Image Encoders解释
论文StyleCLIPDraw: Coupling Content and Style in Text-to-Drawing Translation解释
个人理解
因为有很多图片可以和一个文本相匹配,对于我们人来说这些图片有一个根本和文本不相关,如果进行图像增强大概率会得到局部最优值。在计算损失函数之前对图片先进行增强,透过透视等变换,相关的图片不论如何变换和文本的相似度基本不会降低,而不相关的图像变换完之后一般会让相似度降低,这样就可以防止不相关图片对实验结果的影响。