AI学习记录 -transformer 中对于torch和numpy常用函数的使用方式

news2024/10/2 5:22:47

在transformer源码中,使用了很多矩阵变换的方法,这些方法太多了,了解底层也没啥意义,也不是啥特别复杂的算法。

所以争取一句话描述这些方法,对照着看transformer的时候,可以衔接自己的思维链。

torch.unsqueeze

在指定维度增加一个维度

import torch

# 创建一个 1D 张量
tensor_1d = torch.tensor([1, 2, 3, 4])
print("原始张量:", tensor_1d)

# 在维度 0 处增加一个维度
tensor_2d = torch.unsqueeze(tensor_1d, dim=0)
print("在维度 0 处增加维度:", tensor_2d)

# 在维度 1 处增加一个维度
tensor_2d_alt = torch.unsqueeze(tensor_1d, dim=1)
print("在维度 1 处增加维度:", tensor_2d_alt)

打印:

原始张量: tensor([1, 2, 3, 4])
在维度 0 处增加维度: tensor([[1, 2, 3, 4]])
在维度 1 处增加维度: tensor([
        [1],
        [2],
        [3],
        [4]
 ])

.shape[n]

获取指定维度的形状

torch.zeros

创建全 0 张量

import torch
# 创建一个 2x3 的全零张量
zero_tensor = torch.zeros(2, 3)
print(zero_tensor)
tensor([[0., 0., 0.],
        [0., 0., 0.]])

torch.arange

可以生成有数值叠加的数组

import torch

# 创建从 0 到 9 的一维张量
tensor_1d = torch.arange(10)
print(tensor_1d)

# 创建从 1 到 9 的一维张量,步长为 2
tensor_1d_step = torch.arange(1, 10, 2)
print(tensor_1d_step)

输出:

tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
tensor([1, 3, 5, 7, 9])

torch.transpose

交换张量的指定维度

import torch

# 创建一个 2x3 的张量
tensor = torch.tensor([[1, 2, 3],
                       [4, 5, 6]])
print("原始张量:")
print(tensor)

# 转置张量
transposed_tensor = torch.transpose(tensor, 0, 1)
print("转置后的张量:")
print(transposed_tensor)

打印:

原始张量:
tensor([[1, 2, 3],
        [4, 5, 6]])
转置后的张量:
tensor([[1, 4],
        [2, 5],
        [3, 6]])

torch.exp

1、将数组每一项 转成 e 的 x 次方。
2、激活函数:在神经网络中,通常用于计算 softmax 函数的一部分。

import torch

# 创建一个张量
tensor = torch.tensor([0.0, 1.0, 2.0])
# 计算每个元素的指数
exp_tensor = torch.exp(tensor)
print(exp_tensor)

输出

tensor([ 1.0000,  2.7183,  7.3891])

tensor.eq(0)

将整数数组转成true和false

import torch

# 创建一个张量
tensor = torch.tensor([1, 0, 2, 0, 3])

# 使用 .eq(0) 检查哪些元素等于 0
zero_check = tensor.eq(0)
print(zero_check)

输出

tensor([False,  True, False,  True, False])

numpy.ones

创建一个指定形状的数组,并将所有元素初始化为 1。

import numpy as np

# 创建一个 2x3 的数组,所有元素为 1
ones_array = np.ones((2, 3))
print(ones_array)

# 创建一个 3 维的数组
ones_3d_array = np.ones((2, 2, 2), dtype=int)
print(ones_3d_array)

输出

[[1. 1. 1.]
 [1. 1. 1.]]

[[[1 1]
  [1 1]]

 [[1 1]
  [1 1]]]

numpy.triu

返回一个矩阵的上三角部分,其他元素设为 0。

import numpy as np

# 创建一个 3x3 的矩阵
matrix = np.array([[1, 2, 3],
                   [4, 5, 6],
                   [7, 8, 9]])

# 获取上三角部分
upper_triangle = np.triu(matrix)
print(upper_triangle)

输出

