PyTorch读取数据集全部进内存,使网络的训练速度提高10倍以上!!!

news2025/1/23 21:21:54
  • 正常情况下,torch读取数据的时候是Batch Size小批量读取。
  • 首先找到所有数据集的路径保持到一个变量中,之后需要读取哪个数据的时候,就根据这个变量中的路径索引去读取。因为硬件的限制,从硬盘中读取数据到显存中所花的时间要远远大于从内存中读取数据到显存中。因此,如果程序直接是从硬盘上读取数据到显存中,就会非常耗时。
  • 我们可以把对应的图像直接全部都读取到内存中,然后将其转换为CPU类型的Tensor Dataloader,之后迭代的时候,从这个Tensor中迭代,那么网络在读取数据的时候,就是直接从内存中进行读取了,那么提高的速度就不是一点半点儿了,MNIST训练一个epoch的时间可以由原来的50s减少到5s,速度可以说是提高了10倍。

核心代码

## train_set = torchvision.datasets.MNIST(DOWNLOAD_PATH, train=True, download=True,
                                       transform=transform_mnist)       #train_set其实只是读取的一个路径,但是并没有读取数据主体

train_data = train_set.data.float().unsqueeze(1) / 255.0        #读取所有的数据集路径,然后将对应的数据集读取到内存中
train_label = train_set.targets

train_dataset =  TensorDataset(train_data,train_label)      #将原本保存在内存中的变量转化为Tensor
train_loader = DataLoader(train_dataset,batch_size=batch_size,shuffle=True)

完整代码

代码

### import time

import torch
import torch.nn.functional as F
import torchvision
from einops import rearrange
from torch import nn
from torch import optim
from torch.utils.data import TensorDataset, DataLoader


#残差模块,放在每个前馈网络和注意力之后
class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x, **kwargs):
        return self.fn(x, **kwargs) + x

#layernorm归一化,放在多头注意力层和激活函数层。用绝对位置编码的BERT,layernorm用来自身通道归一化
class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn

    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)
#放置多头注意力后,因为在于多头注意力使用的矩阵乘法为线性变换,后面跟上由全连接网络构成的FeedForward增加非线性结构
class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, dim)
        )

    def forward(self, x):
        return self.net(x)
#多头注意力层,多个自注意力连起来。使用qkv计算
class Attention(nn.Module):
    def __init__(self, dim, heads=8):
        super().__init__()
        self.heads = heads
        self.scale = dim ** -0.5

        self.to_qkv = nn.Linear(dim, dim * 3, bias=False)
        self.to_out = nn.Linear(dim, dim)

    def forward(self, x, mask = None):
        b, n, _, h = *x.shape, self.heads
        qkv = self.to_qkv(x)
        q, k, v = rearrange(qkv, 'b n (qkv h d) -> qkv b h n d', qkv=3, h=h)

        dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale

        if mask is not None:
            mask = F.pad(mask.flatten(1), (1, 0), value = True)
            assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions'
            mask = mask[:, None, :] * mask[:, :, None]
            dots.masked_fill_(~mask, float('-inf'))
            del mask

        attn = dots.softmax(dim=-1)

        out = torch.einsum('bhij,bhjd->bhid', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        out =  self.to_out(out)
        return out

class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, mlp_dim):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Residual(PreNorm(dim, Attention(dim, heads = heads))),
                Residual(PreNorm(dim, FeedForward(dim, mlp_dim)))
            ]))

    def forward(self, x, mask=None):
        for attn, ff in self.layers:
            x = attn(x, mask=mask)
            x = ff(x)
        return x
