PyTorch深度学习实战(7)——批大小对神经网络训练的影响

news2024/11/17 17:29:51

PyTorch深度学习实战(7)——批大小对神经网络性能的影响

    • 0. 前言
    • 1. 批处理概念
    • 2. 批处理的优势
    • 3. 批大小对神经网络性能的影响
      • 3.1 批大小为 32
      • 3.2 批大小为 30,000
    • 小结
    • 系列链接

0. 前言

在神经网络中,批( batch )是指一次输入网络进行训练或推断的一组样本。批处理( batch processing )是指将这一组样本同时输入网络进行计算的操作。本节中首先介绍批( Batch )的基本概念,并且介绍批大小在神经网络训练过程中的影响。

1. 批处理概念

在深度学习中,批( Batch )是指一次输入神经网络的一组样本,批处理的思想是将多个样本同时输入网络进行计算,通过并行化的方式提高计算效率。
在神经网络训练中,训练数据集 Fashion MNIST 中每批包含 32 个数据样本,每个 epoch 中权重的更新次数较多,每个 epoch 会更新 1,875 次权重( 60,000/32 ≈ 1,875,其中 60,000 是训练图像的数量)。此外,并没有考虑模型对未见过的数据集(验证数据集)上的性能。

在本节中,我们将进行以下实验:

  • 当训练批大小为 32 时,观察模型在训练和验证数据上的损失和准确度
  • 当训练批大小为 30,000 时,观察模型在训练和验证数据上的损失和准确度

如果需要使用验证数据,计算模型在验证数据集上的损失和准确度,需要获取验证数据。

2. 批处理的优势

在神经网络训练过程中,批处理具有以下优势:

  • 提高计算效率:通过同时处理多个样本,可以充分利用现代计算硬件(如 GPU )的并行计算能力,加快模型的训练和推断速度。相比逐个样本处理,批处理可以在单次计算中同时完成多个样本的前向传播和反向传播操作。
  • 稳定梯度计算:批处理能够提供对更多样本的梯度计算,从而减小梯度的随机性,使得训练更加稳定。相比于单个样本的梯度更新,批处理可以平均多个样本的梯度,从而减少了噪声的影响。
  • 有效利用内存:批处理可以将多个样本一次性加载到内存中,减少了数据读取和存储的开销。特别是对于训练集较大的情况下,批处理能够显著降低I/O等方面的负载,提高整体的训练效率。

在训练过程中,通常会将数据集按照随机顺序划分成多个批次,并进行多次迭代训练。这样可以增强模型的泛化能力,并降低对特定顺序的依赖性。

3. 批大小对神经网络性能的影响

在实际模型训练过程中,批大小( batch size )是一个重要的超参数,需要根据具体的任务和资源情况进行选择。较大的批大小可以加速计算,但会增加内存需求;而较小的批大小有利于更好地探索样本空间,但可能导致梯度估计不稳定。

3.1 批大小为 32

我们已经在神经网络训练一节中构建了在训练期间批大小为 32 模型,本节中,我们将重点说明如何处理验证数据集。

(1) 下载并导入训练图像和目标值:

from torchvision import datasets
import torch
data_folder = './data/FMNIST'
fmnist = datasets.FashionMNIST(data_folder, download=True, train=True)

tr_images = fmnist.data
tr_targets = fmnist.targets

(2) 与训练图像类似,需要在调用 FashionMNIST 方法时通过指定 train = False 下载并导入验证数据集:

val_fmnist = datasets.FashionMNIST(data_folder, download=True, train=False)
val_images = val_fmnist.data
val_targets = val_fmnist.targets

(3) 导入相关库并定义可用设备:

import matplotlib.pyplot as plt
import numpy as np
from torch.utils.data import Dataset, DataLoader
import torch
import torch.nn as nn
device = 'cuda' if torch.cuda.is_available() else 'cpu'

(4) 定义数据集类 FashionMNIST、用于对一批数据进行训练的函数 train_batch()、计算准确率 accuracy(),然后定义模型架构、损失函数和优化器 get_model()

class FMNISTDataset(Dataset):
    def __init__(self, x, y):
        x = x.float()/255
        x = x.view(-1,28*28)
        self.x, self.y = x, y 
    def __getitem__(self, ix):
        x, y = self.x[ix], self.y[ix] 
        return x.to(device), y.to(device)
    def __len__(self): 
        return len(self.x)

from torch.optim import SGD, Adam
def get_model():
    model = nn.Sequential(
        nn.Linear(28 * 28, 1000),
        nn.ReLU(),
        nn.Linear(1000, 10)
    ).to(device)

    loss_fn = nn.CrossEntropyLoss()
    optimizer = Adam(model.parameters(), lr=1e-2)
    return model, loss_fn, optimizer

