从零开始的深度学习之旅(3)

news2025/1/13 7:44:29

目录

  • 神经网络的损失函数
  • 1.损失函数的引入
  • 2.损失函数
  • 3.回归:误差平方和SSE
    • 3.1 MSE的使用
    • 3.2 二分类交叉熵损失函数
    • 3.3 极大似然估计推导二分类交叉熵损失
    • 3.4 用tensor实现二分类交叉熵损失
  • 4.多分类交叉熵损失函数
    • 4.1 实现多分类交叉熵损失

神经网络的损失函数

1.损失函数的引入

    在之前的学习中,我们建立神经网络时总是先设定好w与b的值,或者由我们调用的PyTorch类帮助我们随机生成权重向量,接着通过加和求出z ,再在z上嵌套sigmoid或者softmax函数,最终获得神经网络的输出。
 神经网络的计算是从左侧向右侧计算的.这是神经网络的正向传播过程。但这并不是神经网络算法的全流程,这个流程虽然可以输出预测结果,但却无法保证神经网络的输出结果与真实值接近。
  此时,我们就要训练神经网络,求解一组最适合的w和b,令神经网络的输出结果与真实值接近,这就是神经网络模型训练的目标.

2.损失函数

    比如我们做了一个预测房价的实验,预测的房价和真正的房价之间肯定存在差异.当真实值与预测值差异越大时,我们就认为神经网络学习过程中丢失了许多信息,丢失的这部分称为”损失“,因此评估真实值与预测值差异的函数被我们称为“损失函数.

  损失函数

  1.在数学上,表示为以需要求解的权重向量ω为自变量的函数L(ω)。

  2.衡量真实值与预测结果的差异,评价模型学习过程中产生的损失的函数。

  3.如果损失函数的值很小,则说明模型预测值与真实值很接近,模型训练得很好

    我们希望损失函数越小越好,以此,我们将问题转变为求解函数L(ω)的最小值所对应的自变量ω.

3.回归:误差平方和SSE

SSE误差平方和:
在这里插入图片描述
    其中zi(公式的前者)是样本i的真实值,而zihat(公式的后者)是样本i的预测值。对于全部样本的平均损失,则可以写作:
在这里插入图片描述

3.1 MSE的使用

# 按照MSE的公式,pytorch已经写好了函数,直接调用就行

import torch
from torch.nn import MSELoss 

yhat=torch.randn(size=(50,),dtype=torch.float32)
y=torch.randn(size=(50,),dtype=torch.float32)

criterion=MSELoss() 
loss = criterion(yhat,y)
loss

输出结果:
在这里插入图片描述

3.2 二分类交叉熵损失函数

    在这一节中,我们将介绍二分类神经网络的损失函数:二分类交叉熵损失函数,也叫做对数损失.
 大多数时候,除非特殊声明为二分类,否则提到交叉熵损失,我们会默认算法的分类目标是多分类.
 二分类交叉熵损失函数是由极大似然估计推导出来的,对于有m个样本的数据集而言,全部样本上的平均损失写作:在这里插入图片描述
单个样本损失:

在这里插入图片描述

    在公式中,ln是以自然底数为底的对数函数,ω表示求解出来的一组权重(ω在σ里),m是样本的个数,yi是样本i上真实的标签,σi是样本i上基于参数计算出来的sigmoid函数的返回值,xi是样本i各个特征的取值。

3.3 极大似然估计推导二分类交叉熵损失

    极大似然估计,如果一个事件的发生概率很大,那这个事件应该很容易发生。
 寻找相应的权重ω,使得目标事件的发生概率最大,就是极大似然估计的基本方法。


    二分类神经网络的标签是[0,1],样本i在由特征向量xi和权重向量ω组成的预测函数中,样本标签被预测为1的概率为:
在这里插入图片描述


    样本i在由特征向量 和权重向量 组成的预测函数中,样本标签被预测为0的概率为:
在这里插入图片描述


    当P1的值为1的时候,代表样本i的标签被预测为1,当P0的值为1的时候,代表样本i的标签被预测为0。P1与P0 相加是一定等于1的.

将两种概率联合:
单个的概率:
在这里插入图片描述
将P1和P2替换,加上符号得到所有样本的概率:
在这里插入图片描述
对该概率P取以e为底的对数:
在这里插入图片描述
我们将极大值转换为极小值,因此我们对lnP取负:
在这里插入图片描述

3.4 用tensor实现二分类交叉熵损失

import torch
import time

