笔记:BLIP源码之(1)数据集预处理【仅考虑Image-Text Retrieval on COCO】

news2025/1/11 2:57:29

BLIP:Bootstrapping Language-Image Pre-training for Unified Vision-Language Understanding and Generat 论文的两个贡献如下:

  1. 从模型的角度:提出了 Encoder-Decoder (MED) 的多模态混合

An MED can operate either as a unimodal encoder, or an image-grounded text encoder, or an image-grounded text decoder.

  1. 从数据的角度:提出了 Captioning and Filtering (CapFilt)

We finetune a pre-trained MED into two modules: a captioner to produce synthetic captions given web images, and a filter to remove noisy captions from both the original web texts and the synthetic texts.

Image-Text Retrieval 任务 on COCO:

1. 先看处理训练集的类

定义了一个处理训练集的类,继承PyTorch中用于处理数据集的基类Dataset,通常情况下,自定义的Dataset类需要实现两个方法:__ len____ getitem__

  • __ len__方法返回数据集的大小,即数据集中样本的总数
  • __getitem__方法用于根据给定的索引返回数据集中对应位置的样本。
class coco_karpathy_train(Dataset):
    def __init__(self, transform, image_root, ann_root, max_words=30, prompt=''):        
        '''省略部分代码'''
        
        # 给每个图像进行编号,编号方式:
        # image_id:n
        self.img_ids = {}  
        n = 0
        for ann in self.annotation:
            img_id = ann['image_id']
            if img_id not in self.img_ids.keys():
                self.img_ids[img_id] = n
                n += 1    
                
    # 之前用函数加载了annotation文件:
    # self.annotation = json.load(open(os.path.join(ann_root,filename),'r'))
    # self.annotation是一个数组,数组中的每个元素是一个dict,如:
    # [{"caption": "A woman wearing a net on her head cutting a cake. ",
    # "image": "val2014/COCO_val2014_000000522418.jpg", "image_id": "coco_522418"}, 
    def __len__(self):
        return len(self.annotation)
    
    def __getitem__(self, index):    
        ann = self.annotation[index]
        image_path = os.path.join(self.image_root,ann['image'])   
        # Image是一个Python图像处理库,常用于图像的加载、处理和保存操作。     
        image = Image.open(image_path).convert('RGB')   
        # 对图像对变换
        image = self.transform(image)
        # prompt + 对caption进行预处理后 得到新的caption
        caption = self.prompt+pre_caption(ann['caption'], self.max_words) 
		# 返回transform后的图形、处理后的caption、图像对应的编号
        return image, caption, self.img_ids[ann['image_id']] 

附上pre_caption函数代码:

def pre_caption(caption,max_words=50):
	# 把这些符号:.!\"()*#:;~ 替换为空格,并且将caption全部转换为小写字母
    caption = re.sub(
        r"([.!\"()*#:;~])",       
        ' ',
        caption.lower(),
    )
    # 将连续出现两个或更多空格的地方替换为单个空格
    caption = re.sub(
        r"\s{2,}",
        ' ',
        caption,
    )
    # 去掉caption末尾的换行符
    caption = caption.rstrip('\n') 
    # 去掉caption 两边的空格
    caption = caption.strip(' ')
    #truncate caption
    caption_words = caption.split(' ')
    if len(caption_words)>max_words: # 如果超过了max_words,就只取前max_words个单词
        caption = ' '.join(caption_words[:max_words])
            
    return caption

2. 对图像进行数据增强

# 定义 normalize
# transforms.Normalize()函数接受两个参数,分别是均值(mean)和标准差(std)
# 均值(mean)和标准差(std) 这些参数是根据训练数据集的特征计算得出的。
# 分别对应三个通道(R、G、B)
normalize = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))