def train_batch(x, y, model, optimizer, batch_loss_fn):
    model.train()
    prediction = model(x)
    batch_loss = batch_loss_fn(prediction, y)
    batch_loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    return batch_loss.item()

def accuracy(x, y, model):
    model.eval()
    # 使用 with torch.no_grad() 与 @torch.no_grad 效果相同
    # 在 with 作用域内的并不计算梯度
    with torch.no_grad():
        prediction = model(x)
    max_values, argmaxes = prediction.max(-1)
    is_correct = argmaxes == y
    return is_correct.cpu().numpy().tolist()

(5) 定义用于获取数据的函数 get_data(),此函数将返回批大小为 32 的训练数据和批大小为验证数据长度的验证数据集(验证数据不会用于训练模型,只用于了解模型在的未见过数据上的准确率):

def get_data(): 
    train = FMNISTDataset(tr_images, tr_targets) 
    trn_dl = DataLoader(train, batch_size=32, shuffle=True)
    val = FMNISTDataset(val_images, val_targets) 
    val_dl = DataLoader(val, batch_size=len(val_images), shuffle=False)
    return trn_dl, val_dl

在以上代码中,除了 train 对象之外,我们还创建了一个名为 valFMNISTDataset 类的对象。此外,用于验证模型的 DataLoader (val_dl) 批大小为 len (val_images),而 trn_dl 的批大小为 32。接下来,我们将根据模型的训练时间和准确率来了解不同 batch_size 对模型的影响。

(6) 定义计算验证数据损失的函数 val_loss

@torch.no_grad()
def val_loss(x, y, model, loss_fn):
    prediction = model(x)
    val_loss = loss_fn(prediction, y)
    return val_loss.item()

应用 @torch.no_grad() 声明无需训练模型而只获取预测结果,通过损失函数 loss_fn() 传递预测结果并返回损失值 val_loss.item()

(7) 获取训练和验证数据加载器,并初始化模型、损失函数和优化器:

trn_dl, val_dl = get_data()
model, loss_fn, optimizer = get_model()

(8) 训练模型,首先初始化包含训练准确率、验证准确率和损失值的列表:

train_losses, train_accuracies = [], []
val_losses, val_accuracies = [], []

循环 10epoch 并初始化包含给定 epoch 内各批训练数据的准确率和损失的列表:

for epoch in range(10):
    print(epoch)
    train_epoch_losses, train_epoch_accuracies = [], []

遍历一批训练数据并计算一个 epoch 内的准确率( train_epoch_accuracy )和损失值( train_epoch_loss):

    for ix, batch in enumerate(iter(trn_dl)):
        x, y = batch
        batch_loss = train_batch(x, y, model, optimizer, loss_fn)
        train_epoch_losses.append(batch_loss) 
    train_epoch_loss = np.array(train_epoch_losses).mean()

计算验证数据的损失值和准确率(验证数据的批大小等于验证数据的长度):

    for ix, batch in enumerate(iter(trn_dl)):
        x, y = batch
        is_correct = accuracy(x, y, model)
        train_epoch_accuracies.extend(is_correct)
    train_epoch_accuracy = np.mean(train_epoch_accuracies)

在以上的代码中,使用 val_loss 函数计算验证数据的损失值,并存储在 validation_loss 变量中。此外,关于验证数据样本是否预测正确的结果存储在 val_is_correct 列表中,使用 val_epoch_accuracy 变量表示验证数据集的平均准确率。

最后,将训练和验证数据集的准确率和损失值添加相应列表中,以查看模型训练的改进:

    train_losses.append(train_epoch_loss)
    train_accuracies.append(train_epoch_accuracy)
    val_losses.append(validation_loss)
    val_accuracies.append(val_epoch_accuracy)

可视化模型训练过程中训练和验证数据集中的准确率和损失值的变化:

epochs = np.arange(10)+1
import matplotlib.ticker as mtick
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
plt.subplot(121)
plt.plot(epochs, train_losses, 'bo', label='Training loss')
plt.plot(epochs, val_losses, 'r', label='Validation loss')
plt.gca().xaxis.set_major_locator(mticker.MultipleLocator(1))
plt.title('Training and validation loss when batch size is 32')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.grid('off')
plt.subplot(122)
plt.plot(epochs, train_accuracies, 'bo', label='Training accuracy')
plt.plot(epochs, val_accuracies, 'r', label='Validation accuracy')
plt.gca().xaxis.set_major_locator(mticker.MultipleLocator(1))
plt.title('Training and validation accuracy when batch size is 32')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.gca().set_yticklabels(['{:.0f}%'.format(x*100) for x in plt.gca().get_yticks()]) 
plt.legend()
plt.grid('off')
plt.show()

批大小为 32

