项目实例_FashionMNIST_CNN

news2024/12/24 0:59:00

前言

提醒:
文章内容为方便作者自己后日复习与查阅而进行的书写与发布,其中引用内容都会使用链接表明出处(如有侵权问题,请及时联系)。
其中内容多为一次书写,缺少检查与订正,如有问题或其他拓展及意见建议,欢迎评论区讨论交流。

文章目录

  • 前言
  • CNN介绍
  • 数据集介绍
      • FashionMNIST 数据集概述
        • 主要特点:
        • 类别标签:
        • 目标:
        • 为什么使用 FashionMNIST:
        • 图像示例:
        • 数据加载与预处理:
        • 下载和加载 FashionMNIST(PyTorch 代码示例):
        • 数据集使用场景:
      • 数据集可视化
      • 总结:
  • 项目实例
    • 代码


CNN介绍

可见于:
MNIST数据集_CNN
在这里插入图片描述

数据集介绍

FashionMNIST 数据集概述

FashionMNIST 是一个包含 10 个类别的图像数据集,用于训练和测试机器学习模型,特别是在图像分类任务中的应用。它是由 Zalando 提供的,目的是为了给机器学习社区提供一个标准化、简洁但具有挑战性的视觉分类数据集。FashionMNIST 是 MNIST 数据集的一个变种,后者包含手写数字图像。

主要特点:
  • 图像大小:每张图像的分辨率为 28x28 像素,每个像素为灰度值(单通道,值在 0 到 255 之间)。
  • 图像类型:这些图像展示了 10 种不同类型的时尚商品,例如鞋子、T 恤、外套等。
  • 数据集结构
    • 训练集:60,000 张图像
    • 测试集:10,000 张图像
  • 类别:数据集包含 10 个类别,分别对应不同的服装商品(每个类别有对应的标签)。
类别标签:
  • 0: T 恤/上衣
  • 1: 裤子
  • 2: 套头衫
  • 3: 连衣裙
  • 4: 外套
  • 5: 凉鞋
  • 6: 衬衫
  • 7: 运动鞋
  • 8: 包
  • 9: 靴子
目标:

FashionMNIST 的目标是对每张 28x28 的灰度图像进行分类,判定该图像属于哪个类别(例如是 “T 恤”、“裤子” 还是 “运动鞋” 等)。因此,它是一个 多类分类 问题,通常被用来评估各种机器学习模型,尤其是在图像分类任务中的表现。

为什么使用 FashionMNIST:

FashionMNIST 与原始的 MNIST 数据集相似,但相比 MNIST(手写数字),FashionMNIST 的图像内容更加复杂且多样。这使得它在很多机器学习和深度学习领域成为了一个较为简单但富有挑战性的测试集。它被广泛应用于:

  • 深度学习模型的评估:特别是用于测试卷积神经网络(CNN)等模型的性能。
  • 学习和研究:它提供了一个简单且标准化的图像分类数据集,适用于机器学习入门或新模型的验证。
图像示例:

每张图像都是 28x28 像素的灰度图,显示的是一件衣物的图片。图像尺寸较小,适合在初学者的机器学习项目中进行训练,因为它不需要大量的计算资源。

数据加载与预处理:

FashionMNIST 数据集一般会在加载时进行一些标准的预处理步骤,如:

  • 归一化:将像素值从 [0, 255] 范围映射到 [0, 1] 范围,或者进行标准化,常常帮助提升模型的收敛速度。
  • 转换:通常使用 transforms.ToTensor() 将图像转换为 PyTorch 中的 Tensor 格式。
下载和加载 FashionMNIST(PyTorch 代码示例):
import torchvision
from torchvision import transforms

# 加载训练集数据
train_data = torchvision.datasets.FashionMNIST(
    root='data',     # 存储路径
    train=True,      # 训练集
    download=True,   # 下载数据集
    transform=transforms.ToTensor(),  # 转换为 Tensor 格式
)

# 加载测试集数据
test_data = torchvision.datasets.FashionMNIST(
    root='data',
    train=False,     # 测试集
    download=True,
    transform=transforms.ToTensor(),
)

这段代码使用 torchvision.datasets.FashionMNIST 来下载并加载训练和测试数据集,图像会被转换为 PyTorch 的 Tensor 格式。

数据集使用场景:
  • 入门项目:对于初学者,FashionMNIST 是一个非常适合入门的图像分类数据集,因为它相对简单,且具有一定的挑战性。
  • 模型对比与验证:在机器学习和深度学习领域,FashionMNIST 常常作为测试模型性能的标准数据集之一,用来对比不同算法(如支持向量机、KNN、神经网络等)的表现。
  • 神经网络训练:尤其是卷积神经网络(CNN)在图像分类任务中的应用,FashionMNIST 为其提供了一个理想的训练平台。

