SSF-CNN:空间光谱融合的卷积光谱图像超分网络

news2024/11/29 22:54:32

SSF-CNN: SPATIAL AND SPECTRAL FUSION WITH CNN FOR HYPERSPECTRAL IMAGE SUPER-RESOLUTION

文章目录

  • SSF-CNN: SPATIAL AND SPECTRAL FUSION WITH CNN FOR HYPERSPECTRAL IMAGE SUPER-RESOLUTION
    • 简介
    • 解决问题
    • 网络框架
    • 代码实现
    • 训练部分
    • 运行结果

简介

​ 本文提出了一种利用空间和光谱进行高光谱融合图像超分辨率的新型CNN架构,首先是对高光谱图像进行双三次插值,使其空间分辨率大小和多光谱一致,然后进行concat操作。使用类似于SRCNN的网络框架对融合超分的图像进行优化,最后输出高分辨率高光谱超分图像。

​ 对于PDCon,也就是引入了部分密集连接,将输入concat到每一个卷积层后面。
Hyperspectral-Image-Super-Resolution-Benchmark——光谱图像超分基准-CSDN博客
Paper: IEEE
Code:https://github.com/miraclefan777/SSFCNN

2023-11-25_16-06-09

解决问题

  1. 传统方法通过基于优化的方法恢复 HR-HS 图像的质量在很大程度上取决于预定义的约束。此外,由于约束项数量较多,优化过程通常涉及较高的计算成本。
  2. 执行HSI SR的一个直接想法是直接应用这样的网络来放大LR-HS图像的空间维度或HR-RGB图像的光谱维度,我们称之为Spatial-CNN和Spectral-CNN,这两种单图像方法忽略了两种图像特有的信息互补优势。

网络框架

  1. 原始的SRCNN是将图片映射到Ycbcr空间,并只使用其中的 Y 分量作为输入来预测 HR Y 图像,该论文则是将图片的通道信息以及空间信息整个进行输入
  2. 原始SRCNN卷积核大小第1,2修改为3*3,增加上下文信息,同时为了避免高维数据(padding为same,保持和原有特征图大小一致)

代码实现

class SSFCNNnet(nn.Module):
    def __init__(self, num_spectral=31, scale_factor=8, pdconv=False):
        super(SSFCNNnet, self).__init__()
        self.scale_factor = scale_factor
        self.pdconv = pdconv

        self.Upsample = nn.Upsample(mode='bicubic', scale_factor=self.scale_factor)

        self.conv1 = nn.Conv2d(num_spectral + 3, 64, kernel_size=3, padding="same")
        if pdconv:
            self.conv2 = nn.Conv2d(64 + 3, 32, kernel_size=3, padding="same")
            self.conv3 = nn.Conv2d(32 + 3, num_spectral, kernel_size=5, padding="same")
        else:
            self.conv2 = nn.Conv2d(64, 32, kernel_size=3, padding="same")
            self.conv3 = nn.Conv2d(32, num_spectral, kernel_size=5, padding="same")

        self.relu = nn.ReLU(inplace=True)

    def forward(self, lr_hs, hr_ms):
        """
            :param lr_hs:LR-HSI低分辨率的高光谱图像
            :param hr_ms:高分辨率的多光谱图像
            :return:
        """
        # 对LR-HSI低分辨率图像进行上采样,让其分辨率更高
        lr_hs_up = self.Upsample(lr_hs)
        # 将上采样后的LR-HSI低分辨率图像与高分辨率的多光谱图像进行拼接
        x = torch.cat((lr_hs_up, hr_ms), dim=1)

        x = self.relu(self.conv1(x))
        if self.pdconv:
            x = torch.cat((x, hr_ms), dim=1)
            x = self.relu(self.conv2(x))
            x = torch.cat((x, hr_ms), dim=1)
        else:
            x = self.relu(self.conv2(x))

        out = self.conv3(x)
        return out

如果需要使用密集连接,只需要在初始化网络模型时,传参pdconv=True

训练部分

未提供自定义dataset类,根据自己的dateset进行参数的修改即可。

