[Pytorch] 前向传播和反向传播示例

news2025/3/10 5:35:27

目录

简介

神经网络训练基本步骤

1. 计算图

2. 前向传播 Forward

3. 计算损失Loss 【损失函数】

4. 反向传播 Backward

5. 使用学习率更新权重【优化器】

样例代码

样例结果

样例图解


简介

PyTorch是一个基于Torch的Python开源机器学习库,用于自然语言处理等应用程序。Pytorch提供了两个高级功能:

  1. 具有强大的GPU加速的张量计算(Numpy的替代品)
  2. 包含自动求导系统的深度神经网络

神经网络训练基本步骤

1. 计算图

组成:由节点和边组成,节点分为Tensor和Function(运算)

  • Tensor分为叶子节点和非叶节点
  • Pytorch计算图是动态图

2. 前向传播 Forward

操作:根据输入数据进行推测。创建Fucntion后可以立即执行,不需要等到计算图定义好之后再执行。

3. 计算损失Loss 【损失函数】

操作:计算前向推测结果与真实值之间的误差

4. 反向传播 Backward

操作:将Loss向输入侧进行反向传播,对所有需要进行梯度计算的所有变量 leaf node Tensor x(requires_grad=True),计算梯度 dLoss/dx,并将其积累到梯度x.grad中备用, 即:x.grad = x.grad + dLoss/dx

5. 使用学习率更新权重【优化器】

操作:使用优化器对x的值进行更新。优化器会根据用户设置的学习率以及x.grad来更新x。

如:随机梯度下降SGD,x = x - learning_rate * x.grad

样例代码

def test_training_pipeline():    
    # ============================================================================ 1.创建计算图
    # ============================================================================ 2.前向传播(即时计算)
    input_data = [[4, 4, 4, 4],
                  [9, 9, 9, 9]]  # 2x4
    input = torch.tensor(input_data, dtype=torch.float32, requires_grad=True)
    output = torch.sqrt(input)
    print("\n### 前向传播推测结果:\n", output)

    # ============================================================================ 3.计算Loss
    target_data = [1, 2, 3, 4]
    target = torch.tensor(target_data, dtype=torch.float32)
    
    loss_fn = torch.nn.MSELoss()
    loss = loss_fn(input=output, target=target)
    print("\n### loss:\n", loss)
       
    # ============================================================================ 4.反向传播
    loss.backward()
    print("\n### input_grad:\n", input.grad)
    
    # ============================================================================ 5.更新input
    optim = torch.optim.SGD([input], lr=0.001)
    print("\n### input before optim.step():\n", input)
    optim.step()
    print("\n### input after optim.step():\n", input)

样例结果

样例图解

图解和手动计算前向传播和反向传播。

参考

理解Pytorch的loss.backward()和optimizer.step() - 知乎

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/386280.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

四、发布确认

