PyTorch 快速入门

news2025/2/1 3:47:10

我们将通过一个简单的示例,快速了解如何使用 PyTorch 进行机器学习任务。PyTorch 是一个开源的机器学习库,它提供了丰富的工具和库,帮助我们轻松地构建、训练和测试神经网络模型。以下是本教程的主要内容:

一、数据处理

PyTorch 提供了两个基本的数据处理工具:torch.utils.data.DataLoadertorch.utils.data.DatasetDataset 用于存储样本及其对应的标签,而 DataLoader 则为 Dataset 提供了一个可迭代的包装器。

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

PyTorch 还提供了许多特定领域的库,如 TorchText、TorchVision 和 TorchAudio,这些库中都包含了各种数据集。在本教程中,我们将使用 TorchVision 中的 FashionMNIST 数据集。每个 TorchVision Dataset 都包含两个参数:transformtarget_transform,分别用于修改样本和标签。

# 下载 FashionMNIST 数据集
training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor()
)

test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
)

我们将 Dataset 作为参数传递给 DataLoaderDataLoader 为我们的数据集提供了一个可迭代的包装器,并支持自动批处理、采样、洗牌和多进程数据加载。在这里,我们定义了一个大小为 64 的批次,即数据加载器可迭代对象的每个元素将返回一个包含 64 个特征和标签的批次。

train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

二、创建模型

在 PyTorch 中,我们通过创建一个继承自 nn.Module 的类来定义神经网络。我们在 __init__ 函数中定义网络的层,并在 forward 函数中指定数据如何通过网络传递。为了加速神经网络中的操作,我们将网络移动到加速器(如 CUDA、MPS、MTIA 或 XPU)上。如果当前加速器可用,我们将使用它;否则,我们将使用 CPU。

class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10)
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

三、优化模型参数

为了训练模型,我们需要一个损失函数和一个优化器。在一个训练循环中,模型会对训练数据集(以批次形式提供)进行预测,并将预测误差反向传播以调整模型的参数。

def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        # 计算预测误差
        pred = model(X)
        loss = loss_fn(pred, y)

        # 反向传播
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if batch % 100 == 0:
            loss, current = loss.item(), (batch + 1) * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

我们还需要检查模型在测试数据集上的性能,以确保它正在学习。

def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

训练过程会在多个迭代(epoch)中进行。在每个 epoch 中,模型会学习参数以做出更好的预测。我们在每个 epoch 打印模型的准确率和损失;我们希望看到准确率随着每个 epoch 的增加而增加,损失则随着每个 epoch 的增加而减少。

for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(train_dataloader, model, loss_fn, optimizer)
    test(test_dataloader, model, loss_fn)
print("Done!")

四、保存模型

保存模型的一种常见方法是序列化内部状态字典(包含模型参数)。

torch.save(model.state_dict(), "model.pth")
print("Saved PyTorch Model State to model.pth")

五、加载模型

加载模型的过程包括重新创建模型结构,并将状态字典加载到其中。

model = NeuralNetwork()
model.load_state_dict(torch.load("model.pth"))

现在,我们可以使用这个模型来进行预测。

classes = [
    "T-shirt/top",
    "Trouser",
    "Pullover",
    "Dress",
    "Coat",
    "Sandal",
    "Shirt",
    "Sneaker",
    "Bag",
    "Ankle boot",
]

model.eval()
x, y = test_data[0][0], test_data[0][1]
with torch.no_grad():
    x = x.to(device)
    pred = model(x)
    predicted, actual = classes[pred[0].argmax(0)], classes[y]
    print(f'Predicted: "{predicted}", Actual: "{actual}"')

以上就是使用 PyTorch 进行机器学习任务的基本流程。

六、完整代码

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

# 下载 FashionMNIST 数据集
training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor()
)

test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
)

# 创建数据加载器
train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

# 定义神经网络模型
class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10)
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

# 检查是否有可用的加速器
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

# 实例化模型并移动到加速器上
model = NeuralNetwork().to(device)
print(model)

# 定义损失函数和优化器
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

# 定义训练函数
def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        # 计算预测误差
        pred = model(X)
        loss = loss_fn(pred, y)

        # 反向传播
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if batch % 100 == 0:
            loss, current = loss.item(), (batch + 1) * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

