PyTorch的torchvision内置数据集使用,transform+pytorch联合使用

news2025/2/5 19:45:01

一、PyTorch的torchvision内置数据集介绍

我们前面的文章里谈到的数据集是我们自己找的一些自定义数据集。那么在Pytorch中存在2种数据集(Dataset),即内置数据集(Built-in dataset)和自定义数据集(Custom dataset)。该2种数据集在使用时有所区别,我们之前课程里学习的就是自定义数据集,这里就不多讲了。

而pytorch的内置数据集(Built-in dataset)是由pytorch的torchvision这个库提供的,我们可以直接在代码中下载下来使用,包含了各种各样的机器训练所需要的数据集类型

1、首先去官网了解一下

官方文档在这:torchvision — Torchvision 0.19 documentation

这是最新的pytorch官方文档,可能每一年都会变,目前2024年最新版的就是这样

往下滑,找到这里【Dataset】的【Built-in dataset】,这个就是内置数据集的官方文档

然后就能看到各种各样的数据集,看不懂可以下载一下网页翻译的插件,或者用微信的截图,可以实时翻译任何界面的英文

然后大概整体介绍一下我们常用的一些数据集:

  • MNIST:手写数字图像数据集,用于图像分类任务。
  • CIFAR:包含10个类别、60000张32x32的彩色图像数据集,用于图像分类任务。
  • COCO:通用物体检测、分割、关键点检测数据集,包含超过330k个图像和2.5M个目标实例的大规模数据集。
  • ImageNet:包含超过1400万张图像,用于图像分类和物体检测等任务。
  • Penn-Fudan Database for Pedestrian Detection and Segmentation:用于行人检测和分割任务的数据集。
  • STL-10:包含100k张96x96的彩色图像数据集,用于图像分类任务。
  • Cityscapes:包含5000张精细注释的城市街道场景图像,用于语义分割任务。
  • SQUAD:用于机器阅读理解任务的数据集。

其他的比较复杂以后再说,目前就这些可以了解一下,比如点进去CIFAR10可以看到,是一个对图片物品的识别

点击MNIST,这是来自各个人的手写文字的图片数据集

2、了解数据集的三大分类:训练集、验证集、测试集

数据集具体分为三大类:训练集、验证集、测试集

我们可以用以前读书时的上课知识、课后练习题、期末考试来比喻这三类数据集:

假设我们有已知的T1、T2、T3三个点,对应的 “坐标” 是数据集里【带有标签的样本对】:(x1, y1)、(x2, y2)、(x3, y3),x代表输入、y代表标签

这两个坐标会分别得到两个点T1、T2数据,当我们用这两个数据进行机器训练时就会得到一个函数式(假设):y = ax + b,那么这个【y = ax + b】就是【模型

然后我们当我们把T3的【x3】输入,根据【模型】:【y = ax + b】得出 \hat{y}这个 \hat{y}是机器模型训练得到的结果,而我们已知的T3坐标是【x3,y3】(就比如说 [0013035.jpg] 这个图片对应的是 [蚂蚁] 这个标签)

那么现在将\hat{y}】与【y3】进行对比,就可以得出这个模型的好坏。

然后我们刚刚知道,(x1, y1)、(x2, y2)参与了模型训练,因此就是【训练集】;而(x3, y3)没有参与训练,只是在最后对(x1, y1)、(x2, y2)训练出来的模型进行测试,他就是【测试集】

【训练集】【测试集】有严格划分,训练集只参与训练,相当于我们做的课后练习、模拟试卷,而测试集只参与测试,相当于高考试卷,你总不能把高考题直接给学生做了再参加高考吧?

那么【验证集】又是什么?

我们前面的例子里,训练机器模型的时候,我们根据T1、T2获得了【y=ax+b】这么个模型式子,那么a、b就是【参数】

在实际机器训练中,我们还需要一种参数叫【超参数】,它考虑得是更加精细的方面的可能,通过调整这些超参数,才能让模型更加精确、误差更小

但是随着业务越来越复杂,需要调整的超参数越来越多、越来越难,人为调整已经不太适用

