Few Shot Classification小知识——数据集的加载

news2025/1/12 10:55:29

概述

Few-shot classification(小样本分类)是机器学习和人工智能的一个子领域,解决的问题是在训练数据非常有限的情况下,学习对新样本进行分类。在传统的监督学习中,模型需要在包含大量标记样本的数据集上进行训练,每个类别都有丰富的标记样本。然而,在实际应用中,获得如此大量的标记数据可能会非常困难或昂贵。
目前网上对于入门few shot的十分少,博主之前对于episode这些也十分不明白,在看了一些资料和代码后才逐渐理解小样本是怎样进行训练的。对此,博主首先对其中的数据集加载部分进行了总结,希望能够对各位读者有一些启发。

步骤

1.修改文件结构

- data_name
--- images
----- folder_name1
------- img1.png
------- img2.png
----- folder_name2
--- meta
----- classes.txt  
----- fsl_train.txt
----- fsl_test.txt
----- fsl_train_class.txt
----- fsl_test_class.txt

其中folder_name1,folder_name2是文件夹的名字,通常是分类名称,有些可能也是下标数字(1-100的数字)

2.找到图像的标签文件classes.txt

classes.txt里面含有图像全部的类别,如果没有需要自己构建一个,标签文件的内容大致如下:

class_name1
class_name2
class_name3

3.使用代码生成文件

生成的文件包括:fsl_train.txt,fsl_test.txt,fsl_train_class.txt,fsl_test_class.txt文件
该代码目前支持的情况有:

  • (1) folder_name为类别名称
  • (2) folder_name为类别名称对应的下标,从1开始
  • (3) folder_name文件夹下面的图片名称全部是数字,没有其他符号
  • (4) folder_name文件夹下面的图片名称什么符号都有

文件大致内容为:fsl_train.txt:
在这里插入图片描述
fsl_train_class.txt
在这里插入图片描述

代码为:

def make_file(img_root_path, names, path, is_num):
    """
    :param img_root_path: 图像文件夹
    :param names: 对应的图像文件名称
    :param path: 要保存的路径
    :param is_num: 图像文件名称是否是数字
    """
    with open(path,"w") as f:
        for name in names:
            img_dir = os.path.join(img_root_path,str(name))
            img_names = os.listdir(img_dir)
            if is_num:
                sort_img_names = sorted(img_names,key=lambda s: int(s.split('.')[0]))
            else:
                sort_img_names = sorted(img_names)
            for img_name in sort_img_names:
                img_path = os.path.join(img_dir,img_name).replace(img_root_path + "/","")
                f.write(f"{img_path}\n")
            
def generate_split_dataset(data_root, train_num, is_imgs_id, is_img_name_num):
    """
    :param data_root: 数据集目录
    :param train_num: 用于训练的类别数目
    :param is_imgs_id: 图像文件夹名称是否是下标
    :param is_img_name_num: 图像名字是否是数字 
    :return: None
    """
    class_path = os.path.join(data_root,"meta", "classes.txt")
    class_list = list_from_file(class_path)
    if is_imgs_id:
    	# 下标从1开始,可以根据自己的需要修改
        id2class = {i + 1 : _class for i, _class in enumerate(class_list)}
    else:
        id2class = {i: _class for i, _class in enumerate(class_list)}
    # class2id = {_class : i + 1 for i, _class in enumerate(class_list)}
    # 选择train_num个类作为训练集的,其他作为测试的
    train_class_ids = random.sample(range(1, len(class_list) + 1),train_num)
    test_class_ids = []
    for id in range(1, len(class_list) + 1):
        if id not in train_class_ids:
            test_class_ids.append(id)
    # 获得images文件夹的名称
    if is_imgs_id:
        train_class_name = train_class_ids
        test_class_name = test_class_ids
    else:
        train_class_name = [id2class[id] for id in train_class_ids]
        test_class_name = [id2class[id] for id in test_class_ids]
    # 顺序排序
    train_class_name = sorted(train_class_name)
    test_class_name = sorted(test_class_name)
    train_class_save_path = os.path.join(data_root, "meta", "fsl_train_class.txt")
    test_class_save_path = os.path.join(data_root, "meta" , "fsl_test_class.txt")
    with open(train_class_save_path, "w") as f:
        for cls_name in train_class_name:
            f.write(f"{str(cls_name)}\n")

    with open(test_class_save_path, "w") as f:
        for cls_name in test_class_name:
            f.write(f"{str(cls_name)}\n")

    # 将这些数据保存在fsl_train.txt中,格式为:class_name/img_name
    img_root_path = os.path.join(data_root,"images")
    train_imgs_name_path = os.path.join(data_root, "meta", "fsl_train.txt")
    test_imgs_name_path = os.path.join(data_root, "meta", "fsl_test.txt")
    make_file(img_root_path, train_class_name, train_imgs_name_path,is_img_name_num)
    make_file(img_root_path, test_class_name,test_imgs_name_path, is_img_name_num)

