diffusers 源码待理解之处

news2024/11/15 22:27:49

一、训练DreamBooth时,相关代码的细节小计

在这里插入图片描述
**

class_labels = timesteps 时,模型的前向传播怎么走?待深入去看

**

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

利用class_prompt去生成数据,而不是instance_prompt

在这里插入图片描述

class DreamBoothDataset(Dataset):
    """
    A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
    It pre-processes the images and the tokenizes prompts.
    """

    def __init__(
        self,
        instance_data_root,
        instance_prompt,
        tokenizer,
        class_data_root=None,
        class_prompt=None,
        class_num=None,
        size=512,
        center_crop=False,
        encoder_hidden_states=None,
        class_prompt_encoder_hidden_states=None,
        tokenizer_max_length=None,
    ):
        self.size = size
        self.center_crop = center_crop
        self.tokenizer = tokenizer
        self.encoder_hidden_states = encoder_hidden_states
        self.class_prompt_encoder_hidden_states = class_prompt_encoder_hidden_states
        self.tokenizer_max_length = tokenizer_max_length

        self.instance_data_root = Path(instance_data_root)
        if not self.instance_data_root.exists():
            raise ValueError(f"Instance {self.instance_data_root} images root doesn't exists.")

        self.instance_images_path = list(Path(instance_data_root).iterdir())
        self.num_instance_images = len(self.instance_images_path)
        self.instance_prompt = instance_prompt
        self._length = self.num_instance_images

        if class_data_root is not None:
            self.class_data_root = Path(class_data_root)
            self.class_data_root.mkdir(parents=True, exist_ok=True)
            self.class_images_path = list(self.class_data_root.iterdir())
            if class_num is not None:
                self.num_class_images = min(len(self.class_images_path), class_num)
            else:
                self.num_class_images = len(self.class_images_path)
            self._length = max(self.num_class_images, self.num_instance_images)
            self.class_prompt = class_prompt
        else:
            self.class_data_root = None

        self.image_transforms = transforms.Compose(
            [
                transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
                transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
                transforms.ToTensor(),
                transforms.Normalize([0.5], [0.5]),
            ]
        )

    def __len__(self):
        return self._length

    def __getitem__(self, index):
        example = {}
        instance_image = Image.open(self.instance_images_path[index % self.num_instance_images])
        instance_image = exif_transpose(instance_image)

        if not instance_image.mode == "RGB":
            instance_image = instance_image.convert("RGB")
        example["instance_images"] = self.image_transforms(instance_image)

        if self.encoder_hidden_states is not None:
            example["instance_prompt_ids"] = self.encoder_hidden_states
        else:
            text_inputs = tokenize_prompt(
                self.tokenizer, self.instance_prompt, tokenizer_max_length=self.tokenizer_max_length
            )
            example["instance_prompt_ids"] = text_inputs.input_ids
            example["instance_attention_mask"] = text_inputs.attention_mask

        if self.class_data_root:
            class_image = Image.open(self.class_images_path[index % self.num_class_images])
            class_image = exif_transpose(class_image)

            if not class_image.mode == "RGB":
                class_image = class_image.convert("RGB")
            example["class_images"] = self.image_transforms(class_image)

            if self.class_prompt_encoder_hidden_states is not None:
                example["class_prompt_ids"] = self.class_prompt_encoder_hidden_states
            else:
                class_text_inputs = tokenize_prompt(
                    self.tokenizer, self.class_prompt, tokenizer_max_length=self.tokenizer_max_length
                )
                example["class_prompt_ids"] = class_text_inputs.input_ids
                example["class_attention_mask"] = class_text_inputs.attention_mask

        return example
def tokenize_prompt(tokenizer, prompt, tokenizer_max_length=None):
    if tokenizer_max_length is not None:
        max_length = tokenizer_max_length
    else:
        max_length = tokenizer.model_max_length

    text_inputs = tokenizer(
        prompt,
        truncation=True,
        padding="max_length",
        max_length=max_length,
        return_tensors="pt",
    )

    return text_inputs
