Dataloader数据集的制作

news2025/1/13 10:12:31

数据集Dataloader制作

如何自定义数据集:

  • 1.数据和标签的目录结构先搞定(得知道到哪读数据)
  • 2.写好读取数据和标签路径的函数(根据自己数据集情况来写)
  • 3.完成单个数据与标签读取函数(给dataloader举一个例子)

咱们以花朵数据集为例:

  • 原来数据集都是以文件夹为类别ID,现在咱们换一个套路,用txt文件指定数据路径与标签(实际情况基本都这样)
  • 这回咱们的任务就是在txt文件中获取图像路径与标签,然后把他们交给dataloader
  • 核心代码非常简单,按照对应格式传递需要的数据和标签就可以啦
    在这里插入图片描述
import os
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
import torch
from torch import nn
import torch.optim as optim
import torchvision
#pip install torchvision
from torchvision import transforms, models, datasets
#https://pytorch.org/docs/stable/torchvision/index.html
import imageio
import time
import warnings
import random
import sys
import copy
import json
from PIL import Image

先来分细节整明白咱一会要干啥!

任务1:读取txt文件中的路径和标签

  • 第一个小任务,从标注文件中读取数据和标签
  • 至于你准备存成什么格式,都可以的,一会能取出来东西就行
def load_annotations(ann_file):
    data_infos = []
    with open(ann_file) as f:
        samples = [x.strip().split(' ')for x in f.readlines]
def load_annotations(ann_file):
    data_infos = {}
    with open(ann_file) as f:
        samples = [x.strip().split(' ') for x in f.readlines()]
        for filename, gt_label in samples:
            data_infos[filename] = np.array(gt_label, dtype=np.int64)
    return data_infos

在这里插入图片描述

任务2:分别把数据和标签都存在list里

  • 不是我非让你存list里,因为dataloader到时候会在这里取数据
  • 按照人家要求来,不要耍个性,让整list咱就给人家整
img_label = load_annotations('./flower_data/train.txt')
image_name = list(img_label.keys())
label = list(img_label.values())

任务3:图像数据路径得完整

  • 因为一会咱得用这个路径去读数据,所以路径得加上前缀
  • 以后大家任务不同,数据不同,怎么加你看着来就行,反正得能读到图像
data_dir = './flower_data/'
train_dir = data_dir + '/train_filelist'
valid_dir = data_dir + '/val_filelist'
image_path = [os.path.join(train_dir,img) for img in image_name]
image_path

任务4:把上面那几个事得写在一起

  • 1.注意要使用from torch.utils.data import Dataset, DataLoader
  • 2.类名定义class FlowerDataset(Dataset),其中FlowerDataset可以改成自己的名字
  • 3.def init(self, root_dir, ann_file, transform=None):咱们要根据自己任务重写
  • 4.def getitem(self, idx):根据自己任务,返回图像数据和标签数据
from torch.utils.data import Dataset, DataLoader
class FlowerDataset(Dataset):
    def __init__(self, root_dir, ann_file, transform=None):
        self.ann_file = ann_file
        self.root_dir = root_dir
        self.img_label = self.load_annotations()
        self.img = [os.path.join(self.root_dir,img) for img in list(self.img_label.keys())]
        self.label = [label for label in list(self.img_label.values())]
        self.transform = transform
        
    def __len__(self):
        return len(self.img)
 
    def __getitem__(self, idx):
        image = Image.open(self.img[idx])
        label = self.label[idx]
        if self.transform:
            image = self.transform(image)
        label = torch.from_numpy(np.array(label))
        return image, label
    def load_annotations(self):
        data_infos = {}
        with open(self.ann_file) as f:
            samples = [x.strip().split(' ') for x in f.readlines()]
            for filename, gt_label in samples:
                data_infos[filename] = np.array(gt_label, dtype=np.int64)
        return data_infos

任务5:数据预处理(transform)

  • 1.预处理的事都在上面的__getitem__中完成,需要对图像和标签咋咋地的,要整啥事,都在上面整
  • 2.返回的数据和标签就是建模时模型的输入和损失函数中标签的输入,一定整明白自己模型要啥
  • 3.预处理这个事是你定的,不同的数据需要的方法也不一样,下面给出的是比较通用的方法
