DataLoader与Dataset

news2024/9/23 11:18:26

一、人民币二分类在这里插入图片描述

在这里插入图片描述
在这里插入图片描述

二、DataLoader 与 Dataset

DataLoader

torch.utils.data.DataLoader

功能:构建可迭代的数据装载器
(只标注了较为重要的)
• dataset: Dataset类,决定数据从哪读取及如何读取
• batchsize : 批大小
• num_works: 是否多进程读取数据
• shuffle: 每个epoch是否乱序
• drop_last:当样本数不能被batchsize整除时,是否舍弃最后一批数据

DataLoader(
    dataset,
    batch_size=1,
    shuffle=False,
    sampler=None,
    batch_sampler=None,
    num_workers=0,
    collate_fn=None,
    pin_memory=False,
    drop_last=False,
    timeout=0,
    worker_init_fn=None,
    multiprocessing_context=None
)
  • Epoch: 所有训练样本都已输入到模型中,称为一个Epoch
  • Iteration:一批样本输入到模型中,称之为一个Iteration
  • Batchsize:批大小,决定一个Epoch有多少个Iteration

样本总数:80, Batchsize:8
1 Epoch = 10 Iteration

样本总数:87, Batchsize:8
1 Epoch = 10 Iteration ? drop_last = True
1 Epoch = 11 Iteration ? drop_last = False

根据给定的样本总数和批大小,可以计算出一个Epoch中的Iteration数量。

  1. 样本总数为80,批大小为8:
    • 一个Epoch中的Iteration数量 = 样本总数 / 批大小 = 80 / 8 = 10
  2. 样本总数为87,批大小为8,且设置drop_last = True
    • 一个Epoch中的Iteration数量 = 样本总数 // 批大小 = 87 // 8 = 10
  3. 样本总数为87,批大小为8,且设置drop_last = False
    • 一个Epoch中的Iteration数量 = (样本总数 + 批大小 - 1) // 批大小 = (87 + 8 - 1) // 8 = 11

在第3种情况下,由于样本总数无法被批大小整除,因此在最后一个Epoch中会有一个额外的Iteration来处理剩余的样本。

Dataset

torch.utils.data.Dataset

功能:Dataset抽象类,所有自定义的Dataset需要继承它,并且复写__getitem__()

getitem :接收一个索引,返回一个样本

class Dataset(object):
    def __getitem__(self, index):
        raise NotImplementedError
    def __add__(self, other):
        return ConcatDataset([self, other])

上述代码定义了一个名为Dataset的类,该类是一个抽象基类。它包含了两个特殊方法:

  1. __getitem__(self, index)方法:这是一个抽象方法,需要在子类中实现。它用于根据给定的索引index返回对应的数据样本。在这里,抛出了NotImplementedError异常,表示子类必须覆盖这个方法来提供具体的实现。
  2. __add__(self, other)方法:这是一个特殊方法,用于实现对象的加法操作。在这里,它返回一个ConcatDataset对象,该对象将当前的self和另一个other数据集合并在一起。__add__方法的返回值是一个ConcatDataset对象,表示将当前数据集和另一个数据集进行连接。ConcatDataset是PyTorch中的一个类,用于将多个数据集连接在一起,以便在训练过程中一起使用。

四、模型训练

# -*- coding: utf-8 -*-
"""
# @file name  : train_lenet.py
# @author     : siuserjy
# @date       : 2024-01-03 20:50:38
# @brief      : 人民币分类模型训练
"""
import os

# 获取当前文件的目录路径
BASE_DIR = os.path.dirname(os.path.abspath(__file__))

# 导入必要的库和模块
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torch.optim as optim
from matplotlib import pyplot as plt

# 定义lenet.py和common_tools.py文件的路径并检查文件是否存在
path_lenet = os.path.abspath(os.path.join(BASE_DIR, "..", "..", "model", "lenet.py"))
path_tools = os.path.abspath(os.path.join(BASE_DIR, "..", "..", "tools", "common_tools.py"))
assert os.path.exists(path_lenet), "{}不存在,请将lenet.py文件放到 {}".format(path_lenet, os.path.dirname(path_lenet))
assert os.path.exists(path_tools), "{}不存在,请将common_tools.py文件放到 {}".format(path_tools, os.path.dirname(path_tools))

# 将自定义模块所在的目录添加到Python路径中
import sys
hello_pytorch_DIR = os.path.abspath(os.path.dirname(__file__) + os.path.sep + ".." + os.path.sep + "..")
sys.path.append(hello_pytorch_DIR)

# 从自定义模块导入所需内容
from model.lenet import LeNet
from tools.my_dataset import RMBDataset
from tools.common_tools import set_seed

# 设置随机种子
set_seed()

# 定义人民币数据集的标签
rmb_label = {"1": 0, "100": 1}