def collate_fn(examples, with_prior_preservation=False):
    has_attention_mask = "instance_attention_mask" in examples[0]

    input_ids = [example["instance_prompt_ids"] for example in examples]
    pixel_values = [example["instance_images"] for example in examples]

    if has_attention_mask:
        attention_mask = [example["instance_attention_mask"] for example in examples]

    # Concat class and instance examples for prior preservation.
    # We do this to avoid doing two forward passes.
    if with_prior_preservation:
        input_ids += [example["class_prompt_ids"] for example in examples]
        pixel_values += [example["class_images"] for example in examples]

        if has_attention_mask:
            attention_mask += [example["class_attention_mask"] for example in examples]

    pixel_values = torch.stack(pixel_values)
    pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()

    input_ids = torch.cat(input_ids, dim=0)

    batch = {
        "input_ids": input_ids,
        "pixel_values": pixel_values,
    }

    if has_attention_mask:
        attention_mask = torch.cat(attention_mask, dim=0)
        batch["attention_mask"] = attention_mask

    return batch

Dataset和Dataloader的构成
在这里插入图片描述
为了避免模型过拟合或者是说语言漂移的情况,需要用模型去用一个普通的prompt先生成样本。

fine-tune text-encoder,但是对显存要求更高
在这里插入图片描述

二、训练text to image,相关代码的细节小计

**

1、Dataloader的构建如下,但是为啥没有attention_mask呢?训练DreamBooth时有
2、训练或者微调模型时需要图文数据对,如果没有文本数据,可以用BLIP去生成图像描述的文本,但是文本描述不一定可靠
**

 # Get the datasets: you can either provide your own training and evaluation files (see below)
    # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).

    # In distributed training, the load_dataset function guarantees that only one local process can concurrently
    # download the dataset.
    if args.dataset_name is not None:
        # Downloading and loading a dataset from the hub.
        dataset = load_dataset(
            args.dataset_name,
            args.dataset_config_name,
            cache_dir=args.cache_dir,
            data_dir=args.train_data_dir,
        )
    else:
        data_files = {}
        if args.train_data_dir is not None:
            data_files["train"] = os.path.join(args.train_data_dir, "**")
        dataset = load_dataset(
            "imagefolder",
            data_files=data_files,
            cache_dir=args.cache_dir,
        )
        # See more about loading custom images at
        # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder

    # Preprocessing the datasets.
    # We need to tokenize inputs and targets.
    column_names = dataset["train"].column_names

    # 6. Get the column names for input/target.
    dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None)
    if args.image_column is None:
        image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
    else:
        image_column = args.image_column
        if image_column not in column_names:
            raise ValueError(
                f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}"
            )
    if args.caption_column is None:
        caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
    else:
        caption_column = args.caption_column
        if caption_column not in column_names:
            raise ValueError(
                f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}"
            )

    # Preprocessing the datasets.
    # We need to tokenize input captions and transform the images.
    def tokenize_captions(examples, is_train=True):
        captions = []
        for caption in examples[caption_column]:
            if isinstance(caption, str):
                captions.append(caption)
            elif isinstance(caption, (list, np.ndarray)):
                # take a random caption if there are multiple
                captions.append(random.choice(caption) if is_train else caption[0])
            else:
                raise ValueError(
                    f"Caption column `{caption_column}` should contain either strings or lists of strings."
                )
        inputs = tokenizer(
            captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
        )
        return inputs.input_ids

    # Preprocessing the datasets.
    train_transforms = transforms.Compose(
        [
            transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
            transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
            transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5]),
        ]
    )

    def preprocess_train(examples):
        images = [image.convert("RGB") for image in examples[image_column]]
        examples["pixel_values"] = [train_transforms(image) for image in images]
        examples["input_ids"] = tokenize_captions(examples)
        # images text pixel_values input_ids 4种key
        return examples

    with accelerator.main_process_first():
        if args.max_train_samples is not None:
            dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
        # Set the training transforms
        train_dataset = dataset["train"].with_transform(preprocess_train)

    def collate_fn(examples):
        pixel_values = torch.stack([example["pixel_values"] for example in examples])
        pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
        input_ids = torch.stack([example["input_ids"] for example in examples])
        return {"pixel_values": pixel_values, "input_ids": input_ids}

    # DataLoaders creation:
    train_dataloader = torch.utils.data.DataLoader(
        train_dataset,
        shuffle=True,
        collate_fn=collate_fn,
        batch_size=args.train_batch_size,
        num_workers=args.dataloader_num_workers,
    )

