PyTorch分布式数据加载学习 DistributedSampler

news2024/10/4 20:34:40

[源码解析] PyTorch 分布式(1) --- 数据加载之DistributedSampler - 罗西的思考 - 博客园

初始化

class DistributedSampler(Sampler[T_co]):
    def __init__(self, dataset: Dataset, num_replicas: Optional[int] = None,
                 rank: Optional[int] = None, shuffle: bool = True,
                 seed: int = 0, drop_last: bool = False) -> None:
        # If the dataset length is evenly divisible by # of replicas, then there
        # is no need to drop any data, since the dataset will be split equally.
        if self.drop_last and len(self.dataset) % self.num_replicas != 0:
            # Split to nearest available length that is evenly divisible.
            # This is to ensure each rank receives the same amount of data when
            # using this Sampler.
            self.num_samples = math.ceil(
                (len(self.dataset) - self.num_replicas) / self.num_replicas
            )
        else:
            self.num_samples = math.ceil(len(self.dataset) / self.num_replicas) # 向上取整
        self.total_size = self.num_samples * self.num_replicas
        self.shuffle = shuffle
        self.seed = seed

如果不drop_last,那就

self.num_samples = math.ceil(len(self.dataset) / self.num_replicas)

向上取整。

迭代

在迭代的时候,如果不能整除,那就把indices的前几个样本复制一遍:

    def __iter__(self) -> Iterator[T_co]:
        if self.shuffle:
            # deterministically shuffle based on epoch and seed
            g = torch.Generator()
            g.manual_seed(self.seed + self.epoch)
            indices = torch.randperm(len(self.dataset), generator=g).tolist()
        else:
            indices = list(range(len(self.dataset)))

        if not self.drop_last:  # 一般进入这里,不会丢掉剩下的训练数据
            # add extra samples to make it evenly divisible
            padding_size = self.total_size - len(indices)
            if padding_size <= len(indices):
                indices += indices[:padding_size]   # 把indices的前几个复制一次
            else:
                indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size]
        else:
            # remove tail of data to make it evenly divisible.
            indices = indices[:self.total_size]
        assert len(indices) == self.total_size

        # subsample
        indices = indices[self.rank:self.total_size:self.num_replicas]
        assert len(indices) == self.num_samples

        return iter(indices)

最关键的还是这行:

indices = indices[self.rank:self.total_size:self.num_replicas]

规定了每个rank的取数据的索引,起始索引是rank,每间隔num_replicas取一个

shuffle数据集

每次epoch都会shuffle数据集,但是不同进程如何保持shuffle之后数据集一致性

DistributedSampler 使用当前的epoch作为随机数种子,在计算index之前就进行配置,从而保证不同进程都使用同样的随机数种子,这样shuffle出来的数据就能确保一致。

sampler = DistributedSampler(dataset) if is_distributed else None
loader = DataLoader(dataset, shuffle=(sampler is None), ...,
                            sampler=sampler)
	for epoch in range(start_epoch, n_epochs):
    	if is_distributed:
        	sampler.set_epoch(epoch) # 这设置epoch
        train(loader)

设置 random 种子的具体使用是在迭代函数之中:

    def __iter__(self) -> Iterator[T_co]:
        if self.shuffle:
            # deterministically shuffle based on epoch and seed
            g = torch.Generator()
            g.manual_seed(self.seed + self.epoch) # 这里设置随机种子
            indices = torch.randperm(len(self.dataset), generator=g).tolist()  # type: ignore[arg-type]
        else:
            indices = list(range(len(self.dataset)))  # type: ignore[arg-type]

在 PyTorch 中,torch.randperm(n) 函数用于生成一个从 0n-1 的随机排列的整数序列。这个函数是非常有用的,尤其是在需要随机打乱数据或索引时,比如在训练机器学习模型时打乱数据顺序,以确保模型训练的泛化能力。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2188652.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

HTML+CSS基础用法介绍五