如上图所示,当批大小为 32 时,在 epoch 结束时,训练和验证数据集上准确率约为 85%。接下来,改变 get_data() 函数中的 DataLoaderbatch_size 参数,以观察不同批大小对模型准确率的影响。

3.2 批大小为 30,000

在本节中,将批大小设定为 30,000 (除此之外,代码与上一小节完全相同),以便了解改变批大小对模型训练的影响。
修改 get_data() 使其批大小为 30,000,同时从训练数据集中获取训练 DataLoader

def get_data(): 
    train = FMNISTDataset(tr_images, tr_targets) 
    trn_dl = DataLoader(train, batch_size=30000, shuffle=True)
    val = FMNISTDataset(val_images, val_targets) 
    val_dl = DataLoader(val, batch_size=len(val_images), shuffle=False)
    return trn_dl, val_dl

当批大小为 30,000 时,训练和验证的准确率和损失随 epoch 的变化如下:

准确率和损失变化情况

可以看到模型性能(准确率和损失值)不如批大小为 32 的模型,因为当批大小为 32 时,权重的更新次数较多,当批大小为 30,000 时,每个 epoch 仅会进行 2 次( 60000/30000 )权重更新。因此,批大小越小,权重更新的次数就越多,并且通常在 epoch 数相同的情况下,准确率越好。同时,应注意批大小也不能过小,这可能导致训练时间过长,以及过拟合情况的出现。

小结

批处理是神经网络中的重要操作,通过同时处理多个样本来提高计算效率、稳定梯度估计以及有效利用内存。合理选择批大小可以在保证计算效率的同时,获得良好的训练结果。本节中,主要介绍了批大小在神经网络训练过程中的重要作用。

系列链接

PyTorch深度学习实战(1)——神经网络与模型训练过程详解
PyTorch深度学习实战(2)——PyTorch基础
PyTorch深度学习实战(3)——使用PyTorch构建神经网络
PyTorch深度学习实战(4)——常用激活函数和损失函数详解
PyTorch深度学习实战(5)——计算机视觉基础
PyTorch深度学习实战(6)——神经网络性能优化技术

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/814291.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

微服务体系<2> ribbon

1. 什么是负载均衡 比如说像这样 一个请求打在了nginx上 基于nginx进行负载分流 这就是负载均衡但是负载均衡分 服务端负载均衡和客户端负载均衡 客户端负载均衡 我user 从注册中心拉取服务 拉取order列表,然后发起getOne()调用 这就是客户端负载均衡 特点就是我…

【echarts】用js与echarts数据图表化,折线图、折线图堆叠、柱状图、折柱混合、环形图

echarts 是一个基于 JavaScript 的开源可视化库,用于构建交互式和自定义的图表,使数据更加直观和易于理解,由百度开发并于2018年捐赠给 Apache 软件基金会,后来改名为Apache ECharts 类似的还有Chart.js Chart.js地址&#xff1…

从此告别涂硅脂 利民推出新款CPU固态导热硅脂片:一片26.9元

利民(Thermalright)近日推出了新款Heilos CPU固态导热硅脂片,其中Intel版为26.9元,AMD版售价29.9元。 以往向CPU上涂硅脂,需要先挤一粒绿豆大小的硅脂,然后用塑料片涂匀,操作和清理对新手都极不友好。 该固态导热硅脂片…

string【2】模拟实现string类

string模拟实现 引言(实现概述)string类方法实现默认成员函数构造函数拷贝构造赋值运算符重载析构函数 迭代器beginend 容量size、capacity、emptyreserveresize 访问元素operator[] 修改insert插入字符插入字符串 appendpush_backoperatoreraseclearswa…

Python web实战 | 使用 Flask 实现 Web Socket 聊天室

概要 今天我们学习如何使用 Python 实现 Web Socket,并实现一个实时聊天室的功能。本文的技术栈包括 Python、Flask、Socket.IO 和 HTML/CSS/JavaScript。 什么是 Web Socket? Web Socket 是一种在单个 TCP 连接上进行全双工通信的协议。它是 HTML5 中的…

SAMBA 文件分享相关 笔记

目标说明 在Linux 安装Samba,然后在Windows端映射为网络硬盘 流程 Linux 端命令 apt install samba -y 默认情况下软件会询问是否迁移系统网络设置以搭建协议,选择迁移即可修改配置文件 vim /etc/samba/smb.conf Samba 的配置文件中会带一个名为 prin…

【Mybatis】Mybatis架构简介

文章目录 1.整体架构图2. 基础支撑层2.1 类型转换模块2.2 日志模块2.3 反射工具模块2.4 Binding 模块2.5 数据源模块2.6缓存模块2.7 解析器模块2.8 事务管理模块 3. 核心处理层3.1 配置解析3.2 SQL 解析与 scripting 模块3.3 SQL 执行3.4 插件 4. 接口层 1.整体架构图 MyBatis…