三、训ControlNet

Dataloader的搭建的代码如下:


1、新增conditioning_pixel_values图像数据,用于做可控的生成
2、输入中依旧没有attention-mask,待思考


def make_train_dataset(args, tokenizer, accelerator):
    # Get the datasets: you can either provide your own training and evaluation files (see below)
    # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).

    # In distributed training, the load_dataset function guarantees that only one local process can concurrently
    # download the dataset.
    if args.dataset_name is not None:
        # Downloading and loading a dataset from the hub.
        dataset = load_dataset(
            args.dataset_name,
            args.dataset_config_name,
            cache_dir=args.cache_dir,
        )
    else:
        if args.train_data_dir is not None:
            dataset = load_dataset(
                args.train_data_dir,
                cache_dir=args.cache_dir,
            )
        # See more about loading custom images at
        # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script

    # Preprocessing the datasets.
    # We need to tokenize inputs and targets.
    column_names = dataset["train"].column_names

    # 6. Get the column names for input/target.
    if args.image_column is None:
        image_column = column_names[0]
        logger.info(f"image column defaulting to {image_column}")
    else:
        image_column = args.image_column
        if image_column not in column_names:
            raise ValueError(
                f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
            )

    if args.caption_column is None:
        caption_column = column_names[1]
        logger.info(f"caption column defaulting to {caption_column}")
    else:
        caption_column = args.caption_column
        if caption_column not in column_names:
            raise ValueError(
                f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
            )

    if args.conditioning_image_column is None:
        conditioning_image_column = column_names[2]
        logger.info(f"conditioning image column defaulting to {conditioning_image_column}")
    else:
        conditioning_image_column = args.conditioning_image_column
        if conditioning_image_column not in column_names:
            raise ValueError(
                f"`--conditioning_image_column` value '{args.conditioning_image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
            )

    def tokenize_captions(examples, is_train=True):
        captions = []
        for caption in examples[caption_column]:
            if random.random() < args.proportion_empty_prompts:
                captions.append("")
            elif isinstance(caption, str):
                captions.append(caption)
            elif isinstance(caption, (list, np.ndarray)):
                # take a random caption if there are multiple
                captions.append(random.choice(caption) if is_train else caption[0])
            else:
                raise ValueError(
                    f"Caption column `{caption_column}` should contain either strings or lists of strings."
                )
        inputs = tokenizer(
            captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
        )
        return inputs.input_ids

    image_transforms = transforms.Compose(
        [
            transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
            transforms.CenterCrop(args.resolution),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5]),
        ]
    )

    conditioning_image_transforms = transforms.Compose(
        [
            transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
            transforms.CenterCrop(args.resolution),
            transforms.ToTensor(),
        ]
    )

    def preprocess_train(examples):
        images = [image.convert("RGB") for image in examples[image_column]]
        images = [image_transforms(image) for image in images]

        conditioning_images = [image.convert("RGB") for image in examples[conditioning_image_column]]
        conditioning_images = [conditioning_image_transforms(image) for image in conditioning_images]

        examples["pixel_values"] = images
        examples["conditioning_pixel_values"] = conditioning_images
        examples["input_ids"] = tokenize_captions(examples)

        return examples

    with accelerator.main_process_first():
        if args.max_train_samples is not None:
            dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
        # Set the training transforms
        train_dataset = dataset["train"].with_transform(preprocess_train)

    return train_dataset


def collate_fn(examples):
    pixel_values = torch.stack([example["pixel_values"] for example in examples])
    pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()

    conditioning_pixel_values = torch.stack([example["conditioning_pixel_values"] for example in examples])
    conditioning_pixel_values = conditioning_pixel_values.to(memory_format=torch.contiguous_format).float()

    input_ids = torch.stack([example["input_ids"] for example in examples])

    return {
        "pixel_values": pixel_values,
        "conditioning_pixel_values": conditioning_pixel_values,
        "input_ids": input_ids,
    }

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1353487.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

