神经网络全连接层数学推导

news2025/1/10 23:48:53

全连接层分析

对于神经网络为什么都能产生很好的效果虽然其是一个黑盒,但是我们也可以对其中的一些数学推导有一定的了解。

数学背景

目标函数为 f = ∣ ∣ m a x ( X W , 0 ) − Y ∣ ∣ F 2 ,求 ∂ f ∂ W , ∂ f ∂ X , ∂ f ∂ Y 目标函数为f = ||max(XW,0)-Y||^{2}_{F},求\frac{\partial f}{\partial W},\frac{\partial f}{\partial X},\frac{\partial f}{\partial Y} 目标函数为f=∣∣max(XW,0)YF2,求Wf,Xf,Yf

公式证明

解: 首先我们假设替换变量 f = ∣ ∣ S − Y ∣ ∣ F 2 S = m a x ( Z , 0 ) Z = X W 根据假设变量,易得: ∂ f ∂ S = ∂ f ∂ Y = 2 ( S − Y ) ( 1 ) 要求 Z 分量的偏导: ∂ f = t r ( ( ∂ f ∂ S ) T d S ) = t r ( ( ∂ f ∂ S ) T d m a x ( Z , 0 ) ) = t r ( ( ∂ f ∂ S ) T ⊙ m a x ′ ( Z , 0 ) d Z ) 所以: ∂ f ∂ Z = m a x ′ ( Z , 0 ) T ⊙ ∂ f ∂ S = 2 ∗ m a x ′ ( Z , 0 ) ⊙ ( S − Y ) = 2 ∗ m a x ′ ( Z , 0 ) ⊙ ( m a x ( Z , 0 ) − Y ) ( 2 ) 要求 W 分量的偏导 : ∂ f = t r ( ( ∂ f ∂ Z T d Z ) = t r ( ( ∂ f ∂ Z ) T d ( X W ) ) = t r ( ( X T f Z ) d W ) 所以: ∂ f ∂ W = X T ∂ f ∂ Z = 2 ∗ X T m a x ′ ( Z , 0 ) ⊙ ( m a x ( Z , 0 ) − Y ) ( 3 ) 要求 X 分量的偏导 : ∂ f = t r ( ( ∂ f ∂ Z T d Z ) = t r ( ( ∂ f ∂ Z ) T d ( X W ) ) = t r ( ( f Z ) W T d X ) 所以: ∂ f ∂ X = W T ∂ f ∂ Z = 2 m a x ′ ( Z , 0 ) ⊙ ( m a x ( Z , 0 ) − Y ) W T \begin{align} 解:&首先我们假设替换变量 \\ &f=||S-Y||^{2}_{F} \\ &S=max(Z,0) \\ &Z=XW \\ &根据假设变量,易得: \\ \frac{\partial f}{\partial S} &=\frac{\partial f}{\partial Y}= 2(S-Y) \\\\ &(1)要求Z分量的偏导: \\ \partial f&=tr((\frac{\partial f}{\partial S})^{T}dS) \\ &=tr((\frac{\partial f}{\partial S})^{T}dmax(Z,0))\\ &=tr((\frac{\partial f}{\partial S})^{T}\odot max\prime(Z,0)dZ) \\ &所以:\\ \frac{\partial f}{\partial Z}&=max\prime(Z,0)^{T}\odot \frac{\partial f}{\partial S} \\ &=2*max\prime(Z,0)\odot(S-Y) \\ &=2*max\prime(Z,0)\odot(max(Z,0)-Y) \\ \\ &(2)要求W分量的偏导:\\ \partial f&=tr((\frac{\partial f}{\partial Z}^{T}dZ) \\ &=tr((\frac{\partial f}{\partial Z})^{T}d(XW)) \\ &=tr((X^{T}\frac{f}{Z})dW) \\ &所以:\\ \frac{\partial f}{\partial W}&=X^{T}\frac{\partial f}{\partial Z}\\ &=2*X^{T}max\prime(Z,0)\odot(max(Z,0)-Y)\\ \\ &(3)要求X分量的偏导:\\ \partial f&=tr((\frac{\partial f}{\partial Z}^{T}dZ) \\ &=tr((\frac{\partial f}{\partial Z})^{T}d(XW)) \\ &=tr((\frac{f}{Z})W^{T}dX) \\ &所以:\\ \frac{\partial f}{\partial X}&=W^{T}\frac{\partial f}{\partial Z}\\ &=2max\prime(Z,0)\odot(max(Z,0)-Y)W^{T} \end{align} 解:SffZffWffXf首先我们假设替换变量f=∣∣SYF2S=max(Z,0)Z=XW根据假设变量,易得:=Yf=2(SY)(1)要求Z分量的偏导:=tr((Sf)TdS)=tr((Sf)Tdmax(Z,0))=tr((Sf)Tmax(Z,0)dZ)所以:=max(Z,0)TSf=2max(Z,0)(SY)=2max(Z,0)(max(Z,0)Y)(2)要求W分量的偏导:=tr((ZfTdZ)=tr((Zf)Td(XW))=tr((XTZf)dW)所以:=XTZf=2XTmax(Z,0)(max(Z,0)Y)(3)要求X分量的偏导:=tr((ZfTdZ)=tr((Zf)Td(XW))=tr((Zf)WTdX)所以:=WTZf=2max(Z,0)(max(Z,0)Y)WT

全连接ReLU

公式推导

首先一个全连接ReLU神经网络,一个隐藏层,没有bias,用来从x预测y,使用L2 Loss。

h = X W 1 h relu  = max ⁡ ( 0 , h ) Y pred  = h relu  W 2 f = ∥ Y − Y pred  ∥ F 2 \begin{array}{l} h=X W_{1} \\ h_{\text {relu }}=\max (0, h) \\ Y_{\text {pred }}=h_{\text {relu }} W_{2} \\ f=\left\|Y-Y_{\text {pred }}\right\|_{F}^{2} \end{array} h=XW1hrelu =max(0,h)Ypred =hrelu W2f=YYpred F2
其网络连接示意图如下所示:

全连接ReLU示意图

对于 W 1 W_1 W1 W 2 W_2 W2的偏导,由上数学背景易得为:
∂ f ∂ Y p r e d = 2 ( Y − Y p r e d ) ∂ f ∂ W 2 = h r e l u T . 2 ( Y − Y p r e d ) ∂ f ∂ h r e l u = ∂ f ∂ Y p r e ( W 2 T ) ∂ f ∂ h = m a x ′ ( 0 , h ) ∂ f ∂ h r e l u ∂ f ∂ W 1 = X T ∂ f ∂ h \begin{align} \frac{\partial f}{\partial Y_{pred}}&=2(Y-Y_{pred}) \\ \frac{\partial f}{\partial W_{2}} &= h_{relu}^T.2(Y-Y_{pred}) \\ \frac{\partial f}{\partial h_{relu}}&= \frac{\partial f}{\partial Y_{pre}}(W_2^T)\\ \frac{\partial f}{\partial h}&= max\prime(0,h)\frac{\partial f}{\partial h_{relu}}\\ \frac{\partial f}{\partial W_{1}} &= X^{T}\frac{\partial f}{\partial h} \end{align} YpredfW2fhrelufhfW1f=2(YYpred)=hreluT.2(YYpred)=Ypref(W2T)=max(0,h)hreluf=XThf

Numpy实现

import numpy as np
import torch
N,D_in,H,D_out = 64,1000,100,10
#随机数据
x = np.random.randn(N,D_in)
y = np.random.randn(N,D_out)
w1= np.random.randn(D_in,H)
w2= np.random.randn(H,D_out)

#学习率
learning_rate = 1e-6

for it in range(501):
    #Forward pass
    h = x.dot(w1)                #N*H
    h_relu = np.maximum(h,0)     #N*H
    Y_pred = h_relu.dot(w2)      #N*D_out
    
    #compute loss
    #numpy.square()函数返回一个新数组,该数组的元素值为源数组元素的平方。 源阵列保持不变。
    loss = np.square(y-Y_pred).sum()
    #print(loss)
    if it%50==0:
        print(it,loss)
    
    #Backward pass
    #compute the gradient
    grad_Y_pre = 2.0*(Y_pred - y)
    grad_w2 = h_relu.T.dot(grad_Y_pre)
    grad_h_relu = grad_Y_pre.dot(w2.T)
    grad_h = grad_h_relu.copy()
    grad_h[h<0] = 0
    grad_w1 = x.T.dot(grad_h)
    
    #update weights of w1 and w2
    w1 -= learning_rate * grad_w1
    w2 -= learning_rate * grad_w2

全连接层练习2

h = X W 1 + b 1 h sigmoid  = sigmoid ⁡ ( h ) Y pred  = h sigmoid  W 2 + b 2 f = ∥ Y − Y pred  ∥ F 2 \begin{array}{l} h=X W_{1}+b_{1} \\ h_{\text {sigmoid }}=\operatorname{sigmoid}(h) \\ Y_{\text {pred }}=h_{\text {sigmoid }} W_{2}+b_{2} \\ f=\left\|Y-Y_{\text {pred }}\right\|_{F}^{2} \end{array} h=XW1+b1hsigmoid =sigmoid(h)Ypred =hsigmoid W2+b2f=YYpred F2
![[激活函数#sigmoid函数]]

由上公式易得:
∂ f ∂ Y p r e d = 2 ( Y − Y p r e d ) ∂ f ∂ h s i g m o i d = ∂ f ∂ Y p r e d W 2 T ∂ f ∂ h = ∂ f ∂ h s i g m o i d s i g m o i d ′ ( x ) = ∂ f ∂ h s i g m o i d s i g m o i d ( x ) ( 1 − s i g m o i d ( x ) ) ∂ f ∂ W 2 = h s i g m o i d T ∂ f ∂ Y p r e d ∂ f ∂ b 2 = ∂ f ∂ Y p r e d ∂ f ∂ W 1 = X T ∂ f ∂ h s i g m o i d ∂ f ∂ b 1 = ∂ f ∂ h \begin{align} \frac{\partial f}{\partial Y_{pred}} &=2(Y-Y_{pred}) \\ \frac{\partial f}{\partial h_{sigmoid}}&=\frac{\partial f}{\partial Y_{pred}}W_2^{T} \\ \frac{\partial f}{\partial h}&=\frac{\partial f}{\partial h_{sigmoid}}sigmoid\prime(x) \\ &=\frac{\partial f}{\partial h_{sigmoid}}sigmoid(x)(1-sigmoid(x)) \\ \frac{\partial f}{\partial W_2}&=h_{sigmoid}^{T}\frac{\partial f}{\partial Y_{pred}} \\ \frac{\partial f}{\partial b_{2}}&=\frac{\partial f}{\partial Y_{pred}}\\ \frac{\partial f}{\partial W_1}&=X^{T}\frac{\partial f}{\partial h_{sigmoid}}\\ \frac{\partial f}{\partial b_1}&=\frac{\partial f}{\partial h} \end{align} YpredfhsigmoidfhfW2fb2fW1fb1f=2(YYpred)=YpredfW2T=hsigmoidfsigmoid(x)=hsigmoidfsigmoid(x)(1sigmoid(x))=hsigmoidTYpredf=Ypredf=XThsigmoidf=hf

%matplotlib inline
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

torch.manual_seed(1)    # reproducible

x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)  # x data (tensor), shape=(100, 1)
y = x.pow(2) + 0.2*torch.rand(x.size())


plt.scatter(x.numpy(), y.numpy())

class Net(torch.nn.Module):
    def __init__(self, n_feature, n_hidden, n_output):
        super(Net, self).__init__()
        self.linear1 = torch.nn.Linear(n_feature,n_hidden,bias=True)
        self.linear2 = torch.nn.Linear(n_hidden,n_output,bias=True)

    def forward(self, x):
        y_pred = self.linear2(torch.sigmoid(self.linear1(x)))
        return y_pred
net = Net(n_feature=1, n_hidden=20, n_output=1)     # define the network
print(net)  # net architecture
optimizer = torch.optim.SGD(net.parameters(), lr=0.2)
loss_func = torch.nn.MSELoss()  # this is for regression mean squared loss

plt.ion()   # something about plotting

for t in range(201):
    prediction = net(x)     # input x and predict based on x
    loss = loss_func(prediction, y)     # must be (1. nn output, 2. target)

    optimizer.zero_grad()   # clear gradients for next train
    loss.backward()         # backpropagation, compute gradients
    optimizer.step()        # apply gradients

    if t % 20 == 0:
        # plot and show learning process
        plt.cla()
        plt.scatter(x.numpy(), y.numpy())
        plt.plot(x.numpy(), prediction.data.numpy(), 'r-', lw=5)
        plt.text(0.5, 0, 't = %d, Loss=%.4f' % (t, loss.data.numpy()), fontdict={'size': 20, 'color':  'red'})
        plt.pause(0.1)
        plt.show()

plt.ioff()
plt.show()


请添加图片描述
请添加图片描述

请添加图片描述

请添加图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/505929.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

SpringBoot项目如何打包成exe应用程序

准备 准备工作&#xff1a; 一个jar包&#xff0c;没有bug能正常启动的jar包 exe4j&#xff0c;一个将jar转换成exe的工具 链接: https://pan.baidu.com/s/1m1qA31Z8MEcWWkp9qe8AiA 提取码: f1wt inno setup&#xff0c;一个将依赖和exe一起打成一个安装程序的工具 链接:…

设计模式——代理模式(静态代理、JDK动态代理、CGLIB动态代理)

是什么&#xff1f; 如果因为某些原因访问对象不适合&#xff0c;或者不能直接引用目标对象&#xff0c;这个时候就需要给该对象提供一个代理以控制对该对象的访问&#xff0c;代理对象作为访问对象和目标对象之间的中介&#xff1b; Java中的代理按照代理类生成时机不同又分…

婴儿摇篮语音播放芯片,高品质MP3音乐播放芯片,WT2003H

婴儿摇篮是一种用于帮助婴儿入睡的设备。传统的婴儿摇篮通常只是简单的摇晃&#xff0c;但是带有语音播报芯片的婴儿摇篮则可以更好地模拟妈妈的声音&#xff0c;从而更有效地帮助婴儿入睡。 如果您正在寻找高品质音乐摇篮方案&#xff0c;那么WT2003H语音播放芯片&#xff0c…

5月7日 2H55min|5月8日8H50min|时间轴复盘|14:00~14:30

5月8日 todo list list4 40min ✅ |实际上用了50+50 list6 40min ✅ |实际上用了30+60 阅读+听力连做 100min ✅ 口语 day01 ✅ 口语 day02 口语 day03

6、并发事务控制MVCC汇总

1.并发事务控制 单版本控制-锁 先来看锁&#xff0c;锁用独占的方式来保证在只有一个版本的情况下事务之间相互隔离&#xff0c;所以锁可以理解为单版本控制。 在 MySQL 事务中&#xff0c;锁的实现与隔离级别有关系&#xff0c;在 RR&#xff08;Repeatable Read&#xff0…

vCenter Server 8.0U1 OVF:在 Fusion 和 Workstation 中快速部署 vCSA

vCenter Server 8.0U1 OVF&#xff1a;在 Fusion 和 Workstation 中快速部署 vCSA vCenter Server 8.0U1 系列更新 请访问原文链接&#xff1a;https://sysin.org/blog/vmware-vcenter-8-ovf/&#xff0c;查看最新版。原创作品&#xff0c;转载请保留出处。 作者主页&#x…

Win上通过Jconsole查看Java程序资源占用情况(教程总结,一篇就够了)

最近需要读取一个大文件&#xff0c;为了判断有没有读取到内存中&#xff0c;需要一个能查看jar包占用内存的工具&#xff0c;一顿面向百度后&#xff0c;发现了jdk自带的工具Jconsole&#xff0c;将教程分享给大家 一、介绍 JConsole 是一个内置 Java 性能分析器&#xff0c;…

手把手教你使用unisat 交易市场|BRC20|Unisat

开始前先熟悉下这张平台市场标注图&#xff0c;能让你跟得心应手&#xff01; 一、查看实时成交信息&#xff08;已 moon 为例子&#xff09; 搜索进入Token界面&#xff0c;点击 Orders 可以看到 1w 枚成交 87.86U&#xff08;单价 30 聪&#xff0c;大约 0.008786u&#xf…

牛客网剑指offer|中等题day2|JZ76删除链表中的重复节点、JZ23链表中环的入口节点、JZ24 反转链表(简单)

JZ76删除链表中的重复节点 链接&#xff1a;删除链表中重复的结点_牛客题霸_牛客网 参考代码&#xff1a; 自己好像还是偏向双指针这种想法&#xff0c;所以用了两个指针&#xff0c;这样感觉更好理解一些。 对了&#xff0c;去重有两种&#xff0c;我一开始写成了简单的那种&a…

MGV3001_ZG_当贝纯净桌面-线刷固件包

MGV3001_ZG_当贝纯净桌面-线刷固件包-内有教程及短接点 特点&#xff1a; 1、适用于对应型号的电视盒子刷机&#xff1b; 2、开放原厂固件屏蔽的市场安装和u盘安装apk&#xff1b; 3、修改dns&#xff0c;三网通用&#xff1b; 4、大量精简内置的没用的软件&#xff0c;运…

【标准化方法】(3) Group Normalization 原理解析、代码复现,附Pytorch代码

今天和各位分享一下深度学习中常用的标准化方法&#xff0c;Group Normalization 数据分组归一化&#xff0c;向大家介绍一下数学原理&#xff0c;并用 Pytorch 复现。 Group Normalization 论文地址&#xff1a;https://arxiv.org/pdf/1803.08494.pdf 1. 原理介绍 在目标检测…

Javascript - Cookie的获取和保存应用

在之前的博客介绍了如何利用 Selenium去搭建 cookie池&#xff0c;进行自动化登录、获取信息等。那什么是cookie呢&#xff1f;它的作用又是什么呢&#xff1f; 这里&#xff0c;再重复简单介绍一下。 cookie 是浏览器储存在用户电脑上的一小段文本文件。该文件里存了加密后的用…

LeetCode之回溯算法

文章目录 思想&框架1.组合/子集和排列问题2.组合应用问题 组合/子集问题1. lc77 组合2. lc216 组合总和III3. lc39 组合总和4. lc40 组合总和II5. lc78 子集6. lc90 子集II 排列1. 全排列I2. 全排列II 组合问题的应用1.lc17 电话号码的字母组合2.lc131 分割回文串3. lc19 复…

集约式智能自动化办公,实在智能门户开启政企数字化转型新范式

导语&#xff1a; 随着数字化和智能化的快速发展&#xff0c;数字技术已经深入到各个行业和领域。实在智能基于数字员工在行业的深厚理解和丰富的实践经验&#xff0c;打造一站式的智能化统一平台——智能门户&#xff0c;打破了技术壁垒和系统数据之间的割裂感&#xff0c;实现…

软考A计划-重点考点-专题五(计算机网络知识)

点击跳转专栏>Unity3D特效百例点击跳转专栏>案例项目实战源码点击跳转专栏>游戏脚本-辅助自动化点击跳转专栏>Android控件全解手册点击跳转专栏>Scratch编程案例 &#x1f449;关于作者 专注于Android/Unity和各种游戏开发技巧&#xff0c;以及各种资源分享&am…

Apache Sentry

官方 说明 Sentry是一种用于在Hadoop集群中控制和管理访问权限的工具。因此&#xff0c;CDH的Sentry指的是Cloudera Distribution for Hadoop中集成的Sentry组件&#xff0c;用于管理Hadoop集群中的访问控制和权限管理。 作用 Sentry是一个用于管理Hadoop集群中的访问权限的…

基于C++实现旅行线路设计

访问【WRITE-BUG数字空间】_[内附完整源码和文档] 系统根据风险评估&#xff0c;为旅客设计一条符合旅行策略的旅行线路并输出&#xff0c;系统能查询当前时刻旅客所处的地点和状态&#xff08;停留城市/所在交通工具&#xff09;。 实验内容和实验环境描述 1.1 实验内容 城…

【吐槽贴】项目经理如何进行高效沟通?

“项目最大的风险就是都觉得没有风险。” 这还是跟同行聊天时开玩笑的一句话&#xff0c;最近我却深有体会。一直以为一切正常的项目&#xff0c;最近却接连出了问题&#xff0c;复盘才发现几个关键性问题都出在沟通方面&#xff0c;还一直认为沟通能力是自己的优势。这次主要踩…

使用java-timeseries库,使用arima算法预测时间序列(

项目地址&#xff1a; GitHub - signaflo/java-timeseries: Time series analysis in Java maven&#xff1a; <dependency><groupId>com.github.signaflo</groupId><artifactId>timeseries</artifactId><version>0.4</version> &…

【剖析STL】String

1.什么是STL&#xff1f; 标准模板库&#xff08;Standard Template Library&#xff0c;STL&#xff09;是惠普实验室开发的一系列软件的统称。它是由Alexander Stepanov、Meng Lee和David R Musser在惠普实验室工作时所开发出来的。虽说它主要出现到C中&#xff0c;但在被引…