import argparse
from calculate_metrics import Loss_SAM, Loss_RMSE, Loss_PSNR
from models.SSFCNNnet import SSFCNNnet
from torch.utils.data.dataloader import DataLoader
from tqdm import tqdm
from train_dataloader import CAVEHSIDATAprocess
from utils import create_F, fspecial,AverageMeter
import os
import copy
import torch
import torch.nn as nn

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--model', type=str, default="SSFCNNnet")
    parser.add_argument('--train-file', type=str, required=True)
    parser.add_argument('--eval-file', type=str, required=True)
    parser.add_argument('--outputs-dir', type=str, required=True)
    parser.add_argument('--scale', type=int, default=2)
    parser.add_argument('--lr', type=float, default=1e-4)
    parser.add_argument('--batch-size', type=int, default=32)
    parser.add_argument('--num-workers', type=int, default=0)
    parser.add_argument('--num-epochs', type=int, default=400)
    parser.add_argument('--seed', type=int, default=123)
    args = parser.parse_args()

    assert args.model in ['SSFCNNnet', 'PDcon_SSF']

    outputs_dir = os.path.join(args.outputs_dir, '{}'.format(args.model))
    if not os.path.exists(outputs_dir):
        os.makedirs(outputs_dir)

    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    torch.manual_seed(args.seed)

    # 训练参数
    # loss_func = nn.L1Loss(reduction='mean').cuda()
    criterion = nn.MSELoss()


    #################数据集处理#################
    R = create_F()
    PSF = fspecial('gaussian', 8, 3)
    downsample_factor = 8
    training_size = 64
    stride = 32
    stride1 = 32

    train_dataset = CAVEHSIDATAprocess(args.train_file, R, training_size, stride, downsample_factor, PSF, 20)
    train_dataloader = DataLoader(dataset=train_dataset, batch_size=args.batch_size, shuffle=True)

    eval_dataset = CAVEHSIDATAprocess(args.eval_file, R, training_size, stride, downsample_factor, PSF, 12)
    eval_dataloader = DataLoader(dataset=eval_dataset, batch_size=1)

    #################数据集处理#################

    # 模型
    if args.model == 'SSFCNNnet':
        model = SSFCNNnet().cuda()
    else:
        model = SSFCNNnet(pdconv=True).cuda()

    best_weights = copy.deepcopy(model.state_dict())
    best_epoch = 0
    best_psnr = 0.0

    # 模型初始化
    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.Linear)):
            nn.init.xavier_uniform_(m.weight)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    optimizer = torch.optim.Adam([
        {'params': model.conv1.parameters()},
        {'params': model.conv2.parameters()},
        {'params': model.conv3.parameters(), 'lr': args.lr * 0.1}
    ], lr=args.lr)

    start_epoch = 0


    for epoch in range(start_epoch, args.num_epochs):
        model.train()
        epoch_losses = AverageMeter()

        with tqdm(total=(len(train_dataset) - len(train_dataset) % args.batch_size)) as t:
            t.set_description('epoch:{}/{}'.format(epoch, args.num_epochs - 1))

            for data in train_dataloader:
                label, lr_hs, hr_ms = data

                label = label.to(device)
                lr_hs = lr_hs.to(device)
                hr_ms = hr_ms.to(device)
                lr = optimizer.param_groups[0]['lr']
                pred = model(hr_ms, lr_hs)
                loss = criterion(pred, label)

                epoch_losses.update(loss.item(), len(label))

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                t.set_postfix(loss='{:.6f}'.format(epoch_losses.avg), lr='{0:1.8f}'.format(lr))
                t.update(len(label))

        # torch.save(model.state_dict(), os.path.join(outputs_dir, 'epoch_{}.pth'.format(epoch)))

        if epoch % 5 == 0:
            model.eval()
            val_loss = AverageMeter()

            SAM = Loss_SAM()
            RMSE = Loss_RMSE()
            PSNR = Loss_PSNR()

            sam = AverageMeter()
            rmse = AverageMeter()
            psnr = AverageMeter()

            for data in eval_dataloader:
                label, lr_hs, hr_ms = data
                lr_hs = lr_hs.to(device)
                hr_ms = hr_ms.to(device)
                label = label.cpu().numpy()

                with torch.no_grad():
                    preds = model(hr_ms, lr_hs).cpu().numpy()

                sam.update(SAM(preds, label), len(label))
                rmse.update(RMSE(preds, label), len(label))
                psnr.update(PSNR(preds, label), len(label))

            if psnr.avg > best_psnr:
                best_epoch = epoch
                best_psnr = psnr.avg
                best_weights = copy.deepcopy(model.state_dict())

            print('eval psnr: {:.2f}  RMSE: {:.2f}  SAM: {:.2f} '.format(psnr.avg, rmse.avg, sam.avg))