data_transforms = {
    'train': 
        transforms.Compose([
        transforms.Resize(64),
        transforms.RandomRotation(45),#随机旋转,-45到45度之间随机选
        transforms.CenterCrop(64),#从中心开始裁剪
        transforms.RandomHorizontalFlip(p=0.5),#随机水平翻转 选择一个概率概率
        transforms.RandomVerticalFlip(p=0.5),#随机垂直翻转
        transforms.ColorJitter(brightness=0.2, contrast=0.1, saturation=0.1, hue=0.1),#参数1为亮度,参数2为对比度,参数3为饱和度,参数4为色相
        transforms.RandomGrayscale(p=0.025),#概率转换成灰度率,3通道就是R=G=B
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])#均值,标准差
    ]),
    'valid': 
        transforms.Compose([
        transforms.Resize(64),
        transforms.CenterCrop(64),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

任务6:根据写好的class FlowerDataset(Dataset):来实例化咱们的dataloader

  • 1.构建数据集:分别创建训练和验证用的数据集(如果需要测试集也一样的方法)
  • 2.用Torch给的DataLoader方法来实例化(batch啥的自己定,根据你的显存来选合适的)
  • 3.打印看看数据里面是不是有东西了
train_dataset = FlowerDataset(root_dir=train_dir, ann_file = './flower_data/train.txt', transform=data_transforms['train'])
val_dataset = FlowerDataset(root_dir=valid_dir, ann_file = './flower_data/val.txt', transform=data_transforms['valid'])
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=True)

任务7:用之前先试试,整个数据和标签对应下,看看对不对

  • 1.别着急往模型里传,对不对都不知道呢
  • 2.用这个方法:iter(train_loader).next()来试试,得到的数据和标签是啥
  • 3.看不出来就把图画出来,标签打印出来,确保自己整的数据集没啥问题
image, label = iter(train_loader).next()
sample = image[0].squeeze()
sample = sample.permute((1, 2, 0)).numpy()
sample *= [0.229, 0.224, 0.225]
sample += [0.485, 0.456, 0.406]
plt.imshow(sample)
plt.show()
print('Label is: {}'.format(label[0].numpy()))
image, label = iter(val_loader).next()
sample = image[0].squeeze()
sample = sample.permute((1, 2, 0)).numpy()
sample *= [0.229, 0.224, 0.225]
sample += [0.485, 0.456, 0.406]
plt.imshow(sample)
plt.show()
print('Label is: {}'.format(label[0].numpy()))

任务8:咋用就是你来定了,把模型啥的整好往里面传吧

  • 下面这些事之前都唠过了,按照自己习惯的方法整就得了
dataloaders = {'train':train_loader,'valid':val_loader}
model_name = 'resnet'  #可选的比较多 ['resnet', 'alexnet', 'vgg', 'squeezenet', 'densenet', 'inception']
#是否用人家训练好的特征来做
feature_extract = True 
# 是否用GPU训练
train_on_gpu = torch.cuda.is_available()

if not train_on_gpu:
    print('CUDA is not available.  Training on CPU ...')
else:
    print('CUDA is available!  Training on GPU ...')
    
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model_ft = models.resnet18()
model_ft
num_ftrs = model_ft.fc.in_features
model_ft.fc = nn.Sequential(nn.Linear(num_ftrs, 102))
input_size = 64
model_ft
# 优化器设置
optimizer_ft = optim.Adam(model_ft.parameters(), lr=1e-3)
scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)#学习率每7个epoch衰减成原来的1/10
criterion = nn.CrossEntropyLoss()
def train_model(model, dataloaders, criterion, optimizer, num_epochs=25, is_inception=False, filename='best.pth'):
    since = time.time()
    best_acc = 0
    model.to(device)

    val_acc_history = []
    train_acc_history = []
    train_losses = []
    valid_losses = []
    LRs = [optimizer.param_groups[0]['lr']]

    best_model_wts = copy.deepcopy(model.state_dict())

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # 训练和验证
        for phase in ['train', 'valid']:
            if phase == 'train':
                model.train()  # 训练
            else:
                model.eval()   # 验证

            running_loss = 0.0
            running_corrects = 0

            # 把数据都取个遍
            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                # 清零
                optimizer.zero_grad()
                # 只有训练的时候计算和更新梯度
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    loss = criterion(outputs, labels)
                    _, preds = torch.max(outputs, 1)
                    #print(loss)

                    # 训练阶段更新权重
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # 计算损失
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            epoch_loss = running_loss / len(dataloaders[phase].dataset)
            epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)
            
            
            time_elapsed = time.time() - since
            print('Time elapsed {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
            print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))
            

            # 得到最好那次的模型
            if phase == 'valid' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
                state = {
                  'state_dict': model.state_dict(),#字典里key就是各层的名字,值就是训练好的权重
                  'best_acc': best_acc,
                  'optimizer' : optimizer.state_dict(),#优化器的状态信息
                }
                torch.save(state, filename)
            if phase == 'valid':
                val_acc_history.append(epoch_acc)
                valid_losses.append(epoch_loss)
                scheduler.step(epoch_loss)#学习率衰减
            if phase == 'train':
                train_acc_history.append(epoch_acc)
                train_losses.append(epoch_loss)
        
        print('Optimizer learning rate : {:.7f}'.format(optimizer.param_groups[0]['lr']))
        LRs.append(optimizer.param_groups[0]['lr'])
        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))

    # 训练完后用最好的一次当做模型最终的结果,等着一会测试
    model.load_state_dict(best_model_wts)
    return model, val_acc_history, train_acc_history, valid_losses, train_losses, LRs 