H5C3练习心得 2024.01.03(文字加载动画效果)--transition,动画渲染,遮罩层

&#xff08;一&#xff09;transition&#xff08;过渡效果&#xff09; 1.详解 通常将css的属性值更改后&#xff0c;浏览器会立即更新新的样式&#xff0c;例如在鼠标悬停在元素上时&#xff0c;通过 :hover 选择器定义的样式会立即应用在元素上。 在 CSS3 中加入了一项过…

Java虚拟机介绍

JVM是一种用于计算设备的规范&#xff0c;它是一个虚拟出来的计算机&#xff0c;是通过在实际的计算机上仿真模拟计算机的各个功能来实现的。Java语言的一个非常重要的特点就是与平台的无关性。而使用Java虚拟机是实现这一特点的关键。每个Java虚拟机都着一个清晰的任务&#x…

5分钟理解什么是多模态

大家好&#xff0c;我是董董灿。 大模型越来越多了&#xff0c;大模型下沉的行业也越来越多。前几周一个在电厂工作的老哥发消息问我&#xff1a;大模型中所谓的多模态是什么意思&#xff1f; 我当时大概跟他解释了一下。 其实在人工智能领域&#xff0c;我们经常会听到&quo…

leetcode递归算法题总结

递归本质是找重复的子问题 本章目录 1.汉诺塔2.合并两个有序链表3.反转链表4.两两交换链表中的节点5.Pow(x,n) 1.汉诺塔 汉诺塔 //面试写法 class Solution { public:void hanota(vector<int>& a, vector<int>& b, vector<int>& c) {dfs(a,b…

[DevOps-02] Code编码阶段工具

一、简要说明 在code阶段,我们需要将不同版本的代码存储到一个仓库中,常见的版本控制工具就是SVN或者Git,这里我们采用Git作为版本控制工具,GitLab作为远程仓库。 Git安装安装GitLab配置GitLab登录账户二、Git安装 Git官网 Githttps://git-scm.com/

HarmonyOS-ArkTS基本语法及声明式UI描述

初识ArkTS语言 ArkTS是HarmonyOS优选的主力应用开发语言。ArkTS围绕应用开发在TypeScript&#xff08;简称TS&#xff09;生态基础上做了进一步扩展&#xff0c;继承了TS的所有特性&#xff0c;是TS的超集。因此&#xff0c;在学习ArkTS语言之前&#xff0c;建议开发者具备TS语…

docker小白第十一天

docker小白第十一天 dockerfile分析 Dockerfile是用来构建Docker镜像的文本文件&#xff0c;是由一条条构建镜像所需的指令和参数构成的脚本。即构建新镜像时会用到。 构建三步骤&#xff1a;编写dockerfile文件-docker build命令构建镜像-docker run镜像 运行容器实例。即一…

Win32 基本程序设计原理总结

目录 1. Windows系统 基本原理 2. 需要什么函数库&#xff08;.LIB&#xff09; 2.1 C Runtimes&#xff1a; 2.2 Windows API 3. 需要什么头文件&#xff08;.H&#xff09; 4. Windows 程序运行的本质 5. 窗口类的注册与窗口的诞生 6.消息 6.1 消息分类&#xff1a;…

Vue3 结合typescript 组合式函数(2)

安装axios&#xff1a;npm install axios 1、hooks文件夹下新建useURLLoader 在APP.VUE中使用useURLLoader 使用Dog API 2、使用对象中的属性&#xff0c;必须使用toRefs&#xff0c;否则Reactive响应失效 3、使用泛型 结果&#xff1a;

VS2022 Android NativeActivity 开发指南

几年前最初使用VS时&#xff0c;记得是有Android NativeActivity的&#xff0c;今天更新到了2022最新版&#xff0c;发现找不到这个创建选项。 然后确保安装了C 跨平台开发工具后&#xff0c;开始排查原因。 Visual Studio 2022 中没有“本机活动应用程序” - android - SO中…

