深度学习 | 常见问题及对策(过拟合、欠拟合、正则化)

news2025/1/19 11:19:47

 

 

 1、训练常见问题

 

1.1、模型架构设计

         

        关于隐藏层的一个万能近似定理:

        Universal Approximation Theorem:一个具有足够多的隐藏节点的多层前馈神经网络,可以逼近任意连续的函数。(Cybenko, 1989)—— 必须包含至少一种有挤压性质的激活函数。

        

 

1.2、宽度 / 深度

        

1.3、过拟合

         Overfitting:模型在训练数据上表现良好,在测试数据上不佳

        泛化能力:训练后的模型应用到新的、未知的数据上的能力

        产生原因:通常是由模型复杂度过高导致的

        

1.4、欠拟合

        Underfitting:学习能力不足,无法学习到数据集中的“一般规律

        产生原因:模型学习能力较弱,而数据复杂度较高的情况

相互关系:

        

 

 1.5、过拟合应对策略

 

         本质都是数据和模型匹配问题。可以从以下三种方法入手:

                数据复杂度

                模型复杂度

                训练策略

 

        

         

         

        

        

        

        


 

2、正则化

 

        正则化:对学习算法的修改,目的是减少泛化误差,而不是训练误差。

        

没有免费午餐定理:脱离具体问题,空谈什么模型更好没有意义。

        没有一种算法或者模型能够在所有的场景中都表现良好。

        正则化是一种权衡过拟合和欠拟合的手段。

        


 

2. 1、L2正则化

 

         通过给模型的损失函数添加一个模型参数的平方和的惩罚项来实现正则化。

        

         L 为范数。

         在机器学习中也被称为 岭回归。

 

2.2、L1正则化

        通过在损失函数中加入对模型参数权值矩阵中各元素绝对值之和的惩罚项,来限制模型参数的值。

        

       在机器学习中也被称为 LASSO回归。

 

L1正则化和L2正则化空间解释:

        机器学习 | 过拟合与正则化、模型泛化与评价指标-CSDN博客

L1正则化和L2正则化异同对比:

         L1正则化更倾向于产生稀疏解,适于特征选择。

        L2正则化更倾向于小的非零权值,更适用于优化问题,使得权值更加平滑。

         

 


 

2.3、范数惩罚

 

        将L1和L2正则化扩展到一般情况, λ 表示系数,越大表示惩罚程度越大。

         

        神经网络中一般参数会包含两种,y = f ( w X + b )

        w 是仿射变换和偏置 b ,通常情况下我们只考虑对参数 w 作惩罚,这是由于在拟合偏置 b 时所需数据量比较少就可以拟合的很好了。


 

2.4、权重衰减 Weight Decay

         

         直接修改最优化过程中参数迭代的方程:

        

         当使用随机梯度下降 SGD 时,就等价于 L2正则化。

        


2.5、Dropout方法

 

工作原理:

         在训练过程中随机删除(即将其权重设为零)一些神经元,从而使模型不能够依赖于某些特定的特征。

        只用在训练期间,不用在测试期间。

         

主要步骤:        

        指定一个保留比例p;

        每层每个神经元,以p 的概率保留,以1-p 的概率将权重设为零;

        训练中使用保留的神经元进行前向、反向传播;

        测试过程,将所有权重乘以p。

 

 直观理解:

        相当于把一个网络拆分;

        由多个子网络构成集成学习,bagging。

        

 神经网络中的使用:

        在训练网络的每个单元都要添加概率计算。

        添加 r ~

        

 

 为什么能减少过拟合:

        本质是Bagging集成学习,平均化作用;

        减少神经元之间复杂的关系,迫使模型寻找显著的特征;

        类似性别在生物进化中的角色。

        

优缺点:

        可以有效地减少过拟合,简单方便,实用有效。

        降低训练效率(多了分拆、计算概率),损失函数不够明确。

         

 

代码实现:

# 导入必要的库
import torch
import torch.nn as nn
import matplotlib.pyplot as plt

# 随机数种子
torch.manual_seed(2333)

# 定义超参数
num_samples = 20 # 样本数
hidden_size = 200 # 隐藏层大小
num_epochs = 500  # 训练轮数

数据生成

