神经网络的工程基础(零)——PyTorch基础

news2025/1/15 18:14:24

相关说明

这篇文章的大部分内容参考自我的新书《解构大语言模型:从线性回归到通用人工智能》,欢迎有兴趣的读者多多支持。
本文涉及到的代码链接如下:regression2chatgpt/ch06_optimizer/gradient_descent.ipynb

本文将介绍PyTorch的基础。它是神经网络领域常用的建模工具。

关于大语言模型的内容,推荐参考这个专栏。

内容大纲

  • 相关说明
  • 一、PyTorch的数据基础:张量(Tensor)
  • 二、张量的基本计算

一、PyTorch的数据基础:张量(Tensor)

工欲善其事,必先利其器。在讨论如何实现梯度下降法之前,首先探讨一下PyTorch这个强大的工具。PyTorch是一种备受欢迎的开源机器学习框架,被广泛用于构建、训练和部署神经网络模型,因具有灵活性、动态计算图和卓越的GPU支持而成为神经网络领域的首选。

PyTorch的基础数据结构是张量。张量的创建方式如程序清单1所示(完整代码)。

程序清单1 张量的创建
 1 |  # 使用tensor封装的函数创建tensor
 2 |  zeros = torch.zeros(2, 3)
 3 |  tensor([[0., 0., 0.],
 4 |          [0., 0., 0.]])
 5 |  
 6 |  ones = torch.ones(2, 3)
 7 |  tensor([[1., 1., 1.],
 8 |          [1., 1., 1.]])
 9 |  
10 |  torch.manual_seed(1024)
11 |  random = torch.rand(3, 4)
12 |  tensor([[0.8090, 0.7935, 0.2099, 0.9279],
13 |          [0.8136, 0.7422, 0.4769, 0.4955],
14 |          [0.3602, 0.1178, 0.7852, 0.0228]])
15 |  
16 |  # 从Python对象创建
17 |  data = [[2, 3, 4], [1, 0, 1]]
18 |  t_data = torch.tensor(data)
19 |  tensor([[2, 3, 4],
20 |          [1, 0, 1]])
21 |  
22 |  ## 从numpy对象创建
23 |  import numpy as np
24 |  
25 |  n_data = np.array(data)
26 |  tn_data = torch.from_numpy(n_data)
27 |  tensor([[2, 3, 4],
28 |          [1, 0, 1]])
29 |  
30 |  ## Numpy bridge,也就是说对numpy对象的改变会传导到tensor
31 |  n_data += 1
32 |  torch.all(torch.from_numpy(n_data) == tn_data)
33 |  tensor(True)

张量的形状(Shape)是至关重要的概念,它定义了张量的维度以及每个维度的大小。在实际应用中,可以通过使用一系列函数来改变张量的形状,使其适应不同的运算需求,如程序清单2所示。

程序清单2 改变张量的形状
 1 |  # 增加或减少数据的维度
 2 |  a = torch.rand(3, 4)  # (3, 4)
 3 |  ## 增加维度
 4 |  b = a.unsqueeze(0)    # (1, 3, 4)
 5 |  ## 减少维度
 6 |  c = b.squeeze(0)      # (3, 4)
 7 |  ## 数据相同,但是维度不同
 8 |  print(torch.all(c.eq(b)))    # tensor(True)
 9 |  print(c.shape == b.shape)    # False
