pytorch将数据与模型都放到GPU上训练

news2024/12/28 9:56:18

默认是CPU,如果想要用GPU需要:

  1. 安装配置cuda,然后更新/下载支持gpu版本的pytorch,可以参考:https://blog.csdn.net/weixin_35757704/article/details/124315569
  2. 设置device:
    device = torch.device('cuda' if torch.cuda.is_available else 'cpu')
    
    然后将数据与模型后面都额外加上.to(device)即可

示例程序

import torch
import torch.nn as nn


# 一个简单的模型
class LinearRegressionModel(nn.Module):
    def __init__(self, input_shape, output_shape):
        super(LinearRegressionModel, self).__init__()
        self.linear = nn.Linear(input_shape, output_shape)

    def forward(self, x):
        out = self.linear(x)
        return out


def main():
    x_train = torch.randn(100, 4)  # 生成训练特征
    y_train = torch.randn(100, 1)  # 生成label
    model = LinearRegressionModel(x_train.shape[1], 1)
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01)  # 优化函数
    criterion = nn.MSELoss()  # 损失函数
    for epoch in range(100):
        optimizer.zero_grad()
        outputs = model(x_train)
        loss = criterion(outputs, y_train)
        loss.backward()
        optimizer.step()


if __name__ == '__main__':
    main()

修改为GPU版本:

import torch
import torch.nn as nn


# 一个简单的模型
class LinearRegressionModel(nn.Module):
    def __init__(self, input_shape, output_shape):
        super(LinearRegressionModel, self).__init__()
        self.linear = nn.Linear(input_shape, output_shape)

    def forward(self, x):
        out = self.linear(x)
        return out


def main():
    # 1. 设置device
    device = torch.device('cuda' if torch.cuda.is_available else 'cpu')
    # 2. 数据与模型后都加 .to(device) 即可
    x_train = torch.randn(100, 4).to(device)  # 生成训练特征
    y_train = torch.randn(100, 1).to(device)  # 生成label
    model = LinearRegressionModel(x_train.shape[1], 1).to(device)  # next(transformer.parameters()).device

    optimizer = torch.optim.SGD(model.parameters(), lr=0.01)  # 优化函数
    criterion = nn.MSELoss()  # 损失函数
    for epoch in range(100):
        optimizer.zero_grad()
        outputs = model(x_train)
        loss = criterion(outputs, y_train)
        loss.backward()
        optimizer.step()


if __name__ == '__main__':
    main()

修改后:

  1. 查看变量的位置:可以使用x_train.device查看tensor变量的位置
  2. 查看模型的位置:可以使用next(model.parameters()).device查看模型的位置

注意:不在同一个位置上的变量之间无法计算,模型无法使用不在同一个位置的数据

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2266865.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

leetcode-42.接雨水-day19

思路:将一整个区域分成若干个小区域看,高低起伏 1.记录每个板子上的雨水数量,最后相加求和。h板高 2.-->由于每个板子区间能装多少水取决于他的最大前缀板高和最大后缀板高, 3. 然后根据短板效应,则每个板子区间最多…

Postman接口测试02|接口用例设计

目录 六、接口用例设计 1、接口测试的测试点(测试维度) 1️⃣功能测试 2️⃣性能测试 3️⃣安全测试 2、设计方法与思路 3、单接口测试用例 4、业务场景测试用例 1️⃣分析测试点 2️⃣添加员工 3️⃣查询员工、修改员工 4️⃣删除员工、查询…

自定义kali:增加60+常用渗透工具,哥斯拉特战版,cs魔改应有尽有,菜单栏启动

前言: 集合了六十多个工具,有师傅说需要,特搞来,我是脚本小子我爱用 介绍: 主要在菜单增加了非常多别人现成的工具,工具名单: 信息收集: 密探渗透测试工具 水泽 ehole 灯塔 …

数据结构(Java)——链表

