[高光谱]使用PyTorch的dataloader加载高光谱数据

news2024/12/24 11:28:43

本文实验的部分代码参考

Hyperspectral-Classificationicon-default.png?t=N4P3https://github.com/eecn/Hyperspectral-Classification如果对dataloader的工作原理不太清楚可以参见

[Pytorch]DataSet和DataLoader逐句详解icon-default.png?t=N4P3https://blog.csdn.net/weixin_37878740/article/details/129350390?spm=1001.2014.3001.5501

一、原理解析

        常见的高光谱数据维.mat格式,由数据文件gt(ground-truth)文件组成,图像数据和标签数据。这里以印度松数据为例,图像数据的尺寸为145*145*200,标签数据的尺寸为145*145*1。

         本文的实验代码主要思想如下:

                ①获取高光谱数据集gt标签集

                ②按一定比例将数据集切割为训练集、测试集、验证集

                ③将训练集和验证集装入dataloader

二、获取高光谱数据

#  解析高光谱数据
def get_dataset(target_folder,dataset_name):
    palette = None
    
    #  拼接文件路径
    folder = target_folder + '/' + dataset_name
    
    #  打开数据文件
    if dataset_name == 'IndianPines':
        img = open_file(folder + '/Indian_pines_corrected.mat')
        img = img['indian_pines_corrected'] #选择矩阵
        
        rgb_bands = (43, 21, 11)  # AVIRIS sensor
        gt = open_file(folder + '/Indian_pines_gt.mat')['indian_pines_gt']
        #  设置标签
        label_values = ["Undefined", "Alfalfa", "Corn-notill", "Corn-mintill",
                        "Corn", "Grass-pasture", "Grass-trees",
                        "Grass-pasture-mowed", "Hay-windrowed", "Oats",
                        "Soybean-notill", "Soybean-mintill", "Soybean-clean",
                        "Wheat", "Woods", "Buildings-Grass-Trees-Drives",
                        "Stone-Steel-Towers"]
        ignored_labels = [0]
    
    #  设置背景标签
    nan_mask = np.isnan(img.sum(axis=-1))
    img[nan_mask] = 0
    gt[nan_mask] = 0
    ignored_labels.append(0)
    
    #  数据格式转换
    ignored_labels = list(set(ignored_labels))
    img = np.asarray(img, dtype='float32')
    data = img.reshape(np.prod(img.shape[:2]), np.prod(img.shape[2:]))
    data  = preprocessing.minmax_scale(data)
    img = data.reshape(img.shape)
    return img, gt, label_values, ignored_labels, rgb_bands, palette

        这里仅适配了印度松,有其他数据集需求的可以自行修改内部的参数。

        该函数会从.mat文件中获取图像文件和gt文件,并将相关信息打包返回,其中,读取文件的函数为:open_file(.)

#  打开高光谱文件
def open_file(dataset):
    _, ext = os.path.splitext(dataset)
    ext = ext.lower()
    # 根据格式不同打开文件
    if ext == '.mat':
        return io.loadmat(dataset)
    elif ext == '.tif' or ext == '.tiff':
        return imageio.imread(dataset)
    elif ext == '.hdr':
        img = spectral.open_image(dataset)
        return img.load()
    else:
        raise ValueError("Unknown file format: {}".format(ext))

        在主函数中调用如下:

DataSetName = 'IndianPines'
target_folder = 'Dataset'

img, gt, LABEL_VALUES, IGNORED_LABELS, RGB_BANDS, 
            palette = get_dataset(target_folder,DataSetName)

二、DataSet类

        在使用DataSet类加载数据集前,我们需要将数据集进行随机划分,这里直接调用了原项目的sample_gt(.)函数对gt进行分割。

