交叉熵损失函数与参数更新计算实例(节点分类为例)

news2024/11/26 19:53:57

交叉熵损失与参数更新

数据准备

对于下面这样一个图网络网络:
在这里插入图片描述

假如我们得到了节点i的嵌入表示 z i z_i zi数据如下:
i d , x 0 , x 1 , x 2 , x 3 1 , 0.5 , 0.6 , 0.7 , 0.8 2 , 0.3 , 0.8 , 0.3 , 0.4 3 , 0.7 , 0.9 , 0.6 , 0.9 4 , 0.2 , 0.1 , 0.2 , 0.3 5 , 0.8 , 0.4 , 0.3 , 0.2 id,x_0,x_1,x_2,x_3\\ 1,0.5,0.6,0.7,0.8\\ 2,0.3,0.8,0.3,0.4\\ 3,0.7,0.9,0.6,0.9\\ 4,0.2,0.1,0.2,0.3\\ 5,0.8,0.4,0.3,0.2\\ id,x0,x1,x2,x31,0.5,0.6,0.7,0.82,0.3,0.8,0.3,0.43,0.7,0.9,0.6,0.94,0.2,0.1,0.2,0.35,0.8,0.4,0.3,0.2
为了方便说明,我们来处理一个对节点进行有监督分类的问题。
假设我们要对节点的嵌入表示进行分类
真实的类别如下:
1,3,属于第0类
2,4,属于第1类
5,属于第2类

分类层

我们对每个节点经过一个全连接层,我们随机初始化w0,w1,w2三个4维(嵌入向量维数)的权重向量(结果保留两位有效数字,下同)。
w 0 = [ 0.17 , 0.4 , − 0.14 , 0.51 ] w_0 = [0.17,0.4,-0.14,0.51] w0=[0.17,0.4,0.14,0.51]
w 1 = [ 0.75 , − 0.04 , 0.67 , − 0.18 ] w_1 = [0.75,-0.04,0.67,-0.18] w1=[0.75,0.04,0.67,0.18]
w 2 = [ 0.53 , − 0.04 , 0.4 , 0.77 ] w_2 = [0.53,-0.04,0.4,0.77] w2=[0.53,0.04,0.4,0.77]
b 0 , b 1 , b 2 = 0.05 , − 0.11 , − 0.32 b_0,b_1,b_2 = 0.05,-0.11,-0.32 b0,b1,b2=0.05,0.11,0.32
w i , b i w_i,b_i wi,bi对应将节点向量转化为节点属于i类的过程的一些权重;

于是对节点 z 1 z_1 z1,我们得到:

h 1 = [ z 1 w 0 + b 0 , z 1 w 1 + b 1 , z 1 w 2 + b 2 ] = [ 0.68 , 0.57 , 0.82 ] h_1 = [z_1w_0+b_0,z_1w_1+b_1,z_1w_2+b_2]\\ = [0.68, 0.57, 0.82] h1=[z1w0+b0,z1w1+b1,z1w2+b2]=[0.68,0.57,0.82]

类似地,我们得到
h 2 = [ 0.58 , 0.21 , 0.23 ] h 3 = [ 0.9 , 0.62 , 0.95 ] h 4 = [ 0.25 , 0.12 , 0.09 ] h 5 = [ 0.47 , 0.55 , 0.4 ] h_2 = [0.58, 0.21, 0.23] \\ h_3 = [0.9, 0.62, 0.95] \\ h_4 = [0.25, 0.12, 0.09]\\ h_5 = [0.47, 0.55, 0.4]\\ h2=[0.58,0.21,0.23]h3=[0.9,0.62,0.95]h4=[0.25,0.12,0.09]h5=[0.47,0.55,0.4]

softmax矩阵

softmax函数作为一种归一化函数,可以将一组任意实数转换为一个概率分布,常用于多分类问题,其表达式为:

