【机器学习】人工神经网络优化方法及正则化技术

news2024/9/29 1:24:21

鑫宝Code

🌈个人主页: 鑫宝Code
🔥热门专栏: 闲话杂谈| 炫酷HTML | JavaScript基础
💫个人格言: "如无必要,勿增实体"


文章目录

  • 人工神经网络优化方法及正则化技术
    • 1. 引言
    • 2. 神经网络优化的基础
      • 2.1 损失函数
      • 2.2 梯度下降
    • 3. 高级优化算法
      • 3.1 随机梯度下降(SGD)
      • 3.2 动量法(Momentum)
      • 3.3 AdaGrad
      • 3.4 RMSprop
      • 3.5 Adam
    • 4. 学习率调度
      • 4.1 学习率衰减
      • 4.2 周期性学习率
      • 4.3 热重启
    • 5. 正则化技术
      • 5.1 L1正则化(Lasso)
      • 5.2 L2正则化(Ridge)
      • 5.3 弹性网络(Elastic Net)
      • 5.4 Dropout
      • 5.5 批量归一化(Batch Normalization)
      • 5.6 权重衰减(Weight Decay)
    • 6. 高级正则化技术
      • 6.1 数据增强
      • 6.2 早停(Early Stopping)
      • 6.3 混合精度训练
    • 7. 结论

人工神经网络优化方法及正则化技术

1. 引言

人工神经网络(Artificial Neural Networks,ANN)是机器学习和深度学习中的核心技术之一。为了提高神经网络的性能和泛化能力,研究人员开发了各种优化方法和正则化技术。本文将深入探讨这些方法,帮助读者更好地理解和应用这些重要的技术。

2. 神经网络优化的基础

2.1 损失函数

损失函数是衡量神经网络预测结果与真实值之间差异的指标。常见的损失函数包括:

  • 均方误差(MSE)
  • 交叉熵(Cross-Entropy)
  • Hinge Loss

2.2 梯度下降

梯度下降是优化神经网络的基本方法,它通过计算损失函数相对于网络参数的梯度,并沿着梯度的反方向更新参数,以最小化损失函数。

3. 高级优化算法

3.1 随机梯度下降(SGD)

在这里插入图片描述

SGD是标准梯度下降的变体,每次只使用一个或一小批样本来计算梯度,从而加快训练速度。

for epoch in range(num_epochs):
    for batch in data_loader:
        optimizer.zero_grad()
        loss = loss_function(model(batch), targets)
        loss.backward()
        optimizer.step()

3.2 动量法(Momentum)

动量法通过累积过去的梯度来加速收敛,特别是在处理高曲率、小但一致的梯度时很有效。

v = beta * v - learning_rate * gradient
theta = theta + v

3.3 AdaGrad

AdaGrad自适应地调整学习率,对频繁更新的参数使用较小的学习率,对不经常更新的参数使用较大的学习率。

cache += gradient ** 2
theta -= learning_rate * gradient / (np.sqrt(cache) + epsilon)

3.4 RMSprop

RMSprop是AdaGrad的改进版本,通过使用移动平均来缓解学习率急剧下降的问题。

cache = decay_rate * cache + (1 - decay_rate) * gradient ** 2
theta -= learning_rate * gradient / (np.sqrt(cache) + epsilon)

3.5 Adam

Adam结合了动量法和RMSprop的优点,是目前最流行的优化算法之一。

m = beta1 * m + (1 - beta1) * gradient
v = beta2 * v + (1 - beta2) * (gradient ** 2)
m_hat = m / (1 - beta1 ** t)
v_hat = v / (1 - beta2 ** t)
theta -= learning_rate * m_hat / (np.sqrt(v_hat) + epsilon)

4. 学习率调度

4.1 学习率衰减

随着训练的进行,逐步降低学习率可以帮助模型更好地收敛。

