Pytorch学习笔记——在GPU上进行训练

news2024/11/13 14:44:41

文章目录

      • 1. 环境准备
      • 2. 导入必要库
      • 3. 加载数据集
      • 4. 定义简单的神经网络模型
      • 5. 检查和设置GPU设备
      • 6. 定义损失函数和优化器
      • 7. 训练模型
      • 8. 全部代码展示及运行结果

1. 环境准备

首先,确保PyTorch已经安装,且CUDA(NVIDIA的并行计算平台和编程模型)支持已经正确配置。可以通过以下代码检查CUDA是否可用:

print(torch.cuda.is_available())  # 如果返回True,则CUDA可用

配置PyTorch环境和CUDA支持,可以参考我写的这篇博客
Pytorch学习笔记——环境配置安装

2. 导入必要库

import torch
import torchvision
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Linear, Flatten, Sequential
from torch.utils.data import DataLoader
  • torch 是PyTorch的核心库。
  • torchvision 提供了用于计算机视觉任务的工具,包括数据集和变换。
  • nn 包含了构建神经网络所需的各种模块。
  • DataLoader 用于加载数据集并进行批处理。

3. 加载数据集

使用CIFAR-10数据集进行训练,它是一个常用的小型图像数据集。加载数据集并创建数据加载器。

# 加载数据集
dataset = torchvision.datasets.CIFAR10(root="data1", train=False, transform=torchvision.transforms.ToTensor(), download=True)
# 创建数据加载器
dataloader = DataLoader(dataset, batch_size=64)
  • torchvision.datasets.CIFAR10:下载并加载CIFAR-10数据集。
  • DataLoader:将数据集划分为小批次,并进行数据加载。

4. 定义简单的神经网络模型

定义一个简单的卷积神经网络(CNN)模型:

self.model1 = Sequential(
            Conv2d(3, 32, 5, padding=2),  # 第一次卷积
            MaxPool2d(2),  # 第一次最大池化
            Conv2d(32, 32, 5, padding=2),  # 第二次卷积
            MaxPool2d(2),  # 第二次最大池化
            Conv2d(32, 64, 5, padding=2),  # 第三次卷积
            MaxPool2d(2),  # 第三次最大池化
            Flatten(),    # 展平层
            Linear(1024, 64),  # 第一个全连接层
            Linear(64, 10),  # 第二个全连接层
        )
def forward(self, x):
        x = self.model1(x)
        return x
  • Conv2d:二维卷积层,用于提取图像特征。
  • MaxPool2d:最大池化层,用于下采样。
  • Flatten:将多维输入展平为一维,用于全连接层的输入。
  • Linear:全连接层,用于分类任务。

5. 检查和设置GPU设备

需要检查是否有可用的GPU,并将模型和数据移动到GPU上。

# 检查是否有可用的GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 将模型和数据转移到GPU上
mynn = NN().to(device)
print(mynn)
  • torch.device("cuda"):如果CUDA可用,则使用GPU;否则使用CPU。
  • to(device):将模型转移到指定的设备(CPU或GPU)。

6. 定义损失函数和优化器

定义损失函数和优化器来训练模型:

# 定义损失函数
loss = nn.CrossEntropyLoss().to(device)

# 定义优化器
optim = torch.optim.SGD(mynn.parameters(), lr=0.01)
  • nn.CrossEntropyLoss:适用于分类问题的损失函数。
  • torch.optim.SGD:随机梯度下降优化器。

7. 训练模型

通过多个epoch对模型进行训练。在每个epoch中,进行前向传播、计算损失、反向传播和参数更新:

# 多轮学习 0 - 20 20轮
for epoch in range(20):
    running_loss = 0.0
    for data in dataloader:
        # 确保数据也转移到GPU上
        imgs, targets = data[0].to(device), data[1].to(device)

        optim.zero_grad()  # 清零梯度缓存
        outputs = mynn(imgs)  # 前向传播
        loss_value = loss(outputs, targets)  # 计算损失
        loss_value.backward()  # 反向传播,计算梯度
        optim.step()  # 根据梯度更新权重

        running_loss += loss_value.item()  # 累加损失值

    print(f"Epoch {epoch + 1}, Loss: {running_loss / len(dataloader)}")
    print("------------------------------")
  • optim.zero_grad():清零之前计算的梯度。
  • outputs = mynn(imgs):进行前向传播。
  • loss_value.backward():进行反向传播,计算梯度。
  • optim.step():根据计算的梯度更新权重。
  • running_loss:累加损失值以计算平均损失。

8. 全部代码展示及运行结果