# 生成训练集
x_train = torch.unsqueeze(torch.linspace(-1, 1, num_samples), 1)
y_train = x_train + 0.3 * torch.randn(num_samples, 1)

# 测试集
x_test = torch.unsqueeze(torch.linspace(-1, 1, num_samples), 1)
y_test = x_test + 0.3 *  torch.randn(num_samples, 1)

# 绘制训练集和测试集
plt.scatter(x_train, y_train, c='r', alpha=0.5, label='train')
plt.scatter(x_test, y_test, c='b', alpha=0.5, label='test')
plt.legend(loc='upper left')
plt.ylim((-2, 2))
plt.show()

模型定义

# 定义一个可能会过拟合的网络
net_overfitting = torch.nn.Sequential(
    torch.nn.Linear(1, hidden_size),
    torch.nn.ReLU(),
    torch.nn.Linear(hidden_size, hidden_size),
    torch.nn.ReLU(),
    torch.nn.Linear(hidden_size, 1),
)

# 定义一个包含 Dropout 的网络
net_dropout = torch.nn.Sequential(
    torch.nn.Linear(1, hidden_size),
    torch.nn.Dropout(0.5),  # p=0.5
    torch.nn.ReLU(),
    torch.nn.Linear(hidden_size, hidden_size),
    torch.nn.Dropout(0.5),  # p=0.5
    torch.nn.ReLU(),
    torch.nn.Linear(hidden_size, 1),
)

模型训练

# 定义优化器和损失函数
optimizer_overfitting = torch.optim.Adam(net_overfitting.parameters(), lr=0.01)
optimizer_dropout = torch.optim.Adam(net_dropout.parameters(), lr=0.01)

# 损失函数
criterion = nn.MSELoss()

# 分别进行训练
for i in range(num_epochs):
    # overfitting的网络:预测、损失函数、反向传播
    pred_overfitting = net_overfitting(x_train)
    loss_overfitting = criterion(pred_overfitting, y_train)
    optimizer_overfitting.zero_grad()
    loss_overfitting.backward()
    optimizer_overfitting.step()
    
    # 包含dropout的网络:预测、损失函数、反向传播
    pred_dropout = net_dropout(x_train)
    loss_dropout = criterion(pred_dropout, y_train)
    optimizer_dropout.zero_grad()
    loss_dropout.backward()
    optimizer_dropout.step()

预测和可视化

# 在测试过程中不使用 Dropout
net_overfitting.eval()
net_dropout.eval()

# 预测
test_pred_overfitting = net_overfitting(x_test)
test_pred_dropout = net_dropout(x_test)

# 绘制拟合效果
plt.scatter(x_train, y_train, c='r', alpha=0.3, label='train')
plt.scatter(x_test, y_test, c='b', alpha=0.3, label='test')
plt.plot(x_test, test_pred_overfitting.data.numpy(), 'r-', lw=2, label='overfitting')
plt.plot(x_test, test_pred_dropout.data.numpy(), 'b--', lw=2, label='dropout')
plt.legend(loc='upper left')
plt.ylim((-2, 2))
plt.show()

 

 

 3、梯度消失与梯度爆炸

3.1、梯度的重要性

 

        深度神经网络就是非线性多元函数。

        优化模型就是找到合适权重,最小化损失函数。

        

        

 

3.2、反向传播的内在问题

        链式求导法则本身就是激活函数的偏导数连乘

        

        

3.3、梯度消失

        sigmiod函数:

                梯度不超过 0.5。

        激活函数的导数小于1容易发生梯度消失。

       

3.4、梯度爆炸

        梯度可在更新中累积,变成非常大,导致网络不稳定或者模型溢出。

        原因之一:深层网络;

        原因之二:初始化权重的值过大。

3.5、解决办法

        预训练加微调

        梯度剪切、正则

        ReLU激活函数

        Batchnorm

        残差结构

3.5.1、梯度剪切 —— 针对梯度爆炸

        设置一个梯度剪切闻值,超过则将其强制限制在这个范围之内。

        

3.5.2、ReLU激活函数

        

 

3.5.3、Batchnorm

        对某一个通道进行规范化操作。

        

 

3.5.4、残差结构

         

 


 

4、模型文件的读写

 

4.1、张量的保存和加载