运行结果

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1253720.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

基于STC12C5A60S2系列1T 8051单片读写掉电保存数据IIC总线器件24C02多字节并显示在液晶显示器LCD1602上应用

基于STC12C5A60S2系列1T 8051单片多字节读写掉电保存数据IIC总线器件24C02多字节并显示在液晶显示器LCD1602上应用 STC12C5A60S2系列1T 8051单片机管脚图STC12C5A60S2系列1T 8051单片机I/O口各种不同工作模式及配置STC12C5A60S2系列1T 8051单片机I/O口各种不同工作模式介绍IIC通…

pytorch分布式训练

1 基本概念 rank:进程号,在多进程上下文中,我们通常假定rank 0是第一个进程或者主进程,其它进程分别具有1,2,3不同rank号,这样总共具有4个进程 node:物理节点,可以是一个…

Java游戏之王者荣耀

首先创建类: 游戏运行结果如下: GameFrame类 所需图片: GameObject类 Turret类 所需图片: TurretBlue类 TurretRed类 Champion类 所需图片: 单个: move包: ChampionDaji类 所需图片: Minio…

css加载会造成阻塞吗??

前言 前几天面试问到了这个问题,当时这个答得不敢确定哈哈,虽然一面还是过了 现在再分析下这个,总结下,等下次遇到就能自信得回答,666 准备工作 为了完成本次测试,先来科普一下,如何利用chr…

输出后,我悟了!

大家好,我是木川 今天和前同事吃饭聊天,谈到了输出,今天简单谈下关于输出的重要性 一、为什么要输出 1、不输出容易忘,如果不输出很容易就忘记了,如果再遇见一次,还是需要重新学习,实际上是浪费…

【如何学习Python自动化测试】—— Python 的 unittest 框架

10 、Python 的 unittest 框架 10.1 Unittest 框架介绍 Unittest是Python语言中的一种测试框架,是Python标准库中的一个模块。它可以帮助开发者编写自动化测试,可以进行单元测试、集成测试、功能测试等各种类型的测试。 Unittest的特点是简单易学&#…

用 Addon 增强 Node.js 和 Electron 应用的原生能力

前言 Node.js Addon 是 Node.js 中为 JavaScript 环境提供 C/C 交互能力的机制。其形态十分类似 Java 的 JNI,都是通过提供一套 C/C SDK,用于在 C/C 中创建函数方法、进行数据转换,以便 JavaScript / Java 等语言进行调用。这样编写的代码通常…

突破技术障碍:软件工程师如何应对项目中的难题?

在软件开发项目中,工程师常常会遇到各种技术难题。这些难题可能涉及到复杂的算法、不兼容的系统、难以预见的软件行为,或者其他许多方面。 以下是一些策略和方法,可以帮助软件工程师有效地应对这些挑战: 1、理解问题:…

MySQL数据库——存储函数(介绍、案例)

目录 介绍 案例 介绍 存储函数是有返回值的存储过程,存储函数的参数只能是IN类型的。具体语法如下: CREATE FUNCTION 存储函数名称 ([ 参数列表 ]) RETURNS type [characteristic ...] BEGIN-- SQL语句RETURN ...;END ; characteristic说明&#xf…

Matplotlib子图的创建_Python数据分析与可视化

Matplotlib子图的创建 plt.axes创建子图fig.add_axes()创建子图 plt.axes创建子图 前面已经介绍过plt.axes函数,这个函数默认配置是创建一个标准的坐标轴,填满整张图。 它还有一个可选的参数,由图形坐标系统的四个值构成。这四个值表示为坐…

