P18 PyTorch 感知机的梯度推导

news2024/12/23 13:12:30

前言

这里面简单介绍一下单层感知机和多层感知机的模型

参考:

https://www.bilibili.com/video/BV17e4y1q7NG?p=41


一 单层感知机模型

输入

: k 代表网络层数,i 代表输入节点的编号

前向传播

: 权重系数

k: 层数

i: 前一层输入节点编号

j: 当前层输出节点编号

这里: k=1, j =0

损失函数

t: 标签值

Backward: 反向传播更新梯度

# -*- coding: utf-8 -*-
"""
Created on Fri Feb 17 16:56:58 2023

@author: chengxf2
"""

import torch
import torch.nn.functional as F

def forward(w,x):
    print("w 的梯度",w.grad)
    if w.grad is not None: 
        w.grad.data.zero_()
        print("\n 初始化的w ",w)

    a = torch.matmul(x,w.T)
    o= torch.sigmoid(a)
    
    loss = F.mse_loss(o, target)
    print("\n loss ",loss)
    return loss

def  backward(loss, w):
    
    loss.backward()
    print("\n 权重系数 ",w.data, "\n 权重系数梯度 ",w.grad,"\n 权限系数梯度", w.grad.data)
    

    
target = torch.ones(1,1) 
x = torch.randn(1,3)
w = torch.randn(1,3, requires_grad=True)
 
loss = forward(w,x)
backward(loss,w)

权重系数,以及权重系数的梯度,都是放在data 里面


二 多层感知机

2.1 模型

从上面可以看出梯度消失,梯度弥散的原因

当经过多个这样的神经元梯度就会消失了

2.2 矩阵推导

另一种更普遍的方式利用矩阵方式推导

输入

Forward:

损失函数

利用迹和梯度的关系

所以

利用链式求导法则

(点乘 对应 multiply)

所以


# -*- coding: utf-8 -*-
"""
Created on Mon Feb 13 21:28:26 2023

@author: cxf
"""

import torch
import torch.nn.functional as F


def matGrad(z,o,x):
    print("\n o",o)
    a = torch.multiply(o, 1.0-o)
    print("\n a ",a)
    
    b = torch.multiply(a,z)
    
    c = torch.matmul(b,x.T)
    print("\n out \n",c/2.0)

def grad():
    
    x = torch.randn(3,1)
    w = torch.randn(4,3,requires_grad=True)
    a = torch.matmul(w, x)
    o = torch.sigmoid(a)
    
    t = torch.ones(4,1)
    
    loss = F.mse_loss(o, t)
    
    z = o-t
    matGrad(z,o,x)
    loss.backward()
    
    print("\n w grad \n ",w.grad)
    
    

if __name__ == "__main__":
    
    grad()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/359054.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Python opencv进行圆形识别(圆检测)

圆形识别(圆检测)是图像识别中很常见的一种处理方式,最核心的是cv2.HoughCircles这个函数实现的圆形检测。当然还有一些其他的处理过程,以下详述: 1 读入图像 首先需要读取一个图像文件,将其作为一个变量…

Java 字符串

文章目录一、API二、String1. String 构造方法2. String 对象的特点3. 字符串的比较4. 用户登录案例5. 遍历字符串6. 统计字符次数7. 拼接字符串8. 字符串反转三、StringBuilder1. 构造方法2. 添加及反转方法3. 与 String 相互转换4. 拼接字符串升级版5. 字符串反转升级版一、A…

【Java】Spring核心与设计思想

