Pytorch 复习总结 3

news2024/10/4 14:42:34

Pytorch 复习总结,仅供笔者使用,参考教材:

  • 《动手学深度学习》
  • Stanford University: Practical Machine Learning

本文主要内容为:Pytorch 多层感知机。

本文先介绍了多层感知机的用法,再就训练过程中经常出现的过拟合现象提出解决办法。


Pytorch 语法汇总:

  • Pytorch 张量的常见运算、线性代数、高等数学、概率论 部分 见 Pytorch 复习总结1;
  • Pytorch 线性神经网络 部分 见 Pytorch 复习总结2;
  • Pytorch 多层感知机 部分 见 Pytorch 复习总结3;
  • Pytorch 深度学习计算 部分 见 Pytorch 复习总结4;
  • Pytorch 卷积神经网络 部分 见 Pytorch 复习总结5;
  • Pytorch 现代卷积神经网络 部分 见 Pytorch 复习总结6;

目录

  • 一. 多层感知机
    • 1. 读取数据集
    • 2. 神经网络模型
    • 3. 激活函数
    • 4. 损失函数
    • 5. 优化器
    • 6. 训练
  • 二. 过拟合的缓解
    • 1. 权重衰减
    • 2. Dropout

一. 多层感知机

虽然线性模型易于实现和理解、计算成本低、泛化能力强,但是对于一些非线性问题,可能会违反线性模型的单调性。为此,多层感知器引入了隐藏层来克服线性模型的限制,并且加入激活函数以增强网络非线性建模能力。

1. 读取数据集

同 Pytorch 复习总结 2 中 Softmax 回归的数据读取,继续使用 Fashion-MNIST 图像分类数据集:

import torch
import torchvision
from torch.utils import data
from torchvision import transforms

def load_data_fashion_mnist(batch_size, resize=None):
    """下载Fashion-MNIST数据集并将其加载到内存中"""
    trans = [transforms.ToTensor()]
    if resize:
        trans.insert(0, transforms.Resize(resize))
    trans = transforms.Compose(trans)
    mnist_train = torchvision.datasets.FashionMNIST(
        root="./data", train=True, transform=trans, download=True)
    mnist_test = torchvision.datasets.FashionMNIST(
        root="./data", train=False, transform=trans, download=True)
    return (data.DataLoader(mnist_train, batch_size, shuffle=True),
            data.DataLoader(mnist_test, batch_size, shuffle=False))

batch_size = 256
train_iter, test_iter = load_data_fashion_mnist(batch_size)

2. 神经网络模型

先将输入的图像展平,然后使用 2 个全连接层进行处理,中间的全连接层需要使用激活函数激活,最后一层全连接层作为输出:

from torch import nn
net = nn.Sequential(nn.Flatten(),
                    nn.Linear(784, 256),
                    nn.ReLU(),
                    nn.Linear(256, 10)
)

仍然使用 init_weights() 函数按正态分布初始化所有全连接层的权重:

def init_weights(m):
    if type(m) == nn.Linear:
        nn.init.normal_(m.weight, std=0.01)

net.apply(init_weights)

3. 激活函数

上一节使用了 ReLU 函数进行激活,在实际应用中,还可以使用 sigmoid、tanh 等函数激活。ReLU、sigmoid、tanh 函数的梯度可视化如下:

import torch
from matplotlib import pyplot as plt

x = torch.arange(-8.0, 8.0, 0.1, requires_grad=True)
# y = torch.relu(x)
# y = torch.sigmoid(x)
y = torch.tanh(x)
y.backward(torch.ones_like(x), retain_graph=True)
plt.figure(figsize=(5, 2.5))
plt.plot(x.detach(), x.grad)
plt.show()

4. 损失函数

同 Softmax 回归:

loss = nn.CrossEntropyLoss(reduction='none')

5. 优化器

同 Softmax 回归:

trainer = torch.optim.SGD(net.parameters(), lr=0.1)

6. 训练

同 Softmax 回归,可以将训练过程封装成函数:

def accuracy(y_hat, y):
    """计算预测正确的数量"""
    if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:
        y_hat = y_hat.argmax(axis=1)
    cmp = y_hat.type(y.dtype) == y
    return float(cmp.type(y.dtype).sum())

