PyTorch常用代码段汇总

news2024/12/30 2:32:54

本文是PyTorch常用代码段合集,涵盖基本配置、张量处理、模型定义与操作、数据处理、模型训练与测试等5个方面,还给出了多个值得注意的Tips,内容非常全面。

PyTorch最好的资料是官方文档。本文是PyTorch常用代码段,在参考资料[1](张皓:PyTorch Cookbook)的基础上做了一些修补,方便使用时查阅

基本配置

导入包和版本查询

import torch
import torch.nn as nn
import torchvision
print(torch.__version__)
print(torch.version.cuda)
print(torch.backends.cudnn.version())
print(torch.cuda.get_device_name(0))

可复现性

在硬件设备(CPU、GPU)不同时,完全的可复现性无法保证,即使随机种子相同。但是,在同一个设备上,应该保证可复现性。具体做法是,在程序开始的时候固定torch的随机种子,同时也把numpy的随机种子固定。

np.random.seed(0)
torch.manual_seed(0)
torch.cuda.manual_seed_all(0)

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

显卡设置

  • 如果只需要一张显卡
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  • 如果需要指定多张显卡,比如0,1号显卡。
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
  • 也可以在命令行运行代码时设置显卡:
CUDA_VISIBLE_DEVICES=0,1 python train.py
  • 清除显存
torch.cuda.empty_cache()
  • 也可以使用在命令行重置GPU的指令
nvidia-smi --gpu-reset -i [gpu_id]

张量(Tensor)处理

张量的数据类型

PyTorch有9种CPU张量类型和9种GPU张量类型。
在这里插入图片描述

张量基本信息

tensor = torch.randn(3,4,5)
print(tensor.type())  # 数据类型
print(tensor.size())  # 张量的shape,是个元组
print(tensor.dim())   # 维度的数量

命名张量

张量命名是一个非常有用的方法,这样可以方便地使用维度的名字来做索引或其他操作,大大提高了可读性、易用性,防止出错。

# 在PyTorch 1.3之前,需要使用注释
# Tensor[N, C, H, W]
images = torch.randn(32, 3, 56, 56)
images.sum(dim=1)
images.select(dim=1, index=0)

# PyTorch 1.3之后
NCHW = [‘N’, ‘C’, ‘H’, ‘W’]
images = torch.randn(32, 3, 56, 56, names=NCHW)
images.sum('C')
images.select('C', index=0)
# 也可以这么设置
tensor = torch.rand(3,4,1,2,names=('C', 'N', 'H', 'W'))
# 使用align_to可以对维度方便地排序
tensor = tensor.align_to('N', 'C', 'H', 'W')

数据类型转换

# 设置默认类型,pytorch中的FloatTensor远远快于DoubleTensor
torch.set_default_tensor_type(torch.FloatTensor)
# 类型转换
tensor = tensor.cuda()
tensor = tensor.cpu()
tensor = tensor.float()
tensor = tensor.long()

torch.Tensor与np.ndarray转换

除了CharTensor,其他所有CPU上的张量都支持转换为numpy格式然后再转换回来。

ndarray = tensor.cpu().numpy()
tensor = torch.from_numpy(ndarray).float()
tensor = torch.from_numpy(ndarray.copy()).float() # If ndarray has negative stride.

Torch.tensor与PIL.Image转换

# pytorch中的张量默认采用[N, C, H, W]的顺序,并且数据范围在[0,1],需要进行转置和规范化
# torch.Tensor -> PIL.Image
image = PIL.Image.fromarray(torch.clamp(tensor*255, min=0, max=255).byte().permute(1,2,0).cpu().numpy())
image = torchvision.transforms.functional.to_pil_image(tensor)  # Equivalently way

# PIL.Image -> torch.Tensor
path = r'./figure.jpg'
tensor = torch.from_numpy(np.asarray(PIL.Image.open(path))).permute(2,0,1).float() / 255
tensor = torchvision.transforms.functional.to_tensor(PIL.Image.open(path)) # Equivalently way

np.ndarray与PIL.Image的转换

image = PIL.Image.fromarray(ndarray.astype(np.uint8))
ndarray = np.asarray(PIL.Image.open(path))

从只包含一个元素的张量中提取值

value = torch.rand(1).item()

张量形变

# 在将卷积层输入全连接层的情况下通常需要对张量做形变处理,
# 相比torch.view,torch.reshape可以自动处理输入张量不连续的情况

tensor = torch.rand(2,3,4)
shape = (6, 4)
tensor = torch.reshape(tensor, shape)

打乱顺序

tensor = tensor[torch.randperm(tensor.size(0))]  # 打乱第一个维度

水平翻转

# pytorch不支持tensor[::-1]这样的负步长操作,水平翻转可以通过张量索引实现
# 假设张量的维度为[N, D, H, W].

tensor = tensor[:,:,:,torch.arange(tensor.size(3) - 1, -1, -1).long()]

复制张量

# Operation                
 |  New/Shared memory | Still in computation graph |tensor.clone()           
  # |        New         |          Yes               |tensor.detach()           # |      Shared        |          No                |tensor.detach.clone()()   # |        New         |          No                |