4.构建basedataset类

basedataset类是一个用于加载含有类别名称的文件,代码为:

import copy
from abc import ABCMeta, abstractmethod
from typing import Dict, List, Mapping, Optional, Sequence, Union
from torch.utils.data import Dataset
import torchvision.transforms as transforms
import numpy as np
import os.path as osp
from PIL import Image
import torch

from util import tools

from mmpretrain.evaluation import Accuracy

class BaseFewShotDataset(Dataset, metaclass=ABCMeta):
    def __init__(self,
                 pipeline,
                 data_prefix: str,
                 classes: Optional[Union[str, List[str]]] = None,
                 ann_file: Optional[str] = None) -> None:
        super().__init__()

        self.ann_file = ann_file
        self.data_prefix = data_prefix
        self.pipeline = pipeline
        self.CLASSES = self.get_classes(classes)
        self.data_infos = self.load_annotations()
        self.data_infos_class_dict = {i: [] for i in range(len(self.CLASSES))}
        for idx, data_info in enumerate(self.data_infos):
            self.data_infos_class_dict[data_info['gt_label'].item()].append(
                idx)

    def load_image_from_file(self,info_dict):
        img_prefix = info_dict['img_prefix']
        img_name = info_dict['img_info']['filename']
        img_file = osp.join(img_prefix,f"{img_name}")
        img_data = Image.open(img_file).convert('RGB')
        return img_data

    @abstractmethod
    def load_annotations(self):
        pass

    @property
    def class_to_idx(self) -> Mapping:
        return {_class: i for i, _class in enumerate(self.CLASSES)}

    def prepare_data(self, idx: int) -> Dict:
        results = copy.deepcopy(self.data_infos[idx])
        imgs_data = self.load_image_from_file(results)
        data = {
            "img" : self.pipeline(imgs_data),
            "gt_label" : torch.tensor(self.data_infos[idx]['gt_label'])
        }
        return data

    def sample_shots_by_class_id(self, class_id: int,
                                 num_shots: int) -> List[int]:
        all_shot_ids = self.data_infos_class_dict[class_id]
        return np.random.choice(
            all_shot_ids, num_shots, replace=False).tolist()

    def __len__(self) -> int:
        return len(self.data_infos)

    def __getitem__(self, idx: int) -> Dict:
        return self.prepare_data(idx)

    @classmethod
    def get_classes(cls,
                    classes: Union[Sequence[str],
                                   str] = None) -> Sequence[str]:
        if isinstance(classes, str):
            class_names = tools.list_from_file(classes)
        elif isinstance(classes, (tuple, list)):
            class_names = classes
        else:
            raise ValueError(f'Unsupported type {type(classes)} of classes.')

        return class_names

5.构建通用的少样本数据集加载类UniversalFewShotDataset

该文件的作用主要是将数据从标签文件中拿出来,加载数据。
代码如下:

from datasets.base import BaseFewShotDataset
from typing_extensions import Literal
from typing import List, Optional, Sequence, Union
from util import tools
import os
import os.path as osp
import numpy as np
import torchvision.transforms as transforms
class UniversalFewShotDataset(BaseFewShotDataset):
    def __init__(self,
                 img_size,
                 subset: Literal['train', 'test', 'val'] = 'train',
                 *args,
                 **kwargs):
        if isinstance(subset, str):
            subset = [subset]
        for subset_ in subset:
            assert subset_ in ['train', 'test', 'val']
        self.subset = subset
        self.file_format = file_format
        # 归一化参数
        norm_params = {'mean': [0.485, 0.456, 0.406],
                       'std': [0.229, 0.224, 0.225]}
        # 对数据进行处理
        if subset[0] == 'train':
            pipeline = transforms.Compose([
                transforms.RandomResizedCrop(img_size),
                transforms.RandomHorizontalFlip(),
                transforms.ColorJitter(brightness=0.4,contrast=0.4,saturation=0.4),
                transforms.ToTensor(),
                transforms.Normalize(**norm_params)
        ])
        elif subset[0] == 'test':
            pipeline = transforms.Compose([
                transforms.Resize(size=int(img_size * 1.15)),
                transforms.CenterCrop(size=img_size),
                transforms.ToTensor(),
                transforms.Normalize(**norm_params)
            ])
        super().__init__(pipeline=pipeline, *args, **kwargs)

    def get_classes(
            self,
            classes: Optional[Union[Sequence[str], str]] = None) -> Sequence[str]:
        class_names = tools.list_from_file(classes)
        return class_names
    
	# 加载标签文件
    def load_annotations(self) -> List:
        data_infos = []
        ann_file = self.ann_file
        with open(ann_file) as f:
            for i, line in enumerate(f):
                class_name, filename = line.strip().split('/')
                gt_label = self.class_to_idx[class_name]
                info = {
                    'img_prefix':
                    osp.join(self.data_prefix, 'images', class_name),
                    'img_info': {
                        'filename': filename
                    },
                    'gt_label':
                    np.array(gt_label, dtype=np.int64)
                }
                data_infos.append(info)
        return data_infos

