使用pytorch处理自己的数据集

news2024/11/24 7:53:53

目录

1 返回本地文件中的数据集

2 根据当前已有的数据集创建每一个样本数据对应的标签

3 tensorboard的使用

4 transforms处理数据

tranfroms.Totensor的使用

transforms.Normalize的使用

transforms.Resize的使用

transforms.Compose使用

5 dataset_transforms使用


1 返回本地文件中的数据集

在这个操作中,当前数据集的上一级目录就是当前所有同一数据的label

import os
from torch.utils.data import Dataset
from PIL import Image

class MyDataset(Dataset):
    def __init__(self, root_dir, label_dir):
        """
        :param root_dir: 根目录文件
        :param label_dir: 分类标签目录
        """
        self.root_dir = root_dir
        self.label_dir = label_dir
        self.path = os.path.join(root_dir, label_dir)
        self.image_path_list = os.listdir(self.path)
    def __getitem__(self, idx):
        """
        :param idx: idx是自己文件夹下的每一个图片索引
        :return: 返回每一个图片对象和其对应的标签,对于返回类型可以直接调用image.show显示或者用于后续图像处理
        """
        img_name = self.image_path_list[idx]
        ever_image_path = os.path.join(self.root_dir, self.label_dir, img_name)
        image = Image.open(ever_image_path)
        label = self.label_dir
        return image, label
    def __len__(self):
        return len(self.image_path_list)

root_dir = 'G:\python_files\深度学习代码库\cats_and_dogs_small\\train'
label_dir = 'cats'
my_data = MyDataset(root_dir, label_dir)
first_pic, label = my_data[0]   # 自动调用__getitem__(self, idx)
first_pic.show()
print("当前图片中动物所属label", label)

F:\Anaconda\envs\py38\python.exe G:/python_files/深度学习代码库/dataset/MyDataSet.py
当前图片中动物所属label cats

2 根据当前已有的数据集创建每一个样本数据对应的标签


import os
from torch.utils.data import Dataset
from PIL import Image

class MyLabelData:
    def __init__(self, root_dir, target_dir, label_dir, label_name):
        """
        :param root_dir: 根目录
        :param target_dir: 生成标签的目录
        :param label_dir: 要生成为标签目录名称
        :param label_name: 生成的标签名称
        """
        self.root_dir = root_dir
        self.target_dir = target_dir
        self.label_dir = label_dir
        self.label_name = label_name
        self.image_name_list = os.listdir(os.path.join(root_dir, target_dir))
    def label(self):
        for name in self.image_name_list:
            file_name = name.split(".jpg", 1)[0]
            label_path = os.path.join(self.root_dir, self.label_dir)
            if not os.path.exists(label_path):
                os.makedirs(label_path)
            with open(os.path.join(label_path, '{}'.format(file_name)), 'w') as f:
                f.write(self.label_name)
                f.close()
root_dir = 'G:\python_files\深度学习代码库\cats_and_dogs_small\\train'
target_dir = 'cats'
label_dir = 'cats_label'
label_name = 'cat'
label = MyLabelData(root_dir, target_dir, label_dir, label_name)
label.label()

这样上面的代码中的训练集目录下的每一个样本都会在train的cats_label目录下创建其对应的分类标签

每一个标签中文件中都有一个cat字符串或者其他动物的分类名称,以确定它到底是哪一个动物

3 tensorboard的使用

# tensorboard --logdir=深度学习代码库/logs --port=2001
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter('logs')
for i in range(100):
    writer.add_scalar('当前的函数表达式y=3*x',i*3,i)
writer.close()
#-----------------------------------------------------------
import numpy as np
from PIL import Image
image_PIL = Image.open('G:\python_files\深度学习代码库\cats_and_dogs_small\\train\cats\cat.1.jpg')
image_numpy = np.array(image_PIL)
print(type(image_numpy))
print(image_numpy.shape)
writer.add_image('cat图片', image_numpy,2, dataformats='HWC')

这里使用tensorboard的作用是为了更好的展示数据,但是对于函数的使用,比如上面的add_image中的参数,最好的方式是点击源码查看其对应的参数类型,然后根据实际需要将它所需的数据类型丢给add_image就好,而在源码中该函数的参数中所要求的图片类型必须是tensor类型或者是numpy,所以想要使用tensorboard展示数据就首先必须使用numpy或者使用transforms.Totensor将其转化为tensor,然后丢给add_image函数

