[Pytorch]DataSet和DataLoader逐句详解

news2024/12/25 12:24:49

        将自己的数据集引入Pytorch是搭建属于自己的神经网络的重要一步,这里我设计了一个简单的实验,结合这个实验代码,我将逐句教会大家如何将数据引入DataLoader。

        这里以目标检测为例,一个batch中包含图片文件、先验框的框体坐标、目标类型,相对而言更加全面。大家亦可根据自己的数据结构和需求进行修改。

一、数据文件分析

        标准的Voc格式是无法直接注入模型的,而如果在训练程序中进行处理即拖慢了运算速度,又难以保证数据集分割的一致性。最好是使用一个独立程序完成数据集的分割、组织、暂存。这里参考了Bubbliiiing的做法,将数据集信息暂存为txt文件。其中一条具体数据的格式如下:

../VOC2007/JPEGImages/0.jpg 166,121,336,323,0 1052,372,1371,924,1
#   文件路径(绝对路径为佳)/ # 先验框框体信息1/  # 先验框框体信息1/

        文件路径框体信息之间采用空格分开;框体信息内部以逗号分开,前4个为坐标信息,最后一个为分类信息

        在随后的程序中,我们将循环读取这个文件中的数据来获取数据集信息。

二、载入数据

        1.定义DataSet超参

                在开始重写DataSet前,我们需要定义一些用来控制DataSet的参数。

#----自定义DataSet,继承自torch.utils.data.DataSet
class MyDataSet(Dataset):
    #----参数定义,输入的参数分别为数据行、输入图像尺寸,类型数
    def __init__(self,file_Lines,inp_Shape,num_Classes):
        super(MyDataSet,self).__init__()
        #  将局部形参变为类内的全局变量
        self.length = len(file_Lines)   #将文件数赋值给length
        self.file_Lines = file_Lines
        self.inp_Shape = inp_Shape
        self.num_Classes = num_Classes

        2.重写len函数

                没什么技巧,因为我们刚刚将数据行(file_Lines)的长度赋值给了self.length,直接返回这个值就能拿到数据集的长度了。这也是设置超参数的意义所在。

def __len__(self):
    return self.length

        3.重写getitem函数

                此函数每次会获取单个文件,DataLoader通过反复调用这个函数最终获取整个数据集,我们写入的,这个函数自带一个index用于控制获取的行数。

                ①分解数据集文件

                        按照我们上面解析的文件格式,我们用split函数分割index行的文本(空格分割),得到的file_item中第1个元素为文件的绝对路径,后续元素为目标先验框的信息。

    def __getitem__(self, index):
        index = index % self.length     #计算batch长度

        file_item = self.file_Lines[index].split()  
        #按空格拆分文件行,其中的元素分别为:文件路径、先验框坐标(n个)

                                file_item的值如下图所示: 

                ②加载图片文件

                        同样没什么技巧,拿到文件路径后直接打开就好了。需要注意的是神经网络的输入需要为固定的形状(图片尺寸和颜色通道数),如果图片为灰阶图则需要将其颜色通道扩充为3个(RGB图)

        img = Image.open(file_item[0])              #打开图片
        img = cvt2RGB(img)                          
        #若图像为灰度图需要先转换为RGB图(神经网络输入为3通道)
#----将图像转换为RGB----#
def cvt2RGB(img):
    if len(np.shape(img)) == 3 and np.shape(img)[2]==3:
        return img                                      #为RGB不需要转换
    else:
        img = img.convert('RGB')
        return img          

                ③拆分框体信息

                        同样同样没什么好说的,遍历分割file_item从1开始的元素就好了

box_info = np.array([np.array(list(map(int,box.split(',')))) for box in file_item[1:]])   
#从文件中加载先验框坐标和类型(从第1个元素开始)

                ④将图片变形

                        这一步也不是必须的,可以选择在开始训练之前对图片信息进行处理。但是在程序中处理需要注意一点,在改变图像的同时需要以同样的比例改变先验框的坐标。