def train_net(net, train_iter, test_iter, loss, num_epochs, trainer):
    for epoch in range(num_epochs):     # 迭代训练轮次
        net.train()                     # 将模型设置为训练模式

        train_loss_sum = 0.0            # 训练损失总和
        train_acc_sum = 0.0             # 训练准确度总和
        sample_num = 0                  # 样本数

        for X, y in train_iter:
            y_hat = net(X)
            l = loss(y_hat, y)
            trainer.zero_grad()
            l.mean().backward()
            trainer.step()

            train_loss_sum += l.sum()
            train_acc_sum += accuracy(y_hat, y)
            sample_num += y.numel()

        train_loss = train_loss_sum / sample_num
        train_acc = train_acc_sum / sample_num

        net.eval()                      # 将模型设置为评估模式
        test_acc_sum = 0.0
        test_sample_num = 0
        for X, y in test_iter:
            test_acc_sum += accuracy(net(X), y)
            test_sample_num += y.numel()
        test_acc = test_acc_sum / test_sample_num

        print(f'epoch {epoch + 1}, '
            f'train loss {train_loss:.4f}, train acc {train_acc:.4f}, '
            f'test acc {test_acc:.4f}')
    
num_epochs = 10
train_net(net, train_iter, test_iter, loss, num_epochs, trainer)

二. 过拟合的缓解

当模型过于复杂、训练数据太少、迭代轮数太多时,就会出现过拟合现象。解决过拟合的方法有很多:

  • 增加数据量:增加训练数据可以帮助模型更好地学习数据的真实规律,减少过拟合的发生;
  • 简化模型:降低模型的复杂度,可以通过减少模型的参数数量、使用正则化等方法来实现;
  • 交叉验证:使用交叉验证来评估模型的泛化能力,选择最优的模型;
  • 提前停止:即 Dropout,在训练过程中监控模型在验证集上的表现,当验证集误差不再下降甚至开始上升时,及时停止训练,防止模型过拟合;
  • 集成学习:使用集成学习方法(如随机森林、梯度提升树等)降低模型的方差,提高泛化能力。

下面介绍几种常用的正则化方法。

1. 权重衰减

权重衰减 (Weight Decay) 通过向损失函数中添加一个惩罚项来减小模型复杂度,以防止过拟合。惩罚项也叫 正则项,通常是权重的平方和(即 L2 范数)或权重的绝对值和(即 L1 范数)乘以一个正则化系数。

以线性回归的损失函数 L ( w , b ) L(\mathbf{w}, b) L(w,b) 为例,使用优化器训练时,在损失函数 L ( w , b ) L(\mathbf{w}, b) L(w,b) 上添加 L2 范数如下:
L ( w , b ) + λ 2 ∥ w ∥ 2 = 1 n ∑ i = 1 n 1 2 ( w ⊤ x ( i ) + b − y ( i ) ) 2 + λ 2 ∥ w ∥ 2 L(\mathbf{w}, b)+\frac{\lambda}{2}\|\mathbf{w}\|^2\\ =\frac{1}{n} \sum_{i=1}^n \frac{1}{2}\left(\mathbf{w}^{\top} \mathbf{x}^{(i)}+b-y^{(i)}\right)^2+\frac{\lambda}{2}\|\mathbf{w}\|^2\\ L(w,b)+2λw2=n1i=1n21(wx(i)+by(i))2+2λw2

损失函数中没有添加偏置 b b b 的惩罚项,因为一般情况下,网络输出层的偏置项不需要正则化。代入 w \mathbf{w} w 的参数更新表达式为:
w ← ( 1 − η λ ) w − η ∣ B ∣ ∑ i ∈ B x ( i ) ( w ⊤ x ( i ) + b − y ( i ) ) \mathbf{w} \leftarrow(1-\eta \lambda) \mathbf{w}-\frac{\eta}{|\mathcal{B}|} \sum_{i \in \mathcal{B}} \mathbf{x}^{(i)}\left(\mathbf{w}^{\top} \mathbf{x}^{(i)}+b-y^{(i)}\right) w(1ηλ)wBηiBx(i)(wx(i)+by(i))

要想对模型进行权重衰减,只需要在实例化优化器时通过 weight_decay 指定权重衰减参数。默认情况下,PyTorch 同时衰减权重和偏移:

trainer = torch.optim.SGD(net.parameters(), lr=lr)

如果想要只衰减权重,需要指定参数:

params_to_optimize = [
    {"params": net[0].weight, 'weight_decay': wd},
    {"params":net[0].bias}
]
trainer = torch.optim.SGD([
        {"params":net[0].weight,'weight_decay': wd},
        {"params":net[0].bias}], lr=lr)

2. Dropout

