pytorch进阶学习(二):使用DataLoader读取自己的数据集

news2025/1/12 7:46:02
上一节使用的是官方数据集fashionminist进行训练,这节课使用自己搜集的数据集来进行数据的获取和训练。

所需资源

教学视频:https://www.bilibili.com/video/BV1by4y1b7hX/?spm_id_from=333.1007.top_right_bar_window_history.content.click&vd_source=e482aea0f5ebf492c0b0220fb64f98d3

pytorch进阶学习(一):https://blog.csdn.net/weixin_45662399/article/details/129737499?spm=1001.2014.3001.5501

课程准备:本节课需要用到3个Python文件和一个数据集文件,代码后面我都会给出,zip需要自己下载,,数据集zip文件和所需的三个代码文件可以在“leo在这”进行下载。

一、数据集文件准备

1.1 项目文件结构

下图为我的项目文件在远程服务器上的目录,需要新建一个dataset中为上传的自己的数据集。

1.2 上传数据集到服务器

数据集文件解压后如下所示,有6个子文件夹,对应6个类别。

我们先把dataset.zip上传到服务器中的代码项目文件夹中一定要找到服务器中项目的路径,不要传错位置!

我的项目目录在服务器中的路径为“tmp/pycharm_932”,数据集zip文件即下载到tmp/pycharm_932/dataset目录下。

1.3 解压zip文件

回到服务器控制台在红,先使用cd命令定位到tmp/pycharm_932/dataset路径下,然后使用unzip 'dataset.zip'命令解压压缩文件。

可以看到服务器完成了解压。

最后把zip文件从dataset文件夹中删去即可,最终解压好的文件如下。

1.4 代码框架解读

  1. 'CreateDataset.py' 用于把数据集文件夹中的所有图片文件生成一个TXT文件,其中存放着所有图片的路径和图片对应的标签。

  1. 'CreateDataLoader.py' 用于把上一步生成的TXT文件中的信息提取出来,进行图片信息的打包,生成一个dataloader

  1. 最后在'加载自己的数据.py' 文件中对dataloader进行使用,并且完成网络的训练和测试。

二、运行CreateDataset.py

作用:该代码可以生成训练集和测试集中每张图片的路径和标签,保存在TXT文件中。后续即可以从文件中调用每一张图片,进行读取。

2.1 代码实现

'''
生成训练集和测试集,保存在txt文件中
'''

import os
import random

#60%当训练集
train_ratio = 0.6
#剩下的当测试集
test_ratio = 1-train_ratio

rootdata = r"dataset"

train_list, test_list = [],[]
data_list = []

class_flag = -1
for a,b,c in os.walk(rootdata):
    print(a)
    for i in range(len(c)):
        data_list.append(os.path.join(a,c[i]))

    for i in range(0,int(len(c)*train_ratio)):
        train_data = os.path.join(a, c[i])+'\t'+str(class_flag)+'\n'
        train_list.append(train_data)

    for i in range(int(len(c) * train_ratio),len(c)):
        test_data = os.path.join(a, c[i]) + '\t' + str(class_flag)+'\n'
        test_list.append(test_data)

    class_flag += 1

print(train_list)
random.shuffle(train_list)
random.shuffle(test_list)

with open('train.txt','w',encoding='UTF-8') as f:
    for train_img in train_list:
        f.write(str(train_img))

with open('test.txt','w',encoding='UTF-8') as f:
    for test_img in test_list:
        f.write(test_img)

2.2 运行结果

可以看到此时服务器中文件管理器中已经有了test.txt和train.txt两个文件。

生成了train.txt和test.txt两个文件,里面保存了每张图片的对应相对路径和类别标签,标签是以int型进行存储。打开test.txt文件可以发现里面的内容为测试集所有图片路径以及其标签。

2.3 代码要点解析

  1. 对训练集和测试集的划分比例为6:4

rootdata为数据集文件保存的根目录,为dataset文件夹。

#60%当训练集
train_ratio = 0.6
#剩下的当测试集
test_ratio = 1-train_ratio
rootdata = r"dataset"

2. 读取文件夹

  • a读取文件夹根目录,再使用c[i]读取每个图片的名称,使用os.path.join进行拼接,实现每张图片的相对路径的path存取。可以看到a为dataset加上下面类别子文件夹。

