pytorch基础2-数据集与归一化

news2024/9/24 15:18:30

专题链接:https://blog.csdn.net/qq_33345365/category_12591348.html

本教程翻译自微软教程:https://learn.microsoft.com/en-us/training/paths/pytorch-fundamentals/

初次编辑:2024/3/2;最后编辑:2024/3/2


本教程第一篇介绍pytorch基础和张量操作

这是本教程的第二篇,介绍以下内容:

  1. 数据集与数据加载器
  2. 归一化

另外本人还有pytorch CV相关的教程,见专题:

https://blog.csdn.net/qq_33345365/category_12578430.html


数据集与数据加载器 Datasets and Dataloaders

处理数据样本的代码可能会变得复杂且难以维护。通常希望数据集代码与模型训练代码解耦,以获得更好的可读性和模块化性。PyTorch提供了两个数据原语:torch.utils.data.DataLoadertorch.utils.data.Dataset,它们使您能够使用预加载的数据集以及您自己的数据。Dataset存储样本及其对应的标签,而DataLoader则在Dataset周围包装了一个可迭代对象,以便轻松访问样本。

PyTorch库提供了许多预加载的样本数据集(例如FashionMNIST),它们是torch.utils.data.Dataset的子类,并实现了特定于特殊数据的功能。用于模型原型设计(prototyping)和基准测试(benchmark)的样本包括:

  1. 图像数据集
  2. 文本数据集
  3. 音频数据集

加载数据集

此处使用TorchVision来加载Fashion-MNIST数据集。Fashion-MNIST是Zalando的服装图像数据集,包含60,000个训练样本和10,000个测试样本。每个样本包括一个28×28的灰度图像和一个来自10个类别中的关联标签。

  • 每个图像高度为28个像素,宽度为28个像素,总共有784个像素。
  • 这10个类别告诉图像是什么类型,例如:T恤/上衣、裤子、套头衫、连衣裙、包、短靴等。
  • 灰度像素的值介于0到255之间,用于衡量黑白图像的强度。强度值从白色到黑色递增。例如:白色的颜色值为0,而黑色的颜色值为255。

我们使用以下参数加载FashionMNIST数据集:

  • root 是存储训练/测试数据的路径。
  • train 指定训练或测试数据集。
  • download=True 如果数据在root中不可用,则从互联网下载数据。
  • transformtarget_transform 指定特征和标签的转换。

加载训练和测试数据集的代码如下所示:

import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda
import matplotlib.pyplot as plt

training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor()
)

test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
)

数据集的迭代与可视化

可以像列表一样手动索引Datasets: training_data[index]。此处使用matplotlib对训练数据中的一些样本进行可视化。代码如下所示:

labels_map = {
    0: "T-Shirt",
    1: "Trouser",
    2: "Pullover",
    3: "Dress",
    4: "Coat",
    5: "Sandal",
    6: "Shirt",
    7: "Sneaker",
    8: "Bag",
    9: "Ankle Boot",
}
figure = plt.figure(figsize=(8, 8))
cols, rows = 3, 3
for i in range(1, cols * rows + 1):
    sample_idx = torch.randint(len(training_data), size=(1,)).item()
    img, label = training_data[sample_idx]
    figure.add_subplot(rows, cols, i)
    plt.title(labels_map[label])
    plt.axis("off")
    plt.imshow(img.squeeze(), cmap="gray")
plt.show()

得到如下图片:

在这里插入图片描述

使用数据加载器(DataLoaders)准备自定义数据

Dataset 逐个样本检索数据集的特征和标签。在训练模型时,通常希望以“minibatch”的形式传递样本,在每个周期重混洗(reshuffle)数据以减少模型过拟合,并使用Python的多进程加速数据检索。

在机器学习中,需要指定数据集中的特征和标签。**特征(Feature)**是输入,**标签(label)**是输出。我们训练特征,然后训练模型以预测标签。

  • 特征是图像像素中的模式。
  • 标签是10个类别类型:T恤,凉鞋,连衣裙等。

DataLoader是一个可迭代对象,用简单的API抽象了这种复杂性。为了使用DataLoader,我们需要设置以下参数:

  • data 用于训练模型的训练数据,以及用于评估模型的测试数据。
  • batch size 每个批次要处理的记录数。
  • shuffle 按索引随机抽取数据样本。

使用DataLoader处理两个数据集的代码如下所示:

from torch.utils.data import DataLoader

train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

使用DataLoader进行迭代

我们已将数据集加载到 DataLoader 中,现在可以根据需要迭代数据集。 下面的每次迭代都会返回一个批次的 train_featurestrain_labels(分别包含 batch_size=64 个特征和标签)。由于我们指定了 shuffle=True,在遍历完所有批次后,数据将被混洗(shuffle),以更精细地控制数据加载顺序。

展示图片和标签的代码如下所示:

# Display image and label.
train_features, train_labels = next(iter(train_dataloader))
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")
img = train_features[0].squeeze()
label = train_labels[0]
plt.imshow(img, cmap="gray")
plt.show()
label_name = list(labels_map.values())[label]
print(f"Label: {label_name}")