# 对训练集进行的transform
transform_train = transforms.Compose([     
			# 根据给定的 image_size 进行scale,以及使用BICUBIC插值方法进行图像的插值填充     
            transforms.RandomResizedCrop(config['image_size'],scale=(min_scale, 1.0),interpolation=InterpolationMode.BICUBIC),
            # 随机水平翻转
            transforms.RandomHorizontalFlip(),
            # 自定义的 RandomAugment 函数,下面会做记录
            # Identity(无操作)、AutoContrast(自动对比度调整)、Brightness(亮度调整)、
            # Sharpness(锐度调整)、Equalize(直方图均衡化)、ShearX(X轴方向的错切变换)、
            # ShearY(Y轴方向的错切变换)、TranslateX(X轴方向的平移变换)、
            # TranslateY(Y轴方向的平移变换)、Rotate(旋转变换)
            RandomAugment(2,5,isPIL=True,augs=['Identity','AutoContrast','Brightness','Sharpness','Equalize',
                                              'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate']),    
            # 将图像数据转换为PyTorch张量的格式 
            transforms.ToTensor(),
            normalize,
        ]) 

当使用BICUBIC插值方法进行图像插值填充时,原始图像上的像素值被用于计算新图像上每个像素的值。通过计算原始图像中像素的加权平均值,BICUBIC插值可以提供更平滑和连续的图像结果。

额外自定义了数据增强的代码:
在这里插入图片描述

class RandomAugment(object):
    def __init__(self, N=2, M=10, isPIL=False, augs=[]):
        self.N = N
        self.M = M
        # 是否是PIL格式的图像
        self.isPIL = isPIL
        if augs:
            self.augs = augs       
        else:
            self.augs = list(arg_dict.keys())

    def get_random_ops(self):
    	# 从augs这个数组中随机选择N个存储在 sampled_ops 列表中
        sampled_ops = np.random.choice(self.augs, self.N)
        return [(op, 0.5, self.M) for op in sampled_ops]

    def __call__(self, img):
        if self.isPIL:
        # 将PIL图像对象转换为NumPy数组形式
            img = np.array(img)            
        ops = self.get_random_ops()
        for name, prob, level in ops:
        	# 根据概率判断是否应用当前的增强操作
            if np.random.random() > prob:
                continue
            args = arg_dict[name](level)
            # 这个 *args 包括 上一行代码得到的(level, replace_value)
            img = func_dict[name](img, *args) 
        return img

__call__函数是Python中的特殊方法(special method),用于使对象可以像函数一样被调用,当调用该实例时,会自动执行__call__方法,并按照其中的逻辑进行

有很多ops操作,只选择一个记录TranslateX

translate_const = 10
MAX_LEVEL = 10
replace_value = (128, 128, 128)

func_dict = {
   '''省略部分代码'''
      'TranslateX': translate_x_func,
   '''省略部分代码'''
}

def translate_x_func(img, offset, fill=(0, 0, 0)):
 # offset:水平平移的偏移量,表示图像将向右平移的像素数。
 # fill:边界填充的颜色,默认为(0, 0, 0),表示黑色填充
    '''
        same output as PIL.Image.transform
    '''
    # 这个img已经是numpy数组了
    H, W = img.shape[0], img.shape[1]
    # 平移矩阵M
    M = np.float32([[1, 0, -offset], [0, 1, 0]])
    # 对输入图像进行仿射变换,将平移矩阵M应用于图像
    out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)
    return out

arg_dict = {
   '''省略部分代码'''
      'TranslateX': translate_level_to_args(
        translate_const, MAX_LEVEL, replace_value
    ),
    '''省略部分代码'''
}

def translate_level_to_args(translate_const, MAX_LEVEL, replace_value):
    def level_to_args(level): # 将level转换为一组用于平移操作的参数
    	# 将传入的level除以MAX_LEVEL,然后乘以translate_const,得到一个平移的具体数值
        level = (level / MAX_LEVEL) * float(translate_const)
        # 以50%的概率将平移的数值取反,实现随机选择正向或负向平移
        if np.random.random() > 0.5: level = -level
        return (level, replace_value)
	# 返回 level_to_args 这个函数
    return level_to_args

2. 对于验证集和测试集

val和test的annotation也是list,list中每个元素都是dict,包含两个键值,一个image,一个caption,其中caption是list,如下:

{"image": "val2014/COCO_val2014_000000184613.jpg",
  "caption": ["A child holding a flowered umbrella and petting a yak.",
         "A young man holding an umbrella next to a herd of cattle.",
         "a young boy barefoot holding an umbrella touching the horn of a cow",
         "A young boy with an umbrella who is touching the horn of a cow.",
         "A boy holding an umbrella while standing next to livestock."]}
class coco_karpathy_retrieval_eval(Dataset):
    def __init__(self, transform, image_root, ann_root, split, max_words=30):  
       '''省略部分代码'''
        self.text = []
        # 保存每一张图片的路径的list
        self.image = []
        self.txt2img = {}
        self.img2txt = {}

        txt_id = 0
        # ann就是一个dict,包含"image"和 "caption",img_id 就是索引 index
        for img_id, ann in enumerate(self.annotation):
            self.image.append(ann['image'])
            self.img2txt[img_id] = []
            # 一个图片对应多个caption
            for i, caption in enumerate(ann['caption']):
            	# 对caption做预处理之后,把新的caption 放入text数组中
                self.text.append(pre_caption(caption,max_words))
                # txt_id是每一张图片对应的多个caption的index,这些txt_id放在一个list中:
                # {0 : [0, 1, 2,3,4]}
                self.img2txt[img_id].append(txt_id)
                # {0:0} {1:0} {2:0} {3:0} 表示txt_id到img_id的映射,
                # 多个text可以映射到同一张图片
                self.txt2img[txt_id] = img_id
                txt_id += 1
    '''__len__和  __getitem__的代码省略,和训练集的类一样'''                   

test和val数据集的transform相对于train的简单很多:

transform_test = transforms.Compose([
        transforms.Resize((config['image_size'],config['image_size']),interpolation=InterpolationMode.BICUBIC),
        transforms.ToTensor(),
        normalize,
        ])  

调用这两个实例就能得到三个数据集:

    elif dataset=='retrieval_coco':          
        train_dataset = coco_karpathy_train(transform_train, config['image_root'], config['ann_root'])
        val_dataset = coco_karpathy_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'val') 
        test_dataset = coco_karpathy_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'test')          
        return train_dataset, val_dataset, test_dataset    

以上,完成了自定义数据集,接下来则需要做数据集的loader,也就是可迭代的数据加载器

3. 数据集的loader(先不考虑分布式训练)

torch.utils.data.DataLoader是PyTorch中用于数据加载的类。它提供了一种方便的方式来迭代和批量处理数据。

DataLoader的主要作用是将自定义的数据集包装成一个可迭代的数据加载器,以便于在训练或测试过程中以批量的方式加载和处理数据

使用DataLoader可以实现以下功能:

  • 数据批量加载:DataLoader可以指定批量大小(batch size),在每次迭代中返回一个批量大小的数据。
  • 数据并行加载:DataLoader可以使用多个线程并行加载数据,提高数据加载的效率。
  • 数据随机打乱:DataLoader可以对数据进行随机打乱,增加训练的随机性,避免模型对数据的顺序产生依赖。
  • 数据预处理和转换:DataLoader可以通过transform参数传入的数据转换函数对数据进行预处理和转换。
  • 数据加载器迭代:通过迭代DataLoader对象,可以逐批地获取数据,方便模型进行训练或测试。

使用DataLoader需要指定以下参数:

  • dataset:要加载的数据集,通常是自定义的Dataset对象。
  • batch_size:每个批次的样本数量。
  • shuffle:是否在每个时期(epoch)重新打乱数据。
  • num_workers:用于数据加载的线程数。
  • collate_fn:用于批量处理样本的函数。