# -*- coding: utf-8 -*-
# @Author: kk
import torch
import torchvision
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Linear, Flatten, Sequential
from torch.utils.data import DataLoader

# 加载数据集
dataset = torchvision.datasets.CIFAR10(root="data1", train=False, transform=torchvision.transforms.ToTensor(), download=True)
# loader加载
dataloader = DataLoader(dataset, batch_size=64)

# 网络
class NN(nn.Module):
    def __init__(self):
        super(NN, self).__init__()
        self.model1 = Sequential(
            Conv2d(3, 32, 5, padding=2),  # 第一次卷积
            MaxPool2d(2),  # 第一次最大池化
            Conv2d(32, 32, 5, padding=2),  # 第二次卷积
            MaxPool2d(2),  # 第二次最大池化
            Conv2d(32, 64, 5, padding=2),  # 第三次卷积
            MaxPool2d(2),  # 第三次最大池化
            Flatten(),    # 展平层
            Linear(1024, 64),  # 第一个全连接层
            Linear(64, 10),  # 第二个全连接层
        )

    def forward(self, x):
        x = self.model1(x)
        return x

# 检查是否有可用的GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 将模型和数据转移到GPU上
mynn = NN().to(device)
print(mynn)
loss = nn.CrossEntropyLoss().to(device)

# 优化器
optim = torch.optim.SGD(mynn.parameters(), lr=0.01)

# 多轮学习  0 - 20  20轮
for epoch in range(20):
    running_loss = 0.0
    for data in dataloader:
        # 确保数据也转移到GPU上
        imgs, targets = data[0].to(device), data[1].to(device)

        optim.zero_grad()  # 清零梯度缓存
        outputs = mynn(imgs)  # 前向传播
        loss_value = loss(outputs, targets)  # 计算损失
        loss_value.backward()  # 反向传播,计算梯度
        optim.step()  # 根据梯度更新权重

        running_loss += loss_value.item()  # 累加损失值

    print(f"Epoch {epoch + 1}, Loss: {running_loss / len(dataloader)}")
    print("------------------------------")

运行结果如下,发现在每一轮过后,Loss在逐渐减小:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1953239.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

go-kratos 学习笔记(6) 数据库gorm使用

数据库是项目的核心,数据库的链接数据是data层的操作,选择了比较简单好用的gorm作为数据库的工具;之前是PHP开发,各种框架都是orm的操作;gorm还是很相似的,使用起来比较顺手 go-kratos官网的实例是ent&…

鸿蒙UI系统组件10——菜单(Menu)

果你也对鸿蒙开发感兴趣,加入“Harmony自习室”吧!扫描下面名片,关注公众号。 Menu是菜单接口,一般用于鼠标右键弹窗、点击弹窗等。 1、创建默认样式的菜单 菜单需要调用bindMenu接口来实现。bindMenu响应绑定组件的点击事件&am…

ModuleNotFoundError: No module named ‘py3langid‘ 以及如何将包安在虚拟环境下

前提:已经安装过改包(pip install py3langid),但仍报错 原因:安装在其他目录下了 解决办法: 1、再次在终端输入pip install py3langid 显示安装位置 Requirement already satisfied: py3langid in c:\…

css大屏设置中间元素四周渐变透明效果

css大屏设置中间元素四周渐变透明效果 四周透明效果: // 设置蒙版上下左右渐变显示mask-image: linear-gradient(to right, rgba(0, 0, 0, 0) 0%, rgba(0, 0, 0, 1) 10%, rgba(0, 0, 0, 1) 90%, rgba(0, 0, 0, 0) 100%),linear-gradient(to bottom, rgba(0, 0, 0…

性能测试的指标及流程

性能测试指标 相应时间 并发数 吞吐量: 点击数 错误率 资源使用率 所有的东西是存在磁盘里的,在代码运行的时候会将磁盘的东西读取到内存里,磁盘IO和网络都是衡量速度,在任务管理器可查看资源使用率 题: 答案&#xf…

创建vue3项目,以及使用示例

1.在根目录下cmd:vue create myobj(没有切换淘宝镜像记得切换,这样创建项目运行快) 2. 3.(按空格键选中,选好回撤就到下一步了) 4. 5. 6. 7. 8. 9. 10. 11. 12. 13.然后输入执行以下两步就已经运行项目了 以…

Java算法之递归算法-如何计算阶乘的值

上一篇学了递归之后,练习一下递归算法。 题目:使用递归算法计算阶乘的值,也就是5!5*4*3*2*1,直接使用循环是非常简单的,这边练习一下递归算法。 先写一下两个条件 基线条件:等于1的时候返回1…