文章目录Spring核心与设计思想1. Spring是什么1.1 什么是容器1.2 什么是IOC1.2.1 传统程序开发1.2.2 控制反转式程序开发1.2.3 对比总结规律1.3 理解Spring IOC1.4 DI概念说明Spring核心与设计思想 1. Spring是什么 我们通常所说的Spring指的是Spring Framework(S…

工业4.0和工业物联网如何协同工作

虽然许多公司已经接受了工业物联网,但他们现在必须接受工业4.0对数据驱动的数字化转型的承诺。随着制造业、能源、公用事业和供应链应用迅速采用工业物联网(IIoT),这些行业的新现实正在形成。工业物联网提供了企业管理数千个活动部件所需的数据类型&…

二、最基本的vuex的使用

二、最基本的vuex的使用: 学习任何技术,先找到没有用这个技术时,给我们带来了什么麻烦 而这个新技术是怎么帮我们解决这些问题的。 理解方式: state:装数据的一个对象 mutations:装方法的一个对象&#…

FPGA纯Verilog实现任意尺寸图像缩放,串口指令控制切换,贴近真实项目,提供工程源码和技术支持

目录1、前言2、目前主流的FPGA图像缩放方案3、本方案的优越性4、详细设计方案5、vivado工程详解6、上板调试验证并演示7、福利:工程源码获取1、前言 代码使用纯verilog实现,没有任何ip,可在Xilinx、Intel、国产FPGA间任意移植; 图…

steam搬砖信息差项目,新手1周拿到结果!

项目具体是什么呢? 项目简单概括通过选品软件自动分析出此商品国内外商品价格,计算出利润率,选择出有利润销量好的商品,在以最低价格上架到国内buff的平台里,既能快速的卖出,还能获利。 主要利润在于商品…

力扣刷题|216.组合总和 III、17.电话号码的字母组合

文章目录LeetCode 216.组合总和题目链接🔗思路LeetCode 17.电话号码的字母组合题目链接🔗思路LeetCode 216.组合总和 题目链接🔗 LeetCode 216.组合总和 思路 本题就是在[1,2,3,4,5,6,7,8,9]这个集合中找到和为n的k个数的组合。 相对于7…

2 月 25 日,论道京城 | 云原生开源项目应用实践报名开启

在数字化转型的浪潮中,云原生已经逐渐成为人们关注的焦点。开源社区作为云原生技术创新的根据地,为云原生的产业发展打造了丰富的技术生态圈,也在广泛的实践中源源不断地创造着新的机遇。想知道云原生存储技术实现了怎样的突破吗?…

51单片机开发环境搭建 - VS Code 从编写到烧录

我安装并测试成功的环境: 操作系统:Windows 10 (22H2)单片机:STC89C52RCPython version: 3.7.6 在这之前,给51单片机写程序是用 Keil 5(编写编译)、STC-ISP(烧录),由于…

第六章.卷积神经网络(CNN)—卷积层(Convolution)池化层(Pooling)

第六章.卷积神经网络(CNN) 6.1 卷积层(Convolution)&池化层(Pooling) 1.整体结构 以5层神经网络的实现为例: 1).基于全连接层(Affine)的网络 全连接层:相邻层的所有神经元之间都有连接 2).常见的CNN的网络 3).全连接层存在的问题 数据的形状容易被…

VSCode Remote-SSH配置免密登录踩坑

VSCode Remote-SSH配置免密登录踩坑1. 参考2. 基本流程2.1 机器A(Windows客户端)2.2 机器B(Linux服务器)2.3 机器A(Windows客户端)的VSCode设置3. 踩坑总结相关教程很多,但要么冗余,…

Teradata退出中国,您可以相信中国数据库!

继Adobe、Tableau、Salesforce之后,2023年2月15日,数仓软件巨头Teradata宣布将逐步结束在中国的直接运营。数仓界的“黄埔军校”仓皇撤出中国市场给出的理由非常含蓄:Teradata对中国当前和未来商业环境的慎重评估,我们做了一个艰难…

阅读笔记5——深度可分离卷积

一、标准卷积 标准卷积在卷积时,同时考虑了特征图的区域和通道信息。 标准卷积的过程如图1-1所示,假设输入特征图的channel3,则每个卷积核的channel都为3,每个卷积核的3个channel对应提取输入特征图的3个channel的特征&#xff08…

抖音怎样报白?报白需要审核哪些资料呢

抖音怎样报白?报白需要审核哪些资料呢 抖音报白需要什么资料,翡翠原石产品如何开通报白#报白#小店报白#小店运营#抖音#抖音小店运营 文/专栏作家百收 随着抖音在国内流行起来,抖音上每天会有大量的视频更新,越来越多的年轻人也加…

Java数据结构中链表分割及链表排序使用快速排序、归并排序、集合排序、迭代、递归,刷题的重点总结

本篇主要介绍在单链表进行分割,单链表进行分隔并使用快速排序、归并排序、集合排序、迭代、递归等方法的总结,愿各位大佬喜欢~~ 86. 分隔链表 - 力扣(LeetCode) 148. 排序链表 - 力扣(LeetCode) 目录 一…

CAS概述

目录一、CAS与原子类1.1 CAS1.2 乐观锁与悲观锁1.3 原子操作类二、 synchronized优化2.1 轻量级锁2.2 轻量级锁-无竞争2.3 轻量级锁-锁膨胀2.4 重量级锁-自旋2.5 偏向锁2.6 synchronized-其他优化一、CAS与原子类 1.1 CAS CAS(一种不断尝试)即Compare …

2023年正在使用的设计资源网站分享

这篇文章,也将整理出我今年一直都在使用的设计资源网站!作为设计师一定是离不开优质的资源网站的,我自己的话会每天都花一两个小时的时间去浏览自己的收藏的这些资源网站。哪怕只是简单的浏览,也可以在无形中增加自己对设计的“设…

rocketmq延时消息自定义配置

概述 使用的是开源版本的rocketmq4.9.4 rocketmq也是支持延时消息的。 rocketmq一般是4个部分: nameserver:保存路由信息broker:保存消息生产者:生产消息消费者:消费消息 延时消息的处理是在其中的broker中。 但是…

华为认证含金量如何?

一本证书是否有用,还要看它是否被市场所认可。 我们说华为认证HCIP有用,很大一部分还取决于它极高的适用性和权威性。华为是国内最大的生产销售通信设备的民营通信科技公司。 自2013年起,国家对网络安全极度重视,相继把国外的网…