模型训练中出现loss为NaN怎么办?

news2024/9/20 20:50:55

在这里插入图片描述

文章目录

  • 一、模型训练中出现loss为NaN原因
    • 1. 学习率过高
    • 2. 梯度消失或爆炸
    • 3. 数据不平衡或异常
    • 4. 模型不稳定
    • 5. 过拟合
  • 二、 针对梯度消失或爆炸的解决方案
    • 1. 使用`torch.autograd.detect_anomaly()`
    • 2. 使用 torchviz 可视化计算图
    • 3. 检查梯度的数值范围
    • 4. 调整梯度剪裁
  • 三、更具体的办法
    • 3.1 可能导致梯度爆炸的部分
    • 3.2 解决方案

一、模型训练中出现loss为NaN原因

1. 学习率过高

在训练的某个阶段,学习率可能设置得过高,导致模型参数更新幅度过大,甚至可能出现数值不稳定的情况。你可以尝试降低学习率,并观察训练过程中的变化。

2. 梯度消失或爆炸

如果模型的某些层出现梯度消失或爆炸的问题,可能会导致loss变得异常低。你可以检查梯度的大小,确保它们在合理范围内。

3. 数据不平衡或异常

训练数据中可能存在异常值或分布不平衡的情况,导致模型在某些批次的训练过程中出现异常。你可以检查数据集,确保数据质量。

4. 模型不稳定

模型架构或训练过程中的某些设置可能导致不稳定,比如过深的网络、过复杂的模型等。你可以尝试简化模型架构或添加正则化项。

5. 过拟合

模型可能在某些阶段已经过拟合到训练数据上,导致训练loss异常低而验证loss较高。你可以通过早停法(early stopping)、正则化、数据增强等方法来缓解过拟合问题。
解决方法

  1. 调节学习率:适当降低学习率,观察训练过程中的变化。
  2. 检查梯度:通过torch.autograd检查梯度的大小,确保没有出现梯度消失或爆炸。
  3. 数据检查:确保数据集没有异常值或分布不平衡的情况。
  4. 模型架构:简化模型架构,增加正则化项,如L2正则化、dropout等。
  5. 验证集监控:通过监控验证集的loss和指标,防止过拟合。\

二、 针对梯度消失或爆炸的解决方案

使用 torch.autograd.detect_anomaly() 和相关工具确实可以帮助你检测和排除训练过程中出现的梯度问题。以下是如何在你的代码中使用这些工具来检测异常和可视化梯度的示例。

1. 使用torch.autograd.detect_anomaly()

这个函数可以帮助检测反向传播过程中出现的异常,并输出具体的错误信息和位置。

import torch

# 定义模型
model = MyModel()

# 定义损失函数和优化器
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# 输入数据
inputs = torch.randn(56, 1024, 28, 28)
targets = torch.randint(0, 10, (56,))

# 在训练过程中使用 detect_anomaly
with torch.autograd.detect_anomaly():
    outputs = model(inputs)
    loss = criterion(outputs, targets)
    loss.backward()
    optimizer.step()

2. 使用 torchviz 可视化计算图

torchviz 是一个可以帮助你可视化计算图的工具,这对于调试复杂的模型非常有用。

首先,安装 torchviz:

pip install torchviz

然后,可以使用以下代码来生成和保存计算图:

from torchviz import make_dot

# 定义模型
model = MyModel()

# 输入数据
inputs = torch.randn(56, 1024, 28, 28)

# 获取模型输出
outputs = model(inputs)

# 创建计算图
dot = make_dot(outputs, params=dict(model.named_parameters()))

# 保存计算图
dot.format = 'png'
dot.render('model_graph')

3. 检查梯度的数值范围

你可以在每个训练步骤之后检查模型中各个参数的梯度,以确保梯度的数值范围正常。

# 定义模型
model = MyModel()

# 定义损失函数和优化器
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# 输入数据
inputs = torch.randn(56, 1024, 28, 28)
targets = torch.randint(0, 10, (56,))