目录&#xff1a; 结构伪类选择器盒子模型-边框线盒子模型-内边距盒子模型-解决盒子被撑大盒子模型-外边距与版心居中小知识&#xff1a;清除浏览器中所有标签的默认样式内容溢出控制显示方式盒子模型-圆角 &#x1f40e;正片开始 结构伪类选择器 什么是结构伪类选择器&…

18.安卓逆向-frida基础-调试实战2

免责声明&#xff1a;内容仅供学习参考&#xff0c;请合法利用知识&#xff0c;禁止进行违法犯罪活动&#xff01; 内容参考于&#xff1a;图灵Python学院 本人写的内容纯属胡编乱造&#xff0c;全都是合成造假&#xff0c;仅仅只是为了娱乐&#xff0c;请不要盲目相信。 工…

Windows UAC权限详解以及因为权限不对等引发软件工具无法正常使用的实例分析

目录 ​1、什么是UAC&#xff1f; 2、微软为什么要设计UAC&#xff1f; 3、标准用户权限与管理员权限 4、程序到底以哪种权限运行&#xff1f;与哪些因素有关&#xff1f; 4.1、给程序设置以管理员权限运行的属性 4.2、当前登录用户的类型 4.3、如何通过代码判断某个进程…

智能 AI 写作软件:开启创作新纪元

不论你在哪行哪业应该都躲不开写作这件事被。写作已经成为了我们生活和工作中不可或缺的一部分。随着人工智能技术的飞速发展&#xff0c;AI 智能写作工具应运而生。接下来&#xff0c;让我们一起揭开智能ai写作工具的神秘面纱。 1.笔灵AI写作 直通车&#xff1a;https://ibi…

②EtherCAT转ModbusTCP, EtherCAT/Ethernet/IP/Profinet/ModbusTCP协议互转工业串口网关

EtherCAT/Ethernet/IP/Profinet/ModbusTCP协议互转工业串口网关https://item.taobao.com/item.htm?ftt&id822721028899 协议转换通信网关 EtherCAT 转 Modbus TCP (接上一章&#xff09; GW系列型号 配置说明 上载 网线连接电脑到模块上的 WEB 网页设置网口&#…

论文笔记:Online Class-Incremental Continual Learning with Adversarial Shapley Value

这篇工作的focus 是 memory-based approach 1. 挑战/问题&#xff1a; 灾难性遗忘&#xff1a;深度神经网络在学习新任务时往往会忘记先前任务的知识。内存和计算效率&#xff1a;在个人设备上执行深度学习任务时&#xff0c;需要最小化内存占用和计算成本。数据流增量学习&am…

系统安全 - 大数据组件的安全及防护

文章目录 导图1. Hadoop的安全风险2. 常见攻击方式3. Hadoop的自带安全功能4. Apache Knox和Apache Ranger等安全框架5. 安全策略建议 导图 1. Hadoop的安全风险 Hadoop最初设计为在可信网络中运行&#xff0c;因此默认安全性较低。常见的安全风险包括&#xff1a; 未经授权的…

探索未来:揭秘pymqtt,AI与物联网的新桥梁

文章目录 探索未来&#xff1a;揭秘pymqtt&#xff0c;AI与物联网的新桥梁背景&#xff1a;为什么选择pymqtt&#xff1f;什么是pymqtt&#xff1f;如何安装pymqtt&#xff1f;简单的库函数使用方法1. 配置MQTT连接2. 创建Mqtt对象3. 发布消息4. 订阅主题5. 运行MQTT客户端 场景…

Java项目实战II基于Java+Spring Boot+MySQL的小徐影城管理系统设计与实现(源码+数据库+文档)

目录 一、前言 二、技术介绍 三、系统实现 四、文档参考 五、核心代码 六、源码获取 全栈码农以及毕业设计实战开发&#xff0c;CSDN平台Java领域新星创作者 一、前言 随着文化娱乐产业的快速发展&#xff0c;影城管理面临着日益复杂的挑战&#xff0c;包括票务管理、座…

Redis操作常用API

说明&#xff1a;Redis应用于java项目中&#xff0c;操作Redis数据可以使用API&#xff0c;相较于命令行更方便。使用前&#xff0c;需先添加依赖。 <dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-data-re…