def sample_gt(gt, train_size, mode='random'):
    indices = np.nonzero(gt)
    X = list(zip(*indices)) # x,y features
    y = gt[indices].ravel() # classes
    train_gt = np.zeros_like(gt)
    test_gt = np.zeros_like(gt)
    if train_size > 1:
       train_size = int(train_size)
    
    if mode == 'random':
       train_indices, test_indices = sklearn.model_selection.train_test_split(X, train_size=train_size, stratify=y)
       train_indices = [list(t) for t in zip(*train_indices)]
       test_indices = [list(t) for t in zip(*test_indices)]
       train_gt[tuple(train_indices)] = gt[tuple(train_indices)]
       test_gt[tuple(test_indices)] = gt[tuple(test_indices)]
    elif mode == 'fixed':
       print("Sampling {} with train size = {}".format(mode, train_size))
       train_indices, test_indices = [], []
       for c in np.unique(gt):
           if c == 0:
              continue
           indices = np.nonzero(gt == c)
           X = list(zip(*indices)) # x,y features

           train, test = sklearn.model_selection.train_test_split(X, train_size=train_size)
           train_indices += train
           test_indices += test
       train_indices = [list(t) for t in zip(*train_indices)]
       test_indices = [list(t) for t in zip(*test_indices)]
       train_gt[train_indices] = gt[train_indices]
       test_gt[test_indices] = gt[test_indices]

    elif mode == 'disjoint':
        train_gt = np.copy(gt)
        test_gt = np.copy(gt)
        for c in np.unique(gt):
            mask = gt == c
            for x in range(gt.shape[0]):
                first_half_count = np.count_nonzero(mask[:x, :])
                second_half_count = np.count_nonzero(mask[x:, :])
                try:
                    ratio = first_half_count / second_half_count
                    if ratio > 0.9 * train_size and ratio < 1.1 * train_size:
                        break
                except ZeroDivisionError:
                    continue
            mask[:x, :] = 0
            train_gt[mask] = 0

        test_gt[train_gt > 0] = 0
    else:
        raise ValueError("{} sampling is not implemented yet.".format(mode))
    return train_gt, test_gt

        主函数调用如下:

#--训练集占比
SAMPLE_PERCENTAGE = 0.1

#--数据集划分
train_gt, test_gt = sample_gt(gt,SAMPLE_PERCENTAGE,mode='random')
train_gt, val_gt = sample_gt(train_gt, 0.95, mode='random')

        随后将划分好的数据集放入DataSet类中,DataSet类共计9个参数,分别代表:

data-高光谱数据集;
gt-标签集;
patch_size-邻居个数(即感受野,影响提取的每个块大小);
ignored_labels - 需要忽略的类别;
flip_augmentation - 是否使用随机折叠;
radiation_augmentation - 是否使用随机噪声;
mixture_augmentation - 是否对光谱进行随机混合
center_pixel - 设置为True以仅考虑中心像素的标签
supervision - 训练模式,可选'full'-全监督 或 'semi'-半监督

        DataSet如下:

#  高光谱dataset类
class HyperX(torch.utils.data.Dataset):
    
    def __init__(self,data,gt,patch_size,ignored_labels,flip_augmentation,radiation_augmentation,mixture_augmentation,center_pixel,supervision):
        super().__init__()
        self.data = data
        self.label = gt
        self.patch_size = patch_size
        self.ignored_labels = ignored_labels
        self.flip_augmentation = flip_augmentation
        self.radiation_augmentation = radiation_augmentation
        self.mixture_augmentation = mixture_augmentation
        self.center_pixel = center_pixel
        supervision = supervision
        
        # 监督模式
        if supervision == 'full':
            mask = np.ones_like(gt)
            for l in self.ignored_labels:
                mask[gt == l] = 0
        #  半监督模式
        elif supervision == 'semi':
            mask = np.ones_like(gt)
        
        x_pos, y_pos = np.nonzero(mask)
        p = self.patch_size // 2
        self.indices = np.array([(x,y) for x,y in zip(x_pos, y_pos) if x > p-1 and x < data.shape[0] - p and y > p-1 and y < data.shape[1] - p])
        self.labels = [self.label[x,y] for x,y in self.indices]
        np.random.shuffle(self.indices)
        
    @staticmethod   #静态方法
    def flip(*arrays):
        horizontal = np.random.random() > 0.5
        vertical = np.random.random() > 0.5
        if horizontal:
            arrays = [np.fliplr(arr) for arr in arrays]
        if vertical:
            arrays = [np.flipud(arr) for arr in arrays]
        return arrays
    
    @staticmethod
    def radiation_noise(data, alpha_range=(0.9, 1.1), beta=1/25):
        alpha = np.random.uniform(*alpha_range)
        noise = np.random.normal(loc=0., scale=1.0, size=data.shape)
        return alpha * data + beta * noise

    def mixture_noise(self, data, label, beta=1/25):
        alpha1, alpha2 = np.random.uniform(0.01, 1., size=2)
        noise = np.random.normal(loc=0., scale=1.0, size=data.shape)
        data2 = np.zeros_like(data)
        for  idx, value in np.ndenumerate(label):
            if value not in self.ignored_labels:
                l_indices = np.nonzero(self.labels == value)[0]
                l_indice = np.random.choice(l_indices)
                assert(self.labels[l_indice] == value)
                x, y = self.indices[l_indice]
                data2[idx] = self.data[x,y]
        return (alpha1 * data + alpha2 * data2) / (alpha1 + alpha2) + beta * noise
    
    #  获得长度数据
    def __len__(self):
        return len(self.indices)
    
    #  获得元素
    def __getitem__(self, i):
        x,y = self.indices[i]
        x1,y1 = x-self.patch_size // 2, y-self.patch_size // 2
        x2,y2 = x1+self.patch_size, y1+self.patch_size
        
        data = self.data[x1:x2,y1:y2]
        label = self.label[x1:x2,y1:y2]
        
        #  选择数据增强模式
        if self.flip_augmentation and self.patch_size > 1:  #
            data, label = self.flip(data, label)
        if self.radiation_augmentation and np.random.random() < 0.1:
                data = self.radiation_noise(data)
        if self.mixture_augmentation and np.random.random() < 0.2:
                data = self.mixture_noise(data, label)
        
        #  mat->np->tensor
        data = np.asarray(np.copy(data).transpose((2, 0, 1)), dtype='float32')
        label = np.asarray(np.copy(label), dtype='int64')

        data = torch.from_numpy(data)
        label = torch.from_numpy(label)
        
        #  提取中心标签
        if self.center_pixel and self.patch_size > 1:
            label = label[self.patch_size // 2, self.patch_size // 2]
        
        #  使用不可见光谱时删除未使用部分
        elif self.patch_size == 1:
            data = data[:, 0, 0]
            label = label[0, 0]
        
        #  进行3D卷积时增加一维
        if self.patch_size > 1:
            data = data.unsqueeze(0)
            
        return data,label

        dataset_collate:

def HyperX_collate(batch):
    datas = []
    labels = []
    for data, label in batch:
        datas.append(data)
        labels.append(label)
    datas = np.array(datas)
    labels = np.array(labels)
    return datas, labels

        在主函数中调用如下:

#  调用dataset
train_dataset = HyperX(img, train_gt,patch_size,IGNORED_LABELS,True,True,True,True,'full')
val_dataset = HyperX(img, val_gt,patch_size,IGNORED_LABELS,True,True,True,True,'full')

#  调用dataloader
train_loader = DataLoader(train_dataset,batch_size=batch_size,pin_memory=True,shuffle=True)
val_loader = DataLoader(val_dataset,batch_size=batch_size,pin_memory=True,shuffle=True)

三、数据展示

#  可视化展示
for item in train_dataset:
    img,label = item
    img = torch.squeeze(img,0)  #除去第0维度
    img = img.permute(1,2,0)    #调整通道位置
    print('tensor尺寸:{}'.format(img.shape))
    img = img.numpy()           #转换为numpy
    view1 = spy.imshow(data=img, bands=RGB_BANDS, title="train")  # 图像显示
    print('标签编号:{}'.format(label.numpy()))

        邻居个数patch_size设置为9,运行后得到如下结果:

                 

四、模拟训练

    print("模拟训练")
    for epoch in range(3): 
        step = 0  
        for data in train_loader:
            imgs, labels = data
            print(imgs.shape)
            print(labels.shape)
            img = imgs[0]
            img = torch.squeeze(img,0).permute(1,2,0).numpy()  #通道调整和numpy转换
            view1 = spy.imshow(data=img, bands=RGB_BANDS, title="train")  # 图像显示
        step=step+1
    input("按任意键继续")

         测试结果如下:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/585177.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

使用Nextcloud搭建私人云盘,并内网穿透实现公网远程访问

文章目录 摘要视频教程1. 环境搭建2. 测试局域网访问3. 内网穿透3.1 ubuntu本地安装cpolar3.2 创建隧道3.3 测试公网访问 4 配置固定http公网地址4.1 保留一个二级子域名4.1 配置固定二级子域名4.3 测试访问公网固定二级子域名 转载自cpolar极点云的文章&#xff1a;使用Nextcl…

好程序员:如果你从6月份开始学Java编程......

现在学习Java编程还来得及&#xff1f;好程序员可以明确的告诉你&#xff1a;当然了~新手入门快&#xff0c;非常容易学。Java计算机语言也是一门面向对象的语言&#xff0c;更加符合人类的思想&#xff0c;所求皆对象&#xff0c;并没有指针等一些难理解的知识。Java覆盖面宽、…

2023最新软件测试面试题大全(包含答案)

前言 在我认为&#xff0c;对于测试面试以及进阶的最佳学习方法莫过于刷题博客书籍视频总结&#xff0c;前几者博主将淋漓尽致地挥毫于这篇博客文章中&#xff0c;至于总结在于个人&#xff0c;实际上越到后面你会发现面试并不难&#xff0c;其次就是在刷题的过程中有没有去思…

Power BI许可证差异(免费、Pro、PPU、Embedded、Premium)

不可否认&#xff0c;在商业BI软件中Power BI是最强大的&#xff0c;在2023年的Gartner的魔力象限中Power BI又是第一名Microsoft named a Leader in the 2023 Gartner Magic Quadrant™ for Analytics and BI PlatformsI[1] image.png 目前还没有使用Power BI的&#xff0c;甚…

Microsoft Build 发布,开发者可能关注的重点→

又是一年一度的 Microsoft Build 了&#xff0c;你有和我一样熬夜看了吗&#xff1f;如果没有&#xff0c;那么你就错过了一场精彩的技术盛宴。本次的 Microsoft Build&#xff0c;有非常多的干货&#xff0c;围绕打造 Copilot 应用展开。我会将基于 Data AI 比较重要的内容列…

2.1. 类与对象

在 Java 中&#xff0c;类和对象是面向对象编程的基本构建块。类是一种模板&#xff0c;用于定义对象的属性和行为。对象是类的实例&#xff0c;具有类定义的属性和行为。 2.1.1. 类的定义 要定义一个类&#xff0c;可以使用以下语法&#xff1a; class ClassName {// 成员变…

Java学习路线(17)——日志框架

一、日志技术概述 &#xff08;1&#xff09;概念&#xff1a; 日志是一种将系统运行信息封装至文件的一种记录载体。 &#xff08;2&#xff09;优势&#xff1a; 输出语句日志技术输出位置只能是控制台文件或数据库取消日志需要修改代码达成无需修改代码多线程性能较差性能较…

AcrelEMS企业微电网能效管理系统-强化电力需求侧管理,缓解电力系统峰值压力

摘要 近年来全国用电负荷特别是居民用电负荷的快速增长&#xff0c;全国范围内夏季、冬季用电负荷“双峰”特征日益突出&#xff0c;极端气候现象多发增加了电力安全供应的压力。具有随机性、波动性、间歇性特征的可再生能源大规模接入电网对电力系统的稳定性带来新的挑战&…

财务共享服务中心建设流程是什么样的?

财务共享是当今众多企业在数智化转型道路上的首选模式&#xff0c;财务共享服务中心由于具备“标准化、流程化、资源共享、信息化”的特点&#xff0c;一改传统财务分散的运作模式&#xff0c;将资源集中共享&#xff0c;大大提升了财务管理效率&#xff0c;也为企业管理打下良…

Loki安装使用方式

Distributor 收到 HTTP 请求&#xff0c;用于存储流数据 通过 hash 环对数据流进行 hash Distributor将数据流发送到对应的Ingester及其副本上 Ingester 新建 Chunk 或将数据追加到已有Chunk 上 Distributor通过 HTTP连接发送响应信息 Loki 日志系统由以下3个部分组成&#xf…

每日一题——删除字符串中的所有相邻重复项

每日一题 删除字符串中的所有相邻重复项 题目链接 思路 这是一道用栈解决的典型题目 我们先来看看栈的基本性质&#xff1a; 栈&#xff1a;是一种特殊的线性表&#xff0c;其只允许在固定的一端进行插入和删除元素的操作。进行数据插入和删除操作的一端称为栈顶&#xff0c…

解决Github上传或者下载时失败的问题

总是出现push不到GitHub的问题, 这里来记录一下每次的解决方法 文章目录 2023年05月28日出现问题2023年05月28日再次出现问题2023年05月29日出现问题 2023年05月28日出现问题 push代码时出现如下图所示的错误 Failed to connect to 127.0.0.1 port 1080 after 2052 ms: Conne…

08.Stack和Queue

栈&#xff1a;先进后出 队列&#xff1a;先进先出 JVM的栈就是平常所说的一块内存。 此处所说的栈是数据结构 1. 栈(Stack) 1.1 概念 栈&#xff1a;一种特殊的线性表&#xff0c;其只允许在固定的一端进行插入和删除元素操作。进行数据插入和删除操作的一端称为栈 顶&…

迪赛智慧数——柱状图(基本柱状图):全国美食类门店TOP10地区

效果图 中国“最会吃”的省份&#xff0c;广东榜上有名&#xff0c;看看有你的家乡吗? 广东美食门店开店数量保持较好状态&#xff0c;重庆、上海、北京、天津开店率在5%以上。广东拥有众多的美食文化&#xff0c;美食门店数量也是全国最多的省份有99万家美食门店&#xff0c…

SpringBoot SSE服务端主动推送事件详解

一、SSE概述 1、SSE简介 SSE(Server Sent Event)&#xff0c;直译为服务器发送事件&#xff0c;也就是服务器主动发送事件&#xff0c;客户端可以获取到服务器发送的事件。 我们常见的 http 交互方式是客户端发起请求&#xff0c;服务端响应&#xff0c;然后一次请求完毕。但是…

Centos 7安装python 3.9.10

概述 Python是一种高级编程语言&#xff0c;它具有简单易学、可读性强、代码简洁等特点。Python由Guido van Rossum于1991年创造&#xff0c;最初被用作一种教学语言&#xff0c;但现在已经成为一种通用的编程语言。 Python支持多种编程范式&#xff0c;包括面向对象编程、函数…

如何运用R语言进行Meta分析在【文献计量分析、贝叶斯、机器学习等】多技术的融合

Meta分析是针对某一科研问题&#xff0c;根据明确的搜索策略、选择筛选文献标准、采用严格的评价方法&#xff0c;对来源不同的研究成果进行收集、合并及定量统计分析的方法&#xff0c;最早出现于“循证医学”&#xff0c;现已广泛应用于农林生态&#xff0c;资源环境等方面。…

【MySQL】MySql的底层数据结构

文章目录 前言索引结构及查找算法不适合做MySql的数据结构及其原因 一、BTree和BTree的引出1.1 BTree数据结构2.2 BTree数据结构 二、计算m阶&#xff0c;即BTree该取多少合适总结 前言 索引结构及查找算法 一个sql语句在mysql里究竟是如何运行的呢&#xff1f;又是怎么去查找…

如何在Linux系统中使用SCP命令传输文件和文件夹?

在Linux系统中&#xff0c;SCP&#xff08;Secure Copy&#xff09;是一种用于在本地和远程主机之间安全传输文件和文件夹的命令行工具。它基于SSH协议&#xff0c;并提供了加密和身份验证机制&#xff0c;确保数据的安全性和完整性。 本文将详细介绍如何使用SCP命令在Linux系统…

如何通过pytest进行更改自动化测试用例的执行顺序?

前言 在自动化测试中&#xff0c;自动化测试用例设计原则就是执行过程时不能存在依赖顺序&#xff0c;那么如果测试用例需要按照指定顺序执行&#xff0c;这个时候应该怎么做呢&#xff1f;目前单元测试框架中unittest没有办法改变测试用例的执行顺序&#xff0c;但是另一个单…