[PyTorch][chapter 50][创建自己的数据集 2]

news2024/11/25 20:39:48

前言:

      这里主要针对图像数据进行预处理.定义了一个 class Pokemon(Dataset) 类,实现

图像数据集加载,划分的基本方法.
 


目录:

  1.      整体框架
  2.       __init__ 
  3.      load_images
  4.       save_csv
  5.      divide_data
  6.      __len__
  7.      denormalize
  8.     __getitem__
  9.    main
  10.    ImageFolder 

     


一  整体框架

       我们需要创建一个自定义的数据集类,该类必须继承自Dataset类,

      重点实现以下三个方法:

      __init__

   __len__()

       __getitem__()


二  __init__ 

      实现了图像数据集的加载

      根据mode 进行划分

    
    def __init__(self, root, resize, mode,fileName):
        #初始化函数
        super(Pokemon, self).__init__()
        
        self.root = root
        self.resize = resize
        self.name2label ={}
        
        #遍历目录
        path = os.path.join(root)
        #用子目录文件夹名字作为分类key
        for name in sorted(os.listdir(path)):
            subDir = os.path.join(root, name)
            if not os.path.isdir(subDir):
                continue
            else:
                self.name2label[name] = len(self.name2label.keys())
            
        
        csv_path = os.path.join(self.root, fileName)
        print("\n csv_path:  ",csv_path)
        if not os.path.exists(csv_path):
            images = self.load_images()
            self.save_csv(fileName, images)
        
        self.images, self.labels = self.load_csv(fileName)
        self.divide_data(mode)


三 load_images

    加载指定目录下面的图片,

   把图片路径保存到列表里面

  def load_images(self):
        images =[]
        for name in self.name2label.keys():
            #pokeon\\newtwoo\\00001.png
            #返回所有匹配的文件路径列表。它只有一个参数pathname,定义了文件路径匹配规则,这里可以是绝对路径,也可以是相对路径。下面是使用glob.glob的例子:
            pngPath = os.path.join(self.root, name,'*.png')
            jpgPath = os.path.join(self.root, name,'*.jpg')
            jpegPath = os.path.join(self.root, name,'*.jpeg')
            
            
            png = glob.glob(pngPath)
            jpg =glob.glob(jpgPath)
            jpeg = glob.glob(jpegPath)
         
            images +=jpg
            images +=jpeg
            images +=png
        print("\n images ",len(images))
        random.shuffle(images)
        return images

四    save_csv

       图片路径,标签保存到csv 文件里面

   

       #image, label
    def save_csv(self, fileName, images):
        
        path = os.path.join(self.root, fileName)
        csvfile = open(path,mode='w',newline='')
        writer = csv.writer(csvfile)
        
        for img in images:
            
            name = img.split(os.sep)[-2]
            
            label = self.name2label[name]
            
            writer.writerow([img, label])

        csvfile.close()


四  load_csv

    加载 csv 文件

    def load_csv(self, fileName):
        
        path = os.path.join(self.root, fileName)
        csvfile = open(path,mode='r',newline='')
        
        reader = csv.reader(csvfile)
        images =[]
        labels =[]
        for row in reader:
            
            img, label = row
            label = int(label)
            images.append(img)
            labels.append(label)
            
        m = len(images)
        n = len(labels)
        print("\n number images: %d number labels: %d"%(m,n))
        return  images,labels

五  divide_data

   数据集划分

    训练集: 60%

    验证集: 20%

    测试机:20%

    def divide_data(self,mode):
        
        N = len(self.images)
        if 'train' == mode: #0->60%
            start = 0
            end = int(0.6*N)
        elif 'val' == mode:#60%->80%
            start = int(0.6*N)
            end = int(0.8*N)
        else:#80%->100%
            start = int(0.8*N)
            end = N
            
        
        self.images = self.images[start:end]
        self.labels = self.labels[start:end]
        m = len(self.images )

        print("\n number divide images: %d "%(m))
        