1、发布确认原理 生产者将信道设置成 confirm 模式,一旦信道进入 confirm 模式,所有在该信道上面发布的消息都将会被指派一个唯一的 ID(从 1 开始),一旦消息被投递到所有匹配的队列之后,broker就会发送一个确认给生产者(包含消息…

某小公司面试记录

记录一次面试过程,还有一些笔试题,挺简单的,排序,去重,this指向,深浅拷贝,微任务的执行顺序,变量提升等。 ES6数组新增的方法 Array.from: 将两类对象转为真正的数组&am…

微信又变天!

大家好,我是良许。 不知道大家有没发现,过去两周,微信又双叒改版了! 这个改版,喜欢看公众号的小伙伴可能会不习惯,作为公众号的作者更为难受,用一个变天来形容都不为过。 微信又搞啥幺蛾子呢…

软件测试---测试分类

一 : 按测试对象划分 1.1 可靠性测试 可靠性(Availability)即可用性,是指系统正常运行的能力或者程度,一般用正常向用户提供软件服务的时间占总时间的百分比表示。 1.2 容错性测试 行李箱 , 四个轮子 , 坏了一个 , 说明这个容错…

如何在香港BGP服务器上进行安全性和隐私性配置?

​  香港BGP服务器是在香港运营的,它是基于BGP多线路的网络拓扑所构建的服务器,主要面向于中国内地和海外地域。香港BGP服务器庞大的市场扩张,引来了国内外企业的眼光。然而,如果想要确保香港BGP服务器上的数据安全可靠&#xf…

Tapdata Cloud 基础课:新功能详解之「微信告警」,更及时的告警通知渠道

【前言】作为中国的 “Fivetran/Airbyte”, Tapdata 是一个以低延迟数据移动为核心优势构建的现代数据平台,内置 60 数据连接器,拥有稳定的实时采集和传输能力、秒级响应的数据实时计算能力、稳定易用的数据实时服务能力,以及低代码可视化操作…

MFC界面控件BCGControlBar v33.4 - 支持Win 11 Mica material主题

BCGControlBar库拥有500多个经过全面设计、测试和充分记录的MFC扩展类。 我们的组件可以轻松地集成到您的应用程序中,并为您节省数百个开发和调试时间。BCGControlBar专业版和BCGSuite for MFC v33.4已正式发布了,该版本包含了对Windows 11 Mica materia…

小Redis:开源一款迷你C++17 KV内存型数据库

A KV high-performance mini-database based on memory and C17 This project is inspired by Redis source code. 部分模仿Redis源码。 https://github.com/ZYunfeii/MiniKV Command line tools Developed command line tool kvctl. value type:string yunfeiubuntu:~/Min…

JavaScript函数之prototype原型和原型链

文章目录1. 原型2. 显式和隐式原型3. 原型链3.1 访问顺序4. instanceof4.1 如何判断1. 原型 函数的prototype属性 每个函数都有一个prototype属性,它默认指向一个Object空对象(即:原型对象)。原型对象中有一个属性constructor&a…

【C++从入门到放弃】类和对象(中)———类的六大默认成员函数

🧑‍💻作者: 情话0.0 📝专栏:《C从入门到放弃》 👦个人简介:一名双非编程菜鸟,在这里分享自己的编程学习笔记,欢迎大家的指正与点赞,谢谢! 类和对…

Python | 蓝桥杯进阶第一卷——字符串

欢迎交流学习~~ 专栏: 蓝桥杯Python组刷题日寄 蓝桥杯进阶系列: 🏆 Python | 蓝桥杯进阶第一卷——字符串 🔎 Python | 蓝桥杯进阶第二卷——递归(待续) 💝 Python | 蓝桥杯进阶第三卷——动态…

论文阅读-End-to-End Open-Domain Question Answering with BERTserini

论文链接:https://aclanthology.org/N19-4013.pdf 目录 摘要 1 简介 2 背景及相关工作 3 系统架构 3.1 Anserini Retriever 3.2 BERT 阅读器 4 实验结果 5演示 6结论 摘要 我们展示了一个端到端的问答系统,它将 BERT 与开源 Anserini 信息检索…

MSYS2安装

最近在学习windows上编译FFmpeg,需要用到msys2,在此记录一下安装和配置过程。 点击如下链接,下载安装包: Index of /msys2/distrib/x86_64/ | 清华大学开源软件镜像站 | Tsinghua Open Source Mirror 我下载的是:ms…

相信人还是相信ChatGPT,龙测首席AI专家给出了意料之外的答案

最近,关于ChatGPT的话题太火了!各大社交软件都是他的消息!从去年12月份ChatGPT横空出世,再到近期百度文心一言、复旦Moss的陆续宣布,点燃了全球对AIGC(内容人工智能自动生成)领域的热情&#xf…

搭建Bitbucket项目管理工具详细教程

目录 1.安装前准备 2.jdk安装 2.1.rpm安装方式: 3.创建bitbucket数据库 4.安装Git 5.安装bitbucket 5.1下载完成上传至服务器的 /usr/atlassian/ 目录下 5.2安装atlassian-bitbucket-7.21.0 5.3安装MySQL驱动 5.4破解激活bitbucket 1.安装前准备 首先查看操…

Python 之网络式编程

一 客户端/服务器架构 即C/S架构,包括 1、硬件C/S架构(打印机) 2、软件B/S架构(web服务) C/S架构与Socket的关系: 我们学习Socket就是为了完成C/S的开发 二 OSI七层 引子:   计算机组成…

【Spark分布式内存计算框架——Spark Streaming】13. 偏移量管理(下)MySQL 存储偏移量

6.3 MySQL 存储偏移量 此处将偏移量数据存储到MySQL表中,数据库及表的DDL和DML语句如下: -- 1. 创建数据库的语句 CREATE DATABASE IF NOT EXISTS db_spark DEFAULT CHARSETutf8mb4 COLLATEutf8mb4_0900_ai_ci; USE db_spark ; -- 2. 创建表的语句 CRE…

蓝牙资讯|2022 年 Q4 全球 TWS 耳机出货量 7900 万部

Canalys 最新数据显示,2022 年第四季度,全球个人智能音频设备出货量下降 26%,跌至 1.1 亿部。所有品类的出货量都面临不一的下滑趋势,甚至是一直支撑市场的 TWS 品类也遭遇 23% 两位数的下降至 7900 万部。 全球市场方面&#x…

MySQL中varchar(M)存储字符串过长

最近写项目&#xff0c;数据库报了一个错&#xff0c;错误原因是MySQL中存储的字符串过长最近在学MySQL的基础&#xff0c;刚好学到了关于varchar类型要存储的字符串是 “<p>12121212121212</p>\n<p><img src\"https://zzjzzjzzjbucket.oss-cn-hangz…

附录5-大事件项目前端

目录 1 前言 2 用到的插件 2.1 截取图像 cropper 2.2 富文本编辑器 tinymce 3 项目结构 4 config.js 5 主页 5.1 iframe 5.2 页面的宽高 5.3 修改文章 6 个人中心-基本资料 7 个人中心-更换头像 8 个人中心-更换密码 9 文章管理-文章分类 10 文章…