代码解读:Diffusion Models中的长宽桶技术(Aspect Ratio Bucketing)

news2025/1/12 18:13:40

Diffusion Models专栏文章汇总:入门与实战

前言:自从SDXL提出了长宽桶技术之后,彻底解决了不同长宽比的图像输入问题,现在已经成为训练扩散模型必选的方案。这篇博客从代码详细解读如何在模型训练的时候运用长宽桶技术(Aspect Ratio Bucketing)。

目录

原理解读-原有训练的问题

长宽桶技术(Aspect Ratio Bucketing)

完整代码


原理解读-原有训练的问题

纵横比分桶训练可以极大地提高输出质量,现有图像生成模型的一个常见问题是,它们非常容易生成带有非自然作物的图像。这是因为这些模型被训练成生成方形图像。然而,大多数照片和艺术品都不是方形的。然而,该模型只能同时在相同大小的图像上工作,并且在训练过程中,通常的做法是同时在多个训练样本上操作,以优化所使用gpu的效率。作为妥协,选择正方形图像,在训练过程中,只裁剪出每个图像的中心,然后作为训练样例显示给图像生成模型。

例如,人类通常是没有脚或头的,剑只有一个刀刃,剑柄和剑尖在框架外。因为我们正在创建一个图像生成模型来配合我们的故事叙述体验,所以我们的模型能够产生适当的,未裁剪的角色是很重要的,并且生成的骑士不应该持有延伸到无限的金属状直线。

对裁剪图像进行训练的另一个问题是,它可能导致文本和图像之间的不匹配。例如,带有王冠标签的图像通常在中央裁剪后不再包含王冠,因此君主已经被斩首。我们发现使用随机作物代替中心作物只能略微改善这些问题。使用具有可变图像大小的稳定扩散是可能的,尽管可以注意到,远远超过512x512的原生分辨率往往会引入重复的图像元素,并且非常低的分辨率会产生无法识别的图像。

尽管如此,这向我们表明,在可变大小的图像上训练模型应该是可能的。在单个、可变大小的样本上进行训练是微不足道的,但也非常缓慢,而且由于使用小批量提供的缺乏正则化,更容易产生训练不稳定性。

长宽桶技术(Aspect Ratio Bucketing)

由于这个问题似乎没有现有的解决方案,我们已经为我们的数据集实现了自定义批生成代码,允许创建批处理,其中批处理中的每个项目具有相同的大小,但批处理的图像大小可能不同。

我们通过一种叫做宽高比桶的方法来做到这一点。另一种方法是使用固定的图像大小,缩放每个图像以适应这个固定的大小,并应用在训练期间被掩盖的填充。由于这会导致训练期间不必要的计算,我们没有选择遵循这种替代方法。

在下面,我们描述了我们自定义的宽高比桶的批量生成方案背后的原始想法。

首先,我们必须定义要将数据集的图像排序到哪个存储桶中。为此,我们定义的最大图像尺寸为512x768,最大尺寸为1024。由于最大图像大小为512x768,比512x512大,需要更多的VRAM,因此每个gpu的批处理大小必须降低,这可以通过梯度积累来补偿。

我们通过应用以下算法生成桶:

Set the width to 256.
While the width is less than or equal to 1024:
    Find the largest height such that height is less than or equal to 1024 and that width multiplied by height is less than or equal to 512 * 768.
    Add the resolution given by height and width as a bucket.
    Increase the width by 64.

同样的重复,宽度和高度互换。重复的桶将从列表中删除,并添加一个大小为512x512的桶。

接下来,我们将图像分配到相应的桶中。为此,我们首先将桶分辨率存储在NumPy数组中,并计算每个分辨率的长宽比。对于数据集中的每张图像,我们检索其分辨率并计算长宽比。图像宽高比从桶宽高比数组中减去,使我们能够根据宽高比差的绝对值有效地选择最接近的桶:

image_bucket = argmin(abs(bucket_aspects — image_aspect))

图像的桶号与其数据集中的项目ID相关联。如果图像的宽高比非常极端,甚至与最适合的桶相差太大,则从数据集中修剪图像。