# 在训练过程中使用 detect_anomaly
with torch.autograd.detect_anomaly():
    outputs = model(inputs)
    loss = criterion(outputs, targets)
    loss.backward()

    # 检查梯度数值范围
    for name, param in model.named_parameters():
        if param.grad is not None:
            grad_min = param.grad.min().item()
            grad_max = param.grad.max().item()
            print(f'{name} - grad_min: {grad_min}, grad_max: {grad_max}')
    
    optimizer.step()

4. 调整梯度剪裁

在训练过程中,可以使用梯度剪裁来防止梯度爆炸。以下是如何在 PyTorch 中实现梯度剪裁的示例:

# 定义模型
model = MyModel()

# 定义损失函数和优化器
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# 输入数据
inputs = torch.randn(56, 1024, 28, 28)
targets = torch.randint(0, 10, (56,))

# 在训练过程中使用 detect_anomaly
with torch.autograd.detect_anomaly():
    outputs = model(inputs)
    loss = criterion(outputs, targets)
    loss.backward()
    
    # 梯度剪裁
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
    
    optimizer.step()

通过以上方法,可以更好地检测和调试训练过程中出现的梯度问题,提高模型的训练稳定性和效率。如果在使用过程中发现任何异常或需要进一步调试,请随时提供更多细节。

三、更具体的办法

3.1 可能导致梯度爆炸的部分

  1. ReLU 激活函数的使用:激活函数可参考激活函数汇总
    ReLU 是一种常见的激活函数,但如果输入有较大的正值,经过 ReLU 之后,这些值会直接传递下去,可能导致后续层的梯度爆炸。考虑使用其他激活函数,如 Leaky ReLU 或 SELU,它们在某些情况下对梯度更友好。

    embedding_spatial_786 = torch.nn.functional.relu(embedding_spatial_786, inplace=True)
    
  2. 特征插值:
    插值操作可能会生成较大的值,尤其是在上采样过程中。如果插值后的值过大,可能会导致梯度爆炸。
    upsample_feat = F.interpolate(feat_high, scale_factor=2., mode=‘nearest’)

  3. 特征拼接:
    多个特征拼接后,如果这些特征值过大,会导致拼接后的张量值过大,进而影响后续层的梯度。

    inner_out = self.fpn_blocks[len(proj_feats) - 1 - idx](torch.concat([upsample_feat, feat_low], dim=1))
    
  4. 全连接层:
    全连接层的权重初始化方式可能会导致梯度爆炸。确保使用了合适的初始化方法,如 Xavier 初始化或 He 初始化。

  5. 权重共享:
    如果多个部分共享权重,需要确保这些共享权重不会导致梯度的累积效应。

3.2 解决方案

  1. 梯度剪裁:
    在反向传播过程中使用梯度剪裁,可以防止梯度爆炸。你可以在 optimizer.step() 之前加上梯度剪裁。

    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
    
  2. 使用更稳定的激活函数:
    尝试使用 Leaky ReLU 或 SELU,它们在某些情况下对梯度更友好。

  3. 检查权重初始化:
    确保所有层的权重初始化方式合理,避免初始值过大。

    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.Linear)):
            nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
    
  4. 监控梯度值:
    在每次反向传播后,监控梯度的值,确保梯度不会爆炸。

    for name, param in model.named_parameters():
        if param.grad is not None:
            grad_min = param.grad.min().item()
            grad_max = param.grad.max().item()
            print(f'{name} - grad_min: {grad_min}, grad_max: {grad_max}')
    

Enjoy~

∼ O n e   p e r s o n   g o   f a s t e r ,   a   g r o u p   o f   p e o p l e   c a n   g o   f u r t h e r ∼ \sim_{One\ person\ go\ faster,\ a\ group\ of\ people\ can\ go\ further}\sim One person go faster, a group of people can go further

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1934995.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

鸿蒙语言基础类库:【@system.request (上传下载)】

