Pytorch 张量运算函数(补充)

news2024/11/24 22:28:35
mean()

mean()函数是进行张量均值计算的函数,常用参数可以设置参数dim来进行对应维度的均值计算

以下是使用一个二维张量进行演示的例子

import numpy as np
import torch
device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
print(device
)
data1 = torch.randint(0,10,(2,3),dtype=torch.float).to(device)
print(data1)
print(data1.mean())
print(data1.mean(dim = 0))
print(data1.mean(dim = 1))
# mps
# tensor([[2., 8., 2.],
#         [7., 3., 7.]], device='mps:0')
# tensor(4.8333, device='mps:0')
# tensor([4.5000, 5.5000, 4.5000], device='mps:0')
# tensor([4.0000, 5.6667], device='mps:0')

可以看到在不指定dim维度的情况下,mean()函数会对所有张量元素进行求和后的均值计算,结果是一个标量张量

在指定了dim为0后,均值计算会沿着行方向去求每一列的均值

在指定了dim为1后,均值计算会沿着列方向去求每一行的均值

sum()

sum()函数为求和函数,同样类似于mean()函数,可以指定参数dim来进行指定维度上的求和计算

以下同样是一个二维张量的演示例子

import numpy as np
import torch
device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
print(device
)
data1 = torch.randint(0,10,(2,3),dtype=torch.float).to(device)
print(data1)
print(data1.sum())
print(data1.sum(dim = 0))
print(data1.sum(dim = 1))
# mps
# tensor([[3., 4., 1.],
#         [8., 6., 0.]], device='mps:0')
# tensor(22., device='mps:0')
# tensor([11., 10.,  1.], device='mps:0')
# tensor([ 8., 14.], device='mps:0')

sum()函数在不指定dim的时候也是对所有张量元素求和计算

在指定了dim为0后,求和计算会沿着行方向去求每一列的和

在指定了dim为1后,求和计算会沿着列方向去求每一行的和

pow()

pow()函数是对张量进行幂次计算的函数,参数为指定指数值exponoent

import numpy as np
import torch
device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
print(device
)
data1 = torch.randint(0,10,(2,3),dtype=torch.float).to(device)
print(data1)
print(data1.pow(2))
# mps
# tensor([[2., 5., 9.],
#         [4., 2., 7.]], device='mps:0')
# tensor([[ 4., 25., 81.],
#         [16.,  4., 49.]], device='mps:0')


上面的例子中指定了指数为2,底数为张量中的每个元素值

sqrt()

sqrt()函数是用于对张量进行开二次方根计算的,无需参数设置

import numpy as np
import torch
device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
print(device
)
data1 = torch.randint(0,10,(2,3),dtype=torch.float).to(device)
print(data1)
print(data1.sqrt())
# mps
# tensor([[7., 6., 2.],
#         [6., 3., 9.]], device='mps:0')
# tensor([[2.6458, 2.4495, 1.4142],
#         [2.4495, 1.7321, 3.0000]], device='mps:0')



注意,由于sqrt函数无法进行高次方根的计算,所以若有高次方根的计算需求,可以依旧使用pow()函数进行计算,以下为三次方根的计算演示

import numpy as np
import torch
device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
print(device
)
data1 = torch.randint(0,10,(2,3),dtype=torch.float).to(device)
print(data1)
print(data1.pow(1/3))
# mps
# tensor([[2., 7., 8.],
#         [1., 0., 1.]], device='mps:0')
# tensor([[1.2599, 1.9129, 2.0000],
#         [1.0000, 0.0000, 1.0000]], device='mps:0')



exp()

exp()函数适用于计算底数为e(约等于2.71828)的幂次计算,同样不需要参数指定,指数值就为张量的每个元素值.

import numpy as np
import torch
device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
print(device
)
data1 = torch.randint(0,10,(2,3),dtype=torch.float).to(device)
print(data1)
print(data1.exp())
# mps
# tensor([[4., 3., 9.],
#         [0., 6., 9.]], device='mps:0')
# tensor([[5.4598e+01, 2.0086e+01, 8.1031e+03],
#         [1.0000e+00, 4.0343e+02, 8.1031e+03]], device='mps:0')



log()

log()函数是用于对数计算的函数,底数为e,为了方便更改底数常用的还有log2(底数为2),log10(底数为10)

