Pytorch指定数据加载器使用子进程

news2024/9/20 1:00:25
torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True,num_workers=4, pin_memory=True)

num_workers 参数是 DataLoader 类的一个参数,它指定了数据加载器使用的子进程数量。通过增加 num_workers 的数量,可以并行地读取和预处理数据,从而提高数据加载的速度。

通常情况下,增加 num_workers 的数量可以提高数据加载的效率,因为它可以使数据加载和预处理工作在多个进程中同时进行。然而,当 num_workers 的数量超过一定阈值时,增加更多的进程可能不会再带来更多的性能提升,甚至可能会导致性能下降。

这是因为增加 num_workers 的数量也会增加进程间通信的开销。当 num_workers 的数量过多时,进程间通信的开销可能会超过并行化所带来的收益,从而导致性能下降。

此外,还需要考虑到计算机硬件的限制。如果你的计算机 CPU 核心数量有限,增加 num_workers 的数量也可能会导致性能下降,因为每个进程需要占用 CPU 核心资源。

因此,对于 num_workers 参数的设置,需要根据具体情况进行调整和优化。通常情况下,一个合理的 num_workers 值应该在 2 到 8 之间,具体取决于你的计算机硬件配置和数据集大小等因素。在实际应用中,可以通过尝试不同的 num_workers 值来找到最优的配置。

综上所述,当 num_workers 的值从 4 增加到 8 时,如果你的计算机硬件配置和数据集大小等因素没有发生变化,那么两者之间的性能差异可能会很小,或者甚至没有显著差异。

测试代码如下

import torch
import torchvision
import matplotlib.pyplot as plt
import torchvision.models as models
import torch.nn as nn
import torch.optim as optim
import torch.multiprocessing as mp
import time

if __name__ == '__main__':
    mp.freeze_support()
    train_on_gpu = torch.cuda.is_available()
    if not train_on_gpu:
        print('CUDA is not available. Training on CPU...')
    else:
        print('CUDA is available! Training on GPU...')

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    batch_size = 4
    # 设置数据预处理的转换
    transform = torchvision.transforms.Compose([
        torchvision.transforms.Resize((512,512)),  # 调整图像大小为 224x224
        torchvision.transforms.ToTensor(),  # 转换为张量
        torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # 归一化
    ])
    dataset = torchvision.datasets.ImageFolder('C:\\Users\\ASUS\\PycharmProjects\\pythonProject1\\cats_and_dogs_train',
                                                     transform=transform)


    val_ratio = 0.2
    val_size = int(len(dataset) * val_ratio)
    train_size = len(dataset) - val_size
    train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

    train_dataset = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True,num_workers=4, pin_memory=True)
    val_dataset = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=True,num_workers=4, pin_memory=True)

    model = models.resnet18()

    num_classes = 2
    for param in model.parameters():
        param.requires_grad = False

    model.fc = nn.Sequential(
        nn.Dropout(),
        nn.Linear(model.fc.in_features, num_classes),
        nn.LogSoftmax(dim=1)
    )
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss().to(device)
    model.to(device)

    filename = "recognize_cats_and_dogs.pt"

    def save_checkpoint(epoch, model, optimizer, filename):
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss,
        }
        torch.save(checkpoint, filename)

    num_epochs = 3
    train_loss = []
    for epoch in range(num_epochs):
        running_loss = 0
        correct = 0
        total = 0
        epoch_start_time = time.time()
        for i, (inputs, labels) in enumerate(train_dataset):
            # 将数据放到设备上
            inputs, labels = inputs.to(device), labels.to(device)
            # 前向计算
            outputs = model(inputs)
            # 计算损失和梯度
            loss = criterion(outputs, labels)
            optimizer.zero_grad()
            loss.backward()
            # 更新模型参数
            optimizer.step()
            # 记录损失和准确率
            running_loss += loss.item()
            train_loss.append(loss.item())
            _, predicted = torch.max(outputs.data, 1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)
        accuracy_train = 100 * correct / total
        # 在测试集上计算准确率
        with torch.no_grad():
            running_loss_test = 0
            correct_test = 0
            total_test = 0
            for inputs, labels in val_dataset:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                running_loss_test += loss.item()

                _, predicted = torch.max(outputs.data, 1)
                correct_test += (predicted == labels).sum().item()
                total_test += labels.size(0)
            accuracy_test = 100 * correct_test / total_test
            # 输出每个 epoch 的损失和准确率
        epoch_end_time = time.time()
        epoch_time = epoch_end_time - epoch_start_time
        print("Epoch [{}/{}], Time: {:.4f}s, Loss: {:.4f}, Train Accuracy: {:.2f}%, Loss: {:.4f}, Test Accuracy: {:.2f}%"
              .format(epoch + 1, num_epochs,epoch_time,running_loss / len(val_dataset),
                      accuracy_train, running_loss_test / len(val_dataset), accuracy_test))
        save_checkpoint(epoch, model, optimizer, filename)

    plt.plot(train_loss, label='Train Loss')
    # 添加图例和标签
    plt.legend()
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.title('Training Loss')

    # 显示图形
    plt.show()

