pytorch2.x 官方quickstart测试

news2024/10/6 10:38:51

文章目录

  • 1.本地环境
  • 2.[安装pytorch](https://pytorch.org/get-started/locally/) (Windows GPU版本)
  • 3. [官方quickstart](https://pytorch.org/tutorials/beginner/basics/quickstart_tutorial.html)

1.本地环境

D:\python2023>nvidia-smi
Thu Jul 27 23:27:45 2023
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 497.29       Driver Version: 497.29       CUDA Version: 11.5     |
|-------------------------------+----------------------+----------------------+
| GPU  Name            TCC/WDDM | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  NVIDIA GeForce ... WDDM  | 00000000:03:00.0  On |                  N/A |
| 27%   36C    P8     8W / 120W |    397MiB /  3072MiB |      1%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+

+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|    0   N/A  N/A      1288    C+G   Insufficient Permissions        N/A      |
|    0   N/A  N/A      3444    C+G   ...y\ShellExperienceHost.exe    N/A      |
|    0   N/A  N/A      7420    C+G   ...nputApp\TextInputHost.exe    N/A      |
|    0   N/A  N/A      7896    C+G   C:\Windows\explorer.exe         N/A      |
|    0   N/A  N/A      8392    C+G   ...b3d8bbwe\WinStore.App.exe    N/A      |
|    0   N/A  N/A      8872    C+G   ...5n1h2txyewy\SearchApp.exe    N/A      |
|    0   N/A  N/A     10860    C+G   ...lPanel\SystemSettings.exe    N/A      |
|    0   N/A  N/A     11536    C+G   ...se6\Application\360se.exe    N/A      |
|    0   N/A  N/A     14264    C+G   ...\qbblinktrial\browser.exe    N/A      |
+-----------------------------------------------------------------------------+

D:\python2023>gcc --version
gcc (x86_64-posix-sjlj-rev0, Built by MinGW-W64 project) 8.1.0
Copyright (C) 2018 Free Software Foundation, Inc.
This is free software; see the source for copying conditions.  There is NO
warranty; not even for MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.


D:\python2023>python --version
Python 3.8.5

D:\python2023>

2.安装pytorch (Windows GPU版本)

注意 安装时一定要指定–index-url https://download.pytorch.org/whl/torch/ ,否则安装的是cpu版本,可以访问https://download.pytorch.org/whl/torch/,找到需要的版本如torch-2.0.1+cu117-cp38-cp38-win_amd64.whl 用迅雷下载比较快

pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu117 

在这里插入图片描述

3. 官方quickstart

按官方 quickstart拼起來的代码,如果有GPU且安装的GPU版本pytorch则跑GPU上否则CPU(所有CPU),本地测试20CPU与1 个geforce GPU耗时差不多20s

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

#####################
# Download training data from open datasets.
training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),
)

# Download test data from open datasets.
test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor(),
)
#####################
batch_size = 64

# Create data loaders.
train_dataloader = DataLoader(training_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)
for X, y in test_dataloader:
    print(f"Shape of X [N, C, H, W]: {X.shape}")
    print(f"Shape of y: {y.shape} {y.dtype}")
    break
#####################
# Get cpu, gpu or mps device for training.
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")


# Define model
class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28 * 28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10)
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits


model = NeuralNetwork().to(device)
print(model)
#######################################
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)


########################################
def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if batch % 100 == 0:
            loss, current = loss.item(), (batch + 1) * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")


#######################################
def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100 * correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")


if __name__ == '__main__':
    epochs = 5
    for t in range(epochs):
    	start = time.time()
        print(f"Epoch {t + 1}\n-------------------------------")
        train(train_dataloader, model, loss_fn, optimizer)
        test(test_dataloader, model, loss_fn)
        end = time.time()
        print(f"epoch Done:{end-start}")
    print("Done!")

D:\python2023>nvidia-smi
Fri Jul 28 00:42:59 2023
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 497.29       Driver Version: 497.29       CUDA Version: 11.5     |
|-------------------------------+----------------------+----------------------+
| GPU  Name            TCC/WDDM | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  NVIDIA GeForce ... WDDM  | 00000000:03:00.0  On |                  N/A |
| 27%   37C    P5     9W / 120W |   1007MiB /  3072MiB |      9%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+

+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|    0   N/A  N/A     11536    C+G   ...se6\Application\360se.exe    N/A      |
|    0   N/A  N/A     14264    C+G   ...\qbblinktrial\browser.exe    N/A      |
|    0   N/A  N/A     15648      C   ...ython\Python38\python.exe    N/A      |
+-----------------------------------------------------------------------------+

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/800791.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

idea项目依赖全部找不到

目录 1,出错现象2,解决3,其他尝试 1,出错现象 很久没打开的Java项目,打开之后大部分依赖都找不到,出现了所有的含有import语句的文件都会报错和一些注解报红报错,但pom文件中改依赖是确实被引入…

VS2015配置opencv4.1(x86和x64)

1.安装VS2015 vs版本和部门统一,安装C模块即可 2.安装opencv4.1 重点还是配置,安装opencv4.1,装就完事了 3.配置opencv4.1 给整麻了,配了一早上 3.1 在电脑属性中找到“高级系统配置” 3.2 环境变量 3.3 写上x86 和 x64的环…

草稿#systemverilog# 说说Systemverilog中《static》那些事儿(拓展篇)

3)static和automatic可以将一个任务task或者函数function显式地声明成静态或者自动的:一个自动automatic 声明的任务、函数或块内声明的数据缺省情况下具有调用期或激活期内的生命周期,并且具有本地的作用范围; 一个静态static 声…

Java虚拟机——线程与协程

