Image和Video在同一个Dataloader中交错加载联合训练

news2025/1/7 19:13:26

单卡实现

本文主要从两个方面进行展开:
1.将两个或多个dataset组合成pytorch中的一个ConcatDataset.这个dataset将会作为pytorch中Dataloader的输入。
2.覆盖重写RandomSampler修改batch产生过程,以确保在第一个batch中产生第一个任务的数据(image),在第二个batch中产生下一个任务的数据(video)。

下述定义了一个BatchSchedulerSampler类,实现了一个新的sampler iterator。首先,通过为每一个单独的dataset创建RandomSampler;接着,在每一个dataset iter中获取对应的sample index;最后,创建新的sample index list

import math
import torch
from torch.utils.data.sampler import RandomSampler
from torch.utils.data.dataset import ConcatDataset
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from dataset.K600 import *


class BatchSchedulerSampler(torch.utils.data.sampler.Sampler):
    """
    iterate over tasks and provide a random batch per task in each mini-batch
    """
    def __init__(self, dataset, batch_size):
        self.dataset = dataset
        self.batch_size = batch_size
        self.number_of_datasets = len(dataset.datasets)
        self.largest_dataset_size = max([len(cur_dataset.samples) for cur_dataset in dataset.datasets])

    def __len__(self):
        return self.batch_size * math.ceil(self.largest_dataset_size / self.batch_size) * len(self.dataset.datasets)

    def __iter__(self):
        samplers_list = []
        sampler_iterators = []
        for dataset_idx in range(self.number_of_datasets):
            cur_dataset = self.dataset.datasets[dataset_idx]
            sampler = RandomSampler(cur_dataset)
            samplers_list.append(sampler)
            cur_sampler_iterator = sampler.__iter__()
            sampler_iterators.append(cur_sampler_iterator)

        push_index_val = [0] + self.dataset.cumulative_sizes[:-1]
        step = self.batch_size * self.number_of_datasets
        samples_to_grab = self.batch_size
        # for this case we want to get all samples in dataset, this force us to resample from the smaller datasets
        epoch_samples = self.largest_dataset_size * self.number_of_datasets

        final_samples_list = []  # this is a list of indexes from the combined dataset
        for _ in range(0, epoch_samples, step):
            for i in range(self.number_of_datasets):
                cur_batch_sampler = sampler_iterators[i]
                cur_samples = []
                for _ in range(samples_to_grab):
                    try:
                        cur_sample_org = cur_batch_sampler.__next__()
                        cur_sample = cur_sample_org + push_index_val[i]
                        cur_samples.append(cur_sample)
                    except StopIteration:
                        # got to the end of iterator - restart the iterator and continue to get samples
                        # until reaching "epoch_samples"
                        sampler_iterators[i] = samplers_list[i].__iter__()
                        cur_batch_sampler = sampler_iterators[i]
                        cur_sample_org = cur_batch_sampler.__next__()
                        cur_sample = cur_sample_org + push_index_val[i]
                        cur_samples.append(cur_sample)
                final_samples_list.extend(cur_samples)

        return iter(final_samples_list)


if __name__ == "__main__":
    image_dataset = ImageFolder(root='/mnt/workspace/data/imagenet/data/newtrain', transform=transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ]))
    video_dataset = VideoFolder(root='/mnt/workspace/data/k600/train_videos', transform=transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ]))
    joint_dataset = ConcatDataset([image_dataset, video_dataset])

    batch_size = 8
    dataloader = torch.utils.data.DataLoader(dataset=joint_dataset,
                sampler=BatchSchedulerSampler(dataset=joint_dataset, batch_size=batch_size),
                batch_size=batch_size, shuffle=False)

    num_epochs = 1
    for epoch in range(num_epochs):
        for inputs, labels in dataloader:
            print(inputs.shape)
'''
torch.Size([8, 3, 224, 224])
torch.Size([8, 3, 16, 224, 224])
torch.Size([8, 3, 224, 224])
torch.Size([8, 3, 16, 224, 224])
'''

DDP多卡实现