import torch
a = torch.rand(6)
a
tensor([0.8608, 0.6997, 0.4133, 0.6113, 0.5393, 0.8223])
torch.save(a,"model/tensor_a")
torch.load("model/tensor_a")
tensor([0.8608, 0.6997, 0.4133, 0.6113, 0.5393, 0.8223])
a = torch.rand(6)
b = torch.rand(6)
c = torch.rand(6)
[a,b,c]
[tensor([0.6443, 0.6780, 0.9844, 0.3475, 0.3763, 0.9680]),
 tensor([0.0351, 0.3652, 0.9474, 0.5658, 0.5001, 0.7580]),
 tensor([0.5543, 0.2713, 0.3125, 0.0378, 0.0676, 0.2208])]
torch.save([a,b,c],"model/tensor_abc")
torch.load("model/tensor_abc")
[tensor([0.6443, 0.6780, 0.9844, 0.3475, 0.3763, 0.9680]),
 tensor([0.0351, 0.3652, 0.9474, 0.5658, 0.5001, 0.7580]),
 tensor([0.5543, 0.2713, 0.3125, 0.0378, 0.0676, 0.2208])]
tensor_dict= {'a':a,'b':b,'c':c}
tensor_dict
{'a': tensor([0.6443, 0.6780, 0.9844, 0.3475, 0.3763, 0.9680]),
 'b': tensor([0.0351, 0.3652, 0.9474, 0.5658, 0.5001, 0.7580]),
 'c': tensor([0.5543, 0.2713, 0.3125, 0.0378, 0.0676, 0.2208])}
torch.save(tensor_dict,"model/tensor_dict_abc")
torch.load("model/tensor_dict_abc")
{'a': tensor([0.6443, 0.6780, 0.9844, 0.3475, 0.3763, 0.9680]),
 'b': tensor([0.0351, 0.3652, 0.9474, 0.5658, 0.5001, 0.7580]),
 'c': tensor([0.5543, 0.2713, 0.3125, 0.0378, 0.0676, 0.2208])}

 


 

4.2、模型的保存与加载

from torchvision import datasets
from torchvision import transforms
import torch.nn as nn 
import torch.optim as optim

# 定义 MLP 网络  继承nn.Module
class MLP(nn.Module):
    
    # 初始化方法
    # input_size输入数据的维度    
    # hidden_size 隐藏层的大小
    # num_classes 输出分类的数量
    def __init__(self, input_size, hidden_size, num_classes):
        # 调用父类的初始化方法
        super(MLP, self).__init__()
        # 定义第1个全连接层  
        self.fc1 = nn.Linear(input_size, hidden_size)
        # 定义激活函数
        self.relu = nn.ReLU()
        # 定义第2个全连接层
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        # 定义第3个全连接层
        self.fc3 = nn.Linear(hidden_size, num_classes)
        
    # 定义forward函数
    # x 输入的数据
    def forward(self, x):
        # 第一层运算
        out = self.fc1(x)
        # 将上一步结果送给激活函数
        out = self.relu(out)
        # 将上一步结果送给fc2
        out = self.fc2(out)
        # 同样将结果送给激活函数
        out = self.relu(out)
        # 将上一步结果传递给fc3
        out = self.fc3(out)
        # 返回结果
        return out
    
# 定义参数    
input_size = 28 * 28  # 输入大小
hidden_size = 512  # 隐藏层大小
num_classes = 10  # 输出大小(类别数) 

# 初始化MLP    
model = MLP(input_size, hidden_size, num_classes)

4.2.1、方式1(推荐)

# 保存模型参数
torch.save(model.state_dict(),"model/mlp_state_dict.pth")
# 读取保存的模型参数
mlp_state_dict = torch.load("model/mlp_state_dict.pth")

# 新实例化一个MLP模型
model_load = MLP(input_size,hidden_size,num_classes)

# 调用load_state_dict方法 传入读取的参数
model_load.load_state_dict(mlp_state_dict)
<All keys matched successfully>

 

4.2.2、方式2

# 保存整个模型
torch.save(model,"model/mlp_model.pth")
# 加载整个模型
mlp_load = torch.load("model/mlp_model.pth")

 

4.2.3、方式3 : checkpoint(推荐)

# 保存参数
torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss,
            ...
            }, PATH)
