PyTorch 创建数据集

news2024/11/13 10:44:46

图片数据和标签数据准备

1.本文所用图片数据在同级文件夹中 ,文件路径为'train/’

在这里插入图片描述

2.标签数据在同级文件,文件路径为'train.csv'

在这里插入图片描述

3。将标签数据提取

train_csv=pd.read_csv('train.csv')

创建继承类

第一步,首先创建数据类对象 此时可以想象为单个数据单元的创建 { 图像,标签}

在这里插入图片描述

继承的是Dataset类 (数据集类)

from torch.utils.data import Dataset
from PIL import Image          //从文件路径中提取图片所需要的函数

class Imagedata(Dataset):        //继承Dataset类
	def __init__(self,df,dir,transform=None):     //往类里传输需要的数据必须在这定义,后面初始化函数才能使用传入的数据,
	                                              //df表示传入的标签数据,dir表示图像数据文件地址,transform是图像增强的处理操作
	      super().__init__()                      //声明后面操作需要用的数据
	      self.df=df                           
	      self.dir=dir
	      self.transform=transform
    def __len__(self):                     //模板函数,没什么卵用
        return len(self.df)
    def __getitem__(self, idex):           //将单个数据和标签整合到一块的初始化函数
        img_id=self.df.iloc[idex,0]        //图片的名称在df文件中,标签也在df的文件中,如下图,为的就是提出图像数据文件中的图片,否则从图片数据文件中一张一张提取出来很难,名称太长
        img=Image.open(self.dir+img_id)   //拿到了图片的整个完整地址  
        img=np.array(img)                //Image提取出来的为image类型,需要转换为numpy数组,才能存储到数据集中
                                         //上面两行也可以换为cv2.imread(dir),直接读取的数据就可以往里面存,避免了数据转换
        label=self.df.iloc[idex,1]       //从df中提取对应的标签,就是同一张图像的标签,由idex固定
        return img,label                 //返回整理好的单个数据单元(图像+标签)
		

在这里插入图片描述

第二步,创造好了单个数据单元对象,那么需要将多个数据单元整合起来构成一个完整的数据集

先将单个数据单元实现,因为上面的代码为类对象代码,并没有实现

train_dataset=ImageDataset(df=train_csv,dir='train/')  //df为标签文件,dir表示你图像存储的文件地址

得到了单个数据单元,那么开始将数据整合,先调用数据整合函数:

from torch.utils.data import DataLoader

通过数据流来整合

train_data=DataLoader(train_dataset,batch_size=32)    //train_dataset 为单个对象     batch_size为设置几个为一小组,为后面的分组训练做准备

那么最后得到的train_data就是带有图像和标签的数据集,可以验证一下:

for img,label in train_data:
    print(img,label)

在这里插入图片描述

图像增强技术(降噪,标准化)

上面没有加入图像增强代码,创建数据集时候,可以先将图像增强后再存入数据集,增强的主要目的就是提高训练准确率,标准化可以使图像在神经网络训练的更快,因为图像的数据明显变小,举个例子,由像素[233,221,222]可以直接变为[2.33,2.21,2.22]

如下使图像增强代码,用的使torchvision,每行代码都有注释

from torchvision import transforms

transform_train = transforms.Compose([transforms.ToTensor(),        //将图像变为Tensor张量,并将图像像素由255-0变为1-0,压缩,并将图像的维度从 (H x W x C) 转换为 (C x H x W)
                                      transforms.Pad(32, padding_mode='symmetric')   //表示在图像的四周各填充 32 个像素。
                                      transforms.RandomHorizontalFlip(),    //以一定的概率对图像进行随机水平翻转。这有助于增加数据的多样性,提高模型的泛化能力。防止拟合
                                      transforms.RandomVerticalFlip(),      //以一定的概率对图像进行随机垂直翻转。同样是为了增加数据多样性
                                      transforms.RandomRotation(10),       //以一定的概率对图像进行随机旋转,旋转角度在 -1010 度之间。增加数据的多样性
                                      transforms.Normalize((0.485, 0.456, 0.406),     //指定每个通道的均值。通常是在 ImageNet 数据集上计算得到的均值。
                                                           (0.229, 0.224, 0.225))])   //指定每个通道的标准差。也是在 ImageNet 数据集上计算得到的标准差。
                                                           

那么在数据单元创建的时候加入,以下是完整代码:

from torch.utils.data import Dataset

