深度学习——损失函数与BP算法

news2024/11/29 7:42:22

一、损失函数

1. 线性回归损失函数

1.1 MAE损失

MAE(Mean Absolute Error,平均绝对误差)通常也被称为 L1-Loss,通过对预测值和真实值之间的绝对差取平均值来衡量他们之间的差异。MAE的公式如下:

其中:

  • n 是样本的总数。
  • y_{i} 是第 i 个样本的真实值。
  • \hat{y}_{i} 是第 i 个样本的预测值
  • \left | y_{i}- \hat{y}_{i} \right | 是真实值和预测值之间的绝对误差。

特点

  1. 鲁棒性:与均方误差(MSE)相比,MAE对异常值(outliers)更为鲁棒,因为它不会像MSE那样对较大误差平方敏感。
  2. 物理意义直观:MAE以与原始数据相同的单位度量误差,使其易于解释。

应用场景: MAE通常用于需要对误差进行线性度量的情况,尤其是当数据中可能存在异常值时,MAE可以避免对异常值的过度惩罚。

1.2 MSE损失

MSE(Mean Squared Error,均方误差)通常也被称为L2Loss。通过对预测值和真实值之间的误差平方取平均值,来衡量预测值与真实值之间的差异。MSE的公式如下:

其中:

  • n 是样本的总数。
  • y_{i} 是第 i 个样本的真实值。
  • \hat{y}_{i} 是第 i 个样本的预测值。
  • \left ( y_{i}- \hat{y}_{i} \right )^{2} 是真实值和预测值之间的误差平方。

特点

  1. 平方惩罚:因为误差平方,MSE 对较大误差施加更大惩罚,所以 MSE 对异常值更为敏感。
  2. 凸性:MSE 是一个凸函数,这意味着它具有一个唯一的全局最小值,有助于优化问题的求解。

应用场景

MSE被广泛应用在神经网络中。

1.3 SmoothL1Loss

SmoothL1Loss可以做到在损失较小时表现为 L2 损失,而在损失较大时表现为 L1 损失。SmoothL1Loss 的公式如下:

其中,x 表示预测值和真实值之间的误差,即 x=y_{i}-\hat{y}_{i}

所有样本的平均损失为:

特点:

  1. 平滑过渡:当误差较小时,损失函数表现为 L2 Loss(平方惩罚);当误差较大时,损失函数逐渐向 L1 Loss过渡。这种平滑过渡既能对大误差有所控制,又不会对异常值过度敏感。
  2. 稳健性:对于异常值更加稳健,同时在小误差范围内提供了较好的优化效果。

应用场景:

SmoothL1Loss常用于需要对大误差进行一定控制但又不希望完全忽略小误差的回归任务。特别适用于目标检测任务中的边界框回归,如 Faster R-CNN 等算法中。

2. CrossEntropyLoss

交叉熵损失函数,使用在输出层使用softmax激活函数进行多分类时,一般都采用交叉熵损失函数。

对于多分类问题,CrossEntropyLoss 公式如下:

其中:

  • C 是类别的总数。
  •  y 是真实标签的one-hot编码向量,表示真实类别。
  • \hat{y} 是模型的输出(经过 softmax 后的概率分布)。
  • y_{i} 是真实类别的第 i个元素(0 或 1)。
  • \hat{y}_{i} 是预测的类别概率分布中对应类别 i 的概率。

特点:

Softmax 直白来说就是将网络输出的 logits 通过 softmax 函数,就映射成为(0,1)的值,而这些值的累和为1(满足概率的性质),那么我们将它理解成概率,选取概率最大(也就是值对应最大的)节点,作为我们的预测目标类别。

3. BCELoss

二分类交叉熵损失函数,使用在输出层使用sigmoid激活函数进行二分类时。

对于二分类问题,CrossEntropyLoss 的简化版本称为二元交叉熵(Binary Cross-Entropy Loss),公式为:

log的底数一般默认为e,y是真实类别目标,根据公式可知L是一个分段函数 :  

 

以上损失函数是一个样本的损失值,总样本的损失值是求损失均值即可。  