HIKVISION 海康威视对讲服务配置平台弱口令

漏洞描述 杭州海康威视系统技术有限公司对讲服务配置平台存在弱口令 漏洞复现 FOFA "document.write(TITLE_SYSTEM);" POC admin #账号 12345 #密码 登录成功

利用Spring Boot打造新闻推荐解决方案

1系统概述 1.1 研究背景 如今互联网高速发展&#xff0c;网络遍布全球&#xff0c;通过互联网发布的消息能快而方便的传播到世界每个角落&#xff0c;并且互联网上能传播的信息也很广&#xff0c;比如文字、图片、声音、视频等。从而&#xff0c;这种种好处使得互联网成了信息传…

Kotlin基本知识

Kotlin是一种现代的静态类型编程语言&#xff0c;由JetBrains公司在2010年推出&#xff0c;并被Google在2019年宣布为Android开发的首选语言。 超过 50% 的专业 Android 开发者使用 Kotlin 作为主要语言&#xff0c;而只有 30% 使用 Java 作为主要语言。 70% 以 Kotlin 为主要语…

Redis数据库与GO(二):list,set

一、list&#xff08;列表&#xff09; list&#xff08;列表&#xff09;是简单的字符串列表&#xff0c;按照插入顺序排序。你可以添加一个元素到列表的头部(左边)或者尾部(右边)。List本质是个链表&#xff0c; list是一个双向链表&#xff0c;其元素是有序的&#xff0c;元…

【含文档】基于Springboot+Vue的护肤品推荐系统(含源码+数据库+lw)

1.开发环境 开发系统:Windows10/11 架构模式:MVC/前后端分离 JDK版本: Java JDK1.8 开发工具:IDEA 数据库版本: mysql5.7或8.0 数据库可视化工具: navicat 服务器: SpringBoot自带 apache tomcat 主要技术: Java,Springboot,mybatis,mysql,vue 2.视频演示地址 3.功能 系统定…

ctfshow-web入门(信息收集,持续更新中。。)

写在之前:近期打了个比赛,备受打击,入手了vip账号进修,加油! 文章目录 ctfshow-web1查看源代码ctfshow-web2burp抓包ctfshow-web3burp抓包ctfshow-web4访问robots.txtctfshow-web5dirscarch扫描PHPS文件泄露ctfshow-web6dirscarch扫描ctfshow-web7dirscarch扫描ctfshow-w…

力扣 简单 101.对称二叉树

文章目录 题目介绍解题思路 题目介绍 解题思路 在上题【100. 相同的树】的基础上稍加改动,将根节点的左右子树看成左右两个树 递归判断左边的右子树和右边的左子树以及左边的左子树和右边的右子树是否都相同 class Solution {public boolean isSymmetric(TreeNode root) {re…

1.2.2 计算机网络的分层结构(下)

水平视角 YSCS协议&#xff08;压缩传输协议&#xff09; 发送方先压缩然后接收方再解压。 为什么要分层&#xff1f;为什么要制定协议&#xff1f; 计算机网路功能负责->采用分层结构&#xff0c;将诸多功能合理地划分在不同层次->对等层之间制定协议&#xff0c;以…

10.4今日错题解析(软考)

目录 前言系统开发基础——概要设计与详细设计系统开发基础——开发模型 前言 这是用来记录我备考软考设计师的错题的&#xff0c;今天知识点为概要设计与详细设计、开发模型&#xff0c;大部分错题摘自希赛中的题目&#xff0c;但相关解析是原创&#xff0c;有自己的思考&…

SpringBoot3+Vue3开发后台管理系统脚手架

后台管理系统脚手架 介绍 在快速迭代的软件开发世界里&#xff0c;时间就是生产力&#xff0c;效率决定成败。对于构建复杂而庞大的后台系统而言&#xff0c;一个高效、可定制的后台脚手架&#xff08;Backend Scaffold&#xff09;无疑是开发者的得力助手。 脚手架 后台脚…