不同num_workers的结果如下

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1133019.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

多伦多 Pwn2Own 大赛首日战报!三星 Galaxy S23 被黑两次

Bleeping Computer 网站披露,加拿大多伦多举行的 Pwn2Own 2023 黑客大赛的第一天,网络安全研究人员就成功两次攻破三星 Galaxy S23。 大会现场,研究人员还“演示"了针对小米 13 Pro 智能手机、打印机、智能扬声器、网络附加存储 (NAS) 设…

Ubuntu卸载或重置防火墙规则

Ubuntu卸载或重置防火墙规则 1、开启防火墙后查看对应规则编号,然后进行删除 sudo ufw status numbered ——查看所有规则编号id sudo ufw delete 2 ——删除对应id的规则(比如删除2号规则) 2、按规则来删除。 例如,如果你使用s…

Meetup 回顾|Data Infra 研究社第十六期(含资料发布)

本文整理于上周六(10月21日)Data Infra 第 16 期的活动内容。本次活动由 Databend 研发工程师-王旭东为大家带来了一场主题为《Databend hash join spill 设计与实现》的分享,让我们一起回顾一下吧~ 以下是本次活动的相关视频、资料及文字&a…

【算法】模拟退火算法(SAA,Simulated Annealing Algorithm)

模拟退火算法(SAA)简介 模拟退火算法(SAA,Simulated Annealing Algorithm)的灵感来源于工艺铸造流程中的退火处理,随着铸造温度升高,分子运动趋于无序,徐徐冷却后,分子运…

【数据分享】2014-2022年我国淘宝村点位数据(Excel格式/Shp格式)

电子商务是过去一二十年我国发展最快的行业,其中又以淘宝为代表,淘宝的发展壮大带动了一大批服务淘宝电子商务的村庄,这些村庄被称为淘宝村! 截至到目前,阿里研究院梳理并公布了2014-2022年共9个年份的淘宝村名单&…

2.AUTOSAR SWC设计概述

1.SWC概述 SWC,全称Software Components,运行在RTE之上,属于应用算法逻辑这一层,如下图: 由1.AUTOSAR的架构及方法论中我们了解到该框架的提出就是为了减少平台移植成本、加快研发效率;这也就是说在AUTOSAR框架下,SWC作为组件是需要被重用的,意味着一个…

数据预处理(超详细)

import pandas as pd import numpy as np【例5-1】使用read_csv函数读取CSV文件。 df1 pd.read_csv("sunspots.csv")#读取CSV文件到DataFrame中 print(df1.sample(5))df2 pd.read_table("sunspots.csv",sep ",")#使用read_table,…

人工智能基础_机器学习003_有监督机器学习_sklearn中线性方程和正规方程的计算_使用sklearn解算八元一次方程---人工智能工作笔记0042

然后我们再来看看,如何使用sklearn,来进行正规方程的运算,当然这里 首先要安装sklearn,这里如何安装sklearn就不说了,自己查一下 首先我们还是来计算前面的八元一次方程的解,但是这次我们不用np.linalg.solve这个 解线性方程的方式,也不用 直接 解正规方程的方式: 也就是上面…

接口自动化测试实践