上传下载 说明: 从API Version 6开始,该接口不再维护,推荐使用新接口[ohos.request]。本模块首批接口从API version 4开始支持。后续版本的新增接口,采用上角标单独标记接口的起始版本。 导入模块 import request from system.re…

C语言之指针的奥秘(三)

一、字符指针变量 在指针的类型中&#xff0c;有字符指针char*&#xff0c;一般使用&#xff1a; #include<stdio.h> int main() {char ch w;char* p &ch;*p w;return 0; } 还有一种方式&#xff1a; #include<stdio.h> int main() {const char* p &qu…

Android 10.0 Launcher3拖拽图标进入hotseat自适应布局功能实现一

1.前言 在10.0的系统rom定制化开发中&#xff0c;在对于launcher3的一些开发定制中&#xff0c;在对hotseat的一些开发中&#xff0c;需要实现动态hotseat居中 的功能&#xff0c;就是在拖拽图标进入和拖出hotseat&#xff0c;都可以保持hotseat居中的功能&#xff0c;接下来分…

html2canvas + jspdf 纯前端HTML导出PDF的实现与问题

前言 这几天接到一个需求&#xff0c;富文本编辑器的内容不仅要展示出来&#xff0c;还要实现展示的内容导出pdf文件。一开始导出pdf的功能是由后端来做的&#xff0c;然后发现对于宽度太大的图片&#xff0c;导出的pdf文件里部分图片内容被遮盖了&#xff0c;但在前端是正常显…

S参数入门

一、说明 S参数全称为散射参数&#xff0c;主要用来作为描述线性无源互联结构的一种行为模型&#xff0c;来源于网络分析方法。网络分析法是一种频域方法&#xff0c;在一组离散的频率点上&#xff0c;通过在输入和输出端口得到的参量完全描述线性时不变系统&#xff08;定义参…

园区AR导航系统构建详解:从三维地图构建到AR融合导航的实现

随着现代园区规模的不断扩大与功能的日益复杂&#xff0c;传统的二维地图导航已难以满足访客高效、精准定位的需求。园区内部错综复杂的布局、频繁变更的商户位置常常让访客感到迷茫&#xff0c;造成寻路上的时间浪费。园区AR导航系统以创新的技术手段&#xff0c;破解了私域地…

对redis进行深入学习

目录 1. 什么是redis&#xff1f;1.1 为什么使用redis作为缓存&#xff1f;1.1.0 数据库&#xff08;MySQL&#xff09;与 redis1. 存储介质不同&#xff08;408选手应该都懂hh&#xff09;2. 数据结构优化3. I/O模型差异4. CPU缓存友好性5. 单线程与多线程差异6. 持久化与缓存…

Volatility:分析MS10-061攻击

1、概述 # 1&#xff09;什么是 Volatility Volatility是开源的Windows&#xff0c;Linux&#xff0c;MaC&#xff0c;Android的内存取证分析工具。基于Python开发而成&#xff0c;可以分析内存中的各种数据。Volatility支持对32位或64位Wnidows、Linux、Mac、Android操作系统…

2024算力基础设施安全架构设计与思考(免费下载)

算网安全体系是将数据中心集群、算力枢纽、一体化大数据中心三个层级的安全需求进行工程化解耦&#xff0c;从国家安全角度统筹设计&#xff0c;通过安全 服务化方式&#xff0c;依托威胁情报和指挥协同通道将三层四级安全体系串联贯通&#xff0c;达成一体化大数据安全目标。 …

插画插件:成都亚恒丰创教育科技有限公司

【插画插件&#xff1a;数字创意时代的艺术加速器】 在数字化浪潮汹涌的今天&#xff0c;视觉艺术以其独特的魅力穿梭于互联网的每一个角落&#xff0c;成为连接人心、传递情感与信息的桥梁。而在这股创意洪流中&#xff0c;插画插件以其高效、便捷、个性化的特点&#xff0c;…

1219:马走日

