0基础学习PyTorch——GPU上训练和推理

news2025/1/23 17:36:33

大纲

  • 创建设备
  • 训练
  • 推理
  • 总结

在《Windows Subsystem for Linux——支持cuda能力》一文中,我们让开发环境支持cuda能力。现在我们要基于《0基础学习PyTorch——时尚分类(Fashion MNIST)训练和推理》,将代码修改成支持cuda的训练和推理。

创建设备

我们首先需要依据环境是否支持cuda来创建相应设备。

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

因为我们开发环境WSL已经支持了cuda,所以此时我们创建的是GPU设备。

训练

训练的过程有两处修改:

  • 将模型实例化到GPU上。
model = GarmentClassifier().to(device) # model = GarmentClassifier()
  • 将数据移动到GPU上。
inputs, labels = data  # 获取输入数据和对应的标签
inputs, labels = inputs.to(device), labels.to(device)  # 将数据移动到GPU上

完整代码如下

from datetime import datetime
import torch
import torchvision
import torchvision.transforms as transforms
from garmentclassifier import GarmentClassifier

# 定义图像转换操作:将图像转换为张量,并进行归一化处理
transform = transforms.Compose(
    [transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))]) # 对图像的每个通道进行标准化,使得每个通道的像素值具有零均值和单位标准差

# 加载FashionMNIST训练数据集,并应用定义的图像转换操作
training_set = torchvision.datasets.FashionMNIST('./data', train=True, transform=transform)

# 创建数据加载器,用于批量加载训练数据,batch_size为4,数据顺序随机打乱
trainloader = torch.utils.data.DataLoader(training_set, batch_size=4, shuffle=True)

# 将模型移动到GPU上
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 实例化模型并移动到GPU上
model = GarmentClassifier().to(device)

# 定义损失函数为交叉熵损失
loss_fn = torch.nn.CrossEntropyLoss()
# 定义优化器为随机梯度下降(SGD),学习率为0.001,动量为0.9
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

# 训练模型,训练2个epoch
for epoch in range(2):
    running_loss = 0.0  # 初始化累计损失
    # 枚举数据加载器中的数据,i是批次索引,data是当前批次的数据
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data  # 获取输入数据和对应的标签
        
        inputs, labels = inputs.to(device), labels.to(device)  # 将数据移动到GPU上
        
        optimizer.zero_grad()  # 清空梯度
        outputs = model(inputs)  # 前向传播,计算模型输出
        loss = loss_fn(outputs, labels)  # 计算损失
        loss.backward()  # 反向传播,计算梯度
        optimizer.step()  # 更新模型参数
        running_loss += loss.item()  # 累加损失

        # 每2000个批次打印一次平均损失
        if i % 2000 == 1999:
            print(f'[{epoch + 1}, {i + 1}] loss: {running_loss / 2000}')
            running_loss = 0.0  # 重置累计损失
      
# 获取当前时间戳,格式为 'YYYYMMDD_HHMMSS'
timestamp = datetime.now().strftime('%Y%m%d%H%M%S.pth')

# 定义模型保存路径,包含时间戳
model_path = 'model_{}'.format(timestamp)      

# 保存模型的状态字典到指定路径
torch.save(model.state_dict(), model_path)

在这里插入图片描述

推理

GPU上算出的模型不一定非要在GPU上推理,也可以在CPU上推理。
但是本文我们就是希望模型在GPU上推理,则可以对代码做如下修改。

  • 将模型实例化到GPU上。
model = GarmentClassifier().to(device)  # model = GarmentClassifier()
  • 将数据移动到GPU上。
image = image.to(device)  # 将图像移动到GPU上

完整代码如下

import os
import glob
import torch
import torchvision.transforms as transforms
from PIL import Image
from datetime import datetime
from garmentclassifier import GarmentClassifier

def get_latest_model_path(directory, pattern="model_*.pth"):
    # 获取目录下所有符合模式的文件
    model_files = glob.glob(os.path.join(directory, pattern))
    if not model_files:
        raise FileNotFoundError("No model files found in the directory.")
    
    # 找到最新的模型文件
    latest_model_file = max(model_files, key=os.path.getmtime)
    return latest_model_file