还有一个需要注意的是,使用add_image函数,图片的tensor类型或者numpy类型必须和dataformats的默认数据类型一样,否则根据图片的数据类型修改后面的额dataformatas就好

4 transforms处理数据

tranfroms.Totensor的使用
import numpy as np
from torchvision import transforms
from PIL import Image
tran = transforms.ToTensor()
PIL_image = Image.open('G:\python_files\深度学习代码库\\cats\cat\cat.11.jpg')
tensor_pic = tran(PIL_image)
print(tensor_pic)
print(tensor_pic.shape)
from torch.utils.tensorboard import SummaryWriter
write = SummaryWriter('logs')
write.add_image('Tensor_picture',tensor_pic)

tensor([[[0.9216, 0.9059, 0.8353,  ..., 0.2392, 0.2275, 0.2078],
         [0.9765, 0.9216, 0.8118,  ..., 0.2431, 0.2392, 0.2235],
         [0.9490, 0.8745, 0.7608,  ..., 0.2471, 0.2471, 0.2314],
         ...,
         [0.3490, 0.4902, 0.6667,  ..., 0.7804, 0.7804, 0.7804],
         [0.3412, 0.4431, 0.5216,  ..., 0.7765, 0.7922, 0.7882],
         [0.3490, 0.4510, 0.5294,  ..., 0.7765, 0.7922, 0.7882]],

        [[0.9451, 0.9294, 0.8706,  ..., 0.2980, 0.2863, 0.2667],
         [1.0000, 0.9451, 0.8471,  ..., 0.3020, 0.2980, 0.2824],
         [0.9725, 0.8980, 0.7961,  ..., 0.2980, 0.2980, 0.2824],
         ...,
         [0.3725, 0.5137, 0.6902,  ..., 0.8431, 0.8431, 0.8431],
         [0.3647, 0.4667, 0.5451,  ..., 0.8392, 0.8549, 0.8510],
         [0.3608, 0.4627, 0.5412,  ..., 0.8392, 0.8549, 0.8510]],

        [[0.9294, 0.9137, 0.8588,  ..., 0.2235, 0.2118, 0.1922],
         [0.9922, 0.9373, 0.8353,  ..., 0.2275, 0.2235, 0.2078],
         [0.9725, 0.8980, 0.7922,  ..., 0.2275, 0.2275, 0.2118],
         ...,
         [0.4196, 0.5608, 0.7373,  ..., 0.9412, 0.9412, 0.9333],
         [0.4196, 0.5216, 0.6000,  ..., 0.9373, 0.9529, 0.9412],
         [0.4196, 0.5216, 0.6000,  ..., 0.9373, 0.9529, 0.9412]]])
torch.Size([3, 410, 431])

transforms.Normalize的使用
# 对应三个通道,每一个通道一个平均值和方差
# output[channel] = (input[channel] - mean[channel]) / std[channel]
nor = transforms.Normalize([0.5, 0.5, 0.5],[10, 0.5, 0.5])
print(tensor_pic[0][0][0])
x_nor = nor(tensor_pic)
write.add_image('nor_picture:', x_nor)
print(tensor_pic[0][0][0])
write.close()

打开源码查看

def forward(self, tensor: Tensor) -> Tensor:
    """
    Args:
        tensor (Tensor): Tensor image to be normalized.

    Returns:
        Tensor: Normalized Tensor image.
    """
    return F.normalize(tensor, self.mean, self.std, self.inplace)

必须传入的是tensor数据类型

transforms.Resize的使用
size_tensor = transforms.Resize((512,512))
# 裁剪tensor
tensor_pic_size = size_tensor(tensor_pic)
# 裁剪Image
size_pic = transforms.Resize((512,512))
image_size = size_pic(PIL_image)
print(image_size)
write.add_image('tensor_pic_size',tensor_pic_size)
print(tensor_pic_size.shape)
np_image = np.array(image_size)
print('np_image.shape:', np_image.shape)
write.add_image('image_size', np_image, dataformats='HWC')

调用Resize的时候,需要传入的数据类型的要求,查看源码如下