除了上述的参数,还有:

  • pin_memory:通常情况下,在使用GPU进行训练时,如果主机内存足够,建议将pin_memory设置为True,以提高数据加载到GPU的速度。但如果遇到内存不足的情况,可以将pin_memory设置为False,以节省内存资源。

  • sampler:用于指定数据加载的顺序和采样方式。sampler参数可以接受以下几种类型的取值:
    (1)SequentialSampler:顺序采样器,按照数据集的顺序依次采样数据,不进行随机打乱。
    (2)RandomSampler:随机采样器,在每个时期(epoch)中随机打乱数据,并按照打乱后的顺序进行采样。
    (3)SubsetRandomSampler:子集随机采样器,从给定的索引列表中随机采样数据,适用于对数据集的子集进行采样。
    (4)WeightedRandomSampler:加权随机采样器,根据给定的样本权重进行采样,用于处理类别不平衡的数据集。
    (5)自定义采样器:用户可以自定义采样器类,继承自Sampler,实现自己的数据采样逻辑。

  • drop_last:如果数据集的样本数量无法被批次大小整除,并且drop_last参数设置为True,则最后一个不完整的批次将被丢弃。这通常在训练过程中用于确保每个批次的大小保持一致,以提高训练的效率。

ps:本论文采用的是Pytorch提供的DistributedSampler作为分布式训练的采样器,而如果不是分布式训练,则把sampler设置成了None

    if args.distributed:
        num_tasks = utils.get_world_size()
        global_rank = utils.get_rank()            
        samplers = create_sampler([train_dataset], [True], num_tasks, global_rank) + [None, None]
    else:
        samplers = [None, None, None]

调用create_loader函数:

train_loader, val_loader, test_loader = 
create_loader([train_dataset, val_dataset, test_dataset],samplers,
               batch_size=[config['batch_size_train']]+[config['batch_size_test']]*2,
               # 工作线程数的列表
               num_workers=[4,4,4],
               is_trains=[True, False, False], 
               #数据集的collate函数列表,用于对每个批次的样本进行处理和组合。如果不需要特定的处理逻辑,可以设置为None
               collate_fns=[None,None,None]) 

create_loader函数:

def create_loader(datasets, samplers, batch_size, num_workers, is_trains, collate_fns):

	'''也许是我看的代码比较少的原因,看到这样做loader真的感觉很高效,代码简洁、清晰、好看,
	使用zip就可以依次把三个数据集的loader做好,灵活使用if来判断,可以共用代码,并且传入的参数
	也很特别,不是单独一个,而是包含3个元素的list,这样正好对应三个数据集'''
	
    loaders = [] # 用来保存三个数据集的loader
    for dataset,sampler,bs,n_worker,is_train,collate_fn in zip(datasets,samplers,batch_size,num_workers,is_trains,collate_fns):
        if is_train:
            # 如果sampler 是 None,也就是非分布式训练,则随机打乱
            # 否在,在分布式训练下,不需随机打乱
            shuffle = (sampler is None)
            # 训练集会把 最后一个不完整的批次丢掉
            drop_last = True
        else:
            # 在val 和 test 数据集,既不随机打乱数据,也不会丢弃最后一个不完整的批次
            shuffle = False
            drop_last = False
        loader = DataLoader(
            dataset,
            batch_size=bs,
            num_workers=n_worker,
            pin_memory=True,
            sampler=sampler,
            shuffle=shuffle,
            collate_fn=collate_fn,
            drop_last=drop_last,
        )          
        # 把做好的loader加到list中    
        loaders.append(loader)
    return loaders  

4. 有关分布式训练中取样器的代码

create_sampler()函数用于创建分布式训练中的采样器(sampler):

# num_tasks:总任务数,即分布式训练中的进程数
# global_rank:当前进程的全局排名
def create_sampler(datasets, shuffles, num_tasks, global_rank):
    samplers = []
    for dataset,shuffle in zip(datasets,shuffles):
     # 遍历datasets和shuffles列表,对每个数据集创建一个分布式采样器,
     # 并将其添加到samplers列表中
        sampler = torch.utils.data.DistributedSampler(dataset, num_replicas=num_tasks, rank=global_rank, shuffle=shuffle)
        samplers.append(sampler)
    return samplers   

分布式采样器使用torch.utils.data.DistributedSampler类进行创建,需要指定数据集总任务数当前进程的全局排名是否进行洗牌