# 设置训练参数
MAX_EPOCH = 10
BATCH_SIZE = 16
LR = 0.01
log_interval = 10
val_interval = 1

# ============================ step 1/5 数据 ============================

# 设置数据集路径
split_dir = os.path.abspath(os.path.join(BASE_DIR, "..", "..", "data", "rmb_split"))
if not os.path.exists(split_dir):
    raise Exception(r"数据 {} 不存在, 回到lesson-06\1_split_dataset.py生成数据".format(split_dir))

# 设置训练集和验证集路径
train_dir = os.path.join(split_dir, "train")
valid_dir = os.path.join(split_dir, "valid")

# 设置图像的均值和标准差
norm_mean = [0.485, 0.456, 0.406]
norm_std = [0.229, 0.224, 0.225]

# 设置训练集的数据预处理
train_transform = transforms.Compose([
    transforms.Resize((32, 32)),  # 将图像大小调整为32x32
    transforms.RandomCrop(32, padding=4),  # 随机裁剪32x32大小的图像
    transforms.ToTensor(),  # 将图像转换为Tensor格式
    transforms.Normalize(norm_mean, norm_std),  # 标准化图像
])

# 设置验证集的数据预处理
valid_transform = transforms.Compose([
    transforms.Resize((32, 32)),  # 将图像大小调整为32x32
    transforms.ToTensor(),  # 将图像转换为Tensor格式
    transforms.Normalize(norm_mean, norm_std),  # 标准化图像
])

# 构建训练集和验证集的数据集实例
train_data = RMBDataset(data_dir=train_dir, transform=train_transform)
valid_data = RMBDataset(data_dir=valid_dir, transform=valid_transform)

# 构建训练集和验证集的DataLoader
train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(dataset=valid_data, batch_size=BATCH_SIZE)

# ============================ step 2/5 模型 ============================
# 构建LeNet模型实例
net = LeNet(classes=2)
net.initialize_weights()

# ============================ step 3/5 损失函数 ============================
# 设置损失函数
criterion = nn.CrossEntropyLoss()


# ============================ step 4/5 优化器 ============================
# 设置优化器
optimizer = optim.SGD(net.parameters(), lr=LR, momentum=0.9)

# 设置学习率下降策略
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)


# ============================ step 5/5 训练 ============================
train_curve = list()  # 记录训练集的loss值
valid_curve = list()  # 记录验证集的loss值

for epoch in range(MAX_EPOCH):  # 迭代训练多个epoch

    loss_mean = 0.  # 记录每个epoch的平均loss值
    correct = 0.  # 记录分类正确的样本数量
    total = 0.  # 记录总样本数量

    net.train()  # 将模型设置为训练模式
    for i, data in enumerate(train_loader):  # 遍历训练集数据

        # forward
        inputs, labels = data  # 获取输入数据和标签
        outputs = net(inputs)  # 将输入数据输入模型,得到输出结果

        # backward
        optimizer.zero_grad()  # 将模型参数的梯度置零
        loss = criterion(outputs, labels)  # 计算损失值
        loss.backward()  # 反向传播,计算梯度

        # update weights
        optimizer.step()  # 更新模型参数

        # 统计分类情况
        _, predicted = torch.max(outputs.data, 1)  # 获取预测结果
        total += labels.size(0)  # 累计总样本数量
        correct += (predicted == labels).squeeze().sum().numpy()  # 累计分类正确的样本数量

        # 打印训练信息
        loss_mean += loss.item()  # 累计每个batch的loss值
        train_curve.append(loss.item())  # 将每个batch的loss值记录下来
        if (i+1) % log_interval == 0:  # 每隔一定的batch数打印一次训练信息
            loss_mean = loss_mean / log_interval  # 计算平均loss值
            print("Training:Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(
                epoch, MAX_EPOCH, i+1, len(train_loader), loss_mean, correct / total))
            loss_mean = 0.  # 重置loss_mean

    scheduler.step()  # 更新学习率

    # validate the model
    if (epoch+1) % val_interval == 0:  # 每隔一定的epoch数进行一次验证

        correct_val = 0.  # 记录验证集分类正确的样本数量
        total_val = 0.  # 记录验证集总样本数量
        loss_val = 0.  # 记录验证集的loss值
        net.eval()  # 将模型设置为评估模式
        with torch.no_grad():  # 不计算梯度
            for j, data in enumerate(valid_loader):  # 遍历验证集数据
                inputs, labels = data  # 获取输入数据和标签
                outputs = net(inputs)  # 将输入数据输入模型,得到输出结果
                loss = criterion(outputs, labels)  # 计算损失值

                _, predicted = torch.max(outputs.data, 1)  # 获取预测结果
                total_val += labels.size(0)  # 累计验证集总样本数量
                correct_val += (predicted == labels).squeeze().sum().numpy()  # 累计验证集分类正确的样本数量

                loss_val += loss.item()  # 累计验证集的loss值

            loss_val_epoch = loss_val / len(valid_loader)  # 计算验证集每个epoch的平均loss值
            valid_curve.append(loss_val_epoch)  # 将验证集每个epoch的平均loss值记录下来
            print("Valid:\t Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(
                epoch, MAX_EPOCH, j+1, len(valid_loader), loss_val_epoch, correct_val / total_val))