dataset/擦花
dataset/桔皮
dataset/碰伤
dataset/横条压凹
dataset/不导电
dataset/漏底

和c[i]进行拼接后即可完成每一张图片的定位。

dataset/碰伤/碰伤20180906142721对照样本.jpg

  • 然后使用class_flag进行图片类别标签的存储,从0开始依次增加,一共6个类别,故取值为【0,5】。

  • path和label之间使用\t进行分割,即一个tab的距离。

for a,b,c in os.walk(rootdata):
    print(a)
    for i in range(len(c)):
        data_list.append(os.path.join(a,c[i]))

    for i in range(0,int(len(c)*train_ratio)):
        train_data = os.path.join(a, c[i])+'\t'+str(class_flag)+'\n'
        train_list.append(train_data)

    for i in range(int(len(c) * train_ratio),len(c)):
        test_data = os.path.join(a, c[i]) + '\t' + str(class_flag)+'\n'
        test_list.append(test_data)

    class_flag += 1

3. 写入文件

  • 使用shuffle打乱数据集的顺序,把上面生成的每一张图片的“路径+标签”转为字符串形式,然后存入txt文件中。

  • 因为文件名有中文,所以使用utf-8编码形式。

random.shuffle(train_list)
random.shuffle(test_list)

with open('train.txt','w',encoding='UTF-8') as f:
    for train_img in train_list:
        f.write(str(train_img))

with open('test.txt','w',encoding='UTF-8') as f:
    for test_img in test_list:
        f.write(test_img)

三、CreateDataLoader.py

作用:完成对上一步dataset中生成的txt文件中对图片和标签信息的读取,将图片进行打包放入图片加载器DataLoader中。

使用系统带的数据集如下代码所示,将带的数据集存在training_data中,将training_data作为参数传入DataLoader中。

training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),
)

# 下面是测试集,同样需要下载
test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor(),
)

train_dataloader = DataLoader(training_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)

可以看到我们代码是使用LoadData类自己新建了一个数据集train_dataset,然后把train_dataset传入DataLoader中。

train_dataset = LoadData("train.txt", True)
 # 传入dataloader中
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,batch_size=10,shuffle=True)

3.1 代码实现

import torch
from PIL import Image

import torchvision.transforms as transforms
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

from torch.utils.data import Dataset

# 数据归一化与标准化

# 图像标准化
transform_BZ= transforms.Normalize(
    mean=[0.5, 0.5, 0.5],# 取决于数据集
    std=[0.5, 0.5, 0.5]
)

#读取TXT文件
class LoadData(Dataset):
    def __init__(self, txt_path, train_flag=True):
        self.imgs_info = self.get_images(txt_path)
        self.train_flag = train_flag

        self.train_tf = transforms.Compose([
                transforms.Resize(224),
                transforms.RandomHorizontalFlip(),
                transforms.RandomVerticalFlip(),
                transforms.ToTensor(),
                transform_BZ
            ])
        self.val_tf = transforms.Compose([
                transforms.Resize(224),
                transforms.ToTensor(),
                transform_BZ
            ])

    def get_images(self, txt_path):
        with open(txt_path, 'r', encoding='utf-8') as f:
            imgs_info = f.readlines()
            imgs_info = list(map(lambda x:x.strip().split('\t'), imgs_info))
        return imgs_info

    def padding_black(self, img):
        w, h  = img.size
        scale = 224. / max(w, h)
        img_fg = img.resize([int(x) for x in [w * scale, h * scale]])
        size_fg = img_fg.size
        size_bg = 224
        img_bg = Image.new("RGB", (size_bg, size_bg))
        img_bg.paste(img_fg, ((size_bg - size_fg[0]) // 2,
                              (size_bg - size_fg[1]) // 2))
        img = img_bg
        return img

    def __getitem__(self, index):
        img_path, label = self.imgs_info[index]
        img = Image.open(img_path)
        img = img.convert('RGB')
        img = self.padding_black(img)
        if self.train_flag:
            img = self.train_tf(img)
        else:
            img = self.val_tf(img)
        label = int(label)

        return img, label

    def __len__(self):
        return len(self.imgs_info)


if __name__ == "__main__":
    train_dataset = LoadData("train.txt", True)
    print("数据个数:", len(train_dataset))
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                               batch_size=10,
                                               shuffle=True)
    for image, label in train_loader:
        print(image.shape)
        print(image)
        # img = transform_BZ(image)
        # print(img)
        print(label)

    # test_dataset = Data_Loader("test.txt", False)
    # print("数据个数:", len(test_dataset))
    # test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
    #                                            batch_size=10,
    #                                            shuffle=True)
    # for image, label in test_loader:
    #     print(image.shape)
    #     print(label)

3.2 运行结果

我学校服务器跑大约等了两分钟,然后跑完后出现如下结果。

可以看到数据集中有380张图片,图片大小为224*224,dataloader中图片每10个为一组,tensor中为10个图片的标签。

3.3 代码要点解析

3.3.1 class LoadData init方法

  1. 该类的初始化方法,定义了两个变量,img_info为get_images方法获取的信息,是一个list,保存着图片的路径和标签;train_flag为标志点,标志是否为训练集,TRUE为时=是,否则为测试集。

  1. train_tf和val_tf使用compose完成对图片样式的变换,如定义大小为224*224,随机水平展平,正则化等。

3.3.2 get_images方法

通过txt文件的路径读取到txt中信息,使用‘\t’分割图片路径和图片标签,并且保存在imgs)info的列表中。