得到图片:

在这里插入图片描述

归一化

标准化是一种常见的数据预处理技术,用于对数据进行缩放或转换,以确保每个特征都有相等的学习贡献。例如,灰度图像中的每个像素的值都介于0和255之间,这些是特征。如果一个像素值为17,另一个像素为197,则像素重要性的分布将不均匀,因为较高的像素值会偏离学习。标准化改变了数据的范围,而不会扭曲其特征之间的区别。这种预处理是为了避免:

  • 减少预测精度
  • 模型学习困难
  • 特征数据范围的不利分布

Transforms

数据并不总是以最终处理过的形式呈现,这种形式适合训练机器学习算法。可以使用transforms来操作数据,使其适合训练。

所有 TorchVision 数据集都有两个参数(transform用于修改特征,target_transform用于修改标签),这些参数接受包含转换逻辑的可调用对象。torchvision.transforms 模块提供了几种常用的转换。

FashionMNIST 的特征是 PIL 图像格式,标签是整数。 对于训练,需要将特征转换为归一化的张量,将标签转换为 one-hot 编码的张量。 为了进行这些转换,将使用 ToTensorLambda

from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda

ds = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),
    target_transform=Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1))
)

解释上述代码:

ToTensor()

ToTensor 将PIL图像或NumPy ndarray转换为FloatTensor,并将图像的像素强度值缩放到范围 [0., 1.]内。

Lambda transforms

Lambda transforms 应用任何用户定义的lambda函数。在这里,我们定义了一个函数来将整数转换为一个one-hot编码的张量。它首先创建一个大小为10的零张量(数据集中标签的数量),然后调用scatter,根据标签y的索引为其分配值为1。也可以使用torch.nn.functional.one_hot的方式来实现这一点。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1483508.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

swagger在java中的基本使用

自动生成接口文档&#xff0c;和在线接口测试的框架。 导入依赖 <!-- knife4j对swagger进行一个封装--><dependency><groupId>com.github.xiaoymin</groupId><artifactId>knife4j-spring-boot-starter</artifactId><versi…

StarCoder2本地部署上手体验:程序猿要下岗了吗?

StarCoder2简介 ServiceNow、Hugging Face 和 NVIDIA 于2月28日宣布发布 StarCoder2&#xff0c;这是一个用于代码生成的开放式大型语言模型系列&#xff0c;为性能、透明度和成本效益设定了新标准。StarCoder2 是与 BigCode 社区合作开发的&#xff0c;由ServiceNow和 Huggin…

labview数组精讲

题主经过写文章一段时间的发现,许多同学对该软件的理解和编程能力是不太一样的,有些知识相对一些同学较为简单,但是有些同学提问就比较困难。那么针对这个问题,题主打算出一期说白话系列的专栏,在该栏目中用最通俗的大白话和例子去让大家深刻了解这个软件的功能和摸透他的…

Py2neo查询neo4j周杰伦数据库中的节点、关系和路径教程

文章目录 py2neo介绍连接Neo4j数据库py2neo查询图数据库neo4j数据概览使用NodeMatcher查询节点使用RelationshipMatcher查询关系 通过执行Cypher语句查询 py2neo介绍 Neo4j是一款开源图数据库&#xff0c;Py2neo提供了使用Python语言访问Neo4j的接口。本文介绍了使用Py2neo的N…

钉钉机器人发送折线图卡片 工具类代码

钉钉机器人 “创建并投放卡片 接口 ” 可以 发送折线图、柱状图 官方文档&#xff1a;创建并投放卡片 - 钉钉开放平台 0依赖、1模板、2机器人放到内部应用、3放开这个权限 、4工具类、5调用工具类 拼接入参 卡片模板 自己看文档创建&#xff0c;卡片模板的id 有用 0、依赖…

vue2 设置keepAlive之后怎么刷新页面数据

场景&#xff1a;移动端有 A、B、C 三个页面&#xff0c;A、B 页面路由设置了keepAlive属性&#xff0c;有下面两个场景&#xff1a; 1、A 页面 --> B 页面&#xff0c;B 页面刷新。 2、C 页面 --> B页面&#xff0c;B 页面不刷新。 一、分为以下两个情况讨论&#xf…

Linux安装Nginx配置Keepalived高可用

Vmwaire 安装 Linux 解决启动没有IP地址问题 cd /etc/sysconfig/network-scripts vi ifcfg-ens33# 重启linux reboot # 再次查看ip ip addrLinux 镜像地址下载 ps: 发现阿里有一个工具箱&#xff0c;里面有各种镜像 阿里镜像地址 https://developer.aliyun.com/mirror/ 安装…

计算机设计大赛 深度学习图像修复算法 - opencv python 机器视觉

文章目录 0 前言2 什么是图像内容填充修复3 原理分析3.1 第一步&#xff1a;将图像理解为一个概率分布的样本3.2 补全图像 3.3 快速生成假图像3.4 生成对抗网络(Generative Adversarial Net, GAN) 的架构3.5 使用G(z)生成伪图像 4 在Tensorflow上构建DCGANs最后 0 前言 &#…