1.概念及结构 链表是一种 物理存储结构上非连续 存储结构,数据元素的 逻辑顺序 是通过链表中的 引用链接 次序实现的 。 2.分类 链表的结构非常多样,以下情况组合起来就有 8 种链表结构: (1)单向或者双向 (…

Linux 文件的特殊权限—Sticky Bit(SBIT)权限

本文为Ubuntu Linux操作系统- 第十九期~~ 其他特殊权限: 【SUID 权限】和【SGID 权限】 更多Linux 相关内容请点击👉【Linux专栏】~ 主页:【练小杰的CSDN】 文章目录 Sticky(SBIT)权限基本概念Sticky Bit 的表示方式举例 设置和取…

PPT画图——如何设置导致图片为600dpi

winr,输入regedit打开注册表 按路径找,HKEY_CURRENT_USER\Software\Microsoft\Office\XX.0\PowerPoint\Options(xx为版本号,16.0 or 15.0或则其他)。名称命名:ExportBitmapResolution 保存即可,…

小米汽车加速出海,官网建设引领海外市场布局!

面对国内市场的饱和态势,中国企业出海步伐纷纷加速,小米也是其中的一员。Canalys数据显示,2024年第三季度,小米以13.8%的市场份额占比,实现了连续17个季度位居全球前三的成绩。 据“36 氪汽车”报道,小米汽…

Cocos Creator 试玩广告开发 第二弹

上一篇的项目是2d的,现在谈谈对于3d试玩项目的一些经历。 相对于2d来说,3d的项目更接近于Unity的开发,但是也有很多不一样的地方,具体的也可以参考Cocos给他官方示例。 Unity 开发者入门 Cocos Creator 快速指南 | Cocos Creator…

CTFshow—爆破

Web21 直接访问页面的话会弹窗需要输入密码验证,抓个包看看,发现是Authorization认证,Authorization请求头用于验证是否有从服务器访问所需数据的权限。 把Authorization后面的数据进行base64解码,就是我们刚刚输入的账号密码。 …

docker-开源nocodb,使用已有数据库

使用已有数据库 创建本地数据库 数据库:nocodb 用户:nocodb 密码:xxxxxx修改docker-compose.yml 默认网关的 IP 地址是 172.17.0.1(适用于 bridge 网络模式)version: "2.1" services:nocodb:environment:…

UGUI简单动画制作

一、最终效果 UI简单动画制作 二、制作过程 1、打开动画制作窗口 2、新建一个动画 3、给一个对象制作动画 4、创建动画控制器进行不同动画变换控制 5、书写脚本,通过按钮来进行不同动画切换 using System.Collections; using System.Collections.Generic; using U…

[SAP ABAP] 程序备份

备份当前程序到本地的方式如下: 1.复制粘贴 Ctrl A 、Ctrl V 2.【实用程序】|【更多实用程序】|【上载/下载】|【下载】 ​ 3.快捷键,支持多种格式导出(.abap .html .pdf 等) 在事务码SE38(ABAP编辑器)屏幕右下角,点击【Options选项】图…

代码随想录Day51 99. 岛屿数量,99. 岛屿数量,100. 岛屿的最大面积。

1.岛屿数量深搜 卡码网题目链接(ACM模式)(opens new window) 题目描述: 给定一个由 1(陆地)和 0(水)组成的矩阵,你需要计算岛屿的数量。岛屿由水平方向或垂直方向上相邻的陆地连接…

【漏洞复现】CVE-2022-41678 Arbitrary JMX Service Invocation with Web Interface

漏洞信息 NVD - cve-2022-41678 Apache ActiveMQ prior to 5.16.5, 5.17.3, there is a authenticated RCE exists in the Jolokia /api/jolokia. 组件影响版本安全版本Apache:ActiveMQ< 5.16.6> 5.16.6Apache:ActiveMQ5.17.0 - 5.17.4> 5.17.4&#xff0c;> 6.…

Bash 脚本教程

注&#xff1a;本文为 “Bash 脚本编写” 相关文章合辑。 BASH 脚本编写教程 as good as well于 2017-08-04 22:04:28 发布 这里有个老 American 写的 BASH 脚本编写教程&#xff0c;非常不错&#xff0c;至少没接触过 BASH 的也能看懂&#xff01; 建立一个脚本 Linux 中有…

操作系统(26)数据一致性控制

前言 操作系统数据一致性控制是确保在计算机系统中&#xff0c;数据在不同的操作和处理过程中始终保持正确和完整的一种机制。 一、数据一致性的重要性 在当今数字化的时代&#xff0c;操作系统作为计算机系统的核心&#xff0c;负责管理和协调各种资源&#xff0c;以确保计算机…

48页PPT|2024智慧仓储解决方案解读

本文概述了智慧物流仓储建设方案的行业洞察、业务蓝图及建设方案。首先&#xff0c;从政策层面分析了2012年至2020年间国家发布的促进仓储业、物流业转型升级的政策&#xff0c;这些政策强调了自动化、标准化、信息化水平的提升&#xff0c;以及智能化立体仓库的建设&#xff0…

Windows和Linux安全配置和加固

一.A模块基础设施设置/安全加固 A-1.登录加固 1.密码策略 a.最小密码长度不少于8个字符&#xff0c;将密码长度最小值的属性配置界面截图。 练习用的WindowsServer2008,系统左下角开始 > 管理工具 > 本地安全策略 > 账户策略 > 密码策略 > 密码最小长度&#…

EleutherAI/pythia-70m

EleutherAI/pythia-70m” 是由 EleutherAI 开发的一个小型开源语言模型&#xff0c;它是 Pythia Scaling Suite 系列中参数量最小的模型&#xff0c;拥有大约 7000 万个参数。这个模型主要旨在促进对语言模型可解释性的研究&#xff1b; Pythia Scaling Suite是为促进可解释性…

Linux系统编程——详解页表

目录 一、前言 二、深入理解页表 三、页表的实际组成 四、总结&#xff1a; 一、前言 页表是我们之前在讲到程序地址空间的时候说到的&#xff0c;它是物理内存到进程程序地址空间的一个桥梁&#xff0c;通过它物理内存的数据和代码才能映射到进程的程序地址空间中&#xff…