六      __len__

    返回数据集大小

    def __len__(self):
        #总的数据
        N = len(self.images)
        return N

七  denormalize

   图像数据 标准后,当需要显示原图片的时候,需要反标准化

   def denormalize(self,x_hat):
        
        #x_hat =(x-mean)/std
        #x = x_hat*std+mean
        #x: [c,h,w]
        #mean: [3]=>[3,1,1]
        
        mean=[0.485, 0.456, 0.406]
        std=[0.229, 0.224, 0.225]
        
        mean = torch.tensor(mean).unsqueeze(1).unsqueeze(1)
        std =  torch.tensor(std).unsqueeze(1).unsqueeze(1)
        
        x =x_hat*std+mean
        
        return x
        

八  __getitem__

   根据指定的索引获取对应的图片,以及标签值

        
    
    def __getitem__(self, index):
        #返回当前index 对应的图片数据
         #self.images, self.labels
         #idx ~[0,N]
         
         img_path = self.images[index] #图片路径
         label = self.labels[index] #图片标签
         #print("\n img_path",img_path)
         tf = transforms.Compose([  
                          lambda x:Image.open(x).convert('RGB'),
                          transforms.Resize((int(self.resize*1.25) , int(self.resize*1.25))), 
                          transforms.RandomRotation(15), 
                          transforms.ToTensor(),
                          transforms.CenterCrop(self.resize),
                          transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                               std=[0.229, 0.224, 0.225])
                          ])

         img =  tf(img_path)
         label = torch.tensor(label)
         #print("\n index ",index, "\t img ",img.shape,"\t label ",label)
         return img, label

九  main 

1 先定义一个class Pokemon(Dataset): 类,并实现上面的方法

2    数据集的迭代加载,以及通过visdom 工具加载显示

def main():
    root ='pokemon'
    resize =224
    mode = 'test' #数据集分为三种 tain,val,test
    csvfile ='data.csv'
    db = Pokemon(root, resize, mode,csvfile)

    viz = visdom.Visdom()
   
    
    # datetime转字符串
    time.time() #显示当前的时间戳
    curtime = time.strftime('%H:%M:%S') #结构化输出当前的时间

   
    
    
    
    BATCH_SIZE = 32
    loader = DataLoader(dataset = db, batch_size = BATCH_SIZE,shuffle = True)
  
    for step, (batchX, batchY) in enumerate(loader):
            print( '| Step: ', step, '| batch x: ',batchX.shape, '| batch y: ', batchY.shape)
            viz.images(db.denormalize(batchX),nrow=8, win='batchX',opts=dict(title=curtime))
            viz.text(str(batchY.numpy()),win='batchY',opts=dict(title='label'))
            time.sleep(10)
    

    
if __name__ == "__main__" :
    main()

十  ImageFolder 

  自己的图像数据集如果有规律的话,可以直接用PyTorch API 函数实现 Pokemon

类的功能

from torchvision.datasets import ImageFolder
from torchvision import transforms
 
imgMean =[0.485, 0.456, 0.406]
imgStd = [0.229, 0.224, 0.225]
normalize=transforms.Normalize(mean=imgMean,std=imgStd)
transform=transforms.Compose([
    transforms.RandomCrop(180),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(), #将图片转换为Tensor,归一化至[0,1]
    normalize
])
 
dataset=ImageFolder('./data/train',transform=transform)

参考:

torchvision.datasets.ImageFolder使用详解_☞源仔的博客-CSDN博客

课时102 自定义数据集实战-5_哔哩哔哩_bilibili

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/878185.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

[NepCTF 2023] crypto 复现

这个赛很不理想,啥都不会。 拿了WP看了几个题,记录一下 random_RSA 这题不会是正常情况,我认为。对于论文题,不知道就是不知道,基本没有可能自己去完成论文。 题目不长,只有两个菜单,共可交…