learning_rate = initial_lr * (decay_rate ** (epoch // decay_steps))

4.2 周期性学习率

周期性地调整学习率可以帮助模型跳出局部最小值。

learning_rate = base_lr + (max_lr - base_lr) * abs(sin(pi * t / (2 * step_size)))

4.3 热重启

热重启技术通过周期性地重置学习率来改善优化过程。

T_cur = epoch % T_i
learning_rate = lr_min + 0.5 * (lr_max - lr_min) * (1 + cos(pi * T_cur / T_i))

5. 正则化技术

正则化是防止过拟合、提高模型泛化能力的重要技术。
在这里插入图片描述

5.1 L1正则化(Lasso)

L1正则化通过在损失函数中添加参数的绝对值和来实现稀疏化。

loss = original_loss + lambda * sum(abs(parameter))

5.2 L2正则化(Ridge)

L2正则化通过在损失函数中添加参数的平方和来防止参数值过大。

loss = original_loss + lambda * sum(parameter ** 2)

5.3 弹性网络(Elastic Net)

弹性网络结合了L1和L2正则化的优点。

loss = original_loss + lambda1 * sum(abs(parameter)) + lambda2 * sum(parameter ** 2)

5.4 Dropout

Dropout是一种强大的正则化技术,通过在训练过程中随机"丢弃"一部分神经元来防止过拟合。

class Dropout(nn.Module):
    def __init__(self, p=0.5):
        super(Dropout, self).__init__()
        self.p = p

    def forward(self, x):
        if self.training:
            mask = torch.bernoulli(torch.ones_like(x) * (1 - self.p))
            return x * mask / (1 - self.p)
        return x

5.5 批量归一化(Batch Normalization)

在这里插入图片描述

批量归一化通过标准化每一层的输入来加速训练并提高模型的稳定性。

class BatchNorm(nn.Module):
    def __init__(self, num_features, eps=1e-5, momentum=0.1):
        super(BatchNorm, self).__init__()
        self.num_features = num_features
        self.eps = eps
        self.momentum = momentum
        self.gamma = nn.Parameter(torch.ones(num_features))
        self.beta = nn.Parameter(torch.zeros(num_features))
        self.running_mean = torch.zeros(num_features)
        self.running_var = torch.ones(num_features)

    def forward(self, x):
        if self.training:
            mean = x.mean(dim=0)
            var = x.var(dim=0, unbiased=False)
            self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean
            self.running_var = (1 - self.momentum) * self.running_var + self.momentum * var
        else:
            mean = self.running_mean
            var = self.running_var
        
        x_normalized = (x - mean) / torch.sqrt(var + self.eps)
        return self.gamma * x_normalized + self.beta

5.6 权重衰减(Weight Decay)

权重衰减是L2正则化的一种实现,通过在每次参数更新时减小权重来防止过拟合。

for param in model.parameters():
    param.data -= weight_decay * param.data

6. 高级正则化技术

6.1 数据增强

数据增强通过对训练数据进行变换来增加数据的多样性,从而提高模型的泛化能力。

transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.ToTensor(),
])

6.2 早停(Early Stopping)

早停通过监控验证集的性能来决定何时停止训练,防止过拟合。

best_val_loss = float('inf')
patience = 10
counter = 0

for epoch in range(num_epochs):
    train(model, train_loader, optimizer, criterion)
    val_loss = validate(model, val_loader, criterion)
    
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        counter = 0
        torch.save(model.state_dict(), 'best_model.pth')
    else:
        counter += 1
        if counter >= patience:
            print("Early stopping")
            break

6.3 混合精度训练

混合精度训练通过使用低精度(如float16)和高精度(如float32)的混合来加速训练并减少内存使用。

scaler = torch.cuda.amp.GradScaler()

for batch in data_loader:
    optimizer.zero_grad()
    with torch.cuda.amp.autocast():
        loss = loss_function(model(batch), targets)
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()

7. 结论

人工神经网络的优化和正则化是深度学习中至关重要的主题。通过合理地选择和组合各种优化算法和正则化技术,我们可以显著提高模型的性能和泛化能力。然而,需要注意的是,没有一种通用的方法适用于所有问题。在实际应用中,我们需要根据具体的任务、数据集和计算资源来选择合适的方法,并通过实验来找到最佳的组合。

随着深度学习领域的不断发展,新的优化方法和正则化技术也在不断涌现。保持对最新研究的关注,并在实践中不断尝试和改进,将有助于我们构建更加高效和强大的神经网络模型。

End

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1980538.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

