[PyTorch][chapter 8][李宏毅深度学习][Back propagation]

news2024/9/24 17:12:39

前言:

              反向传播算法(英:Backpropagation algorithm,简称:BP算法)是一种监督学习算法,常被用来训练多层感知机。 它用于计算梯度计算中,降低误差。

      

目录:

  1.     链式法则
  2.     模型简介(Model)
  3.     损失函数,梯度
  4.     手写例子
  5.     min-batch

一  链式法则

      链式法则是反向传播算法里面的核心。

     case1: y=g(x),z=h(y), x,y,z 都是scalar

                       

                     \frac{dz }{dx}=\frac{dz }{dy}\frac{dy }{dx}        

      case2:  x=g(s),y=h(s),z=k(x,y),s,x,y,z 都是scalar

                   

                       \frac{dz}{ds}=\frac{dz}{dy}\frac{dy}{ds}+\frac{dz}{dx}\frac{dx}{ds}

      case3:   x,y,z 都是向量vector

                   x\rightarrow y\rightarrow z

                    \frac{dz }{dx}=\frac{dz }{dy}\frac{dy }{dx}


二  模型(Model)

以常用的网络模型DNN 为例:

 激活函数为 \sigma

 总的层数为 L


三    损失函数,梯度

       3.1 损失函数

           J(w,b)=||a^{L}-y||_2^{2}

       3.2 梯度更新

               梯度计算分为两步:

   Forward pass, Backward pass

         a Forward pass

               假设 \delta^{l}=\frac{\partial J}{\partial z^l}:

            利用微分和迹的关系很容易得到

         

          b  Backward pass  

               假设为最后一层L

                 \delta^{L}=(\frac{\partial a^L}{\partial z^L})^T\frac{\partial J}{\partial a^L}

                       =diag(\sigma^{'}(z^{L}))(a^{L}-\hat{y})

                      =(a^{L}-\hat{y})\odot \sigma{'}(z^{L})

            我们用数学归纳法,第L层的\delta^{L}已经求出, 假设第l+1层的\delta^{l+1}已经求出来了,那么我们如何求出第l层的\delta^{l}呢?

                \delta^{l}=\frac{\partial J}{\partial z^{l}}

                    =(\frac{\partial z^{l+1}}{\partial z^{l}})^T\frac{\partial J}{\partial z^{l+1}}

                    =(\frac{\partial z^{l+1}}{\partial a^l}\frac{\partial a^{l}}{\partial z^l})^T \delta^{l+1}

                    =(diag(\sigma^{'}(z^l)(w^{l+1})^T)\delta^{l+1}

                    =(w^{l+1})^T\delta^{t+1}\odot \sigma^{'}(z^l)


四   简单DNN 网络例子

 4.1 说明:

          这里面随机生成5张图形,分别对应手写数字1,2,3,4,5。

简单的了解一下如何快速搭建一个DNN Model, 梯度如何计算,更新的.

 

# -*- coding: utf-8 -*-
"""
Created on Fri Dec 15 17:21:35 2023

@author: chengxf2
"""

import torch 
from torch import nn
from torch import optim


class DNN(nn.Module):
    
    '''
    它是一个序列容器,是nn.Module的子类。 
    `nn.Sequential` 中的层是有顺序的,而且严格按照其顺序执行
    相邻两个层连接必须保证前一个层的输出与后一个层的输入相匹配。
    '''
    def __init__(self):
        
        super(DNN, self).__init__()
        
        self.net = nn.Sequential(
            nn.Linear(in_features=28*28, out_features=500),
            nn.Sigmoid(),
            nn.Linear(in_features=500, out_features=10),
            nn.Sigmoid()
            )

    def forward(self, input):
        
        output = self.net(input)
        
        return output


def train():
    
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    
    model = DNN()
    criteon = torch.nn.CrossEntropyLoss(reduction='mean')
    optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
    batch_size= 5
    data = torch.rand((batch_size,28*28))
    epochs = 2
    target = torch.tensor([0,1,2,3,4])
    target = target.to(device)
    
    for epoch in range(epochs):
        
        yHat = model(data)
        loss = criteon(yHat, target)
        loss.backward()
        print("\n loss ",loss)
        
        optimizer.step()
        

if __name__ == "__main__":
    train()
    
    
    

 


五  min-batch

  在深度学习训练中,数据集我们通常采用min-batch 方案

    我们采用随机梯度方法,是为了加快运算速度。

但是GPU 可以并行运算,所以可以采用min-batch 方法进行梯度计算。

   使用min-batch 有个限制:

    1: 硬件限制 batch 不能超过硬件大小

    2:    batch 不能太大,否则容易陷入到局部极小值点,采用小的batch 可以有一定的随机性

每次出发点都不一样,一定概率跳过局部极小值点

参考:

7: Backpropagation_哔哩哔哩_bilibili

https://www.cnblogs.com/pinard/p/6422831.html

CSDN

8-1: “Hello world” of deep learning_哔哩哔哩_bilibili

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1324493.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

godot 报错Unable to initialize Vulkan video driver解决

版本 godot 4.2.1 现象 godot4.2.1 默认使用vulkan驱动,如果再不支持vulkan驱动的主机上,进入引擎编辑器将报错如下 解决 启动参数添加 –rendering-driver opengl3 即可进入引擎编辑器 此时运行项目仍然会报错无法初始化驱动 在项目设置中配置编…

Apache Tomcat httpoxy 安全漏洞 CVE-2016-5388 已亲自复现

Apache Tomcat httpoxy 安全漏洞 CVE-2016-5388 已亲自复现 漏洞名称漏洞描述影响版本 漏洞复现环境搭建漏洞利用修复建议 总结 漏洞名称 漏洞描述 在Apache Tomcat中发现了一个被归类为关键的漏洞,该漏洞在8.5.4(Application Server Soft ware)以下。受影响的是组…

Python---IP 地址的介绍

1. IP 地址的概念 IP 地址就是标识网络中设备的一个地址,好比现实生活中的家庭地址。 网络中的设备效果图: 2. IP 地址的表现形式 说明: IP 地址分为两类: IPv4 和 IPv6 IPv4 是目前使用的ip地址 IPv6 是未来使用的ip地址 IPv4 是由点分十进制组成 …

跟着我学Python进阶篇:01.试用Python完成一些简单问题

往期文章 跟着我学Python基础篇:01.初露端倪 跟着我学Python基础篇:02.数字与字符串编程 跟着我学Python基础篇:03.选择结构 跟着我学Python基础篇:04.循环 跟着我学Python基础篇:05.函数 跟着我学Python基础篇&#…

类和对象(中篇)

类的六个默认成员函数 如果一个类中什么成员都没有,简称为空类。 空类中真的什么都没有吗?并不是,任何类在什么都不写时,编译器会自动生成以下6个默认成员函数。 默认成员函数: 用户没有显式实现,编译器会…

Linux上随机输出谚语的程序fortune

概要: Linux上有一个随机输出谚语的程序叫fortune 手册对它的描述是:输出一个随机的、充满希望的、有趣的谚语 本篇所用的系统是Ubuntu22.04 一、fortune的安装 sudo apt install fortune-mod 二、fortune的使用 1、示例一 这个谚语是什么意思啊…

[DNS网络] 网页无法打开、显示不全、加载卡顿缓慢 | 解决方案

[网络故障] 网页无法打开、显示不全、加载卡顿缓慢 | 解决方案 问题描述 最近,我在使用CSDN插件浏览 MOOC 网站时,遇到了一些网络故障。具体表现为: MOOC 中国大学慕课网:www.icourse163.org点击CSDN插件首页的 MOOC&#xff08…

SLAM算法与工程实践——RTKLIB编译

SLAM算法与工程实践系列文章 下面是SLAM算法与工程实践系列文章的总链接,本人发表这个系列的文章链接均收录于此 SLAM算法与工程实践系列文章链接 下面是专栏地址: SLAM算法与工程实践系列专栏 文章目录 SLAM算法与工程实践系列文章SLAM算法与工程实践…

linux 内核的 lru_list 的结构

在linux的slab分配的入口slab_alloc有一个传入参数lru,它的作用是使每个slab对象在unused,但可能后面继续使用的时候,不需要free,可以先放在lru_list上。lru_list的结构为: struct list_lru {struct list_lru_node *n…

网工内推 | 上市公司中级网工,思科、华为认证优先,有带薪年假

01 新晨科技 招聘岗位:中级网络工程师 职责描述: 1. 负责公司网络系统的规划、设计、实施、维护和优化; 2. 负责网络设备的选型、采购、安装、配置和调试; 3. 负责网络安全策略的制定和实施,保障公司网络安全&#xf…

fastjson1.2.24 反序列化漏洞(CVE-2017-18349)分析

FastJson在< 1.2.24 版本中存在反序列化漏洞&#xff0c;主要原因FastJson支持的两个特性&#xff1a; fastjson反序列化时&#xff0c;JSON字符串中的type字段&#xff0c;用来表明指定反序列化的目标恶意对象类。fastjson反序列化时&#xff0c;字符串时会自动调用恶意对…

国标28181平台只能连接视频监控吗?

在一些视频监控项目中&#xff0c;国标28181平台成为了必不可少的工具。这个平台的主要作用在于将分布在不同区域的视频监控录像机、摄像头等设备进行联网管理&#xff0c;同时还能将视频监控连接到上一级的国标监控平台。 可以说&#xff0c;国标监控平台是一个非常重要的承上…

原子学习笔记3——使用tslib库

一、tslib介绍 tslib 是专门为触摸屏设备所开发的 Linux 应用层函数库&#xff0c;并且是开源。 tslib 为触摸屏驱动和应用层之间的适配层&#xff0c;它把应用程序中读取触摸屏 struct input_event 类型数据&#xff08;这是输入设备上报给应用层的原始数据&#xff09;并进行…

动力电池系统介绍(十四)——热管理系统

动力电池系统介绍&#xff08;十四&#xff09; 一、梗概二、座舱热管理&#xff08;汽车空调&#xff09;2.1 空调制冷2.2 空调制热2.2.1 传统燃油汽车空调制热2.2.2 新能源汽车空调制热 三、动力系统热管理3.1 燃油车发动机热管理3.1.1 冷却系统3.1.2 润滑系统3.1.3 进排气系…

网络编程day5

作业 1> 使用select完成TCP客户端程序 //client #include<myhead.h> #define CLINET_IP "192.168.125.79" #define CLINET_PORT 9999 #define SERVE_IP "192.168.125.79" #define SERVE_PORT 8888 int main(int argc, const char *argv[]) {/…

音视频直播核心技术介绍

直播流程 采集&#xff1a; 是视频直播开始的第一个环节&#xff0c;用户可以通过不同的终端采集视频&#xff0c;比如 iOS、Android、Mac、Windows 等。 前处理&#xff1a;主要就是美颜美型技术&#xff0c;以及还有加水印、模糊、去噪、滤镜等图像处理技术等等。 编码&#…

文件消失但是有占用内存的恢复方法

文件消失但占用内存是一个常见的问题&#xff0c;通常是由于文件系统错误或病毒攻击引起的。在这种情况下&#xff0c;文件虽然从目录结构中消失&#xff0c;但它们仍然占用存储空间。本文将分析这一问题的原因&#xff0c;并探讨解决该问题的几种方法。 文件消失但占用内存的原…

Text2SQL学习整理(四)将预训练语言模型引入WikiSQL任务

导语 上篇博客&#xff1a;Text2SQL学习整理&#xff08;三&#xff09;&#xff1a;SQLNet与TypeSQL模型简要介绍了WikiSQL数据集提出后两个早期的baseline&#xff0c;那时候像BERT之类的预训练语言模型还未在各种NLP任务中广泛应用&#xff0c;因而作者基本都是使用Bi-LSTM…

头部首发优志愿头部u_sign生成与TLS指纹处理! + 数据可视化技术讲解【Python爬虫】

目录 针对大学名称 大学排名, 综合指数,学校情况等数据进行爬取 找对应得数据包 请求发现数据有加密 发现加密参数 搜索加密参数&#xff0c;好进行分析 分析过程 数据可视化 针对大学名称 大学排名, 综合指数,学校情况等数据进行爬取 首先进行鼠标右键&#xff0c;进行…

制作PPT找了一个校徽是方形的,如何裁剪为圆形的。

问题描述&#xff1a;制作PPT找了一个校徽是方形的&#xff0c;如何裁剪为圆形的。 问题解决&#xff1a;使用一个在线圆形裁剪软件即可。 网址为&#xff1a; https://crop-circle.imageonline.co/cn/#google_vignette