10 |  
11 |  # 变换tensor形状
12 |  data = torch.tensor(range(0, 10))   # tensor([1, 2, 3, 4, 5, 6, 7, 8, 9])
13 |  view1 = data.view(2, 5)
14 |  tensor([[0, 1, 2, 3, 4],
15 |          [5, 6, 7, 8, 9]])
16 |  transpose1 = view1.T
17 |  tensor([[0, 5],
18 |          [1, 6],
19 |          [2, 7],
20 |          [3, 8],
21 |          [4, 9]])
22 |  ## 非毗邻存储的对象不能进行view操作
23 |  print(view1.is_contiguous(), transpose1.is_contiguous()) 
24 |  True False
25 |  ## 下面的操作会报错
26 |  view2 = transpose1.view(1, 10)
  1. 程序清单2的第4—6行使用unsqueeze和squeeze函数来增加或减少张量的维度。需要注意的是,这些操作并不会改变张量实际存储的数据,也不会在实质上改变张量的形状。相反,它们只是在张量的形状中添加或删除一个空的维度。具体的变化可以在第8行和第9行中看到。
  2. 为了改变张量的形状,可以使用view函数,如第12—15行所示。但需要注意的是,view函数只能用在毗邻存储的张量1对象上。非毗邻存储的张量只能使用reshape函数来改变形状。尽管这两个函数在功能上相似,但在计算效率上存在显著差异:相较于 view 函数,reshape 的计算开销要大得多。因此,在实际应用中,最好优先选择使用 view 函数。

二、张量的基本计算

张量的运算分为两种:逐元素操作(Element-Wise Operations)和矩阵乘法,这些计算方法在处理数据和构建神经网络模型时都具有重要作用。程序清单6-3中讨论了这些操作,并介绍了PyTorch中的广播机制(Broadcasting Semantics),它在处理不同形状的张量时起到了重要的作用。

程序清单3 张量的常见运算
 1 |  # 逐元素操作
 2 |  twos = torch.ones(2, 2) * 2
 3 |  tensor([[2., 2.],
 4 |          [2., 2.]])
 5 |  powers = twos ** torch.tensor([[1, 2], [3, 4]])
 6 |  tensor([[ 2.,  4.],
 7 |          [ 8., 16.]])
 8 |  
 9 |  ## tensor广播,tensor broadcasting
10 |  a = torch.tensor(range(1, 7)).view(2, 3)
11 |  tensor([[1, 2, 3],
12 |          [4, 5, 6]])
13 |  b = torch.tensor(range(1, 4)).view(   3)
14 |  tensor([1, 2, 3])
15 |  print(a * b)
16 |  tensor([[ 1,  4,  9],
17 |          [ 4, 10, 18]])    
18 |  ## 关于广播,更复杂的例子
19 |  a =     torch.ones(4, 1, 3, 2)
20 |  b = a * torch.rand(   5, 1, 2)
21 |  print(b.shape)
22 |  torch.Size([4, 5, 3, 2])
23 |  
24 |  # 矩阵运算
25 |  mat1 = torch.randn(3, 4)    # (3, 4)
26 |  mat2 = torch.randn(4, 5)    # (4, 5)
27 |  re = mat1 @ mat2            # (3, 5)
28 |  ## 矩阵运算的广播
29 |  mat1 = torch.randn(5, 1, 3, 4)   # (5, 1, 3, 4)
30 |  mat2 = torch.randn(   8, 4, 5)   # (   8, 4, 5)
31 |  re = mat1 @ mat2                 # (5, 8, 3, 5)
  1. 逐元素操作要求进行运算的两个张量的形状必须相同,如程序清单3中的第2—7行所示。然而,在实际应用中,常常需要对形状不同的张量进行操作。为此,PyTorch引入了广播机制,它允许在一定条件下对形状不同的张量进行逐元素操作,如第9—22行所示。
  2. 广播机制的流程相对复杂,如图1所示,需要注意几个关键步骤。首先,从后向前逐个比较两个张量的维度;接着,对缺失的维度进行扩充(类似于unsqueeze函数的操作);然后,检查广播规则,即两个张量的各分量要么相等,要么其中一个等于1;最后,复制数据,实现广播操作。
  3. 广播机制不仅适用于逐元素操作,它同样影响着张量的矩阵乘法。不同之处在于,当执行矩阵乘法时,广播机制只会作用于前面的维度,而不涉及最后两维,如第29—31行所示。

图1