4. 总结

  • 当输出层使用softmax多分类时,使用交叉熵损失函数;
  • 当输出层使用sigmoid二分类时,使用二分类交叉熵损失函数, 比如在逻辑回归中使用;
  • 当功能为线性回归时,使用smooth L1损失函数或均方差损失-L2 loss;

二、BP算法 

误差反向传播算法(BP)的基本步骤:

  1. 前向传播:正向计算得到预测值。
  2. 计算损失:通过损失函数L\left ( y_{pred},y_{true} \right ) 计算预测值和真实值的差距。
  3. 梯度计算:反向传播的核心是计算损失函数 L 对每个权重和偏置的梯度。
  4. 更新参数:一旦得到每层梯度,就可以使用梯度下降算法来更新每层的权重和偏置,使得损失逐渐减小。
  5. 迭代训练:将前向传播、梯度计算、参数更新的步骤重复多次,直到损失函数收敛或达到预定的停止条件。

1. 前向传播

前向传播(Forward Propagation)把输入数据经过各层神经元的运算并逐层向前传输,一直到输出层为止。

前向传播的主要作用是:

  1. 计算神经网络的输出结果,用于预测或计算损失。
  2. 在反向传播中使用,通过计算损失函数相对于每个参数的梯度来优化网络。

2. 反向传播

反向传播(Back Propagation,简称BP)通过计算损失函数相对于每个参数的梯度来调整权重,使模型在训练数据上的表现逐渐优化。反向传播结合了链式求导法则和梯度下降算法,是神经网络模型训练过程中更新参数的关键步骤。

2.1 原理

利用链式求导法则对每一层进行求导,直到求出输入层x的导数,然后利用导数值进行梯度更新

2.2. 链式法则

链式求导法则(Chain Rule)是微积分中的一个重要法则,用于求复合函数的导数。在深度学习中,链式法则是反向传播算法的基础,这样就可以通过分层的计算求得损失函数相对于每个参数的梯度。

2.3 重要性

反向传播算法极大地提高了多层神经网络训练的效率,使得训练深度模型成为可能。通过链式法则逐层计算梯度,反向传播可以有效地处理复杂的网络结构,确保每一层的参数都能得到合理的调整。

2.4 案例助解

2.4.1 数据准备

整体网络结构及神经元数据和权重参数如下图所示:

2.4.2 神经元计算

所以,我们可以得到如下数据:

计算h1的相关数据:

计算h2的相关数据:  

计算o1的相关数据:

计算o2的相关数据:

 所以,最终的预测结果分别为: 0.7514、0.7729

2.4.3 损失计算

预测值和真实值(target)进行比较计算损失:

得到损失是:0.2984  

2.4.4 梯度计算

接下来,我们进行梯度计算和参数更新

计算 w5 权重的梯度

计算 w7 权重的梯度

计算 w1 权重的梯度

2.4.5 参数更新

现在就可以进行权重更新了:假设学习率是0.5

2.4.6 代码实现
import torch
import torch.nn as nn
import torch.optim as optim


class Net(nn.Module):

    def __init__(self):
        super(Net, self).__init__()
        self.linear1 = nn.Linear(2, 2)
        self.linear2 = nn.Linear(2, 2)

        # 网络参数初始化
        self.linear1.weight.data = torch.tensor([[0.15, 0.20], [0.25, 0.30]])
        self.linear2.weight.data = torch.tensor([[0.40, 0.45], [0.50, 0.55]])
        self.linear1.bias.data = torch.tensor([0.35, 0.35])
        self.linear2.bias.data = torch.tensor([0.60, 0.60])

    def forward(self, x):

        x = self.linear1(x)
        x = torch.sigmoid(x)
        x = self.linear2(x)
        x = torch.sigmoid(x)

        return x