那么就产生了【验证集】,借助它可以获得【最优超参数组合】

【训练集】产生模型,【验证集】评估模型并找出一组最有超参数,然后再交给【训练集】产出新的模型......如此反复,最终将最优化的模型交给【测试集】测试

3、pytorch下载并使用数据集

回到代码 ,现在我们先用【CIFAR-10】这个数据集为例子讲解,点开官方文档,这次点左边不要点右边

然后可以看到,【CIFAR-10】这个数据集需要传入这么几个参数,具体作用我写在图片了

执行一下代码进行下载:

import torchvision

# 这里我们设置train_set是训练集数据、test_set是测试集数据
# 然后因为我根目录已经有一个"dataset"目录放自定义数据集了,我就指定在根目录再创一个文件夹叫“dataset2”,用来放下载的内置数据集
# 然后【训练集】的train参数设置为true,【测试集】就是false,最后允许下载
train_set = torchvision.datasets.CIFAR10("./dataset2", train=True, download=True)
test_set = torchvision.datasets.CIFAR10("./dataset2", train=False, download=True)

(网速因人而异吧,我这个巨特么慢)

那么有迅雷或者百度网盘的超级会员的可以试一下去迅雷或者百度网盘下载,然后复制粘贴进项目里就行了,没有的话其实也是一样,可以安心下载完挂后台去干别的了,至少一小时起步

感谢,来自该博主文章:CIFAR-10 / CIFAR-100数据集(官网/网盘下载)_cifar100pytorch下载-CSDN博客

链接:百度网盘 请输入提取码
提取码:uwis

N万年之后......终于下载安装完毕,那么而我们在下载的代码下面输出一下数据集看看就会发现成功了,这里注意,即使我们下载完了数据集,这两句代码还是要留着,download参数也不用换成False,因为它会自动检测已经下载存在,就不会重复下载了

import torchvision

# 这里我们设置train_set是训练集数据、test_set是测试集数据
# 然后因为我根目录已经有一个"dataset"目录放自定义数据集了,我就指定在根目录再创一个文件夹叫“dataset2”,用来放下载的内置数据集
# 然后【训练集】的train参数设置为true,【测试集】就是false,最后允许下载
train_set = torchvision.datasets.CIFAR10("./dataset2", train=True, download=True)
test_set = torchvision.datasets.CIFAR10("./dataset2", train=False, download=True)

# 打印训练集和测试集的数据量
print(train_set[0])

然后在PyTorch的内置数据集中,无论是CIFAR-10、MNIST还是其他数据集,当你索引数据集对象时,都会返回一个元组,其中第一个元素是【数据】,第二个元素是【标签的索引】。

那么我们知道python可以多个变量同时接收多个返回值,那么就可以设两个变量来接收【数据】和【标签的索引】,变量名随便起,比如:

那么我们将代码放到【python控制台】运行的话,就会看到数据集有这么个属性:【classes】,这个属性是列表list,而我们刚刚获取到的【标签的索引】就是对应【classes】列表里的索引,那么【数据集.classes[ 标签的索引 ]】就可以获取到【标签】了

4、结合transforms,把数据集转化成transforms的tensor型数据

transform有个函数叫【Compose】,它可以将多个操作组合起来,相当于打包,然后形成一个操作流程,一次性依次对图像进行所有操作

例如,假设你有一个图像,你想要先将其转换为灰度图像,然后再将其大小调整为32x32。你可以使用【torchvision.transforms.Compose】来组合这两个转换操作:

from torchvision import transforms

# 创建两个转换操作
to_grayscale = transforms.Grayscale(num_output_channels=1) # 调灰度的操作
resize = transforms.Resize((32, 32)) # 裁剪图像大小的操作

# 组合这两个转换操作,现在就能用它代表这两个操作了
transform = transforms.Compose([to_grayscale, resize])

# 使用组合的转换操作对图像一次性进行两个操作转换
image = transform(image)

那么有了【transform转换操作】之后,翻看上面我们讲过,pytorch的【CIFAR-10】接收的参数里有一个叫【transform=?】,就是用来接收【transform转换操作】的,在两个 “下载代码” 那里的第三个参数位置加上【transform = transform转换操作】

