pytorch:to()、device()、cuda()将Tensor或模型移动到指定的设备上

news2024/9/28 21:18:33

将Tensor或模型移动到指定的设备上:tensor.to(‘cuda:0’)

  • 最开始读取数据时的tensor变量copy一份到device所指定的GPU上去,之后的运算都在GPU上进行
  • 在做高维特征运算的时候,采用GPU无疑是比用CPU效率更高,如果两个数据中一个加了.cuda()或者.to(device),而另外一个没有加,就会造成类型不匹配而报错。

1. Tensor.to(device)

功能:将Tensor移动到指定的设备上。

以下代码将Tensor移动到GPU上:
device = torch.device(“cuda:0” if torch.cuda.is_available() else “cpu”)

1.1 修改dtype

a = tensor.to(torch.float64).

  • tensor.dtype : torch.float32
  • a.dtype : torch.float64

1.2 改变device:用字符串形式给出

a = tensor.to('cuda:0').

  • tensor.device : device(type=‘cpu’)
  • a.device : device(type=‘cuda’, index=0)

1.3 改变device:用torch.device给出

cuda0 = torch.device('cuda:0') .
b = tensor.to(cuda0) .

  • tensor.device : device(type=‘cpu’)
  • b.device : device(type=‘cuda’, index=0)

1.4 同时改变device和dtype

c = tensor.to('cuda:0',torch.float64) .
other = torch.randn((), dtype=torch.float64, device=cuda0) .
d = tensor.to(other, non_blocking=True) .

  • tensor.device:device(type=‘cpu’)
  • d :tensor([], device=‘cuda:0’, dtype=torch.float64))

2. model.to(device)

功能:将模型移动到指定的设备上。

使用以下代码将模型移动到GPU上:

import torch
import torch.nn as nn

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(10, 5)
        self.fc2 = nn.Linear(5, 1)
    def forward(self, x):
        x = self.fc1(x)
        x = self.fc2(x)
        return x
        
model = Net()
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model.to(device)

如果有多个GPU,使用以下方法:

if torch.cuda.device_count() > 1:
    model = nn.DataParallel(model,device_ids=[0,1,2])
  
model.to(device)

将由GPU保存的模型加载到GPU上。

device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.to(device)

将由GPU保存的模型加载到CPU上。
torch.load()函数中的map_location参数设置为torch.device('cpu')

device = torch.device('cpu')
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location=device))
model.to(device)

将由CPU保存的模型加载到GPU上。
torch.load()函数中的map_location参数设置为torch.device('cuda')

device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location="cuda:0"))  # Choose whatever GPU device number you want
model.to(device)

参考:PyTorch之Tensor.to(device)和model.to(device)

3. .to(device) 和.cuda()的区别

  • .to(device) 可以指定CPU 或者GPU
  • .cuda() 只能指定GPU

图参考:pytorch中.to(device) 和.cuda()的区别
在这里插入图片描述

官方文档:CUDA SEMANTICS

with torch.cuda.device(1):
    # allocates a tensor on GPU 1
    a = torch.tensor([1., 2.], device=cuda)
 
    # transfers a tensor from CPU to GPU 1
    b = torch.tensor([1., 2.]).cuda()
    # a.device and b.device are device(type='cuda', index=1)
 
    # You can also use ``Tensor.to`` to transfer a tensor:
    b2 = torch.tensor([1., 2.]).to(device=cuda)
    # b.device and b2.device are device(type='cuda', index=1)
  • 两个方法都可以达到同样的效果,在pytorch中,即使是有GPU的机器,它也不会自动使用GPU,而是需要在程序中显示指定。
  • 调用model.cuda(),可以将模型加载到GPU上去。这种方法不被提倡,而建议使用model.to(device)的方式,这样可以显示指定需要使用的计算资源,特别是有多个GPU的情况下。

4. CUDA相关信息查询