1 Java与线程 目前线程是Java里面进行处理器资源调度的最基本单位。如果日后Loom项目能够为Java引入纤程(Fiber)的话,可能会改变这一点。 1.1 线程的实现 这里先把Java技术的背景放下,以一个通用的应用程序的角度来看线程是如何实现的。 1.1.1 内核线…

C数据结构与算法——顺序查找和二分查找算法 应用

实验任务 (1) 掌握顺序查找算法的实现; (2) 掌握二分查找算法的实现; (3) 掌握两种查找算法的时间复杂度与适用场合。 实验内容 (1) 基于顺序查找表实现顺序查找和二分查找算法; (2) 使用两个不同大小的查找表进行两次理论和实际性能对比&…

利用STM32为主控以LORA为通讯模块,通过中继器链接MQTT服务器的物联网信息采集处理的信息系统方案

项目的详细方案如下: 硬件组成: STM32主控板:作为项目的主控单元,负责采集终端点位的温湿度信息,并通过LORA通讯模块发送数据到中继器。 LORA通讯模块:作为STM32与中继器之间的无线通信模块,负…

【NLP】语音识别 — GMM, HMM

一、说明 在语音识别的深度学习(DL)时代之前,HMM和GMM是语音识别的两项必学技术。现在,有将HMM与深度学习相结合的混合系统,并且有些系统是免费的HMM。我们现在有更多的设计选择。然而,对于许多生成模型来说…

浅谈集成式电力电容器无功补偿装置的技术特点及应用状况

安科瑞 华楠 摘要:阐述了集成式电力电容器无功补偿装置的组成与应用状况.在与常规电力电容器对比的基础上,分析了集成式电力电容器无功补偿装置的技术特点。通过对集成式无功补偿装置原理结构的分析,探讨了对集成式无功补偿装置的…

Spring Batch教程(四)tasklet使用示例:spring batch的定时任务使用

Spring batch 系列文章 Spring Batch教程(一) 简单的介绍以及通过springbatch将xml文件转成txt文件 Spring Batch教程(二)示例:将txt文件转成xml文件以及读取xml文件内容存储到数据库mysql Spring Batch教程&#xff…

TPU-MLIR编译部署算法

注意: 由于SOPHGO SE5微服务器的CPU是基于ARM架构,以下步骤将在基于x86架构CPU的开发环境中完成 初始化开发环境(基于x86架构CPU的开发环境中完成)模型转换 (基于x86架构CPU的开发环境中完成) 处理后的PP-OCR项目文件将被拷贝至 SE5微服务器 上进行推理…

el-table-column 合并列,切换表格显示,数据错乱问题

由于同一个页面需要通过lable进行切换显示不同的表格结果在切换的时候发现表格列错乱了 正常是这样的 切换错乱的是这样的 序号没有了,已接单协同总数列也不见了 切换回来发现第一个表格 原先的两列被后面的挤压了 代码也没啥毛病,最主要的原因是因为同…

【从零开始学习JAVA | 第三十二篇】 异常(下)新手必学!

目录 前言: Exceptions(异常): 异常的两大作用: 异常的处理方式: 1.JVM默认处理 2.自己捕获异常 3.抛出处理 自定义异常: 异常的优点: 总结: 前言: 前…

LUMEN技术要点总结

LUMEN总结 主题是动态全局光照和Lumen Lumen更像是一个各种GI算法的集大成者。 1. 如何理解lumen及全局光照的实现机制 渲染方程 至今为止所有的实时光照都是按照Render Equation来进行渲染的,我们做得到只是在无限的逼近它。 我们把只进行一次反弹叫做SingleBou…

uni-app 经验分享,从入门到离职(一)——初始 uni-app,快速上手(文末送书福利1.0)

文章目录 📋前言🎯什么是 uni-app🎯创建第一个 uni-app 项目🧩前期工作🧩创建项目(熟悉默认项目、结构)🧩运行项目 📝最后🎯文末送书🔥参与方式 &…

客户方数据库服务器CPU负载高优化案例

客户方数据库服务器CPU负载高优化案例 背景 上周线上服务出现一个问题,打开某个页面,会导致其它接口请求响应超时,排查后发现数据库响应超400s,之前1s就可查到数据。 具体原因是有个大屏统计页面,会实时查看各业务服…

echarts坐标轴名称换行

一、期望效果: 期望超过6个字换行,最多可显示十个字 如图: 二、踩坑: echarts的width和overflow设置后换行无效。(如果其他人有设置有效的 还请说明下) 三、解决方案: 用\n换行&#xf…

Django + Xadmin 数据列表复选框显示为空,怎么修复这个问题?

问题描述: 解决方法: 后续发现的报错: 解决方案: 先根据报错信息定位到源代码: 在该文件顶部写入: from django.core import exceptions然后把: except models.FieldDoesNotExist修改为&…

qt6.5 download for kali/ubuntu ,windows (以及配置选项选择)

download and sign in qt官网 sign in onlion Install 1 2 3 4 5

SpringBoot整合WebService

SpringBoot整合WebService WebService是一个比较旧的远程调用通信框架,现在企业项目中用的比较少,因为它逐步被SpringCloud所取代,它的优势就是能够跨语言平台通信,所以还有点价值,下面来看看如何在SpringBoot项目中使…

Neo4j图数据基本操作

Neo4j 文章目录 Neo4jCQL结点和关系增删改查匹配语句 根据标签匹配节点根据标签和属性匹配节点删除导入数据目前的问题菜谱解决的问题 命令行窗口 neo4j.bat console 导入rdf格式的文件 :GET /rdf/ping CALL n10s.graphconfig.init(); //初始化 call n10s.rdf.import.fetch(&q…