Back Propagation 反向传播

news2024/11/26 8:42:57

文章目录

    • 3、Back Propagation 反向传播
      • 3.1 引出算法
      • 3.2 非线性函数
      • 3.3 算法步骤
        • 3.3.1 例子
        • 3.3.2 作业1
        • 3.3.3 作业2
      • 3.4 Tensor in PyTorch
      • 3.5 PyTorch实现线性模型
      • 3.6 作业3

3、Back Propagation 反向传播

B站视频教程传送门:PyTorch深度学习实践 - 反向传播

3.1 引出算法

对于简单(单一)的权重 W ,我们可以将它封装到 Neuron(神经元)中,然后对权重W进行更新即可:

在这里插入图片描述

但是,往往实际情况中会涉及多个权重(例如下图:会有几十到上百个),我们几乎不可能做到写出所有的解析式:

在这里插入图片描述

矩阵论参考:The Matrix Cookbook - http://matrixcookbook.com

所以,引出反向传播(Back Propagation)算法:

在这里插入图片描述

当然,也可以将括号里面展开,如下如所示:

在这里插入图片描述

注意:不断地进行线性变化,不管有多少层,最后都会统一成一种形式。

所以:不能够化简(展开),因为增加的这些权重毫无意义。

3.2 非线性函数

所以我们需要引进非线性函数,每一层都需要一个非线性函数

在这里插入图片描述

3.3 算法步骤

在具体讲解反向传播之间,先回顾一下链式求导法则

在这里插入图片描述

(1)创建计算图表 Create Computational Graph (Forward)

在这里插入图片描述

(2)本地梯度 Local Gradient

在这里插入图片描述

(3)给出连续节点的梯度 Given gradient from successive node

在这里插入图片描述

(4)使用链式规则来计算梯度 Use chain rule to compute the gradient (Backward)

在这里插入图片描述

(5)Example: 𝑓 = 𝑥 ∙ 𝜔, 𝑥 = 2, 𝜔 = 3

在这里插入图片描述

在这里插入图片描述

3.3.1 例子

在这里插入图片描述

3.3.2 作业1

在这里插入图片描述

解答如下:

在这里插入图片描述

3.3.3 作业2

在这里插入图片描述

解答如下:

在这里插入图片描述

3.4 Tensor in PyTorch

在这里插入图片描述

3.5 PyTorch实现线性模型

import torch

x_data = [1.0, 2.0, 3.0, 4.0]
y_data = [2.0, 4.0, 6.0, 8.0]

w = torch.Tensor([1.0])
w.requires_grad = True  # 计算梯度


def forward(x):
    return x * w


def loss(x, y):
    y_pred = forward(x)
    return (y_pred - y) ** 2


print("predict (before training)", 4, forward(4).item())

for epoch in range(100):
    for x, y in zip(x_data, y_data):
        l = loss(x, y)
        l.backward()
        print('\tgrad:', x, y, w.grad.item())
        w.data -= 0.01 * w.grad.data

        w.grad.data.zero_()

    print("progress:", epoch, l.item())

print("predict (after training)", 4, forward(4).item())
predict (before training) 4 4.0
	grad: 1.0 2.0 -2.0
	grad: 2.0 4.0 -7.840000152587891
	grad: 3.0 6.0 -16.228801727294922
	grad: 4.0 8.0 -23.657981872558594
progress: 0 8.745314598083496
...
	grad: 1.0 2.0 -2.384185791015625e-07
	grad: 2.0 4.0 -9.5367431640625e-07
	grad: 3.0 6.0 -2.86102294921875e-06
	grad: 4.0 8.0 -3.814697265625e-06
progress: 99 2.2737367544323206e-13
predict (after training) 4 7.999999523162842

3.6 作业3

在这里插入图片描述

解答如下:

在这里插入图片描述

代码实现:

import torch

x_data = [1.0, 2.0, 3.0]
y_data = [2.0, 4.0, 6.0]

w1 = torch.Tensor([1.0])  # 初始权值
w1.requires_grad = True  # 计算梯度,默认是不计算的
w2 = torch.Tensor([1.0])
w2.requires_grad = True
b = torch.Tensor([1.0])
b.requires_grad = True


def forward(x):
    return w1 * x ** 2 + w2 * x + b


def loss(x, y):  # 构建计算图
    y_pred = forward(x)
    return (y_pred - y) ** 2


print('Predict (before training)', 4, forward(4))