张量拼接

'''
注意torch.cat和torch.stack的区别在于torch.cat沿着给定的维度拼接,
而torch.stack会新增一维。例如当参数是3个10x5的张量,torch.cat的结果是30x5的张量,
而torch.stack的结果是3x10x5的张量。
'''
tensor = torch.cat(list_of_tensors, dim=0)
tensor = torch.stack(list_of_tensors, dim=0)

将整数标签转为one-hot编码

# pytorch的标记默认从0开始
tensor = torch.tensor([0, 2, 1, 3])
N = tensor.size(0)
num_classes = 4
one_hot = torch.zeros(N, num_classes).long()
one_hot.scatter_(dim=1, index=torch.unsqueeze(tensor, dim=1), src=torch.ones(N, num_classes).long())

得到非零元素

torch.nonzero(tensor)               # index of non-zero elements
torch.nonzero(tensor==0)            # index of zero elements
torch.nonzero(tensor).size(0)       # number of non-zero elements
torch.nonzero(tensor == 0).size(0)  # number of zero elements

判断两个张量相等

torch.allclose(tensor1, tensor2)  # float tensor
torch.equal(tensor1, tensor2)     # int tensor

张量扩展

# Expand tensor of shape 64*512 to shape 64*512*7*7.
tensor = torch.rand(64,512)
torch.reshape(tensor, (64, 512, 1, 1)).expand(64, 512, 7, 7)

矩阵乘法

# Matrix multiplcation: (m*n) * (n*p) * -> (m*p).
result = torch.mm(tensor1, tensor2)

# Batch matrix multiplication: (b*m*n) * (b*n*p) -> (b*m*p)
result = torch.bmm(tensor1, tensor2)

# Element-wise multiplication.
result = tensor1 * tensor2

计算两组数据之间的两两欧式距离

利用广播机制

dist = torch.sqrt(torch.sum((X1[:,None,:] - X2) ** 2, dim=2))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/779012.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【AutoSAR 架构介绍】

AutoSAR简介 AUTOSAR是Automotive Open System Architecture(汽车开放系统架构)的首字母缩写,是一家致力于制定汽车电子软件标准的联盟。 AUTOSAR是由全球汽车制造商、部件供应商及其他电子、半导体和软件系统公司联合建立,各成…

ubuntu 静态IP设置

ubuntu 静态IP设置: 1.输入: sudo vim /etc/netplan/01-network-manager-all.yaml Let NetworkManager manage all devices on this system network: ethernets: ens33: dhcp4: no addresses: [192.168.1.119/24] gateway4: 192.168.1.1 nameservers: …

代码随想录额外题目| 数组02 ●189旋转数组 ●724寻找数组中心索引

#189旋转数组 很快写出来但是用了个新数组&#xff0c;不好 void rotate(vector<int>& nums, int k) {vector<int> res(nums.size(),0);for(int i0;i<nums.size();i){int newiik;if(newi>nums.size()-1) newinewi%nums.size();res[newi]nums[i];}numsr…

结构型设计模式之桥接模式【设计模式系列】

系列文章目录 C技能系列 Linux通信架构系列 C高性能优化编程系列 深入理解软件架构设计系列 高级C并发线程编程 设计模式系列 期待你的关注哦&#xff01;&#xff01;&#xff01; 现在的一切都是为将来的梦想编织翅膀&#xff0c;让梦想在现实中展翅高飞。 Now everythi…

Vue3状态管理库Pinia——核心概念(Store、State、Getter、Action)

个人简介 &#x1f440;个人主页&#xff1a; 前端杂货铺 &#x1f64b;‍♂️学习方向&#xff1a; 主攻前端方向&#xff0c;正逐渐往全干发展 &#x1f4c3;个人状态&#xff1a; 研发工程师&#xff0c;现效力于中国工业软件事业 &#x1f680;人生格言&#xff1a; 积跬步…

行为型模式 - 迭代器模式

概述 定义&#xff1a; 提供一个对象来顺序访问聚合对象中的一系列数据&#xff0c;而不暴露聚合对象的内部表示。 结构 迭代器模式主要包含以下角色&#xff1a; 抽象聚合&#xff08;Aggregate&#xff09;角色&#xff1a;定义存储、添加、删除聚合元素以及创建迭代器对象…

Mind+积木编程控制小水泵给宠物喂水

前期用scratch&#xff0c;带着小朋友做了大鱼吃小鱼、桌面弹球、小学生计算器3个作品&#xff0c;小朋友收获不小。关键是小家伙感兴趣&#xff0c;做出来后给家人炫耀了一圈后&#xff0c;兴趣大增&#xff0c;嚷嚷着要做更好玩的。 最近&#xff0c;娃妈从抖音上买了个小猫喝…

JMeter 配置环境变量步骤

通过给 JMeter 配置环境变量&#xff0c;可以快捷的打开 JMeter&#xff1a; 打开终端。执行 jmeter。 配置环境变量的方法如下。 Mac 和 Linux 系统 1、在 ~/.bashrc 中加如下内容&#xff1a; export JMETER_HOMEJMeter所在目录 export PATH$JAVA_HOME/bin:$PATH:.:$JME…