如果是要进行分布式训练,则需要获得总进程数以及进程排名,最后调用create_sampler函数:

    if args.distributed:
    	# 获得分布式训练环境中的总进程数
        num_tasks = utils.get_world_size()
        # 获取当前进程在分布式训练环境中的排名
        # 这样可以了解当前进程在整个分布式训练中的位置和角色,以便进行相应的操作和通信。
        global_rank = utils.get_rank()  
        # 对训练集做sampler,验证集和测试集不需要    
        samplers = create_sampler([train_dataset], [True], num_tasks, global_rank) + [None, None]

检查环境以及获得进程数:

def is_dist_avail_and_initialized():
    # 检查当前环境是否支持分布式训练
    if not dist.is_available():
        return False
    # 检查是否已经初始化了分布式训练环境
    if not dist.is_initialized():
        return False
    return True


def get_world_size():
    if not is_dist_avail_and_initialized():
        # 如果环境不支持或者未初始化,则默认进程数为1
        # 表示当前环境中只有一个进程
        return 1
    # 获取分布式训练环境中的总进程数,并返回该值
    return dist.get_world_size()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/552504.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Js常识三

文章目录 作用域GCclosure变量和函数提升函数参数 作用域 GC Js Gc 算法 引用计数(已淘汰)标记清除 closure 一句话:内层函数 外层函数的变量 闭包作用:私有化数据,or 私有化状态 变量和函数提升 Js 祖传var变…

C语言结构体初级

目录 一、为什么要用结构体 二、使用结构体的具体形式 1.结构体类型的声明(main函数外部) 2.结构体变量的定义(在main函数内或者外) 3.结构体变量的初始化 4.结构体成员的访问 5.结构体的传参 跑了这么久,再坚…

分布式软件架构——单体架构

序言 当一项大工程需要大量人员共同开发,并保证它们分布在网络中的大量服务器节点能够同时运行,那么随着项目规模的增大、运行时间变长,它必然会受到墨菲定律的无情打击。 Murphy’s Law:Anything that can go wrong will go wro…

Qt文件系统源码分析—第四篇QLockFile

深度 本文主要分析Windows平台,Mac、Linux暂不涉及 本文只分析到Win32 API/Windows Com组件/STL库函数层次,再下层代码不做探究 本文QT版本5.15.2 类关系图 QTemporaryFile继承QFile QFile、QSaveFile继承QFileDevice QFileDevice继承QIODevice Q…

法规标准-ISO 17361标准解读

ISO 17361是做什么的? ISO 17361全称为智能交通系统-车道偏离警告系统性能要求和测试程序,其中主要描述了LDWS系统的功能要求及测试要求 系统功能 车道偏离警告系统的功能元件应符合图中的要求,抑制请求、车速检测、驾驶员偏好和其他附加功…

[CTF/网络安全] 攻防世界 simple_js 解题详析

