利用clip模型实现text2draw

news2024/9/19 10:34:30

参考论文

实践

有数据增强的代码

import math
import collections
import CLIP_.clip as clip
import torch
import torch.nn as nn
from torchvision import models, transforms
import numpy as np
import webp
from PIL import Image
import skimage
import torchvision
import pydiffvg
import os
import torch.nn.functional as F

class GeometrymatchLoss(torch.nn.Module):
    def __init__(self, device, reference_images_path):
        super(GeometrymatchLoss, self).__init__()
        self.device = device
        self.model, clip_preprocess = clip.load(
            'ViT-B/32', self.device, jit=False)
        self.model.eval()
        self.preprocess = transforms.Compose(
            [clip_preprocess.transforms[0], clip_preprocess.transforms[-1]])  # clip normalisation
        self.reference_images_feature = self.reference_images_feature(reference_images_path)
        self.reference_images_feature =self.reference_images_feature/ self.reference_images_feature.norm(dim=-1, keepdim=True)
        self.text = clip.tokenize([ "A picture of triangle"]).to(device)
        self.text_features = self.model.encode_text(self.text)
        # self.text_features = self.text_features / self.text_features.norm(dim=-1, keepdim=True)
        print("text_features.requires_grad:",self.text_features.requires_grad)
        self.text_features=self.text_features.detach()
        self.shape_groups=[pydiffvg.ShapeGroup(shape_ids=torch.tensor([0]), fill_color=torch.tensor([0.0, 0.0, 0.0, 1.0]),
                                       stroke_color=torch.tensor([0.0, 0.0, 0.0, 1.0]))]

        # Image Augmentation Transformation
        self.augment_trans = transforms.Compose([
            transforms.RandomPerspective(fill=1, p=1, distortion_scale=0.5),
            transforms.RandomResizedCrop(224, scale=(0.7, 0.9)),
        ])



    def forward(self, t,canvas_width, canvas_height,shapes):

        scene_args = pydiffvg.RenderFunction.serialize_scene(canvas_width, canvas_height, shapes, self.shape_groups)
        # 渲染图像
        render = pydiffvg.RenderFunction.apply
        target = render(canvas_width, canvas_height, 2, 2, 0, None, *scene_args)

        if target.shape[-1] == 4:
            target = self.compose_image_with_white_background(target)
        if t%100==0:
            pydiffvg.imwrite(target.cpu(), f'learn/log_augs/output_{t}.png', gamma=2.2)
        # targets_ = self.preprocess(target.permute(2, 0, 1).unsqueeze(0)).to(self.device)
        img = target.unsqueeze(0)
        img = img.permute(0, 3, 1, 2)
        loss = 0
        NUM_AUGS = 4
        img_augs = []
        for n in range(NUM_AUGS):
            img_augs.append(self.augment_trans(img))
        im_batch = torch.cat(img_augs)
        image_features = self.model.encode_image(im_batch)
        # logit_scale = self.model.logit_scale.exp()
        for n in range(NUM_AUGS):
            loss -= torch.cosine_similarity(self.text_features, image_features[n:n + 1], dim=1)
        return loss


    def compose_image_with_white_background(self, img: torch.tensor) -> torch.tensor:
        if img.shape[-1] == 3:  # return img if it is already rgb
            return img
        # Compose img with white background
        alpha = img[:, :, 3:4]
        img = alpha * img[:, :, :3] + (1 - alpha) * torch.ones(
            img.shape[0], img.shape[1], 3, device=self.device)
        return img

    def read_png_image_from_path(self, path_to_png_image: str) -> torch.tensor:
        numpy_image = skimage.io.imread(path_to_png_image)
        normalized_tensor_image = torch.from_numpy(numpy_image).to(
            torch.float32) / 255.0

        resizer = torchvision.transforms.Resize((224, 224))
        resized_image = resizer(normalized_tensor_image.permute(2, 0, 1)
                                ).permute(1, 2, 0)
        return resized_image

    def reference_images_feature(self, reference_images_path):
        reference_images_num = len(os.listdir(reference_images_path))
        reference_images_feature = []
        for i in range(reference_images_num):
            i_reference_image = self.read_png_image_from_path(os.path.join(reference_images_path, str(i) + ".png"))
            if i_reference_image.shape[-1] == 4:
                i_reference_image = self.compose_image_with_white_background(i_reference_image)
            # targets_ = self.preprocess(i_reference_image.permute(2, 0, 1).unsqueeze(0)).to(self.device)
            i_reference_image_features = self.model.encode_image(i_reference_image.permute(2, 0, 1).unsqueeze(0).to(self.device)).detach()
            reference_images_feature.append(i_reference_image_features)
        return torch.cat(reference_images_feature)