由于我们在多个GPU上进行训练,在每个epoch之前,我们对数据集进行了分片,以确保每个GPU在大小相等的不同子集上工作。为此,我们首先复制数据集中的项目id列表并对它们进行洗牌。如果这个复制的列表不能被gpu数量乘以批大小整除,则会对列表进行修剪,并删除最后的项以使其可整除。

然后,我们根据当前进程的全局排名选择1/world_size*bsz项id的不同子集。自定义批处理生成的其余部分将从这些过程中的任何一个过程中进行描述,并对数据集项id的子集进行操作。

对于当前的分片,每个bucket的列表是通过迭代打乱的数据集项目ID列表并将ID分配给分配给图像的bucket对应的列表来创建的。

处理完所有图像后,我们遍历每个bucket的列表。如果它的长度不能被批大小整除,则根据需要删除列表上的最后一个元素以使其可整除,并将它们添加到单独的捕获所有桶中。由于保证整个分片大小包含许多可被批大小整除的元素,因此保证生成一个长度可被批大小整除的所有bucket。

当请求批处理时,我们从加权分布中随机抽取一个桶。桶的权重设置为桶的大小除以所有剩余桶的大小。这确保了即使有大小差异很大的桶,自定义批生成在训练期间不会引入强烈的偏差,根据图像大小显示图像。如果在没有加权的情况下选择桶,那么小的桶将在训练过程中早期清空,只有最大的桶将在训练结束时保留。按大小对桶进行加权可以避免这种情况。

最后从所选的桶中取出一批项。取走的项目从桶中移除。如果桶现在为空,则在epoch的剩余时间内删除它。所选的项id和所选桶的分辨率现在被传递给图像加载函数。

完整代码

import numpy as np
import pickle
import time

def get_prng(seed):
    return np.random.RandomState(seed)

