系列文章目录
PyTorch|Dataset与DataLoader使用、构建自定义数据集
文章目录
- 系列文章目录
- 一、Transforms
- 二、构建神经网络模型
- 三、模型分层
- (一)模型输入
- (二)nn.Flatten
- (三)nn.Linear
- (四)非线性激活函数nn.ReLU
- (五)nn.Sequential
- (六)nn.Softmax
- (七)模型参数
- 四、nn.Module源码
- (一)init函数
- (二)register_buffer函数
- (三)register_parameter函数
- (四)add_module函数、register_module函数、get_submodule函数
- (五)get_parameter函数、get_buffer函数
- (六)_apply函数和apply函数
- (七)cuda函数、xpu函数、cpu函数
- (八)type函数、float函数、double函数、half函数、bfloat16函数
- (九)to函数、to_empty函数
- (十)__getattr__函数、parameters函数、buffers函数、modules函数
- (十一)_save_to_state_dict函数、state_dict函数、_load_from_state_dict函数、load_state_dict函数
- (十二)train函数、eval函数
- (十三)requires_grad_函数、zero_grad函数
一、Transforms
数据并不总是以训练机器学习算法所需的最终处理形式出现。Transforms是对数据的特征和标签等进行变换,使其满足神经网络的输入要求。Transforms函数一般是在Dataset中定义好,然后通过get_item应用。
- transform:修改特征
- target_transform:修改标签
FashionMNIST数据集为PILlmage格式,标签为整数。神经网络的训练需要将特征归一化张量,标签是单热编码张量。为了进行这些转换,我们使用ToTensor和Lambda:
- ToTensor将PIL图像或NumPy数组转换为FloatTensor,并在范围内缩放图像的像素强度值[0,1]。
- Lambda变换应用于任何用户定义的Lambda函数。这里定义了一个函数来将整数转换为一个独热编码张量。它首先创建一个大小为10的零张量(我们数据集中的标签数量),并调用scatter_,它在标签y给出的索引上赋值1。
import torch
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda
ds = datasets.FashionMNIST(
root="data",
train=True,
download=True,
transform=ToTensor(),
target_transform=Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1))
二、构建神经网络模型
神经网络由对数据执行操作的层/模块组成。torch nn命名空间提供了构建自己的神经网络所需的所有构建块。PyTorch中的每个模块都是n. module的子类。
构建一个神经网络来对FashionMNIST数据集中的图像进行分类:
导入相关库:
import os
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
确定训练的设备:
device = (
"cuda"
if torch.cuda.is_available()
else "mps"
if torch.backends.mps.is_available()
else "cpu"
)
print(f"Using {device} device")
定义分类模型网络:
所有的层、网络、模型都需要继承自nn.Module父类,并且通常需要定义两个方法:
- init方法:创建子模块,初始化神经网络层
- forward方法:对输入数据的前向运算
class NeuralNetwork(nn.Module):
def __init__(self):
super().__init__() #调用父类init方法
self.flatten = nn.Flatten()
# 线性relu堆叠模块
self.linear_relu_stack = nn.Sequential( #串联模块
nn.Linear(28*28, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, 10),
)
def forward(self, x):
x = self.flatten(x) #维度展开,铺平
logits = self.linear_relu_stack(x)
return logits
调用分类网络:
model = NeuralNetwork().to(device)
print(model)
使用分类网络:
X = torch.rand(1, 28, 28, device=device)
logits = model(X)
pred_probab = nn.Softmax(dim=1)(logits)
y_pred = pred_probab.argmax(1)
print(f"Predicted class: {y_pred}")
三、模型分层
(一)模型输入
采用3张大小为28x28的图像的样本minibatch输入到网络中:
input_image = torch.rand(3,28,28) #随机生成一个维度为(3,28,28)的张量
print(input_image.size())
输入张量的大小:
(二)nn.Flatten
从start_dim维度到end_dim维度进行铺平,默认从第一维到最后一维(最终只保留第0维和其他维共两个维度)
调用flatten之后维度就从3x28x28转换为了3x784:
flatten = nn.Flatten()
flat_image = flatten(input_image)
print(flat_image.size())
(三)nn.Linear
nn.Linear包含输入维度、输出维度、偏置、设备、类型等参数,Linear层还包括weight和bias属性。
将维度为[3,784]的数据输入到线性层中,返回输出维度为[3,20]。
layer1 = nn.Linear(in_features=28*28, out_features=20)
hidden1 = layer1(flat_image)
print(hidden1.size())
(四)非线性激活函数nn.ReLU
print(f"Before ReLU: {hidden1}\n\n")
hidden1 = nn.ReLU()(hidden1)
print(f"After ReLU: {hidden1}")
(五)nn.Sequential
nn.Sequential是一个关于模块的堆叠容器。
seq_modules = nn.Sequential(
flatten,
layer1,
nn.ReLU(),
nn.Linear(20, 10)
)
input_image = torch.rand(3,28,28)
logits = seq_modules(input_image)
(六)nn.Softmax
实例归一化:
softmax = nn.Softmax(dim=1)
pred_probab = softmax(logits)
(七)模型参数
print(f"Model structure: {model}\n\n")
for name, param in model.named_parameters():
print(f"Layer: {name} | Size: {param.size()} | Values : {param[:2]} \n")
由于第1层和第3层是relu函数,不含参数,所以不会打印出来。
四、nn.Module源码
(一)init函数
默认情况下training=True
(二)register_buffer函数
神经网络中的“buffer”: 通常指代在网络中的某些层或操作中存储或缓存的临时数据。这些缓冲区可以在网络的前向传播和反向传播过程中被使用,以帮助网络进行参数更新和计算梯度。例如:
- Batch Normalization 中的均值和方差: 在批量归一化层中,通常会在训练过程中计算每个批次的输入数据的均值和方差,并将它们存储在缓冲区中。这些均值和方差用于标准化输入数据,以便提高网络的训练效果。
- 滑动平均(Exponential Moving Average,EMA): 在一些优化算法(如 Momentum、Adam 等)中,会使用滑动平均来估计参数的移动均值,以稳定优化过程。这些移动均值通常存储在缓冲区中,并在每次迭代中更新。
- 卷积层的权重和偏置: 在卷积神经网络中,卷积层的权重和偏置通常存储在缓冲区中。这些参数在网络的训练过程中被更新,并在前向传播和反向传播中被使用。
- 循环神经网络(RNN)中的隐藏状态: 在循环神经网络中,隐藏状态通常被存储在缓冲区中,并在每个时间步被更新。这些隐藏状态在网络的每个时间步被传递和使用。
register_buffer函数的作用: 定义一组参数,该组参数在模型训练时不会更新(即调用 optimizer.step() 后该组参数不会变化,只可人为地改变它们的值),但是保存模型时,该组参数又作为模型参数不可或缺的一部分被保存。
(三)register_parameter函数
Parameter和Buffer的区分:
- 模型中需要进行更新的参数注册为Parameter,不需要进行更新的参数注册为buffer
- 模型保存的参数是 model.state_dict() 返回的 OrderDict
- 模型进行设备移动时,模型中注册的参数(Parameter和buffer)会同时进行移动
register_parameter函数主要用于注册一个可训练更新的参数: 将一个不可训练的类型Tensor转换成可以训练的类型parameter,并将这个parameter绑定到这个module里面,相当于变成了模型的一部分,成为了模型中可以根据训练进行变化的参数。
使用实例:
(四)add_module函数、register_module函数、get_submodule函数
add_module函数:往当前module中再去增加一个子模块,这个子模块会加入到self._modules字典中
register_module函数:用于注册模块
get_submodule函数:用于获取当前模块中的子模块
(五)get_parameter函数、get_buffer函数
get_parameter函数:根据字符串获得对应的模型参数
get_buffer函数:根据字符串获得对应的buffer
(六)_apply函数和apply函数
_apply函数:
- 对所有的子模块调用某个function
- 对所有的参数调用某个function
- 对buffer变量调用某个function
apply函数:
- 在模型参数初始化时会用到apply函数,主要作用是递归的将某个function运用到子模块上
(七)cuda函数、xpu函数、cpu函数
cuda函数是将所有的模型参数及buffer变量移动到gpu上
xpu函数、cpu函数类似,是将所有的模型参数及buffer变量移动到xpu、cpu上
这三个函数本质上都是使用的_apply函数
(八)type函数、float函数、double函数、half函数、bfloat16函数
- type函数:将所有的参数、buffer都转化一个数据类型
- float函数、double函数、half函数、bfloat16函数都是实现对于浮点数的转换(都是转换函数,但是只针对浮点类型)
(九)to函数、to_empty函数
- to_empty函数:将当前模型中的参数、buffer都移动到一个设备上,但是不会拷贝存储空间
- to函数有许多种用法:
(十)__getattr__函数、parameters函数、buffers函数、modules函数
__getattr__函数中的所有_parameters、_buffers、_modules没有对子模块进行遍历,只会去对当前模块进行查找
而parameters函数、buffers函数及modules函数都是递归的,会返回当前module及子module的参数或者buffers。
(十一)_save_to_state_dict函数、state_dict函数、_load_from_state_dict函数、load_state_dict函数
-
_save_to_state_dict函数:把当前module的所有参数及buffers放入一个字典中
-
state_dict函数:对当前module及子module的所有参数及buffers放入一个字典中
-
_load_from_state_dict函数:从一个state_dict中得到参数及buffers中然后载入到当前的模型中
-
load_state_dict函数:递归载入所有参数及buffers
(十二)train函数、eval函数
train函数:参数设置成true就说明已经将该模型设置为训练模式(包括子模块)
eval函数:将模型设置为评估模式,其实就是将train函数的参数设置为false
(十三)requires_grad_函数、zero_grad函数
自动微分是否需要在这些参数上记录操作,换句话说就是是否需要对这个模型记录导数值
zero_grad函数用于清理之前的累计梯度,在训练中一般不用对模型参数进行清理,优化器中会有用到。
参考:
PyTorch官方教程
7、深入剖析PyTorch nn.Module源码
8、深入剖析PyTorch的state_dict、parameters、modules源码