Pytorch学习笔记#2: 搭建神经网络训练MNIST手写数字数据集

news2024/9/23 15:19:41

学习自https://pytorch.org/tutorials/beginner/basics/quickstart_tutorial.html

导入并预处理数据集

pytorch中数据导入和预处理主要用torch.utils.data.DataLoader 和 torch.utils.data.Dataset
Dataset 存储样本及其相应的标签,DataLoader在数据上生成一个可迭代对象(Dataset stores the samples and their corresponding labels, and DataLoader wraps an iterable around the Dataset.)

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

# Download training data from open datasets.
training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),
)

# Download test data from open datasets.
test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor(),
)

将数据集作为参数传递给 DataLoader。 这在我们的数据集上包装了一个可迭代对象,并支持自动批处理、采样、混洗和多进程数据加载。并且每一个batch大小为64。

batch_size = 64

# Create data loaders.
train_dataloader = DataLoader(training_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)

for X, y in test_dataloader:
    print(f"Shape of X [N, C, H, W]: {X.shape}")
    print(f"Shape of y: {y.shape} {y.dtype}")
    break

搭建神经网络

MNIST手写数字数据集的图片是2828的,所以第一层的输入为2828。
因为识别结果是0~9这10种,所以最后一层的输出就是10个。

我们需要定义神经网络结构,这部分在__init__(self)部分实现。
且我们需要forward部分定义网络正向传播的方法。

class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28 * 28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10)
        )
    
    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

model = NeuralNetwork().to(device)
print(model)

训练模型

首先,我们需要先定义损失函数和优化器(优化梯度下降算法)

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) # lr为学习率

在一次循环中,神经网络通过forward进行预测(我们写的forward函数),然后再利用预测误差。通过反向传播来进行梯度下降(pytorch帮我们实现)。

def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            loss, current = loss.item(), (batch + 1) * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")
def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

开始训练!

epochs = 5
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(train_dataloader, model, loss_fn, optimizer)
    test(test_dataloader, model, loss_fn)
print("Done!")

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/392447.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

刷题小抄4-数组

在Python中数组的功能由列表来实现,本文主要介绍一些力扣上关于数组的题目解法 寻找数组中重复的数字 题目链接 题目大意: 给出一个数组,数组长度为n,数组里的数字在[0,n-1]范围以内,数字可以重复,寻找出数组中任意一个重复的数字,返回结果 解法一 该题最基础的思路是使用字…

