logistic回归的参数梯度更新方法的个人理解

news2024/11/27 12:57:52

logistic回归参数更新看了几篇博文,感觉理解不透彻,所以自己写一下,希望能有更深的理解。logistic回归输入是一个线性函数 W x + b \boldsymbol{W}\boldsymbol{x}+\boldsymbol{b} Wx+b,为了简单理解,考虑batchsize为1的情况。这时输入 x \boldsymbol{x} x为一个 n × 1 n\times1 n×1的向量,标签 y \boldsymbol{y} y我们采用oneHot编码为一个 m × 1 m\times1 m×1的向量,显然\boldsymbol{b}也是一个 m × 1 m\times1 m×1的向量,参数 W \boldsymbol{W} W为一个 m × n m\times n m×n的矩阵。若 n = 4 n=4 n=4 m = 3 m=3 m=3,我们用图形表示logistic回归如下:
在这里插入图片描述
这里的标签 y \boldsymbol{y} y采用onehot编码,长度为3,如果类别编号为1,则其编码为 { 1 , 0 , 0 } T \{1,0,0\}^T {1,0,0}T,对应上图的话,就是 y ∗ 1 = 1 y_*^1=1 y1=1 y ∗ 2 = 0 y_*^2=0 y2=0 y ∗ 3 = 0 y_*^3=0 y3=0。损失函数 L L L就是 y 1 y^1 y1 y ∗ 1 y_*^1 y1的交叉熵损失+ y 2 y^2 y2 y ∗ 2 y_*^2 y2的交叉熵损失+ y 3 y^3 y3 y ∗ 3 y_*^3 y3的交叉熵损失。
L = ∑ i = 1 3 y ∗ i log ⁡ y i = y ∗ 1 log ⁡ y 1 + y ∗ 2 log ⁡ y 2 + y ∗ 3 log ⁡ y 3 \begin{aligned} L&=\sum_{i=1}^3y^i_*\log{y^i}\\ &=y^1_*\log{y^1}+y^2_*\log{y^2}+y^3_*\log{y^3} \end{aligned} L=i=13yilogyi=y1logy1+y2logy2+y3logy3
上式中:
y 1 = e z 1 e z 1 + e z 2 + e z 3 y 2 = e z 2 e z 1 + e z 2 + e z 3 y 3 = e z 3 e z 1 + e z 2 + e z 3 \begin{aligned} y^1&=\frac{e^{z^1}}{e^{z^1}+e^{z^2}+e^{z^3}}\\ y^2&=\frac{e^{z^2}}{e^{z^1}+e^{z^2}+e^{z^3}}\\ y^3&=\frac{e^{z^3}}{e^{z^1}+e^{z^2}+e^{z^3}}\\ \end{aligned} y1y2y3=ez1+ez2+ez3ez1=ez1+ez2+ez3ez2=ez1+ez2+ez3ez3

z 1 = w 1 T x + b 1 z 2 = w 2 T x + b 2 z 3 = w 3 T x + b 3 \begin{aligned} z^1=\boldsymbol{w_1}^T \boldsymbol{x}+b_1\\ z^2=\boldsymbol{w_2}^T \boldsymbol{x}+b_2\\ z^3=\boldsymbol{w_3}^T \boldsymbol{x}+b_3 \end{aligned} z1=w1Tx+b1z2=w2Tx+b2z3=w3Tx+b3
其中, w 1 = { w 11 , w 12 , w 13 , w 14 } T \boldsymbol{w_1}=\{w_{11},w_{12},w_{13},w_{14}\}^T w1={w11,w12,w13,w14}T x = { x 1 , x 2 , x 3 , x 4 } T \boldsymbol{x}=\{x_{1},x_{2},x_{3},x_{4}\}^T x={x1,x2,x3,x4}T因此:

损失函数 L L L w 1 \boldsymbol{w_1} w1求导:
∂ L ∂ w 1 = ∂ L ∂ y 1 ∂ y 1 ∂ z 1 ∂ z 1 ∂ w 1 + ∂ L ∂ y 2 ∂ y 2 ∂ z 1 ∂ z 1 ∂ w 1 + ∂ L ∂ y 3 ∂ y 3 ∂ z 1 ∂ z 1 ∂ w 1 = y 1 ∗ y 1 × y 1 ( 1 − y 1 ) × x − y 2 ∗ y 2 × y 1 y 2 × x − y 3 ∗ y 3 × y 1 y 3 × x = ( y 1 ∗ ( 1 − y 1 ) − y 2 ∗ y 1 − y 3 ∗ y 1 ) x = ( y 1 ∗ − y 1 ( y 1 ∗ + y 2 ∗ + y 3 ∗ ) ) x = ( y 1 ∗ − y 1 ) x \begin{aligned} \frac{\partial L}{\partial \boldsymbol{w_1}}&=\frac{\partial L}{\partial y_1}\frac{\partial y_1}{\partial z^1}\frac{\partial z^1}{\partial \boldsymbol{w_1}}+\frac{\partial L}{\partial y_2}\frac{\partial y_2}{\partial z^1}\frac{\partial z^1}{\partial \boldsymbol{w_1}}+\frac{\partial L}{\partial y_3}\frac{\partial y_3}{\partial z^1}\frac{\partial z^1}{\partial \boldsymbol{w_1}}\\ &=\frac{y_1^*}{y_1}\times y_1(1-y_1)\times \boldsymbol{x}-\frac{y_2^*}{y_2}\times y_1y_2\times \boldsymbol{x}-\frac{y_3^*}{y_3}\times y_1y_3\times \boldsymbol{x}\\ &=(y_1^*(1-y_1)-y_2^*y_1-y_3^*y_1)\boldsymbol{x}\\ &=(y_1^*-y_1(y_1^*+y_2^*+y_3^*))\boldsymbol{x}\\ &=(y_1^*-y_1)\boldsymbol{x}\\ \end{aligned} w1L=y1Lz1y1w1z1+y2Lz1y2w1z1+y3Lz1y3w1z1=y1y1×y1(1y1)×xy2y2×y1y2×xy3y3×y1y3×x=(y1(1y1)y2y1y3y1)x=(y1y1(y1+y2+y3))x=(y1y1)x
注意 ( y 1 ∗ + y 2 ∗ + y 3 ∗ ) (y_1^*+y_2^*+y_3^*) (y1+y2+y3)是标签onehot编码的三个值,和正好为1。同理可得到剩下的两个导数:
∂ L ∂ w 2 = ( y 2 ∗ − y 2 ) x ∂ L ∂ w 3 = ( y 3 ∗ − y 3 ) x \frac{\partial L}{\partial \boldsymbol{w_2}} = (y_2^*-y_2)\boldsymbol{x}\\ \frac{\partial L}{\partial \boldsymbol{w_3}} = (y_3^*-y_3)\boldsymbol{x} w2L=(y2y2)xw3L=(y3y3)x
交叉熵损失函数 L L L关于 w \boldsymbol{w} w的梯度为:
[ ( y 1 ∗ − y 1 ) x 1 ( y 2 ∗ − y 2 ) x 1      ( y 3 ∗ − y 3 ) x 1 ( y 1 ∗ − y 1 ) x 2 ( y 2 ∗ − y 2 ) x 2      ( y 3 ∗ − y 3 ) x 2 ( y 1 ∗ − y 1 ) x 3 ( y 2 ∗ − y 2 ) x 3      ( y 3 ∗ − y 3 ) x 3 ( y 1 ∗ − y 1 ) x 4 ( y 2 ∗ − y 2 ) x 4      ( y 3 ∗ − y 3 ) x 4 ( y 1 ∗ − y 1 ) x 5 ( y 2 ∗ − y 2 ) x 5      ( y 3 ∗ − y 3 ) x 5 ] T \left[ \begin{aligned} &(y_1^*-y_1)x1&(y_2^*-y_2)x1\space\space\space\space&(y_3^*-y_3)x1\\ &(y_1^*-y_1)x2&(y_2^*-y_2)x2\space\space\space\space&(y_3^*-y_3)x2\\ &(y_1^*-y_1)x3&(y_2^*-y_2)x3\space\space\space\space&(y_3^*-y_3)x3\\ &(y_1^*-y_1)x4&(y_2^*-y_2)x4\space\space\space\space&(y_3^*-y_3)x4\\ &(y_1^*-y_1)x5&(y_2^*-y_2)x5\space\space\space\space&(y_3^*-y_3)x5\\ \end{aligned} \right]^T (y1y1)x1(y1y1)x2(y1y1)x3(y1y1)x4(y1y1)x5(y2y2)x1    (y2y2)x2    (y2y2)x3    (y2y2)x4    (y2y2)x5    (y3y3)x1(y3y3)x2(y3y3)x3(y3y3)x4(y3y3)x5 T
这样交叉熵损失函数 L L L关于 w \boldsymbol{w} w的梯度用numpy的外积计算表示为:
∂ L ∂ w = n u m p y . o u t e r ( x , y ∗ − y ) \frac{\partial L}{\partial \boldsymbol{w}}=numpy.outer(\boldsymbol{x},\boldsymbol{y^*}-\boldsymbol{y}) wL=numpy.outer(x,yy)
用同样的方法可以推导出:
∂ L ∂ b = y ∗ − y \frac{\partial L}{\partial \boldsymbol{b}}=\boldsymbol{y^*}-\boldsymbol{y} bL=yy

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/402653.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