img,new_box = self.resize_img_withBox(img,box_info,self.inp_Shape)

                        这里给出一个无损变换大小的函数,若不指定参数则不变化。 

    def resize_img_withBox(self,img, box, size=[0,0]):      #输入参数分别为:原图、先验框列表、变形后的图片大小
        iw,ih = img.size
        w,h = size
        new_box=[]
        #  若没有指定大小则不需要变形,若指定了大小则进行变形
        if size!=(0,0):
            scale = min(w/iw,h/ih)                          #获取变形比例
            nw = int(iw*scale)                              #计算变形后的长宽
            nh = int(ih*scale)
            
            dx = (w-nw)//2
            dy = (h-nh)//2
            #  图像变形
            img = img.resize((nw,nh),Image.BICUBIC)
            new_img = Image.new('RGB',size,(128,128,128))   #创建一张灰色背景
            new_img.paste(img,((w-nw)//2,(h-nh)//2))        #将变形后的图片贴进背景中央
            #  先验框变形
            if len(box)>0:
                np.random.shuffle(box)
                box[:, [0,2]] = box[:, [0,2]]*nw/iw + dx
                box[:, [1,3]] = box[:, [1,3]]*nh/ih + dy
                box[:, 0:2][box[:, 0:2]<0] = 0
                box[:, 2][box[:, 2]>w] = w
                box[:, 3][box[:, 3]>h] = h
                box_w = box[:, 2] - box[:, 0]
                box_h = box[:, 3] - box[:, 1]
                new_box = box[np.logical_and(box_w>1, box_h>1)] # discard invalid box
        return new_img,new_box

                ⑤变换图片的通道

                        标准的图片通道为RGB,而在pytorch中图片的通道为BGR,所以我们需要对通道进行调整,同时为其附加batch通道。

img = np.transpose(preprocess_input(np.array(img, dtype=np.float32)), (2, 0, 1))

                        前面的函数是一个增强函数,会给RGB三个通道加上不同的权值,至于权值则是一个默认权值(我也不知道为什么用这个)

#----为图像加权,这是一般默认的参数----#
def preprocess_input(image):
    image   = np.array(image,dtype = np.float32)[:, :, ::-1]
    mean    = [0.40789655, 0.44719303, 0.47026116]
    std     = [0.2886383, 0.27408165, 0.27809834]
    return (image / 255. - mean) / std

                ⑥拆分坐标信息和分类信息

                        如题,将坐标信息和分类信息从先验框信息组中进行分割

        #  拆分先验框坐标和类型
        box_data = np.zeros((len(new_box),5))
        if(len(box_info)>0):
            box_data[:len(box_info)] = box_info
        box = box_data[:,:4]
        label = box_data[:,-1]

                完成上述步骤后,将得到的数据返回即可,完整的getitem函数如下:

    def __getitem__(self, index):
        index = index % self.length     #计算batch长度
        #  读取文件
        file_item = self.file_Lines[index].split()  #按空格拆分文件行,其中的元素分别为:文件路径、先验框坐标(n个)
        img = Image.open(file_item[0])              #打开图片
        img = cvt2RGB(img)                          #若图像为灰度图需要先转换为RGB图(神经网络输入为3通道)
        box_info = np.array([np.array(list(map(int,box.split(',')))) for box in file_item[1:]])   #从文件中加载先验框坐标和类型(从第1个元素开始)
        #  对图像进行变形(含先验框变形)
        img,new_box = self.resize_img_withBox(img,box_info,self.inp_Shape)
        #  将图像进行加权
        #img = np.transpose(preprocess_input(np.array(img, dtype=np.float32)), (2, 0, 1))
        img = np.transpose(np.array(img))
        #  拆分先验框坐标和类型
        box_data = np.zeros((len(new_box),5))
        if(len(box_info)>0):
            box_data[:len(box_info)] = box_info
        box = box_data[:,:4]
        label = box_data[:,-1]
        
        return img,box,label

三、数据封包

        我们在训练时肯定不能这样一个一个训练,一般情况我们训练时会将这些数据打包成一个个的patch送给迭代器,而collate_fn就是做这个的,需要注意collate_fn并不是DataSet类的成员

        这个函数使dataloader自动使用的,其中的images、bboxes、labels 将会在训练过程中用到,这里我们只要确保将数据装入对应的容器中即可。

# DataLoader中collate_fn使用
def my_collate(batch):
    images = []
    bboxes = []
    labels = []
    for img, box, label in batch:
        images.append(img)
        bboxes.append(box)
        labels.append(label)
    images = np.array(images)
    return images, bboxes, labels

四、调用

        ①读取数据集文件(txt)

    file_path = "数据集文件的路径"
    with open(file_path) as f:
        train_lines = f.readlines()

        ②实例化DataSet

    train_dataset = MyDataSet(train_lines,input_shape,num_classes)
    train_loader = DataLoader(train_dataset, shuffle = True, batch_size = 32, num_workers = 1, collate_fn=my_collate)

                其中num_workers是线程数;batch_size是单个batch的大小;collate_fn指向我们刚刚重写的collate_fn;shuffle表示是否打乱数据集的顺序。

        ③训练

                当然,这里不是真的训练,我们用一个展示函数来代替训练。

    print("开始打印结果")
    for item in train_dataset:
        img, box, label = item
        img = img.transpose(1,2,0)
        print(img.shape)
        im = Image.fromarray(np.uint8(img))
        im.show()
        input("按任意键继续")

         这个昆虫数据集太恶心了就不给大家看了

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/389946.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【C语言进阶:指针的进阶】你真分得清sizeof和strlen?

本章重点内容&#xff1a; 字符指针指针数组数组指针数组传参和指针传参函数指针函数指针数组指向函数指针数组的指针回调函数指针和数组面试题的解析这篇博客 FLASH 将带大家一起来练习一些容易让人凌乱的题目&#xff0c;通过这些题目来进一步加深和巩固对数组&#xff0c;指…

基于注解@Transactional事务基本用法

关于概念性的放在最下面,熟读几遍 在使用时候也没多关注,总是加个Transactional 初识下 一般查询 Transactional(propagation Propagation.SUPPORTS) 增删改 Transactional(propagation Propagation.REQUIRED) 当然不能这么马虎 Spring中关于事务的传播 一个个测试,事…

计算机网络第八版——第一章课后题答案(超详细)

第一章 该答案为博主在网络上整理&#xff0c;排版不易&#xff0c;希望大家多多点赞支持。后续将会持续更新&#xff08;可以给博主点个关注~ 【1-01】计算机网络可以向用户提供哪些服务&#xff1f; 解答&#xff1a;这道题没有现成的标准答案&#xff0c;因为可以从不同的…

操作系统——15.FCFS、SJF、HRRN调度算法

这节我们来看一下进程调度的FCFS、SJF、HRRN调度算法 目录 1.概述 2.先来先服务算法&#xff08;FCFS&#xff0c;First Come First Serve&#xff09; 3.短作业优先算法&#xff08;SJF&#xff0c;Shortest Job First&#xff09; 4.高响应比优先算法&#xff08;HRRN&…

Jackson CVE-2018-5968 反序列化漏洞

0x00 前言 同CVE-2017-15095一样&#xff0c;是CVE-2017-7525黑名单绕过的漏洞&#xff0c;主要还是看一下绕过的调用链利用方式。 可以先看&#xff1a; Jackson 反序列化漏洞原理 或者直接看总结也可以&#xff1a; Jackson总结 影响版本&#xff1a;至2.8.11和2.9.x至…

【数据结构】AVL平衡二叉树底层原理以及二叉树的演进之多叉树

1.AVL平衡二叉树底层原理 背景 二叉查找树左右子树极度不平衡&#xff0c;退化成为链表时候&#xff0c;相当于全表扫描&#xff0c;时间复杂度就变为了O(n) 插入速度没影响&#xff0c;但是查询速度变慢&#xff0c;比单链表都慢&#xff0c;每次都要判断左右子树是否为空 需…

java Spring JdbcTemplate配合mysql实现数据批量修改

其实这个操作和批量添加挺像的 调的同一个方法 首先 我们看数据库结构 这是我本地的 mysql 里面有一个test数据库 里面有一张user_list表 然后创建一个java项目 然后 引入对应的JAR包 在src下创建 dao 目录 在下面创建一个接口 叫 BookDao 参考代码如下 package dao;impo…

进程控制~

进程控制 &#xff08;创建、终止&#xff0c;等待&#xff0c;程序替换&#xff09; 进程创建&#xff1a; pid_t fork();父子进程&#xff0c;数据独有&#xff0c;代码共享&#xff0c;各有各的地址 pit_t vfork();父进程阻塞&#xff0c;直到子进程exit退出或者程序替换之…

电力系统稳定性的定义与分类

1电力系统稳定性的定义与分类 IEEE给出电力系统稳定性定义&#xff1a;电力系统稳定性是指电力系统这样的一种能力—对于给定的初始运行状态&#xff0c;经历物理扰动后&#xff0c;系统能够重新获得运行平衡点的状态&#xff0c;同时绝大多数系统变量有界&#xff0c;因此整个…

自己写一个简单的IOC

什么是SpringIOC&#xff1f; 答&#xff1a;IOC即控制反转&#xff0c;就是我们不在手动的去new一个对象&#xff0c;而是将创建对象的权力交给Spring去管理&#xff0c;我们想要一个User类型的对象&#xff0c;就只需要定义一个User类型的变量user1&#xff0c;然后让Spring去…

蓝桥杯-刷题统计

蓝桥杯-刷题统计1、问题描述2、解题思路3、代码实现3.1 方案一&#xff1a;累加方法(超时)3.2 方案二1、问题描述 小明决定从下周一开始努力刷题准备蓝桥杯竞赛。他计划周一至周五每天做 a 道题目, 周六和周日每天做 b 道题目。请你帮小明计算, 按照计划他将在 第几天实现做题数…

KNN学习报告

原理 KNN算法就是在其表征空间中&#xff0c;求K个最邻近的点。根据已知的这几个点对其进行分类。如果其特征参数只有一个&#xff0c;那么就是一维空间。如果其特征参数只有两个&#xff0c;那么就是二维空间。如果其特征参数只有三个&#xff0c;那么就是三维空间。如果其特征…

软件设计师教程(七)计算机系统知识-操作系统知识

软件设计师教程 软件设计师教程&#xff08;一&#xff09;计算机系统知识-计算机系统基础知识 软件设计师教程&#xff08;二&#xff09;计算机系统知识-计算机体系结构 软件设计师教程&#xff08;三&#xff09;计算机系统知识-计算机体系结构 软件设计师教程&#xff08;…

Redis十大类型——Hash常见操作

Redis十大类型——Hash常见操作命令操作简列存放及获取获取健值对长度元素查找列出健值对对数字进行操作赋值hsetnx很明显咯它也是以健值对方式存在的&#xff0c;只不过value也就是值&#xff0c;在这里也变成了一组简直对。 &#x1f34a;个&#x1f330;&#xff1a; 想必多…

【Linux】P3 用户与用户组

用户与用户组root 超级管理员设置超级管理员密码切换到超级管理员sudo 临时使用超级权限用户与用户组用户组管理用户管理getentroot 超级管理员 设置超级管理员密码 登陆后不会自动开启 root 访问权限&#xff0c;需要首先执行如下步骤设定 root 超级管理员密码 1、解除 roo…

【C++】string的使用及其模拟实现

文章目录1. STL的介绍1.1 STL的六大组件1.2 STL的版本1.3 STL的缺陷2. string的使用2.1 为什么要学习string类&#xff1f;2.2 常见构造2.3 Iterator迭代器2.4 Capacity2.5 Modifiers2.6 String operations3. string的模拟实现3.1 构造函数3.2 拷贝构造函数3.3 赋值运算符重载和…

DevOps实战50讲-(2)Jenkins配置

1. Docker镜像方式安装拉取Jenkins镜像docker pull jenkins/jenkins编写docker-compose.ymlversion: "3.1" services:jenkins:image: jenkins/jenkinscontainer_name: jenkinsports:- 8080:8080- 50000:50000volumes:- ./data/:/var/jenkins_home/首次启动会因为数据…

iis之web服务器搭建、部署(详细)~千锋

目录 Web服务器 部署web服务器 实验一 发布一个静态网站 实验二 一台服务器同时发布多个web站点 网站类型 Web服务器 也叫网页服务或HTTP服务器web服务器使用的协议是HTTPHTTP协议端口号&#xff1a;TCP 80、HTTPS协议端口号&#xff1a;TCP 443Web服务器发布软件&…

【备战面试】每日10道面试题打卡-Day4

本篇总结的是Java集合知识相关的面试题&#xff0c;后续也会更新其他相关内容 文章目录1、HashMap在JDK1.7和JDK1.8中有哪些不同&#xff1f;2、HashMap 的长度为什么是2的幂次方&#xff1f;3、HashMap的扩容操作是怎么实现的&#xff1f;4、HashMap是怎么解决哈希冲突的&…

Android 基础知识4-3.5 RadioButton(单选按钮)Checkbox(复选框)详解

一、RadioButton&#xff08;单选按钮&#xff09; 1.1、简介 RadioButton表示单选按钮&#xff0c;是button的子类&#xff0c;每一个按钮都有选择和未选中两种状态&#xff0c;经常与RadioGroup一起使用&#xff0c;否则不能实现其单选功能。RadioGroup继承自LinearLayout&a…