# 加载参数
model = TheModelClass(*args, **kwargs)
optimizer = TheOptimizerClass(*args, **kwargs)

checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

model.eval()

 


参考

Deep-Learning-Code: 《深度学习必修课:进击算法工程师》配套代码 - Gitee.com

哔哩哔哩_bilibili

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1331045.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

mysql自增序列 关于mysql线程安全 独享内存 溢出 分析

1 MySQL锁概述 锁是计算机协调多个进程或线程并发访问某一资源的机制。如何保证数据并发访问的一致性、有效性是所有数据库必须解决的一个问题&#xff0c;锁冲突也是影响数据库并发访问性能的一个重要因素。 相对其他数据库而言&#xff0c;MySQL的锁机制比较简单&#xff0c…

C++中的内存锁定

内存锁定(memory locking)是确保进程保留在主内存中并且免于分页的一种方法。在实时环境中&#xff0c;系统必须能够保证将进程锁定在内存中&#xff0c;以减少数据访问、指令获取、进程之间的缓冲区传递等的延迟。锁定内存中进程的地址空间有助于确保应用程序的响应时间满足实…

二维码初体验 com.google.zxing 实现

文章目录 一、概述二、实现效果1. 完整版本2. 简化版本 三、源码结构四、完整代码 一、概述 Java 操作二维码的开源项目很多&#xff0c;如 SwetakeQRCode、BarCode4j、Zxing 等&#xff0c;这边以Zxing 为例进行介绍。 二、实现效果 1. 完整版本 选择需要生成QR原始文件&a…

大模型工具_awesome-chatgpt-prompts-zh

https://github.com/PlexPt/awesome-chatgpt-prompts-zh 1 功能 整体功能&#xff0c;想解决什么问题 ChatGPT 中文调教指南&#xff1a;提供一些常用的使用场景及对应的 Prompt 提示 当前解决了什么问题&#xff0c;哪些问题解决不了 针对想解决实际问题&#xff0c;但不知道…

图像识别中的 Vision Transformers (ViT)

引言 Vision Transformers (ViT) 最近已成为卷积神经网络(CNN) 的竞争替代品&#xff0c;而卷积神经网络 (CNN) 目前在不同的图像识别计算机视觉任务中处于最先进的水平。ViT 模型在计算效率和准确性方面比当前最先进的 (CNN) 模型高出近 4 倍。 Transformer 模型已成为自然语…

Diffusion扩散模型学习:图片高斯加噪

高斯分布即正态分布&#xff1b;图片高斯加噪即把图片矩阵每个值和一个高斯分布的矩阵上的对应值相加 1、高斯分布 np.random.normal 一维&#xff1a; import numpy as np import matplotlib.pyplot as pltdef generate_gaussian_noise(mean, std_dev, size):noise np.ran…

小白入门之安装NodeJS

重生之我在大四学JAVA 第五章 安装NodeJS 如果你在购买我闲鱼的程序&#xff0c;请尽量使用node14版本 修改安装路径 接着傻瓜式NEXT 测试是否安装成功 如果上面没提示版本号&#xff0c;就按照前两章配置环境变量步骤配置下环境变量 设置镜像地址 npm config set re…

一种简单的自编码器PyTorch代码实现

1. 引言 对于许多新接触深度学习爱好者来说&#xff0c;玩AutoEncoder总是很有趣的&#xff0c;因为它具有简单的处理逻辑、简易的网络架构&#xff0c;方便可视化潜在的特征空间。在本文中&#xff0c;我将从头开始介绍一个简单的AutoEncoder模型&#xff0c;以及一些可视化潜…

全渠道在线客服系统支持的沟通渠道:多渠道整合与无缝对接

我们在挑选客服系统的时候&#xff0c;经常会看到有些客服产品会强调自己是“全渠道客服系统”&#xff0c;那什么是全渠道客服系统呢&#xff1f; 1、什么是全渠道客服系统&#xff1f; 简单来讲&#xff0c;它是指能把某个客户在不同渠道的互动历史放到一起集中展现&#x…

rqt_graph使用说明

其中右边的&#xff1a;/rosout是一个topic 也就是一个话题 /rosout是一个topic 也是一个话题 可以看到凡是在rqt_graph里面用长方形标识的全都是话题 通过观察可以发现&#xff1a;凡是用椭圆标识的全都是节点 如果切换为Nodes only视图会发现&#xff1a; 所说的no…

