深度学习中的并行策略概述:2 Data Parallelism

news2024/12/25 22:45:42

深度学习中的并行策略概述:2 Data Parallelism
在这里插入图片描述
数据并行(Data Parallelism)的核心在于将模型的数据处理过程并行化。具体来说,面对大规模数据批次时,将其拆分为较小的子批次,并在多个计算设备上同时进行处理。每个设备负责处理一个子批次,实现并行计算。处理完成后,将各个设备上的计算结果汇总,以便对模型进行统一更新。由于其在深度学习中的普遍应用,数据并行成为了一种广泛支持的并行计算策略,并在主流框架中得到了良好的实现。

以下代码展示了如何在PyTorch中使用nn.DataParallel和DistributedDataParallel实现数据并行,以加速模型的训练过程。

使用nn.DataParallel实现数据并行

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

# 假设我们有一个简单的数据集类
class SimpleDataset(Dataset):
    def __init__(self, data, target):
        self.data = data
        self.target = target

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], self.target[idx]

# 假设我们有一个简单的神经网络模型
class SimpleModel(nn.Module):
    def __init__(self, input_dim):
        super(SimpleModel, self).__init__()
        self.fc = nn.Linear(input_dim, 1)

    def forward(self, x):
        return torch.sigmoid(self.fc(x))

# 假设我们有一些数据
n_sample = 100
n_dim = 10
batch_size = 10
X = torch.randn(n_sample, n_dim)
Y = torch.randint(0, 2, (n_sample,)).float()
dataset = SimpleDataset(X, Y)
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# 初始化模型
device_ids = [0, 1, 2]  # 指定使用的GPU编号
model = SimpleModel(n_dim).to(device_ids[0])
model = nn.DataParallel(model, device_ids=device_ids)

# 定义优化器和损失函数
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
criterion = nn.BCELoss()

# 训练模型
for epoch in range(10):
    for batch_idx, (inputs, targets) in enumerate(data_loader):
        inputs, targets = inputs.to('cuda'), targets.to('cuda')
        outputs = model(inputs)
        loss = criterion(outputs, targets.unsqueeze(1))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        print(f'Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item()}')

使用DistributedDataParallel实现数据并行

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

# 假设我们有一个简单的数据集类
class SimpleDataset(Dataset):
    def __init__(self, data, target):
        self.data = data
        self.target = target

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], self.target[idx]

# 假设我们有一个简单的神经网络模型
class SimpleModel(nn.Module):
    def __init__(self, input_dim):
        super(SimpleModel, self).__init__()
        self.fc = nn.Linear(input_dim, 1)

    def forward(self, x):
        return torch.sigmoid(self.fc(x))

# 初始化进程组
def init_process(rank, world_size, backend='nccl'):
    dist.init_process_group(backend, rank=rank, world_size=world_size)

# 训练函数
def train(rank, world_size):
    init_process(rank, world_size)
    torch.cuda.set_device(rank)
    model = SimpleModel(10).to(rank)
    model = DDP(model, device_ids=[rank])

    dataset = SimpleDataset(torch.randn(100, 10), torch.randint(0, 2, (100,)).float())
    sampler = torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=world_size, rank=rank)
    data_loader = DataLoader(dataset, batch_size=10, sampler=sampler)

    optimizer = optim.SGD(model.parameters(), lr=0.01)
    criterion = nn.BCELoss()

    for epoch in range(10):
        for inputs, targets in data_loader:
            inputs, targets = inputs.to(rank), targets.to(rank)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets.unsqueeze(1))
            loss.backward()
            optimizer.step()

if __name__ == "__main__":
    world_size = 4
    torch.multiprocessing.spawn(train, args=(world_size,), nprocs=world_size, join=True)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2265503.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

OneCode:开启高效编程新时代——企业定制出码手册

一、概述 OneCode 的 DSM(领域特定建模)出码模块是一个强大的工具,它支持多种建模方式,并具有强大的模型转换与集成能力,能够提升开发效率和代码质量,同时方便团队协作与知识传承,还具备方便的仿…

《Web 应用项目开发:从构思到上线的全过程》

目录 一、引言 二、项目启动与需求分析 三、设计阶段 四、技术选型 五、开发阶段 六、测试阶段 七、部署与上线 八、维护与更新 九、总结 一、引言 在数字化浪潮席卷全球的当下,Web 应用如繁星般在互联网的苍穹中闪烁,它们形态各异&#xff0c…

中小学教室多媒体电脑安全登录解决方案

中小学教室多媒体电脑面临学生随意登录的问题,主要涉及到设备使用、网络安全、教学秩序等多个方面。以下是对这一问题的详细分析: 一、设备使用问题 1. 设备损坏风险 学生随意登录可能导致多媒体电脑设备过度使用,增加设备损坏的风险。不当…

Odoo 免费开源 ERP:通过 JavaScript 创建对话框窗口的技术实践分享

作者 | 老杨 出品 | 上海开源智造软件有限公司(OSCG) 概述 在本文中,我们将深入研讨如何于 Odoo 18 中构建 JavaScript(JS)对话框或弹出窗口。对话框乃是展现重要讯息、确认用户操作以及警示用户留意警告或错误的行…

OOP面向对象编程:类与类之间的关系