Springboot 在 redis 中使用 BloomFilter 布隆过滤器机制

一、导入SpringBoot依赖 在pom.xml文件中&#xff0c;引入Spring Boot和Redis相关依赖 <!-- Google Guava 使用google的guava布隆过滤器实现--><dependency><groupId>com.google.guava</groupId><artifactId>guava</artifactId><vers…

IPv4分组

4.3.1 IPv4分组 IP协议定义数据传送的基本单元——IP分组及其确切的数据格式 1. IPv4分组的格式 IPv4分组由首部和数据部分&#xff08;TCP、UDP段&#xff09;组成&#xff0c;其中首部分为固定部分&#xff08;20字节&#xff09;和可选字段&#xff08;长度可变&#xff0…

行业追踪,2023-08-14

自动复盘 2023-08-14 凡所有相&#xff0c;皆是虚妄。若见诸相非相&#xff0c;即见如来。 k 线图是最好的老师&#xff0c;每天持续发布板块的rps排名&#xff0c;追踪板块&#xff0c;板块来开仓&#xff0c;板块去清仓&#xff0c;丢弃自以为是的想法&#xff0c;板块去留让…

深入浅出 栈和队列(附加循环队列、双端队列)

栈和队列 一、栈 概念与特性二、Stack 集合类及模拟实现1、Java集合中的 Stack2、Stack 模拟实现 三、栈、虚拟机栈、栈帧有什么区别&#xff1f;四、队列 概念与特性五、Queue集合类及模拟实现1、Queue的底层结构&#xff08;1&#xff09;顺序结构&#xff08;2&#xff09;链…

做海外游戏推广有哪些条件?

做海外游戏推广需要充分准备和一系列条件的支持。以下是一些关键条件&#xff1a; 市场调研和策略制定&#xff1a;了解目标市场的文化、玩家偏好、竞争格局等是必要的。根据调研结果制定适合的推广策略。 本地化&#xff1a;将游戏内容、界面、语言、货币等进行本地化&#…

计算两个字符串之间的编辑距离【支持多字节字符串】