SpringSecurity安全框架 ——认证与授权

目录 一、简介 1.1 什么是Spring Security 1.2 工作原理 1.3 为什么选择Spring Security 1.4 HttpSecurity 介绍&#x1f31f; 二、用户认证 2.1 导入依赖与配置 2.2 用户对象UserDetails 2.3 业务对象UserDetailsService 2.4 SecurityConfig配置 2.4.1 BCryptPasswo…

【数据结构入门精讲 | 第八篇】一文讲清全部排序算法(2)

在上一篇文章中我们介绍了冒泡排序、快速排序等算法&#xff0c;这一篇我们接着对排序算法的学习。 目录 归并排序堆排序选择排序计数排序基数排序排序总结 归并排序 归并排序是建立在归并操作上的一种有效&#xff0c;稳定的排序算法&#xff0c;该算法是采用分治法&#xff…

MySQL报错:1054 - Unknown column ‘xx‘ in ‘field list的解决方法

我在操作MySQL遇到1054报错&#xff0c;报错内容&#xff1a;1054 - Unknown column Cindy in field list&#xff0c;下面演示解决方法&#xff0c;非常简单。 根据箭头指示&#xff0c;Cindy对应的应该是VARCHAR文本数字类型&#xff0c;字符串要用引号&#xff0c;所以解决方…

【C语言】打印内存数据

C语言&#xff0c;用函数封装&#xff1a;16进制打印unsigned char *p指向的内存&#xff0c;长度为int l。16个字节&#xff0c;换一次行。16个字节用一个字符串缓存&#xff0c;一次打印。 以下是一个使用函数封装的C语言代码&#xff0c;用于以16进制格式打印unsigned char …

MySQL 事务的ACID特性

MySQL事务是什么&#xff0c;它就是一组数据库的操作&#xff0c;是访问数据库的程序单元&#xff0c;事务中可能包含一个或者多个 SQL 语句。这些SQL 语句要么都执行、要么都不执行。我们知道&#xff0c;在MySQL 中&#xff0c;有不同的存储引擎&#xff0c;有的存储引擎比如…

省时攻略:快速获得Creo安装包,释放创意天才!

不要再在网上浪费时间寻找Creo的安装包了&#xff0c;一键下载安装&#xff0c; 你要的一切都可以在这里找到&#xff01;我们深知在海量的信息中寻找合适的软件包并非易事&#xff0c;而且往往还伴随着繁琐的安装过程。然而&#xff0c;现在有了我们&#xff0c;一切变得轻松简…

【飞凌 OK113i-C 全志T113-i开发板】一些有用的常用的命令测试

一些有用的常用的命令测试 一、系统信息查询 可以查询板子的内核信息、CPU处理器信息、环境变量等 二、CPU频率 从上面的系统信息查询到&#xff0c;这是一颗具有两个ARMv7结构A7内核的处理器&#xff0c;主频最高1.2GHz 可以通过命令查看当前支持的频率以及目前所使用主频 …

爬虫工作量由小到大的思维转变---<第二十三章 Scrapy开始很快,越来越慢(医病篇)>

诊断篇https://blog.csdn.net/m0_56758840/article/details/135170994?ops_request_misc%257B%2522request%255Fid%2522%253A%2522170333243316800180644102%2522%252C%2522scm%2522%253A%252220140713.130102334.pc%255Fall.%2522%257D&request_id1703332433168001806441…

更改WiseAlign软件界面图标方法

更改WiseAlign软件界面图标方法 未替换时 首先将图片转换为BMP格式&#xff0c;在搜索栏处输入画图&#xff0c;点击打开画图工具 按住图标拖动到画布内&#xff0c;或是直接CtrlV将图标复制到画布内 点击文件&#xff0c;再点击另存为 保存类型选择“24位位图&#xff08;*.bm…

SpringBoot3-基础特性

文章目录 自定义 banner自定义 SpringApplicationFluentBuilder APIProfiles指定环境环境激活环境包含Profile 分组Profile 配置文件 外部化配置配置优先级 外部配置导入配置属性占位符 单元测试-JUnit5测试组件测试注解断言嵌套测试参数化测试 自定义 banner banner 就是启动…