# 绘制训练曲线和验证曲线
train_x = range(len(train_curve))  # 训练曲线的x轴
train_y = train_curve  # 训练曲线的y轴

train_iters = len(train_loader)  # 训练集的迭代次数
valid_x = np.arange(1, len(valid_curve)+1) * train_iters*val_interval - 1  # 验证曲线的x轴,将epoch转换为iteration
valid_y = valid_curve  # 验证曲线的y轴

plt.plot(train_x, train_y, label='Train')  # 绘制训练曲线
plt.plot(valid_x, valid_y, label='Valid')  # 绘制验证曲线

plt.legend(loc='upper right')  # 设置图例位置
plt.ylabel('loss value')  # 设置y轴标签
plt.xlabel('Iteration')  # 设置x轴标签
plt.show()  # 显示图像


# ============================ inference ============================


# 设置基本路径
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
test_dir = os.path.join(BASE_DIR, "test_data")

# 创建测试数据集
test_data = RMBDataset(data_dir=test_dir, transform=valid_transform)

# 创建验证数据加载器
valid_loader = DataLoader(dataset=test_data, batch_size=1)

# 遍历验证数据集
for i, data in enumerate(valid_loader):
    # 前向传播
    inputs, labels = data
    outputs = net(inputs)
    _, predicted = torch.max(outputs.data, 1)

    # 判断预测结果是1元还是100元
    rmb = 1 if predicted.numpy()[0] == 0 else 100

    # 打印模型获得的金额
    print("模型获得{}元".format(rmb))


在这里插入图片描述

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1353557.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

12.Harbor构建私有镜像仓库