6.构建针对元学习的数据集加载类EpisodicDataset

代码如下:

import numpy as np
from torch import Tensor
from torch.utils.data import Dataset,DataLoader
from functools import partial
import os.path as osp
from typing import Mapping
from util import tools
import json
class EpisodicDataset:
    def __init__(self,
                 dataset: Dataset,
                 num_episodes: int,
                 num_ways: int,
                 num_shots: int,
                 num_queries: int,
                 episodes_seed: int):
        self.dataset = dataset
        self.num_ways = num_ways
        self.num_shots = num_shots
        self.num_queries = num_queries
        self.num_episodes = num_episodes
        self._len = len(self.dataset)
        self.CLASSES = dataset.CLASSES
        self.episodes_seed = episodes_seed
        self.episode_idxes, self.episode_class_ids = \
            self.generate_episodic_idxes()

    def generate_episodic_idxes(self):
        """Generate batch indices for each episodic."""
        episode_idxes, episode_class_ids = [], []
        class_ids = [i for i in range(len(self.CLASSES))]
        # 这一句可以不用
        with tools.local_numpy_seed(self.episodes_seed):
            for _ in range(self.num_episodes):
                np.random.shuffle(class_ids)
                # sample classes
                sampled_cls = class_ids[:self.num_ways]
                episode_class_ids.append(sampled_cls)
                episodic_support_idx = []
                episodic_query_idx = []
                # sample instances of each class
                for i in range(self.num_ways):
                    shots = self.dataset.sample_shots_by_class_id(
                        sampled_cls[i], self.num_shots + self.num_queries)
                    episodic_support_idx += shots[:self.num_shots]
                    episodic_query_idx += shots[self.num_shots:]
                episode_idxes.append({
                    'support': episodic_support_idx,
                    'query': episodic_query_idx
                })
        return episode_idxes, episode_class_ids

    def __getitem__(self, idx: int):
        support_data = [self.dataset[i] for i in self.episode_idxes[idx]['support']]
        query_data = [self.dataset[i] for i in self.episode_idxes[idx]['query']]
        return {
            'support_data':support_data,
            'query_data':query_data
        }

    def __len__(self):
        return self.num_episodes

    def evaluate(self, *args, **kwargs):
        return self.dataset.evaluate(*args, **kwargs)

    def get_episode_class_ids(self, idx: int):
        return self.episode_class_ids[idx]

7.构建自己的配置文件,如:json格式

配置文件除了json,也可以是其他形式的,这里以json格式为例:

{
    "train":{
        "num_episodes":2000,
        "num_ways":10,
        "num_shots":5,
        "num_queries":5,
        "episodes_seed":1001,
        "per_gpu_batch_size":1,
        "per_gpu_workers": 8,
        "epoches": 160,
        "dataset":{
            "name": "vireo_172",
            "img_size": 224,
            "data_prefix":"/home/gaoxingyu/dataset/vireo-172/",
            "classes":"/home/gaoxingyu/dataset/vireo-172/meta/fsl_train_class.txt",
            "ann": "/home/gaoxingyu/dataset/vireo-172/meta/fsl_train.txt"
        }
    }
}

8.编写主程序,进行测试

代码如下:

with open("config.json", 'r', encoding='utf-8') as f:
     f = f.read()
     configs = json.loads(f)
     logger.info(f"Experiment Setting:{configs}")
# 创建数据集
## train_dataset
train_food_dataset = UniversalFewShotDataset(data_prefix=configs['train']['dataset']['data_prefix'],
                         subset="train", classes=configs['train']['dataset']['classes'],
                         img_size=configs['train']['dataset']['img_size'],ann_file=configs['train']['dataset']['ann'])
