深度学习之“制作自定义数据”--torch.utils.data.DataLoader重写构造方法。

news2024/9/28 7:26:06

深度学习之“制作自定义数据”–torch.utils.data.DataLoader重写构造方法。

前言:

​ 本文讲述重写torch.utils.data.DataLoader类的构造方法对自定义图片制作类似MNIST数据集格式(image, label),用于自己的Pytorch神经网络模型运行,代码已整理打包上传网盘,文末下载。tensor数据格式(N,C,H,W)

  • N:Batch,批处理大小,表示一个batch中的图像数量

  • C:Channel,通道数,表示一张图像中的通道数

  • H:Height,高度,表示图像垂直维度的像素数

  • W:Width,宽度,表示图像水平维度的像素数

  • 例如下图输出一个批次的训练集数据就是一批次64张图片(N),3维通道数(C),一张图片高度32像素(H),一张图片宽度32像素(W)

在这里插入图片描述

步骤一

​ 对图片整理分类(python代码os库进行对文件夹创建和图片的移动到文件夹),以文件夹名为图片的种类名,如下图所示:

在这里插入图片描述

步骤二

​ 对所有种类文件夹进行遍历读入,将每个(图片的文件路径 )和(对应的标签)写入到txt文本中,结果为trian.txt 和 test.txt,作为训练集合测试集的数据准备。代码为CreateDataset01.py

# -*- coding: utf-8 -*-
# @Time : 2023/1/26/026 18:48
# @Author : LeeSheel
# @File : CreateDataset01.py
# @Project : 深度学习

'''
生成训练集和测试集,保存在txt文件中

本地电脑,只选取出3000张图片为训练集进行模型运行数据
'''

import os
import random
train_ratio = 0.6
test_ratio = 1-train_ratio
train_list, test_list = [],[]  #创建两个个列表,里面存放  图片路径+‘\t’+图片标签
data_list = []

rootdata = r"D:\FreeDesk\大创项目\手写藏文字母识别\手写藏文字母数据\总数据"



for root,dirs,files in os.walk(rootdata):
    # print(root)
    # print(dirs)
    # print(files)
    #拼接每个图片的绝对文件路径:
    for i in range(int(len(files)*train_ratio)):
        # print(files[i])#输出的是每个图片的名称
        # print(root+"---"+files[i])  #shu输出每个每个图片的文件夹路径----图片名称
        # print(os.path.join(root, files[i]))  #拼接路径,
        # print(str(root).split("/")[-1])   #dui对root进行字符串切割,获得最后一个元素,代表每个图片的标签。
        class_flag = str(root).split("\\")[-1]  #biaoqain标签
        data = os.path.join(root, files[i]) + '\t' + str(class_flag) + '\n'
        train_list.append(data)

    for i in range(int(len(files) * train_ratio),len(files)):
        # print(i)
        class_flag = str(root).split("\\")[-1]  # biaoqain标签
        # print(class_flag)
        # print(files[i])
        data = os.path.join(root, files[i]) + '\t' + str(class_flag) + '\n'
        test_list.append(data)

# print(train_list)
random.shuffle(train_list)
random.shuffle(test_list)

with open('train.txt','w',encoding='UTF-8') as f:
    for train_img in train_list:
            f.write(str(train_img))

with open('test.txt','w',encoding='UTF-8') as f:
    for test_img in test_list:
        f.write(test_img)



## 随机抽取3000个作为本地train.txt   以及1000个作为本地test.txt

# from random import sample
#
# print(sample(train_list, 30000)) # 随机抽取5个元素
# local_train_list = sample(train_list, 30000)
# print("dsdfsdfs")
# print(len(local_train_list))
# local_test_list = sample(test_list, 10000)
#
# with open('localtrain.txt','w',encoding='UTF-8') as f:
#     for train_img in local_train_list:
#             f.write(str(train_img))
#
# with open('localtest.txt','w',encoding='UTF-8') as f:
#     for test_img in local_test_list:
#         f.write(test_img)