Dropout 通过在训练过程中随机地将网络 内部 的一部分神经元的输出设置为零,即以一定的概率 “丢弃” 这些神经元。这样可以防止神经元在训练过程中过于依赖其他神经元,从而降低了网络对特定神经元的依赖性,使得网络更具鲁棒性:
在这里插入图片描述

通常情况下,Dropout 只在训练过程中使用,不在推理阶段使用,因为推理时模型需要产生确定性的输出。

Dropout 需要在网络中添加 Dropout 层,一般位于激活函数后,并且给定 dropout 概率:

dropout1, dropout2 = 0.2, 0.5

net = nn.Sequential(nn.Flatten(),
        nn.Linear(784, 256),
        nn.ReLU(),
        nn.Dropout(dropout1),
        nn.Linear(256, 256),
        nn.ReLU(),
        nn.Dropout(dropout2),
        nn.Linear(256, 10)
)

def init_weights(m):
    if type(m) == nn.Linear:
        nn.init.normal_(m.weight, std=0.01)

net.apply(init_weights)

Dropout 概率的设置技巧是靠近输入层的地方设置较低的概率,远离输入层的地方设置较高的概率。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1462081.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

数据结构 计算结构体大小

一、规则: 操作系统制定对齐量: 64位操作系统,默认8Byte对齐 32位操作系统,默认4Byte对齐 结构体对齐规则: 1.结构体整体的大小,需要是最大成员对齐量的整数倍 2.结构体中每一个成员的偏移量需要存在…

IDEA 2021.3激活

1、打开idea,在设置中查找Settings/Preferences… -> Plugins 内手动添加第三方插件仓库地址:https://plugins.zhile.io搜索:IDE Eval Reset 插件进行安装。应用和使用,如图

合纵连横 – 以 Flink 和 Amazon MSK 构建 Amazon DocumentDB 之间的实时数据同步

在大数据时代,实时数据同步已经有很多地方应用,包括从在线数据库构建实时数据仓库,跨区域数据复制。行业落地场景众多,例如,电商 GMV 数据实时统计,用户行为分析,广告投放效果实时追踪&#xff…

pytorch: ground truth similarity matrix

按照真实标签排序pair-wise相似度矩阵的Pytorch代码 本文仅作留档,用于输出可视化 Inputs: Ground-truths Y ∈ R n 1 \mathbf{Y}\in\mathbb R^{n\times 1} Y∈Rn1, Similarity matrix A ∈ R n n \mathbf{A}\in\mathbb R^{n\times n} A∈RnnOutputs: Block dia…

【无标题】https://www.php.cn/faq/602417.html

https://www.php.cn/faq/602417.htmlTOC 欢迎使用Markdown编辑器 你好! 这是你第一次使用 Markdown编辑器 所展示的欢迎页。如果你想学习如何使用Markdown编辑器, 可以仔细阅读这篇文章,了解一下Markdown的基本语法知识。 新的改变 我们对Markdown编…

Mysql系列之命令行登录、连接工具登录、数据库表常用命令

登录与常用命令 连接工具登录命令行登录数据库1、查看数据库2、指定数据库3、查看当前数据库4、建库语句 数据表1、查看数据表2、查看表结构信息3、查看建表语句4、建表语句 连接工具登录 首先下载mysql连接工具,解压后直接打开软件,按以下步骤操作&…

车载氢气浓度传感器为氢能源车保驾护航

最近,车载氢气浓度传感器成为了一个热门话题。作为一名对科技充满热情的汽车爱好者,我自然也对这个话题产生了浓厚的兴趣。那么,车载氢气浓度传感器到底是什么?它又是如何工作的呢?下面就让我为你一一揭秘。 首先&…

C++ Primer 笔记(总结,摘要,概括)——第7章 类

目录 ​编辑 7.1 定义抽象数据类型 7.1.1 设计Sales_data类 7.1.2 定义改进的Sales_data类 7.1.3 定义类相关的非成员函数 7.1.4 构造函数 7.1.5 拷贝、赋值和析构 7.2 访问控制和封装 7.2.1 友元 7.3 类的其他特性 7.3.1 类成员再探 7.3.2 返回*this的成员函数 7.3.3 类类…

大蟒蛇(Python)笔记(总结,摘要,概括)——第10章 文件和异常

目录 10.1 读取文件 10.1.1 读取文件的全部内容 10.1.2 相对文件路径和绝对文件路径 10.1.3 访问文件中的各行 10.1.4 使用文件的内容 10.1.5 包含100万位的大型文件 10.1.6 圆周率中包含你的生日吗 10.2 写入文件 10.2.1 写入一行 10.2.2 写入多行 10.3 异常 10.3.1 处理Ze…