# 定义测试函数
def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

# 训练和测试模型
epochs = 5
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(train_dataloader, model, loss_fn, optimizer)
    test(test_dataloader, model, loss_fn)
print("Done!")

# 保存模型
torch.save(model.state_dict(), "model.pth")
print("Saved PyTorch Model State to model.pth")

# 加载模型
model = NeuralNetwork()
model.load_state_dict(torch.load("model.pth"))

# 使用模型进行预测
classes = [
    "T-shirt/top",
    "Trouser",
    "Pullover",
    "Dress",
    "Coat",
    "Sandal",
    "Shirt",
    "Sneaker",
    "Bag",
    "Ankle boot",
]

model.eval()
x, y = test_data[0][0], test_data[0][1]
with torch.no_grad():
    x = x.to(device)
    pred = model(x)
    predicted, actual = classes[pred[0].argmax(0)], classes[y]
    print(f'Predicted: "{predicted}", Actual: "{actual}"')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2289147.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【浏览器 - Mac实时调试iOS手机浏览器页面】

最近开发个项目,需要在 Mac 电脑上调试 iOS 手机设备上的 Chrome 浏览器,并查看Chrome网页上的 console 信息,本来以为要安装一些插件,没想到直接使用Mac上的Safari 直接可以调试,再此记录下,分享给需要的伙…

PyQt5之QtDesigner的若干配置和使用

1.描述 QtDesigner是一个可视化工具,可以通过该工具设计页面 2.简单使用 1.下载PyQt5-tools pip install pyqt5-tools 2.打开designer.exe文件 我采用的是虚拟环境,该文件位于C:\Users\24715\anaconda3\envs\pyqt\Lib\site-packages\qt5_applicatio…

侯捷C++day01

一个类该准备什么样的数据、函数。才能满足使用这个类人的需求。 inline关键字是建议编译器做inline处理。 private只有本类可以看到。 C创建对象会自动调用构造函数。不可能在程序中显示调用构造函数。不带指针的类多半不用写析构函数。 以下两个重载构造函数会发生错误 不允许…

CTF-web: phar反序列化+数据库伪造 [DASCTF2024最后一战 strange_php]

step 1 如何触发反序列化? 漏洞入口在 welcome.php case delete: // 获取删除留言的路径,优先使用 POST 请求中的路径,否则使用会话中的路径 $message $_POST[message_path] ? $_POST[message_path] : $_SESSION[message_path]; $msg $userMes…

Win11下帝国时代2无法启动解决方法

鼠标右键点图标,选择属性 点开始,输入启用和关闭

GSI快速收录服务:让你的网站内容“上架”谷歌

辛苦制作的内容无法被谷歌抓取和展示,导致访客无法找到你的网站,这是会让人丧失信心的事情。GSI快速收录服务就是为了解决这种问题而存在的。无论是新上线的页面,还是长期未被收录的内容,通过我们的技术支持,都能迅速被…

mysql_init和mysql_real_connect的形象化认识

解析总结 1. mysql_init 的作用 mysql_init 用于初始化一个 MYSQL 结构体,为后续数据库连接和操作做准备。该结构体存储连接配置及状态信息,是 MySQL C API 的核心句柄。 示例: MYSQL *conn mysql_init(NULL); // 初始化连接句柄2. mysql_…

python学opencv|读取图像(四十九)原理探究:使用cv2.bitwise()系列函数实现图像按位运算

【0】基础定义 按位与运算:两个等长度二进制数上下对齐,全1取1,其余取0。 按位或运算:两个等长度二进制数上下对齐,有1取1,其余取0。 按位异或运算: 两个等长度二进制数上下对齐,相…

基础项目实战——学生管理系统(c++)

目录 前言一、功能菜单界面二、类与结构体的实现三、录入学生信息四、删除学生信息五、更改学生信息六、查找学生信息七、统计学生人数八、保存学生信息九、读取学生信息十、打印所有学生信息十一、退出系统十二、文件拆分结语 前言 这一期我们来一起学习我们在大学做过的课程…

春节期间,景区和酒店如何合理用工?