OOP面向对象编程:类与类之间的关系 三大关系:复合(适配器设计模式)、委托(桥接设计模式)、继承 8、1复合Composition has-a -> 适配器模式 一个类里面含有另一个类的对象 —> 复合关系 has-a 适配器设…

集成 jacoco 插件,查看单元测试覆盖率

文章目录 前言集成 jacoco 插件,查看单元测试覆盖率1. 添加pom2. 配置完成、执行扫描3. 执行结果4. 单元测试报告 前言 如果您觉得有用的话,记得给博主点个赞,评论,收藏一键三连啊,写作不易啊^ _ ^。   而且听说点赞…

下载运行Vue开源项目vue-pure-admin

git地址:GitHub - pure-admin/vue-pure-admin: 全面ESMVue3ViteElement-PlusTypeScript编写的一款后台管理系统(兼容移动端) 安装pnpm npm install -g pnpm # 国内 淘宝 镜像源 pnpm config set registry https://registry.npmmirror.com/…

创建用于预测序列的人工智能模型,设计模型架构。

上一篇:《创建用于预测序列的人工智能模型,设计数据集》 序言:在前一篇中,我们创建了用于训练人工智能模型的数据集。接下来,就要设计模型的架构了。其实,人工智能模型的开发关键并不在于代码量&#xff0…

ubuntu22.04安装PPOCRLabel

可使用的模型参考模型列表,ppocr版本这里PPOCR版本作为预训练模型: (经常用放在这里) 基础电脑配置: cunda12.4 ubuntu22.04系统 pytorch2.5.0 (python3.10不能运行,python3.8我之前可以正…

Linux网络——TCP的运用

系列文章目录 文章目录 系列文章目录一、服务端实现1.1 创建套接字socket1.2 指定网络接口并bind2.3 设置监听状态listen2.4 获取新链接accept2.5 接收数据并处理(服务)2.6 整体代码 二、客户端实现2.1 创建套接字socket2.2 指定网络接口2.3 发起链接con…

江苏捷科云:可视化平台助力制造企业智能化管理

公司简介 江苏捷科云信息科技有限公司(以下简称“捷科”)是一家专注于云平台、云储存、云管理等产品领域的创新型企业,集研发、生产和销售于一体,致力于在网络技术领域打造尖端品牌。在推动制造业企业数字化转型的进程中&#xf…

消息队列(一)消息队列的工作流程

什么是消息队列 首先,代入一个场景,我现在做一个多系统的集成,分别有系统A、B、C、D四个系统,A系统因为使用产生了业务数据,B、C、D需要使用这些数据做相关的业务处理和运算,最基本的做法就是通过接口通信…

施耐德变频器ATV320系列技术优势:创新与安全并重

在工业自动化领域,追求高效、安全与智能已成为不可阻挡的趋势。施耐德变频器ATV320系列凭借其强大的设计标准和全球认证,成为能够帮助企业降低安装成本,提高设备性能的创新解决方案。 【全球认证,品质保障】ATV320 系列秉持施耐德…

WEB入门——文件上传漏洞

文件上传漏洞 一、文件上传漏洞 1.1常见的WebShell有哪些?1.2 一句话木马演示1.2 文件上传漏洞可以利用需满足三个条件1.3 文件上传导致的危害 二、常用工具 2.1 搭建upload-labs环境2.2 工具准备 三、文件上传绕过 3.1 客户端绕过 3.1.1 实战练习 :upl…

Android 蓝牙开发-传输数据

概述 传统蓝牙是通过建立REFCCOM sockect来进行通信的,类似于socket通信,一台设备需要开放服务器套接字并处于listen状态,而另一台设备使用服务器的MAC地址发起连接。连接建立后,服务器和客户端就都通过对BluetoothSocket进行读写…

红米Note 9 Pro5G刷小米官方系统

前言 刷机有2种方式:线刷 和 卡刷。 线刷 线刷:需要用电脑刷机工具,例如:XiaoMiFlash.exe,通过电脑和数据线对设备进行刷机。 适用场景: 系统损坏无法开机。恢复官方出厂固件。刷机失败导致软砖、硬砖的…

html + css 淘宝网实战

之前有小伙伴说,淘宝那么牛逼你会写代码,能帮我做一个一样的淘宝网站吗,好呀,看我接下来如何给你做一个淘宝首页。hahh,开个玩笑。。。学习而已。 在进行html css编写之前 先了解下网页的组成和网页元素的尺寸吧 1.网页的组成 …

SOME/IP 协议详解——信息格式

文章目录 1. 头部格式1.1 消息 ID(Message ID)1.2 长度(Length)1.3 请求 ID(Request ID)1.4 协议版本(Protocol Version):1.5 接口版本(Interface Version&am…

使用QML实现播放器进度条效果

使用QML实现播放进度效果 QML Slider介绍 直接上DEMO如下: Slider {width: 300;height: 20;orientation: Qt.Vertical; //决定slider是横还是竖 默认是HorizontalstepSize: 0.1;value: 0.2;tickmarksEnabled: true; //显示刻度}效果图如下 那么我先改变滑块跟滚轮…

Android——自定义按钮button

项目中经常高频使用按钮,要求:可设置颜色,有圆角且有按下效果的Button 一、自定义按钮button button的代码为 package com.fslihua.clickeffectimport android.annotation.SuppressLint import android.content.Context import android.gra…