Fashion MNIST数据集介绍及基于Pytorch下载数据集

news2024/12/25 10:32:10

Fashion MNIST数据集介绍及基于Pytorch下载数据集


🌵文章目录🌵

  • 🌳引言🌳
  • 🌳Fashion MNIST数据集简介🌳
    • Fashion MNIST数据集的类别说明
    • Fashion MNIST数据集图片示例
  • 🌳基于PyTorch下载Fashion MNIST数据集🌳
  • 🌳使用Fashion MNIST数据集进行图像分类任务🌳
  • 🌳小结🌳
  • 🌳结尾🌳


🌳引言🌳

Fashion MNIST是深度学习和机器学习领域中一个非常流行且实用的数据集。它为初学者和研究者提供了一个挑战性的任务,以磨练他们的图像分类技能。本文将深入探讨Fashion MNIST数据集的背景、目的、使用方法和示例代码,帮助您更好地了解如何利用这个数据集进行图像分类任务。


🌳Fashion MNIST数据集简介🌳

Fashion MNIST是一个包含10个类别的服饰分类数据集,每个类别有7000个28x28像素的灰度图像。与MNIST数据集相比,Fashion MNIST在图像质量和多样性方面具有更高的挑战性,因为它包含了更多的背景和不同的视角。

Fashion MNIST数据集的类别说明


标签说明
0T恤(T-shirt)
1裤子(Trouser)
2套头衫(Pullover)
3连衣裙(Dress)
4外套(Coat)
5凉鞋(Sandal)
6衬衫(Shirt)
7运动鞋(Sneaker)
8包(Bag)
9靴子(Ankle boot)

Fashion MNIST数据集图片示例


在这里插入图片描述

图1 数据集示例


🌳基于PyTorch下载Fashion MNIST数据集🌳

在开始使用Fashion MNIST数据集之前,您需要先将其下载到本地计算机上。以下是使用Python和Pytorch库下载数据集的步骤:

  1. 确保已经安装了Python和Pytorch。您可以从Pytorch官网下载并安装最新版本的Pytorch。
  2. 导入所需的库:
import torch
from torchvision import datasets, transforms
  1. 下载训练数据集:
train_data = torchvision.datasets.FashionMNIST(root='./data', train=True, transform=transforms.ToTensor(), download=True)

这将在当前目录下创建一个名为“data”的文件夹,并将训练数据集下载到其中。如果您已经拥有数据集,并且想要跳过下载过程,请将download参数设置为False

  1. 下载测试数据集:
test_data = torchvision.datasets.FashionMNIST(root='./data', train=False, transform=transforms.ToTensor(), download=True)

同样,这将在“data”文件夹中提供测试数据集。

  1. 现在您已经成功下载了Fashion MNIST数据集,您可以使用Pytorch的数据加载器(DataLoader)来轻松加载数据。例如,以下代码将创建一个训练数据加载器:
from torch.utils.data import DataLoader

train_loader = DataLoader(train_data, batch_size=32, shuffle=True)

这里,我们将批量大小设置为32,并启用了随机打乱功能。您可以根据需要调整这些参数。类似地,您可以为测试数据集创建一个加载器:

test_loader = DataLoader(test_data, batch_size=32, shuffle=False)

🌳使用Fashion MNIST数据集进行图像分类任务🌳

一旦您下载并准备好了数据集,就可以开始构建和训练图像分类模型了。以下是一个使用PyTorch构建简单卷积神经网络(CNN)进行图像分类的示例代码:

  1. 导入所需的库和模块:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split
  1. 划分训练和测试数据集:

首先,我们需要将Fashion MNIST数据集划分为训练集和测试集。以下是一个简单的示例代码,用于将数据分为训练集和测试集:

# 将数据转换为Tensor格式并进行归一化处理(将像素值缩放到0-1之间)
transform = transforms.ToTensor()
train_data = TensorDataset(torch.tensor(train_data.data), train_data.targets) # targets表示图像对应的类别标签(0-9)
test_data = TensorDataset(torch.tensor(test_data.data), test_data.targets) # targets表示图像对应的类别标签(0-9)
train_loader = DataLoader(train_data, batch_size=32, shuffle=True) # 创建训练数据加载器,设置批量大小为32并【启用】随机打乱功能
test_loader = DataLoader(test_data, batch_size=32, shuffle=False) # 创建测试数据加载器,设置批量大小为32并【禁用】随机打乱功能

3. 定义模型结构:
现在,我们可以定义一个简单的卷积神经网络(CNN)模型,用于图像分类任务。以下是一个示例代码,展示了如何使用PyTorch构建一个包含两个卷积层、一个全连接层的CNN模型:


```python
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = x.view(-1, 64 * 7 * 7) # 将卷积后的特征图展平,以便输入全连接层
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1) # 使用log_softmax激活函数进行分类概率计算
  1. 训练模型:

接下来,我们将使用训练数据集对模型进行训练。以下是一个示例代码,展示了如何定义损失函数和优化器,以及如何训练模型:

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 检查是否有可用的GPU,并定义设备(CPU或GPU)
model = SimpleCNN().to(device) # 将模型移动到设备上(CPU或GPU)
criterion = nn.CrossEntropyLoss() # 定义损失函数为交叉熵损失函数
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9) # 定义优化器为随机梯度下降(SGD)优化器,设置学习率为0.001,动量为0.9

# 训练模型
num_epochs = 10 # 设置训练轮数为10轮
for epoch in range(num_epochs):
    model.train() # 设置模型为训练模式
    running_loss = 0.0
    for i, data in enumerate(train_loader): # 使用训练数据加载器逐批获取数据和标签
        inputs, labels = data[0].to(device), data[1].to(device) # 将数据和标签移动到设备上(CPU或GPU)
        optimizer.zero_grad() # 将梯度清零
        outputs = model(inputs) # 前向传播,获取预测输出
        loss = criterion(outputs, labels) # 计算损失值
        loss.backward() # 反向传播,计算梯度值
        optimizer.step() # 更新权重参数
        running_loss += loss.item() # 累加损失值
    print('Epoch %d loss: %.3f' % (epoch + 1, running_loss / len(train_loader))) # 输出当前轮次的平均损失值
  1. 测试模型:

经过训练后,我们需要使用测试数据集评估模型的性能。以下是一个示例代码,展示了如何使用测试数据加载器评估模型:

model.eval() # 设置模型为评估模式,关闭dropout和batch normalization等在训练模式下的特殊操作
correct = 0
total = 0
with torch.no_grad(): # 不需要计算梯度,以提高评估速度
    for data in test_loader: # 使用测试数据加载器逐批获取数据和标签
        images, labels = data[0].to(device), data[1].to(device) # 将数据和标签移动到设备上(CPU或GPU)
        outputs = model(images) # 前向传播,获取预测输出
        _, predicted = torch.max(outputs.data, 1) # 获取最大概率对应的类别标签作为预测结果
        total += labels.size(0) # 统计样本总数
        correct += (predicted == labels).sum().item() # 统计正确分类的样本数量

print('Accuracy of the network on the test images: %d %%' % (
    100 * correct / total)) # 输出模型在测试数据集上的准确率

🌳小结🌳

Fashion MNIST是一个流行的机器学习数据集,主要用于服饰分类任务。它包含10个类别的7000个28x28像素的灰度图像,挑战性较高,因为涉及更多背景和视角。通过PyTorch,可以轻松下载并使用此数据集。一旦数据集准备好,可以使用CNN等模型进行图像分类。本文详细介绍了Fashion MNIST的背景、目的、使用方法和示例代码,为初学者和研究者提供了实用的指导和资源。


🌳结尾🌳

亲爱的读者,首先感谢抽出宝贵的时间来阅读我们的博客。我们真诚地欢迎您留下评论和意见💬
俗话说,当局者迷,旁观者清。的客观视角对于我们发现博文的不足、提升内容质量起着不可替代的作用。
如果博文给您带来了些许帮助,那么,希望能为我们点个免费的赞👍👍/收藏👇👇您的支持和鼓励👏👏是我们持续创作✍️✍️的动力
我们会持续努力创作✍️✍️,并不断优化博文质量👨‍💻👨‍💻,只为给带来更佳的阅读体验。
如果有任何疑问或建议,请随时在评论区留言,我们将竭诚为你解答~
愿我们共同成长🌱🌳,共享智慧的果实🍎🍏!


万分感谢🙏🙏点赞👍👍、收藏⭐🌟、评论💬🗯️、关注❤️💚~

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1428820.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

ElasticSearch-SpringBoot整合ElasticSearch

六、SpringBoot整合ElasticSearch 1、浏览官方文档 1、查找跟ES客户端相关的文档 使用Java REST Client 选择Java Hight Level REST Client 2、创建项目的准备 1.找到原生的依赖 2.找到对象 3.分析这个类里面的方法 3、正式创建项目 1.创建工程 2.导入依赖 注意依赖版本…

如何使用本地私有NuGet服务器

写在前面 上一篇介绍了如何在本地搭建一个NuGet服务器, 本文将介绍如何使用本地私有NuGet服务器。 操作步骤 1.新建一个.Net类库项目 2.打包类库 操作后会生成一个.nupkg文件,当然也可以用dotnet pack命令来执行打包。 3.推送至本地NuGet服务器 打开命…

RISC-V指令格式

RISC-V指令格式 1 RISC-V指令集命名规范2 RISC-V指令集组成2.1 基础整数指令集2.2 扩展指令集 3 RISC-V指令格式3.1 指令表述3.2 指令格式 本文属于《 RISC-V指令集基础系列教程》之一,欢迎查看其它文章。 1 RISC-V指令集命名规范 前面提到过RV32I,这是…

C#,哥伦布数(Golomb Number)的算法与源代码

1 哥伦布数(Golomb Number) 哥伦布数(Golomb Number)是一个自然数的非减量序列,使得n在序列中正好出现G(n)次。前几个15的G(n)值为:1 2 2 3 3 4 4 4 5 5 5 6…

【深度学习】基于PyTorch架构神经网络学习总结(基础概念基本网络搭建)

神经网络整体架构 类似于人体的神经元 神经网络工作原来为层次结构,一层一层的变换数据。如上述示例有4层,1层输入层、2层隐藏层、1层输出层神经元:数据的量或矩阵的大小,如上述示例中输入层中有三个神经元代表输入数据有3个特征…

网络异常案例四_IP异常

问题现象 终端设备离线,现场根据设备ip,ping不通。查看路由器。 同一个路由器显示的终端设备(走同一个wifi模块接入),包含不同网段的ip。 现场是基于三层的无线漫游,多个路由器wifi配置了相同的ssid信息&a…

SpringBoot+Vue实现各种文件预览(附源码)

👨‍💻作者简介:在笑大学牲 🎟️个人主页:无所谓^_^ ps:点赞是免费的,却可以让写博客的作者开心好几天😎 项目运行效果 前言 在做项目时,文件的上传和预览必不可少。继上…

国标GB/T 28181详解:GB/T28181状态信息报送流程

目 录 一、状态信息报送 二、状态信息报送的基本要求 三、命令流程 1、流程图 2、流程描述 四、协议接口 五、产品说明 六、状态信息报送的作用 七、参考 在国标GBT28181中,定义了状态信息报送的流程,当源设备(包括网关、SIP 设备、SIP 客…

面试经典150题 -- 区间(总结)

总的链接 : 面试经典 150 题 - 学习计划 - 力扣(LeetCode)全球极客挚爱的技术成长平台最经典 150 题,掌握面试所有知识点https://leetcode.cn/studyplan/top-interview-150/ 228 汇总区间 直接用双指针模拟即可 ; class Solution { public…

华为数通方向HCIP-DataCom H12-821题库(单选题:401-420)

第401题 R1的配置如图所示,此时在R1查看FIB表时,关于目的网段192.168.1.0/24的下跳是以下哪一项? A、10.0.23.3 B、10.0.12.2 C、10.0.23.2 D、10.0.12.1 【答案】A 【答案解析】 该题目考查的是路由的递归查询和 RIB 以及 FIB 的关系。在 RIB 中,静态路由写的是什么,下…

【React】react组件传参

【React】react组件传参 一、props:父组件向子组件传参1、将普通的参数作为props传递2、将jsx作为props传递(组件插槽) 二、自定义事件:子父组件向父组件传参三、context进行多级组件传参四、redux全局状态管理 一、props&#xf…

C++ pair+map+set+multimap+multiset+AVL树+红黑树(深度剖析)

文章目录 1. 前言2. 关联式容器3. pair——键值对4. 树形结构的关联式容器4.1 set4.1.1 set 的介绍4.1.2 set 的使用 4.2 map4.2.1 map 的介绍4.2.2 map 的使用 4.3 multiset4.3.1 multiset 的介绍4.3.2 multiset 的使用 4.4 multimap4.4.1 multimap 的介绍4.4.2 multimap 的使…

利用Dynamo进行模型版本对比

你好,这里是 BIM 的乐趣,我是九哥~ 今天我们来聊一个老生常谈的话题——模型版本对比。 先来看一段视频演示: Dynamo模型版本对比 比较同一个模型的不同版本,找出新增,删除以及更改的内容,虽然感觉上实现…

SpringBoot+Redis如何实现用户输入错误密码后限制登录(含源码)

点击下载《SpringBootRedis如何实现用户输入错误密码后限制登录(含源码)》 1. 引言 在当今的网络环境中,保障用户账户的安全性是非常重要的。为了防止暴力破解和恶意攻击,我们需要在用户尝试登录失败一定次数后限制其登录。这不…

MongoDB从入门到实战之MongoDB快速入门

前言 上一章节主要概述了MongoDB的优劣势、应用场景和发展史。这一章节将快速的概述一下MongoDB的基本概念,带领大家快速入门MongoDB这个文档型的NoSQL数据库。 MongoDB从入门到实战的相关教程 MongoDB从入门到实战之MongoDB简介👉 MongoDB从入门到实战…

图像处理之《可逆重缩放网络及其扩展》论文精读

一、文章摘要 图像重缩放是一种常用的双向操作,它首先将高分辨率图像缩小以适应各种显示器或存储和带宽友好,然后将相应的低分辨率图像放大以恢复原始分辨率或放大图像中的细节。然而,非单射下采样映射丢弃了高频内容,导致逆恢复…

算法练习-二叉树的节点个数【完全/普通二叉树】(思路+流程图+代码)

难度参考 难度:中等 分类:二叉树 难度与分类由我所参与的培训课程提供,但需要注意的是,难度与分类仅供参考。且所在课程未提供测试平台,故实现代码主要为自行测试的那种,以下内容均为个人笔记,旨…

基于WordPress开发微信小程序2:决定开发一个wordpress主题

上一篇:基于WordPress开发微信小程序1:搭建Wordpress-CSDN博客 很快发现一个问题,如果使用别人的主题模板,多多少少存在麻烦,所以一咬牙,决定自己开发一个主题模板,并且开源在gitee上&#xff…

Javascript | JS如何断点测试(WebStorm)

JavaScript的断点与之前所学到的Java和python在jetbrain系列编辑器中的断点debug不太一样,往常我们在编写python的时候用pycharm的时候是直接断点进入debug的,就像下面这样 只要直接在代码中断点,然后运行debug功能即可 但是在WebStorm中不是…

网络流数据集处理(深度学习数据处理基础)

一、数据集处理 处理数据集是一个文件夹 一个文件夹处理的,将原网络流数据集 放入一个文件夹 处理转换成 Json文件。(数据预处理)然后将这些文件处理成目标文件格式 再分割成训练集和测试集。每次运行只会处理一个文件夹。 运行train.py 导入…