Pytorch教程(代码逐行解释)

news2024/11/24 13:46:19

0、配准环境教程

1、开始导入相应的包

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

torch是pytorch的简写
torch.utils.data import DataLoader 是用于读取数据的迭代器
torchvision是视觉处理包,datasets导入的是视觉相关的数据集
transforms 是用于图像变换的。

2、下载数据集(准备数据集)

# Download training data from open datasets.
training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),
)

# Download test data from open datasets.
test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor(),
)

datasets.FashionMNIST,指的是一个数据集,这个数据集用于服饰的识别。FashionMNIST是一个非常流行的图像分类数据集,其中包含10个类别的70000个28x28灰度图像。
当然,pytorch还有很多其他的数据集格式。例如以下的数据集。其他数据集可点击这个连接在这里插入图片描述

3、加载数据集

batch_size = 64

# Create data loaders.
train_dataloader = DataLoader(training_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)

for X, y in test_dataloader:
    print(f"Shape of X [N, C, H, W]: {X.shape}")
    print(f"Shape of y: {y.shape} {y.dtype}")
    break

DataLoader是PyTorch中一个非常有用的模块,它主要用于批量加载数据,特别是当数据集非常大时,DataLoader可以极大地提高数据加载速度并减少内存占用。
DataLoader的主要功能包括:
批量处理数据:DataLoader可以将数据划分为多个批次(batch),每个批次包含一定数量的数据样本,然后一次处理一个批次的数据,这样可以大大减少内存占用。
数据打乱:通过设置shuffle=True参数,DataLoader可以在每个epoch开始时随机打乱数据集的顺序,这样可以增加模型的泛化能力。
batch_size 指的是每次读取的数据的大小,这里设置一次读取64张

4、创建训练的模型

# Get cpu, gpu or mps device for training.
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

# Define model
class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10)
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

model = NeuralNetwork().to(device)
print(model)

super().init()表示调用父类(nn.Module)的 init() 方法
self.flatten = nn.Flatten(),这行代码的作用主要是在神经网络模型中的作用是将输入数据从多维(例如二维或三维)转化为一维,这个操作通常被称为"flatten"。
在这个例子中,该模型预期的输入是一个形状为[batch_size, 28, 28]的张量,即一个包含多个(这里是28*28=784个)特征值的数据集。nn.Flatten()层将这个三维数据转化为一维数组,以便后续的线性层(nn.Linear)能以更高效的方式进行操作。

nn.Sequential 是 PyTorch 中一个用于创建顺序神经网络模型的模块。它是一个有序的容器,可以包含任意数量的其他模块。当你将数据输入到 nn.Sequential 模型时,数据会按照你在容器中定义的顺序通过每个模块。

5、设置优化器以及损失函数

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

损失函数还有很多种,其他的参考点击这个链接
优化器也有很多种,如ASGD,ADAM等等,其他的参考这个链接

6、模型的训练

定义训练的过程

def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if batch % 100 == 0:
            loss, current = loss.item(), (batch + 1) * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

从数据集中,每次取一个图像个标签进行训练,然后反向传播,梯度优化,完成训练。
item():.item()是用来从张量中提取标量值的方法。当你调用.item()方法时,如果张量中只有一个元素,那么这个元素会被返回;如果张量中有多个元素,则会抛出一个错误。

def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

correct += (pred.argmax(1) == y).type(torch.float).sum().item():解释:
(pred.argmax(1) == y):首先,这行代码通过argmax(1)获取了每个样本的预测类别。然后,它将预测类别与真实类别进行比较(==)。这将返回一个布尔型的张量,表示每个样本的预测是否正确。
(pred.argmax(1) == y).type(torch.float):接下来,这行代码将布尔型的张量转换为浮点型。在PyTorch中,布尔型的张量会自动转换为浮点型。
(pred.argmax(1) == y).type(torch.float).sum():然后,这行代码计算了所有样本中预测正确的总数。这是通过调用sum()函数实现的,该函数会返回一个张量中所有元素的和。
correct += …:最后,这行代码将预测正确的总数加到了变量correct上。+=是一个累加操作符,它将左侧的变量与右侧的表达式结果相加。

7、定义训练的轮次

epochs = 5
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(train_dataloader, model, loss_fn, optimizer)
    test(test_dataloader, model, loss_fn)
print("Done!")

8、保存模型

torch.save(model.state_dict(), "model.pth")
print("Saved PyTorch Model State to model.pth")

model.state_dict():解释:
model.state_dict()函数返回一个包含模型所有参数的字典,torch.save()函数则将这个字典保存到磁盘上的一个文件。

9、加载模型

model = NeuralNetwork().to(device)
model.load_state_dict(torch.load("model.pth"))

10、模型的测试