接口自动化概述 Python接口自动化测试零基础入门到精通(2023最新版) 众所周知,接口自动化测试有着如下特点: 低投入,高产出。 比较容易实现自动化。 和UI自动化测试相比更加稳定。 如何做好一个接口自动化测试项目呢…

华媒舍:怎样利用KOL出文营销推广打造出超级影响力?

利用KOL(Key Opinion Leader)出文营销推广已成为很多个人和企业提高影响力的重要方法。根据恰当的思路与技巧,你也能轻松吸引大批粉丝并打造无敌的存在影响力。下面我们就以科谱的形式详细介绍怎样利用KOL 出文营销推广,帮助自己做…

SD-WAN让跨境网络访问更快、更安全!

目前许多外贸企业都面临着跨境网络不稳定、不安全的问题,给业务合作带来了很多困扰。但是,现在有一个解决方案能够帮助您解决这些问题,让您的跨境网络访问更快、更安全,那就是SD-WAN! 首先,让我们来看看SD-…

微机原理:逻辑运算指令、移位指令

文章目录 一、逻辑运算指令1、取反运算指令2、与运算指令3、或运算指令4、异或运算 二、移位指令1、开环移位指令算术左移:SHL、SAL算术右移:SAR逻辑右移:SHR 2、闭环移位指令含进位的循环左移:RCL含进位的循环右移:RC…

lunar-1.5.jar

公历农历转换包 https://mvnrepository.com/artifact/com.github.heqiao2010/lunar <!-- https://mvnrepository.com/artifact/com.github.heqiao2010/lunar --> <dependency> <groupId>com.github.heqiao2010</groupId> <artifactId>l…

使用文件附件

文件附件在peoplesoft中非常常见 This chapter provides an overview of the file attachment functions and discusses: Developing applications that use file attachment functions. Application development considerations. Application deployment and system configu…

基于 Appium 的 Android UI 自动化测试!

自动化测试是研发人员进行质量保障的重要一环&#xff0c;良好的自动化测试机制能够让开发者及早发现编码中的逻辑缺陷&#xff0c;将风险前置。日常研发中&#xff0c;由于快速迭代的原因&#xff0c;我们经常需要在各个业务线上进行主流程回归测试&#xff0c;目前这种测试大…

Kafka入门04——原理分析

目录 01理解Topic和Partition Topic(主题) Partition(分区) 02理解消息分发 消息发送到分区 消费者订阅和消费指定分区 总结 03再均衡(rebalance) 再均衡的触发 分区分配策略 RangeAssignor(范围分区) RoundRobinAssignor(轮询分区) StickyAssignor(粘性分区) Re…

软件测试面试1000问(含文档)

前前后后面试了有20多家的公司吧&#xff0c;最近抽空把当时的录音整理了下&#xff0c;然后给大家分享下 开头都是差不多&#xff0c;就让做一个自我介绍&#xff0c;这个不用再给大家普及了吧 同时&#xff0c;我也准备了一份软件测试视频教程&#xff08;含接口、自动化、…

进阶课4——随机森林

1.定义 随机森林是一种集成学习方法&#xff0c;它利用多棵树对样本进行训练并预测。 随机森林指的是利用多棵树对样本进行训练并预测的一种分类器&#xff0c;每棵树都由随机选择的一部分特征进行训练和构建。通过多棵树的集成&#xff0c;可以增加模型的多样性和泛化能力。…

MTK AEE_EXP调试方法及user版本打开方案

一、AEE介绍 AEE (Android Exception Engine)是安卓的一个异常捕获和调试信息生成机制。 手机发生错误(异常重启/卡死)时生成db文件(一种被加密过的二进制文件)用来保存和记录异常发生时候的全部内存信息,经过调试和仿真这些信息,能够追踪到异常的缘由。 二、调试方法…

深度学习_6_实战_点集最优直线解_代码解析

问题描述&#xff1a; 上述题目的意思为&#xff0c;人工造出一些数据点&#xff0c;对我们的模型y Xw b ∈进行训练&#xff0c;其中标准模型如下&#xff1a; 其中W和X都为张量&#xff0c;我们训练的模型越接近题目给出的标准模型越好 训练过程如下&#xff1a; 人造数…