softmax ( z i ) = e x p ( z i ) ∑ j = 1 K e x p ( z j ) , i = 1 , … , K \text{softmax}(z_i) = \frac{exp(z_i)}{\sum_{j=1}^K exp(z_j)}, \quad i=1,\ldots,K softmax(zi)=j=1Kexp(zj)exp(zi),i=1,,K
K为分类的类别个数,z_i为实际上是向量z的第i个分量,分类问题中,对于向量 z z z而言,softmax的函数值也就是 z z z属于第 i i i类的概率。

在这里,以 h 1 h_1 h1为例,softmax的表达式可以写成:

softmax ( h 1 [ i ] ) = e x p ( h 1 [ i ] ) ∑ j = 1 K e x p ( h 1 [ j ] ) , i = 1 , 2 , 3 \text{softmax}(h_{1}[i]) = \frac{exp(h_1[i])}{\sum_{j=1}^K exp(h_1[j])}, \quad i=1,2,3 softmax(h1[i])=j=1Kexp(h1[j])exp(h1[i]),i=1,2,3
h 1 [ i ] h_1[i] h1[i]表示 h 1 h_1 h1中第i个分量。

于是我们将 h 1 , h 2 , . . . , h 5 h_1,h_2,...,h_5 h1,h2,...,h5每个向量传入softmax函数,得到节点属于各类别的概率分布:

p 1 = [ 0.33 , 0.29 , 0.38 ] p 2 = [ 0.42 , 0.29 , 0.29 ] p 3 = [ 0.36 , 0.27 , 0.37 ] p 4 = [ 0.37 , 0.32 , 0.31 ] p 5 = [ 0.33 , 0.36 , 0.31 ] p_1 = [0.33, 0.29, 0.38]\\ p_2 = [0.42, 0.29, 0.29] \\ p_3 = [0.36, 0.27, 0.37] \\ p_4 = [0.37, 0.32, 0.31] \\ p_5 = [0.33, 0.36, 0.31]\\ p1=[0.33,0.29,0.38]p2=[0.42,0.29,0.29]p3=[0.36,0.27,0.37]p4=[0.37,0.32,0.31]p5=[0.33,0.36,0.31]
于是根据我们上面所提到的, p 1 p_1 p1中最大的是第3列,也就是说,根据我们的结果,节点1属于第2类的概率最大。

交叉熵损失

如是我们可以得出,节点2,4属于第0类,节点5属于第1类,节点1,3属于第2类,这个结果和实际分类相差比较大,所以参数w和b需要重新训练。

为此我们引入交叉熵损失函数:

L = − 1 N ∑ i = 1 , 2 , . . , N l o g e x p ( h i [ s ] ) ∑ q = 0 , 1 , 2 e x p ( h i [ q ] ) L = - \frac{1}{N} \sum_{i = 1,2,..,N} log \frac{exp(h_i[s])}{ \sum_{q = 0,1,2} exp(h_i[q])} L=N1i=1,2,..,Nlogq=0,1,2exp(hi[q])exp(hi[s])

其中N是节点数量5,s表示节点i所属的真实类别。
我们可以看到后面这个分式实质上就是节点i在真实类别s上对应的softmax函数值;所以实际上损失函数的目标,也就是让节点i被分到真实类别的概率最大化。

L = − 1 N ∑ i = 1 , 2 , . . , N l o g ( p i [ s ] ) L = - \frac{1}{N} \sum_{i = 1,2,..,N}log(p_i[s]) L=N1i=1,2,..,Nlog(pi[s])

之前提到的节点的真实的类别如下:
1,3,属于第0类
2,4,属于第1类
5,属于第2类