windows下实现mongodb备份还原

添加环境变量 把mongodb安装目录下的bin路径添加到环境变量的path路径: 备份库 打开CMD,执行以下命令: mongodump -u test -p test -d test -o D://backup_mongodb//20220706 –gzip 参数说明: -u 用户名 -p 密码 -d 需要备份的库名称…

GraphHopper路劲规划导航(Android源码调试运行)

本文主要记录在运行graphhopper安卓版路径规划导航源码的步骤和遇到的问题。 成功运行了程序,但是路劲规划一直不成功,问题一开始是服务地址,后来又是key的问题,在这个项目中涉及到了graphhopper、mapbox、mapilion的key&#xff…

map、foreach、filter这些方法你还不知道什么时候该用哪个吗?那就看过来

forEach:‌主要用于遍历数组并对每个元素执行某种操作,‌通常用于改变当前数组里的值。‌它不会返回新数组,‌而是直接在原数组上进行操作。‌forEach方法不支持return、‌break、‌continue等语句,‌因为这些语句在forEach中不会…

多线程实例-线程池

线程池,就是把线程提前从系统中申请好,放到一个地方,后面需要使用线程的时候,直接从这个地方取,而不是从系统重新申请,线程用完之后也回到刚才的地方。 线程池的优点:降低线程创建和销毁的开销…

MICA:面向复杂嵌入式系统的混合关键性部署框架

背景 在嵌入式场景中,虽然 Linux 已经得到了广泛应用,但并不能覆盖所有需求,例如高实时、高可靠、高安全的场合。这些场合往往是实时操作系统的用武之地。有些应用场景既需要 Linux 的管理能力、丰富的生态,又需要实时操作系统的高…

戴尔vostro15-3568硬盘升级+系统重装

硬盘升级 原2.5机械硬盘换成了SATA2.5的固态硬盘 按F2进入bios后看到的电池信息如下: 需要重新换一个电池 系统重装 步骤如下 1.U盘需要格式化成 NTFS 类型的,并且从官网下载后介质 2.BV1z3411K7AD b站这个视频前三步可以参考设置

八、桥接模式

文章目录 1 基本介绍2 案例2.1 OperatingSystem 抽象类2.2 LinuxOS 类2.3 WindowsOS 类2.4 FileOperation 类2.5 FileAppender 类2.6 FileReplicator 类2.7 Client 类2.8 Client 类运行结果2.9 总结 3 各角色之间的关系3.1 角色3.1.1 Implementor ( 实现者 )3.1.2 ConcreteImpl…

微信答题小程序产品研发-UI界面设计

高保真原型虽然已经很接近产品形态了,但毕竟还不能够直接交付给开发。这时就需要UI设计师依据之前的原型设计,进一步细化和实现界面的视觉元素,包括整体视觉风格、颜色、字体、图标、按钮以及交互细节优化等。 UI设计不仅关系到用户的直观感…

1.c#(winform)编程环境安装

目录 安装vs创建应用帮助查看器安装与使用( msdn) 安装vs 安装什么版本看个人心情,或者公司开发需求需要 而本栏全程使用vs2022进行开发c#,着重讲解winform桌面应用开发 使用***.net framework***开发 那先去官网安装企业版的vs…

这一文,关于Java泛型的点点滴滴 一

作为一个 Java 程序员&#xff0c;用到泛型最多的&#xff0c;我估计应该就是这一行代码&#xff1a; List<String> list new ArrayList<>();这也是所有 Java 程序员的泛型之路开始的地方啊。 不过本文讲泛型&#xff0c;先不从这里开始讲&#xff0c;而是再往前…

CVPR 2024最佳论文分享:Mip-Splatting: 无混叠3D高斯溅射

本推文详细介绍了CVPR 2024最佳论文提名《Mip-Splatting: Alias-free 3D Gaussian Splatting》。该论文的第一作者为 Zehao Yu&#xff08;图宾根大学在读博士&#xff0c;导师&#xff1a;Andreas Geiger &#xff09;。论文提出了一种名为Mip-Splatting的方法&#xff0c;用于…

树和二叉树(不用看课程)

1. 树 1.1 树的概念与结构 树是⼀种非线性的数据结构&#xff0c;它是由 n&#xff08;n>0&#xff09; 个有限结点组成⼀个具有层次关系的集合。把它叫做树是因为它看起来像⼀棵倒挂的树&#xff0c;也就是说它是根朝上&#xff0c;而叶朝下的。 • 有⼀个特殊的结点&am…