05 _ 系统设计目标(三):如何让系统易于扩展?

从架构设计上来说,高可扩展性是一个设计的指标,它表示可以通过增加机器的方式来线性提高系统的处理能力,从而承担更高的流量和并发。 你可能会问:“在架构设计之初,为什么不预先考虑好使用多少台机器,支持…

C语言之指针知识点总结

C语言之指针知识点总结 文章目录 C语言之指针知识点总结1. 初识指针1.1 取地址操作符 &1.2 指针变量1.3 解引用操作符 *1.4 指针变量1.4.1 大小1.4.2 指针类型的意义 1.5 void*指针1.6 const关键字1.61 const修饰变量1.6.2 const修饰指针变量 1.7 指针的运算1.7.1 指针-整数…

堆和栈的区别 重点来说一下堆和栈;堆与栈之间的联系

文章目录 堆和栈的区别重点来说一下堆和栈:那么堆和栈是怎么联系起来的呢? 堆与栈的区别 很明显: 今天来聊一聊java中的堆和栈,工作当中这两个也是经常遇到的,知识我们没有去注意理论上的这些内容,今天就来分享一下。…

MIPI 打怪升级之DSI篇

MIPI 打怪升级之DSI篇 目录 1 Overview2 DSI Mode 2.1 Video 模式2.2 Command 模式3 DSI Physical Layer 3.1 数据流控3.2 双向性3.3 Video Mode Interfaces3.4 Command Mode Interfaces3.5 Clock4 多通道管理 4.1 通道数匹配4.2 线上数据分布5 DSI 协议 5.1 包格式 5.1.1 短包…

Scrapy爬虫异步框架(一篇文章齐全)

1、Scrapy框架初识 2、Scrapy框架持久化存储(点击前往查阅) 3、Scrapy框架内置管道(点击前往查阅) 4、Scrapy框架中间件(点击前往查阅) Scrapy 是一个开源的、基于Python的爬虫框架,它提供了…

CleanMyMac X4.14.5Crack最新Mac电脑清理优化最佳应用

CleanMyMac X 4.14.5是用于清理和优化Mac的最佳应用程序和强大工具。它看起来很棒而且很容易理解。该软件可以清理、保护、优化、稳定和维护您的 Mac 系统。您可以立即删除不必要的、不寻常的、无用的垃圾文件、损坏的文件垃圾,并释放大量内存空间。此外&#xff0c…

【Unity实战】切换场景加载进度和如何在后台异步加载具有庞大世界的游戏场景,实现无缝衔接(附项目源码)

文章目录 最终效果前言一、绘制不同的场景二、切换场景加载进度1. 简单实现2. 优化 三、角色移动和跳跃控制四、添加虚拟摄像机五、触发器动态加载场景六、最终效果参考源码完结 最终效果 前言 观看本文后,我的希望你对unity场景管理有更好的理解,并且能…

Error PostCSS plugin autoprefixer requires PostCSS 8

文章目录 一、情况一二、情况二三、总结 在启动 vue项目时,突然控制台报错: Error: PostCSS plugin autoprefixer requires PostCSS 8。然后依次出现下面几种情况,依次解决完,项目就可以正常启动了 一、情况一 error in ./src/…

04 _ 系统设计目标(二):系统怎样做到高可用?

这里将探讨高并发系统设计的第二个目标——高可用性。 高可用性(High Availability,HA)是你在系统设计时经常会听到的一个名词,它指的是系统具备较高的无故障运行的能力。 我们在很多开源组件的文档中看到的HA方案就是提升组件可…

蓝桥杯第2119题 特殊时间 C++ 思维暴力

题目 思路和解题方法 1110 代表 1110年11月10号11点10分1110 4*4*4 有0111 1011 1101 1110 可以符合年 月日 时分秒的都有4种例如 1113有1113 1131 1311 3111 年份符合月日只有11 13 时分秒 只有11 13 11 31 13 11 无31 11 c 代码 #include <bits/stdc.h> using…