for epoch in range(100):
    l = loss(1, 2)  # 为了在for循环之前定义l,以便之后的输出,无实际意义
    for x, y in zip(x_data, y_data):
        l = loss(x, y)
        l.backward()
        print('\tgrad:', x, y, w1.grad.item(), w2.grad.item(), b.grad.item())
        w1.data = w1.data - 0.01 * w1.grad.data  # 注意这里的grad是一个tensor,所以要取他的data
        w2.data = w2.data - 0.01 * w2.grad.data
        b.data = b.data - 0.01 * b.grad.data
        w1.grad.data.zero_()  # 释放之前计算的梯度
        w2.grad.data.zero_()
        b.grad.data.zero_()
    print('Epoch:', epoch, l.item())

print('Predict (after training)', 4, forward(4).item())
Predict (before training) 4 tensor([21.], grad_fn=<AddBackward0>)
	grad: 1.0 2.0 2.0 2.0 2.0
	grad: 2.0 4.0 22.880001068115234 11.440000534057617 5.720000267028809
	grad: 3.0 6.0 77.04720306396484 25.682401657104492 8.560800552368164
Epoch: 0 18.321826934814453
...
	grad: 1.0 2.0 0.31661415100097656 0.31661415100097656 0.31661415100097656
	grad: 2.0 4.0 -1.7297439575195312 -0.8648719787597656 -0.4324359893798828
	grad: 3.0 6.0 1.4307546615600586 0.47691822052001953 0.15897274017333984
Epoch: 99 0.00631808303296566
Predict (after training) 4 8.544171333312988

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/92544.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

联通数科一面+二面+面谈 经验分享 base济南

联通数科一面二面面谈 10.8 投递简历&#xff08;大数据开发岗位 base西安 有成都岗&#xff1f; 我怎么没看到&#xff09; 10.10-12 笔试 11.05 一面 有五六个面试官 问了问题的有两个 介绍了下项目&#xff0c;问了些每个组件的基础知识&#xff0c;都是大数据的基本八股…

深聊性能测试,从入门到放弃之: Windows系统性能监控(三)任务管理器介绍及使用。

任务管理器1、引言2、任务管理器2.1 打开方式2.2 介绍2.2.1 定义2.2.2 进程2.2.3 性能2.2.4 应用历史记录2.2.5 启动2.2.6 用户2.2.7 详细信息2.2.8 服务3、总结1、引言 关于Windows系统性能监控的前两篇 《Windows系统性能监控(一) 性能监视器介绍及使用》《Windows系统性能…

深圳IB学校哪家强?入学标准如何?

我们都知道&#xff0c;孩子就读哪所学校&#xff0c;学校开设什么课程&#xff0c;这会直接影响孩子的留学之路和未来。 所以一般情况下&#xff0c;选择国际学校都是分两步走&#xff0c;先是需要是根据孩子的个性特点去选择能够适应的课程体系&#xff0c;再根据学校开设课程…

【论文阅读】inception v1学习总结

【论文阅读总结】inception v1总结1. 摘要2. 序言3. 文献综述4.动机和高层考虑4.1提高深度神经网络性能的最直接方法4.1.1 增加模型的大小4.1.2 解决增加模型大小导致的缺点思路5.结构详述5.1 Inception架构的主要思想5.2 原生inception块问题5.3 解决通道数增加问题5.4 1*1卷积…

Golang 【basic_leaming】基本数据类型

阅读目录Golang 数据类型介绍整型特殊整型unsafe.Sizeof 查看内存所占用大小int 不同长度直接的转换数字字面量语法&#xff08;Number literals syntax&#xff09;&#xff08;了解&#xff09;浮点型布尔值字符串字符串转义符多行字符串字符串的常用操作修改字符串byte 和 r…

PYNQ -z2 与 PC主板网口直连

文章目录1.下载映像文件并烧录到板子2. 将usb和网线连接到电脑上3. 使用xshell新建串口通信连接到板子4. 查看板子ip ifconfig命令5. 将pc的网络改成与板子同一个网段6. 通过ip地址访问1.下载映像文件并烧录到板子 可以参考 烧录镜像 2. 将usb和网线连接到电脑上 3. 使用xshe…

jsp+ssm计算机毕业设计毕业论文管理系统【附源码】

项目运行 环境配置&#xff1a; Jdk1.8 Tomcat7.0 Mysql HBuilderX&#xff08;Webstorm也行&#xff09; Eclispe&#xff08;IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持&#xff09;。 项目技术&#xff1a; JSPSSM mybatis Maven等等组成&#xff0c;B/S模式 Mave…

【函数】你上街买菜用的着函数吗?

前言 函数是什么 每个C程序都至少有一个函数&#xff0c;即main主函数 &#xff0c;如果程序的任务比较简单&#xff0c;全部的代码都写在main函数中&#xff0c;但是在实际开发中&#xff0c;程序的任务往往比较复杂&#xff0c;如果全部的代码都写在main函数中&#xff0c;…