class BucketManager:
    def __init__(self, bucket_file, valid_ids=None, max_size=(768,512), divisible=64, step_size=8, min_dim=256, base_res=(512,512), bsz=1, world_size=1, global_rank=0, max_ar_error=4, seed=42, dim_limit=1024, debug=False):
        with open(bucket_file, "rb") as fh:
            self.res_map = pickle.load(fh)
        if valid_ids is not None:
            new_res_map = {}
            valid_ids = set(valid_ids)
            for k, v in self.res_map.items():
                if k in valid_ids:
                    new_res_map[k] = v
            self.res_map = new_res_map
        self.max_size = max_size
        self.f = 8
        self.max_tokens = (max_size[0]/self.f) * (max_size[1]/self.f)
        self.div = divisible
        self.min_dim = min_dim
        self.dim_limit = dim_limit
        self.base_res = base_res
        self.bsz = bsz
        self.world_size = world_size
        self.global_rank = global_rank
        self.max_ar_error = max_ar_error
        self.prng = get_prng(seed)
        epoch_seed = self.prng.tomaxint() % (2**32-1)
        self.epoch_prng = get_prng(epoch_seed) # separate prng for sharding use for increased thread resilience
        self.epoch = None
        self.left_over = None
        self.batch_total = None
        self.batch_delivered = None

        self.debug = debug

        self.gen_buckets()
        self.assign_buckets()
        self.start_epoch()

    def gen_buckets(self):
        if self.debug:
            timer = time.perf_counter()
        resolutions = []
        aspects = []
        w = self.min_dim
        while (w/self.f) * (self.min_dim/self.f) <= self.max_tokens and w <= self.dim_limit:
            h = self.min_dim
            got_base = False
            while (w/self.f) * ((h+self.div)/self.f) <= self.max_tokens and (h+self.div) <= self.dim_limit:
                if w == self.base_res[0] and h == self.base_res[1]:
                    got_base = True
                h += self.div
            if (w != self.base_res[0] or h != self.base_res[1]) and got_base:
                resolutions.append(self.base_res)
                aspects.append(1)
            resolutions.append((w, h))
            aspects.append(float(w)/float(h))
            w += self.div
        h = self.min_dim
        while (h/self.f) * (self.min_dim/self.f) <= self.max_tokens and h <= self.dim_limit:
            w = self.min_dim
            got_base = False
            while (h/self.f) * ((w+self.div)/self.f) <= self.max_tokens and (w+self.div) <= self.dim_limit:
                if w == self.base_res[0] and h == self.base_res[1]:
                    got_base = True
                w += self.div
            resolutions.append((w, h))
            aspects.append(float(w)/float(h))
            h += self.div
        res_map = {}
        for i, res in enumerate(resolutions):
            res_map[res] = aspects[i]
        self.resolutions = sorted(res_map.keys(), key=lambda x: x[0] * 4096 - x[1])
        self.aspects = np.array(list(map(lambda x: res_map[x], self.resolutions)))
        self.resolutions = np.array(self.resolutions)
        if self.debug:
            timer = time.perf_counter() - timer
            print(f"resolutions:\n{self.resolutions}")
            print(f"aspects:\n{self.aspects}")
            print(f"gen_buckets: {timer:.5f}s")

    def assign_buckets(self):
        if self.debug:
            timer = time.perf_counter()
        self.buckets = {}
        self.aspect_errors = []
        skipped = 0
        skip_list = []
        for post_id in self.res_map.keys():
            w, h = self.res_map[post_id]
            aspect = float(w)/float(h)
            bucket_id = np.abs(self.aspects - aspect).argmin()
            if bucket_id not in self.buckets:
                self.buckets[bucket_id] = []
            error = abs(self.aspects[bucket_id] - aspect)
            if error < self.max_ar_error:
                self.buckets[bucket_id].append(post_id)
                if self.debug:
                    self.aspect_errors.append(error)
            else:
                skipped += 1
                skip_list.append(post_id)
        for post_id in skip_list:
            del self.res_map[post_id]
        if self.debug:
            timer = time.perf_counter() - timer
            self.aspect_errors = np.array(self.aspect_errors)
            print(f"skipped images: {skipped}")
            print(f"aspect error: mean {self.aspect_errors.mean()}, median {np.median(self.aspect_errors)}, max {self.aspect_errors.max()}")
            for bucket_id in reversed(sorted(self.buckets.keys(), key=lambda b: len(self.buckets[b]))):
                print(f"bucket {bucket_id}: {self.resolutions[bucket_id]}, aspect {self.aspects[bucket_id]:.5f}, entries {len(self.buckets[bucket_id])}")
            print(f"assign_buckets: {timer:.5f}s")

    def start_epoch(self, world_size=None, global_rank=None):
        if self.debug:
            timer = time.perf_counter()
        if world_size is not None:
            self.world_size = world_size
        if global_rank is not None:
            self.global_rank = global_rank

        # select ids for this epoch/rank
        index = np.array(sorted(list(self.res_map.keys())))
        index_len = index.shape[0]
        index = self.epoch_prng.permutation(index)
        index = index[:index_len - (index_len % (self.bsz * self.world_size))]
        #print("perm", self.global_rank, index[0:16])
        index = index[self.global_rank::self.world_size]
        self.batch_total = index.shape[0] // self.bsz
        assert(index.shape[0] % self.bsz == 0)
        index = set(index)

        self.epoch = {}
        self.left_over = []
        self.batch_delivered = 0
        for bucket_id in sorted(self.buckets.keys()):
            if len(self.buckets[bucket_id]) > 0:
                self.epoch[bucket_id] = np.array([post_id for post_id in self.buckets[bucket_id] if post_id in index], dtype=np.int64)
                self.prng.shuffle(self.epoch[bucket_id])
                self.epoch[bucket_id] = list(self.epoch[bucket_id])
                overhang = len(self.epoch[bucket_id]) % self.bsz
                if overhang != 0:
                    self.left_over.extend(self.epoch[bucket_id][:overhang])
                    self.epoch[bucket_id] = self.epoch[bucket_id][overhang:]
                if len(self.epoch[bucket_id]) == 0:
                    del self.epoch[bucket_id]

        if self.debug:
            timer = time.perf_counter() - timer
            count = 0
            for bucket_id in self.epoch.keys():
                count += len(self.epoch[bucket_id])
            print(f"correct item count: {count == len(index)} ({count} of {len(index)})")
            print(f"start_epoch: {timer:.5f}s")

    def get_batch(self):
        if self.debug:
            timer = time.perf_counter()
        # check if no data left or no epoch initialized
        if self.epoch is None or self.left_over is None or (len(self.left_over) == 0 and not bool(self.epoch)) or self.batch_total == self.batch_delivered:
            self.start_epoch()

        found_batch = False
        batch_data = None
        resolution = self.base_res
        while not found_batch:
            bucket_ids = list(self.epoch.keys())
            if len(self.left_over) >= self.bsz:
                bucket_probs = [len(self.left_over)] + [len(self.epoch[bucket_id]) for bucket_id in bucket_ids]
                bucket_ids = [-1] + bucket_ids
            else:
                bucket_probs = [len(self.epoch[bucket_id]) for bucket_id in bucket_ids]
            bucket_probs = np.array(bucket_probs, dtype=np.float32)
            bucket_lens = bucket_probs
            bucket_probs = bucket_probs / bucket_probs.sum()
            bucket_ids = np.array(bucket_ids, dtype=np.int64)
            if bool(self.epoch):
                chosen_id = int(self.prng.choice(bucket_ids, 1, p=bucket_probs)[0])
            else:
                chosen_id = -1

            if chosen_id == -1:
                # using leftover images that couldn't make it into a bucketed batch and returning them for use with basic square image
                self.prng.shuffle(self.left_over)
                batch_data = self.left_over[:self.bsz]
                self.left_over = self.left_over[self.bsz:]
                found_batch = True
            else:
                if len(self.epoch[chosen_id]) >= self.bsz:
                    # return bucket batch and resolution
                    batch_data = self.epoch[chosen_id][:self.bsz]
                    self.epoch[chosen_id] = self.epoch[chosen_id][self.bsz:]
                    resolution = tuple(self.resolutions[chosen_id])
                    found_batch = True
                    if len(self.epoch[chosen_id]) == 0:
                        del self.epoch[chosen_id]
                else:
                    # can't make a batch from this, not enough images. move them to leftovers and try again
                    self.left_over.extend(self.epoch[chosen_id])
                    del self.epoch[chosen_id]

            assert(found_batch or len(self.left_over) >= self.bsz or bool(self.epoch))

        if self.debug:
            timer = time.perf_counter() - timer
            print(f"bucket probs: " + ", ".join(map(lambda x: f"{x:.2f}", list(bucket_probs*100))))
            print(f"chosen id: {chosen_id}")
            print(f"batch data: {batch_data}")
            print(f"resolution: {resolution}")
            print(f"get_batch: {timer:.5f}s")

        self.batch_delivered += 1
        return (batch_data, resolution)

    def generator(self):
        if self.batch_delivered >= self.batch_total:
            self.start_epoch()
        while self.batch_delivered < self.batch_total:
            yield self.get_batch()