pytorch安装GPU版本 (Cuda12.1)教程: Windows、Mac和Linux系统下GPU版PyTorch(CUDA 12.1)快速安装

&#x1f337;&#x1f341; 博主 libin9iOak带您 Go to New World.✨&#x1f341; &#x1f984; 个人主页——libin9iOak的博客&#x1f390; &#x1f433; 《面试题大全》 文章图文并茂&#x1f995;生动形象&#x1f996;简单易学&#xff01;欢迎大家来踩踩~&#x1f33…

【单调栈 +前缀和】AcWing 4738. 快乐子数组

原题链接 原题链接 相关算法概念介绍 前缀和&#xff08;Prefix Sum&#xff09; 前缀和是指将数组中从开头位置到当前位置的所有元素累加得到的新数组。通常&#xff0c;我们使用一个额外的数组来保存这些累加和&#xff0c;这个数组被称为前缀和数组。对于原始数组A&…

Appium+python自动化(十七)- - Monkey

1、Monkey简介 在Android的官方自动化测试领域有一只非常著名的“猴子”叫Monkey&#xff0c;这只“猴子”一旦启动&#xff0c;就会让被测的Android应用程序像猴子一样活蹦乱跳&#xff0c;到处乱跑。人们常用这只“猴子”来对被测程序进行压力测试&#xff0c;检查和评估被测…

快速排序QuickSort

目录 1.Hoare法 2.挖坑法 3.前后指针法 4.快排分治 5.关于快排 6.关于快排的优化 7.总体实现 总结&#xff1a; 快速排序是Hoare于1962年提出的一种二叉树结构的交换排序方法 其基本思想为&#xff1a;任取待排序元素序列中的某元素作为基准值&#xff0c;按照该排序码…

《5.linux驱动开发-第2部分-5.2.字符设备驱动基础》5.2.5.用开发板来调试模块

1. 首先 开发板 可以运行 Uboot 2. Ubuntu 安装好了 t f t p(启动内核zImage) 和 NFS &#xff08;挂载 根文件系统&#xff09; 3. 提前 制作好了 根文件系统&#xff08;2022年做的&#xff0c;早就忘记 怎么做了&#xff09; 4.内核 需要设置 nfs 作为根文件系统 启动…

聊聊spring-cloud的负载均衡

聊聊spring-cloud的负载均衡 1. 选择合适的负载均衡算法2. 合理设置超时时间3. 缓存服务实例列表4. 使用断路器5. 使用缓存Spring Cloud负载均衡组件对比RibbonLoadBalancerWebClient对比 总结 在微服务架构中&#xff0c;负载均衡是非常重要的一个环节&#xff0c;可以有效地提…

ES6基础知识三:对象新增了哪些扩展?

一、属性的简写 ES6中&#xff0c;当对象键名与对应值名相等的时候&#xff0c;可以进行简写 const baz {foo:foo}// 等同于 const baz {foo}方法也能够进行简写 const o {method() {return "Hello!";} };// 等同于const o {method: function() {return "…

C# List 详解四

目录 18.FindLast(Predicate) 19.FindLastIndex(Int32, Int32, Predicate) 20.FindLastIndex(Int32, Predicate) 21.FindLastIndex(Predicate) 22.ForEach(Action) 23.GetEnumerator() 24.GetHashCode() 25.GetRange(Int32, Int32) C#…

协作实现时序数据高效流转链路 | 7.20 IoTDB X RocketMQ 技术沙龙线上直播回顾

7 月 20 日&#xff0c;IoTDB X RocketMQ 技术沙龙线上直播圆满结束。工业物联网时序数据库研发商天谋科技、云原生事件流平台 Apache RocketMQ 社区的四位技术专家&#xff0c;针对实时数据接入、多样数据处理与系统的高扩展、高可靠特性的数据流转处理平台实现难点&#xff0…

计算机服务器被devos勒索病毒攻击怎么解决,数据库解密恢复方式

科学技术的发展为企业的生产运行提供了极大的便利性&#xff0c;但随之而来的网络安全也应该引起人们的重视。近期&#xff0c;我们收到很多企业的求助&#xff0c;企业的计算机服务器内的数据库被devos后缀勒索病毒攻击&#xff0c;导致企业许多工作无法正常运行。Devos后缀勒…

89、简述RabbitMQ的架构设计

简述RabbitMQ的架构设计 BrokerQueueExchangeRoutingKeyBinding信道架构设计图 Broker RabbitMQ的服务节点 Queue 队列&#xff0c;是RabbitMQ的内部对象&#xff0c;用于存储消息。RabbitMQ中消息只能存储在队列中。生产者投递消息到队列&#xff0c;消费者从队列中获取消息…

Sql Developer日期显示格式问题

sqldeveloper模式日期显示不是很美观 并且使用日期条件查询需要将月份转为中文&#xff0c;系统兼容性差 容易以前如下报错 ORA-01861: 文字与格式字符串不匹配 01861. 00000 - "literal does not match format string"-- sqldeveloper 中执行日期条件 &#xff08;…