【PyTorch深度学习实践】03_反向传播

news2025/1/16 12:58:58

文章目录

    • 1.计算图
    • 2.反向传播
      • 2.1 链式求导法则
      • 2.2 反向传播过程
    • 3.Pytorch中前馈和反馈的计算
      • 3.1 Tensor
      • 3.2 代码演示

对于简单的模型,梯度变换可以用解析式表达进行手算,但是复杂模型(很多w,b)的损失函数需要挨个写解析式,非常不好计算(主要是求导)。因此,可以考虑使用某种算法,把整个网络看做一个计算图,在图上传播整个梯度。这种算法,就称为反向传播算法。

转载:梯度下降法是通用的优化算法,反向传播法是梯度下降法在深度神经网络上的具体实现方式。

1.计算图

单层
在这里插入图片描述

需要注意的是,神经网络的训练本质,就是对每层的w和b进行训练。

在这里插入图片描述每一层的结束都需要引入非线性的激活函数。
如果不加入激活函数,那么无论多少层,得到的结果都是线性的。

在这里插入图片描述

2.反向传播

2.1 链式求导法则

进行反向传播的关键就是链式求导。反向传播其实就是计算图中的梯度求解,通过链式求导得到L对x和w的导数(梯度),再根据更新规则进行更新。

链式求导的规则,非常形象的图:
在这里插入图片描述

2.2 反向传播过程

1.构建计算图(前馈)
在这里插入图片描述

2. 求输出关于x和w的梯度
在这里插入图片描述
3. 损失L关于输出z的偏导
在这里插入图片描述
4. 运用链式求导法则,求L关于x和w的偏导(反馈)

在这里插入图片描述
一个简单线性模型(仿射模型)的前馈+反馈过程
在这里插入图片描述

3.Pytorch中前馈和反馈的计算

3.1 Tensor

参考博客
Tensor本身是一个类,里面包含两个比较重要的成员data(比如权重值)和grad(损失函数对权重的导数)

3.2 代码演示

import torch
import matplotlib.pyplot as plt

x_data = [1.0, 2.0, 3.0]
y_data = [2.0, 4.0, 6.0]

w = torch.Tensor([1.0])
w.requires_grad = True

def forward(x):      # 前馈过程
    return x * w            # 因为w是tensor,因此x会自动类型转换变成tensor,输出的结果也变为tensor

def loss(x, y):             # 损失函数
    y_pred = forward(x)
    return (y_pred - y) ** 2

loss_list = []
epoch_list = []

print("predict(before training)",4, forward(4).item())   # 因为数值是一维标量,所以可以直接用item取,不是标量(如向量,矩阵)得用data

for epoch in range(100):
    for x, y in zip(x_data, y_data):
        l = loss(x, y)  # 前馈:前馈并计算损失函数
        l.backward()  # 反馈:张量自带的成员函数,会自动反向传播算梯度,把计算连路上所有需要的梯度都求出来,算完会释放这个计算图,每次都会创建新的计算图
        print('\tgrad:', x, y, w.grad.item())
        w.data = w.data - 0.03 * w.grad.data   # 更新:.data得到的也是张量,但是只是数值改变的运算。不取data会构建计算图,占用内存
        
        w.grad.data.zero_()  # 把权重的梯度数据清零,不然后面几轮会累加计算

    epoch_list.append(epoch)
    loss_list.append(l.item())
    print("progress:", epoch, l.item())

print("predict(after training)", 4, forward(4).item())

plt.plot(epoch_list, loss_list)
plt.xlabel('epoch')
plt.ylabel('loss')
plt.show()

得到的图像如下:

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/151254.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【黑马】瑞吉外卖-Day03、04笔记

瑞吉外卖Day03、04 公共字段自动填充 使用MybatisPlus实现 问题分析 代码实现 Mybatis Plus公共字段自动填充,也就是在插入或者更新的时候为指定字段赋予指定的值,使用它的好处就是可以统一对这些字段进行处理,避免了重复代码。 实现步骤…

【学习】life long learning

文章目录life long learningLLL的难点评估二、LLL的三个解法1、Selective Synaptic Plasticity选择性突触可塑性为什么会有灾难性遗忘呢?GEM2、Additional Neural Resource Allocation额外的神经资源分配packNet&CPG3、memory replyCurriculum Learninglife lon…

SAP 字段仍作为视图字段在视图中使用 | 更改表结构重新生成 CDS View「实例」

错误信息 Field ZPDAUSER-ZUSERID is still being used as a view field in view ZV_PDA_USER视图 ZPDAUSER-ZUSERID 仍作为视图字段在视图 ZV_PDA_USER 使用 错误原因 当前表被 CDS View 引用,由 CDS View 生成的「视图」已占用当前表的相关字段然而生成的视图又…

实战5:基于 Pytorch 搭建 Faster-R-CNN 实现飞机目标检测(代码+数据)

任务描述: 通过一个飞机检测的案例来对目标检测的基本概念进行介绍并且实现一个简单的目标检测方法。数据集:使用从COCO数据集抽取的飞机数据集mini-airplane,数据集中的数据均为正常的图片。https://download.csdn.net/download/qq_38735017/87374251运行环境:操作系统:l…

Day4 基于XML的Spring应用