图1


  1. 毗邻存储(C Contiguous)是一个与硬件相关的概念。简而言之,毗邻存储意味着数据在内存中是连续存储的,这种存储方式能够显著提升数据的读取和计算速度。张量在内存中的存储细节超出了本书的范围,对此感兴趣的读者可以在PyTorch的官方文档中找到更详细的信息。 ↩︎

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1708190.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Centos安装,window、ubuntus双系统基础上安装Centos安装

文章目录 前言一、准备工作二、开始安装1、2、首先选择DATE&TIME2、选择最小安装3、 选择安装位置 总结 前言 因工作需要,我需要在工控机上额外装Centos7系统,不过我是装在机械硬盘上了不知道对性能是否有影响,若有影响,后面…

MyBatis系统学习篇 - MyBatis逆向工程

MyBatis的逆向工程是指根据数据库表结构自动生成对应的Java实体类、Mapper接口和XML映射文件的过程。逆向工程可以帮助开发人员快速生成与数据库表对应的代码,减少手动编写重复代码的工作量。 我们在MyBatis中通过逆向工具来帮我简化繁琐的搭建框架,减少…

macOS上用Qt creator编译并跑shotcut

1 简介 Shotcut是一个开源的跨平台的视频编辑软件,支持WIN/MACOS/LINUX等平台,由于该项目的编译较为麻烦,踩坑几许,因此写此文章记录完整编译构建过程,后续按此法编译,可减少走弯路,提高生产力。…

vue+antd实践:在输入框光标处插入内容

今天来看一个很简单的需求。 需求描述:在输入框光标处,插入指定的内容。 效果如下: 实现思路:刚开始还在想怎么获取光标的位置,但是发现所做的项目是基于vue3antd组件,那么不简单了嘛,只要调…

SpringBoot Redis 扩展高级功能

环境:SpringBoot2.7.16 Redis6.2.1 1. Redis消息发布订阅 Spring Data 为 Redis 提供了专用的消息传递集成,其功能和命名与 Spring Framework 中的 JMS 集成类似。Redis 消息传递大致可分为两个功能区域: 信息发布 信息订阅 这是一个通常…

[XYCTF新生赛]-Reverse:你是真的大学生吗?解析(汇编异或逆向)

无壳 查看ida 没有办法反汇编,只能直接看汇编了。 这里提示有输入,输入到2F地址后,然后从后往前异或,其中先最后一个字符与第一个字符异或。这里其实也有字符串的长度,推测应该是cx自身异或之后传给了cx 完整exp&am…

三分钟轻松搞定内容,2024视频号最新AI自动生成影视解说,,百分之百过原创, 月入1万+

在这个数字时代,我们有幸见证了AI技术对创新的推动。现如今,一个崭新的平台出现了,它能让你用AI软件在短短3分钟内制作完成一段影视解说,而且由于这个平台尚属于新兴,竞争者稀少,提供了一个广阔的机遇天地。…

如何将云服务器上操作系统由centos切换为ubuntu

本文将介绍如何将我们购买的云服务器上之前装的centos切换为ubuntu,云服务器以华为云为例,要切换的ubuntu版本为ubuntu20.04。 参考官方文档:切换操作系统_弹性云服务器 ECS (huaweicloud.com) 首先打开华为云官网,登录后点击右…

多模态MLLM都是怎么实现的(9)-时序LLM是怎么个事儿?

时序预测这东西大家一般不陌生,随便举几个例子 1- 金融,比如预测股票(股市有风险,入市需谨慎),纯用K线做,我个人不太推荐 2- 天气,比如预测云图,天气预报啥的 3- 交通,早晚高峰,堵车啥的,车啥时候加油,啥时候充电之类的 4- 医疗,看你病史和喝酒的剂量建模,看你会…

斯坦福大学ALOHA家务机器人团队发布了最新研究成果—YAY Robot语言交互式操作系统