前面已经计算出了节点对应的softmax函数值 z 1 z_1 z1 z 5 z_5 z5
所以,在这里
L = − 1 5 ( l n ( p 1 [ 0 ] p 2 [ 1 ] p 3 [ 0 ] p 4 [ 1 ] p 5 [ 2 ] ) ) = − 1 / 5 ( l n ( 0.33 × 0.29 × 0.36 × 0.32 × 0.31 ) ) L = -\frac{1}{5}(ln(p_1[0]p_2[1]p_3[0]p_4[1]p_5[2]))\\ = -1/5(ln(0.33×0.29×0.36×0.32×0.31)) L=51(ln(p1[0]p2[1]p3[0]p4[1]p5[2]))=1/5(ln(0.33×0.29×0.36×0.32×0.31))

我们最终求出此次的损失函数值为-1.1358

反向传播更新参数

此时,我们已经求出了损失函数,下面要做的就是将损失函数对参数求梯度然后反向传播了

∂ L ∂ w t = ∂ L ∂ h i ∂ h i ∂ p i ∂ p i ∂ w i \frac{\partial L}{\partial w_t} = \frac{\partial L}{\partial h_i}\frac{\partial h_i}{\partial p_i}\frac{\partial p_i}{\partial w_i} wtL=hiLpihiwipi

具体求导过程请参考此处

最终得出的结果为:

∂ L ∂ w i = ∑ n = 1 N ( p n ( i ) − 1 ) x n \frac{\partial L}{\partial w_i} = \sum_{n = 1}^N(p_n(i)-1)x_n\\ wiL=n=1N(pn(i)1)xn
特别地,对于 b i b_i bi,我们的梯度应当为

∂ L ∂ b i = ∑ n = 1 N ( p n ( i ) − 1 ) \frac{\partial L}{\partial b_i} = \sum_{n = 1}^{N}(p_n(i)-1) biL=n=1N(pn(i)1)

其中p(i)指的是向量 x n x_n xn属于第i类的概率(softmax函数值);N为节点数5.

所以
∂ L ∂ w 0 = ( p 1 [ 0 ] − 1 ) z 1 + ( p 2 [ 0 ] − 1 ) z 2 + . . . + ( p 5 [ 0 ] − 1 ) z 5 \frac{\partial L}{\partial w_0} = (p_1[0]-1)z_1 +(p_2[0]-1)z_2 +...+(p_5[0]-1)z_5 w0L=(p1[0]1)z1+(p2[0]1)z2+...+(p5[0]1)z5

这样的话,最终得到的梯度结果是一个4维的向量
∂ L ∂ w 0 = [ − 1.62 , − 1.77 , − 1.29 , − 1.73 ] \frac{\partial L}{\partial w_0} = [-1.62,-1.77,-1.29,-1.73] w0L=[1.62,1.77,1.29,1.73]

梯度下降法更新参数:
W : = W − α ∂ L ∂ W W := W - \alpha\frac{\partial L}{\partial W} W:=WαWL

w 0 = [ 0.17 , 0.4 , − 0.14 , 0.51 ] w_0 = [0.17,0.4,-0.14,0.51] w0=[0.17,0.4,0.14,0.51],假如我们设定学习率 α = 0.2 \alpha = 0.2 α=0.2

w 0 : = w 0 − 0.2 ∂ L ∂ W = [ 0.49 , 0.75 , 0.12 , 0.86 ] w_0 := w_0 - 0.2 \frac{\partial L}{\partial W}\\ = [0.49,0.75,0.12,0.86] w0:=w00.2WL=[0.49,0.75,0.12,0.86]
对其他参数 w 1 w_1 w1, w 2 w_2 w2也作类似操作。

这样就完成了一轮参数更新。
(此处更新参数的计算加入了自己的理解,如有疏漏敬请指正)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/475323.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【计算几何】判断一条线段和一段圆弧是否相交 C++代码实现

文章目录 一、前言二、线段与圆弧的代码表示2.1 线段代码表示2.2 圆弧代码表示 三、实现思路及数学推导3.1 第一步(粗略判断)3.2 第二步3.3 第三步 四、完整代码五、效果展示 一、前言 最近做项目,需要判断一条线段是否和一段圆弧相交&#…

利用Ad Hoc传感器网络上的局部信息组织全球坐标系(Matlab代码实现)