classes = [
    "T-shirt/top",
    "Trouser",
    "Pullover",
    "Dress",
    "Coat",
    "Sandal",
    "Shirt",
    "Sneaker",
    "Bag",
    "Ankle boot",
]

model.eval()
x, y = test_data[0][0], test_data[0][1]
with torch.no_grad():
    x = x.to(device)
    pred = model(x)
    predicted, actual = classes[pred[0].argmax(0)], classes[y]
    print(f'Predicted: "{predicted}", Actual: "{actual}"')

所有的完整代码:

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

# Download training data from open datasets.
training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),
)

# Download test data from open datasets.
test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor(),
)

batch_size = 64

# Create data loaders.
train_dataloader = DataLoader(training_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)

for X, y in test_dataloader:
    print(f"Shape of X [N, C, H, W]: {X.shape}")
    print(f"Shape of y: {y.shape} {y.dtype}")
    break

# Get cpu, gpu or mps device for training.
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

# Define model
class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10)
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

model = NeuralNetwork().to(device)
print(model)

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if batch % 100 == 0:
            loss, current = loss.item(), (batch + 1) * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

epochs = 5
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(train_dataloader, model, loss_fn, optimizer)
    test(test_dataloader, model, loss_fn)
print("Done!")

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1208877.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

java实现冒泡排序

冒泡排序是一种简单的排序算法&#xff0c;以下是Java实现示例代码&#xff1a; public static void bubbleSort(int[] array) {int n array.length;for (int i 0; i < n - 1; i) {for (int j 0; j < n - i - 1; j) {// 如果前面的元素比后面的元素大&#xff0c;就交…

H5游戏源码分享-网页版2048小游戏

H5游戏源码分享-网页版2048小游戏 玩过的都懂 <!DOCTYPE html> <html> <head><meta charset"utf-8"><title>分享2048到朋友圈&#xff0c;将免费参加南山郡8.17号啤酒狂欢节&#xff01;</title><link href"style/main…

Spring核心

Spring Framework Spring的两个核心IOC控制反转IOC容器依赖注入DIIOC容器实现注解管理BeanBean对象定义Bean对象获取 AOP面向切面编程 添加依赖入门案例注解通过Spring创建Java bean对象 xml管理Bean案例main下创建bean.XMl文件 DI依赖注入案例创建Spring配置文件 bean-di.xml …

25.4 MySQL 函数

1. 函数的介绍 1.1 函数简介 在编程中, 函数是一种组织代码的方式, 用于执行特定任务. 它是一段可以被重复使用的代码块, 通常接受一些输入(参数)然后返回一个输出. 函数可以帮助开发者将大型程序分解为更小的, 更易于管理的部分, 提高代码的可读性和可维护性.函数在编程语言…

线程有哪些状态

线程的生命周期 线程在Java中有以下几种状态&#xff1a; 新建&#xff08;New&#xff09;&#xff1a;初始化状态就绪&#xff08;Runnable&#xff09;&#xff1a;可运行、运行状态阻塞&#xff08;Blocked&#xff09;&#xff1a;等待状态&#xff0c;无时限等待&#…

Vue3-TypeScript-Threejs:导入外部的glb格式3D模型

一、直接上代码&#xff0c;在vue3-typescript-threejs 项目 导入外部的glb格式3D模型 极简代码&#xff0c;快速理解 <template><div ref"container"></div></template><script lang"ts" setup>import { onMounted, ref …

您的计算机已被Mallox勒索病毒感染?恢复您的数据的方法在这里!

尊敬的读者&#xff1a; 随着科技的迅速发展&#xff0c;网络安全问题日益凸显&#xff0c;其中勒索病毒是一种极具威胁性的恶意软件。在这些勒索病毒中&#xff0c;.mallox 勒索病毒尤为突出&#xff0c;它能够加密用户的数据文件&#xff0c;要求支付赎金才能解密。本文将介…

高效使用 PyMongo 进行 MongoDB 查询和插入操作

插入到集合中&#xff1a; 要将记录&#xff08;在MongoDB中称为文档&#xff09;插入到集合中&#xff0c;使用insert_one()方法。insert_one()方法的第一个参数是一个包含文档中每个字段的名称和值的字典。 import pymongomyclient pymongo.MongoClient("mongodb://l…

笔试题之指针和数组的精讲

&#x1d649;&#x1d65e;&#x1d658;&#x1d65a;!!&#x1f44f;&#x1f3fb;‧✧̣̥̇‧✦&#x1f44f;&#x1f3fb;‧✧̣̥̇‧✦ &#x1f44f;&#x1f3fb;‧✧̣̥̇:Solitary-walk ⸝⋆ ━━━┓ - 个性标签 - &#xff1a;来于“云”的“羽球人”。…

Ubuntu 和 Windows 文件互传