SpringBoot整合POIExcel: 实现导入导出Excel功能

SpringBoot整合POIExcel: 实现导入导出Excel功能 SpringBoot整合POIExcel: 实现导入导出Excel功能摘要引言依赖Poi包结构读取Excel表格读取Excel表格写入Excel表格 实战测试导入表格导出表格代码实现细节 博主 默语带您 Go to New World. ✍ 个人主页—— 默语 的博客&#x1f…

C++面向对象程序设计-北京大学-郭炜【课程笔记(四)】

C面向对象程序设计-北京大学-郭炜【课程笔记(四)】 1、this指针1.1、this指针的作用1.2、this指针和静态成员函数 2、静态成员变量和静态成员函数2.1、基本概念2.2、基本概念总结2.3、如何访问静态成员2.4、静态成员变量的使用场景(重要&…

stm32——hal库学习笔记(ADC)

这里写目录标题 一、ADC简介(了解)1.1,什么是ADC?1.2,常见的ADC类型1.3,并联比较型工作示意图1.4,逐次逼近型工作示意图1.5,ADC的特性参数1.6,STM32各系列ADC的主要特性 …

(done) Positive Semidefinite Matrices 什么是半正定矩阵?如何证明一个矩阵是半正定矩阵? 可以使用特征值

参考视频:https://www.bilibili.com/video/BV1Vg41197ew/?vd_source7a1a0bc74158c6993c7355c5490fc600 参考资料(半正定矩阵的定义):https://baike.baidu.com/item/%E5%8D%8A%E6%AD%A3%E5%AE%9A%E7%9F%A9%E9%98%B5/2152711?frge_ala 看看半正定矩阵的…

《最新出炉》系列初窥篇-Python+Playwright自动化测试-21-处理鼠标拖拽-番外篇

1.简介 前边宏哥拖拽有提到那个反爬虫机制,加了各种参数,以及加载js脚本文件还是有问题,偶尔宏哥好像发现了解决问题的办法,看到了黎明的曙光,宏哥就说试一下看看行不行,万一实现了。结果宏哥试了结果真的…

从零开始学习Netty - 学习笔记 - NIO基础 - 网络编程: Selector

4.网络编程 4.1.非阻塞 VS 阻塞 在网络编程中,**阻塞(Blocking)和非阻塞(Non-blocking)**是两种不同的编程模型,描述了程序在进行网络通信时的行为方式。 阻塞(Blocking)&#xff1…

Kubernetes 卷存储 NFS | nfs搭建配置 原理介绍 nfs作为存储卷使用

1、NFS介绍 NFS(Network File System)是一种分布式文件系统协议,允许客户端远程访问服务器上的文件,实现数据共享。它整合多个存储设备为统一文件系统,方便数据存储和管理,支持负载均衡和故障转移&#xf…

《咸鱼之王》简单拆解图(持续更新)

文章目录 一、 介绍二、 角色设定阿咸咸将 三、游戏拆解 一、 介绍 《咸鱼之王》是一款由阿咸工作室开发的手机游戏,战斗方式为回合制卡牌对战,同时玩家点击屏幕可以为阵容提供助攻。该游戏于2021年3月4日公测。 在游戏中,玩家将化身主角阿…

测试环境搭建整套大数据系统(四:ubuntu22.4创建普通用户)

一:创建用户,修改密码,增加sudo权限。 useradd dolphinscheduler #输入密码 passwd dolphinscheduler # 配置 sudo 免密 sed -i $adolphinscheduler ALL(ALL) NOPASSWD: NOPASSWD: ALL /etc/sudoers sed -i s/Defaults requirett/#Defa…

【Java面试】MQ(Message Queue)消息队列

目录 一、MQ介绍二、MQ的使用1应用解耦2异步处理3流量削峰4日志处理5消息通讯三、使用 MQ 的缺陷1.系统可用性降低:2.系统复杂性变高3.一致性问题四、常用的 MQActiveMQ:RabbitMQ:RocketMQ:Kafka:五、如何保证MQ的高可用?ActiveMQ:RabbitMQ:RocketMQ:Kafka:六、如何保…

Python环境下基于门控双注意力机制的滚动轴承剩余使用寿命RUL预测(Tensorflow模块)

机械设备的寿命是其从开始工作持续运行直至故障出现的整个时间段,以滚动轴承为例,其寿命为开始转动直到滚动体或是内外圈等元件出现首次出现故障前。目前主流的滚动轴承RUL预测分类方法包含两种:一是基于物理模型的RUL预测方法,二…