def read_png_image_from_path(path_to_png_image: str) -> torch.tensor:
    if path_to_png_image.endswith('.webp'):
        numpy_image = np.array(webp.load_image(path_to_png_image))
    else:
        numpy_image = skimage.io.imread(path_to_png_image)
    normalized_tensor_image = torch.from_numpy(numpy_image).to(
        torch.float32) / 255.0

    resizer = torchvision.transforms.Resize((224, 224))
    resized_image = resizer(normalized_tensor_image.permute(2, 0, 1)
                            ).permute(1, 2, 0)
    return resized_image


if __name__ == '__main__':
    torch.autograd.set_detect_anomaly(True)
    from tqdm import tqdm
    def get_bezier_circle(radius: float = 80,
                          segments: int = 4,
                          bias: np.array = np.asarray([100., 100.])):
        deg = torch.arange(0, segments * 3 + 1) * 2 * np.pi / (segments * 3 + 1)
        points = torch.stack((torch.cos(deg), torch.sin(deg))).T
        points = points * radius + torch.tensor(bias).unsqueeze(dim=0)
        points = points.type(torch.FloatTensor).contiguous()
        return points
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    matchLoss = GeometrymatchLoss(device, "reference_images/")
    # print(matchLoss.reference_images_feature.shape)
    # img1 = read_png_image_from_path('learn/output.png')
    canvas_width, canvas_height = 224, 224
    num_segments=4

    points1 = get_bezier_circle()

    path = pydiffvg.Path(num_control_points=torch.tensor(num_segments * [2] + [0],dtype=torch.int32), points=points1, stroke_width=torch.tensor(2.0),
                          is_closed=True)
    shapes=[path]
    path.points.requires_grad = True
    print(id(path.points))
    print(id(points1))
    points_vars = []
    points_vars.append(path.points)
    points_optim = torch.optim.Adam(points_vars, lr=1)
    pbar = tqdm(range(100000))
    print(points1)
    for t in pbar:
        # print(t)
        points_optim.zero_grad()
        # print("match_loss:", match_loss)
        match_loss = matchLoss(t,224, 224, shapes)

        match_loss.backward()
        # print(path.points.grad)
        points_optim.step()
        pbar.set_postfix({"match_loss": f"{match_loss.item()}"})
        # print(points_vars[0])


    pass

迭代1000轮次后生成的结果
在这里插入图片描述

没有图像增强

import math
import collections
import CLIP_.clip as clip
import torch
import torch.nn as nn
from torchvision import models, transforms
import numpy as np
import webp
from PIL import Image
import skimage
import torchvision
import pydiffvg
import os
import torch.nn.functional as F