1、阿里云容器镜像服务-个人版 细心的同学可能已经发现,在前面的部署过程中,前面所有的部署步骤中需要的镜像都是从阿里云的镜像仓库中下载。 因为网络原因,有的镜像可能下载比较慢,有点可能下载不了,所以为了加速镜像下载,我都统一将镜像推送到阿里云的镜像仓库(个人…

【操作系统xv6】学习记录--实验1 Lab: Xv6 and Unix utilities--未完

ref:https://pdos.csail.mit.edu/6.828/2020/xv6.html 实验:Lab: Xv6 and Unix utilities 环境搭建 实验环境搭建:https://blog.csdn.net/qq_45512097/article/details/126741793 搭建了1天,大家自求多福吧,哎。~搞环境真是折磨…

Unity | 渡鸦避难所-5 | 角色和摄像机之间的遮挡物半透明

1 前言 角色在地图上移动到岩石后面时,完全被岩石遮挡,玩家只能看到岩石。这逻辑看起来没问题,但并不是玩家想要看到的画面,玩家更希望关注角色的状态 为了避免角色被遮挡,可以使用 Cinemachine Collider 功能&#x…

多模态——旷视大模型Vary更细粒度的视觉感知实现文档级OCR或图表理解

概述 现代大型视觉语言模型(LVLMs),例如CLIP,使用一个共同的视觉词汇,以适应多样的视觉任务。然而,在处理一些需要更精细和密集视觉感知的特殊任务时,例如文档级OCR或图表理解,尤其…

Java多线程详解

进程 进程是程序的执行实例,而在进程的执行过程中,它需要操作和管理一系列的数据。这个数据集合通常包括程序的代码、程序计数器、寄存器、堆栈、数据段和其他与程序执行相关的信息。这些数据共同构成了一个进程的上下文(context&#xff09…

案例088:基于微信小程序的校车购票平台设计与实现

文末获取源码 开发语言:Java 框架:SSM JDK版本:JDK1.8 数据库:mysql 5.7 开发软件:eclipse/myeclipse/idea Maven包:Maven3.5.4 小程序框架:uniapp 小程序开发软件:HBuilder X 小程序…

【深入浅出Docker原理及实战】「原理实战体系」零基础+全方位带你学习探索Docker容器开发实战指南(Docker-compose使用全解 一)

Docker-compose使用全解 Compose介绍Compose的作用和职能 Compose和Docker兼容性安装docker-compose添加可执行权限 Docker Compose常用配置imagebuildcontext上下文指定镜像名args构建环境变量 commanddepends_onports特殊映射关系 volumesenvironment Docker Compose命令详解…

适合 C++ 新手学习的开源项目——在 GitHub 学编程

作者:HelloGitHub-小鱼干 俗话说:万事开头难,学习编程也是一样。在 HelloGitHub 的群里,经常遇到有小伙伴询问编程语言如何入门方面的问题,如: 我要学习某一门编程语言,有什么开源项目可以推荐…

juniper EX系列交换机 包过滤(Packet Filtering)配置

Juniper EX交换机支持基于物理端口、VLAN和三层VLAN接口的包过滤技术: 在二层过滤下支持: ■ Ingress port firewall filter ■ Ingress VLAN firewall filter ■ Egress VLAN firewall filter 在三层过滤下支持: ■ Ingress port firew…

项目经理面试10问

今天我们来说说项目经理专业面试的十条经验总结。如果你认真阅读并思考,相信对在屏幕前的你会有所帮助和启发。 1、请做一下自我介绍 自我介绍很重要。无论面试什么岗位,面试官通常都会问你一个最常见的问题:“请做一下自我介绍。” 在准备…

搜维尔科技:【简报】第九届元宇宙数字人设计大赛,报名已经进入白热化阶段!

随着元宇宙时代的来临,数字人设计成为了创新前沿领域之一。为了提高大学生元宇宙虚拟人角色策划与美术设计的专业核心能力,我们特别举办了这场元宇宙数字人设计赛道,赛道主题为「AI人工智能科技」 ,只要与「AI人工智能科技」相关的…

三菱plc的点动控制循环(小灯闪烁,单控气缸循环)

以为前一段时间小编做了一个气缸定时循环的程序,根据程序有不足之处,所以小编写下这篇文章,将网络上的plc小灯控制进行总结!如果对你有帮助,不要忘了点赞收藏!如果有更加好的梯形图,欢迎评论&am…

搭建FTP服务器

目录 一、FTP 1.1 FTP简介 1.2 FTP服务器搭建 1.2.1 前提 1.2.2 创建组 1.2.3 创建用户 1.2.4 安装FTP服务器 1.2.5 配置FTP服务器 1.2.6 配置FTP的文件夹权限 1.2.7 连接测试 1.2.8 允许外部访问 二、计算机端口介绍 2.1 端口简介 2.2 开启端口 2.3 端口相关 2…

第一至四批专精特新“小巨人”企业信息库

第一至四批专精特新“小巨人”企业信息库 1、指标:专精特新公示批次、企业名称、登记状态、法定代表人、注册资本、实缴资本、成立日期、核准日期、营业期限、所属省份、所属城市、所属区县、电话、更多电话、邮箱、更多邮箱、统一社会信用代码、纳税人识别号 注册…

Java后端开发——Spring实验

文章目录 Java后端开发——Spring实验一、Spring入门1.创建项目,Spring依赖包。2.创建JavaBean:HelloSpring3.编写applicationContext.xml配置文件4.测试:启动Spring,获取Hello示例。 二、Spring基于XML装配实验1.创建JavaBean类&…

业务中台-UAT测试用例示例

今天我来和大家分享一下我们在业务中台UAT测试用例的案例,这个案例的编写方式是参考了其他项目来编写的。这个测试用例主要分为两个部分:用例目录和测试具体内容。 对于UAT测试用例,我们理解应该存在两种不同的编写方式,一种是功…

c语言-浮点型数据在内存中的存储

目录 前言一、浮点数存储例子二、浮点数在内存的存储格式2.1 32位浮点数存储格式2.2 64位浮点数存储格式 三、IEEE 754对有效数字M和指数E的规定3.1 对存储有效数字M的规定3.2 对存储指数E的规定3.2.1 E在32位浮点数的存储格式3.2.2 E在64位浮点数的存储格式 3.3 对读取有效数M…

Python 箱线图的绘制(Matplotlib篇-13)

Python 箱线图的绘制(Matplotlib篇-13)         🍹博主 侯小啾 感谢您的支持与信赖。☀️ 🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ�…

Python贪吃蛇小游戏(PyGame)

文章目录 写在前面PyGame入门贪吃蛇注意事项写在后面 写在前面 本期内容:基于pygame的贪吃蛇小游戏 实验环境 python3.11及以上pycharmpygame 安装pygame的命令: pip install -i https://pypi.tuna.tsinghua.edu.cn/simple pygamePyGame入门 pygam…

UI5与后端的文件交互(二)

文章目录 前言一、开发Action1. 创建Structure2. BEDF添加Action3. class中实现Action 二、修改UI5 项目1. 添加一个按钮2. 定义事件函数 三、测试及解析1. 测试2. js中提取到的excel流数据3. 后端解析 前言 这系列文章详细记录在Fiori应用中如何在前端和后端之间使用文件进行…