if __name__ == "__main__":

    inputs = torch.tensor([[0.05, 0.10]])
    target = torch.tensor([[0.01, 0.99]])

    # 获得网络输出值
    net = Net()
    output = net(inputs)

    # 计算误差
    loss = torch.sum((output - target) ** 2) / 2

    # 优化方法
    optimizer = optim.SGD(net.parameters(), lr=0.5)

    # 梯度清零
    optimizer.zero_grad()

    # 反向传播
    loss.backward()

    # 打印(w1-w8)观察w5、w7、w1 的梯度值是否与手动计算一致
    print(net.linear1.weight.grad.data)
    print(net.linear2.weight.grad.data)

    #更新梯度
    optimizer.step()
   
    # 打印更新后的网络参数
    print(net.state_dict())

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2249605.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

机器学习-决策树(ID3算法及详细计算推导过程)

决策树是一种基于树结构进行决策的机器学习算法 ,以下是关于它的详细介绍: 1.基本原理 决策树通过一系列的条件判断对样本进行分类或预测数值。它从根节点开始,根据不同的属性值逐步将样本划分到不同的分支,直到到达叶节点&…

【AI系统】LLVM 架构设计和原理

LLVM 架构设计和原理 在上一篇文章中,我们详细探讨了 GCC 的编译过程和原理。然而,由于 GCC 存在代码耦合度高、难以进行独立操作以及庞大的代码量等缺点。正是由于对这些问题的意识,人们开始期待新一代编译器的出现。在本节,我们…

浅谈网络 | 应用层之HTTPS协议

目录 对称加密非对称加密数字证书HTTPS 的工作模式重放与篡改 使用 HTTP 协议浏览新闻虽然问题不大,但在更敏感的场景中,例如支付或其他涉及隐私的数据传输,就会面临巨大的安全风险。如果仍然使用普通的 HTTP 协议,数据在网络传输…

基于 JNI + Rust 实现一种高性能 Excel 导出方案(上篇)

每个不曾起舞的日子,都是对生命的辜负。 ——尼采 一、背景:Web 导出 Excel 的场景 Web 导出 Excel 功能在数据处理、分析和共享方面提供了极大的便利,是许多 Web 应用程序中的重要功能。以下是一些典型的场景: 数据报表导出&am…

最新Linux下使用conda配置Java23或17保姆教程(附赠安装包)

随着技术的不断进步,越来越多的开发者开始在Linux环境下进行Java应用的开发。Java 17作为长期支持版本(LTS),提供了许多新特性和性能改进。当然现在最新的是Java23,这个还作为实验版本未广泛使用。对于需要管理多个编程…

RHEL7+Oracle11.2 RAC集群-多路径(multipath+udev)安装步骤

RHEL7Oracle11.2RAC集群-多路径(multipathudev)安装 配置虚拟存储 使用StarWind Management Console软件,配置存储 dggrid1: 1g*3 Dggrid2: 1g*3 Dgsystem: 5g*1 系统表空间,临时表空间,UNDO,参数文件…

PyTorch 模型转换为 ONNX 格式

PyTorch 模型转换为 ONNX 格式 在深度学习领域,模型的可移植性和可解释性是非常重要的。本文将介绍如何使用 PyTorch 训练一个简单的卷积神经网络(CNN)来分类 MNIST 数据集,并将训练好的模型转换为 ONNX 格式。我们还将讨论 PTH …

VM Virutal Box的Ubuntu虚拟机与windows宿主机之间设置共享文件夹(自动挂载,永久有效)

本文参考如下链接 How to access a shared folder in VirtualBox? - Ask Ubuntu (1)安装增强功能(Guest Additions) 首先,在网上下载VBoxGuestAdditions光盘映像文件 下载地址:Index of http://…

CA系统(file.h---申请认证的处理)

#pragma once #ifndef FILEMANAGER_H #define FILEMANAGER_H #include <string> namespace F_ile {// 读取文件&#xff0c;返回文件内容bool readFilename(const std::string& filePath);bool readFilePubilcpath(const std::string& filePath);bool getNameFro…

【Git】Git 命令参考手册

目录 Git 命令参考手册1. 创建仓库1.1 创建一个新的本地仓库1.2 克隆一个仓库1.3 克隆仓库到指定目录 2. 提交更改2.1 显示工作目录中已修改的文件&#xff0c;准备提交2.2 将文件添加到暂存区&#xff0c;准备提交2.3 将所有已修改的文件添加到暂存区&#xff0c;准备提交2.4 …