class GeometrymatchLoss(torch.nn.Module):
    def __init__(self, device, reference_images_path):
        super(GeometrymatchLoss, self).__init__()
        self.device = device
        self.model, clip_preprocess = clip.load(
            'ViT-B/32', self.device, jit=False)
        self.model.eval()
        self.preprocess = transforms.Compose(
            [clip_preprocess.transforms[0], clip_preprocess.transforms[-1]])  # clip normalisation
        # self.preprocess = transforms.Compose([clip_preprocess.transforms[-1]])  # clip normalisation
        self.reference_images_feature = self.reference_images_feature(reference_images_path)
        self.reference_images_feature =self.reference_images_feature/ self.reference_images_feature.norm(dim=-1, keepdim=True)
        self.text = clip.tokenize([ "A picture of triangle"]).to(device)
        # self.text = clip.tokenize(["A picture of rectangle", "A picture of triangle", "A picture of circle", "A picture of pentagon","A picture of five-pointed star"]).to(device)
        self.text_features = self.model.encode_text(self.text)
        self.text_features = self.text_features / self.text_features.norm(dim=-1, keepdim=True)
        print("text_features.requires_grad:",self.text_features.requires_grad)
        self.text_features=self.text_features.detach()
        self.shape_groups=[pydiffvg.ShapeGroup(shape_ids=torch.tensor([0]), fill_color=torch.tensor([0.0, 0.0, 0.0, 1.0]),
                                       stroke_color=torch.tensor([0.0, 0.0, 0.0, 1.0]))]

        # Image Augmentation Transformation
        self.augment_trans = transforms.Compose([
            transforms.RandomPerspective(fill=1, p=1, distortion_scale=0.5),
            transforms.RandomResizedCrop(224, scale=(0.7, 0.9)),
        ])



    def forward(self, t,canvas_width, canvas_height,shapes):

        scene_args = pydiffvg.RenderFunction.serialize_scene(canvas_width, canvas_height, shapes, self.shape_groups)
        # 渲染图像
        render = pydiffvg.RenderFunction.apply
        target = render(canvas_width, canvas_height, 2, 2, 0, None, *scene_args)

        if target.shape[-1] == 4:
            target = self.compose_image_with_white_background(target)
        if t%100==0:
            pydiffvg.imwrite(target.cpu(), f'learn/log/output_{t}.png', gamma=2.2)
        # targets_ = self.preprocess(target.permute(2, 0, 1).unsqueeze(0)).to(self.device)
        img = target.unsqueeze(0)
        img = img.permute(0, 3, 1, 2)
        loss = 0
        NUM_AUGS = 4
        img_augs = []
        for n in range(NUM_AUGS):
            img_augs.append(self.augment_trans(img))
        im_batch = torch.cat(img_augs)
        image_features = self.model.encode_image(img)
        self.targets_features: torch.tensor=image_features[0]
        self.targets_features = self.targets_features / self.targets_features.norm(dim=-1, keepdim=True)
        loss -= torch.cosine_similarity(self.text_features, self.targets_features, dim=1)

        return loss


    def compose_image_with_white_background(self, img: torch.tensor) -> torch.tensor:
        if img.shape[-1] == 3:  # return img if it is already rgb
            return img
        # Compose img with white background
        alpha = img[:, :, 3:4]
        img = alpha * img[:, :, :3] + (1 - alpha) * torch.ones(
            img.shape[0], img.shape[1], 3, device=self.device)
        return img

    def read_png_image_from_path(self, path_to_png_image: str) -> torch.tensor:
        numpy_image = skimage.io.imread(path_to_png_image)
        normalized_tensor_image = torch.from_numpy(numpy_image).to(
            torch.float32) / 255.0

        resizer = torchvision.transforms.Resize((224, 224))
        resized_image = resizer(normalized_tensor_image.permute(2, 0, 1)
                                ).permute(1, 2, 0)
        return resized_image

    def reference_images_feature(self, reference_images_path):
        reference_images_num = len(os.listdir(reference_images_path))
        reference_images_feature = []
        for i in range(reference_images_num):
            i_reference_image = self.read_png_image_from_path(os.path.join(reference_images_path, str(i) + ".png"))
            if i_reference_image.shape[-1] == 4:
                i_reference_image = self.compose_image_with_white_background(i_reference_image)
            # targets_ = self.preprocess(i_reference_image.permute(2, 0, 1).unsqueeze(0)).to(self.device)
            i_reference_image_features = self.model.encode_image(i_reference_image.permute(2, 0, 1).unsqueeze(0).to(self.device)).detach()
            reference_images_feature.append(i_reference_image_features)
        return torch.cat(reference_images_feature)


def read_png_image_from_path(path_to_png_image: str) -> torch.tensor:
    if path_to_png_image.endswith('.webp'):
        numpy_image = np.array(webp.load_image(path_to_png_image))
    else:
        numpy_image = skimage.io.imread(path_to_png_image)
    normalized_tensor_image = torch.from_numpy(numpy_image).to(
        torch.float32) / 255.0

    resizer = torchvision.transforms.Resize((224, 224))
    resized_image = resizer(normalized_tensor_image.permute(2, 0, 1)
                            ).permute(1, 2, 0)
    return resized_image


if __name__ == '__main__':
    torch.autograd.set_detect_anomaly(True)
    from tqdm import tqdm
    def get_bezier_circle(radius: float = 80,
                          segments: int = 4,
                          bias: np.array = np.asarray([100., 100.])):
        deg = torch.arange(0, segments * 3 + 1) * 2 * np.pi / (segments * 3 + 1)
        points = torch.stack((torch.cos(deg), torch.sin(deg))).T
        points = points * radius + torch.tensor(bias).unsqueeze(dim=0)
        points = points.type(torch.FloatTensor).contiguous()
        return points
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    matchLoss = GeometrymatchLoss(device, "reference_images/")
    # print(matchLoss.reference_images_feature.shape)
    # img1 = read_png_image_from_path('learn/output.png')
    canvas_width, canvas_height = 224, 224
    num_segments=4

    points1 = get_bezier_circle()

    path = pydiffvg.Path(num_control_points=torch.tensor(num_segments * [2] + [0],dtype=torch.int32), points=points1, stroke_width=torch.tensor(2.0),
                          is_closed=True)
    shapes=[path]
    path.points.requires_grad = True
    print(id(path.points))
    print(id(points1))
    points_vars = []
    points_vars.append(path.points)
    points_optim = torch.optim.Adam(points_vars, lr=1)
    pbar = tqdm(range(100000))
    print(points1)
    for t in pbar:
        # print(t)
        points_optim.zero_grad()
        # print("match_loss:", match_loss)
        match_loss = matchLoss(t,224, 224, shapes)

        match_loss.backward()
        # print(path.points.grad)
        points_optim.step()
        pbar.set_postfix({"match_loss": f"{match_loss.item()}"})
        # print(points_vars[0])


    pass