3.3.3 padding_black方法

如果图片过小,使用padding填充该图片,使其能够成为224*224大小。

3.3.4 getitem方法

  1. 该方法为class loaddata的主方法,使用index下标获取到每一张图片的path和label后,用flag判断为训练集还是验证集,并且采用对应的图片处理措施(train_tf/val_tf)。

  1. 返回的是图片和对应标签。

3.3.5 main方法

  1. 把LoadData生成的数据存入train_dataset中,作为数据集传入torch的data.DataLoader类中,设置batchsize=10.

  1. 10张图片为一组,输出每张图片和标签。

四、加载自己的数据.py

4.1 代码实现

大部分和第一节中的模型一样。

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda, Compose
import matplotlib.pyplot as plt
from CreateDataloader import LoadData

# 定义网络模型
class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        # 碾平,将数据碾平为一维
        self.flatten = nn.Flatten()
        # 定义linear_relu_stack,由以下众多层构成
        self.linear_relu_stack = nn.Sequential(
            # 全连接层
            nn.Linear(3*224*224, 512),
            # ReLU激活函数
            nn.ReLU(),
            # 全连接层
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 6),
            nn.ReLU()
        )
    # x为传入数据
    def forward(self, x):
        # x先经过碾平变为1维
        x = self.flatten(x)
        # 随后x经过linear_relu_stack
        logits = self.linear_relu_stack(x)
        # 输出logits
        return logits



# 定义训练函数,需要
def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    # 从数据加载器中读取batch(一次读取多少张,即批次数),X(图片数据),y(图片真实标签)。
    for batch, (X, y) in enumerate(dataloader):
        # 将数据存到显卡
        X, y = X.cuda(), y.cuda()

        # 得到预测的结果pred
        pred = model(X)

        # 计算预测的误差
        # print(pred,y)
        loss = loss_fn(pred, y)

        # 反向传播,更新模型参数
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # 每训练100次,输出一次当前信息
        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")


def test(dataloader, model):
    size = len(dataloader.dataset)
    print("size = ",size)
    # 将模型转为验证模式
    model.eval()
    # 初始化test_loss 和 correct, 用来统计每次的误差
    test_loss, correct = 0, 0
    # 测试时模型参数不用更新,所以no_gard()
    # 非训练, 推理期用到
    with torch.no_grad():
        # 加载数据加载器,得到里面的X(图片数据)和y(真实标签)
        for X, y in dataloader:
            # 将数据转到GPU
            X, y = X.cuda(), y.cuda()
            # 将图片传入到模型当中就,得到预测的值pred
            pred = model(X)
            # 计算预测值pred和真实值y的差距
            test_loss += loss_fn(pred, y).item()
            # 统计预测正确的个数
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= size
    correct /= size
    print("correct = ",correct)
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")