train_dataset = EpisodicDataset(dataset=train_food_dataset,
                                num_episodes=configs['train']['num_episodes'],
                                num_ways=configs['train']['num_ways'],
                                num_shots=configs['train']['num_shots'],
                                num_queries=configs['train']['num_queries'],
                                episodes_seed=configs['train']['episodes_seed'])
## train dataloader
train_samper = torch.utils.data.distributed.DistributedSampler(train_dataset, rank = local_rank, shuffle=True)
train_data_loader = DataLoader(
    dataset=train_dataset,
    batch_size=configs['train']['per_gpu_batch_size'],
    sampler=train_samper,
    num_workers=configs['train']['per_gpu_workers'],
    collate_fn=partial(collate, samples_per_gpu=1),
    worker_init_fn=worker_init_fn,
    drop_last=True
)
for data in train_data_loader:
	print(data)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/811277.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Linux安装wget

1.第一步登录wget官网下载地址,下载最新的wget的rpm安装包到本地 官网地址:http://mirrors.163.com/centos/7/os/x86_64/Packages/ 2.将下载好的wget的rpm安装包通过Xftp工具上传到Linux服务器对应目录下。 3.cd命令进入到这个wget目录下,再…

【文献分享】动态环境下竟然能实现实时语义RGB-D SLAM??

论文题目:Towards Real-time Semantic RGB-D SLAM in Dynamic Environments 中文题目:动态环境下实时语义RGB-D SLAM研究 作者:Tete Ji, Chen Wang, and Lihua Xie 作者机构:新加坡南洋理工大学电气与电子工程学院 卡内基梅隆大…

用于WINDOWS的HACKRF ONE扫频分析仪

https://github.com/pavsa/hackrf-spectrum-analyzer GitHub - mutability/rtl-sdr: RTL-SDR *very* experimental branch - its probably broken! https://github.com/greatscottgadgets/hackrf hackrf_sweep 用于WINDOWS的HACKRF ONE扫频分析仪 几个星期前,Ha…

Java动态代理(全网最详细,没有之一)

首先你要明白为什么要创建代理??? 例如:我们看下面这张图我们发现,有很多重复的代码,我们就可以创建代理,让代理帮我们干这些事情。 1.想要创建代理,我们就要为这个类写一个接口 pu…

无涯教程-jQuery - Menu组件函数

小部件菜单功能可与JqueryUI中的小部件一起使用。一个简单的菜单显示项目列表。 Menu - 语法 $( "#menu" ).menu(); Menu - 示例 以下是显示菜单用法的简单示例- <!doctype html> <html lang"en"><head><meta charset"utf-…

基于Linux操作系统中的MySQL数据库备份(三十三)

目录 一、概述 二、数据备份的重要性 三、造成数据丢失的原因 1、程序错误 2、人为错误 3、运算失败 4、磁盘故障 5、灾难&#xff08;如火灾、地震&#xff09;和盗窃 四、备份类型 &#xff08;一&#xff09;物理与逻辑角度 1、物理备份 1.1、冷备份 1.2、热备…

人工智能-Dlib+Python实现人脸识别(人脸识别篇)

人脸识别流程 人脸检测,人脸数据提取:首先是检测到人脸保存人脸数据:可以保存到mysql数据库中mysql数据库连接mysql数据库安装mysql数据库操作设置人脸数据标签:(人脸名字),保存到数据库打开摄像头,检测到人脸,提取人脸数据:人脸数据与数据库中的数据对比,1、人脸检…

子组件未抛出事件 父组件如何通过$refs监听子组件中数据的变化

我们平时开发项目会使用一些比较成熟的组件库, 但是在极小的情况下,可能会出现我们需要监听某个属性的变化,使我们的页面根据这个属性发生一些改变,但是偏偏组件库没有把这个属性抛出来,当我们使用watch通过refs监听时,由于生命周期的原因还不能拿到,这时候我们可以这样做,以下…

03-高阶导数_导数判断单调性_导数与极值

高阶导数 前面学的是一阶导数&#xff0c;对导数再次求导就是高阶导数&#xff0c;二阶和二阶以上的导数统称为高阶导数。 导数与函数单调性的关系 极值定理 导数为我们寻找极值提供依据&#xff0c;对于可导函数而言&#xff0c;因为在极值位置必然有函数的导数等于 0。 …

深入篇【C++】手搓模拟实现list类(详细剖析底层实现原理)模拟实现正反向迭代器【容器适配器模式】