train_set = torchvision.datasets.CIFAR10("./dataset2", train=True, transform=dataset_transfrom, download=True)
test_set = torchvision.datasets.CIFAR10("./dataset2", train=False, transform=dataset_transfrom, download=True)

5、最后结合tensorboard打印一下图像

既然现在图像数据集已经成了tensor类型数据,那就可以用tensorboard打印一下图像了

完整代码:

import torchvision
from torch.utils.tensorboard import SummaryWriter

# torchvision.transforms.Compose是一个函数,它可以将多个图像转换操作组合在一起
dataset_transfrom = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor() # 将图像转换为【tensor(张量)】操作
    # 还可以添加一些裁剪、转灰度......等等图像操作
])

# 这里我们设置train_set是训练集数据、test_set是测试集数据
# 然后因为我根目录已经有一个"dataset"目录放自定义数据集了,我就指定在根目录再创一个文件夹叫“dataset2”,用来放下载的内置数据集
# 然后【训练集】的train参数设置为true,【测试集】就是false,并开启transform数据转换,最后允许下载
train_set = torchvision.datasets.CIFAR10("./dataset2", train=True, transform=dataset_transfrom, download=True)
test_set = torchvision.datasets.CIFAR10("./dataset2", train=False, transform=dataset_transfrom, download=True)

# 输出一些tensor类型的数据集数据
print(train_set[0])

# 最后利用tensorboard打印出照片
write = SummaryWriter("CIFAR-10")
for i in range(10):
    # 遍历10张图片
    img, target = train_set[i]
    write.add_image("【CIFAR-10】_img", img, i)

write.close()

记得关闭上一个tensorboard的终端

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2056646.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

架站点云自动拼接

southLidar pro 软件里面的架站点云无目标、无传感器的点云自动拼接算法,该算法的特征是速度快,精度高、稳定性高,大部分的场景都能一键自动拼接成功。速度、稳定性:比RealWorks 12、SCENE 2019等软件都快。精度:高于S…

python-docx 实现 Word 办公自动化

前言:当我们需要批量生成一些合同文件或者简历等。如果手工处理对于我们来说不仅工作量巨大,而且难免会出现一些问题。这个时候运用python处理word实现自动生成文件可极大的提高工作效率。 python-docx是python的第三方插件,用来处理word文件…

UPS快递查询|利用API对接国际物流轨迹

聚合国内外1500家快递公司的物流信息查询服务,使用API接口查询UPS快递的便捷步骤,首先选择专业的数据平台的快递API接口:https://www.tanshuapi.com/market/detail-68 以下示例是参考的示例代码: import requestsurl "http…

gstreamer系列 -- 获取媒体信息

Basic tutorial 9: Media information gathering

PyCharm单步调试

1、先在入口设置断点,再点击爬虫图标(shift F9)开始调试 调试图标如图: 2、蓝色光标表示当前运行在这行 3、快捷键 F7:进入当前行函数 F8:单步 F9:全速运行

语言基础/分析和实践 CC++ 位域结构数据类型

文章目录 概述位域和结构体的关系位域/位段的概念位域定义的语法位段结构的利弊 结构字段的定义和存储顺序小端系统上的结构字段存储大端系统上的结构字段存储小结(承上启下) 位域结构的存储(对齐、填充、跨字节)位域结构的Bit位序…

admob 广告分析

1、测试广告集成,官方文档 https://developers.google.com/admob/android/quick-start?hlzh-cn dependencies {implementation("com.google.android.gms:play-services-ads:23.3.0") }2、广告集成,集成测试激励广告。 public class MainAct…

学习大数据DAY41 Hive 分区表创建

目录 分区表 分区表应用场景 oracle 分区表种类 oracle 分区-范围分区 oracle 分区-列表分区 oracle 分区-散列分区 oracle 分区-组合分区 oracle 分区-分区表操作 hive 分区-创建分区表 hive 分区-分区表操作 hive 分区-动态分区表配置 上机练习 分区表 分区是将一…

常见古典密码介绍