if __name__=='__main__':
    batch_size = 16

    # # 给训练集和测试集分别创建一个数据集加载器
    train_data = LoadData("train.txt", True)
    valid_data = LoadData("test.txt", False)


    train_dataloader = DataLoader(dataset=train_data, num_workers=4, pin_memory=True, batch_size=batch_size, shuffle=True)
    test_dataloader = DataLoader(dataset=valid_data, num_workers=4, pin_memory=True, batch_size=batch_size)

    for X, y in test_dataloader:
        print("Shape of X [N, C, H, W]: ", X.shape)
        print("Shape of y: ", y.shape, y.dtype)
        break


    # 如果显卡可用,则用显卡进行训练
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print("Using {} device".format(device))

    # 调用刚定义的模型,将模型转到GPU(如果可用)
    model = NeuralNetwork().to(device)

    print(model)

    # 定义损失函数,计算相差多少,交叉熵,
    loss_fn = nn.CrossEntropyLoss()

    # 定义优化器,用来训练时候优化模型参数,随机梯度下降法
    optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)  # 初始学习率


    # 一共训练5次
    epochs = 5
    for t in range(epochs):
        print(f"Epoch {t+1}\n-------------------------------")
        train(train_dataloader, model, loss_fn, optimizer)
        test(test_dataloader, model)
    print("Done!")

    # 读取训练好的模型,加载训练好的参数
    model = NeuralNetwork()
    model.load_state_dict(torch.load("model.pth"))

4.2 运行结果

首先打印出来了网络的结构。

之后是训练模型的结果,一共有5个epoch,可以看到准确率不是很高。

4.3 代码解析

num_workers为cpu使用多线程读取数据,pin_memory为不写入虚拟内存

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/419954.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

建立数据驱动,关键字驱动和混合Selenium框架这些你了解吗

一、什么是Selenium框架? Selenium框架是一种代码结构,用于简化代码维护和提高代码可读性。框架涉及将整个代码分成较小的代码段,以测试特定的功能。 该代码的结构使得“数据集”与实际的“测试用例”分开,后者将测试Web应用程序…

【PyTorch】第八节:数据的预处理

作者🕵️‍♂️:让机器理解语言か 专栏🎇:PyTorch 描述🎨:PyTorch 是一个基于 Torch 的 Python 开源机器学习库。 寄语💓:🐾没有白走的路,每一步都算数&#…

【NVIDIA GPU 入门】综述

系列文章目录 文章目录系列文章目录前言一、概述二、GPU架构基础2.1 GPU概述2.2 GPU的架构2.3 自主查询GPU相关信息三、CUDA编程概念3.1 CUDA线程模型3.1 线程层次结构1.引入库2.读入数据总结参考文献前言 GPU作为机器学习的基础运算设备,基本上是无人不知无人不晓。…

【bsauce读论文】PSPRAY-基于时序侧信道的Linux内核堆利用技术