vue中$nextTick作用和实例

为什么要使用nextTick&#xff1f; vue中DOM更新是异步执行&#xff0c;相当于我们在修改数据的时候&#xff0c;视图是不会立即更新的&#xff0c;会先把新的数据攒一赞&#xff0c;例如假如v-for更新这三个数据item1和item2和item3&#xff0c;按照vue的特性dom更新的特性会…

SpingBoot的项目实战--模拟电商【4.订单及订单详情的生成】

&#x1f973;&#x1f973;Welcome Huihuis Code World ! !&#x1f973;&#x1f973; 接下来看看由辉辉所写的关于SpringBoot电商项目的相关操作吧 目录 &#x1f973;&#x1f973;Welcome Huihuis Code World ! !&#x1f973;&#x1f973; 一.功能需求 二.代码编写 …

CSS与JavaScript的简单认识

CSS&#xff1a;是一门语言&#xff0c;用于控制网页表现&#xff0c;让页面更好看的。 CSS&#xff08;Cascading Style Sheet&#xff09;&#xff1a;层叠样式表 CSS与html结合的三种方式&#xff1a; 1、内部样式&#xff1a;用style标签&#xff0c;在标签内部定义CSS样式…

java仓库管理系统Myeclipse开发mysql数据库web结构java编程计算机网页项目

一、源码特点 java Web仓库管理系统是一套完善的java web信息管理系统&#xff0c;对理解JSP java编程开发语言有帮助&#xff0c;系统具有完整的源代码和数据库&#xff0c;系统主要采用B/S模式开发。开发环境为 TOMCAT7.0,Myeclipse8.5开发&#xff0c;数据库为Mysql5.0&…

Python控制程控电源(USB)

文章目录 前言一、环境搭建1.软件安装2.硬件安装二、设置程控电源连接方式三、Python代码四、验证结果五、pyd文件前言 随着智能电动汽车行业的持续发展,汽车电子或嵌入式设备在软硬件的测试中,都会使用程控电源供电,特别是自动化测试、压力测试场景必定使用到程控电源控制…

【Mybatis】深入学习MyBatis:CRUD操作与动态SQL实战指南

&#x1f34e;个人博客&#xff1a;个人主页 &#x1f3c6;个人专栏&#xff1a; Mybatis ⛳️ 功不唐捐&#xff0c;玉汝于成 目录 前言 正文 一基本用法 1 CRUD操作 1. 增加&#xff08;Create&#xff09; 2. 查询&#xff08;Read&#xff09; 3. 更新&#x…

test ui-03-cypress 入门介绍

cypress 是什么&#xff1f; 简而言之&#xff0c;Cypress 是一款专为现代Web构建的下一代前端测试工具。我们解决了开发人员和质量保证工程师在测试现代应用程序时面临的关键问题。 我们使以下操作成为可能&#xff1a; 设置测试编写测试运行测试调试测试 Cypress经常与Se…

AUTOSAR软件手册文档缩写描述,AUTOSAR_TR_PredefinedNames

由于AUTOSAR文档中的缩写太多&#xff0c;入门者看起开不方便。例如TR、SWS、SRS、EXP模块。 下载链接&#xff1a;https://www.autosar.org/fileadmin/standards/R21-11/FO/AUTOSAR_TR_PredefinedNames.pdf

阿里云服务器端口PPTP 1723放行教程

阿里云服务器安装PPTP VPN需要先开通1723端口&#xff0c;阿里云服务器端口是在安全组中操作的&#xff0c;阿里云服务器网aliyunfuwuqi.com来详细说明阿里云服务器安全组开放PPTP VPN专用1723端口教程&#xff1a; 阿里云服务器放行1723端口教程 PPTP是点对点隧道协议&#…

用js让用户输入一个数累加和

需求&#xff1a;用户输入一个数&#xff0c; 计算 1 到这个数的和。 比如 用户输入的是 5&#xff0c; 则计算 1~5 之间的累加和 并且输出到控制台 <body><script>let numprompt(请输入一个数)let sum0for(let i1;i<num;i){sumi}console.log(sum)</script…