if __name__ == "__main__":
    # prepare a pickle with mapping of dataset IDs to resolutions called resolutions.pkl to use this
    with open("resolutions.pkl", "rb") as fh:
        ids = list(pickle.load(fh).keys())

    counts = np.zeros((len(ids),)).astype(np.int64)
    id_map = {}
    for i, post_id in enumerate(ids):
        id_map[post_id] = i

    bm = BucketManager("resolutions.pkl", debug=True, bsz=8, world_size=8, global_rank=3)
    print("got: " + str(bm.get_batch()))
    print("got: " + str(bm.get_batch()))
    print("got: " + str(bm.get_batch()))
    print("got: " + str(bm.get_batch()))
    print("got: " + str(bm.get_batch()))
    print("got: " + str(bm.get_batch()))
    print("got: " + str(bm.get_batch()))

    bm = BucketManager("resolutions.pkl", bsz=8, world_size=1, global_rank=0, valid_ids=ids[0:16])
    for _ in range(16):
        bm.get_batch()
    print("got from future epoch: " + str(bm.get_batch()))

    bms = []
    for rank in range(16):
        bm = BucketManager("resolutions.pkl", bsz=8, world_size=16, global_rank=rank)
        bms.append(bm)
    for epoch in range(5):
        print(f"epoch {epoch}")
        for i, bm in enumerate(bms):
            print(f"bm {i}")
            first = True
            for ids, res in bm.generator():
                if first and i == 0:
                    #print(ids)
                    first = False
                for post_id in ids:
                    counts[id_map[post_id]] += 1
        print(np.bincount(counts))

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1938302.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