[[1 2 3]
 [0 5 6]
 [0 0 9]]

torch.from_numpy

将 NumPy 数组转换为 PyTorch 张量。

import numpy as np
import torch

# 创建一个 NumPy 数组
numpy_array = np.array([[1, 2, 3], [4, 5, 6]])

# 将 NumPy 数组转换为 PyTorch 张量
torch_tensor = torch.from_numpy(numpy_array)
print(torch_tensor)

torch.matmul 点乘,点乘的性质在某些场景下可以代表批量点积,在transformer中会体现

不管多少维度的矩阵,都是最后两个维度进行矩阵乘法

import torch

# 创建张量 A 和 B
A = torch.randn(2, 3, 4)  # 形状为 (2, 3, 4)
B = torch.randn(2, 4, 5)  # 形状为 (2, 4, 5)

# 使用 torch.matmul 进行矩阵乘法
C = torch.matmul(A, B)  # 结果形状为 (2, 3, 5)

print(C.shape)  # 输出: torch.Size([2, 3, 5])

torch.dot 点积

import torch

# 创建两个一维张量
a = torch.tensor([1.0, 2.0, 3.0])
b = torch.tensor([4.0, 5.0, 6.0])

# 计算点积
dot_product = torch.dot(a, b)

print(dot_product)  # 输出: tensor(32.)

numpy.sqrt

对传入的数组进行逐元素的平方根操作

import numpy as np

# 创建一个数组
arr = np.array([1, 4, 9, 16])

# 计算平方根
sqrt_arr = np.sqrt(arr)

print(sqrt_arr)  # 输出: [1. 2. 3. 4.]

torch.masked_fill_

将矩阵中为 True 的位置替换为 -1e9(也就是负无穷大)

import torch

# 创建一个示例张量
scores = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])

# 创建一个掩码张量(布尔型)
mask = torch.tensor([[False, True, False], [False, False, True]])

# 使用 masked_fill_ 来掩盖某些元素
scores.masked_fill_(mask, -1e9)

print(scores)
# 输出:
# tensor([[ 1.0000e+00, -1.0000e+09,  3.0000e+00],
#         [ 4.0000e+00,  5.0000e+00, -1.0000e+09]])

torch.view

重新调整张量的展示方向,不会改变形状

import torch

# 创建一个 2x3 的张量
tensor = torch.tensor([[1, 2, 3], [4, 5, 6]])

# 使用 view 重新调整为 3x2 的张量
reshaped_tensor = tensor.view(3, 2)

print(reshaped_tensor)
# 输出:
# tensor([[1, 2],
#         [3, 4],
#         [5, 6]])

torch.repeat

在指定维度重复张量并设置重复次数

import torch

# 创建一个 2x2 的张量
tensor = torch.tensor([[1, 2], [3, 4]])

# 第1维度重复2词,第2维度重复3次
repeated_tensor = tensor.repeat(2, 3)

print(repeated_tensor)
# 输出:
# tensor([[1, 2, 1, 2, 1, 2],
#         [3, 4, 3, 4, 3, 4],
#         [1, 2, 1, 2, 1, 2],
#         [3, 4, 3, 4, 3, 4]])

tensor.reshape

调整矩阵的展示方向,不会改变形状

import torch

# 创建一个 2x3 的张量
tensor = torch.tensor([[1, 2, 3], [4, 5, 6]])

# 使用 reshape 重新调整为 3x2 的张量
reshaped_tensor = tensor.reshape(3, 2)

print(reshaped_tensor)
# 输出:
# tensor([[1, 2],
#         [3, 4],
#         [5, 6]])

nn.ModuleList

有点像数组存了很多函数,然后函数返回当作是下个函数的输入

import torch
import torch.nn as nn

# 定义一个简单的 DecoderLayer 类,继承自 nn.Module
class DecoderLayer(nn.Module):
    def __init__(self):
        super(DecoderLayer, self).__init__()
        self.linear = nn.Linear(10, 10)
    
    def forward(self, x):
        return self.linear(x)