会议:USENIX Security’23 作者:来自 Seoul National University 的 Yoochan Lee、Byoungyoung Lee 等人。 主要内容:由于Linux内核的堆分配器SLUB开启的freelist随机化保护,所以堆相关的内核漏洞利用成功率较低(平均…

BEV(一)---lift splat shoot

1. 算法原理 1.1 2D坐标与3D坐标的关系 如图,已知世界坐标系上的某点P(Xc, Yc, Zc)经过相机的内参矩阵可以获得唯一的图像坐标p(x, y),但是反过来已知图像上某点p&…

软考初级程序员--学习

1、十进制 转 二进制 1.1、整数十进制87 转换为 二进制为 1010111 1.2 、小数十进制0.125 转为 二进制 为 0.001 使用乘2取整法,一直乘到没有小数 2、二进制 转 十进制 2.1、二进制1010111 转换为 十进制 2.2、 二进制小数0.001 转 十进制 3、循环队列 计算长度通用…

周赛341(模拟、双指针、树上DP)

文章目录周赛341[6376. 一最多的行](https://leetcode.cn/problems/row-with-maximum-ones/)暴力模拟[6350. 找出可整除性得分最大的整数](https://leetcode.cn/problems/find-the-maximum-divisibility-score/)暴力模拟[6375. 构造有效字符串的最少插入数](https://leetcode.c…

JVM系统优化实践(15):GC可视化工具实践

您好,我是湘王,这是我的CSDN博客,欢迎您来,欢迎您再来~ 线上系统的JVM监测要么使用jstat、jmap、jhat等工具查看JVM状态,或者使用监控系统,如Zabbix、Prometheus、Open-FaIcon、Ganglia等。作为…

pyg的NeighborLoader和LinkNeighborLoader

NeighborLoader 1 数据格式要求 需要传入加载的属性值: class NeighborLoader(data: Union[Data, HeteroData, Tuple[FeatureStore, GraphStore]], num_neighbors: Union[List[int], Dict[Tuple[str, str, str], List[int]]], input_nodes: Union[Tensor, None…

进程调度的基本过程

进程调度的基本过程🔎 进程是什么🔎 进程管理🔎 进程中结构体的属性进程标识符(PID)内存指针文件描述符表结构体中与进程调度相关的属性进程的状态进程的优先级进程的上下文进程的记账信息🔎 总结🔎 结尾🔎…

(第十四届蓝桥真题) 整数删除(线段树+二分)

样例输入: 5 3 1 4 2 8 7 样例输出: 17 分析:这道题我想的比较复杂,不过复杂度还是够用的,我是用线段树二分来做的。 我们用线段树维护所有位置的最小值,那么我们每次删除一个数之前先求一遍最小值&a…

停车场管理系统文件录入(C++版)

❤️作者主页:微凉秋意 ✅作者简介:后端领域优质创作者🏆,CSDN内容合伙人🏆,阿里云专家博主🏆 文章目录一、案例需求描述1.1、汽车信息模块1.2、普通用户模块1.3、管理员用户模块二、案例分析三…

mysql:使用终端操作数据库

登录进入终端: mysql -u root -p 展示数据库 SHOW DATABASES; 创建数据库: CREATE DATABASE IF NOT EXISTS RUNOOB_TEST DEFAULT CHARSET utf8 COLLATE utf8_general_ci; 1. 如果数据库不存在则创建,存在则不创建。 2. 创建RUNOOB_TEST数据库…

ElasticSearch安装、启动、操作及概念简介

ElasticSearch快速入门 文件链接:https://pan.baidu.com/s/15kJtcHY-RAY3wzpJZIn4-w?pwd0k5a 提取码:0k5a 有些软件对于安装路径有一定的要求,例如:路径中不能有空格,不能有中文,不能有特殊符号&#xf…

JUC并发编程之ReentrantLock

1. 非公平锁实现原理 加锁解锁流程 构造器默认实现的是非公平锁 public ReentrantLock() {sync new NonfairSync();}NonfairSync 继承 Sync, Sync 继承 AbstractQueuedSynchronizer 没有竞争时 第一个竞争出现时 Thread-1 执行了 CAS 尝试将state 由 0 改为 1&…

Stable Diffusion免费(三个月)通过阿里云轻松部署服务

温馨提示:划重点,活动入口在这里喔,不要迷路了。 其实我就在AIGC_有没有一种可能,其实你早就在AIGC了?阿里云邀请你,体验一把AIGC级的毕加索、达芬奇、梵高等大师作画的快感。阿里云将提供免费云产品资源&…

如何使用evosuite为指定被测方法生成测试用例

目录 省流版本 准备工作 环境 evosuite获取 检验环境 参数解释 怎样表示被测方法 怎样指向被测类 其他参数 参考 省流版本 java -jar .\target\depd\evosuite-1.1.0.jar -generateTests -Dtarget_method"isLenient()Z" -class com.google.gson.stream.…

Midjourney教程(二)——Prompt基本结构

Midjourney教程——Prompt基本结构 Basic Prompt 基础版本的prompt仅仅包含图片的描述,能够满足普通的需求,如下图所示 Advanced Prompt 高级版本的prompt主要包含三个部分,如下图所示 Image Prompts(可选) prompt第一部分是Image&#x…

TCP/IP协议详解

一.引言TCP/IP 是 TCP 和 IP 两种协议群的统称,具体来说,IP 或 ICMP、TCP 或 UDP、TELNET 或 FTP、以及 HTTP 等都属于 TCP/IP 协议二.计算机网络体系结构分层计算机网络体系结构分层计算机网络体系结构分层不难看出,TCP/IP 与 OSI 在分层模块…

【C语言】迷宫问题

【C语言】迷宫问题一. 题目描述二. 思想2.1 算法---回溯算法2.2 思路分析图解三. 代码实现3.1 二维数组的实现3.2 上下左右四个方向的判断3.4 用栈记录坐标的实现3.5 完整代码四. 总结一. 题目描述 牛客网链接:https://www.nowcoder.com/questionTerminal/cf2490605…