如何根据同一行的ID利用R语言对值进行求和

需求&#xff1a;将属于同一分组的对应的值进行求和或者求平均值 #设置工作目录 > getwd() [1] "C:/Users/86150/Documents" > setwd("C:/Users/86150/Desktop/AA2024/RUF") > list.files() #读取文件 >install.packages("readxl")…

建投数据人力资源系列产品获得欧拉操作系统及华为鲲鹏技术认证书

近日&#xff0c;经欧拉生态创新中心和华为技术有限公司测评&#xff0c;建投数据自主研发的人力资源管理系统、招聘管理系统、绩效管理系统、培训管理系统&#xff0c;完成了基于欧拉操作系统openEuler 22.03、华为鲲鹏Kunpeng 920&#xff08;Taisha 200&#xff09;的兼容性…

SVM 技能测试:25 个 MCQ 用于测试数据科学家的 SVM

SVM 技能测试:25 个 MCQ 用于测试数据科学家的 SVM(2024 年更新) 一、介绍 你可以把机器学习算法想象成一个装满斧头、剑和刀片的军械库。你有各种各样的工具,但你应该学会在正确的时间使用它们。打个比方,将“线性回归或逻辑回归”视为一把能够有效地切片和切块数据但…

uniapp vue3 上传视频组件封装

首先创建一个 components 文件在里面进行组件的创建 下面是 vvideo组件的封装 也就是图片上传组件 只是我的命名是随便起的 <template><!-- 上传视频 --><view class"up-page"><!--视频--><view class"show-box" v-for"…

纯硬件一键开关机电路的工作原理

这是一个一键开关机电路: 当按一下按键然后松开&#xff0c;MOS管导通&#xff0c;VOUT等于电源电压; 当再次按一下按键然后松开&#xff0c;MOS管关闭&#xff0c;VOUT等于0; 下面来分析一下这个电路的工作原理。上电后&#xff0c;输入电压通过R1和R2给电容充电&#xff0c;最…

MySQL通过bin-log恢复数据

MySQL通过bin-log恢复数据 1.bin-log说明2.数据恢复流程2.1 查看是否开启bin-log2.3 查看bin-log2.4 执行数据恢复操作2.5 检查数据是否恢复 1.bin-log说明 mysqldump和bin-log都可以作为MySQL数据库备份的方式&#xff1a; mysqldump 用于将整个或部分数据库导出为可执行的S…

TeraTerm 使用技巧

参考资料 自分がよく使うTeratermマクロによる自動ログインのやり方をまとめてみたよTera Term マクロでログインを自動化してみたTera Term のススメ 目录 简介一. 常用基础设置1.1 语言变更1.2 log设置 二. 小技巧2.1 指定host别名2.2 新开窗口2.3 设置粘贴多行命令时的行间…

【3D编程技巧】如何用四元数旋转矢量在相机空间进行光照计算

这里介绍一个小TIPS&#xff0c;很久没有这么有成就感了。我以前在学3D数学的时候&#xff0c;书上就有一句话&#xff0c;说你把矢量这些东西用久了&#xff0c;就应该形成一种“直觉”&#xff0c;仿佛这些东西就是你的左右手一样。而这次&#xff0c;我居然真的用“直觉”来…

基于上下文自适应可变长熵编码 CAVLC 原理详细分析

CAVLC CAVLC&#xff0c;即Context-Adaptive Variable-Length Coding&#xff0c;是一种用于视频压缩的编码技术&#xff0c;特别是在MPEG-4视频编码标准中使用。CAVLC是一种熵编码方法&#xff0c;它根据视频数据的上下文信息来调整编码长度&#xff0c;以实现更有效的数据压…

【从0到1,训练大模型,从llama3开始】

摘要: 随着大模型越来越多,大家肯定眼花缭乱。不知道选择哪个好,换句话说,不知道哪个才适合自己。 通过社长的实操:chatgpt3.5、gpt4、gpt4o、llama3、通义千问、豆包等大模型,总结是:大家都很好,都能一定程度上的帮助你。 不过怎么说呢,他们什么都懂,但是,什么都不…

sourcetree中常用功能使用方法及gitlab冲突解决