得到txt结果:(文件路径与标签以空格隔开):

在这里插入图片描述

步骤三

​ 将步骤二得到的train.txt 和 test.txt 转化为train_loader 和 test_loader,重写LoadData类的构造方法,将train.txt文本转为train_dataset ,将test.txt转为test_dataset,最后再使用torch.utils.data.DataLoader()进行转为train_loader 和 test_loader: 就可以用于调用模型训练了。

train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                               batch_size=64,
                                               shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                               batch_size=64,
                                               shuffle=True)


重写LoadData类的构造方法代码(这里的transforms.Normalize()图像标准化,可以使用下文的python代码求出mean和std,填入标准化数值。),步骤三代码为 CreateDataloader02.py

# -*- coding: utf-8 -*-
# @Time : 2023/1/26/026 18:56
# @Author : LeeSheel
# @File : CreateDataloader02.py
# @Project : 深度学习
import torch
from PIL import Image
import torchvision.transforms as transforms
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
from torch.utils.data import Dataset


class LoadData(Dataset):
    def __init__(self, txt_path, train_flag=True):
        self.imgs_info = self.get_images(txt_path)
        self.train_flag = train_flag

        self.train_tf = transforms.Compose([
            # 随机旋转图片
            transforms.RandomHorizontalFlip(),
            # 将图片尺寸resize到32x32
            transforms.Resize((32, 32)),
            # 将图片转化为Tensor格式
            transforms.ToTensor(),
            # 正则化(当模型出现过拟合的情况时,用来降低模型的复杂度)
            transforms.Normalize((0.96934927, 0.9696228, 0.9695143), (0.124204025, 0.12326231, 0.12356147))  # 图像标准化
            ])

        self.val_tf = transforms.Compose([
            # 将图片尺寸resize到32x32
            transforms.Resize((32, 32)),
            transforms.ToTensor(),
            transforms.Normalize((0.96934927, 0.9696228, 0.9695143), (0.124204025, 0.12326231, 0.12356147))
            ])

    def get_images(self, txt_path):
        with open(txt_path, 'r', encoding='utf-8') as f:
            imgs_info = f.readlines()
            imgs_info = list(map(lambda x:x.strip().split('\t'), imgs_info))
        return imgs_info



    def __getitem__(self, index):
        img_path, label = self.imgs_info[index]
        img = Image.open(img_path)
        img = img.convert('RGB')

        if self.train_flag:
            img = self.train_tf(img)
        else:
            img = self.val_tf(img)
        label = int(label)

        return img, label

    def __len__(self):
        return len(self.imgs_info)



train_dataset = LoadData("train.txt", True)

print("训练接数据个数:", len(train_dataset))
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                               batch_size=64,
                                               shuffle=True)
for image, label in train_loader:
        print(image.shape)
        print(image)
        # img = transform_BZ(image)
        # print(img)
        print(label)
        break

test_dataset = LoadData("test.txt", False)
print("测试集数据个数:", len(test_dataset))
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                               batch_size=64,
                                               shuffle=True)

求图片标准化transforms.Normalize()参数 代码

# -*- coding: utf-8 -*-
# @Time : 2023/1/31/031 18:18
# @Author : LeeSheel
# @File : 计算std和mea.py
# @Project : 深度学习
import numpy as np
import cv2
import os

# img_h, img_w = 32, 32
img_h, img_w = 32, 32  # 经过处理后你的图片的尺寸大小
means, stdevs = [], []
img_list = []

imgs_path = "D:\\0"  # 数据集的路径采用绝对引用
imgs_path_list = os.listdir(imgs_path)

len_ = len(imgs_path_list)
i = 0
for item in imgs_path_list:
    img = cv2.imread(os.path.join(imgs_path, item))
    img = cv2.resize(img, (img_w, img_h))
    img = img[:, :, :, np.newaxis]
    img_list.append(img)
    i += 1
    print(i, '/', len_)