vlunstack-2(复现红日安全-ATT CK实战)

环境搭建 配置信息 DC IP&#xff1a;10.10.10.10 OS&#xff1a;Windows 2012(64) 应用&#xff1a;AD域 WEB IP1&#xff1a;10.10.10.80 IP2&#xff1a;192.168.47.131 OS&#xff1a;Windows 2008(64) 应用&#xff1a;Weblogic 10.3.6MSSQL 2008 PC IP1&#xff1a;10.10…

Scrapy入门篇

免责声明 本文的爬虫知识仅用于合法和合理的数据收集&#xff0c;使用者需遵守相关法律法规及目标网站的爬取规则&#xff0c;尊重数据隐私&#xff0c;合理设置访问频率&#xff0c;不得用于非法目的或侵犯他人权益。因使用网络爬虫产生的任何法律纠纷或损失&#xff0c;由使用…

论文解读:LSM Tree 的魔力,提升写入吞吐量的高效数据存储结构

LSM Tree是一种用于高写入吞吐量的数据库存储引擎&#xff0c;广泛应用于现代分布式数据库系统。其核心思想是将写入操作缓存在内存中&#xff0c;并定期批量写入磁盘&#xff0c;减少磁盘 I/O 操作&#xff0c;提高写入性能。因其高效的写入性能和适应大规模数据的能力&#x…

医院客户满意度调查如何开展

深圳满意度咨询有限公司&#xff08;SSC&#xff09;&#xff08;患者第三方满意度测评&#xff09;服务于国内多家医院&#xff0c;辅助医院提高患者满意度、改善医德医风、提高服务水平&#xff0c;调查项目覆盖了国内150余个城市&#xff0c;通过电话调查、网络问卷、现场访…

图片搜索网站,有大量高清图片,避免版权纠纷

一、简介 1、一个图片搜索网站&#xff0c;所有图片均遵循CC0协议&#xff0c;用户可以免费用于商业用途而无需标注来源。网站上有大量高清图片&#xff0c;基本可以满足用户的各种需求&#xff0c;同时避免了法律风险。提供强大的筛选功能&#xff0c;用户可以按图片方向、尺寸…

python学习之路 - python的函数

目录 一、python函数1、函数介绍2、函数的定义3、函数的参数4、函数的返回值5、函数说明文档6、函数的嵌套调用7、变量的作用域8、综合案例9、函数与方法的区别 二、python函数进阶1、函数多返回值2、函数多种传参方式a、位置参数b、关键字参数c、缺省参数d、不定长参数 3、匿名…

Visual Studio 调试时加载符号慢

什么是调试符号 编译程序时生成的一组特殊字符&#xff0c;并包含有关变量和函数在生成的二进制文件中的位置以及其他服务信息的信息。 该数据集可用于逐步调试程序或检查第三方代码。 调试符号可以添加到可执行文件或库中&#xff0c;但是大多数现代编译器将它们存储为单独的…

fabricjs 实现图像的二值化功能

一、效果图 二、图像二值化的作用 二值化是图像处理中常用的一种方法&#xff0c;其作用是将灰度图像转换为二值图像&#xff0c;即将图像中的像素点根据其灰度值分成两类&#xff1a;黑色和白色。这种处理方法可以帮助我们更清晰地识别图像中的目标&#xff0c;简化图像的复杂…

Lumerical 光纤模式仿真

Lumerical 光纤模式仿真 引言正文步骤 1------创建光纤的纤芯设置名称,位置及尺寸参数设置材料参数旋转结构使其朝向 x 方向放置步骤2------创建包层结构设置名称,位置及尺寸参数设置材料参数旋转结构使其朝向 x 方向放置设置透明度,是我们能够看到纤芯结构设置 FDE Solver设…

本科阶段最后一次竞赛Vlog——2024年智能车大赛智慧医疗组准备全过程——4Bin模型转化过程

本科阶段最后一次竞赛Vlog——2024年智能车大赛智慧医疗组准备全过程——4Bin模型转化过程 ​ 大家好&#xff0c;经过前几期的介绍&#xff0c;对于X3派上的Yolo模型部署&#xff0c;我们已经可以进行到最后一步了 ​ 今天给大家带来&#xff0c;转模型的关键步骤&#xff0…