数据集可视化

import torch
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np

# 下载 FashionMNIST 数据集
transform = transforms.Compose([transforms.ToTensor()])  # 只需要转换为 Tensor
train_data = datasets.FashionMNIST(root='data', train=True, download=True, transform=transform)

# 获取前 10 张图像以及对应的标签
images, labels = zip(*[(train_data[i][0], train_data[i][1]) for i in range(10)])

# 类别名称映射
class_names = [
    'T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
    'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot'
]

# 设置绘图
fig, axes = plt.subplots(1, 10, figsize=(15, 15))  # 创建 1 行 10 列的子图

for i in range(10):
    ax = axes[i]
    ax.imshow(images[i].squeeze(), cmap='gray')  # squeeze 去掉多余的维度,cmap 为灰度色
    ax.set_title(class_names[labels[i]])  # 标注类别名称
    ax.axis('off')  # 不显示坐标轴

plt.show()  # 显示图像

运行结果:
在这里插入图片描述

总结:

FashionMNIST 是一个标准化的图像分类数据集,由 Zalando 提供。它由 10 类不同的时尚商品构成,训练集包含 60,000 张图像,测试集包含 10,000 张图像。它适用于深度学习、机器学习模型的训练和评估,尤其适合初学者学习和实验。

项目实例

代码

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

# 下载训练数据集FashionMNIST
training_data = torchvision.datasets.FashionMNIST(
    root="data",  # 数据集存储位置
    train=True,    # 使用训练集
    download=True, # 如果数据集不存在,则下载
    transform=transforms.ToTensor(),  # 转换为Tensor
)

# 下载测试数据集FashionMNIST
test_data = torchvision.datasets.FashionMNIST(
    root="data",
    train=False,   # 使用测试集
    download=True,
    transform=transforms.ToTensor(),  # 转换为Tensor
)

# 标签的映射字典,数字标签对应的衣物类别名称
labels_map = {
    0: "T-Shirt",
    1: "Trouser",
    2: "Pullover",
    3: "Dress",
    4: "Coat",
    5: "Sandal",
    6: "Shirt",
    7: "Sneaker",
    8: "Bag",
    9: "Ankle Boot",
}

# 可视化FashionMNIST数据集中前9张图片
figure = plt.figure(figsize=(8, 8))  # 创建一个8x8的图像画布
cols, rows = 3, 3  # 设置行列数
for i in range(1, cols * rows + 1):
    sample_idx = torch.randint(len(training_data), size=(1,)).item()  # 随机选取一张图片
    img, label = training_data[sample_idx]  # 获取图片和标签
    figure.add_subplot(rows, cols, i)  # 在画布上添加子图
    plt.title(labels_map[label])  # 设置图片的标题为标签对应的衣物类别
    plt.axis("off")  # 关闭坐标轴显示
    plt.imshow(img.squeeze(), cmap="gray")  # 显示图像,squeeze去掉多余维度,cmap设置为灰度图
plt.show()  # 展示图像

# 设置训练过程中的超参数
num_epochs = 10       # 训练的轮数
batch_size = 32       # 批大小
weight_decay = 1e-4    # 权重衰减(L2正则化)
learning_rate = 0.001 # 学习率

# 创建数据加载器,训练集和测试集
train_dataloader = torch.utils.data.DataLoader(training_data, batch_size=batch_size)
test_dataloader = torch.utils.data.DataLoader(test_data, batch_size=batch_size)

# 打印测试集的一个batch的尺寸和标签类型
for X, y in test_dataloader:
    print(f"Shape of X [N, C, H, W]: {X.shape}")  # X是输入图像,N是批大小,C是通道数,H和W是图像的高和宽
    print(f"Shape of y: {y.shape} {y.dtype}")      # y是标签,显示其维度和数据类型
    break  # 只打印一个batch的信息

# 检测是否使用GPU进行训练
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")  # 输出当前使用的设备(CUDA or CPU)

# 定义一个简单的神经网络模型
class NeuralNet(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super(NeuralNet, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)   # 第一层全连接层,输入784个特征,输出hidden_size个特征
        self.relu = nn.ReLU()                           # ReLU激活函数
        self.fc2 = nn.Linear(hidden_size, num_classes)  # 第二层全连接层,输出num_classes个类别的预测值

    def forward(self, x):
        out = self.fc1(x)    # 输入通过第一层
        out = self.relu(out)  # ReLU激活
        out = self.fc2(out)   # 输入通过第二层,输出结果
        return out

# 假设输入是28x28的图像,展开为784维,隐藏层大小为500,分类数为10
model = NeuralNet(input_size=784, hidden_size=500, num_classes=10).to(device)

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()  # 使用交叉熵损失函数,适用于多分类问题
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)  # 使用Adam优化器