在多卡训练中使用分布式数据并行(DDP)时,你需要重写 DistributedSampler 而不是 RandomSampler,以确保每个进程都能正确地获取数据子集。以下是如何实现一个 BatchSchedulerDistributedSampler 类来支持多卡训练的示例

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2272123.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

n8n - AI自动化工作流

文章目录 一、关于 n8n关键能力n8n 是什么意思 二、快速上手 一、关于 n8n n8n是一个具有原生AI功能的工作流自动化平台,它为技术团队提供了代码的灵活性和无代码的速度。凭借400多种集成、原生人工智能功能和公平代码许可证,n8n可让您构建强大的自动化…

cursor 使用技巧

一、创建项目前期步骤 1.先给AI设定一个对应项目经理角色, 2.然后跟AI沟通项目功能,生成功能设计文件:README.md README.md项目功能 3.再让AI总结写出开发项目规则文件: .cursorrules 是技术栈进行限定,比如使用什…

xinput1_3.dll丢失修复方法。方法1-方法6

总结 xinput1_3.dll的核心作用 xinput1_3.dll作为Microsoft DirectX库的关键组件,对于游戏控制器的支持起着至关重要的作用。它不仅提供了设备兼容性、多控制器管理和反馈机制等核心功能,还通过XInput API简化了开发人员对控制器状态的检索和设备特性的…

【C++】P2550 [AHOI2001] 彩票摇奖

博客主页: [小ᶻ☡꙳ᵃⁱᵍᶜ꙳] 本文专栏: C 文章目录 💯前言💯题目描述输入格式:输出格式:输入输出样例: 💯题解思路1. 问题解析 💯我的实现实现逻辑问题分析 💯老…

01:C语言的本质