ChatGPT的影响力和未来发展

ChatGPT是一种基于深度学习的自然语言处理技术,它是由OpenAI开发的一种语言模型。ChatGPT作为一个大型语言模型,可以在很多方面对程序职业产生影响。值得注意的是,ChatGPT和其他语言模型一样,只是一种技术工具,它的应用…

Postman 接口传参报错解决与@RequestBody的注解作用记录

文章目录前言一、接口代码1.1 代码说明1.2 测试结果1.3 问题解决1.4 RequestBody 作用前言 记录接口传参报错与解决和RequestBody的作用记录 一、接口代码 1.1 代码说明 以下面测试代码作为例子:前端发送 POST 请求,请求体里面携带 List 集合的字符串…

C++回顾(十九)—— 容器string

19.1 string概述 1、string是STL的字符串类型,通常用来表示字符串。而在使用string之前,字符串通常是 用char * 表示的。string 与char * 都可以用来表示字符串,那么二者有什么区别呢。 2、string和 char * 的比较 (1&#xff09…

【wed前端初级课程】第一章 什么是HTML

什么是WEB前端? 简单来说就是网页,只是这个网页它是由多种技术参与制作的,用来向用户展示的页面。 HTML(超文本标签语言):它决定了网页的结构。 CSS:网页的装饰器。 JavaScript:JavaScrip最初是因为校验…

【Linux系统编程】06:共享内存

共享内存 OVERVIEW共享内存一、文件上锁flock二、共享内存1.关联共享内存ftok2.获取共享内存shmget3.绑定共享内存shmat4.绑定分离shmdt5.控制共享内存shmctl三、亲缘进程间通信1.共享内存写入与读取2.共享内存解绑与删除3.共享内存综合四、非亲缘进程间通信1.通过sleep同步2.通…

Android 进阶——Binder IPC之Native 服务的启动及代理对象的获取详解(六)

文章大纲引言一、Binder线程池的启动1、ProcessState#startThreadPool函数来启动线程池2、IPCThreadState#joinThreadPool 将当前线程进入到线程池中去等待和处理IPC请求二、Service 代理对象的获取1、获取Service Manager 代理对象BpServiceManager2、调用BpServiceManager#ge…

【算法数据结构体系篇class16】:图 拓扑排序

一、图1)由点的集合和边的集合构成2)虽然存在有向图和无向图的概念,但实际上都可以用有向图来表达3)边上可能带有权值二、图结构的表达1)邻接表法 类似哈希表, key就是当前节点。value就是对应有指向的邻接节点2&…

LeetCode——1590. 使数组和能被 P 整除

一、题目 给你一个正整数数组 nums,请你移除 最短 子数组(可以为 空),使得剩余元素的 和 能被 p 整除。 不允许 将整个数组都移除。 请你返回你需要移除的最短子数组的长度,如果无法满足题目要求,返回 -1…

PostgreSQL 数据库大小写规则