# 开始训练模型
total_step = len(train_dataloader)  # 获取训练集的总批次数
for epoch in range(num_epochs):  # 遍历每一个epoch
    for i, (images, labels) in enumerate(train_dataloader):  # 遍历每一个batch
        images = images.reshape(-1, 28*28).to(device)  # 将28x28的图像展开成784维向量,转移到device(GPU/CPU)
        labels = labels.to(device)  # 标签转移到设备上

        # 前向传播
        outputs = model(images)  # 将输入传入模型,得到预测输出
        loss = criterion(outputs, labels)  # 计算损失

        # 反向传播和优化
        optimizer.zero_grad()  # 清零之前的梯度
        loss.backward()  # 计算当前梯度
        optimizer.step()  # 更新模型参数

        # 每100个batch输出一次训练状态
        if (i + 1) % 100 == 0:
            print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{total_step}], Loss: {loss.item():.4f}')

# 训练完成后,在测试集上评估模型的准确率
model.eval()  # 设置模型为评估模式(此时BatchNorm等层使用移动平均值而不是批量值)
with torch.no_grad():  # 不需要计算梯度
    correct = 0
    total = 0
    for images, labels in test_dataloader:
        images = images.reshape(-1, 28*28).to(device)  # 将图像展开为784维
        labels = labels.to(device)  # 标签转移到设备上
        outputs = model(images)  # 获取模型的输出
        _, predicted = torch.max(outputs.data, 1)  # 获取预测类别,outputs.data返回模型的预测结果
        total += labels.size(0)  # 统计总样本数
        correct += (predicted == labels).sum().item()  # 统计预测正确的样本数

    # 输出模型在测试集上的准确率
    print(f'Test Accuracy of the model on the 10000 test images: {100 * correct / total} %')

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2257689.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

使用数据库同步中间件DBSyncer实现不同数据库的数据同步

点击上方蓝字关注我 有去O(ORACLE数据库)、信创、国产化数据库等项目实践的同学应该都遇到过不同数据库之前进行数据迁移的问题,虽然有各种工具可以实现,但是有些工具的部署、使用比较复杂,也有些工具迁移数据效率很低。本文将介绍一款开源且…

SQL汇总数据:聚集函数