#将图像切割成一个个图像块,组成序列化的数据输入Transformer执行图像分类任务。
class ViT(nn.Module):
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels=3):
        super().__init__()
        assert image_size % patch_size == 0, 'image dimensions must be divisible by the patch size'
        num_patches = (image_size // patch_size) ** 2
        patch_dim = channels * patch_size ** 2

        self.patch_size = patch_size

        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        self.patch_to_embedding = nn.Linear(patch_dim, dim)
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.transformer = Transformer(dim, depth, heads, mlp_dim)

        self.to_cls_token = nn.Identity()

        self.mlp_head = nn.Sequential(
            nn.Linear(dim, mlp_dim),
            nn.GELU(),
            nn.Linear(mlp_dim, num_classes)
        )

    def forward(self, img, mask=None):
        p = self.patch_size

        x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)
        x = self.patch_to_embedding(x)

        cls_tokens = self.cls_token.expand(img.shape[0], -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x += self.pos_embedding
        x = self.transformer(x, mask)

        x = self.to_cls_token(x[:, 0])
        return self.mlp_head(x)

def train_epoch(model, optimizer, data_loader, loss_history):
    total_samples = len(data_loader.dataset)
    model = model.cuda()
    model.train()

    for i, (data, target) in enumerate(data_loader):
        # print("data.shape:", data.shape)
        data=data.cuda()
        target=target.cuda()
        optimizer.zero_grad()
        output = F.log_softmax(model(data), dim=1)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()

        if i % 100 == 0:
            print('[' +  '{:5}'.format(i * len(data)) + '/' + '{:5}'.format(total_samples) +
                  ' (' + '{:3.0f}'.format(100 * i / len(data_loader)) + '%)]  Loss: ' +
                  '{:6.4f}'.format(loss.item()))
            loss_history.append(loss.item())


def evaluate(model, data_loader, loss_history):
    model.eval()

    total_samples = len(data_loader.dataset)
    correct_samples = 0
    total_loss = 0
    model = model.cuda()
    with torch.no_grad():
        for data, target in data_loader:
            data = data.cuda()
            target = target.cuda()
            output = F.log_softmax(model(data), dim=1)
            loss = F.nll_loss(output, target, reduction='sum')
            _, pred = torch.max(output, dim=1)

            total_loss += loss.item()
            correct_samples += pred.eq(target).sum()

    avg_loss = total_loss / total_samples
    loss_history.append(avg_loss)
    print('\nAverage test loss: ' + '{:.4f}'.format(avg_loss) +
          '  Accuracy:' + '{:5}'.format(correct_samples) + '/' +
          '{:5}'.format(total_samples) + ' (' +
          '{:4.2f}'.format(100.0 * correct_samples / total_samples) + '%)\n')





# 设置随机种子
torch.manual_seed(42)

# 定义MNIST数据集的下载路径
DOWNLOAD_PATH = './data/mnist'

# 定义每个batch中包含的样本数量
batch_size = 5000

# 定义数据集预处理操作
transform_mnist = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((0.1307,), (0.3081,))
])

# 设置数据加载器的线程数
num_workers = 0

# 读入训练集和测试集数据
train_set = torchvision.datasets.MNIST(DOWNLOAD_PATH, train=True, download=True, transform=transform_mnist)
test_set = torchvision.datasets.MNIST(DOWNLOAD_PATH, train=False, download=True, transform=transform_mnist)

# 将数据集转换为张量
train_data = train_set.data.float().unsqueeze(1) / 255.0
train_labels = train_set.targets
test_data = test_set.data.float().unsqueeze(1) / 255.0
test_labels = test_set.targets

# 将数据集包装成Dataset对象
train_dataset = TensorDataset(train_data, train_labels)
test_dataset = TensorDataset(test_data, test_labels)

# 创建训练集和测试集的DataLoader
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)





N_EPOCHS = 100

'''
patch大小为 7x7(对于 28x28 图像,这意味着每个图像 4 x 4 = 16 个patch)、10 个可能的目标类别(0 到 9)和 1 个颜色通道(因为图像是灰度)。
在网络参数方面,使用了 64 个单元的维度,6 个 Transformer 块的深度,8 个 Transformer 头,MLP 使用 128 维度。'''
model = ViT(image_size=28, patch_size=7, num_classes=10, channels=1,
            dim=64, depth=6, heads=8, mlp_dim=128)
optimizer = optim.Adam(model.parameters(), lr=0.003)

train_loss_history, test_loss_history = [], []
for epoch in range(1, N_EPOCHS + 1):
    start_time = time.time()
    print('Epoch:', epoch)
    train_epoch(model, optimizer, train_loader, train_loss_history)
    evaluate(model, test_loader, test_loss_history)

    print('Execution time:', '{:5.2f}'.format(time.time() - start_time), 'seconds')