Nginx配置整合:基本概念、命令、反向代理、负载均衡、动静分离、高可用

一、基本概念 1.什么是Nginx Nginx是一个高性能的HTTP和反向代理服务器&#xff0c;也是一个IMAP/POP3/SMTP代理server。其特点是占有内存少。并发能力强&#xff0c;其并发能力确实在同类型的网页server中表现较好。 http服务器 Web服务器是指驻留于因特网上某种类型计算机的程…

热门的Java开源项目

1 JCSprout https://github.com/crossoverJie/JCSprout Star 17084 这是一个还处于萌芽阶段的 Java 核心知识库。分为常用集合、Java多线程、JVM、分布式相关、常用框架等内容 2 arthas https://github.com/alibaba/arthas Star 6836 Arthas旨在帮助开发人员解决Java应用程…

高级网络应用复习——三层交换DHCP中继(带命令)

作者简介&#xff1a;一名在校云计算网络运维学生、每天分享网络运维的学习经验、和学习笔记。 座右铭&#xff1a;低头赶路&#xff0c;敬事如仪 个人主页&#xff1a;网络豆的主页​​​​​​ 目录 前言 一.知识点总结 二.DHCP中继实验 实验要求 实验命令 三层交换…

腾讯会议人数上限进不去?

很多用户都在使用腾讯会议来进行线上会议&#xff0c;因此经常会出现人数到达上限进不去的情况&#xff0c;非常令人头疼&#xff0c;那这时候要怎么办呢&#xff1f;下面就来看看解决办法。 腾讯会议人数上限进不去怎么办&#xff1f; 答&#xff1a;只能等待有人退出后再加入…

数据结构入门-单调队列

数据结构入门-单调队列 原理介绍 双向队列 思考一下&#xff1a;对于数组nums&#xff0c;我们想知道max(nums[i-k],...,nums[i])如何高效处理&#xff1f; 单调队列 单调队列&#xff0c;即从队首到队尾单调的队列。 #mermaid-svg-6PhVAHBib6ohdcIC {font-family:"tre…

从mask rcnn到mask scoring rcnn

mrcnn (mask rcnn) 不足:框架没有解决实例分割评分 对mask head输出的K(类别)个mask,选择哪个mask作为最终的输出,取决于分类支路置信度最高的类别。也就是用分类置信度来衡量mask质量,这会导致下图的现象: 左侧四幅图像显示出良好的检测结果,具有高分类分数但掩模质…

【spring系列】SPI详解

1.什么是SPI SPI全称Service Provider Interface&#xff0c;是Java提供的一套用来被第三方实现或者扩展的接口&#xff0c;它可以用来启用框架扩展和替换组件。 SPI的作用就是为这些被扩展的API寻找服务实现。2.SPI和API的使用场景 API &#xff08;Application Programming …

大数据Kudu(七):Kudu分区策略

文章目录 Kudu分区策略 一、​​​​​​​Partition By Range - 范围分区

最通俗易懂的 JAVA slf4j,log4j,log4j2,logback 关系与区别以及完整集成案例

最近在工作中&#xff0c;发现接触到的很多小伙伴分不清楚logback slf4j 以及log4j 的关系&#xff0c;有的人认为是一个东西&#xff0c;有的人认为是完全没关系&#xff0c;或者说有关系但是不清楚具体是什么区别和联系&#xff0c;今天咱们就简单梳理下他们之间的联系和区别…

项目式学习法(PBL)如何让你快速成为行业专家【一杯咖啡谈项目】

项目人人都是主角&#xff0c;没有旁观者。我们每个人也应当好PM&#xff0c;这就离不开学习提升自己&#xff0c;&#xff0c;如此&#xff0c;方能更好推动经济社会高质量发展。 1、项目式学习是什么&#xff1f; 关于项目式学习&#xff0c;目前国内外还没有个统一的定义&…

【python】 json字符串转对象

目录 一&#xff1a;json对象转换为json字符串 二&#xff1a;json字符串转换为json对象 三&#xff1a;json字符串{"name":"lily","sno":1001} 四&#xff1a;python面向对象程序设计 一&#xff1a;json对象转换为json字符串 import json…

Elasticsearch 安装及启动【Linux】

一、下载安装包 1.下载 Elasticsearch 官网下载地址&#xff1a;https://www.elastic.co/cn/downloads/past-releases#elasticsearch 2.下载 Kibana Kibana 数据可视化平台可以选择性安装 官网下载地址&#xff1a;https://www.elastic.co/cn/downloads/past-releases#kiban…