我们经常需要汇总数据而无需实际检索出这些数据,为此SQL提供了专门的函数。使用这些函数,SQL查询能够高效地检索数据,以便进行分析和报表生成。这类检索的例子包括: 确定表中行数(或者满足某个条件或包含某个特定值的…

HTML颜色-HTML脚本

HTML脚本 js使得HTML页面具有更强的动态和交互性 HTML<script>标签 标签用于定义客户端脚本&#xff0c;比如javascript 可包含脚本语句&#xff0c;也可以通过src属性指向外部的脚本文件 JavaScript最常用于图片操作&#xff0c;表单验证以及动态的内容更新 HTML<n…

ASP.NET Core8.0学习笔记(二十五)——EF Core Include导航数据加载之预加载与过滤

一、导航属性数据加载 1.在EF Core中可以使用导航属性来加载相关实体。 2.加载实体的三种方式&#xff1a; (1)预先加载&#xff1a;直接在查询主体时就把对应的依赖实体查出来&#xff08;作为初始查询的一部分&#xff09; (2)显式加载&#xff1a;使用代码指示稍后显式的从…

【工具变量】上市公司企业过度负债数据(2000-2022年)

一、计算方式&#xff1a;参考C刊《投资研究》汪昌云&#xff08;2022&#xff09;老师的研究&#xff0c;将实际负债率与 Tobit 回归得到的目标负债率之差认定为过度负债率&#xff0c;该种方式认为目标负债率的驱动因素包括公司特征与行业因素&#xff0c;较为全面&#xff0…

分布式数据库中间件-Sharding-JDBC

文章目录 Sharding-JDBCSharding-JDBC介绍Sharding-JDBC的作用什么是分库分表分库分表的方式分库分表带来的问题事务一致性问题跨节点关联查询跨节点分页、排序函数主键重复 Sharding-JDBC 入门&#xff08;水平分表&#xff09;需求说明环境搭建编写代码流程分析其他配置方式概…

FPGA 16 ,Verilog中的位宽:深入理解与应用

目录 前言 一. 位宽的基本概念 二. 位宽的定义方法 1. 使用向量变量定义位宽 ① 向量类型及位宽指定 ② 位宽范围及位索引含义 ③ 存储数据与字节数据 2. 使用常量参数定义位宽 3. 使用宏定义位宽 4. 使用[:][-:]操作符定义位宽 1. 详细解释 : 操作符 -: 操作符 …

HTML:表格重点

用表格就用table caption为该表上部信息&#xff0c;用来说明表的作用 thead为表头主要信息&#xff0c;效果加粗 tbody为表格中的主体内容 tr是 table row 表格的行 td是table data th是table heading表格标题 &#xff0c;一般表格第一行的数据都是table heading

hbuilder 本地插件配置

插件存放路径&#xff0c;项目根目录nativeplugins下&#xff0c;没有就新建。 aar文件存放路径\nativeplugins\pda-module\android package.json存放路径\nativeplugins\module\ 配置package.json文件 { "name": "本地插件", "id": &quo…

大模型应用的数字能源数据集

除了尚须时日的量子计算解决算力效率和能源问题&#xff0c;以及正在路上的超越transformer的全新模型架构外&#xff0c;无疑是“数据集”&#xff0c;准确讲是“高质量大规模多样性的数据集”。数据集是大模型发展的核心要素之一&#xff0c;是大计算的标的物&#xff0c;是实…

飞书解除复制,下载文件限制终极方案

1.通过移除copy 事件&#xff0c;可以复制文档内容&#xff0c;但是飞书表格增加了键盘按键事件&#xff0c;表格无法复制&#xff0c;下载 2.通过chrome插件&#xff0c;可以复制clould document converter 可以实现下载飞书文档&#xff0c;但是无法下载表格 而且无法识别自定…

Java面试题精选:设计模式(二)

1、装饰器模式与代理模式的区别 1&#xff09;代理模式(Proxy Design Pattern ) 原始定义是&#xff1a;让你能够提供对象的替代品或其占位符。代理控制着对于原对象的访问&#xff0c;并允许将请求提交给对象前后进行一些处理。 代理模式的适用场景 功能增强 当需要对一个对…

自然语言处理:从入门到精通全指引

一、引言 自然语言处理&#xff08;NLP&#xff09;作为人工智能领域的关键分支&#xff0c;旨在让计算机理解、生成和处理人类语言&#xff0c;近年来取得了令人瞩目的成就&#xff0c;在智能客服、机器翻译、文本分析、语音助手等众多领域发挥着重要作用。从入门到精通自然语…

Typora 修改默认的高亮颜色

shift F12 参考 怎么给typora添加颜色&#xff1f;

(1)Quartus中如何在外设FLASH中固化jic文件

&#xff08;1&#xff09;在产生jic文件前&#xff0c;必须已经综合通过&#xff0c;生成了sof文件 &#xff08;2&#xff09;点击file-convert Programming Files... &#xff08;3&#xff09;文件类型选择jic文件&#xff0c;flsh型号设定为EPCS128 &#xff08;4&#…

OpenAI2024-12D-3:Sora 发布,谁更胜一筹——Sora 与可灵的全面前瞻对比

藏了一年&#xff0c;终于OpenAI在12天活动的第三天&#xff0c;正式发布了其全新创意工具——Sora&#xff0c;这款工具凭借其强大的文本到视频生成能力和高度的创作自由度&#xff0c;迅速吸引了广大创作者的目光。与此同时&#xff0c;已经在视频创作领域有着成熟表现的可灵…

重生之我在异世界学智力题(4)

大家好&#xff0c;这里是小编的博客频道 小编的博客&#xff1a;就爱学编程 很高兴在CSDN这个大家庭与大家相识&#xff0c;希望能在这里与大家共同进步&#xff0c;共同收获更好的自己&#xff01;&#xff01;&#xff01; 本文目录 引言渡河问题&#xff08;1&#xff09;问…

福州大学《2024年812自动控制原理真题》 (完整版)

本文内容&#xff0c;全部选自自动化考研联盟的&#xff1a;《福州大学812自控考研资料》的真题篇。后续会持续更新更多学校&#xff0c;更多年份的真题&#xff0c;记得关注哦~ 目录 2024年真题 Part1&#xff1a;2024年完整版真题 2024年真题

实现盘盈单自动化处理:吉客云与金蝶云星空数据对接

盘盈单103v2对接其他入库&#xff1a;吉客云数据集成到金蝶云星空 在企业信息化管理中&#xff0c;数据的高效流转和准确性至关重要。本文将分享一个实际案例&#xff0c;展示如何通过轻易云数据集成平台&#xff0c;将吉客云的数据无缝对接到金蝶云星空&#xff0c;实现盘盈单…

Meta Llama 3.3 70B:性能卓越且成本效益的新选择

Meta Llama 3.3 70B&#xff1a;性能卓越且成本效益的新选择 引言 在人工智能领域&#xff0c;大型语言模型一直是研究和应用的热点。Meta公司最近发布了其最新的Llama系列模型——Llama 3.3 70B&#xff0c;这是一个具有70亿参数的生成式AI模型&#xff0c;它在性能上与4050…