FTP 服务 FTP 采用 Internet 标准文件传输协议 FTP 的用户界面&#xff0c; 向用户提供了一组用来管理计算机之间文件传输的应用程序。在开发的过程中会频繁的在 Windows 和 Ubuntu 下进行文件传输&#xff0c;比如在 Windwos 下进行代码编写&#xff0c;然后将编写好的代码拿到…

JavaEE初阶(18)(JVM简介:发展史,运行流程、类加载:类加载的基本流程,双亲委派模型、垃圾回收相关:死亡对象的判断算法,垃圾回收算法,垃圾收集器)

接上次博客&#xff1a;初阶JavaEE&#xff08;17&#xff09;Linux 基本使用和 web 程序部署-CSDN博客 目录 JVM 简介 JVM 发展史 JVM 运行流程 JVM的内存区域划分 JVM 执行流程 堆 堆的作用 JVM参数设置 堆的组成 垃圾回收 堆内存管理 类加载 类加载的基本流…

Windows conan环境搭建

Windows conan环境搭建 1 安装conan1.1 安装依赖软件1.1.1 python安装1.1.2 git bash安装1.1.3 安装Visual Studio Community 20191.1.3.1 选择安装的组件1.1.3.2 选择要支持的工具以及对应的SDK 1.1.4 vscode安装 1.3 验证conan功能1.4 查看conancenter是否包含poco包1.5 查看…

SQL 日期函数

在数据库中&#xff0c;日期和时间是经常需要处理的数据类型之一。SQL提供了许多内置的日期函数&#xff0c;用于对日期和时间进行操作、计算和比较。这些函数可以帮助我们提取日期的各个部分&#xff08;如年份、月份、日、小时、分钟等&#xff09;&#xff0c;执行日期的转换…

第一百七十二回 SegmentedButton组件

文章目录 1. 概念介绍2. 使用方法2.1 SegmentedButton2.2 ButtonSegment 3. 代码与效果3.1 示例代码3.2 运行效果 4. 内容总结 我们在上一章回中介绍了"SearchBar组件"相关的内容&#xff0c;本章回中将 介绍SegmentedButton组件.闲话休提&#xff0c;让我们一起Tal…

抖斗音_快块手直播间获客助手+采集脚本+引流软件功能介绍

软件功能&#xff1a; 支持同时采集多个直播间&#xff0c;弹幕&#xff0c;关*注&#xff0c;礼*物&#xff0c;进直播间&#xff0c;部分用户手*号,粉*丝团采集 不支持采集匿*名直播间 设备需求&#xff1a; 电脑&#xff08;win10系统&#xff09; 文章分享者&#xff1…

【Linux】第十五站:环境变量

文章目录 一、进程相关的一些概念1.一些常见的概念2.对于并发3.**进程切换** 二、环境变量1.PATH环境变量2.HOME环境变量3.SHELL环境变量4.env5.系统调用接口与环境变量6.什么是环境变量&#xff1f;7.命令行参数8.main函数的第三个命令行参数9.如何验证环境变量是可以被继承的…

java实现选择排序

图解 以下是Java实现选择排序的示例代码&#xff1a; public class SelectionSort {public static void selectionSort(int[] arr) {int n arr.length;// 遍历未排序部分的数组for (int i 0; i < n - 1; i) {// 在未排序部分中查找最小元素的下标int minIndex i;for (in…

MySQL 人脸向量,欧几里得距离相似查询

前言 如标题&#xff0c;就是通过提取的人脸特征向量&#xff0c;写一个欧几里得 SQL 语句&#xff0c;查询数据库里相似度排前 TOP_K 个的数据记录。做法虽然另类&#xff0c;业务层市面上有现成的面部检索 API&#xff0c;技术层现在有向量数据库。 用 MySQL 关系型存储 128 …

新学期帮娃把拖延症戒了!这个时间管理器太太太有用啦!

十个孩子九个拖延~ 不要唠叨&#xff0c;不要指责 时间流逝一眼可见&#xff0c;打败拖延症&#xff01; 赶紧把这款时间管理器用上 当当狸时间管理器 说起孩子没有时间观念、拖延症 每个老母亲都有一肚子苦水要倒&#xff5e;&#xff5e; 市面上有很多计时器&#xff0…

【k8s集群搭建(一):基于虚拟机的linux的k8s集群搭建_超详细_解决并记录全过程步骤以及自己的踩坑记录】

虚拟机准备3台Linux系统 k8s集群安装 每一台机器需要安装以下内容&#xff1a; docker:容器运行环境 kubelet:控制机器中所有资源 bubelctl:命令行 kubeladm:初始化集群的工具 Docker安装 安装一些必要的包&#xff0c;yum-util 提供yum-config-manager功能&#xff0c;另两…