C语言的本质 1、ARM架构与汇编2、局部变量初始化与空间分配2.1、局部变量的初始化2.1、局部变量数组初始化 3、全局变量/静态变量初始化化与空间分配4、堆空间 1、ARM架构与汇编 ARM简要架构如下:CPU,ARM(能读能写),Flash(能读&a…

8086汇编(16位汇编)学习笔记10.寄存器总结

8086汇编(16位汇编)学习笔记10.寄存器总结-C/C基础-断点社区-专业的老牌游戏安全技术交流社区 - BpSend.net 寄存器 8086CPU有14个寄存器 它们的名称为: AX、BX、CX、DX、SI、DI、SP、BP、 IP**、CS、DS、ES、**SS、PSW。 8086CPU所有的寄存器都是16位的&#…

如何在 Ubuntu 22.04 上安装 Cassandra NoSQL 数据库教程

简介 本教程将向你介绍如何在 Ubuntu 22.04 上安装 Cassandra NoSQL 数据库。 Apache Cassandra 是一个分布式的 NoSQL 数据库,旨在处理跨多个普通服务器的大量数据,并提供高可用性,没有单点故障。Apache Cassandra 是一个高度可扩展的分布…

uni-app:实现普通选择器,时间选择器,日期选择器,多列选择器

效果 选择前效果 1、时间选择器 2、日期选择器 3、普通选择器 4、多列选择器 选择后效果 代码 <template><!-- 时间选择器 --><view class"line"><view classitem1><view classleft>时间</view><view class"right&quo…

centos,789使用mamba快速安装R及语言包devtools

如何进入R语言运行环境请参考&#xff1a;Centos7_miniconda_devtools安装_R语言入门之R包的安装_r语言devtools包怎么安装-CSDN博客 在R里面使用安装devtools经常遇到依赖问题&#xff0c;排除过程过于费时&#xff0c;使用conda安装包等待时间长等。下面演示centos,789都是一…

STM32第十一课:STM32-基于标准库的42步进电机的简单IO控制(附电机教程,看到即赚到)

一&#xff1a;步进电机简介 步进电机又称为脉冲电机&#xff0c;简而言之&#xff0c;就是一步一步前进的电机。基于最基本的电磁铁原理,它是一种可以自由回转的电磁铁,其动作原理是依靠气隙磁导的变化来产生电磁转矩&#xff0c;步进电机的角位移量与输入的脉冲个数严格成正…

kafka使用以及基于zookeeper集群搭建集群环境

一、环境介绍 zookeeper下载地址&#xff1a;https://zookeeper.apache.org/releases.html kafka下载地址&#xff1a;https://kafka.apache.org/downloads 192.168.142.129 apache-zookeeper-3.8.4-bin.tar.gz kafka_2.13-3.6.0.tgz 192.168.142.130 apache-zookee…

JSON结构快捷转XML结构API集成指南

JSON结构快捷转XML结构API集成指南 引言 在当今的软件开发世界中&#xff0c;数据交换格式的选择对于系统的互操作性和效率至关重要。JSON&#xff08;JavaScript Object Notation&#xff09;和XML&#xff08;eXtensible Markup Language&#xff09;是两种广泛使用的数据表…

Android14 CTS-R6和GTS-12-R2不能同时测试的解决方法

背景 Android14 CTS r6和GTS 12-r1之后&#xff0c;tf-console默认会带起OLC Server&#xff0c;看起来olc server可能是想适配ATS(android-test-station)&#xff0c;一种网页版可视化、可配置的跑XTS的方式。这种网页版ATS对测试人员是比较友好的&#xff0c;网页上简单配置下…

BurpSuite工具安装

BurpSuite介绍&#xff1a; BurpSuite是由PortSwigger开发的一款集成化的Web应用安全检测工具&#xff0c;广泛应用于Web应用的漏洞扫描和攻击模拟&#xff0c;主要用于抓包该包(消息拦截与构造) 一、Burp suite安装 windows系统需要提前配置好java环境&#xff0c;前面博客…

Win11+WLS Ubuntu 鸿蒙开发环境搭建(一)

参考文章 Windows11安装linux子系统 WSL子系统迁移、备份与导入全攻略 如何扩展 WSL 2 虚拟硬盘的大小 Win10安装的WSL子系统占用磁盘空间过大如何释放 《Ubuntu — 调整文件系统大小命令resize2fs》 penHarmony南向开发笔记&#xff08;一&#xff09;开发环境搭建 一&a…

flink cdc oceanbase(binlog模式)

接上文&#xff1a;一文说清flink从编码到部署上线 环境&#xff1a;①操作系统&#xff1a;阿里龙蜥 7.9&#xff08;平替CentOS7.9&#xff09;&#xff1b;②CPU&#xff1a;x86&#xff1b;③用户&#xff1a;root。 预研初衷&#xff1a;现在很多项目有国产化的要求&#…

和为0的四元组-蛮力枚举(C语言实现)

目录 一、问题描述 二、蛮力枚举思路 1.初始化&#xff1a; 2.遍历所有可能的四元组&#xff1a; 3.检查和&#xff1a; 4.避免重复&#xff1a; 5.更新计数器&#xff1a; 三、代码实现 四、运行结果 五、 算法复杂度分析 一、问题描述 给定一个整数数组 nums&…

某xx到家app逆向

去官网下载app即可 https://www.dongjiaotn.com/#/home查壳 360的壳子 直接脱壳即可 抓包 请求地址 https://api.gateway.znjztfn.cn/server/user/index 请求参数 {"lng": "xxxx","lat": "xxxx","city_id": "1376&…

docker搭建gitlab和jenkins

搭建gitlab 搭建gitlab首先需要一个gitlab的镜像 其次最好为他设置一个单独的目录 然后编写一个docker-compose文件 version: 3.1 services:gitlab:image: gitlab_zh:latest //此处为你的镜像名称container_name: gitlab //容器名称restart: always …

嵌入式linux中socket控制与实现

一、概述 1、首先网络,一看到这个词,我们就会想到IP地址和端口号,那IP地址和端口各有什么作用呢? (1)IP地址如身份证一样,是标识的电脑的,一台电脑只有一个IP地址。 (2)端口提供了一种访问通道,服务器一般都是通过知名端口号来识别某个服务。例如,对于每个TCP/IP实…