# 定义图像转换操作:将图像转换为张量,并进行归一化处理
transform = transforms.Compose([
    transforms.Resize((28, 28)),  # 调整图像大小为28x28
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# 将模型移动到GPU上
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 实例化模型并移动到GPU上
model = GarmentClassifier().to(device)  

# 加载训练好的模型
model_path = get_latest_model_path('./')  # 获取最新的模型文件
model.load_state_dict(torch.load(model_path, weights_only=False)) # 加载模型参数
model.eval()  # 设置模型为评估模式

# 从本地加载图像
image_path = 'shoe.jpg'  # 替换为实际的图像路径
image = Image.open(image_path).convert('L')  # 将图像转换为灰度图

# 预处理图像
image = transform(image)
image = image.unsqueeze(0)  # 增加一个批次维度
image = image.to(device)  # 将图像移动到GPU上

# 推理(预测)
with torch.no_grad():  # 在推理过程中不需要计算梯度
    outputs = model(image)  # 前向传播,计算模型输出
    _, predicted = torch.max(outputs, 1)  # 获取预测结果

# 定义类别名称
classes = ('T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
           'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot')

# 打印预测结果
print(f'Predicted label: {classes[predicted.item()]}')

在这里插入图片描述

总结

  • 依据系统是否支持cuda来生成设备。
  • 模型和数据都要移动到相同的设备上。
  • 模型是由CPU还是GPU训练的,并不影响推理使用CPU还是GPU。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2172861.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

[sql-03] 求阅读至少两章的人数

准备数据 CREATE TABLE book_read (bookid varchar(150) NOT NULL COMMENT 书籍ID,username varchar(150) DEFAULT NULL COMMENT 用户名,seq varchar(150) comment 章节ID ) ENGINEInnoDB DEFAULT CHARSETutf8mb4 COMMENT 用户阅读表insert into book_read values(《太子日子》…

MindSearch 部署到Github Codespace 和 Hugging Face Space

和原有的CPU版本相比区别是把internstudio换成了github codespace。 教程是https://github.com/InternLM/Tutorial/blob/camp3/docs/L2/MindSearch/readme_github.md 复现步骤: 根据教材安装环境和创建硅基流动 API 然后启动前后端 然后按照教材部署到 Huggi…

安宝特案例 | 某知名日系汽车制造厂,借助AR实现智慧化转型

案例介绍 在全球制造业加速数字化的背景下,工厂的生产管理与设备维护效率愈发重要。 某知名日系汽车制造厂当前面临着设备的实时监控、故障维护,以及跨地域的管理协作等挑战,由于场地分散和突发状况的不可预知性,传统方式已无法…

计算机的错误计算(一百零六)

摘要 探讨含有变元负的整数次方的多项式的计算精度问题。 计算机的错误计算(一百零五)给出了一个传统多项式的错误计算案例;本节探讨含有变元负的整数次方的多项式的计算精度问题。 例1. 已知 计算 若在Python下计算,则有&…

猎板PCB大讲堂:PCB谐振效应及其对设计的影响

在PCB设计中,谐振效应是一个不可忽视的问题,它可能导致信号完整性问题、电源分配系统(PDS)工作异常,甚至成为EMI辐射源。以下是关于PCB谐振效应的一些详细信息: 1. 谐振产生的原因: - PCB中…

d2l | 目标检测数据集:RuntimeError: No such operator image::read_file

目录 1 存在的问题2 可能的解决方案3 最终的解决方案3.1 方案一(我已弃用)3.2 方案二(基于方案一)3.3 方案三(基于方案一) 1 存在的问题 李沐老师提供的读取香蕉数据集的函数如下: def…

Ubuntu系统设置bond双网卡

这里我的服务器是Ubuntu 22.04.3 LTS,是高阶版本,设置网卡需要通过netplan 根据你的Ubuntu版本(如使用Netplan或/etc/network/interfaces),选择相应的配置方法。 我这边以root用户登录进服务器,就不需要普通用户每次在命令前添加sudo 1.通常/etc/netplan下配置文件名形…

IDEA开发SpringBoot项目基础入门教程。包括Spring Boot简介、IDEA创建相关工程及工程结构介绍、书写配置文件、Bean对象管理等内容

文章目录 0. 关于本文1. 概述1.1 Spring简介1.2 Spring Boot简介1.3 传统的开发方式1.3.1 简述1.3.2 缺点 1.4 Spring Boot的优点 2. 创建一个简单的Spring Boot应用程序2.1 在IDEA创建项目2.2 pom配置文件内容2.3 启动类2.4 创建Controller 3. 从Maven工程创建Spring Boot工程…

数据结构~二叉搜索树

文章目录 一、二叉搜索树的概念二、二叉搜索树的结构二叉搜索树的性能分析二叉搜索树的插入二叉搜索树的查找二叉搜索树的删除 三、二叉搜索树key和key/value使用场景四、二叉搜索树的练习将二叉搜索树就地转化为已排序的双向循环链表从前序与中序遍历序列构造二叉树二叉树的前…

jmeter-请求参数加密-MD5加密

方法1 :使用jmeter自带的函数助手digest Tool(工具)---Function Helper Dialog(函数助手对话框) 第一个参数是要md5加密的值,第二个参数是保存加密后值的变量 ( 此处变量是从txt文件导入的,所以使用的是${wd} ) …

excel统计分析(1):列联表分析与卡方检验

列联表:用于展示两个或多个分类变量之间频数关系的表格。——常用于描述性分析卡方检验:通过实际频数和期望频数(零假设为真情况下的频数),反映了观察频数与期望频数之间的差异程度,来评估两个变量是否独立…

Metasploit渗透测试之服务端漏洞利用

简介 在之前的文章中,我们学习了目标的IP地址,端口,服务,操作系统等信息的收集。信息收集过程中最大的收获是服务器或系统的操作系统信息。这些信息对后续的渗透目标机器非常有用,因为我们可以快速查找系统上运行的服…

System Timer (STM)

文章目录 1. 介绍2. 功能特性3. 应用场景4. 功能介绍4.1 TIME0 ~TIME6计数器精度与定时范围4.2 比较器工作原理4.3 中断处理 5. Ifx Demo5.1 STM_Interrupt_1_KIT_TC277_TFT5.2 STM_System_Time_1_KIT_TC275_LK5.3 SMU_Reset_Alarm_1_KIT_TC275_LK 1. 介绍 Ifx TC37x拥有3个自…

前端大模型入门:使用Transformers.js实现纯网页版RAG(一)

我将使用两篇文章的篇幅,教大家如何实现一个在网页中运行的RAG系统。本文将其前一半功能:深度搜索。 通过这篇文章,你可以了解如何在网页中利用模型实现文本相似度计算、问答匹配功能,所有的推理都在浏览器端本地执行,…

C语言-IO

一,阻塞IO与非阻塞IO 简介: IO的本质是基于操作系统接口来控制底层的硬件之间数据传输,并且在操作系统中实现了多种不同的 IO 方式(模型),比较常见的有下列三种 阻塞型IO模型 非阻塞型IO模型 多路复用IO模型 在 C 语言中&#…

牛客SQL练习详解 02:条件查询

牛客SQL练习详解 02:条件查询 1、基础排序sql36 查找后排序sql37 查找后多列排序sql38 查找后降序排列 2、基础操作符sql6 查找学校时北大的学生信息sql7 查找年龄大于24岁的用户信息sql8 查找某个年龄段的用户信息sql9 查找chuchu 3、高级操作符sql11 高级操作符练…

认知杂谈91《菜鸟的自我修炼:减少过度干预》

内容摘要:          在投资和生活中,动作过多往往因情绪波动和缺乏计划而引发亏损。历史上的安史之乱和现代投资中的频繁交易都是例证。要管理情绪,首先要认识自己的情绪模式,然后改变消极的思考方式,并通过合…

『USB3.0Cypress』QT基于cyusb_linux_1.0.5开发上位机

文章目录 1.CyUSB Suite2.搭建开发环境3.Cyusb的应用4.疑问解决5.传送门1.CyUSB Suite CyUSB Suite for Linux是一个围绕现有开源用户空间USB库libusb的wrapper。CyUSB套件通过围绕libusb的简化包装器以及在下载固件后提供用于测试外围设备的基础设施,让您快速入门。换句话说…

1.6 物理层

欢迎大家订阅【计算机网络】学习专栏,开启你的计算机网络学习之旅! 文章目录 前言1 物理层的基本概念1.1 定义1.2 作用1.3 物理层的主要任务 2 数据通信的基础知识2.1 常用术语2.2 信号2.3 码元2.4 信道2.5 数据通信系统模型 3 信道的极限容量3.1 基本术…

LabVIEW提高开发效率技巧----合理管理程序架构

在LabVIEW开发中,合理管理程序架构是保持项目可维护性和扩展性的关键。随着项目复杂度的增加,良好的架构设计可以避免代码混乱,并且便于后期的修改和扩展。以下是两种常见且有效的架构管理方式: 1. 面向对象编程(OOP&a…