总结java依赖注入的方式set方法注入List、map和properties的注入通过构造方法注入ref是reference的缩写,需要引用其他bean的id,value用于注入普通属性值。自定义标签和其他标签的引用自定义标签beansbeanimportalias其他标签用于引用其他命名空间1 bean的…

sqli-labs 第八关 多命通关攻略(Python3 自动化实现布尔盲注)

sqli-labs 第八关 多命通关攻略(Python3 自动化实现布尔盲注)描述判断注入类型正常输入不正常输入错误输入爆破方式的可行性铺垫函数 IF()关于 MySQL 数据类型之间转换的小小礼物(仅部分)函数 ASCII()ASCII 表(可显示字…

火山引擎 DataTester:5 个优化思路,构建高性能 A/B 实验平台

导读:DataTester 是由火山引擎推出的 A/B 测试平台,覆盖推荐、广告、搜索、UI、产品功能等业务应用场景,提供从 A/B 实验设计、实验创建、指标计算、统计分析到最终评估上线等贯穿整个 A/B 实验生命周期的服务。DataTester 经过了字节跳动业务…

vivo 故障定位平台的探索与实践

作者:vivo 互联网服务器团队- Liu Xin、Yu Dan 本文基于故障定位项目的实践,围绕根因定位算法的原理进行展开介绍。鉴于算法有一定的复杂度,本文通过图文的方式进行说明,希望即使是不懂技术的同学也能理解。 一、背景介绍 1.1 程…

2023最新连锁店软件排名,国内十大连锁店管理软件新鲜出炉!

普通的数据工具、人工管理难以满足连锁店老板们的需求,正所谓“有需求就有市场”,随着连锁店、加盟店如雨后春笋般在城市里出现,连锁店软件也越来越多。究竟哪一款连锁店管理软件,才能满足老板们的需求?小编收集了国内…

9/365 java 数组 内存

1.数组 声明: int[] a;//首选 int a[];//一般不用 创建: int[] a new int[10]; // 需指定数组大小 初始化: 静态初始化: int[] a {8,9,10}; String[] s {new String("hello"), new String("world")…

南邮研究生考试历年真题知识点总结

下边的知识点是我在做南京邮电大学考研历年真题时遇到自己不会的题时整理出来的。第九部分是做mooc课后习题时整理出来的,希望对各位同学有所帮助。 md文档网址:https://gitee.com/infiniteStars/wang-dao-408-notes/blob/master/考研笔记/南邮数据结构知…

内存函数:学习笔记7

目录 一.前言 二. memcpy模拟实现 三. memmove模拟实现 四.memcmp模拟实现 一.前言 计算机内存的实质就是以字节为编号单元的二进制序列集合,操作内存时我们应具有这样的视角。 二. memcpy模拟实现 库函数memcpy函数首部:void *memcpy( void *dest, …

量子计算机“九章”

1.中国量子计算机“九章”实现量子霸权 2020年12月,中国科学技术大学宣布该校成功构建光子量子计算原型机“九章”。“九章”是中国科学技术大学潘建伟团队、中科院上海微系统所和国家并行计算机工程技术研究中心合作完成。“九章”的名字是来源于中国历史上最重要…

算法设计与分析-分支限界法习题

7-1 布线问题印刷电路板将布线区域划分成 nm 个方格阵列,要求确定连接方格阵列中的方格a 点到方格b 的最短布线方案。在布线时,电路只能沿直线布线,为了避免线路相交,已布了线的方格做了封锁标记,其他线路不允许穿过被…

大数据开发之利剑 -- TDengine

前言 在大数据技术全球爆炸的时代,以及大数据在各行各业的实际应用,大数据的快速发展就像计算机和互联网一样,很可能成为新一轮的技术革命。数据处理、机器学习、AI等新兴技术诞生,会改变数据世界的许多算法和理论基础&#xff0c…

DAY-1 | Java数据结构之链表:删除无头单链表中等于给定值 val 的所有节点

目录 一、题干 🔗力扣203. 移除链表元素 二、题解 1、思路 2、完整代码 一、题干 🔗力扣203. 移除链表元素 二、题解 1、思路 题干的意思是,要删除链表中所有指定的元素。最暴力的方法是,依次遍历链表中的各个节点&#x…

Day3 XML方式的Spring应用

全文总结基于XML配置1、学习了bean标签,2、三种配置bean的方式:1、静态工厂;2、实例工厂和3、自定义实现factorybean1 SpringBean 的配置类inin-method 与构造方法不同,构造方法是创建对象的,等对象创建以后使用inin-m…

2023年全国最新消防设施操作员精选真题及答案

百分百题库提供消防设施操作员考试试题、消防设施操作员考试预测题、消防设施操作员考试真题、消防设施操作员证考试题库等,提供在线做题刷题,在线模拟考试,助你考试轻松过关。 1、对外观目测判断,下列哪种情况不应报废?(  ) A、铭牌标志脱落 B、瓶…

pyqt5加载matplotlib图形

matplotlib的图形处理非常强大。今天花了很长时间才将matplotlib图形嵌入到pyqt5中。在这里记录一下,便于以后查寻。有些可能还理解不到位。开始要导入几个模块:from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvasFigur…

tomcat下载和配置(简单,详细)

下载 官网:http://tomcat.apache.org/ 找到需要的版本,点击download 在download页面,选择需要下载的。(分为压缩版和安装版,我比较推荐压缩版,省事解压缩就好) 配置 首先!&…