迭代1000轮次后生成的结果
在这里插入图片描述
迭代2000轮次后生成的结果
在这里插入图片描述
迭代4000轮次后生成的结果
在这里插入图片描述
迭代8000轮次后生成的结果
在这里插入图片描述

无图像增强效果不好的原因分析

论文CLIPDraw: Exploring Text-to-Drawing Synthesisthrough Language-Image Encoders解释

在这里插入图片描述

论文StyleCLIPDraw: Coupling Content and Style in Text-to-Drawing Translation解释

在这里插入图片描述

个人理解

因为有很多图片可以和一个文本相匹配,对于我们人来说这些图片有一个根本和文本不相关,如果进行图像增强大概率会得到局部最优值。在计算损失函数之前对图片先进行增强,透过透视等变换,相关的图片不论如何变换和文本的相似度基本不会降低,而不相关的图像变换完之后一般会让相似度降低,这样就可以防止不相关图片对实验结果的影响。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2091818.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

基于单片机的楼宇消防控制系统设计

本设计基于单片机的楼宇消防控制系统,主要包括温湿度检测模块、空气质量检测模块、火焰检测模块、ZigBee通信模块、报警模块和自动喷水模块。首先,系统通过温湿度检测模块实时监测楼道内的温湿度状况,以便及时掌握火灾发生前的环境变化。其次…

足底筋膜炎怎么治疗效果好

足底筋膜炎的症状 足底筋膜炎是一种常见的足部疾病,主要表现为足底区域(尤其是脚跟附近)的疼痛和不适。这种疼痛在早晨起床或长时间休息后初次站立时尤为明显,被形象地称为“晨间痛”。随着行走时间的增加,疼痛可能会…

直击源头!劳保鞋厂家揭秘机械制造业防护安全鞋挑选秘籍

在机械制造业这一高风险、高强度的行业中,选择合适的劳保鞋对于保障工人的安全至关重要。作为劳保鞋的生产厂家,我们深知一双优质的防护鞋能为工人提供怎样的保护。今天百华小编和大家从多个维度看一下机械制造业是如何挑选防护安全鞋的挑选秘籍&#xf…

四款远程控制分享!你pick哪一款?

远程控制软件已经成为我们日常生活中不可或缺的一部分,无论是远程办公、技术支持还是家庭娱乐,这些软件都扮演着重要的角色。今天,我们就来聊聊几款市面上比较热门的远程控制软件在电脑远程操作上都有哪些表现呢?让我们一探究竟。…

ArkTS语法题

1. 下面示例中会导致编译报错的有? A. let x: number null; B. let x: number | null null; C. let y: string null; D. let y: string 100; 看来GPT对这种标准概念选择,也没有统一的说法。 - 文心…

【3.8】贪心算法-解无重叠区间

一、题目 给定一个区间的集合 intervals ,其中 intervals[i] [starti, endi] 。返回 需要移除区间的最小数量,使剩余区间互不重叠 。 示例 1: 输入: intervals [[1,2],[2,3],[3,4],[1,3]] 输出: 1 解释: 移除 [1,3] 后,剩下的区间没有重叠…

ImportError: Missing optional dependency ‘openpyxl‘.报错已解决

🎬 鸽芷咕:个人主页 🔥 个人专栏: 《C干货基地》《粉丝福利》 ⛺️生活的理想,就是为了理想的生活! 引言: 在开发过程中,你是否遇到过导入模块时出现ImportError: Missing optional dependency openpyxl的…

绘制YOLOv9训练结果的mAP0.5变化曲线

本文绘制mAP0.5在训练过程中的变化曲线(Python脚本),用以比较不同算法的收敛速度,最终精度等,并且能够在论文中直观的展示改进效果。 以下是比较了三个模型的mAP0.5变化曲线,数据来源是直接读取三个训练完…