def forward(self, img):
    """
    Args:
        img (PIL Image or Tensor): Image to be scaled.

    Returns:
        PIL Image or Tensor: Rescaled image.
    """
    return F.resize(img, self.size, self.interpolation)

<PIL.Image.Image image mode=RGB size=512x512 at 0x1A72B1E7D00>
torch.Size([3, 512, 512])
np_image.shape: (512, 512, 3)

transforms.Compose使用
nor = transforms.Normalize([0.5, 0.5, 0.5],[10, 0.5, 0.5])
trans_resize_2 = transforms.Resize((64,64))
trans_to_tensor = transforms.ToTensor()
trans_compose = transforms.Compose([trans_resize_2, trans_to_tensor])
tensor_pic_compose = trans_compose(PIL_image)
write.add_image('tensor_pic_compose',tensor_pic_compose,dataformats='CHW')

5 dataset_transforms使用

from torch.utils.data import DataLoader
from torchvision import  transforms
import torchvision
data_transform = transforms.Compose([transforms.ToTensor()])
train_data = torchvision.datasets.CIFAR10('./data', train=True, download=True)
test_data = torchvision.datasets.CIFAR10('./data', train=False, download=True)
print("train_data", train_data)
print("train_data[0]", train_data[0])
print("train_data.classes", train_data.classes)
image, label = train_data[0]
print("label ",label)
image.show()
print("train_data.classes[label]", train_data.classes[label])

train_data Dataset CIFAR10
    Number of datapoints: 50000
    Root location: ./data
    Split: Train
train_data[0] (<PIL.Image.Image image mode=RGB size=32x32 at 0x144ED58D970>, 6)
train_data.classes ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
label  6
train_data.classes[label] frog

#%%
from torchvision import transforms
import torchvision
# 将整个数据集转化为tensor类型
data_transform1 = transforms.Compose([transforms.ToTensor()])
train_data = torchvision.datasets.CIFAR10('./data', train=True, transform=data_transform1, download=True)
test_data1 = torchvision.datasets.CIFAR10('./data', train=False, transform=data_transform1, download=True)
from torch.utils.tensorboard import SummaryWriter
write = SummaryWriter('batch_picture')
for i in range(10):
    tensor_pic, label = train_data[i]  # 经过前面的transforms成了tensor
    print(tensor_pic.shape)
    write.add_image('batch_picture', tensor_pic, i)
write.close()

Files already downloaded and verified
Files already downloaded and verified
torch.Size([3, 32, 32])
torch.Size([3, 32, 32])
torch.Size([3, 32, 32])
torch.Size([3, 32, 32])
torch.Size([3, 32, 32])
torch.Size([3, 32, 32])
torch.Size([3, 32, 32])
torch.Size([3, 32, 32])
torch.Size([3, 32, 32])
torch.Size([3, 32, 32])

from torchvision import transforms
import torchvision
# 将整个数据集转化为tensor类型
data_transform = transforms.Compose([transforms.ToTensor()])
train_data = torchvision.datasets.CIFAR10('./data', train=True, transform=data_transform, download=True)
test_data = torchvision.datasets.CIFAR10('./data', train=False, transform=data_transform, download=True)
# dataLoad会将一个batch中的图片和图片的Label分别放在一起,形成对应
train_data_load = DataLoader(dataset=train_data, shuffle=True, batch_size=64,)
from torch.utils.tensorboard import SummaryWriter
write = SummaryWriter('dataLoad')
# 遍历整个load,一次遍历的图片是64个
for batch_id, data in enumerate(train_data_load):
    batch_image, batch_label = data
    print("batch_id",batch_id)
    print("image.shape", batch_image.shape)
    print("label.shape", batch_label.shape)
    write.add_images('batch_load_picture', batch_image, batch_id, dataformats='NCHW')
write.close()

batch_id 765
image.shape torch.Size([64, 3, 32, 32])
label.shape torch.Size([64])
batch_id 766
image.shape torch.Size([64, 3, 32, 32])
label.shape torch.Size([64])
batch_id 767
image.shape torch.Size([64, 3, 32, 32])
label.shape torch.Size([64])
batch_id 768
image.shape torch.Size([64, 3, 32, 32])
label.shape torch.Size([64])
batch_id 769
image.shape torch.Size([64, 3, 32, 32])
label.shape torch.Size([64])
batch_id 770
image.shape torch.Size([64, 3, 32, 32])
label.shape torch.Size([64])
batch_id 771
image.shape torch.Size([64, 3, 32, 32])
label.shape torch.Size([64])
batch_id 772
image.shape torch.Size([64, 3, 32, 32])
label.shape torch.Size([64])
batch_id 773
image.shape torch.Size([64, 3, 32, 32])