[java Spring JdbcTemplate配合mysql实现数据批量删除

之前的文章 java Spring JdbcTemplate配合mysql实现数据批量添加和文章java Spring JdbcTemplate配合mysql实现数据批量修改 先后讲解了 mysql数据库的批量添加和批量删除操作 会了这两个操作之后 批量删除就不要太简单 我们看到数据库 这里 我们用的是mysql工具 这里 我们有…

Java——单词接龙

题目链接 leetcode在线oj题——单词接龙 题目描述 字典 wordList 中从单词 beginWord 和 endWord 的 转换序列 是一个按下述规格形成的序列 beginWord -> s1 -> s2 -> … -> sk&#xff1a; 每一对相邻的单词只差一个字母。 对于 1 < i < k 时&#xff…

QML Text详解

1.简介 文本项可以显示普通文本和富文本。 2.示例 示例1&#xff1a;一个简单的text&#xff0c;可以设置字体颜色、大小等。 Window {visible: truewidth: 400height: 400title: qsTr("Hello World")Rectangle{width: 200height: 200border.width: 2Text {text: …

(flutter)黑苹果系统 Xcode iOS flutter 跑通真机模拟器 此oc clover 彼oc swift

前段时间写了关于flutter的一系列基础知识和入门的一些坑&#xff0c;中间把ios端的项目编译部署等工作一带而过&#xff0c;这里我觉得还是有必要专门写一篇文章来讲讲这个&#xff0c;顺便把环境问题也一起说了。 我们都知道开发ios应用需要用到苹果电脑&#xff0c;即使flu…

【NLP】Word2Vec 介绍

Word2Vec 是一种非常流行的自然语言处理技术&#xff0c;它将每个单词表示为高维向量&#xff0c;并且通过向量之间的相似度来表示单词之间的语义关系。 1 One-Hot 编码&#x1f342; 在自然语言处理任务中&#xff0c;我们需要将文本转换为计算机可以理解的形式&#xff0c;即…

ChatGPT后劲很大,问题也是

ChatGPT亮相即封神&#xff0c;最初的访客是程序员、工程师、AI从业者、投资人&#xff0c;最后是无数懵懂又好奇的普通人&#xff1a;ChatGPT是什么&#xff1f;自己会被ChatGPT取代吗&#xff1f;看待ChatGPT的立场也是两个极端&#xff1a; 快乐&#xff0c;是因为ChatGPT太…

科普| 什么是云原生?

“新冠疫情从根本上改变了商业模式&#xff0c;工作流向在线迁移的速度比以往任何时候都要快。越来越多的公司和消费者依靠电子商务“ B2B”和B2C”&#xff0c;以及网上银行促进创新以满足日益增长的客户需求&#xff0c;云原生技术在其中发挥重要作用&#xff0c;同时也加速了…

vm centos7搭建k8s集群

关闭防火墙&#xff0c;三台systemctl stop firewalld关闭selinux&#xff0c;三台sed -i s/enforcing/disabled/ /etc/selinux/config关闭swap&#xff0c;三台swapoff -a设置主机名&#xff0c;三台hostnamectl set-hostname 主机名&#xff0c;三个主机名分别设置成k8s-mast…

JavaScript新手学习手册-基础代码(一)

什么是JavaScript&#xff1f; 百度百科 什么是控制台&#xff1f; 网页➡快捷键F12 进入Console就是控制台&#xff0c;它的作用与开发软件相同&#xff0c;可以进行代码的编写在紫色位置进行编写&#xff0c;另外console.log()方法所打印的内容都是在此进行输出。 一&#…

Spark Join

Spark Join关联形式内关联外关联左外关联右外关联全外关联左半/逆关联关联机制NLJSMJHJ分发模式Join 选择等值 Join不等值 JoinJoin 按照关联形式&#xff08;Join Types&#xff09;划分 : 内关联、外关联、左关联、右关联 Join 按实现机制划分 : NLJ (Nested Loop Join) 、S…

【操作系统原理实验】页面替换策略模拟实现

选择一种高级语言如C/C等&#xff0c;编写一个页面替换算法的模拟实现程序。1) 设计内存管理相关数据结构&#xff1b;2) 随机生成一个页面请求序列&#xff1b;3) 设置内存管理模拟的关键参数&#xff1b;4) 实现该页面置换算法&#xff1b;5) 模拟实现给定配置请求序列的换页…

【python socket】实现websocket服务端

一、获取握手信息首先通过如下代码&#xff0c;我们使用socket来获取客户端的握手信息import socketsock socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) sock.bind(("127.0.0.1", 8002)) sock.li…

启动项管理工具Autoruns使用实验(20)

实验目的 &#xff08;1&#xff09;了解注册表的相关知识&#xff1b; &#xff08;2&#xff09;了解程序在开机过程中的自启动&#xff1b; &#xff08;3&#xff09;掌握Autoruns在注册表和启动项方面的功能&#xff1b;预备知识 注册表是windows操作系统中的一个核心数据…

Android Framework-Android启动过程

第一个系统进程&#xff08;init&#xff09; Android设备的启动必须经历3个阶段&#xff0c;即Boot Loader、Linux Kernel和Android系统服务&#xff0c;默认情况下它们都有各自的启动界面。严格来说&#xff0c;Android系统实际上是运行于Linux内核之上的一系列“服务进程”…

元宇宙XR应用,如何迎接大规模普及的时代?

未来&#xff0c;具有互动性、沉浸感的元宇宙/XR应用将逐渐成为主流&#xff0c;这个趋势已毋庸置疑。 然而&#xff0c;在大趋势下&#xff0c;大众终端用户普遍设备能力不足、网络传输时延、GPU算力分配限制等技术挑战&#xff0c;依然是元宇宙/XR应用在大众广泛渗透的瓶颈。…

【vulhub漏洞复现】Fastjson 1.2.24反序列化漏洞

一、漏洞详情Fastjson 是一个 Java 库&#xff0c;可以将 Java 对象转换为 JSON 格式&#xff0c;也可以将 JSON 字符串转换为 Java 对象。漏洞成因&#xff1a;目标网站在解析 json 时&#xff0c;未对 json 内容进行验证&#xff0c;直接将 json 解析成 java 对象并执行&…

国产数字源表在压力传感器电阻测量上的应用

压力传感器分类压力传感器(Pressure Transducer)是能感受压力信号&#xff0c;并能按照一定的规律将压力信号转换成可用的输出的电信号的器件或装置,压力传感器通常由压力敏感元件和信号处理单元组成。常见的压力传感器有四种:应变式压力传感器、压阻式压力传感器、电容式压力传…

OpenMMLab 目标检测

OpenMMLab 目标检测1. 目标检测简介1.1 滑窗2. 基础知识2.1 边界框&#xff08;Bounding Box&#xff09;3. 两阶段目标检测算法3.1 多尺度检测技术4. 单阶段目标检测算法4.1 YOLO: You Only Look Once (2015)4.2 SSD: Single Shot MultiBox Detetor (2016)5. 无锚框目标检测算…

Nginx的搭建与核心配置

目录 一.Nginx是什么&#xff1f; 1.Nginx概述 2.Nginx模块与作用 3.Nginx三大作用&#xff1a;反向代理、负载均衡、动静分离 二.Nginx和Apache的差异 三.安装Nginx 1.编译安装 2.yum安装 四.Nginx的信号使用 五.Nginx的核心配置指令 1.访问状态统计配置 2.基于授…