#include<bits/stdc.h> using namespace std; int vis[8][2]{-2,1,-1,2,1,2,2,1,2,-1,1,-2,-1,-2,-2,-1};//构造偏移量数组 int t,n,m,x,y,ans;//棋盘总共由(n)(m)个点 bool st[100][100];//如果st[i][j]0 表示i,j这个坐标没有走过 st[a][b]1表示a,b这个坐标走过 void d…

【05】LLaMA-Factory微调大模型——初尝微调模型

上文【04】LLaMA-Factory微调大模型——数据准备介绍了如何准备指令监督微调数据&#xff0c;为后续的微调模型提供高质量、格式规范的数据支撑。本文将正式进入模型微调阶段&#xff0c;构建法律垂直应用大模型。 一、硬件依赖 LLaMA-Factory框架对硬件和软件的依赖可见以下…

GPT-4o大语言模型优化、本地私有化部署、从0-1搭建、智能体构建

原文链接&#xff1a;GPT-4o大语言模型优化、本地私有化部署、从0-1搭建、智能体构建https://mp.weixin.qq.com/s?__bizMzUzNTczMDMxMg&mid2247608565&idx3&snd4e9d447efd82e8dd8192f7573886dab&chksmfa826912cdf5e00414e01626b52bab83a96199a6bf69cbbef7f7fe…

C语言 | Leetcode C语言题解之第257题二叉树的所有路径

题目&#xff1a; 题解&#xff1a; char** binaryTreePaths(struct TreeNode* root, int* returnSize) {char** paths (char**)malloc(sizeof(char*) * 1001);*returnSize 0;if (root NULL) {return paths;}struct TreeNode** node_queue (struct TreeNode**)malloc(size…

Mysql中的几种常见日志

引言 本文是对Mysql中几种常见日志及其作用的介绍 一、error log&#xff08;错误日志&#xff09; MySQL 中的 error log&#xff08;错误日志&#xff09;是一种非常重要的日志类型&#xff0c;它记录了 MySQL 服务器在启动、运行及关闭过程中遇到的所有重要事件、错误信…

python爬虫实现简单的代理ip池

python爬虫实现简单的代理ip池 我们在普通的爬虫过程中经常遇到一些网站对ip进行封锁的 下面演示一下普通的爬虫程序 使用requests.get爬取数据 这段代码是爬取豆瓣排行榜的数据&#xff0c;使用f12来查看请求的url和数据格式 代码 def requestData():# 爬取数据的urlur…

[Maven] 打包编译本地Jar包报错的几种解决办法

目录 方式1&#xff1a;通过scope指定 方式2&#xff1a;通过新建lib 方式3&#xff1a;通过build节点打包依赖​​​​​​​ 方式4&#xff1a;安装Jar包到本地 方式5&#xff1a;发布到远程私有仓库 方式6&#xff1a;删除_remote.repositories 方式7&#xff1a;打包…

Leetcode二分搜索法浅析

文章目录 1.二分搜索法1.1什么是二分搜索法&#xff1f;1.2解法思路 1.二分搜索法 题目原文&#xff1a; 给定一个 n 个元素有序的&#xff08;升序&#xff09;整型数组 nums 和一个目标值 target &#xff0c;写一个函数搜索 nums 中的 target&#xff0c;如果目标值存在返…

TCP重传机制详解

1.什么是TCP重传机制 在 TCP 中&#xff0c;当发送端的数据到达接收主机时&#xff0c;接收端主机会返回⼀个确认应答消息&#xff0c;表示已收到消息。 但是如果传输的过程中&#xff0c;数据包丢失了&#xff0c;就会使⽤重传机制来解决。TCP的重传机制是为了保证数据传输的…

决策树回归(Decision Tree Regression)

理论知识推导 决策树回归是一种非参数监督学习方法&#xff0c;用于回归问题。它通过将数据集划分成较小的子集来建立模型&#xff0c;并在这些子集上构建简单的预测模型&#xff08;通常是恒定值&#xff09;。下面是决策树回归的数学推导过程&#xff1a; 实施步骤与参数解读…