使用add_images对所有批次的数据进行展示

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1164769.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

YOLOv5:修改backbone为SPPCSPC

YOLOv5&#xff1a;修改backbone为SPPCSPC 前言前提条件相关介绍SPPCSPCYOLOv5修改backbone为SPPCSPC修改common.py修改yolo.py修改yolov5.yaml配置 参考 前言 记录在YOLOv5修改backbone操作&#xff0c;方便自己查阅。由于本人水平有限&#xff0c;难免出现错漏&#xff0c;敬…

一文读懂最小相位滤波器和线性相位滤波器

一文读懂最小相位滤波器和线性相位滤波器 1. 举例说明2. 最小相位定义2.1 最小相位多项式2.2 最大相位滤波器2.3 最小相位意味最快的衰减2.4 最小相位/全通分解 3. 建立最小相位系统 前一篇博客 《一文读懂滤波器的线性相位&#xff0c;全通滤波器&#xff0c;群延迟》 详细解…

curl(二)HTTP协议和头

一 HTTP协议相关 ① 强制发出请求的http1.0 7.29 版本默认是http1.1 ② 查看当前curl版本是否支持http2 方式2: curl --version看Features 补充&#xff1a; 7.33.0 版本才引入 http2,才能使用curl发出http2.0版本的请求 ③ 强制发送http3 说明&#xff1a; 了解即可 二…

DevChat:超越编码的未来 - 优势、安装、使用、以及未来前景

目录 前言1 DevChat的优势1.1 精确的上下文控制1.2 灵活的提示管理1.3 上下文构建1.4 提前准备好的提示模板1.5 高级命令控制 2 安装DevChat插件3 使用DevChat插件3.1 代码生成3.2 文档撰写3.3 解释代码3.4 解决问题3.5 版本控制 4 DevChat的未来前景结语 前言 在软件开发领域…

Docker 多阶段构建的原理及构建过程展示

Docker多阶段构建是一个优秀的技术&#xff0c;可以显著减少 Docker 镜像的大小&#xff0c;从而加快镜像的构建速度&#xff0c;并减少镜像的传输时间和存储空间。本文将详细介绍 Docker 多阶段构建的原理、用途以及示例。 Docker 多阶段构建的原理 在传统的 Docker 镜像构建…

阿里面试:让代码不腐烂,DDD是怎么做的?

说在前面 在40岁老架构师 尼恩的读者交流群(50)中&#xff0c;最近有小伙伴拿到了一线互联网企业如阿里、滴滴、极兔、有赞、希音、百度、网易、美团的面试资格&#xff0c;遇到很多很重要的面试题&#xff1a; 谈谈你的高并发落地经验&#xff1f; 谈谈你对DDD的理解&#xf…

CIM与MES

CIM系统&#xff0c;全称计算机集成制造系统&#xff08;Computer-Integrated Manufacturing&#xff09;&#xff0c;是一种集成了计算机技术、网络通讯技术和软件系统的制造自动化框架。CIM的主要目标是整合制造过程中的所有活动&#xff0c;包括生产管理、设备管理和品质管理…

物流小程序制作教程:从零到有,详细解析

随着互联网的快速发展&#xff0c;物流行业也逐渐实现了数字化转型。为了满足消费者对更加便捷、高效的服务需求&#xff0c;许多物流企业选择制作自己的小程序。本文将通过乔拓云网后台&#xff0c;带你轻松搭建物流小程序&#xff0c;主要分为以下几个部分&#xff1a; 一、进…

设置echarts折线图虚线

itemStyle:{normal: {lineStyle: { type: solid}}}itemStyle:{normal: {lineStyle: { type: dashed}}}放到每个红框里面

梯度消失和梯度爆炸的原因

梯度消失和梯度爆炸 梯度爆炸和梯度消失本质上是因为梯度反向传播中的连乘效应。 梯度下降算法 举一个简单的例子,函数表达式为loss 2w^2 4w,如下图 ​​​​​​​ ​​​​​​​ 为了求得w的最优值,使得loss最小,从上图很容易看出来当w -1时,loss最小…