目录 💥1 概述 📚2 运行结果 🎉3 参考文献 👨‍💻4 Matlab代码 💥1 概述 知道通信网络中节点的地理位置通常是有用的,但在每个节点上添加GPS接收器或其他复杂的传感器可能会很昂贵。 本文…

系统集成项目管理工程师 笔记(第14章 项目采购管理)

文章目录 采购管理包括如下几个过程14.2 编制采购计划 4621)采购管理计划2)采购工作说明书3)采购文件14.2.3 工作说明书(SOW) 14.3 实施采购 47414.3.2 实施采购的方法和技术 476(1)投标人会议&…

深入篇【C++】类与对象:构造函数+析构函数

深入篇【C】类与对象:构造函数析构函数 ①.构造函数Ⅰ.概念Ⅱ.特性1.函数名和类型相同。2.无返回值,也不用写void。3.自动调用对应的构造函数。4.构造函数可重载5.编译器的无参构造6.编译器的无参构造特性7.声明时可缺省8.构造函数的调用9.默认构造函数 …

进程控制下篇

进程控制下篇 1.进程创建 1.1认识fork / vfork 在linux中fork函数时非常重要的函数&#xff0c;它从已存在进程中创建一个新进程。新进程为子进程&#xff0c;而原进程为父进程 #include<unistd.h> int main() {pid_t i fork;return 0; }当前进程调用fork&#xff0c;…

【VScode】的 安装--配置--使用(中文插件下载不了怎么办?)

&#x1f58a;作者 : D. Star. &#x1f4d8;专栏 : VScode &#x1f606;今日分享 : ”兰因絮果“是世间定律吗&#xff1f; 一段美好爱情开始时你侬我侬、缠缠绵绵&#xff0c;最后却以相看两厌结尾&#xff0c;让人唏嘘。清代词人纳兰容若于是咏出「人生若只如初见&#xff…

后端程序员的前端必备【Vue】 -01 Vue入门

