基于图卷积神经网络(GCN)的高光谱图像分类详细教程(含python代码)

news2024/11/17 13:39:59

目录

一、背景

二、基于卷积神经网络的代码实现

1、安装依赖库

2、建立图卷积神经网络

3、建立数据的边

4、训练模型

5、可视化

三、项目代码


一、背景

图卷积神经网络(Graph Convolutional Networks, GCNs)在高光谱图像分类中是一种有效的方法,特别适用于处理具有复杂空间关系的数据。高光谱图像通常包含数百个甚至数千个连续的频谱波段,每个波段对应一个光谱特征,这使得传统的卷积神经网络在处理高光谱图像时面临困难,因为它们无法有效地捕获像素之间的空间关系。

GCNs通过利用图结构来解决这一问题,将像素(或者像素附近的区域)视为图中的节点,并利用这些节点之间的关系进行特征学习和分类。以下是GCNs在高光谱图像分类中的一些关键点和优势:

  1. 图结构建模:将高光谱图像中的像素视为图中的节点,像素之间的空间关系(例如邻近关系)作为图的边,这样就能够在整个图上利用节点的局部和全局信息。

  2. 卷积操作:GCN引入了图卷积操作,允许在图结构上进行类似于传统卷积神经网络中的卷积操作。这种操作可以捕获节点及其邻居的特征,并利用这些信息来提取更有意义的特征表示。

  3. 特征学习:通过多层的图卷积操作,GCNs能够逐步学习出更加抽象和高级的特征表示,这对于高光谱数据的复杂特征提取尤为重要。

  4. 分类器:最后一层通常是一个分类器,用于将学习到的特征映射到类别标签空间,从而进行分类。

  5. 适应性:GCNs在处理高光谱图像时具有很强的适应性和灵活性,能够处理不同大小和分辨率的图像,以及不同数量和配置的频谱波段。

总体来说,图卷积神经网络通过充分利用高光谱图像中像素之间的空间关系,有效地提升了分类性能,并在遥感图像分析和其他高维数据的处理中展现出了广阔的应用前景。

二、基于卷积神经网络的代码实现

下面我们以IP数据集为例子进行展开讲解。

1、安装依赖库
matplotlib==3.3.4
networkx==2.1
numpy==1.19.5
pandas==1.1.5
scikit_learn==1.5.1
scipy==1.5.4
seaborn==0.11.2
spectral==0.22.4
torch==1.7.1+cu110
torch_geometric==2.0.2
tqdm==4.62.3
2、建立图卷积神经网络
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
import torch
import torch.nn as nn

class GCN(torch.nn.Module):
    def __init__(self, num_node_features, num_classes):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(num_node_features, 32)
        self.conv1_bn_relu = nn.Sequential(
            nn.BatchNorm1d(32),
            nn.ReLU()
        )

        self.conv2 = GCNConv(32, 64)
        self.conv2_bn_relu = nn.Sequential(
            nn.BatchNorm1d(64),
            nn.ReLU()
        )

        self.cls = nn.Sequential(
            nn.Linear(64, num_classes),
        )

    def forward(self, edge, data):
        x, edge_index = data, edge
        x = self.conv1_bn_relu(self.conv1(x, edge_index))
        x = self.conv2_bn_relu(self.conv2(x, edge_index))

        return self.cls(x)
3、建立数据的边

首先进行PCA数据降维:

X_pca = applyPCA(X, numComponents=pca_components)

然后将无标签数据进行剔除:

    X_pca = X_pca.reshape(-1,pca_components)
    y = y.ravel()
    mask = y == 0

    # 剔除无标签的数据
    data = X_pca[~mask]
    label = y[~mask]

划分训练验证集(训练70%):

X_train, X_test, y_train, y_test = splitTrainTestSet(range(len(data)),label,trainRatio=0.7)

最后建立所有样本的边(这里取最近邻的样本为3):

Edge_build(data,k=3)

4、训练模型

加载数据和模型:

    X_train_index,X_test_index = utils.create_train_test('./data/'+patch_+'/train_index.txt',
                                                          './data/'+patch_+'/test_index.txt')
    data,label = utils.create_features('./data/'+patch_+'/data.txt',
                                        './data/'+patch_+'/label.txt')
    edge = pd.read_csv('./data/'+patch_+'/edge.txt', sep=" ", header=None).values.T

    # 建立模型
    model = GCN(30, 16)

训练模型:

class Trainer():
    def __init__(self, data,y,edge,X_train_index,X_test_index, model, optimizer, loss_function, epochs):
        self.y = y
        self.edge = torch.from_numpy(edge).type(torch.LongTensor).to(device)
        self.X_train_index = X_train_index
        self.X_test_index = X_test_index
        self.data = torch.from_numpy(data).type(torch.FloatTensor).to(device)
        self.model = model.to(device)
        self.optimizer = optimizer
        self.loss_function = loss_function
        self.epochs = epochs

        self.y_train = torch.from_numpy(y[X_train_index]).type(torch.LongTensor).to(device)
        self.y_test = torch.from_numpy(y[X_test_index]).type(torch.LongTensor).to(device)

        self.preds = None
    def train(self):
        pass

    def test(self):
        self.model.eval()
        pass

    trainer = Trainer(
        data=data,
        y=label,
        edge=edge,
        X_train_index=X_train_index,
        X_test_index=X_test_index,
        model=model,
        optimizer=optim.Adam(model.parameters(), lr=0.001),
        loss_function=nn.CrossEntropyLoss(),
        epochs=1000
    )

    trainer.train()
    trainer.test()

5、可视化
if __name__ == '__main__':
    patch_ = "IP"

    graph, A = utils.create_Graphs_with_attributes_adjadjency_matrix('./data/' + patch_ + '/edge.txt',
                                                                     './data/' + patch_ + '/data.txt')
    data, label = utils.create_features('./data/' + patch_ + '/data.txt',
                                        './data/' + patch_ + '/label.txt')
    edge = pd.read_csv('./data/' + patch_ + '/edge.txt', sep=" ", header=None).values.T

    model = GCN(30, 16)
    model.eval()
    net_params = torch.load("./weight/model.pkl")
    model.load_state_dict(net_params)  # 加载模型可学习参数

    trainer = Trainer(
        data=data,
        y=label,
        edge=edge,
        model=model,
    )

    pred = trainer.pre() + 1

    y_ = sio.loadmat('./data/Indian_pines_gt.mat')['indian_pines_gt']
    a, b = y_.shape
    print('Label shape: ', y_.shape)

    y = y_.ravel()
    mask = y == 0

    outputs = np.zeros_like(y)
    outputs[~mask] = pred

    outputs = outputs.reshape((a, b))

    import spectral
    import matplotlib.pyplot as plt

    predict_image = spectral.imshow(classes=outputs.astype(int), figsize=(5, 5))
    plt.savefig('./results/pre.png', dpi=300)
    plt.pause(1)

三、项目代码

本项目的代码通过以下链接下载:基于图卷积神经网络(GCN)的高光谱图像分类详细教程(含python代码)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1955981.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Unity + Hybridclr + Addressable + 微信小程序 热更新报错

报错时机: Generate All 怎么All 死活就是报错 生成微信小程序,并启动后 报错内容: MissingMethodException:AoT generic method notinstantiated in aot.assembly:Unity.ResourceManager:dll, 原因: Hybridclr 开发文档 解…

【人工智能】深度剖析:Midjourney与Stable Diffusion的全面对比

文章目录 🍊1 如何选择合适的AI绘画工具1.1 个人需求选择1.2 比较工具特点1.3 社区和资源 🍊2 Midjourney VS Stable Diffusion:深度对比与剖析 2.1 使用费用对比 2.2 使用便捷性与系统兼容性对比 2.3 开源与闭源对比 2.4 图片质量对比 2.5 上…

MATLAB基础应用精讲-【数模应用】Poisson 回归分析(附R语言代码实现)

目录 前言 知识储备 基于泊松回归、负二项回归模型 数据分布介绍 模型介绍 模型的选择 案例介绍 算法原理 泊松回归 数学模型 适用条件 参数估计与假设检验 SPSSAU Poisson 回归案例 1、背景 2、理论 3、操作 4、SPSSAU输出结果 5、文字分析 6、剖析 疑难解…

【探索Linux】P.42(传输层 —— TCP面向字节流 | TCP粘包问题 | TCP异常情况 )

阅读导航 引言一、TCP面向字节流二、TCP粘包问题1. 粘包原因2. 粘包类型3. 粘包的影响4. 解决粘包的方法5. 对于UDP协议来说, 是否也存在 "粘包问题" 呢? 三、TCP异常情况温馨提示 引言 继上篇深入剖析TCP协议的拥塞控制、延迟应答和捎带应答之后,本文将…

TCP 协议的 time_wait 超时时间

优质博文:IT-BLOG-CN 灵感来源 Time_Wait 产生的时机 TCP四次挥手的流程 如上所知:客户端在收到服务端第三次FIN挥手后,就会进入TIME_WAIT状态,开启时长为2MSL的定时器。 【1】MSL是Maximum Segment Lifetime报文最大生存时间…

【六】集群管理工具

1. 群控命令 查看java程序的运行状态是最常用的指令。首先在ubuntu1输入该find命令,查找jps位置,需要首先完成java jdk的安装和配置。 find / -name jps回显如下,jps的位置确定了。rootubuntu1:/usr/local/bin# find / -name jps /usr/loca…

C语言 | Leetcode C语言题解之第300题最长递增子序列

题目&#xff1a; 题解&#xff1a; int lengthOfLIS(int* nums, int numsSize) {if(numsSize<1)return numsSize;int dp[numsSize],result1;for(int i0;i<numsSize;i){dp[i]1;}for(int i0;i<numsSize;i){printf("%d ",dp[i]);}for(int i1;i<numsSize;i…