# 定义网络时,我们创建 n_layers 个 DecoderLayer
n_layers = 3
decoder_layers = nn.ModuleList([DecoderLayer() for _ in range(n_layers)])

# 输入数据
x = torch.randn(5, 10)  # 假设输入张量大小为 (5, 10)

# 逐个层进行前向传播
for layer in decoder_layers:
    x = layer(x)

print(x)

torch.gt

对比两个张量里面每个数字的大小

import torch

# 创建两个张量
a = torch.tensor([1, 2, 3, 4])
b = torch.tensor([2, 2, 2, 2])

# 比较 a 是否大于 b
result = torch.gt(a, b)

print(result)
# 输出:
# tensor([False, False, True, True])

torch 中什么是连续 torch 中什么是不连续

可以用一种迭代方式构造出想要的矩阵,那就是连续的 。可以用一种迭代方式构造不出想要的矩阵,那就是连续的。

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2183822.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

python访问SQL数据库

Step 1 创建一个数据库 Step 2 安装mysql-connector-python pip install mysql-connector-pythonStep 3 访问mysql,并查询表 import mysql.connectordef connect_and_query():try:# 连接到MySQL数据库connection mysql.connector.connect(hostlocalhost, # 数据库主机…

闯关训练一:Linux基础

闯关任务:完成SSH连接与端口映射并运行hello_world.py 1.创建开发机 2.SSH连接 3. VS-Code 连接 选择 Linux 平台 ,输入密码 ,选择进入文件夹 4.端口映射 按照下文安装Docs pip install gradio 运行server.py import gradio as grdef …

Python核心知识:pip使用方法大全

什么是 pip? pip 是 Python 的包管理工具,允许用户安装、升级和管理 Python 的第三方库和依赖。它极大地简化了开发过程,使开发者可以轻松地获取并安装所需的软件包。pip 已成为 Python 项目中最常见的包管理工具,并且自 Python …

【Linux】几种常见配置文件介绍

配置文件目录 linux 系统中有很多配置文件目录/etc/systemd/system、/lib/systemd/system 以及/usr/lib/systemd/system 等,这三者有什么样的关系呢? 以下是网络上找的资料汇总,并加了一些操作验证。方便后期使用 介绍 目录/lib/systemd/s…

虚拟机窗口顶部和底部出现白边(鸿蒙开发)

预览窗口顶部和底部出现白边 问题描述:预览窗口顶部和底部都有白边,导致无法全屏显示 解决方法: 官方文档:https://developer.huawei.com/consumer/cn/doc/harmonyos-faqs-V5/faqs-previewer-operating-6-V5 这里官方文档给了两种…

【有啥问啥】AI中的数据融合(Data Fusion):让数据“1+1>2”

AI中的数据融合(Data Fusion):让数据“11>2” 引言 在人工智能(AI)的浪潮中,数据作为驱动创新的核心要素,其重要性不言而喻。随着物联网(IoT)、传感器技术和云计算的…

基于单片机远程家电控制系统设计

本设计基于单片机的远程家电控制系统,以STC89C52单片机为核心,通过液晶LCD1602实时显示并控制,利用ESP8266WiFi模块实现本地与云平台的连接,最终实现远程对于灯光,热水器等家电的开关控制。同时,系统设有防…

pdf怎么编辑修改内容?详细介绍6款pdf编辑器功能

■ pdf怎么编辑修改内容? PDF(Portable Document Format)作为一种广泛使用的文件格式,具有特点包括兼容性强、易于传输、文件安全性高、跨平台性、可读性强、完整性、可搜索性、安全性、可压缩性。 PDF文件本身是不可以直接进行编…

认知杂谈73《成年人的修炼:勇敢前行,积极向上》

内容摘要: 成长是成年人的必修课,它要求我们不断学习、面对挑战、做出选择、调整行动。成长的必要性在于适应社会、实现自我价值。实现成长的策略包括自我掌舵、自救、为结果负责、保持积极心态。 追求艺术或商业目标、自己解决问题、承担责任、换个角度…