ALOHA YAY 演示视频-智能佳 斯坦福的ALOHA家务机器人团队,发布了最新研究成果—Yell At Your Robot(简称YAY),有了它,机器人的“翻车”动作,只要喊句话就能纠正了! 标ALOHA2协作平台题 而且机器…

Shell脚本基本命令

文件名后缀.sh 编写shell脚本一定要说明一下在#!/bin/bash在进行编写。命令选项空格隔开。Shell脚本是解释的语言,bash 文件名即可打印出编写的脚本。chmod给权限命令。如 chmod 0777 文件名意思是给最高权限。 注意:count赋值不能加空格。取消变量可在变…

基于Spring 框架中的@Async 注解实现异步任务

Async 是 Spring 框架中的一个注解,用于实现方法级别的异步执行。使用 Async 可以让你的代码在非当前线程中执行,从而提高应用的并发性能。 1、 启用异步支持 在 Spring 应用的主配置类或任何其他配置类上添加 EnableAsync 注解来开启异步任务的支持 …

13.Redis之数据库管理redis客户端JAVA客户端

1.数据库管理 mysql 中有一个重要的概念,database 1个 mysql 服务器上可以有很多个 database1个 database 上可以有很多个 表mysql 上可以随心所欲的 创建/删除 数据库~~ Redis 提供了⼏个⾯向 Redis 数据库的操作,分别是 dbsize、select、flushdb、flushall 命令…

2024年中国金融行业网络安全市场全景图

网络安全一直是国家安全的核心组成部分,特别是在金融行业,金融机构拥有大量的敏感数据,包括个人信息、交易记录、财务报告等,这些数据的安全直接关系到消费者的利益和金融市场的稳定,因此金融行业在网络安全建设领域一…

Java—选择排序

选择排序是一种简单但高效的排序算法。它的基本思想是从未排序的部分中选择最小(或最大)的元素,并将其放置在已排序部分的末尾。 实现步骤 具体实现选择排序的步骤如下: 遍历数组:从数组的第一个元素开始&#xff0…

cesium绘制区域编辑

npm 安装也是可以的 #默认安装最新的 yarn add cesium#卸载插件 yarn remove cesium#安装指定版本的 yarn add cesium1.96.0#安装指定版本到测试环境 yarn add cesium1.96.0 -D yarn install turf/turf <template><div id"cesiumContainer"></div&…

从零开始学React--环境搭建

React官网 快速入门 – React 中文文档 1.搭建环境 下载nodejs,双击安装 nodejs下载地址 更新npm npm install -g npm 设置npm源&#xff0c;加快下载速度 npm config set registry https://registry.npmmirror.com 创建一个react应用 npx create-react-app react-ba…

QT 自定义协议TCP传输文件

后面附带实例的下载地址 一、将文件看做是由:文件头+文件内容组成,其中文件头包含文件的一些信息:文件名称、文件大小等。 二、文件头单独发送,文件内容切块发送。 三、每次发送信息格式:发送内容大小、发送内容类型(文件头或是文件块内容)、文件块内容。 四、效果展…

项目十三:搜狗——python爬虫实战案例

根据文章项目十二&#xff1a;简单的python基础爬虫训练-CSDN博客的简单应用&#xff0c;这一次来升级我们的技术&#xff0c;那么继续往下看&#xff0c;希望对技术有好运。 还是老样子&#xff0c;按流程走&#xff0c;一条龙服务&#xff0c;嘿嘿。 第一步&#xff1a;导入…

设置AXI主寄存器切片和AXI数据FIFO

设置AXI主寄存器切片和AXI数据FIFO 打开MHS文件&#xff0c;并为每个AXI主机设置启用寄存器切片/启用数据FIFO。到 确定正确的设置&#xff0c;使用下表中的信息搜索MHS。 进行搜索时&#xff0c;将<intf_name>替换为相关的BUS_INTERFACE名称。 例如&#xff0c;BUS_INTE…