import numpy as np
import torch
device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
print(device
)
data1 = torch.randint(0,10,(2,3),dtype=torch.float).to(device)
print(data1)
print(data1.log())
print(data1.log2())
print(data1.log10())
# mps
# tensor([[0., 7., 8.],
#         [7., 0., 1.]], device='mps:0')
# tensor([[  -inf, 1.9459, 2.0794],
#         [1.9459,   -inf, 0.0000]], device='mps:0')
# tensor([[  -inf, 2.8074, 3.0000],
#         [2.8074,   -inf, 0.0000]], device='mps:0')
# tensor([[  -inf, 0.8451, 0.9031],
#         [0.8451,   -inf, 0.0000]], device='mps:0')

上面分别演示了log,log2,log10也就是底数分别为e,2,10的对数计算结果

在实际情况中我们可能对底数的选择更加灵活,如果要计算任意底数的对数,这里我们就可以用到下面的公式进行计算

\log_b(x) = \frac{\log_c(x)}{\log_c(b)}

这里对任意底数b进行对数计算,都可以转换成另一底数但是真数分别为原真数和广播后的原底数的商

import numpy as np
import torch
device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
print(device
)
data1 = torch.randint(0,10,(2,3),dtype=torch.float).to(device)
print(data1)
print(data1.log()/torch.full_like(data1,3).log())
# mps
# tensor([[9., 1., 6.],
#         [4., 2., 3.]], device='mps:0')
# tensor([[2.0000, 0.0000, 1.6309],
#         [1.2619, 0.6309, 1.0000]], device='mps:0')



本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2069041.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【数据管理】数据治理

目录 1、相关概念 2、数据治理和管理职责语境关系图 3、业务驱动因素 4、目标和原则 5、 数据治理和数据管理的关系 6、数据治理组织 7、数据管理职能 8、数据制度 9、数据资产估值 1、相关概念 1)战略(Stategy):定义、交流和驱动数据战略和数…

[数据集][目标检测]电力场景输电线异物检测数据集VOC+YOLO格式2060张1类别

数据集格式:Pascal VOC格式YOLO格式(不包含分割路径的txt文件,仅仅包含jpg图片以及对应的VOC格式xml文件和yolo格式txt文件) 图片数量(jpg文件个数):2060 标注数量(xml文件个数):2060 标注数量(txt文件个数):2060 标注…

电脑丢失dll文件一键修复之dll确实损坏影响电脑运行

在使用电脑过程中,DLL文件丢失或损坏是一个常见的问题,它可能导致程序无法正常运行,甚至影响整个系统的稳定性。本文将详细介绍如何一键修复丢失的DLL文件,探讨常见的DLL丢失报错原因,并提供详细的修复步骤和预防措施。…

sklearn回归树