春节期间,景区和酒店如何合理用工? 春节期间,旅游市场将迎来高峰期。景区与酒店,作为旅游产业链中的两大核心环节,承载着无数游客的欢乐与期待。然而,也隐藏着用工管理的巨大挑战。如何合理安排人力资源&a…

Linux Samba 低版本漏洞(远程控制)复现与剖析

目录 前言 漏洞介绍 漏洞原理 产生条件 漏洞影响 防御措施 复现过程 结语 前言 在网络安全的复杂生态中,系统漏洞的探索与防范始终是保障数字世界安全稳定运行的关键所在。Linux Samba 作为一款在网络共享服务领域应用极为广泛的软件,其低版本中…

【Numpy核心编程攻略:Python数据处理、分析详解与科学计算】1.27 线性代数王国:矩阵分解实战指南

1.27 线性代数王国:矩阵分解实战指南 #mermaid-svg-JWrp2JAP9qkdS2A7 {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-JWrp2JAP9qkdS2A7 .error-icon{fill:#552222;}#mermaid-svg-JWrp2JAP9qkdS2A7 .erro…

初二回娘家

昨天下午在相亲相爱一家人群里聊天,今天来娘家拜年。 聊天结束后,开始准备今天的菜肴,梳理了一下,凉菜,热菜,碗菜。 上次做菜,粉丝感觉泡的不透,有的硬,这次使用开水浸泡…

【Block总结】PKI 模块,无膨胀多尺度卷积,增强特征提取的能力|即插即用

论文信息 标题: Poly Kernel Inception Network for Remote Sensing Detection 作者: Xinhao Cai, Qiuxia Lai, Yuwei Wang, Wenguan Wang, Zeren Sun, Yazhou Yao 论文链接:https://arxiv.org/pdf/2403.06258 代码链接:https://github.com/NUST-Mac…

Blazor-@bind

数据绑定 带有 value属性的标记都可以使用bind 绑定&#xff0c;<div>、<span>等非输入标记&#xff0c;无法使用bind 指令的&#xff0c;默认绑定了 onchange 事件&#xff0c;onchange 事件是指在输入框中输入内容之后&#xff0c;当失去焦点时执行。 page &qu…

架构技能(六):软件设计(下)

我们知道&#xff0c;软件设计包括软件的整体架构设计和模块的详细设计。 在上一篇文章&#xff08;见 《架构技能&#xff08;五&#xff09;&#xff1a;软件设计&#xff08;上&#xff09;》&#xff09;谈了软件的整体架构设计&#xff0c;今天聊一下模块的详细设计。 模…

C++并发编程指南07

文章目录 [TOC]5.1 内存模型5.1.1 对象和内存位置图5.1 分解一个 struct&#xff0c;展示不同对象的内存位置 5.1.2 对象、内存位置和并发5.1.3 修改顺序示例代码 5.2 原子操作和原子类型5.2.1 标准原子类型标准库中的原子类型特殊的原子类型备选名称内存顺序参数 5.2.2 std::a…

MySQL 容器已经停止(但仍然存在),但希望重新启动它,并使它的 3306 端口映射到宿主机的 3306 端口是不可行的

重新启动容器并映射端口是不行的 由于你已经有一个名为 mysql-container 的 MySQL 容器&#xff0c;你可以使用 docker start 启动它。想要让3306 端口映射到宿主机是不行的&#xff0c;实际上&#xff0c;端口映射是在容器启动时指定的。你无法在容器已经创建的情况下直接修改…

春晚舞台上的人形机器人:科技与文化的奇妙融合

文章目录 人形机器人Unitree H1的“硬核”实力传统文化与现代科技的创新融合网友热议与文化共鸣未来展望&#xff1a;科技与文化的更多可能结语 2025 年央视春晚的舞台&#xff0c;无疑是全球华人目光聚焦的焦点。就在这个盛大的舞台上&#xff0c;一场名为《秧BOT》的创意融合…

将pandas.core.series.Series类型的小数转化成百分数

大年初二&#xff0c;大家过年好&#xff0c;蛇年行大运&#xff01; 今天在编写一个代码的时候&#xff0c;使用 import pandas as pd产生了pandas.core.series.Series类型的数据&#xff0c;里面有小数&#xff0c;样式如下&#xff1a; 目的&#xff1a;将这些小数转化为百…