OpenAI o1:使用限额提高,o1 模型解析

OpenAI 最新公告 OpenAI 近日宣布对 o1-mini 和 o1-preview 的消息使用限额进行了提升,让 Plus 和 Team 用户可以更频繁地体验 o1 系统。具体来说,o1-mini 的限额从每周 50 条增加到了每天 50 条,而 o1-preview 的限额则从每周 30 条提升到了…

【算法】链表:21.合并两个有序链表(easy)

系列专栏 《分治》 《模拟》 《Linux》 目录 1、题目链接 2、题目介绍 3、解法(双指针) 4、代码 1、题目链接 21. 合并两个有序链表 - 力扣(LeetCode) 2、题目介绍 3、解法(双指针) 推荐一篇题解…

Arduino UNO R3自学笔记13 之 Arduino使用LM35如何测量温度?

注意:学习和写作过程中,部分资料搜集于互联网,如有侵权请联系删除。 前言:学习使用传感器测温。 1.LM35介绍 一般来讲当知道需求,就可以 通过既定要求的条件来筛选需要的器件,多方面的因素最终选定了器件…

c语言实例

大家好,欢迎来到无限大的频道 今天给大家带来的是c语言 题目描述 创建一个双链表,并将链表中的数据输出到控制台,输入要查找的数据,将查找到的数据删除,并且显示删除后的链表 下面是一个用C语言实现的双链表&#…

数据结构-4.2.串的定义和基本操作

一.串的定义: 1.单/双引号不是字符串里的内容,他只是一个边界符,用来表示字符串的头和尾; 2.空串也是字符串的子串,空串长度为0; 3.字符的编号是从1开始,不是0; 4.空格也是字符&a…

Windows 11 安装配置 Git 教程

目录 Git Windows 11 环境安装配置 Git Git Git是一个开源的分布式版本控制系统,由Linus Torvalds创建,用于有效、高速地处理从小到大的项目版本管理。Git是目前世界上最流行的版本控制系统,广泛应用于软件开发中。 以下是Git的一些关键特…

Python空间地表联动贝叶斯地震风险计算模型

🎯要点 使用贝叶斯推断模型兼顾路径和场地效应,量化传统地理统计曲线拟合技术。使用破裂和场地特征等地质信息以及事件间残差和事件内残差描述数学模型模型使用欧几里得距离度量、角距离度量和土壤差异性度量确定贝叶斯先验分布和后验分布参数&#xff…

使用Qt实现实时数据动态绘制的折线图示例

基于Qt的 QChartView 和定时器来动态绘制折线图。它通过动画的方式逐步将数据点添加到图表上,并动态更新坐标轴的范围,提供了一个可以实时更新数据的折线图应用。以下是对代码的详细介绍及其功能解析: 代码概述 该程序使用Qt的 QChartView…

【Python报错已解决】 Encountered error while trying to install package.> lxml

🎬 鸽芷咕:个人主页 🔥 个人专栏: 《C干货基地》《粉丝福利》 ⛺️生活的理想,就是为了理想的生活! 专栏介绍 在软件开发和日常使用中,BUG是不可避免的。本专栏致力于为广大开发者和技术爱好者提供一个关于BUG解决的经…

黑马linux笔记(转载)

学习链接 视频链接:黑马程序员新版Linux零基础快速入门到精通 原文链接:黑马程序员新版Linux零基础快速入门到精通——学习笔记 黑马Linux笔记 文章目录 学习链接01初识Linux1.1、操作系统概述1.1.1、硬件和软件1.1.2、操作系统1.1.3、常见操作系统 1.…

10/01赛后总结

T1隔离 题目传送门:隔离http://bbcoj.cn/contest/1027/problem/1 实在是太刁钻了,有两种情况没有考虑: 1.隔离后做完全部的是再回去 2.在路程上花的时间如果大于在隔离一次花的时间,那还不如隔离,然后做完全部的事…