第5集丨webpack 江湖 —— 项目发布 和 source map

目录 一、webpack项目发布1.1 新增发布(build)命令1.2 优化js和图片文件的存放路径1.3 执行1.4 效果 二、clean-webpack-plugin插件2.1 安装2.2 配置2.3 执行 三、source map3.1 配置3.2 生成的source map文件 四、定义符4.1 配置4.2 使用 五、工程附件汇总5.1 webpack.config.…

大麦订单一键生成 仿大麦订单生成

后台一键生成链接,独立后台管理 教程:修改数据库config/Conn 不会可以看源码里有教程 下载程序:https://pan.baidu.com/s/16lN3gvRIZm7pqhvVMYYecQ?pwd6zw3

强化学习(EfficientZero)(应用于图像和声音)

目录 摘要 1.背景介绍 2.MCTS(蒙特卡洛树搜索)(推理类模型,棋类效果应用好,控制好像也不错) 3.MUZERO 4.EfficientZero(基于MUZERO) 展望 参考文献 摘要 在文中,基于…

【雕爷学编程】MicroPython动手做(20)——掌控板之三轴加速度5

知识点:什么是掌控板? 掌控板是一块普及STEAM创客教育、人工智能教育、机器人编程教育的开源智能硬件。它集成ESP-32高性能双核芯片,支持WiFi和蓝牙双模通信,可作为物联网节点,实现物联网应用。同时掌控板上集成了OLED…

数学建模 好文章和资源推荐

数学建模入门篇(0基础必看,全是自己的经验) 【竞赛|数学建模】Part 1:什么是数学建模和各模块介绍 0基础小白,如何入门数学建模? 数学建模入门篇(0基础必看,全是自己的经验) 什么是数学建模 重申了一下题目&#xff…

基于SpringBoot+Vue的地方废物回收机构管理系统设计与实现(源码+LW+部署文档等)

博主介绍: 大家好,我是一名在Java圈混迹十余年的程序员,精通Java编程语言,同时也熟练掌握微信小程序、Python和Android等技术,能够为大家提供全方位的技术支持和交流。 我擅长在JavaWeb、SSH、SSM、SpringBoot等框架…

并发编程——线程池

1.概述 如果并发的线程过多,而且执行的时间都非常短,如果这样,每次都要创建线程就会大大降低效率,我们可以通过线程池来解决,JDK5增加了内置线程池ThreadPollExecutor。 2.线程池的优点 1.重复利用,降低…

【LeetCode】单链表——刷题

你曾经灼热的眼眶,是人生中少数的笨拙又可贵的时刻。 文章目录 1.反转单链表 题目思路及图解 代码中需要注意的问题 2.移除链表元素 题目思路及图解 代码中需要注意的问题 大家好,我是纪宁。 这篇文章分享给大家一些经典的单链表leetcode笔试题的…

【Unity 实用插件篇】 | 行为状态机StateMachine,规范化的管理对象行为

前言【Unity 实用插件篇】 | 行为状态机StateMachine 学习使用一、StateMachine行为状态机 介绍二、StateMachine 结构分析三、StateMachine状态机详细使用流程3.1 第一步:创建状态机Transition Table SO3.2 第二步:创建对应状态的 State SO3.3 第三步:创建状态的切换条件 C…

Hadoop学习指南:探索大数据时代的重要组成——运行环境搭建

Hadoop运行环境搭建(开发重点) 模板虚拟机环境准备 数据来源层 安装模板虚拟机,IP地址192.168.10.100、主机名称hadoop100、内存4G、硬盘50G hadoop100 虚拟机配置要求如下(本文Linux系统全部以CentOS-7.5-x86-1804为例&#…

小研究 - Java 虚拟机实现原理分析

针对虚拟机的底层实现原理及相关实现过程,讨论了 Java 语言的跨平台原理以及相关工作机制,分析了 JVM 底层各数据区内存管理过程,阐述了 JVM 在 Java 语言中的核心作用以及重要地位。 目录 1 概述 2 Java 平台分层原理 3 虚拟机工作原理 …

CDC一键入湖:当 Apache Hudi DeltaStreamer 遇见 Serverless Spark

文章目录 1. 整体架构2. 环境准备3. 配置全局变量4. 创建专属工作目录和存储桶5. 创建 EMR Serverless Execution Role6. 创建 EMR Serverless Application7. 提交 Apache Hudi DeltaStreamer CDC 作业7.1 准备作业描述文件7.2 提交作业7.3 监控作业7.4 错误检索7.5 停止作业 8…

仿转转闲鱼链接后台生成

教程:修改数据库账号密码直接使用。 源码带有教程! 下载程序:https://pan.baidu.com/s/16lN3gvRIZm7pqhvVMYYecQ?pwd6zw3