Vue概述与基础入门 1 Vue简介1.1 简介1.2 MVVM 模式的实现者——双向数据绑定模式1.3 其它 MVVM 实现者1.4 为什么要使用 Vue.js1.5 Vue.js 的两大核心要素1.5.1 数据驱动![请添加图片描述](https://img-blog.csdnimg.cn/963aca7d7a4447009a23f6900fdd7ee1.png)1.5.2 组件化 2 …

系统集成项目管理工程师 笔记(第13章 项目合同管理)

文章目录 13.2.1 按信息系统 范围 划分的合同分类 4451、总承包合同2、单项工程承包合同3、分包合同 13.2.2 按项目 付款方式 划分的合同分类 4461、总价合同2、成本补偿合同&#xff08;卖方有利&#xff09;3、工料合同 13.3.1 项目合同的内容 44713.3.2 项目合同签订的注意事…

进程地址空间与页表方面知识点(缺页中断及写时拷贝部分原理)

谢谢阅读&#xff0c;如有错误请大佬留言&#xff01;&#xff01; 目录 谢谢阅读&#xff0c;如有错误请大佬留言&#xff01;&#xff01; 抛出总结 开始介绍 发现问题 进程地址空间&#xff08;虚拟地址&#xff09; 页表 物理内存与进程地址空间映射 缺页中断基本…

Linux操作系统之mysql数据库简介

文章目录 数据库的介绍有关数据库的操作有关数据表的操作C语言访问mysql事务视图索引 数据库的介绍 mysql数据库模型&#xff1a; 关系型数据库与非关系型数据库&#xff1a; 关系型数据库&#xff1a;指采用了关系模型来组织数据的数据库&#xff0c;关系模型就是指二维表格模…

【PCL】—— 点云滤波

文章目录 直通滤波降采样使用统计滤波&#xff08;statisticalOutlierRemoval&#xff09;移除离群点使用条件滤波&#xff08;ConditionalRemoval&#xff09;或 半径滤波&#xff08;RadiusOutlinerRemoval&#xff09;移除离群点 在获取点云数据时&#xff0c;由于设备精度&…

Vue(组件化编程:非单文件组件、单文件组件)

一、组件化编程 1. 对比传统编写与组件化编程&#xff08;下面两个解释图对比可以直观了解&#xff09; 传统组件编写&#xff1a;不同的HTML引入不同的样式和行为文件 组件方式编写&#xff1a;组件单独&#xff0c;复用率高&#xff08;前提组件拆分十分细致&#xff09; 理…

【Fluent】Error: Model information is incompatible with incoming mesh.

一、问题背景 在原有workbench数据文件上&#xff0c;修改几何数据&#xff0c;然后重新划分网格&#xff0c;在更新网格后&#xff0c;workbench就弹出错误Error&#xff01; Model information is incompatible with incoming mesh. 因为当时并不影响我打开fluent求解器&am…

C语言数组介绍和用法

文章目录 前言一、数组的定义二、数组的大小三、数组的访问方法四、使用for循环遍历数组五、数组地址的访问方法六、二维数组七、二维数组的遍历总结 前言 本篇文章将带大家学习C语言中的数组&#xff0c;数组在C语言中是一个比较重要的点&#xff0c;大家需要好好理解并多加使…

Linux Shell 介绍及常用命令汇总

文章目录 Part.I shell 简介Chap.I 概念汇编Chap.II 命令概览 Part.II shell 常用命令大全Chap.I 关于文件和目录Chap.II 关于磁盘和内存Chap.III 关于进程调度 Reference Part.I shell 简介 Chap.I 概念汇编 下面是一些概念 shell 与 bash 的区别与联系&#xff1a;bash 是 b…

2023五一杯B题:快递需求分析问题

题目 网络购物作为一种重要的消费方式&#xff0c;带动着快递服务需求飞速增长&#xff0c;为我国经济发展做出了重要贡献。准确地预测快递运输需求数量对于快递公司布局仓库站点、节约存储成本、规划运输线路等具有重要的意义。附件1、附件2、附件3为国内某快递公司记录的部分…

从力的角度再次比较9-2分布和8-3分布

( A, B )---1*30*2---( 1, 0 )( 0, 1 ) 让网络的输入只有1个节点&#xff0c;AB各由11张二值化的图片组成&#xff0c;让A中有3个0&#xff0c;8个1.B中全是0&#xff0c;排列组合A的所有可能&#xff0c;统计迭代次数的顺序。在前面实验中得到了8-3分布的数据 A-B 迭代次数 …

孔乙己文学,满街长衫,为谁而穿?解构孔乙己文学

鲁迅先生创作《孔乙己》的背景是20世纪初期的中国社会。那时&#xff0c;中国正处于民国的初期&#xff0c;社会动荡不安&#xff0c;人民生活贫困。在这个背景下&#xff0c;鲁迅开始写作并发表了一系列揭露社会黑暗面的作品。《孔乙己》是其中之一&#xff0c;它讲述了一个被…

利用snpEff对基因型VCF文件进行变异注释的详细方法

利用snpEff对VCF文件进行变异注释 群体遗传研究中&#xff0c;在获得SNP位点后,我们需要对SNP位点进行注释&#xff0c;对这些SNP位点进行更深的了解。 snpEff是一个用于对基因组单核苷酸多态性(SNP)进行注释的软件&#xff0c;snpEff软件可以用于对VCF文件进行变异注释&#x…

VC++ | VS2017编译报错-20230428

VC | VS2017编译报错-20230428 文章目录 VC | VS2017编译报错-202304281.报错1-1.解决办法 2.报错2-1.解决办法2-1-1.做如下设置2-1-2.代码调整 1.报错 1>------ 已启动生成: 项目: NvtUSBTool, 配置: Debug Win32 ------ 1>NvtUSBTool.cpp 1>$(PRJ_ROOT_DIR)nvtusbt…