[高光谱]PyTorch使用CNN对高光谱图像进行分类

news2025/1/10 12:56:22

项目原地址:

Hyperspectral-Classificationicon-default.png?t=N6B9https://github.com/eecn/Hyperspectral-ClassificationDataLoader讲解:

[高光谱]使用PyTorch的dataloader加载高光谱数据icon-default.png?t=N6B9https://blog.csdn.net/weixin_37878740/article/details/130929358

一、模型加载

        在原始项目中,提供了14种模型可供选择,从最简单的SVM到3D-CNN,这里以2D-CNN为例,在原项目中需要将model属性设置为:sharma。

         模型通过一个get_model(.)函数获得,该函数一共四个返回(model, optimizer, loss, hyperparams;分别为:模型,迭代器,损失函数,超参数),输入为模型类别。

        进入函数内部,找到对应的函数体如下:

elif name == 'sharma':
        kwargs.setdefault('batch_size', 60)        #batch_szie
        epoch = kwargs.setdefault('epoch', 30)     #迭代数
        lr = kwargs.setdefault('lr', 0.05)         #学习率
        center_pixel = True                        #是否开启中心像素模型
        # We assume patch_size = 64
        kwargs.setdefault('patch_size', 64)        #patch_szie,即图像块大小
        model = SharmaEtAl(n_bands, n_classes, patch_size=kwargs['patch_size'])  #模型本体
        optimizer = optim.SGD(model.parameters(), lr=lr, weight_decay=0.0005)    #迭代器
        criterion = nn.CrossEntropyLoss(weight=kwargs['weights'])            #交叉熵损失函数
        kwargs.setdefault('scheduler', optim.lr_scheduler.MultiStepLR(optimizer, milestones=[epoch // 2, (5 * epoch) // 6], gamma=0.1))

        这里设置了一部分超参数,同时设置了patch_size为64(此概念可以参见dataloader篇),采用的损失函数为常见的交叉熵损失函数,而模型本体则是使用SharmaEtAl(.)进行加载。

二、模型本体

        跳转至SharmaEtAl(nn.Module),其继承自nn.model,输入参数3个,分别为:输入通道数、分类数、图块尺寸。

def __init__(self, input_channels, n_classes, patch_size=64):

  该网络的结构如图,模型中里面包含3个卷积、2个bn、2个池化和2个全连接,如下:

# 卷积层1
self.conv1 = nn.Conv3d(1, 96, (input_channels, 6, 6), stride=(1,2,2))
self.conv1_bn = nn.BatchNorm3d(96)
self.pool1 = nn.MaxPool3d((1, 2, 2))
# 卷积层2
self.conv2 = nn.Conv3d(1, 256, (96, 3, 3), stride=(1,2,2))
self.conv2_bn = nn.BatchNorm3d(256)
self.pool2 = nn.MaxPool3d((1, 2, 2))
# 卷积层3
self.conv3 = nn.Conv3d(1, 512, (256, 3, 3), stride=(1,1,1))

# 展平函数
self.features_size = self._get_final_flattened_size()

# 由两个全连接组成的分类器
self.fc1 = nn.Linear(self.features_size, 1024)
self.dropout = nn.Dropout(p=0.5)
self.fc2 = nn.Linear(1024, n_classes)

        其中的展平函数_get_final_flattened_size(.),并不实际参与前向传递,仅计算转换后的通道数。

    def _get_final_flattened_size(self):
        with torch.no_grad():
            x = torch.zeros((1, 1, self.input_channels,
                             self.patch_size, self.patch_size))
            x = F.relu(self.conv1_bn(self.conv1(x)))
            x = self.pool1(x)
            print(x.size())
            b, t, c, w, h = x.size()
            x = x.view(b, 1, t*c, w, h) 
            x = F.relu(self.conv2_bn(self.conv2(x)))
            x = self.pool2(x)
            print(x.size())
            b, t, c, w, h = x.size()
            x = x.view(b, 1, t*c, w, h) 
            x = F.relu(self.conv3(x))
            print(x.size())
            _, t, c, w, h = x.size()
        return t * c * w * h

        实际的前向传递如下:

    def forward(self, x):
        # 卷积块1
        x = F.relu(self.conv1_bn(self.conv1(x)))
        x = self.pool1(x)
        # 获取tensor尺寸
        b, t, c, w, h = x.size()
        # 调整tensor尺寸
        x = x.view(b, 1, t*c, w, h) 
        # 卷积块2
        x = F.relu(self.conv2_bn(self.conv2(x)))
        x = self.pool2(x)
        # 获取tensor尺寸
        b, t, c, w, h = x.size()
        # 调整tensor尺寸
        x = x.view(b, 1, t*c, w, h) 
        # 卷积块3
        x = F.relu(self.conv3(x))
        # 调整tensor尺寸
        x = x.view(-1, self.features_size)
        # 分类器
        x = self.fc1(x)
        x = self.dropout(x)
        x = self.fc2(x)
        return x

三、训练与测试

        主函数中,训练和测试结构如下:

        try:
            train(model, optimizer, loss, train_loader, hyperparams['epoch'],
                  scheduler=hyperparams['scheduler'], device=hyperparams['device'],
                  supervision=hyperparams['supervision'], val_loader=val_loader,
                  display=viz)
        except KeyboardInterrupt:
            # Allow the user to stop the training
            pass

        probabilities = test(model, img, hyperparams)
        prediction = np.argmax(probabilities, axis=-1)

        训练被封装在train(.)函数中,测试封装在test(.)函数中,下面逐一来看。

        首先是train函数,这里省去外围部分,仅看核心的循环控制段。

# 外循环控制,用于控制轮次(epoch)
for e in tqdm(range(1, epoch + 1), desc="Training the network"):
        # 进入训练模式
        net.train()
        avg_loss = 0.

        # 从dataloader中取出图像(data)和标签(target)
        for batch_idx, (data, target) in tqdm(enumerate(data_loader), total=len(data_loader)):
            # 如果是GPU模式则需要转换为cuda格式
            data, target = data.to(device), target.to(device)

        #---实际的训练部分---#
            # 冻结梯度
            optimizer.zero_grad()
            # 训练模式(监督训练/半监督训练)
            if supervision == 'full':
                # 前向传递
                output = net(data)
                #target = target - 1
                # 交叉熵损失函数
                loss = criterion(output, target)
            elif supervision == 'semi':
                outs = net(data)
                output, rec = outs
                #target = target - 1
                loss = criterion[0](output, target) + net.aux_loss_weight * criterion[1](rec, data)
        #---实际的训练部分---#
            # 损失函数反向传递
            loss.backward()
            # 迭代器步进
            optimizer.step()

            # 记录损失函数
            avg_loss += loss.item()
            losses[iter_] = loss.item()
            mean_losses[iter_] = np.mean(losses[max(0, iter_ - 100):iter_ + 1])

            iter_ += 1
            del(data, target, loss, output)

        接下来是test函数,与train不同的是,其参数为:model, img, hyperparams。其中img,是一整张高光谱图像,而不是由DataSet块采样后的图像块。故其结构也与train大不相同。

        在进行测试的时候,需要一个滑动窗口(sliding_window)函数将其进行切块以满足图像输入的要求。同时还需要一个grouper函数将其组装为batch送入神经网络中。所以我们可以看到循环控制的最外层实际上就是上面两个函数来组成的。

    # 图像切块
    iterations = count_sliding_window(img, **kwargs) // batch_size
    for batch in tqdm(grouper(batch_size, sliding_window(img, **kwargs)),
                      total=(iterations),
                      desc="Inference on the image"
                      ):
        #  锁定梯度
        with torch.no_grad():
            #  逐像素模式
            if patch_size == 1:
                data = [b[0][0, 0] for b in batch]
                data = np.copy(data)
                data = torch.from_numpy(data)
            # 其他模式
            else:
                data = [b[0] for b in batch]
                data = np.copy(data)
                data = data.transpose(0, 3, 1, 2)
                data = torch.from_numpy(data)
                data = data.unsqueeze(1)

            indices = [b[1:] for b in batch]
            # 类型转换
            data = data.to(device)
            # 前向传递
            output = net(data)
            if isinstance(output, tuple):
                output = output[0]
            output = output.to('cpu')

            if patch_size == 1 or center_pixel:
                output = output.numpy()
            else:
                output = np.transpose(output.numpy(), (0, 2, 3, 1))
            for (x, y, w, h), out in zip(indices, output):

            # 将得到的像素平装回原尺寸
                if center_pixel:
                    probs[x + w // 2, y + h // 2] += out
                else:
                    probs[x:x + w, y:y + h] += out
    return probs

        这个函数会使用上述的两个函数,将图像切割成可以放入神经网络的尺寸并逐个进行前向传递,最后将得到的所有像素的结果按照原来的尺寸组成一个结果矩阵返回。

        最后,这个结果由一个argmax函数得到其概率最大的预测结果:

prediction = np.argmax(probabilities, axis=-1)

四、结果计算

        在完成上述步骤后,由metrics(.)函数计算最终的模型结果:

run_results = metrics(prediction, test_gt, ignored_labels=hyperparams['ignored_labels'], n_classes=N_CLASSES)

        其函数体如下:

def metrics(prediction, target, ignored_labels=[], n_classes=None):
    """Compute and print metrics (accuracy, confusion matrix and F1 scores).

    Args:
        prediction: list of predicted labels
        target: list of target labels
        ignored_labels (optional): list of labels to ignore, e.g. 0 for undef
        n_classes (optional): number of classes, max(target) by default
    Returns:
        accuracy, F1 score by class, confusion matrix
    """
    ignored_mask = np.zeros(target.shape[:2], dtype=np.bool)
    for l in ignored_labels:
        ignored_mask[target == l] = True
    ignored_mask = ~ignored_mask
    #target = target[ignored_mask] -1
    target = target[ignored_mask]
    prediction = prediction[ignored_mask]

    results = {}

    n_classes = np.max(target) + 1 if n_classes is None else n_classes

    cm = confusion_matrix(
        target,
        prediction,
        labels=range(n_classes))

    results["Confusion matrix"] = cm

    # Compute global accuracy
    total = np.sum(cm)
    accuracy = sum([cm[x][x] for x in range(len(cm))])
    accuracy *= 100 / float(total)

    results["Accuracy"] = accuracy

    # Compute F1 score
    F1scores = np.zeros(len(cm))
    for i in range(len(cm)):
        try:
            F1 = 2. * cm[i, i] / (np.sum(cm[i, :]) + np.sum(cm[:, i]))
        except ZeroDivisionError:
            F1 = 0.
        F1scores[i] = F1

    results["F1 scores"] = F1scores

    # Compute kappa coefficient
    pa = np.trace(cm) / float(total)
    pe = np.sum(np.sum(cm, axis=0) * np.sum(cm, axis=1)) / \
        float(total * total)
    kappa = (pa - pe) / (1 - pe)
    results["Kappa"] = kappa

    return results

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/878307.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

怎样才能免费使用Qt开发闭源商业软件?

Qt 是一个跨平台的应用程序开发框架,其使用遵循 GNU Lesser General Public License(LGPL)开源许可协议。根据 LGPL 许可协议,您可以将 Qt 用于闭源商业软件,但是您需要满足以下条件: 1. 在您的软件中使用…

Python实现SSA智能麻雀搜索算法优化卷积神经网络分类模型(CNN分类算法)项目实战

说明:这是一个机器学习实战项目(附带数据代码文档视频讲解),如需数据代码文档视频讲解可以直接到文章最后获取。 1.项目背景 麻雀搜索算法(Sparrow Search Algorithm, SSA)是一种新型的群智能优化算法,在2020年提出&a…

Synopsys EDA数字设计与仿真

参考如下文章安装Synopsys EDA开发工具 https://blog.csdn.net/tugouxp/article/details/132255002?csdn_share_tail%7B%22type%22%3A%22blog%22%2C%22rType%22%3A%22article%22%2C%22rId%22%3A%22132255002%22%2C%22source%22%3A%22tugouxp%22%7D Synopsys EDA工具的结构 下…

【Android Framework系列】第10章 PMS之Hook实现广播的调用

1 前言 前面章节我们学习了【Android Framework系列】第4章 PMS原理我们了解了PMS原理,【Android Framework系列】第9章 AMS之Hook实现登录页跳转我们知道AMS可以Hook拦截下来实现未注册Activity页面的跳转,本章节我们来尝试一下HookPMS实现广播的发送。…

React入门 jsx学习笔记

一、JSX介绍 概念:JSX是 JavaScript XML(HTML)的缩写,表示在 JS 代码中书写 HTML 结构 作用:在React中创建HTML结构(页面UI结构) 优势: 采用类似于HTML的语法,降低学…

chatserver服务器开发笔记

chatserver服务器开发笔记 1 chatserver2 开发环境3 编译 1 chatserver 集群聊天服务器和客户端代码,基于muduo、redis、mysql实现。 学习于https://fixbug.ke.qq.com/ 本人已经挂github:https://github.com/ZixinChen-S/chatserver/tree/main 需要该项…

ASR 语音识别接口封装和分析

这个文档主要是介绍一下我自己封装了 6 家厂商的短语音识别和实时流语音识别接口的一个包,以及对这些接口的一个对比。分别是,阿里,快商通,百度,腾讯,科大,字节。 zxmfke/asrfactory (github.c…

神秘的ip地址8.8.8.8,到底是什么类型的DNS服务器?

下午好,我的网工朋友。 DNS,咱们网工配置网络连接或者路由器时,高低得和这玩意儿打交道吧。 它是互联网中用于将人类可读的域名(例如http://www.example.com)转换为计算机可理解的IP地址(例如192.0.2.1&a…

机器学习中基本的数据结构说明

数据维度或数据结构 当我们在机器学习或深度学习的领域内处理数据,我们通常会遇到四种主要的数据结构:标量,向量,矩阵和张量。理解这些基本数据结构是非常重要的,因为它们是机器学习算法和神经网络的核心。下面是对这…

AssetBundle总结

文章目录 目的打包过程打包时的分组策略和压缩方式资源的载入和卸载其它:Manifest、校验、视图工具思维导图 前言: 大佬文章链接(据此总结的) 目的 避免软件因资源占用空间太大,导致运行缓慢 避免每次更新资源&#…

【STM32】高效开发工具CubeMonitor快速上手

工欲善其事必先利其器。拥有一个辅助测试工具,能极大提高开发项目的效率。STM32CubeMonitor系列工具能够实时读取和呈现其变量,从而在运行时帮助微调和诊断STM32应用,类似于一个简单的示波器。它是一款基于流程的图形化编程工具,类…

Michael.W基于Foundry精读Openzeppelin第26期——ERC1820Implementer.sol

Michael.W基于Foundry精读Openzeppelin第26期——ERC1820Implementer.sol 0. 版本0.1 ERC1820Implementer.sol 1. 目标合约2. 代码精读2.1 _registerInterfaceForAddress(bytes32 interfaceHash, address account) internal2.2 canImplementInterfaceForAddress(bytes32 interf…

Java的反射机制、Lambda表达式和枚举

Java的反射机制、Lambda表达式和枚举 文章目录 Java的反射机制、Lambda表达式和枚举1.反射机制反射的概念、用途、优缺点反射相关的类及使用(重要!!)相关类Class类:代表类实体,表示类和接口Field类&#xf…

URLSearchParams:JavaScript中的URL查询参数处理工具

文章目录 导言:一、URLSearchParams的来历二、URLSearchParams的作用三、URLSearchParams的方法和属性四、使用示例五、注意事项六、结论参考资料 导言: 在Web开发中,处理URL查询参数是一项常见的任务。为了简化这一过程,JavaScr…

论文阅读——Adversarial Eigen Attack on Black-Box Models

Adversarial Eigen Attack on Black-Box Models 作者:Linjun Zhou, Linjun Zhou 攻击类别:黑盒(基于梯度信息),白盒模型的预训练模型可获得,但训练数据和微调预训练模型的数据不可得&#xff…

龙迅LT86102UXE产品概括,HDMI2.0转一分二HDMI2.0/1.4,支持音频剥离,支持4K60HZ

LT86102UXE 1. 一般说明 龙迅 LT86102UXE HDMI2.0 分路器具有符合 HDMI2.0/1.4 规范的 1:2 分路器、最大 6Gbps 高速数据速率、自适应均衡 RX 输入和预加重 TX 输出(用于支持长电缆应用)、内部 TX 通道交换以实现灵活的 PCB 布线。 LT86102…

A Survey for In-context Learning

A Survey for In-context Learning 摘要: 随着大语言模型(LLMs)能力的增长,上下文学习(ICL)已经成为一个NLP新的范式,因为LLMs仅基于几个训练样本让内容本身增强。现在已经成为一个新的趋势去探索ICL来评价和extrapolate LLMs的能力。在这篇…

vue3+element-plus表格默认排序default-sort失效问题

场景 在使用动态数据渲染的场景&#xff0c;el-table设置默认属性default-sort失效。 原因 el-table的default-sort属性是针对静态数据的&#xff0c;如果是动态数据&#xff0c;default-sort则无法监听到。 案例&#xff1a;静态数据 <template><el-table:data&…

一文分析多少杠杆最高

在进行加杠杆操作时&#xff0c;合约产品通常会有一定的杠杆比例限制&#xff0c;这是由监管机构或交易平台设定的。杠杆比例限制的目的是为了控制风险&#xff0c;避免过度杠杆化导致的潜在损失过大。 不同的交易平台和合约产品可以有不同的杠杆比例限制。一般来说&#xff0…

堆叠注入进阶--(buuctf-随便注、GYCTF-black_list)【多方法详解】

了解一下 堆叠注入基础知识及其他题目&#xff1a; SQL-堆叠注入 终于有时间来填填坑了 Buuctf-随便注 算是堆叠注入中非常经典的题目了。 随便试试就能看到黑名单&#xff1a; 没了select&#xff0c;其实大概率就是堆叠注入 先探测一下&#xff1a; 1;show databases;…