Flink1.14.* 各种算子在StreamTask控制下如何调用的源码

前言:一、StreamTask执行算子的生命周期二、 Source的streamTask用的是SourceStreamTask三、基础转换操作,窗口用的是OneInputStreamTask1、初始化OneInputStreamTask2、StreamTask运行invoke调用的是StreamTask的processInput方法3、从缓冲区获取数据放…

从0到DevOps(1)-初步了解DevOps和容器

DevOps从提出以来陆续成为行业普遍实践,目前是数字化生产普遍不可或缺的信息底座。本系列文章旨在系统性的阐述与认识DevOps, 了解企业实践里DevOps的实际面貌。 什么是DevOps? DevOps 是一套实践、工具和文化理念,为实现用户不断的软件功能和可用性要…

学会这5个AI变现方法,让你在小红书上轻松赚钱!

大家好!最近AI真是大火,尤其是ChatGPT、Midjourney这些AI工具,感觉不搞点AI相关的内容,都跟不上潮流啦! 作为一个深耕小红书的内容创作者,我发现AI其实在小红书上有着巨大的变现潜力。 那么,如…

C--四种排序方法的补充

上一篇文章因为时间原因只写了三种,这一篇来补充第四种,第四种的代码更多,所需要理解的也是更多的。 堆排序 想要学会堆排序,你必须了解二叉树的内容。堆排序的排序速度也是非常的快。 这里都已大堆为例 1.向上调整算法&#…

JavaWeb - Spring Boot

Spring 官网​​​​​Spring | Home Spring Boot Spring Boot是一个由Pivotal团队提供的开源框架,旨在简化Spring应用的初始搭建以及开发过程。在Spring Boot项目中,通常会有Controller、Service、Mapper和Entity等层次结构。下面将详细介绍这些层次的…

Mac 安装Hadoop教程

1. 引言 本教程旨在介绍在Mac 电脑上安装Hadoop,便于编程开发人员对大数据技术的熟悉和掌握。 2.前提条件 2.1 安装JDK 想要在你的Mac电脑上安装Hadoop,你必须首先安装JDK。具体安装步骤这里就不详细描述了。你可参考Mac 下载JDK8。 2.2 配置ssh环境…

三分钟讲明白怎么用Fusion360和3D打印做模具

前言 模具,这东西听起来好像很常见,但是听到价格又很高大上,但是现在好消息是你可以在家里用3D打印方式实现一个模具,虽然是一个学习级的简易模具但是符合模具的9成要素 这里我们设计一个可以把热熔胶变成实物的模具 如何实现 1首…

生成密码c++

需求 目前需要实现生成8位密码,密码要求至少包含一位数字,一位大写字母,一位小写字母,一位特殊字符。如果用户第一次使用还没有输入密码,密码则为系统随机生成。 用户输入密码,符合规则则将默认密码覆盖掉…

重生之我们在ES顶端相遇第10 章- 分分分词器的基本使用

文章目录 思维导图0. 前言1. 光速上手1.1 指定分词器1.2 测试分词器 2. 分词流程(重要)2.1 基本介绍2.2 深入如何测试分词器 3. 自定义一个简单的分词器 思维导图 0. 前言 分词器在 ES 搜索使用中非常关键,一个好的分词器能够提高搜索的质量,让用户搜索…

进程间的通信(无名管道)

进程间通信 IPC InterProcess Communication 1.进程间通信方式 1.早期的进程间通信: 无名管道(pipe)、有名管道(fifo)、信号(signal) 2.system V PIC: 共享内存(share memory)、信号灯集(semaphore)、消息队列(message queue) 3.BSD: 套接字(socket) 2.无…

AI壁纸套装,单月变现7000+,手把手教你,别说你还不会

介绍 这种类型的手机壁纸,平板壁纸,电脑壁纸,甚至是手表壁纸,流量都很不错,尤其是深受一些女性的喜欢。 变现能力也不错,而且变现方式也多种多样。 今天就一步一步的教大家如何制作这种壁纸,怕…

本地部署 Flux.1 最强文生图大模型!Comfyui 一键安装

前言 最近,由前 Stability AI员工创立的黑森林实验室推出了开源文生图大模型–FLUX.1横空出世。 FLUX.1在文字生成、复杂指令遵循和人手生成上具备优势。以下是其生成图像示例,可以看到即使是生成大段的文字、多个人物,也没有出现字符、人手…