文章目录 Vigenre 密码变异凯撒摩斯密码栅栏密码加密方式一加密方式二 Caesar和ROT13的区别ROT13加密原理ROT13查找表 Vigenre 密码 由于频率分析法可以有效的破解单表替换密码,法国密码学家维吉尼亚于1586年提出一种多表替换密码,   即维吉尼亚密码&…

什么是局域网管理软件?这款局域网管理软件简直太好用了丨好物分享

在信息技术日新月异的今天,企业的内部网络管理如同古代战场上的排兵布阵,需有精良之器以应对复杂多变的局势。 局域网,作为企业内部信息交流与资源共享的重要平台,其管理效率与安全性直接影响到企业的运营与发展。 一、局域网管理…

Docker-安装软件

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 前言一、安装MySQL(一)拉取MySQL镜像(二)运行MySQL容器(1)数据卷概念 (三&#x…

开发 LLM 支持的应用程序:Azure 上的 Llama 2(5/n)

微软与 Meta 联手在 Azure 上提供 Meta 的开源大型语言模型 (LLM) Llama 2,打开了闸门!信不信由你,这是一件大事。 亚马逊的 AWS 于去年年底发布了 Amazon SageMaker Jumpstart,它与 Azure 类似,支持部署开源模型供公共…

常见而不容忽视,电器中微型紧固件的四大影响

技术和创新催生了数以百万计的电器,它们已成为每个家庭和人类日常使用的一部分。从微波炉和冰箱到笔记本电脑和智能手机,这些部件的技术影响正在迅速增长,成为现代生活的一部分。 在许多人的不经意间,这些功能强大的发明为我们许多…

[Linux#42][线程] 锁的接口 | 原理 | 封装与运用 | 线程安全

互斥量 mutex • 大部分情况,线程使用的数据都是局部变量,变量的地址空间在线程栈空间 内,这种情况,变量归属单个线程,其他线程无法获得这种变量。 • 但有时候,很多变量都需要在线程间共享,这…

代码随想录算法训练营第二十一天(二叉树 八)

今天是二叉树复习最后一天! 力扣题部分: 669. 修剪二叉搜索树 题目链接:. - 力扣(LeetCode) 题面: 给你二叉搜索树的根节点 root ,同时给定最小边界low 和最大边界 high。通过修剪二叉搜索树,使得所有节点的值在[low…

使用 Dify 和 AI 大模型理解视频内容:Qwen 2 VL 72B

接下来的几篇相关的文章,聊聊使用 Dify 和 AI 大模型理解视频内容。 本篇作为第一篇内容,以昨天出圈的“黑神话悟空制作人采访视频”为例,先来聊聊经常被国外厂商拿来对比的国产模型:千问系列,以及它的内测版。 写在…

Linux非VP扩容方案

Linux系统非VP扩容方案 描述:现有虚拟机磁盘1TB 容量不够,需要扩容。 采用:https://bbs.sangfor.com.cn/forum.php?modviewthread&tid110403 扩容失败。原因是没有VP 和LV 解决方案: 1,查看分区 cat /proc/p…

鸿蒙内核源码分析(中断概念篇) | 海公公的日常工作

关于中断部分系列篇将用三篇详细说明整个过程. 中断概念篇 中断概念很多,比如中断控制器,中断源,中断向量,中断共享,中断处理程序等等.本篇做一次整理.先了解透概念才好理解中断过程.本篇的主角是海公公,用…

全国计算机二级C语言笔试试题及答案

一、选择题(每小题2分,共70分)   下列各题A)、B)、C)、D)四个选项中,只有一个选项是正确的。 请将正确选项填涂在答题卡相应位置上,答在试卷上不得分。   (1)下列叙述中正确的是 A)线性表的链式存储结构与顺序存储结构所需要的存储空间是相同的 …

day06-SpringBootWeb请求响应

前言 在上一次的课程中,我们开发了springbootweb的入门程序。 基于SpringBoot的方式开发一个web应用,浏览器发起请求 /hello 后 ,给浏览器返回字符串 “Hello World ~”。 其实呢,是我们在浏览器发起请求,请求了我们的…