神经网络 torch.nn---损失函数与反向传播

news2024/11/26 9:45:51

torch.nn - PyTorch中文文档 (pytorch-cn.readthedocs.io)

torch.nn — PyTorch 2.3 documentation

Loss Function的作用

  • 每次训练神经网络的时候都会有一个目标,也会有一个输出。目标和输出之间的误差,就是用Loss Function来衡量的。所以,误差Loss是越小越好的。

  • 此外,我们可以根据误差Loss,指导输出output接近目标target。即我们可以以target为依据,不断训练神经网络,优化神经网络中各个模块,从而优化output

Loss Function的作用

  1. 计算实际输出和目标之间的差距
  2. 为我们更新输出提供一定的依据,这个提供依据的过程也叫反向传播

nn.L1Loss

创建一个衡量输入x(模型预测输出)和目标y之间差的绝对值的平均值的标准。

class torch.nn.L1Loss(size_average=None, reduce=None, reduction='mean')

参数说明:

  • reduction:默认为 ‘mean’ ,可选meansum

  • reduction='mean'时,计算误差采用公式:

  • reduction='sum'时,计算误差采用公式:

需要注意的是,计算的数据必须为浮点数

程序代码:

import torch
from torch.nn import L1Loss

input=torch.tensor([1,2,3],dtype=torch.float32)
target=torch.tensor([1,2,5],dtype=torch.float32)

input=torch.reshape(input,(1,1,1,3))
target=torch.reshape(target,(1,1,1,3))

loss1=L1Loss()  #reduction='mean'
loss2=L1Loss(reduction='sum')  
result1=loss1(input,target)
result2=loss2(input,target)

print(result1,result2)

输出:

nn.MSELoss

创建一个衡量输入x(模型预测输出)和目标y之间均方误差标准。

  • x 和 y 可以是任意形状,每个包含n个元素。

  • n个元素对应的差值的绝对值求和,得出来的结果除以n

  • 如果在创建MSELoss实例的时候在构造函数中传入size_average=False,那么求出来的平方和将不会除以n

class torch.nn.MSELoss(size_average=None, reduce=None, reduction='mean')

参数说明:

reduction:默认为 ‘mean’ ,可选meansum

  • reduction='mean'时,计算误差采用公式:

  • reduction='sum'时,计算误差采用公式:

程序代码:

import torch
from torch.nn import L1Loss,MSELoss

input = torch.tensor([1,2,3],dtype=torch.float32)
target = torch.tensor([1,2,5],dtype=torch.float32)

input = torch.reshape(input,(1,1,1,3))
target = torch.reshape(target,(1,1,1,3))

loss_mse1 = MSELoss()  #reduction='mean'
loss_mse2 = MSELoss(reduction='sum')
result_mse1 = loss_mse1(input, target)
result_mse2 = loss_mse2(input, target)

print(result_mse1, result_mse2)

输出:

nn.CrossEntropyLoss(交叉熵)

当训练一个分类问题的时候,假设这个分类问题有C个类别,那么有:

 当weight参数被指定的时候,loss的计算公式变为:

计算出的lossmini-batch的大小取了平均。

形状(shape):

  • Input: (N,C)    其中N代表batch_size,C 是类别的数量即数据要分成几类(或有几个标签)。

  • Target: (N)     Nmini-batch的大小,0 <= targets[i] <= C-1

举个例子:

  • 我们对包含了人、狗、猫的图片进行分类,其标签的索引分别为0、1、2。这时候将一张的图片输入神经网络,即目标(target)为1(对应标签索引)。输出结果为[0.1,0.2,0.3],该列表中的数字分别代表分类标签对应的概率。

  • 根据上述分类结果,图片为的概率更大,即0.3。对于该分类的Loss Function,我们可以通过交叉熵去计算,即:

那么如何验证这个公式的合理性呢?根据上面的例子,分类结果越准确,Loss应该越小。这条公式由两个部分组成:

  • 1、log(∑jexp(x[j])

log(∑jexp(x[j])主要作用是控制或限制预测结果的概率分布。比如说,预测出来的人、狗、猫的概率均为0.9,每个结果概率都很高,这显然是不合理的。此时 log(∑jexp(x[j]) 的值会变大,误差loss(x,class)也会随之变大。同时该指标也可以作为分类器性能评判标准。

  • 2、−x[class]:在已知图片类别的情况下,预测出来对应该类别的概率x[class]越高,其预测结果误差越小。

程序代码:

import torch
from torch import nn
from torch.nn import L1Loss

inputs = torch.tensor([1, 2, 3], dtype=torch.float)
targets = torch.tensor([1, 2, 5], dtype=torch.float)

inputs = torch.reshape(inputs, (1, 1, 1, 3))
targets = torch.reshape(targets, (1, 1, 1, 3))

x = torch.tensor([0.1, 0.2, 0.3], dtype=torch.float)
y = torch.tensor([1])
x = torch.reshape(x, (1, 3))

loss_cross_1 = nn.CrossEntropyLoss(reduction='mean')
result_cross_1 = loss_cross_1(x, y)
loss_cross_2 = nn.CrossEntropyLoss(reduction='sum')
result_cross_2 = loss_cross_2(x, y)
print(result_cross_1, result_cross_2)

输出:

反向传播

如何根据Loss Function为更新神经网络数据提供依据?

  • 对于每个卷积核当中的参数,设置一个grad(梯度)。

  • 当我们进行反向传播的时候,对每一个节点的参数都会求出一个对应的梯度。之后我们根据梯度对每一个参数进行优化,最终达到降低Loss的一个目的。比较典型的一个方法——梯度下降法

代码举例:

 result_loss = loss(outputs, targets)
 result_loss.backward()
  • 上面就是反向传播的使用方法,它的主要作用是计算一个grad。使用debug功能并删掉上面这行代码,会发现单纯由result_loss=loss(output,targets)计算出来的结果,是没有grad这个参数的。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1795493.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

超简单白话文机器学习 - 模型检验与评估(含算法介绍,公式,源代码实现以及调包实现)

1. 模型检验 1.1 Holdout交叉验证 1.1.1 算法 在这种交叉验证技术中&#xff0c;整个数据集被随机划分为训练集和验证集。根据经验&#xff0c;整个数据集的近 70% 用作训练集&#xff0c;其余 30% 用作验证集。 优点&#xff1a;可以快速进行区分&#xff0c;仅仅通过一次区…

【cmake】cmake cache

cmake cache是什么 cmake cache是cmake在配置好后生成的一个CMakeCache.txt的文件&#xff0c;里面存储了一堆变量&#xff0c;这些变量一般都是关于项目的配置和环境的。 比如你用的什么编译器&#xff0c;编译器选项&#xff0c;还有项目目录。 例如&#xff08;在cmakelist…

06- 数组的基础知识详细讲解

06- 数组的基础知识详细讲解 一、基本概念 一次性定义多个相同类型的变量&#xff0c;并且给它们分配一片连续的内存。 int arr[5];1.1 初始化 只有在定义的时候赋值&#xff0c;才可以称为初始化。数组只有在初始化的时候才可以统一赋值。 以下是一些示例规则&#xff1a; …

Vivado 设置关联使用第三方编辑器 Notepad++

目录 1.前言2.Vivado关联外部编辑器步骤3.Notepad的一些便捷操作 微信公众号获取更多FPGA相关源码&#xff1a; 1.前言 Vivado软件自带的编辑器超级难用&#xff0c;代码高亮对比不明显&#xff0c;而且白色背景看久了眼睛痛。为了写代码时有更加舒适的体验&#xff0c;可以…

伯克希尔·哈撒韦:“股神”的“登神长阶”

股价跳水大家见过不少&#xff0c;但一秒跌掉62万美元的你见过吗&#xff1f; 今天我们来聊聊“股市”巴菲特的公司——伯克希尔哈撒韦 最近&#xff0c;由于纽交所技术故障&#xff0c;伯克希尔哈撒韦A类股股价上演一秒归“零”&#xff0c;从超过62万美元跌成185.1美元&…

《QT从基础到进阶·四十一》无法解析的外部符号及生成事件加入QT打包命令报错问题

其他无法解析的外部符号&#xff1a; 无法解析的外部符号 "public: virtual struct QMetaObject const * __cdecl ML_AddinManger::metaObject(void)const "… 无法解析的外部符号 “public: virtual void * __cdecl ML_AddinManger::qt_metacast(char const *)” (?…

微信小程序毕业设计-民大食堂用餐综合服务平台系统项目开发实战(附源码+论文)

大家好&#xff01;我是程序猿老A&#xff0c;感谢您阅读本文&#xff0c;欢迎一键三连哦。 &#x1f49e;当前专栏&#xff1a;微信小程序毕业设计 精彩专栏推荐&#x1f447;&#x1f3fb;&#x1f447;&#x1f3fb;&#x1f447;&#x1f3fb; &#x1f380; Python毕业设计…

查看服务器端口是否打开,如何查看服务器端口是否打开

查看服务器端口是否打开&#xff0c;是确保服务器正常运行和网络通信畅通的关键步骤。以下是几个有力的方法&#xff0c;帮助你快速、准确地判断端口状态。 首先&#xff0c;你可以使用telnet命令来检测端口的连通性。telnet是一个网络协议&#xff0c;可以用于远程登录和管理网…

中国新闻网怎么投稿 新闻稿件文章如何发布到中国新闻网上,附中国新闻网价格明细

中国新闻网是中国最具影响力和权威性的新闻门户网站之一。作为广大作者和媒体从业者&#xff0c;怎样向中国新闻网投稿一直是一个备受关注的话题。在这篇文章中&#xff0c;我们将着重介绍媒介库网发稿平台&#xff0c;并分享如何在该平台上成功投稿至中国新闻网。 媒介库网发稿…

代码片段 | Matlab三维图显示[ R T 0 1] 的最佳方法

% 输入N组RT矩阵 N 4; R zeros(3, 3, N); T zeros(3, N); R(:,:,1) [-0.902608 0.250129 0.350335 ; 0.314198 0.939127 0.138996 ;-0.294242 0.235533 -0.926253 ]; T(:,1) [205.877;2796.02; 907.116];R(:,:,2) [-0.123936 0.643885 0.755018 ;0.816604 0.464468 -0.26…

Docker基础篇之本地镜像发布到阿里云

文章目录 1. 本地镜像发布到阿里云的流程2. 阿里云开发平台3. 将自己的本地镜像推送到阿里云 1. 本地镜像发布到阿里云的流程 阿里云ECS Docker生态如下图所示&#xff1a; 2. 阿里云开发平台 在控制台找到容器和镜像服务&#xff1a; 然后创建一个个人实例&#xff1a; 下面…

LeetCode 算法:合并区间c++

原题链接&#x1f517;&#xff1a;合并区间 难度&#xff1a;中等⭐️⭐️ 题目 以数组 intervals 表示若干个区间的集合&#xff0c;其中单个区间为 intervals[i] [starti, endi] 。请你合并所有重叠的区间&#xff0c;并返回 一个不重叠的区间数组&#xff0c;该数组需恰…

java第二十课 —— 面向对象习题

类与对象练习题 编写类 A01&#xff0c;定义方法 max&#xff0c;实现求某个 double 数组的最大值&#xff0c;并返回。 public class Chapter7{public static void main(String[] args){A01 m new A01();double[] doubleArray null;Double res m.max(doubleArray);if(res !…

内地户口转香港身份的7种途径!2024年怎么同时拥有2个身份?一篇说明白

很多人还不知道怎么同时拥有内地身份和香港身份&#xff0c;这里一次性说明白&#xff0c;不同背景情况及政策有可能随时变化&#xff0c;这里分享最近拿香港身份的7种途径。 #01 优才『香港优秀人才计划』 获批准的申请人无须在来港定居前先获得本地雇主聘任。所有申请人均必…

搜索与图论:图中点的层次

搜索与图论&#xff1a;图中点的层次 题目描述参考代码 题目描述 输入样例 4 5 1 2 2 3 3 4 1 3 1 4输出样例 1参考代码 #include <cstring> #include <iostream> #include <algorithm>using namespace std;const int N 100010;int n, m; int h[N], e[N]…

【云岚到家】-day01-项目熟悉-查询区域服务开发

文章目录 1 云岚家政项目概述1.1 简介1.2 项目业务流程1.3 项目业务模块1.4 项目架构及技术栈1.5 学习后掌握能力 2 熟悉项目2.1 熟悉需求2.2 熟悉设计2.2.1 表结构2.2.2 熟悉工程结构2.2.3 jzo2o-foundations2.2.3.1 工程结构2.2.3.2 接口测试 3 开发区域服务模块3.1 流程分析…

【命令scp】Linux不同主机之间拷贝指令scp

scp可以在不同主机之间拷贝文件 # 将本地文件拷贝到远程服务器 scp a.tar changxr192.168.100.100:/home/changxr/cxr

Android——热点开关演讲稿

SoftAP打开与关闭 目录 1.三个名词的解释以及关系 Tethering——网络共享&#xff0c;WiFi热点、蓝牙、USB SoftAp——热点(无线接入点)&#xff0c;临时接入点 Hostapd——Hostapd是用于Linux系统的软件&#xff0c;&#xff0c;支持多种无线认证和加密协议&#xff0c;将任…

Visual Studio和BOM历史渊源

今天看文档无意间碰到了微软对编码格式解释&#xff0c;如下链接&#xff1a; Understanding file encoding in VS Code and PowerShell - PowerShell | Microsoft LearnConfigure file encoding in VS Code and PowerShellhttps://learn.microsoft.com/en-us/powershell/scrip…

锁存器(Latch)的产生与特点

Latch 是什么 Latch 其实就是锁存器&#xff0c;是一种在异步电路系统中&#xff0c;对输入信号电平敏感的单元&#xff0c;用来存储信息。锁存器在数据未锁存时&#xff0c;输出端的信号随输入信号变化&#xff0c;就像信号通过一个缓冲器&#xff0c;一旦锁存信号有效&#…