class ImageDataset(Dataset):
    def __init__(self, df, dir, transform=None): 
        super().__init__()
        
        self.df = df
        self.dir = dir
        self.transform = transform
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        img_id = self.df.iloc[idx,0]
        img_path = self.dir + img_id
        image = cv2.imread(img_path)            //这里用了cv2直接读取图片,避免了转换numpy
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)   //opencv里的数据增强
        label = self.df.iloc[idx,1]
        
        if self.transform is not None:
            image = self.transform(image)
        return image, label


-----------------------图像增强技术------------------------
from torchvision import transforms
transform_train = transforms.Compose([transforms.ToTensor(),
                                      transforms.Pad(32, padding_mode='symmetric'),
                                      transforms.RandomHorizontalFlip(),
                                      transforms.RandomVerticalFlip(),
                                      transforms.RandomRotation(10),
                                      transforms.Normalize((0.485, 0.456, 0.406),
                                                           (0.229, 0.224, 0.225))])
transform_test = transforms.Compose([transforms.ToTensor(),
                                     transforms.Pad(32, padding_mode='symmetric'),
                                        transforms.Normalize((0.485, 0.456, 0.406),
                                                           (0.229, 0.224, 0.225))])


from torch.utils.data import DataLoader
dataset_train = ImageDataset(df=train_df, img_dir='train/',transform=transform_train)
loader_train = DataLoader(dataset=dataset_train, batch_size=32, shuffle=True)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2108680.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【PyTorch】基础环境如何打开

前期安装可以基于这个视频,本文是为了给自己存档如何打开pycharm和jupyter notebookPyTorch深度学习快速入门教程(绝对通俗易懂!)【小土堆】_哔哩哔哩_bilibili Pycharm 配置 新建项目的时候选择解释器pytorch-gpu即可。 Jupyte…

【C++ 第二十二章】C++的类型转换

1.C语言中的类型转换 在C语言中,如果赋值运算符左右两侧类型不同,或者形参与实参类型不匹配,或者返回值类型与接收返回值类型不一致时,就需要发生类型转化,C语言中总共有两种形式的类型转换:隐式类型转换和…

CDA数据分析一级考试备考攻略

一、了解考试内容和结构 CDA一级考试主要涉及的内容包括:数据分析概述与职业操守、数据结构、数据库基础与数据模型、数据可视化分析与报表制作、Power BI应用、业务数据分析与报告编写等。 CDA Level Ⅰ 认证考试大纲:https://www.cdaglobal.com/certification.h…

一文还原时序数据库 IoTDB 在 TPCx-IoT 的测试全流程!

在云服务硬件环境下,IoTDB 写入、查询、利用资源能力均表现出色! 之前,我们为大家介绍了基于 IoTDB 的企业级产品 TimechoDB,在 TPCx-IoT 基准测试中打破世界纪录,取得的双指标第一成绩,和选择 TPCx-IoT 的…

【Python机器学习】核心数、进程、线程、超线程、L1、L2、L3级缓存