添加至缓存&#xff1a;等于git add 提交&#xff1a;等于git commit 拉取/获取&#xff1a;等于git pull ,在每次要新增代码或者提交代码前需要先拉取一遍服务器中最新的代码&#xff0c;防止服务器有其他人更新了代码&#xff0c;但我们自己本地的代码在我们更新前跟服务器不…

邮件安全篇:企业电子邮件安全涉及哪些方面?

1. 邮件安全概述 企业邮件安全涉及多个方面&#xff0c;旨在保护电子邮件通信的机密性、完整性和可用性&#xff0c;防止数据泄露、欺诈、滥用及其他安全威胁。本文从身份验证与防伪、数据加密、反垃圾邮件和反恶意软件防护、邮件内容过滤与审计、访问控制与权限管理、邮件存储…

面试题 17.14.最小K个数

题目&#xff1a;如下图 答案&#xff1a;如下图 /*** Note: The returned array must be malloced, assume caller calls free().*/ void AdjustDown(int* a,int n,int root) {int parent root;int child parent * 2 1;//默认左孩子是大的&#xff0c;将其与右孩子比较&am…

《机器学习》读书笔记:总结“第5章 神经网络”中的概念

&#x1f4a0;神经网络&#xff08;neural network&#xff09; 神经网络是由具有适应性的简单单元组成的广泛并行互联的网络&#xff0c;它的组织能够模拟生物神经系统对真实世界物体所作出的交互反应。 神经网络中最基本的成分是 神经元(neuron / unit)&#xff0c;即上述定…

机械臂泡水维修|机器人雨后进水维修措施

如果机器人不慎被水淹&#xff0c;别慌&#xff01;我们为你准备了一份紧急的机械臂泡水维修抢修指南&#xff0c;帮助你解决这个问题。 【机器人浸水被淹后紧急维修抢修&#xff5c;如何处理&#xff1f;】 机械臂被淹进水后维修处理方式 1. 机械手淹水后断电断网 首先&am…

Hive分布式SQL计算平台

Hive分布式SQL计算平台 一、Hive 概述二、Hive架构三、Hive客户端 1、Hive有哪些客户端可以使用2、Hive第三方客户端 四、Hive使用语法 1、数据库操作2、内部表&#xff0c;外部表3、数据的导入与导出4、分区表5、分桶表6、复杂类型操作7、数据抽样8、Virtual Columns 虚拟列9…

压缩视频大小的方法 怎么减少视频内存大小 几个简单方法

随着4K、8K高清视频的流行&#xff0c;我们越来越容易遇到视频文件体积过大&#xff0c;导致存储空间不足、传输速度缓慢等问题。视频压缩成为解决这一问题的有效途径&#xff0c;但如何在减小文件大小的同时&#xff0c;保证视频质量不受影响呢&#xff1f;本文将为你揭晓答案…

(10)深入理解pandas的核心数据结构:DataFrame高效数据清洗技巧

目录 前言1. DataFrame数据清洗1.1 处理缺失值&#xff08;NaNs&#xff09;1.1.1 数据准备1.1.2 读取数据1.1.3 查找具有 null 值或缺失值的行和列1.1.4 计算每列缺失值的总数1.1.5 删除包含 null 值或缺失值的行1.1.6 利用 .fillna&#xff08;&#xff09; 方法用Portfolio …

OpenCV Mat类简介,Mat对象创建与赋值 C++实现

在 C 中&#xff0c;OpenCV 提供了一个强大的类 Mat 来表示图像和多维矩阵。Mat 类是 OpenCV 中最基础和最常用的类&#xff0c;用于存储和操作图像数据。 文章目录 Mat类简介Mat 类的定义Mat 类的构造函数 代码示例深拷贝示例赋值示例浅拷贝示例 Mat类简介 Mat 类是一个多维…

【雷丰阳-谷粒商城 】【分布式高级篇-微服务架构篇】【29】Sentinel

持续学习&持续更新中… 守破离 【雷丰阳-谷粒商城 】【分布式高级篇-微服务架构篇】【29】Sentinel 简介熔断降级什么是熔断什么是降级相同点不同点 整合Sentinel自定义sentinel流控返回数据使用Sentinel来保护feign远程调用自定义资源给网关整合Sentinel参考 简介 熔断降…