速度对比

  • RTX2060 使用transformer 训练一个epoch的时间

从硬盘中读取数据

Batch_Size time
300 62s
500 59s
1000 57s
3000 48s
5000 47s
10000 50s

将数据读入到内存中再运行
100 51s
500 12s 1500M 37%
1000 7.18s 1700M 30%
5000 5.84s 3400M 67%
10000 4.65s 5400M 54%

这速度差的真的不是一点半点的

从硬盘中读取数据

Batch_Size time
300 62s
500 59s
1000 57s
3000 48s
5000 47s
10000 50s

将数据读入到内存中再运行
100 51s
500 12s 1500M 37%
1000 7.18s 1700M 30%
5000 5.84s 3400M 67%
10000 4.65s 5400M 54%

这速度差的真的不是一点半点的

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/521770.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

计算机体系结构实验一

计算机体系结构实验一 一.实验目的 ​理解RISC-V的指令执行的数据流和控制信号,熟悉指令流水线的工作过程。 二.实验过程 1.RISC-V的相关指令 实验的模拟器使用RISC-V指令集,为了便于后续分析,首先学习实验中使用的RISC-V指令。 基本RIS…

Cesium最新版使用天地图地形及注记服务

天地图三维地名服务和地形服务需要利用 cesium 开源三维地球API与天地图扩展插件共同使用,目前支持cesuim1.52、1.58、1.63.1。 天地图调用demo: http://lbs.tianditu.gov.cn/docs/#/sanwei/ 注意: demo里的地形服务地址不对,需要自己更换成…

MCU通用移植方案

MCU通用移植方案 目录 MCU通用移植方案前言1 硬件移植2 软件移植2.1 底层移植方法2.1.1 移植原理2.1.2 移植方法 2.2 中间层移植方法2.2.1 移植原理2.2.2 移植方法 2.3 两种移植方法比对 3 结束语 前言 因为项目的需求或者成本控制等因素,我们经常会遇到更换MCU的情…

华硕 PRIME H610M-A D4 i5-12490F 1060电脑 Hackintosh 黑苹果efi引导文件

原文来源于黑果魏叔官网,转载需注明出处。(下载请直接百度黑果魏叔) 硬件型号驱动情况 主板华硕 PRIME H610M-A D4(LPC Controller/eSPI Controller H610芯片组) 处理器12th Gen Intel Core i5-12490F 六核已驱动 内…

Mysql的重要知识点以及问题

查看索引的命令 show index from mytable 索引的原理 索引用来快速地寻找那些具有特定值的记录。如果没有索引,⼀般来说执行查询时遍历整张表。 索引的原理:就是把无序的数据变成有序的查询 把创建了索引的列的内容进行排序 对排序结果生成倒排表…

货运物流小程序开发功能有哪些?

移动互联网的深入发展让网购等线上交易更加盛行,货运快递物流也随之增多,成为我们日常生活的重要组成部分。传统的货运物流管理主要依赖人工,不仅效率慢还容易出错。随着市场的发展以及人们对服务质量要求的提高,现在很多中大型货…

Oracle 12c安装

前言 版本:12c第二版 检查弹出窗口程序,需要安装xmanager,并执行以下命令: xhost 192.168.194.91 安装步骤如下 安装必须的安装包: rpm -q bc binutils compat-libcap1 compat-libstdc-33 glibc glibc-devel ksh libaio libaio…

c++ 多态与虚函数

c中多态分为静态多态和动态多态,静态多态是函数重载,在编译阶段就能确定调用哪个函数。动态多态是由继承产生的,指不同的对象根据所接收的消息(成员函数)做出不同的反应。例如,动物都能发出叫声,但不同的动物能发出不同…

esp32之解析json

文章目录 前言一、json的作用二、json结构三、esp32 json解析安装库解析StaticJsonDocumentDynamicJsonDocument 四、解析今天的北京天气总结 前言 在现代Web开发中,JSON(JavaScript Object Notation)已成为常用的数据传输格式。ESP32是一款…

Netty 爱好者必看!一文详解 ChannelHandler 家族,助你快速掌握 Netty 开发技巧!