imgs = np.concatenate(img_list, axis=3)
imgs = imgs.astype(np.float32) / 255.

for i in range(3):
    pixels = imgs[:, :, i, :].ravel()  # 拉成一行
    means.append(np.mean(pixels))
    stdevs.append(np.std(pixels))

# BGR --> RGB , CV读取的需要转换,PIL读取的不用转换
means.reverse()
stdevs.reverse()

print("normMean = {}".format(means))
print("normStd = {}".format(stdevs))

代码下载:

链接:https://pan.baidu.com/s/1fa_gdLYXagu65P2uYpepqA?pwd=xx78
提取码:xx78

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/366699.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

大数据Hadoop教程-学习笔记04【数据仓库基础与Apache Hive入门】

视频教程:哔哩哔哩网站:黑马大数据Hadoop入门视频教程 总时长:14:22:04教程资源: https://pan.baidu.com/s/1WYgyI3KgbzKzFD639lA-_g 提取码: 6666【P001-P017】大数据Hadoop教程-学习笔记01【大数据导论与Linux基础】【17p】【P018-P037】大…

Spring boot开启定时任务的三种方式(内含源代码+sql文件)

Spring boot开启定时任务的三种方式(内含源代码sql文件) 源代码sql文件下载链接地址:https://download.csdn.net/download/weixin_46411355/87486580 目录Spring boot开启定时任务的三种方式(内含源代码sql文件)源代码…

【无人机】回波状态网络(ESN)在固定翼无人机非线性控制中的应用(Matlab代码实现)

💥💥💞💞欢迎来到本博客❤️❤️💥💥 🏆博主优势:🌞🌞🌞博客内容尽量做到思维缜密,逻辑清晰,为了方便读者。 ⛳️座右铭&a…

前端常见手写面试题集锦

实现迭代器生成函数 我们说迭代器对象全凭迭代器生成函数帮我们生成。在ES6中,实现一个迭代器生成函数并不是什么难事儿,因为ES6早帮我们考虑好了全套的解决方案,内置了贴心的 生成器 (Generator)供我们使用&#xff…

java面试题-IO流

基础IO1.如何从数据传输方式理解IO流?IO流根据处理数据的类型可以分为字节流和字符流。字节流字节流以字节(8位)为单位读写数据。字节流主要用于读写二进制文件,如图片、音频、视频等。Java中的InputStream和OutputStream就是字节…

写论文不用构建语料库!只需要福昕PDF阅读器高级搜索

写论文不用构建语料库!只需要福昕PDF阅读器高级搜索 文章目录写论文不用构建语料库!只需要福昕PDF阅读器高级搜索前言:“福昕语料库”使用前的准备:调用“语料库”:前言: 最近论文阅读可以借助NewBing的总…

【算法与数据结构(C语言)】栈和队列

文章目录 目录 前言 一、栈 1.栈的概念及结构 2.栈的实现 入栈 出栈 获取栈顶元素 获取栈中有效元素个数 检测栈是否为空,如果为空返回非零结果,如果不为空返回0 销毁栈 二、队列 1.队列的概念及结构 2.队列的实现 初始化队列 队尾入队列 队头出队列 获…

报表开发难上手?这里有一份 Fastreport 最新中文用户指南,请查收

Fast Reports,Inc.成立于1998年,多年来一直致力于开发快速报表软件,包括应用程序、库和插件。FastReport的报表生成器(VCL平台和.NET平台)、跨平台的多语言脚本引擎FastScript、桌面OLAP FastCube,如今都受到世界各地开…

Typecho COS插件实现网站静态资源存储到COS,降低本地存储负载