苹芯科技发布新AI模型,引领全球轻量级AI应用革命

苹芯科技&#xff0c;一家在全球AI技术领域中不断创新的公司&#xff0c;于2月28日宣布推出其最新研发的轻量级AI模型。这款新模型旨在为开发者和企业提供更高效、更易访问的人工智能工具&#xff0c;尤其强调在数据敏感和计算资源受限的环境下的应用潜力。 在谷歌刚刚推出Gemm…

普元MDM主数据管理系统与微软Dynamic CRM系统(新加坡节点)集成案例

一、项目背景 某工程机械集团是中国工程机械产业奠基者、开创者和引领者&#xff0c;是工程机械行业具有全球竞争力、影响力的千亿级龙头企业。公司主要指标始终稳居中国工程机械行业第1位 客户需要将物料和配件等主数据和海外系统进行对接&#xff0c;由SAP PO在中间对接海…

【开发视角】大模型 RAG 检索增强生成究竟是什么

【大白话讲懂】大模型 RAG 检索增强生成 话先说在前面&#xff0c;本文不讲不会讲太多原理&#xff0c;仅面向工程开发&#xff0c;从工作流程的宏观角度进行梳理&#xff0c;旨在快速上手。 RAG 是什么 基本定义 让我们先来解释名词&#xff0c;看看宏观框架。 RAG 的意思…

Opencv调用yolov5的onnx文件时报错记录

Opencv调用yolov5的onnx文件时报错记录 报错内容&#xff1a; Error: Unspecified error (> Node [Powai.onnx]:(onnx_node!/model.24/Pow) parse error: OpenCV(4.6.0) F:\opencv-4.6.0\opencv-4.6.0\modules\dnn\src\onnx\onnx_importer.cpp:601: error: (-215:Assertio…

C++ vector的基本使用(待补全)

std::vector 是C标准模板库(STL)中的一个非常重要的容器类&#xff0c;它提供了一种动态数组的功能。能够存储相同类型的元素序列&#xff0c;并且可以自动管理存储空间的大小&#xff0c;以适应序列大小变化&#xff0c;处理元素集合的时候很灵活 1. vector的定义 构造函数声…

西安电子科技大学2025届毕业生生源信息

2025届本科毕业生专业分布一览表 2025届硕士毕业生专业分布一览表 2025届博士毕业生专业分布一览表 2025届本科毕业生生源地分布 左右滑动查看更多 2025届硕士毕业生生源地分布 2025届博士毕业生生源地分布

小红书笔试-选择题

HTTP/2.0默认长连接。选B ABC 一个类可以实现多个接口&#xff0c;一个接口可以继承一个或多个接口&#xff1a; 这是正确的。Java 支持多重继承的变体&#xff0c;即一个类可以实现多个接口&#xff0c;以获取多个接口中定义的方法。同时&#xff0c;一个接口可以通过 extends…

假如家里太大了,wifi连不上了怎么办

最近有个土豪朋友抱怨&#xff0c;他家里太大了&#xff0c;一个路由器的Wi-Fi信号根本无法覆盖他们家的每个房间&#xff0c;都没办法上网看奥运会比赛了。&#xff08;还好我是穷人&#xff0c;就没有这种烦恼T_T&#xff09;。 然后我问他为何不用一个路由器作主路由器&…

安卓常用控件ListView

文章目录 ListView的常用属性ListView的常用APIListView的简单使用 ListView是一个列表样式的 ViewGroup&#xff0c;将若干 item 按行排列。它是一个很基本的控件也是 Android 中最重要的控件之一。它可以实现多个 View 的垂直排列并支持滚动显示效果。 ListView的常用属性 常…

数论——绝对素数、素数筛法、埃氏筛法、欧拉筛法、最大公约数

绝对素数 绝对素数是指一个素数在其十进制表示下&#xff0c;无论是从左向右读还是从右向左读&#xff0c;所得到的数仍然是素数。 例如&#xff0c;13 是一个素数&#xff0c;从右向左读是 31&#xff0c;31 也是素数&#xff0c;所以 13 是一个绝对素数。 #include <io…