[CTF/网络安全] 攻防世界 simple_js 解题详析 代码分析代码漏洞姿势String[fromCharCode]总结 题目描述:小宁发现了一个网页,但却一直输不对密码。(Flag格式为 Cyberpeace{xxxxxxxxx} ) 页面源代码: 代码分析 function dechiffre(pass_enc){…

StarRocks 集群模式搭建

一、StarRocks 集群模型搭建 上篇文章对 StarRocks 进行了简单的介绍及使用 Docker 进行了快速体验,本篇文章进行StarRocks 集群模型的搭建,下面是上篇文章的地址: StarRocks 极速全场景 MPP 数据库介绍及使用 部署规划 host主机名角色192.…

求解包含约束的最优化问题:拉格朗日乘子法和KKT条件

文章目录 无约束等式约束不等式约束KKT条件 无约束 之前梯度类算法中介绍的最速下降法、牛顿法和拟牛顿法,可以直接使用的条件之一为:决策变量都是无约束的。 用数学语言描述的话,可以表达为:决策变量为 x ( x 1 , x 2 , ⋅ ⋅…

LeetCode104. 二叉树的最大深度(递归非递归)

写在前面: 题目链接:LeetCode104.二叉树的最大深度 编程语言:C 题目难度:简单 一、题目描述 给定一个二叉树,找出其最大深度。 二叉树的深度为根节点到最远叶子节点的最长路径上的节点数。 说明: 叶子节点是指没有子…

You Only Look Once:Unified,Real-Time Object Detection总结笔记

一、论文思想 1.将一个图像分成S*S个网格(grid cell),如果某个object的中心落在这个网格中,则这个网络就负责预测这个object。 2.每个网格要预测B个bounding box,每个bounding box除了要预测位置之外,还要…

微服务技术(SpringCloud、Docker、RabbitMQ)

目录 一、微服务技术简介 二、服务拆分及远程调用 1.Eureka注册中心 2.Nacos注册中心 3.Nacos配置管理 4.http客户端Feign 三、统一网关Gateway 四、Docker 五、异步通信技术 六、ElasticSearch 一、微服务技术简介 微服务是分布式架构(分布式&#xff…

Lesson14---卷积神经网络

14.1 深度学习基础 14.1.1 深度学习的基本思想 特征工程:尽可能选择和构建出好的特征,使得机器学习算法能够达到最佳性能。是机器学习的上限,而算法就是逼近这个上限传统的机器学习特证工程 依靠人工方式提取和设计特征需要大量的专业知识…

低代码系统前端实践之vue-element-admin运行demo

文章目录 1、简介2、实践功能3、实践过程3.0 下载运行demo3.1.1 解决执行npm install或出现以下报错(删掉组件tui-editor相关即可)3.1.2 解决执行npm run dev或出现no module body-parser(安装body-parser即可)3.1.3 解决执行npm run dev或出现error:0308010C:digital envelope…

RK3568平台开发系列讲解(驱动基础篇)RK平台I2C的使用

🚀返回专栏总目录 文章目录 一、I2C 使用情况二、定义和注册 I2C 设备三、定义和注册 I2C 驱动3.1 I2C 驱动定义3.2 I2C 驱动注册3.3 通过 I2C 收发数据沉淀、分享、成长,让自己和他人都能有所收获!😄 📢本篇将对RK I2C 的使用进行学习。 配置 I2C 可分为两大步骤: 定…

【Linux C】GCC编译 GDB调试 从入门到放弃 (gcc调试选项详解、gdb调试、条件断点、远程调试、脚本化调试)

阅读本文可能需要一些基础,比如:C语言基础、Linux基础操作、vim、防火墙等。篇幅有限,本文讲的“比较浅显”。 通过本文你将学会: gcc编译gdb调试 少年你渴望力量吗👇👇👇 一、使用GCC编译C程序…

Antd 下拉面板的位置计算错误

项目场景: 公司使用无界微前端集成ERP项目应用(可惜没跟着走一边无界,难受),某些子应用使用时,发现antd的弹窗弹出的位置不对。如下图: 问题描述 无界微前端嵌入的子应用中的antd的下拉框位置…

【谷粒商城笔记】基于docker的mysql、redis环境配置

0.系统 宝塔 v7.5.1 Centos v8.2 1. 安装Docker 直接yum install docker会提示找不到 > docker-client-latest \ docker-common \ docker-latest \ docker-latest-logrotate \ docker-logrotate \ docker-engine Loaded plugins: fastestmirror No Match for argument: …

Prometheus如何优化远程读写的性能

Prometheus如何优化远程读写的性能 场景 为了解决prometheus本地存储带来的单点问题,我们一般在高可用监控架构中会使用远程存储,并通过配置prometheus的remote_write和remote_read来对接 远程写优化:remote_write 用户可以在Prometheus配…

码上行动:零基础学会Python编程(文末送书)

🤵‍♂️ 个人主页:艾派森的个人主页 ✍🏻作者简介:Python学习者 🐋 希望大家多多支持,我们一起进步!😄 如果文章对你有帮助的话, 欢迎评论 💬点赞&#x1f4…

Day3 字符串中找出连续最长的数字串、数组中出现次数超过一半的数字

✨个人主页: 北 海 🎉所属专栏: C/C相关题解 🎃操作环境: Visual Studio 2019 版本 16.11.17 文章目录 选择题1、进程管理2、计算机组成原理 编程题1、字符串中找出连续最长的数字串2、数组中出现次数超过一半的数字 选…