ER图设计神器,帮你省时省力,高效完成工作!

ER图&#xff08;Entity-Relationship Diagram&#xff09;工具用于设计数据库模型&#xff0c;通常用于表示数据实体、关系和属性之间的关系。以下是10个好用的ER图工具。 一、Lucidchart Lucidchart 是一款基于云的协作式图表设计工具&#xff0c;它允许用户创建、编辑和共享…

SAP发票及复制控制

一、概述 众所周知&#xff0c;SAP为业财一体化的ERP管理系统&#xff0c;因此财务发票必不可少&#xff0c;很多外贸企业还会用到形式发票。 发票相关的配置主要包含&#xff1a;发票类型、复制控制、发票定价以及自动过账的配置。 二、系统配置 1. 发票类型 1.1 概述 发…

Latex安装使用教程

在论文投稿时有些期刊要求使用Latex格式&#xff0c;比如博主现在就遇到了这个问题&#xff0c;木有办法&#xff0c;老老实实的学呗。大家可以去官网下载&#xff0c;但官网的界面设计属实有些一言难尽&#xff0c;因此我们可以使用国内的镜像。 LaTeX 基于 TeX&#xff0c;主…

2023年双十一第2波红包活动淘宝天猫京东双11红包领取优惠券跨店满多少减多少规则?

本文为大家提供众多福利&#xff1a; 2023年淘宝/天猫双十一红包第2波活动时间与领取入口&#xff0c;最高23888超级红包 2023年京东双十一红包第2波活动时间与领取入口&#xff0c;最高11111京享红包 草柴APP领淘宝/天猫、京东大额内部隐藏优惠券&#xff0c;拿购物返利 美…

Content-Type 值有哪些?

1、application/x-www-form-urlencoded 最常见 POST 提交数据的方式。 浏览器的原生 form 表单&#xff0c;如果不设置 enctype 属性&#xff0c;那么最终就会以 application/x-www-form-urlencoded 方式提交数据。 <form action"http://www.haha/ads/sds?name小草莓…

根据Aurora发送时序,造Aurora 发送数据包

首先Aurora采用AXIS接口 由于后续需要进行AXIS接口 不同时钟域的数据位宽转换&#xff08;64bit和256bit之间的转换&#xff09;&#xff0c;因此分两次走。 第一种方法&#xff1a;采用AXIS数据位宽转换IP AXIS跨时钟域IP 第二种方法&#xff1a;逻辑完成 下面记录逻辑…

ERP集成WMS仓储管理系统提升仓储效率

随着企业业务规模的不断扩张&#xff0c;仓储管理逐渐凸显其重要性。传统的仓储管理方式&#xff0c;随着业务量的增长&#xff0c;可能带来种种问题&#xff0c;如数据不准确、效率低下、成本增高等。在这样的背景下&#xff0c;ERP集成WMS仓储管理系统的出现成为企业的及时雨…

麒麟KYLINIOS软件仓库搭建01-新创建软件仓库服务器

原文链接&#xff1a;麒麟KYLINIOS软件仓库搭建01-新创建软件仓库服务器 hello&#xff0c;大家好啊&#xff0c;今天给大家带来麒麟桌面操作系统软件仓库搭建的文章01-新创建软件仓库服务器&#xff0c;本篇文章主要给大家介绍了如何在麒麟桌面操作系统2203-x86版本上搭建内网…

Redis高可用解决方案之Redis集群,和Spring Cloud集成实战

专栏集锦&#xff0c;大佬们可以收藏以备不时之需 Spring Cloud实战专栏&#xff1a;https://blog.csdn.net/superdangbo/category_9270827.html Python 实战专栏&#xff1a;https://blog.csdn.net/superdangbo/category_9271194.html Logback 详解专栏&#xff1a;https:/…

Java网站如何集成支付宝当面付,企业个人都能使用的支付(比较简单)

创建应用 这个得先去登录 - 支付宝创建应用 相关配置设置 maven配置 <dependency><groupId>com.alipay.sdk</groupId><artifactId>alipay-sdk-java</artifactId><version>4.38.10.ALL</version></dependency> 支付服务代码 …