model_ft, val_acc_history, train_acc_history, valid_losses, train_losses, LRs  = train_model(model_ft, dataloaders, criterion, optimizer_ft, num_epochs=20, filename='best.pth')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/867138.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Flink CDC系列之:TiDB CDC 导入 Elasticsearch

Flink CDC系列之:TiDB CDC 导入 Elasticsearch 一、通过docker 来启动 TiDB 集群二、下载 Flink 和所需要的依赖包三、在TiDB数据库中创建表和准备数据四、启动Flink 集群,再启动 SQL CLI五、在 Flink SQL CLI 中使用 Flink DDL 创建表六、Kibana查看Ela…

页面静态化(模板引擎Freemarker)

1、浏览器请求web服务器 2、服务器渲染页面,渲染的过程就是向jsp页面(模板)内填充数据(模型)。 3、服务器将渲染生成的页面返回给浏览器。 所以模板引擎就是:模板数据输出,Jsp页面就是模板,页面中嵌入的jsp标签就是数据&#x…

【Linux】Reactor模式

​🌠 作者:阿亮joy. 🎆专栏:《学会Linux》 🎇 座右铭:每个优秀的人都有一段沉默的时光,那段时光是付出了很多努力却得不到结果的日子,我们把它叫做扎根 目录 👉Reactor …

生成式人工智能模型:提升营销分析用户体验

使用生成式人工智能来改善分析体验,使业务用户能够询问有关我们数据平台中可用数据的任何信息。 在本文中,我们将解释如何使用新的生成式人工智能模型 ( LLM ) 来改善业务用户在我们的分析平台上的体验。假设我们为零售销售经理提供 Web 应用程序或移动应…

【雕爷学编程】Arduino动手做(24)---水位传感器模块3

37款传感器与模块的提法,在网络上广泛流传,其实Arduino能够兼容的传感器模块肯定是不止37种的。鉴于本人手头积累了一些传感器和执行器模块,依照实践出真知(一定要动手做)的理念,以学习和交流为目的&#x…

斯坦福「小镇」开源AI智能体;小米应用商店将要求AI应用符合资质标准

🦉 AI新闻 🚀 斯坦福「小镇」开源AI智能体 摘要:斯坦福研究人员开源了一个类似《西部世界》的数字化「小镇」,里面有25个AI智能体可以生活、工作、社交。这项研究被视为AGI的重要开端,可能会改变游戏、企业应用领域。网友期待这项技术改善游戏NPC的交互…

【Fegin技术专题】「原生态」打开Fegin之RPC技术的开端,你会使用原生态的Fegin吗?(上)

前提介绍 Feign是SpringCloud中服务消费端的调用框架,通常与ribbon,hystrix等组合使用。由于遗留原因,某些项目中,整个系统并不是SpringCloud项目,甚至不是Spring项目,而使用者关注的重点仅仅是简化http调…

软工导论知识框架(七)面向对象设计

一.设计准则 分析:提取、整理用户需求,建立问题域精确模型。设计:转变需求为系统实现方案,建立求解域模型。 在实际的软件开发过程中分析和设计的界限是模糊的,分析和设计活动是一个多次反复迭代的过程。分析的结果可…

无涯教程-Perl - msgsnd函数