windows安装部署node.js以及搭建运行第一个Vue项目

一、官网下载安装包 官网地址&#xff1a;https://nodejs.org/zh-cn/download/ 二、安装程序 1、安装过程 如果有C/C编程的需求&#xff0c;勾选一下下图所示的部分&#xff0c;没有的话除了选择一下node.js安装路径&#xff0c;直接一路next 2、测试安装是否成功 【winR】…

cv_bridge连接自定义版本的opencv

在ros noetic版本中&#xff0c;默认的cv_bridge依赖的opencv版本为4.2.0&#xff0c;若要升级opencv版本&#xff0c;则无法使用cv_bridge&#xff0c;所以需要重新自编译cv_bridge。 一. 编译cv_bridge 1.通过网站 https://github.com/ros-perception/vision_opencv/tree/n…

MYSQL的优化学习,从原理到索引,在到事务和锁机制,最后的主从复制、读写分离和分库分表

mysql的优化学习 为什么选择Mysql不选择其他的数据库&#xff1f;还有哪些&#xff0c;有什么区别&#xff1f; Mysql&#xff1a;开源免费版本可用&#xff0c;适用于中小型应用 Oracle&#xff1a;适用于大型企业级应用&#xff0c;复杂的业务场景和大量数据的处理&#xf…

Acwing 每日一题 空调 差分 贪心

&#x1f468;‍&#x1f3eb; 空调 &#x1f468;‍&#x1f3eb; 参考题解 import java.util.Scanner;public class Main {static int N (int) 1e5 10;static int[] a new int[N];static int n;public static void main(String[] args){Scanner sc new Scanner(System.…

智能家居控制系统(51单片机)

smart_home_control_system 51单片机课设&#xff0c;智能家居控制系统 使用及转载请标明出处&#xff08;最好点个赞及star哈哈&#xff09; Github地址&#xff0c;带有PPT及流程图 Gitee码云地址&#xff0c;带有PPT及流程图 ​ 以STC89C52为主控芯片&#xff0c;以矩阵键…

【OpenCV C++】Mat img.total() 和img.cols * img.rows 意思一样吗?二者完全相等吗?

文章目录 1 结论及区别2 Mat img的属性 介绍1 结论及区别 在大多数情况下,img.total() 和 img.cols * img.rows 是相等的,但并不总是完全相等的。下面是它们的含义和一些区别: 1.img.total() 表示图像中像素的总数,即图像的总像素数量。2.img.cols * img.rows 也表示图像中…

C++学习笔记:二叉搜索树

二叉搜索树 什么是二叉搜索树?搜索二叉树的操作查找插入删除 二叉搜索树的应用二叉搜索树的代码实现K模型:KV模型 二叉搜索树的性能怎么样? 什么是二叉搜索树? 二叉搜索树又称二叉排序树&#xff0c;它或者是一棵空树&#xff0c;或者是具有以下性质的二叉树: 若它的左子树…

【go语言开发】swagger安装和使用

本文主要介绍go-swagger的安装和使用&#xff0c;首先介绍如何安装swagger&#xff0c;测试是否成功&#xff1b;然后列出常用的注释和给出使用例子&#xff1b;最后生成接口文档&#xff0c;并在浏览器上测试 文章目录 安装注释说明常用注释参考例子 文档生成格式化文档生成do…

代码随想录算法训练营第二十八天|93.复原IP地址 、78.子集、90.子集II

文章目录 [1.复原 IP 地址](https://leetcode.cn/problems/restore-ip-addresses/description/)2.子集[3.子集 II](https://leetcode.cn/problems/subsets-ii/) 1.复原 IP 地址 切割问题可以使用回溯&#xff0c;本题分别两步&#xff0c;切割字符串和判断IP 切割逻辑如下&…

微信小程序 --- 分包加载

分包加载 1. 什么是分包加载 什么是分包加载 ❓ 小程序的代码通常是由许多页面、组件以及资源等组成&#xff0c;随着小程序功能的增加&#xff0c;代码量也会逐渐增加&#xff0c;体积过大就会导致用户打开速度变慢&#xff0c;影响用户的使用体验。 分包加载是一种小程序…

MATLAB图像噪声添加与滤波

在 MATLAB 中添加图像噪声和进行滤波通常使用以下函数&#xff1a; 添加噪声&#xff1a;可以使用imnoise函数向图像添加各种类型的噪声&#xff0c;如高斯噪声、椒盐噪声等。 滤波&#xff1a;可以使用各种滤波器对图像进行滤波处理&#xff0c;例如中值滤波、高斯滤波等。 …

(三)电机控制之方波驱动无刷直流电机(BLDC)与正弦波驱动无刷直流电机(PMSM)的详细对比

电流控制方式和波形&#xff1a; 方波驱动&#xff1a;在每个换相周期内&#xff0c;定子绕组中的电流被切换为高或低两个状态&#xff0c;形成矩形波。通常采用六步换向法&#xff0c;即每60度电角度换相一次&#xff0c;从而产生转矩。正弦波驱动&#xff1a;定子绕组中流过的…