import torch
print('CUDA版本:',torch.version.cuda)
print('Pytorch版本:',torch.__version__)
print('显卡是否可用:','可用' if(torch.cuda.is_available()) else '不可用')
print('显卡数量:',torch.cuda.device_count())
print('当前显卡的CUDA算力:',torch.cuda.get_device_capability(0))
print('当前显卡型号:',torch.cuda.get_device_name(0))
>>>
CUDA版本: 11.7
Pytorch版本: 1.13.1
显卡是否可用: 可用
显卡数量: 1
当前显卡的CUDA算力: (8, 6)
当前显卡型号: NVIDIA GeForce RTX 3060 Laptop GPU

参考:https://blog.csdn.net/weixin_43845386/article/details/131723010

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1309696.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Vue3安装使用Mock.js--解决跨域

首先使用axios发送请求到模拟服务器上,再将mock.js模拟服务器数据返回给客户端。打包工具使用的是vite。 1.安装 npm i axios -S npm i mockjs --save-dev npm i vite-plugin-mock --save-dev 2.在vite.config.js文件中配置vite-plugin-mock等消息 import { viteMo…

数据库——存储过程及游标

智能2112杨阳 一、目的与要求: 1、掌握存储过程的工作原理、定义及操作方法 2、掌握函数的工作原理、定义及操作方法 3、掌握游标的工作原理、定义及操作方法 二、内容: 1. 创建存储过程,用来自动统计给定订单号的订单总金额 源码&…

.NET 反射优化的经验分享

比如针对 GetCustomAttributes 通过反射获取属性的优化,以下例子 // dotnet run -c Release -f net7.0 --filter "*" --runtimes net7.0 net8.0public class Tests{public object[] GetCustomAttributes() => typeof(C).GetCustomAttributes(typeof(MyAttribute…

Windows server 2016 如何禁止系统自动更新

1.打开“运行”,输入cmd,点击“确定”。 2.输入sconfig,然后按回车键。 3.输入5,然后按回车键。 4.示例需要设置为手动更新,即输入M,然后按回车键。 5.出现提示信息,点击“确定”即可。

第四十一篇:移动端调试工具

1.下载工具 npm install vconsole 2.在main.js里全局引用 > import Vconsole from vconsole > new Vconsole()

Python面向对象三大特征(python系列20)

1.封装 定义: 数据角度:将基本数据类型复合成一个自定义类型。 作用:可读性更高,将数据与对数据的操作相关联。 行为角度:对类外提供必要的功能,隐藏实现的细节 作用:让调用者不必了解实现代码&…

ABAP 明细alv跳转到汇总alv一般模板

需求描述:做开发的同时,经常会有遇到,根据明细表进行逻辑汇总,在两个屏幕进行跳转,然后按钮还要做功能的情况,我这边记录一下最简单点模板,给新手可以直接复制使用的。 一、源代码 TYPE-POOLS…

基于JAVAEE技术校园车辆管理系统论文

摘 要 现代经济快节奏发展以及不断完善升级的信息化技术,让传统数据信息的管理升级为软件存储,归纳,集中处理数据信息的管理方式。本校园车辆管理系统就是在这样的大环境下诞生,其可以帮助管理者在短时间内处理完毕庞大的数据信息…

Mozilla 推出 Solo:借助 AI 帮助零编程用户创建网站

Mozilla 近日推出名为 Solo 的全新项目,面向没有任何编程经验的用户,通过融入 AI 能力,所创建的网站可以媲美专业开发者的开发效果。 Mozilla 表示该项目主要针对中小型企业、个体户,在官方演示中,用户只需要输入文本、…

Linux完成mysql数据库的备份与恢复

背景: 在进行数据报表的测试过程中,为了让我们的测试数据更加真实,因此我们需要同步生产数据到测试环境。方式有很多种,我这里介绍的是通过Linux完成数据同步。 备份数据: 执行命令:mysqldump -uxxx -pxxx…

HAAS 哈斯机床 读写刀补数据

哈斯机床不管是串口机床还是网口机床 都提供了Q命令 可以使用Q命令 进行刀具补偿的读取和写入 最多支持200把刀的 读取和写入

外贸SOHO建站怎么做?海洋建站方法策略?

外贸SOHO建站多少钱?外贸自助建站系统有哪些? 随着全球化的加速发展,外贸SOHO已经成为越来越多创业者的选择。然而,要想在竞争激烈的外贸市场中脱颖而出,一个专业的外贸网站是必不可少的。接下来海洋建站将探讨外贸SO…

jsp文件引用的css修改后刷新不生效问题

问题 在对 JavaWeb 项目修改的过程中,发现修改了 jsp 文件引入的 css 文件的代码后页面的样式没有更新的问题。 原因 导致这个问题的原因可能是因为浏览器缓存的问题。 解决方法 下面介绍两种解决方法,供大家参考: 1、给 link 标签的 c…

图文并茂讲VLAN,一遍就能理解

图文并茂讲VLAN,一遍就能理解 弱电行业圈2019-03-19 10:12 vlan的应用在网络项目中是非常广泛的,基本上大部分的项目都需要划分vlan,前几天我们讲到vlan的配置,有朋友就提到有没有更基础一些的内容,今天我们就从基础…

【LeetCode刷题】--172.阶乘后的零

172.阶乘后的零 方法&#xff1a; class Solution {public int trailingZeroes(int n) {int ans 0;for(int i 5;i<n;i5){for(int x i; x % 50; x/5){ans;}}return ans;} }进一步优化&#xff1a; class Solution {public int trailingZeroes(int n) {int ans 0;while (n…

今日开幕!飞凌嵌入式受邀参加2023年瑞萨技术交流日全国巡回展

来源&#xff1a;飞凌嵌入式官网 2023年瑞萨技术交流日全国巡回展&#xff08;广州站&#xff09;今日开幕&#xff0c;飞凌嵌入式再次受邀参加&#xff0c;并与来自新能源、自动化、工业物联网以及人工智能等领域的精英们共同探讨前沿技术。 在今日的巡展现场&#xff0c;飞凌…

绝地求生:PGC2023胜者组D2下半场:17天霸成功晋级,TL、NH跌入最后机会组

第四场 第一名&#xff1a;LGC 第二名&#xff1a;T5 第三名&#xff1a;FaZe 17仅剩两人&#xff0c;T5踩住高点&#xff0c;sujiu前顶时被T5架枪位击倒&#xff0c;小鬼的盾牌没能挡住对方的雷遗憾第五出局。然而T5自己也进圈不易&#xff0c;仅剩两人。 LG独狼卡住T5却忽…

STM32-02-STM32基础知识

文章目录 STM32基础知识1. STM32F103系统架构2. STM32寻址范围3. 存储器映射4. 寄存器映射 STM32基础知识 1. STM32F103系统架构 STM32F103 STM32F103是ST公司基于ARM授权Cortex M3内核而设计的一款芯片&#xff0c;而Cortex M内核使用的是ARM v7-M架构&#xff0c;是为了替代…

新生儿智力检测的关键:培养潜能、关注发展

引言&#xff1a; 新生儿期是智力发展的关键时期&#xff0c;而科学的智力检测可以帮助父母更好地了解宝宝的认知水平和发展潜力。然而&#xff0c;在进行新生儿智力检测时&#xff0c;需要特别注意一些关键事项&#xff0c;以确保测试的准确性和对宝宝的尊重。本文将深入探讨…

【ECharts】从零实现echarts地图完整代码(纯前端,包含地图资源)

最终效果 标题环境搭建 这里忽略创建vue项目的操作过程&#xff0c;请自行搭建 vue2 项目、less 环境 安装下载 echarts 这里我们选择npm下载 npm install echarts安装成功后&#xff0c;在 main.js 中把echarts配置到this上 // 引入 echarts import * as Echarts from ech…