PostgreSQL 数据库对大小写的处理规则如下: 严格区分大小写默认把所有 SQL 语句都转换成小写再执行加双引号的 SQL 语句除外 如果想要成功执行名称中带有大写字母的对象,则需要把对象名称加上双引号。 验证如下: 想要创建数据库 IZone&…

Windows WSL配置ubuntu环境并登录

一、Windows WSL配置ubuntu环境1、管理员运行cmd,执行以下命令启用“适用于 Linux 的 Windows 子系统”dism.exe /online /enable-feature /featurename:Microsoft-Windows-Subsystem-Linux /all /norestart2、管理员运行cmd,执行以下命令启用“虚拟机功…

浅谈ChatGPT

ChatGPT概述 ChatGPT是一种自然语言处理模型,ChatGPT全称Chat Generative Pre-trained Transformer,由OpenAI开发。它使用了基于Transformer的神经网络架构,可以理解和生成自然语言文本。ChatGPT是当前最强大和最先进的预训练语言模型之一&a…

windows应用(vc++2022)MFC基础到实战(3)-基础(3)

目录框架调用代码MFC 对象之间的关系访问其他对象CWinApp:应用程序类initInstance 成员函数运行成员函数OnIdle 成员函数ExitInstance 成员函数CWinApp 和 MFC 应用程序向导特殊 CWinApp 服务Shell 注册文件管理器拖放CWinAppEx 类用于创建 OLE 应用程序的操作顺序用…

【算法题目】【Python】一文刷遍贪心算法题目

文章目录介绍分配饼干K 次取反后最大化的数组和柠檬水找零摆动序列单调递增的数字介绍 贪心算法是一种基于贪心思想的算法,它每次选择当前最优的解决方案,从而得到全局最优解。具体来说,贪心算法在每一步都做出局部最优选择,希望…

Flutter——Isolate主线机制

简述 在DartFlutter应用程序启动时,会启动一个主线程其实也就是Root Isolate,在Root Isolate内部运行一个EventLoop事件循环。所以所有的Dart代码都是运行在Isolate之中的,它就像是机器上的一个小空间,具有自己的私有内存块和一个运行事件循…

Linux下LED灯驱动模板详解

一、地址映射我们先了解MMU,全称是Memory Manage Unit。在老版本的Linux中要求处理器必须有MMU,但是现在Linux内核已经支持五MMU。MMU主要完成的功能如下:1、完成虚拟空间到物理空间的映射2、内存保护,设置存储器的访问权限&#…

【Linux学习笔记】mmap-共享内存进程通信 vs 有名信号量和无名信号量

mmap和信号量实现进程间通信相关mmap1. mmap 使用的注意事项2. mmap的两种映射3. mmap调用接口以及参数4. 使用存储映射区实现父子进程间通信(有名)父子进程通信的三种方式unlink5. 创建匿名存储映射区6. 通过存储映射区实现非血缘关系进程间的通信信号量…

SiteSucker for macOS + CRACK

SiteSucker for macOS CRACK SiteSucker是一个简单的macOS应用程序,允许您下载网站。它还可以将网站、网页、背景图片、视频和许多其他文件复制到Mac的硬盘上。 SiteSucker是一个Macintosh应用程序,可以自动下载Internet上的网页。它通过将网站的页面、…

遥感影像道路提取算法——SGCN

论文介绍 Split Depth-wise Separable Graph Convolution Network for Road Extraction in Complex Environment from High-resolution Remote Sensing Imagery(TGRS) 用于从高分辨率遥感图像(TGRS)中提取复杂环境中道路的分割深…

java对象的创建与内存分配机制

文章目录对象的创建与内存分配机制对象的创建类加载检查分配内存初始化零值设置对象头指向init方法其他:指针压缩对象内存分配对象在栈上分配对象在Eden区中分配大对象直接分配到老年代长期存活的对象进入老年代对象动态年龄判断老年代空间分配担保机制对象的内存回…

Spring的核心模块:Bean的生命周期(内含依赖循环+业务场景)。

Bean的生命周期前言为什么要学习Bean的生命周期前置知识Spring Post-processor(后置处理器)Aware接口简单介绍Bean的实例化过程为什么会有bean的实例化?过程Bean的初始化阶段为什么会有Bean的初始化?Bean的初始化目的是什么&#…