/*** 计算两个字符串之间的编辑距离【支持多字节字符串】** param string $str1 求编辑距离中的其中一个字符串* param string $str2 求编辑距离中的另一个字符串** return int*/ function levenshtein_copy(string $str1, string $str2): int {$arr1 mb_str_split($str1);$ar…

IK分词器升级,MySQL热更新助一臂之力

ik分词器采用MySQL热更新 ​ 官方所给的IK分词器只支持远程文本文件热更新&#xff0c;不支持采用MySQL热更新&#xff0c;没关系&#xff0c;这难不倒伟大的博主&#xff0c;给哈哈哈。今天就来和大家讲一下如何采用MySQL做热更新IK分词器的词库。 一、建立数据库表 CREATE…

20个常考的前端算法题,你全都会吗?

现在面试中&#xff0c;算法出现的频率越来越高了&#xff0c;大厂基本必考 今天给大家带来20个常见的前端算法题&#xff0c;重要的地方已添加注释&#xff0c;如有不正确的地方&#xff0c;欢迎多多指正&#x1f495; 1、两数之和 题目&#xff1a;给定一个数组 nums 和一…

d3dcompiler43.dll缺失怎么修复?dll缺失解决方法分享

在使用电脑过程中&#xff0c;我们有时会遇到一些系统文件的问题&#xff0c;其中一个常见的问题是d3dcompiler43.dll文件的损坏或丢失。当这个文件出现问题时&#xff0c;可能会导致应用程序无法正常运行或图形渲染出现异常。最近我也遇到了这个问题&#xff0c;以下是我修复d…

ClickHouse(十八):Clickhouse Integration系列表引擎

进入正文前&#xff0c;感谢宝子们订阅专题、点赞、评论、收藏&#xff01;关注IT贫道&#xff0c;获取高质量博客内容&#xff01; &#x1f3e1;个人主页&#xff1a;含各种IT体系技术&#xff0c;IT贫道_Apache Doris,大数据OLAP体系技术栈,Kerberos安全认证-CSDN博客 &…

UE4拾取物品高亮显示

UE4系列文章目录 文章目录 UE4系列文章目录前言一、如何实现 前言 先看下效果&#xff0c;当角色靠近背包然后看向背包&#xff0c;背包就会高亮显示。 一、如何实现 1.为选中物品创建蓝图接口 在“内容” 窗口中&#xff0c;鼠标右键选择“蓝图”->蓝图接口&#xff0c…

P13-CNN学习1.3-ResNet(神之一手~)

论文地址:CVPR 2016 Open Access Repository https://arxiv.org/pdf/1512.03385.pdf Abstract 翻译 深层的神经网络越来越难以训练。我们提供了一个残差学习框架用来训练那些非常深的神经网络。我们重新定义了网络的学习方式&#xff0c;让网络可以直接学习输入信息与输出信息…

乐鑫ESP32S3串口下载出现奇怪问题解决方法

正在学习ESP32S3&#xff0c;有一个原厂BOX开发板&#xff0c;使用虚拟机&#xff0c;安装 debian11 &#xff0c;安装IDF4.4.5版本工具。下载box示例代码。 进入example,idf.py set-target esp32s3, idf.py flash 下载时&#xff0c;出现错误&#xff1a; Wrote 22224 bytes…

【Unity实战系列】如何把你的二次元老婆/老公导入Unity进行二创并且进行二次元渲染?(附模型网站分享)

君兮_的个人主页 即使走的再远&#xff0c;也勿忘启程时的初心 C/C 游戏开发 Hello,米娜桑们&#xff0c;这里是君兮_&#xff0c;在正式开始讲主线知识之前&#xff0c;我们先来讲点有趣且有用的东西。 我知道&#xff0c;除了很多想从事游戏开发行业的人以外&#xff0c;还…

试岗第一天问题

1、公司的一个项目拉下来 &#xff0c;npm i 不管用显示 后面百度 使用了一个方法 虽然解决 但是在增加别的依赖不行&#xff0c;后面发现是node版本过高&#xff0c;更换node版本解决。 2、使用插件动态的使数字从0到100&#xff08;vue-animate-number插件&#xff09; 第一…

Redis之删除策略

文章目录 前言一、过期数据二、数据删除策略2.1定时删除2.2惰性删除2.3 定期删除2.4 删除策略比对 三、逐出算法3.1影响数据逐出的相关配置 总结 前言 Redis的常用删除策略 一、过期数据 Redis是一种内存级数据库&#xff0c;所有数据均存放在内存中&#xff0c;内存中的数据可…

Python 图形界面框架TkInter(第八篇:理解pack布局)

前言 tkinter图形用户界面框架提供了3种布局方式&#xff0c;分别是 1、pack 2、grid 3、place 介绍下pack布局方式&#xff0c;这是我们最常用的布局方式&#xff0c;理解了pack布局&#xff0c;绝大多数需求都能满足。 第一次使用pack&#xff08;&#xff09; import …

大模型相关知识

一. embedding 简单来说&#xff0c;embedding就是用一个低维的向量表示一个物体&#xff0c;可以是一个词&#xff0c;或是一个商品&#xff0c;或是一个电影等等。这个embedding向量的性质是能使距离相近的向量对应的物体有相近的含义&#xff0c;比如 Embedding(复仇者联盟)…

湖南大学计算机考研分析

关注我们的微信公众号 姚哥计算机考研 更多详情欢迎咨询 24计算机考研|上岸指南 湖南大学 湖南大学计算机考研招生学院是信息科学与工程学院。目前均已出拟录取名单。 湖南大学信息科学与工程学院内设国家示范性软件学院、国家保密学院和湘江人工智能学院&#xff0c;计算机…