科普文:万字详解Kafka基本原理和应用

一、Kafka 简介 1. 消息引擎系统ABC Apache Kafka是一款开源的消息引擎系统&#xff0c;也是一个分布式流处理平台。除此之外&#xff0c;Kafka还能够被用作分布式存储系统&#xff08;极少&#xff09;。 A. 常见的两种消息引擎系统传输协议&#xff08;即用什么方式把消息…

git 、shell脚本

git 文件版本控制 安装git yum -y install git 创建仓库 将文件提交到暂存 git add . #将暂存区域的文件提交仓库 git commit -m "说明" #推送到远程仓库 git push #获取远程仓库的更新 git pull #克隆远程仓库 git clone #分支&#xff0c;提高代码的灵活性 #检查分…

模板-树上点差分

题目链接&#xff1a;松鼠的新家 图解&#xff1a; 模板&#xff1a; #include <bits/stdc.h> #define int long long using namespace std; const int inf 0x3f3f3f3f3f3f3f3f; const int N 3e55; int n; vector<int>g[N]; int d[N],fa[N][35],dep[N]; int a[…

Java | Leetcode Java题解之第301题删除无效的括号

题目&#xff1a; 题解&#xff1a; class Solution {public List<String> removeInvalidParentheses(String s) {int lremove 0;int rremove 0;List<Integer> left new ArrayList<Integer>();List<Integer> right new ArrayList<Integer>(…

DS1302时钟芯片全解析——概况,性能,MCU连接,样例代码

DS1302概述&#xff1a; 数据&#xff1a; DS1302是一个可充电实时时钟芯片&#xff0c;包含时钟&#xff08;24小时格式或12小时格式&#xff09;、日历&#xff08;年&#xff0c;月&#xff0c;日&#xff0c;星期&#xff09;、31字节RAM&#xff08;断电数据丢失&#x…

【Test】 Qt 多元素控件

文章目录 1. Qt 中的多元素控件2. QListWidget 1. Qt 中的多元素控件 xxWidget 和 xxView之间的区别 2. QListWidget 小案例&#xff1a;实现下图

WSL快速入门

1. WSL介绍 WSL文档地址&#xff1a;https://learn.microsoft.com/zh-cn/windows/wsl WSL&#xff1a;全称 Windows Subsystem for Linux&#xff0c;即windows上的Linux子系统&#xff08;虚拟机工具&#xff09;。是Win10推出的全新特性&#xff0c;可以更轻量地在Windows系统…

R语言统计分析——整合和重构

参考资料&#xff1a;R语言实战【第2版】 R中提供了许多用来整合&#xff08;aggregate&#xff09;和重塑&#xff08;reshape&#xff09;数据的强大方法。在整合数据时&#xff0c;往往将多组观测替换为根据这些观测计算的描述性统计量。在重塑数据时&#xff0c;则会通过修…

【Unity插件】Editor Console Pro:提升开发效率的神器

在 Unity 开发过程中&#xff0c;控制台&#xff08;Console&#xff09;是我们排查错误、获取信息的重要窗口。而 Editor Console Pro 则是 Unity 编辑器控制台的强大替代品&#xff0c;为 Unity 的控制台带来了更多实用的功能和改进&#xff0c;极大地提升了开发效率。 一、…

[硬件]-电路噪声

电路噪声 1.电路噪声来源 本征噪声&#xff1a;晶体管、电阻&#xff1b;外部噪声&#xff1a;电源、参考、偏置、衬底、串扰&#xff1b; 将电路的输入短接&#xff0c;理想情况下输出为0&#xff0c;但实际输出不为0&#xff0c;即为电路噪声。 2.噪声大小衡量方法 2.1 时…

【Oracle 进阶之路】Oracle 简介

一、简述 Oracle Database&#xff0c;又名Oracle RDBMS&#xff0c;或简称Oracle。是甲骨文公司的一款关系数据库管理系统。它是在数据库领域一直处于领先地位的产品。可以说Oracle数据库系统是世界上流行的关系数据库管理系统&#xff0c;系统可移植性好、使用方便、功能强&…

初学Mybatis之多对一查询 association 和一对多查询 collection

XML 映射器 多对一&#xff1a;关联&#xff08;association&#xff09; 一对多&#xff1a;集合&#xff08;collection&#xff09; mysql 创建教师、学生表&#xff0c;插入数据 create table teacher(id int(10) primary key,name varchar(30) default null ) engineI…

Meta 发布地表最大、最强大模型 Llama 3.1

最近这一两周看到不少互联网公司都已经开始秋招提前批了。不同以往的是&#xff0c;当前职场环境已不再是那个双向奔赴时代了。求职者在变多&#xff0c;HC 在变少&#xff0c;岗位要求还更高了。 最近&#xff0c;我们又陆续整理了很多大厂的面试题&#xff0c;帮助一些球友解…