N = 3*pow(10,3)
torch.random.manual_seed(420)
X = torch.rand((N,4),dtype=torch.float32)
w = torch.rand((4,1),dtype=torch.float32,requires_grad=True)
y = torch.randint(low=0,high=2,size=(N,1),dtype=torch.float32)
zhat = torch.mm(X,w)
sigma = torch.sigmoid(zhat)
Loss = -(1/N)*torch.sum((1-y)*torch.log(1-sigma)+y*torch.log(sigma))

在这里插入图片描述

4.多分类交叉熵损失函数

    对于多分类的状况而言,标签不再服从伯努利分布(0-1分布),因此我们可以定义,样本i在由特征向量和权重向量组成的预测函数中,样本标签被预测为类别k的概率为:
在这里插入图片描述
   对于多分类算法而言,σ就是softmax函数返回的对应类别的值。

    假设样本的真实标签为1,我们就希望 P1最大,同理,如果样本的真实标签为其他值,我们就希望其他值所对应的概率最大。二分类可以使用0和1来分类,如果多分类的标签也可以使用0和1来表示就好了,这样我们就可以继续使用真实标签作为指数的方式,如下图方式进行改变

  原本的真实标签y是含有[1, 2, 3]三个分类的列向量,现在我们把它变成了标签矩阵,每个样本对应一个向量.

在这里插入图片描述


  当我们把标签整合为标签矩阵后,我们就可以将单个样本在总共K个分类情况整合为以下的似然函数
在这里插入图片描述


公式简写为:
在这里插入图片描述

所有可能的的概率P求和为:
在这里插入图片描述
  再对整个公式取负,就得到了多分类状况下的损失函数
在这里插入图片描述
在这里插入图片描述

4.1 实现多分类交叉熵损失

import torch
import torch.nn as nn
N = 3*pow(10,2)
torch.random.manual_seed(420)
X = torch.rand((N,4),dtype=torch.float32)
w = torch.rand((4,3),dtype=torch.float32,requires_grad=True)

y = torch.randint(low=0,high=3,size=(N,),dtype=torch.float32)
zhat = torch.mm(X,w)
#从这里开始调用softmax和NLLLoss
logsm = nn.LogSoftmax(dim=1) #实例化
logsigma = logsm(zhat)
criterion = nn.NLLLoss() #实例化
criterion(logsigma,y.long())

输出结果:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/27784.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Excel - 获取帮助信息,查找Sheet中和VBA里的可用函数

Excel获取帮助信息 在使用Excel时,可以点击菜单的Help,可以获取帮助信息或Training。 点击Help帮助信息: 如果你觉得查看不方便,开可以使用浏览器,访问官网线上支持文档: Excel help & learning 而点击…

【微服务】GateWay概念与使用

一、API 网关功能: 路由到指定位置:后台管理系统经常给各个服务发送请求,某一个服务掉线了,我们不可能手动去修改端口号,让它去其他机器找。因此,需要 API 网关,让其帮助我们将请求路由到正确位…

【华为OD机试真题 python】竖直四子棋【2022 Q4 | 200分】

■ 题目描述 【竖直四子棋】 竖直四子棋的棋盘是竖立起来的,双方轮流选择棋盘的一列下子,棋子因重力落到棋盘底部或者其他棋子之上,当一列的棋子放满时,无法再在这列上下子。 一方的4个棋子横、竖或者斜方向连成一线时获胜。 现给定一个棋盘和红蓝对弈双方的下子步骤,…

学会问问题

推荐文档:学会问问题; 目录 三句话原则 你就是孙子 问问题过程 第一步—学会问好 示例如下 第二步—有屁快放 问问题需要加上的前缀或者后缀: 示例如下 第三步—介绍自己的框架 示例如下 第四步—介绍自己的解决思路 示例如下 …

spring cache (Redis方式)

目录前置pom: jar配置文件: application.ymlMyCacheConfig.java效果图前置 会演示springcache的使用方式 项目地址: https://gitee.com/xmaxm/test-code/blob/master/chaim-cache/chaim-spring-cache/chaim-spring-cache-redis/README.md 前置配置 本篇文章是基于上篇文章进行…

Flutter 使用FFI+CustomPainter实现全平台渲染视频

Flutter视频渲染系列 第一章 Android使用Texture渲染视频 第二章 Windows使用Texture渲染视频 第三章 Linux使用Texture渲染视频 第四章 全平台FFICustomPainter渲染视频(本章) 文章目录Flutter视频渲染系列前言一、如何实现1、C/C实现视频采集&#xf…

3. 使用PyTorch深度学习库训练第一个卷积神经网络CNN