深入篇【C】手搓模拟实现list类(详细剖析底层实现原理&#xff09;&& 模拟实现正反向迭代器【容器适配器模式】 Ⅰ.迭代器实现1.一个模板参数2.两个模板参数3.三个模板参数 Ⅱ.反向迭代器实现1.容器适配器模式 Ⅲ.list模拟实现1.定义结点2.封装结点3.构造/拷贝4.迭代器…

【Python】Web学习笔记_flask(1)——模拟登录

安装flask pip3 install flask 第一部分内容&#xff1a; 1、主页面输出hello world 2、根据不同用户名参数输出用户信息 3、模拟登录 from flask import Flask,url_for,redirectappFlask(__name__)app.route(/) def index():return hello worldapp.route(/user/<uname…

linux_进程状态

目录 一. 概念铺设 状态是什么&#xff1f; 传统操作系统的状态转换图 二. 传统操作系统状态 1. 运行 2. 阻塞 3. 挂起 三. linux 中的进程状态 1. 总体介绍 2. R 3. S 4. D kill -9 D vs S 5. T kill T vs S 6. Z 什么是僵尸状态&#xff1f; 僵尸进程的危害 …

hadoop部署配置

端口名称 Hadoop2.x Hadoop3.x NameNode内部通信端口 8020 / 9000 8020 / 9000/9820 NameNode HTTP UI 50070 9870 MapReduce查看执行任务端口 8088 8088 历史服务器通信端口 19888 19888 端口名称Hadoop2.xHadoop3.xNameNode内部通信端口8020 / 90008020 / 9000/9820NameNode…

延长周末,获得高质量休息:工作与学习党的生活策略

&#x1f337;&#x1f341; 博主猫头虎 带您 Go to New World.✨&#x1f341; &#x1f984; 博客首页——猫头虎的博客&#x1f390; &#x1f433;《面试题大全专栏》 文章图文并茂&#x1f995;生动形象&#x1f996;简单易学&#xff01;欢迎大家来踩踩~&#x1f33a; &a…

【Linux】多线程的补充

1 线程安全的单例模式 1.1 什么是单例模式 单例模式是一种 "经典的, 常用的, 常考的" 设计模式. 1.2 什么是设计模式 IT行业这么火, 涌入的人很多. 俗话说林子大了啥鸟都有. 大佬和菜鸡们两极分化的越来越严重. 为了让菜鸡们不太拖大佬的后腿, 于是大佬们针对一些…

从源码角度配合网络编程函数API 分析下 三握手四挥手都做了什么

首先我们先说下网络编程API&#xff1a; 数据在网络上通信&#xff0c;通信的双方一个是 客户端&#xff0c; 一个是 服务器 更具体来说&#xff0c;不是 客户端和服务器这两个机器在 经由互联网 进行通信&#xff0c; 而是 客户端上的某一进程 与 服务器端的某一进程 进…

Vue2 第七节 Vue监测数据更新原理

&#xff08;1&#xff09;Vue会监视data中所有层次的数据 &#xff08;2&#xff09;如何监测对象中的数据 通过setter实现监视&#xff0c;且要在new Vue时传入要监测的数据对象中后追加的属性&#xff0c;Vue默认不做响应式处理如果要给后添加的属性做响应式&#xff0c;使…

【笔记】PyTorch DDP 与 Ring-AllReduce

转载请注明出处&#xff1a;小锋学长生活大爆炸[xfxuezhang.cn] 文内若有错误&#xff0c;欢迎指出&#xff01; 今天我想跟大家分享的是一篇虽然有点老&#xff0c;但是很经典的文章&#xff0c;这是一个在分布式训练中会用到的一项技术&#xff0c; 实际上叫ringallreduce。 …

Hyper-v 设置静态IP 搭建集群

背景 最近想在本机WIN11上创建几个Centos用于做几个试验&#xff0c;之前一直用VMWare&#xff0c;需要安装额外的软件&#xff0c;正好win自带虚拟机功能&#xff0c;只需要在功能中安装Hyper-v就可以使用。 新建虚拟机 虚拟机交换器 Hyper-V 虚拟交换机是基于软件的第 2 层…

P5691 [NOI2001] 方程的解数

题目 思路 暴搜显然会TLE&#xff0c;所以这时候就应该请出DFS的伙伴——折半搜索&#xff08;meet in the middle&#xff09;了 折半搜索的思路就是先搜完后一半后&#xff0c;借助这一半的数据来搜索前一半&#xff0c;效率是原来的2倍 这个题怎么才能折半搜索呢&#xff1…