1 Channel 接口的生命周期 Channel 定义了一组和 ChannelInboundHandler API 密切相关的简单但功能强大的状态模型 1.1 Channel 的状态 状 态描 述ChannelUnregisteredChannel 已经被创建,但还未注册到 EventLoopChannelRegisteredChannel 已经被注册到了 EventL…

Wealth 开源的账本响应式网站系统免费部署

演示网站: https://wealth.willin.wang 前置准备 首先需要注册一个 Github 账号,Fork 这个开源项目: https://github.com/willin/wealth (欢迎 Star) 然后使用 Github 账号分别注册 Vercel 和 Planetscale&#xf…

【Linux 】 ps命令详解,查看进程pid

文章目录 ps概述ps语法指定pid进行查看 ps概述 ps 命令是最常用的监控进程的命令,通过此命令可以查看系统中所有运行进程的详细信息。 ps 命令有多种不同的使用方法,这常常给初学者带来困惑。在各种 Linux 论坛上,询问 ps 命令语法的帖子屡…

双向链表--C语言实现数据结构

本期带大家一起用C语言实现双向链表🌈🌈🌈 文章目录 一、链表的概念🌎二、链表中数据元素的构成🌎 🌍三、链表的结构🌎 🌍 🌏四、 双向带哨兵位循环链表的实现&#x1f3…

ROS2 中 使用奥比中光 Orbbec Astra Pro 深度相机

本文将以 Ubuntu 20.04 和 ROS2 foxy 环境为例,详细介绍如何在 ROS2 中使用奥比中光 Orbbec Astra Pro 深度相机。在这一篇文章中,你会学到如何创建工作空间,使用 usb_cam 功能包,编译安装使用 ros_astra_camera 等。 文章目录 1.…

年薪50万的程序员和一般的中学教师相比,被亲戚看不起

我是一名程序员,已经工作五年,年薪大概有50万左右。然后,亲戚家的孩子是博士生,在一所中学教书,自己一年的工资可以抵达五六年的薪资,不过还是被亲戚给鄙视了。 很多人都持有不同的观点。我自己是一名程序…

vue-事件修饰符+键盘事件

事件修饰符 1、prevent&#xff1a; 阻止默认事件&#xff08;或在方法中使用e.preventDefault()&#xff09; <a hrefhttps://blog.csdn.net/weixin_52993364?typeblog click.preventshowInfo>点我</a> 说明&#xff1a;这样点击后就不会发生地址的跳转 2、s…

Linux查找指令 时间查看

date 我们在windows中想要看一下时间&#xff0c;我们可以直接在显示器上看到&#xff0c;但是如果我们用的是linux远程登录软件我们像查看一下时间&#xff0c;我们应该怎么做&#xff1f; 我们直接输入date&#xff0c;我们就可以看到当前的时间&#xff0c;不过这个是系统按…

蚁群算法ACS处理旅行商问题TSP【Java实现】

1. 介绍 蚁群算法是一种群体智能算法&#xff0c;模拟了蚂蚁寻找食物时的行为&#xff0c;通过蚂蚁之间的信息交流和合作&#xff0c;最终实现全局最优解的寻找【是否找得到和迭代次数有关】。 蚁群算法的基本思想是将搜索空间看作一个由节点组成的图&#xff0c;每个节点代表…

Linux awk [-v] {print} 命令

AWK 是一种处理文本文件的语言&#xff0c;是一个强大的文本分析工具。 语法&#xff1a;语法&#xff1a;awk 条件1 {动作 1} 条件 2 {动作 2} … 文件名 awk是处理文本文件的语言&#xff0c;所以要传入文本数据供其处理&#xff08;文件逐行读入&#xff09;&#xff0c;…

合宙Air780e C-SDK开发

Air78e简介 AirXXXE系列模组&#xff0c;是合宙通信基于移芯EC618平台设计研发的新款4G Cat.1模组。 Air780e的资料点击这里打开。 Air78e开发板简介 一代 IPEX 天线连接器&#xff08;选配&#xff09;4G 弹簧天线一个下载/调试串口&#xff0c;两个通用串口IO 口默认电平…