如何知道自己电脑的CPU是几核的,打开任务管理器(同时按下:Esc键、SHIFT键、CTRL键) 然后,点击任务管理器左上角的性能选项,观察右下角中的内核:后面的数字,就是你CPU的核心数,下图中我的是16个核心的。 需要注意的是,下面的逻辑处理器:32 表示支持 32 线程(即超线…

【爬虫软件】批量采集短视频博主的主页作品

用python开发的DY爬虫采集软件,可自动按博主抓取其已发布视频数据。 软件界面: 采集结果: 日志记录: 软件说明: 演示视频: https://www.bilibili.com/video/BV1Kb42187qf 讲解文章: https://www.bi…

2024数学建模国赛选题建议+团队助攻资料

目录 一、题目特点和选题建议 二、模型选择 1、评价模型 2、预测模型 3、分类模型 4、优化模型 5、统计分析模型 三、white学长团队助攻资料 1、助攻代码 2、成品论文PDF版 3、成品论文word版 9月5日晚18:00就要公布题目了,根据历年竞赛题目…

QT: Unable to create a debugging engine.

1.问题场景: 第一次安装QT,没有配置debug功能 打开控制面板》程序》找到Kit 重启电脑即可 2.问题场景: qt原本一直好好的,突然有天打开运行调试版本,提示Unable to create a debugging engine.错误。这个是指无法创…

CIOE中国光博会&电巢科技即将联办“智能消费电子创新发展论坛”

在科技浪潮汹涌澎湃的当下,从通信领域的高速光传输,到消费电子中的高清显示与先进成像技术,光电技术的应用范围不断拓展且日益深化。而AIGC 凭借其丰富的内容供给与个性化反馈能力,正为新一代消费电子及智能穿戴产品开辟崭新的发展…

具身智能猜想 ——机器人进化

设想一个机器人进化的仿真模拟环境,可以通过 “基因突变” 产生新功能,让机器人逐步进化。以下是这个进化系统的关键要素和可能的实现步骤: 1. 仿真环境 虚拟世界:创建一个包含多样化任务和挑战的虚拟环境,如探索、抓…

uniapp 实现tabbar图标凸起

实现tabbar图标凸起有两种,第一种是自定义tabbar,第二种就是使用官方的tabbar跟api实现,自定义在体验中不如原生的tabbar,所以我下面展示的是使用官方的tabbar跟api实现 效果如图: 左边是未选中中间的凸起&#xff0c…

深入解密 Elasticsearch 查询优化:巧用 Profile 工具/API 提升性能

1、Elasticsearch Profile 工具介绍 在使用 ES 进行检索查询时,我们常常要去优化一些复杂的查询语句,这里 ES 结合 lucene 的生态制作了 Profile API 和图形化的 Profile 分析界面以供用户使用。 这里我们来简单讲解一下这个工具 API,希望能给…

全双工语音交互

文章目录 微软小冰全双工字节大模型语音交互[Language Model Can Listen While Speaking](https://arxiv.org/html/2408.02622v1) 微软小冰全双工 全双工的定义:一路持续的听,upload audio;一路持续的输出,download audio&#xf…

C#中的Graphics类和SetQuality()自定义方法

在 C# 中,Graphics 类是 System.Drawing 命名空间的一部分,它提供了一组方法和属性,用于在 Windows Forms 应用程序中进行二维绘图。Graphics 对象可以绘制文本、线条、曲线、形状和图像,并可以对它们进行变换和剪辑。 Graphics …

【JAVA入门】Day33 - Collections

【JAVA入门】Day33 - Collections 文章目录 【JAVA入门】Day33 - Collections Collections 是集合的工具类。其包含的方法如下表所示,其中前两个方法最为常用。 以下代码演示了如何创建集合并批量添加数据,然后打乱集合元素顺序,然后用二分法…

数据结构:(LeetCode203)移除链表元素

给你一个链表的头节点 head 和一个整数 val ,请你删除链表中所有满足 Node.val val 的节点,并返回 新的头节点 。 示例 1: 输入:head [1,2,6,3,4,5,6], val 6 输出:[1,2,3,4,5]示例 2: 输入&#xff1…

代码随想录:343. 整数拆分

343. 整数拆分 class Solution { public:int integerBreak(int n) {int dp[100]{0};//拆分i的最大乘积为dp[i]dp[1]1;//初始化&#xff0c;主要是为了dp[2]初始for(int i2;i<n;i){for(int j1;j<i;j){ dp[i]max(dp[i],max(j,dp[j])*max(i-j,dp[i-j]));//取最大值&#x…

深入Linux轻量级进程管理:线程创建、线程ID解析与进程地址空间页表探究

&#x1f351;个人主页&#xff1a;Jupiter. &#x1f680; 所属专栏&#xff1a;Linux从入门到进阶 欢迎大家点赞收藏评论&#x1f60a; 目录 &#x1f6b2;Linux线程控制&#x1f40f;POSIX线程库&#x1f415;创建线程&#x1f41f;指令查看轻量级进程指令&#xff1a;ps -a…

Python(TensorFlow)和Java及C++受激发射损耗导图

&#x1f3af;要点 神经网络监督去噪预测算法聚焦荧光团和检测模拟平台伪影消除算法性能优化方法自动化多尺度囊泡动力学成像生物研究多维分析统计物距粒子概率算法 Python和MATLAB图像降噪算法 消除噪声的一种方法是将原始图像与表示低通滤波器或平滑操作的掩模进行卷积。…

汇编伪指令 GNU 风格(24)

先来看看关于标号的内容。 这里的局部标号是需要注意的。 全局标号&#xff0c;以及注释 可以不看。 来看一个例子&#xff1b; 这里的 BSYM 我不知道是什么意思。 在来看看关于伪操作的内容&#xff0c; 一般是以 . 开头的。 这是基本的一些操作。 然后是举例&#xff1a; …