Typecho 简介Typecho 是一个简单、强大的轻量级开源博客平台,用于建立个人独立博客。它具有高效的性能,支持多种文件格式,并具有对设备的响应式适配功能。Typecho 相对于其他 CMS 还有一些特殊优势:包括可扩展性、不同数据库之间的…

IDA 实战--(2)熟悉工具

布局介绍 软件启动后会 有几个选项,一般直接选择Go 即可 之后的工作台布局如下 开始分析 分析的第一个,将PE 文件拖入工作区 刚开始接触,我们先保持默认选项,其它选项后面会详细讲解,点击OK 后,等待分析…

软件项目管理知识回顾---软件项目质量和资源管理

软件项目质量和资源管理 5.0质量管理 5.1质量管理模型 1.模型 boehm模型:可移植性,可使用性,可维护性McCall模型ISO体系认证5.2质量成本 1.含义:由于产品第一次不正常运行而产生的附加费用 预防成本和缺陷成本5.3质量管理 1.过程 …

Python opencv进行矩形识别

Python opencv进行矩形识别 图像识别中,圆形和矩形识别是最常用的两种,上一篇讲解了圆形识别,本例讲解矩形识别,最后的结果是可以识别出圆心,4个顶点,如下图: 左边是原始图像,右边是识别结果,在我i5 10400的CPU上,执行时间不到8ms。 识别出结果后,计算任意3个顶点…

【自监督论文阅读笔记】Unsupervised Learning of Dense Visual Representations

Abstract 对比自监督学习已成为无监督视觉表示学习的一种有前途的方法。通常,这些方法学习全局(图像级)表示,这些表示对于同一图像的不同视图(即数据增强的组合)是不变的。然而,许多视觉理解任务…

PDF文件怎么转图片格式?转换有技巧

PDF文件有时为了更美观或者更直观的展现出效果,我们会把它转成图片格式,这样不论是归档总结还是存储起来都会更为高效。有没有合适的转换方法呢?这就来给你们罗列几种我个人用过体验还算不错的方式,大家可以拿来参考一下哈。1.用电…

vm 网络配置

点击NAT设置,配置本台虚拟机ip(注意网关要在同一个网段),配置对应端口 然后添加映射端口: 然后选择网络适配器 选择vm8网卡 配置网卡静态ip #查看网卡 ip addr #修改网卡配置 cd /etc/sysconfig/network-scripts…

DevData Talks | 对谈谷歌云 DORA 布道师,像谷歌一样度量 DevOps 表现

本期 DevData Talks 我们请到来自 Google Cloud 谷歌云的 DORA 研究团队的嘉宾 Nathen Harvey与 Apache DevLake、CNCF DevStream 的海外社区负责人 Maxim 进行对谈。如果您关注 DevOps 的话,也许对这个团队有所耳闻。 DORA 的全称是 DevOps Research and Assessme…

mysql lesson1

常用命令 1:exit 退出mysql 2:uroot pENTER键,再输入密码,不被别人看见 3:完美卸载:双击安装包,手动删除program file中的mysql,手动删除Programedate里的mysql 4:use mysql 使用数据库 5:…

InstallAware Multi-Platform updated

InstallAware Multi-Platform updated 原生ARM:为您的内置设置、IDE和整个工具链添加了Apple macOS和Linux ARM构建。 本地化:引擎内多语言感知,可再分发工具,具有资产隔离功能,使您的IP保持安全。 模板:将…

通信算法复习题纲

通信算法复习题1、当信源发送信号满足以下哪一项条件时,接收端采用最小距离准则进行判决等价于采用最大后验概率准则进行判决?2、OFDM系统的正交性体现在哪个方面?3、模拟信号数字化过程中,哪一步会引入量化噪声?4、OF…

考研复试机试 | c++ | 王道复试班

目录n的阶乘 (清华上机)题目描述代码汉诺塔问题题目:代码:Fibonacci数列 (上交复试)题目代码:二叉树:题目:代码:n的阶乘 (清华上机) …