这篇博客将介绍如何使用PyTorch深度学习库训练第一个卷积神经网络(CNN)。训练CNN使用 KMNIST 数据集(MNIST digits数据集的替代品,内置在PyTorch中)识别手写平假名字符(handwritten Hiragana characters&am…

图的二种遍历-广度优先遍历和深度优先遍历

图的广度优先遍历 1.树的广度优先遍历 这样一个图中,是如何实现广度优先遍历的呢,首先,从1遍历完成之后,在去遍历2,3,4,最后遍历5 ,6 , 7 , 8。这也就是为什么叫做广度优先遍历,是一层一层的往…

36个数据分析方法与模型

目录一、战略与组织二、质量与生产三、营销服务四、财务管理五、人力资源六、互联网运营好的数据分析师不仅熟练地掌握了分析工具,还掌握了大量的数据分析方法和模型。这样得出的结论不仅具备条理性和逻辑性,而且还更具备结构化和体系化,并保…

Python连接MYSQL、SQL Server、Oracle数据入库一网打尽

描述: Python众所周知用来数据提取,通俗说用来抓数据,将拿到的数据进行数据清洗、加工,分析等等。而其中最重要的部分就是数据爬取、数据入库这两部分了,至于数据分析那就特别考察你的SQL能力,如果是自己设计页面&…

马齿苋多糖偶联顺铂复合物/黄连素偶联顺铂化合物/载顺铂mPEg-PGA纳米微球制备方法

小编今天整理了马齿苋多糖偶联顺铂复合物/黄连素偶联顺铂化合物/载顺铂mPEg-PGA纳米微球制备方法,一起来看! 黄连素偶联顺铂化合物制备方法: 以A549/DDP细胞为研究对象,分别加入12 μg/mL的顺铂,浓度为20 μmol/L,40 μmol/L,80 μmol/L的黄连素12 μg/…

艾美捷EndoGrade卵清蛋白重组示例说明

卵清蛋白是一种优质蛋白质,占蛋清蛋白总量的 54%-69%,卵清蛋白是典型的球蛋白,分子量为 44.5k Da,属含磷糖蛋白,含有四个自由巯基、385 个氨基酸残基。这些氨基酸残基相互缠绕折叠形成具有高度二级结构的球型结构&…

spring cache (默认方式)

目录前置pom配置示列代码效果图部分源码关键类流程代码描述 (此类无用, 只是备注源码的逻辑)前置 什么是springcache: 通过注解就能实现缓存功能, 简化在业务中去操作缓存 Spring Cache只是提供了一层抽象, 底层可以切换不同的cache实现. 通过CacheManager接口来统一不同的缓存…

大数据培训课程MapTask工作机制

MapTask工作机制 MapTask工作机制如图4-12所示。 图4-12 MapTask工作机制 (1)Read阶段:MapTask通过用户编写的RecordReader,从输入InputSplit中解析出一个个key/value。 (2)Map阶段:该节点主要…

java面试强基(9)

字符串拼接用“” 还是 StringBuilder? ​ Java 语言本身并不支持运算符重载,“”和“”是专门为 String 类重载过的运算符,也是 Java 中仅有的两个重载过的运算符。 ​ 字符串对象通过“”的字符串拼接方式,实际上是通过 StringBuilder 调…

【MFC】一个最简单的MFC程序(9)

了解完MFC程序的流程后,会有 “果然不需要了解这些东西,直接用就可以了” 的感觉。这应该是MFC的初衷吧——按照框架来,集中精力做应用。但是没有了解呢? 最简单的MFC程序 步骤: 1、创建WIN32应用程序,空…

GoWeb 的 MVC 入门实战案例,基于 Iris 框架实现(附案例全代码)

1、什么是 MVC M 即 Model 模型是指模型表示业务规则。在MVC的三个部件中,模型拥有最多的处理任务。被模型返回的数据是中立的,模型与数据格式无关,这样一个模型能为多个视图提供数据,由于应用于模型的代码只需写一次就可以被多个…

1531_AURIX_TriCore内核架构_任务以及函数

全部学习汇总: GreyZhang/g_tricore_architecture: some learning note about tricore architecture. (github.com) 继续前面的内核架构学习,这次看一下任务以及函数的描述。 1. 在嵌入式系统中,内核以及函数的设计其实是有一定的模型或者说是…

day33 文件上传中间件解析漏洞编辑器安全

前言 先判断中间件,是否有解析漏洞,字典扫描拿到上传点,或者会员中心,有可能存在文件上传的地方,而后测试绕过/验证,根据实际情况判断是白名单、黑名单还是内容其他的绕过,绕过/验证和中间件的…

数字信号处理FFT快速傅立叶变换MATLAB实现——实例

今天做作业的时候发现要对一个信号进行FFT变换,在网上找了半天也没找到个能看懂的(因为我太菜了),后来自己研究了一下,感觉一知半解的 起因是这道作业题 例题-满足奈奎斯特 我画了两个图,一个是原信号经过…