描述 此功能使用可选的FLAGS将消息MSG发送到消息队列ID。 语法 以下是此函数的简单语法- msgsnd ID, MSG, FLAGS返回值 该函数在错误时返回0,在成功时返回1。 Perl 中的 msgsnd函数 - 无涯教程网无涯教程网提供描述此功能使用可选的FLAGS将消息MSG发送到消息队列ID。 语法…

接地电阻测试仪的原理和使用事项

接地电阻测试仪(Ground Resistance Tester)是用来测量接地电阻的一种仪器。接地系统是指用于保护人员和设备的设施,它将电流引导到地下,将任何潜在危险的电流导向地面。 接地电阻测试仪的作用是通过测量接地系统的电阻值来评估其…

C++ STL string类模拟实现

目录 string类成员变量 一.构造函数 二.析构函数 三.拷贝构造 四.size(),capacity() 五.operator [ ] 六. operator 七.字符串比较 八.reserve() 九.push_back(),append() 十.operato…

【雕爷学编程】Arduino动手做(12)---霍尔磁场传感器模块5

37款传感器与模块的提法,在网络上广泛流传,其实Arduino能够兼容的传感器模块肯定是不止37种的。鉴于本人手头积累了一些传感器和执行器模块,依照实践出真知(一定要动手做)的理念,以学习和交流为目的&#x…

数据归一化:优化数据处理的必备技巧

文章目录 🍀引言🍀数据归一化的概念🍀数据归一化的应用🍀数据归一化的注意事项与实践建议🍀代码演示🍀在sklearn中使用归一化🍀结语 🍀引言 在当今数据驱动的时代,数据的…

Vue在页面输出JSON对象,测试接口可复制使用

效果图&#xff1a; 数据处理前&#xff1a; 数据处理后&#xff1a; 代码实现&#xff1a; HTML: <el-table height"600" :data"tableData" border style"width: 100%" tooltip-effect"dark" size"mini"><el-…

Django笔记之数据库函数之日期函数

日期函数主要介绍两个大类&#xff0c;Extract() 和 Trunc() Extract() 函数作用是提取日期&#xff0c;比如我们可以提取一个日期字段的年份&#xff0c;月份&#xff0c;日等数据 Trunc() 的作用则是截取&#xff0c;比如 2022-06-18 12:12:12&#xff0c;我们可以根据需求…

深度学习基础知识笔记

深度学习要解决的问题 1 深度学习要解决的问题2 应用领域3 计算机视觉任务4 视觉任务中遇到的问题5 得分函数6 损失函数7 前向传播整体流程8 返向传播计算方法1 梯度下降9 神经网络整体架构 11 神经元个数对结果的影响12 正则化和激活函数1 正则化2 激活函数 13 神经网络过拟合…

人工智能可解释性(二)(梯度计算,积分梯度等)

目录 1.定义 2.详述 2.1局部解释 可视化方法 梯度计算 2.2积分梯度Integrated Gradients&#xff08;梯度计算进阶&#xff09; 2. 3全局解释 2.3.1Activation Maximization 2.3.2GAN,VAE 2. 4用一个可解释模型解释不可解释模型 2. 4.1LIME 局部解释 参考文献 1.定义 可…

access怎么做进销存?借助access开发进销存管理应用

我不太推荐使用Access&#xff0c;因为他的缺点还是比较明显的&#xff1a; 1、软件自身限制 不能用于互联网&#xff1a;使用Access制作好的管理软件&#xff0c;访问页只能在局域网中使用&#xff1b;只能在Windows上运行&#xff1a;Access仅支持windows的运行环境&#x…

从零开始学习 Java:简单易懂的入门指南之多态(十)

多态&包&final&权限修饰符&代码块 第一章 多态1.1 多态的形式1.2 多态的使用场景1.3 多态的定义和前提1.4 多态的运行特点1.5 多态的弊端1.6 引用类型转换1.6.1 为什么要转型1.6.2 向上转型&#xff08;自动转换&#xff09;1.6.3 向下转型&#xff08;强制转换…

【将回声引入信号中】在语音或音频文件中引入混响或简单回声,以研究回声延迟和回波幅度对生成的回波信号感知的影响(Matlab代码实现)

&#x1f4a5;&#x1f4a5;&#x1f49e;&#x1f49e;欢迎来到本博客❤️❤️&#x1f4a5;&#x1f4a5; &#x1f3c6;博主优势&#xff1a;&#x1f31e;&#x1f31e;&#x1f31e;博客内容尽量做到思维缜密&#xff0c;逻辑清晰&#xff0c;为了方便读者。 ⛳️座右铭&a…