说明:内容来自菜菜的sklearn机器学习和ai生成 回归树 调用对象的参数 class sklearn.tree.DecisionTreeRegressor (criterion’mse’, splitter’best’, max_depthNone, min_samples_split2, min_samples_leaf1, min_weight_fraction_leaf0.0, max_featuresNone…

大数据基础:数仓架构演变

文章目录 数仓架构演变 一、传统离线大数据架构 二、​​​​​​Lambda架构 三、Kappa架构 四、​​​​​​​​​​​​​​混合架构 五、湖仓一体架构 六、流批一体架构 数仓架构演变 20世纪70年代,MIT(麻省理工)的研究员致力于研究一种优化的技术架构&…

Linux shell编程学习笔记75:sed命令——沧海横流任我行(下)

0 前言 在 Linux shell编程学习笔记73:sed命令——沧海横流任我行(上)-CSDN博客文章浏览阅读684次,点赞32次,收藏24次。在大数据时代,我们要面对大量数据,有时需要对数据进行替换、删除、新增、…

OpenCV几何图像变换(9)仿射变换函数warpAffine()的使用

操作系统:ubuntu22.04 OpenCV版本:OpenCV4.9 IDE:Visual Studio Code 编程语言:C11 算法描述 函数是应用一个仿射变换到图像上。 warpAffine 函数使用指定的矩阵对源图像进行仿射变换: dst ( x , y ) src ( M 11 x M 12 y M…

Elasticsearch:使用 ELSER 进行语义搜索 - sparse_vector

Elastic Learned Sparse EncodeR(或 ELSER)是由 Elastic 训练的 NLP 模型,可让你使用稀疏向量表示执行语义搜索。语义搜索不是根据搜索词进行文字匹配,而是根据搜索查询的意图和上下文含义检索结果。 本教程中的说明向你展示了如…

[医疗 AI ] 3D TransUNet:通过 Vision Transformer 推进医学图像分割

[医疗 AI ] 3D TransUNet:通过 Vision Transformer 推进医学图像分割’ 论文地址 - https://arxiv.org/pdf/2310.07781 0. 摘要 医学图像分割在推进医疗保健系统的疾病诊断和治疗计划中起着至关重要的作用。U 形架构,俗称 U-Net,已被证明在…

提高实时多媒体传输效率的三大方法

实时多媒体数据传输面临的挑战 实时多媒体数据的传输具有数据量巨大、对时延和时延抖动高度敏感及能容忍丢分组的特点。然而,当今互联网的网络层协议提供的仅是一种“尽最大努力服务”,对分组的端到端时延、时延抖动和分组丢失率等指标不做任何承诺。这…

MySQL的延迟复制

目录 1 MySQL 延迟复制介绍 1.1 延迟复制语法: 1.2 延迟复制可用于多种用途: 1.3 延迟复制的有关的参数 1.4 延迟复制的操作 2 MySQL 延迟复制 实操 2.1 实验环境 2.2 对 SLAVE --MySQL-3 进行延迟复制操作 2.3 停止相关进程的原因 2.4 实验测试 2.5 动…

Variomes:支持基因组变异筛选的高召回率搜索引擎

《Bioinformatics》2022 Variomes: https://candy.hesge.ch/Variomes Source code: https://github.com/variomes/sibtm-variomes SynVar: https://goldorak.hesge.ch/synvar 文章摘要(Abstract) 动机(Mot…

读软件开发安全之道:概念、设计与实施07密码学(上)

1. 加密工具 1.1. 加密工具之所以没有得到充分使用,就是因为人们往往认为密码学是一个准入门槛极高的专业领域 1.2. 如今的加密学大部分都源自纯数学,所以只要能够正确使用,加密学确实行之有效 1.2.1. 不代表这些算法本身确实无法破解&…

机器学习 | 基于wine数据集的KMeans聚类和PCA降维案例

KMeans聚类:K均值聚类是一种无监督的学习算法,它试图根据数据的相似性对数据进行聚类。无监督学习意味着不需要预测结果,算法只是试图在数据中找到模式。在k均值聚类中,我们指定希望将数据分组到的聚类数。该算法将每个观察随机分…

四大消息队列:Kafka、ActiveMQ、RabbitMQ、RocketMQ对比

四大消息队列:Kafka、ActiveMQ、RabbitMQ、RocketMQ对比 1. 社区活跃度2. 持久化消息3. 技术实现4. 高并发性能5. RabbitMQ与Kafka对比 💖The Begin💖点点关注,收藏不迷路💖 在软件开发中,消息队列&#xf…

【Redis】Redis数据结构——Hash 哈希

哈希 命令hsethgethexistshdelhkeyshvalshgetallhmgethlenhsetnxhincrbyhincrbyfloat命令小结 内部编码使用场景缓存⽅式对⽐ ⼏乎所有的主流编程语⾔都提供了哈希(hash)类型,它们的叫法可能是哈希、字典、关联数组、映射。在 Redis 中&#…

Python furl库:一键搞定复杂URL操作

更多Python学习内容:ipengtao.com 在Web开发和数据处理的过程中,URL的解析、修改和构建是不可避免的操作。然而,直接操作URL字符串不仅繁琐,而且容易出错。Python的furl库提供了一种简单且强大的方法来处理URL,使得URL…

简易的 Websocket + 心跳机制 + 尝试重连

文章目录 演示大纲基础 WebSocket前端: 添加心跳机制前端: 尝试重新连接历史代码 还没有写完,bug 是有的,我在想解决办法了… 演示 大纲 基础的 webSocket 连接前后端:添加心跳机制后端无心跳反应,前端尝试重新连接设置重新连接…

Java 日常反常识踩坑

作者:若渝 本文主要是日常业务开发中自身碰到过跟常识不一致的坑,问题虽然基础,但却可能造成比较大的线上问题。 一、转 BigDecimal 类型时精度丢失 public class Test { public static void main(String[] args) { BigDecimal bi…

算法-分隔链表

一、题目描述 (一) 题目 给你一个链表的头节点 head 和一个特定值 x ,请你对链表进行分隔,使得所有 小于 x 的节点都出现在 大于或等于 x 的节点之前。你应当保留两个分区中每个节点的初始相对位置。 (二) 示例 示例 1: 输入:…