【Linux系列】Chrony时间同步服务器搭建完整指南

1. 简介 Chrony是一个用于Linux系统的高效、精准的时间同步工具&#xff0c;通常用于替代传统的NTP&#xff08;Network Time Protocol&#xff09;服务。Chrony不仅在系统启动时提供快速的时间同步&#xff0c;还能在时钟漂移较大的情况下进行及时调整&#xff0c;因此广泛应…

数据库日志

MySQL中有哪些日志 1&#xff0c;redo log重做日志 redo log是物理机日志&#xff0c;因为它记录的是对数据页的物理修改&#xff0c;而不是SQL语句。 作用是确保事务的持久性&#xff0c;redo log日志记录事务执行后的状态&#xff0c;用来恢复未写入 data file的已提交事务…

【vue for beginner】Vue该怎么学?

&#x1f308;Don’t worry , just coding! 内耗与overthinking只会削弱你的精力&#xff0c;虚度你的光阴&#xff0c;每天迈出一小步&#xff0c;回头时发现已经走了很远。 vue2 和 vue3 Vue2现在正向vue3逐渐更新中&#xff0c;官方vue2已经不再更新。 这个历程和当时的pyt…

【Ubuntu 24.04】How to Install and Use NVM

参考 下载 curl -o- https://raw.githubusercontent.com/nvm-sh/nvm/v0.39.7/install.sh | bash激活 Activate NVM: Once the installation script completes, you need to either close and reopen the terminal or run the following command to use nvm immediately. exp…

SeggisV1.0 遥感影像分割软件【源代码】讲解

在此基础上进行二次开发&#xff0c;开发自己的软件&#xff0c;例如&#xff1a;【1】无人机及个人私有影像识别【2】离线使用【3】变化监测模型集成【4】个人私有分割模型集成等等&#xff0c;不管是您用来个人学习 还是公司研发需求&#xff0c;都相当合适&#xff0c;包您满…

Python轴承故障诊断 (21)基于VMD-CNN-BiTCN的创新诊断模型

往期精彩内容&#xff1a; Python-凯斯西储大学&#xff08;CWRU&#xff09;轴承数据解读与分类处理 Pytorch-LSTM轴承故障一维信号分类(一)-CSDN博客 Pytorch-CNN轴承故障一维信号分类(二)-CSDN博客 Pytorch-Transformer轴承故障一维信号分类(三)-CSDN博客 三十多个开源…

使用docker搭建hysteria2服务端

源链接&#xff1a;https://github.com/apernet/hysteria/discussions/1248 官网地址&#xff1a;https://v2.hysteria.network/zh/docs/getting-started/Installation/ 首选需要安装docker和docker compose 切换到合适的目录 cd /home创建文件夹 mkdir hysteria创建docke…

基于Java实现的潜艇大战游戏

基于Java实现的潜艇大战游戏 一.需求分析 1.1 设计任务 本次游戏课程设计小组成员团队合作的方式&#xff0c;通过游戏总体分析设计&#xff0c;场景画面的绘制&#xff0c;游戏事件的处理&#xff0c;游戏核心算法的分析实现&#xff0c;游戏的碰撞检测&#xff0c;游戏的反…

课题组自主发展了哪些CMAQ模式预报相关的改进技术?

空气污染问题日益受到各级政府以及社会公众的高度重视&#xff0c;从实时的数据监测公布到空气质量数值预报及预报产品的发布&#xff0c;我国在空气质量监测和预报方面取得了一定进展。随着计算机技术的高速发展、空气污染监测手段的提高和人们对大气物理化学过程认识的深入&a…

深入解析下oracle date底层存储方式

之前我们介绍了varchar2和char的数据库底层存储格式&#xff0c;今天我们介绍下date类型的数据存储格式&#xff0c;并通过测试程序快速获取一个日期。 一、环境搭建 1.1&#xff0c;创建表 我们还是创建一个测试表t_code&#xff0c;并插入数据&#xff1a; 1.2&#xff0c;…