SENet实现遥感影像场景分类

news2024/11/16 15:51:21

今天我们分享SENet实现遥感影像场景分类。

数据集

本次实验我们使用的是NWPU-RESISC45 Dataset。NWPU Dataset 是一个遥感影像数据集,其中 NWPU-RESISC45 Dataset 是由西北工业大学创建的遥感图像场景分类可用基准,该数据集包含像素大小为 256*256 共计 31500 张图像,涵盖 45 个场景类别,其中每个类别有 700 张图像。

这 45 个场景类别包括飞机、机场、棒球场、篮球场、海滩、桥梁、丛林、教堂、圆形农田、云、商业区、密集住宅、沙漠、森林、高速公路、高尔夫球场、地面田径、港口、工业地区、交叉口、岛、湖、草地、中型住宅、移动房屋公园、山、立交桥、宫、停车场、铁路、火车站、矩形农田、河、环形交通枢纽、跑道、海、船舶、雪山、稀疏住宅、体育场、储水箱、网球场、露台、火力发电站和湿地。

数据集划分

首先我们可以对数据集进行划分,按训练集、验证集、测试集比例7:1.5:1.5进行划分。

import os
import shutil
import random

# 设置数据集根目录
data_root = './datasets/NWPU-RESISC45'  

# 设置训练集、验证集、测试集的目录
train_dir = './datasets/train'
val_dir = './datasets/val'
test_dir = './datasets/test'

# 创建目录
os.makedirs(train_dir, exist_ok=True)
os.makedirs(val_dir, exist_ok=True)
os.makedirs(test_dir, exist_ok=True)

# 获取所有子文件夹列表
class_folders = sorted(os.listdir(data_root))

# 定义训练集、验证集、测试集比例
train_ratio = 0.7
val_ratio = 0.15
test_ratio = 0.15

for class_folder in class_folders:
    class_path = os.path.join(data_root, class_folder)
    images = os.listdir(class_path)
    random.shuffle(images)  # 随机打乱顺序

    num_images = len(images)
    num_train = int(num_images * train_ratio)
    num_val = int(num_images * val_ratio)

    train_images = images[:num_train]
    val_images = images[num_train:num_train + num_val]
    test_images = images[num_train + num_val:]

    # 移动图像到对应目录
    for img in train_images:
        src = os.path.join(class_path, img)
        dest = os.path.join(train_dir, class_folder, img)
        os.makedirs(os.path.dirname(dest), exist_ok=True)
        shutil.copy(src, dest)

    for img in val_images:
        src = os.path.join(class_path, img)
        dest = os.path.join(val_dir, class_folder, img)
        os.makedirs(os.path.dirname(dest), exist_ok=True)
        shutil.copy(src, dest)

    for img in test_images:
        src = os.path.join(class_path, img)
        dest = os.path.join(test_dir, class_folder, img)
        os.makedirs(os.path.dirname(dest), exist_ok=True)
        shutil.copy(src, dest)

划分完毕后,数据集分别保存在train、val、test三个文件夹内。每个文件夹内有21个子文件夹分别对应21类。

SENet

SeNet(Squeeze-and-Excitation Networks)是一种卷积神经网络(CNN)架构,由Jie Hu、Li Shen和Gang Sun于2017年提出。SeNet旨在通过引入注意力机制来增强模型对重要特征的学习能力,从而提高CNN在图像分类等计算机视觉任务上的性能。 SeNet的关键创新在于引入了“Squeeze-and-Excitation”模块,这个模块可以在不增加网络复杂度的情况下,自适应地学习特征通道之间的相关性,并对每个通道进行加权,以增强重要特征的表示。它由两个关键步骤组成: Squeeze(压缩)阶段:通过全局池化操作(通常是全局平均池化),将特征图的每个通道的信息进行汇总,生成通道级别的描述信息。 Excitation(激发)阶段:在Squeeze阶段生成的描述信息基础上,引入了多层感知机(MLP)结构来学习每个通道的权重。这些权重用于重新加权特征图,以增强有助于任务的重要特征并抑制不重要的特征。 SeNet模块可以轻松地集成到各种CNN架构中,例如ResNet、Inception等,通过在这些网络中插入SeNet模块,可以提高模型的性能,使其更具有泛化能力。 SeNet的提出在图像分类、目标检测和语义分割等计算机视觉任务中取得了显著的性能提升,并成为了当时领域内的重要技术之一。 alt

import torch.nn as nn
from torch.nn import functional as F


class Residual(nn.Module):
    def __init__(self, in_channel, out_channel, use_1x1Conv=False, strides=1):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channel, out_channel, kernel_size=3, padding=1, stride=strides)
        self.bn1 = nn.BatchNorm2d(out_channel)
        self.conv2 = nn.Conv2d(out_channel, out_channel, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channel)

        if use_1x1Conv:
            self.conv3 = nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=strides)
        else:
            self.conv3 = None

    def forward(self, X):
        out = F.relu(self.bn1(self.conv1(X)))
        out = self.bn2(self.conv2(out))
        if self.conv3:
            X = self.conv3(X)
        out += X
        return F.relu(out)


def residualBlock(in_channel, out_channel, num_residuals, first_block=False):
    blks = []
    for i in range(num_residuals):
        if i == 0 and not first_block:
            blks.append(Residual(in_channel, out_channel, use_1x1Conv=True,
                                 strides=2))
        else:
            blks.append(Residual(out_channel, out_channel))

    return blks

class SEBlock(nn.Module):
    def __init__(self, C, r=16):
        super().__init__()
        self.squeeze = nn.AdaptiveAvgPool2d(1)
        self.excitation = nn.Sequential(
            nn.Linear(C, C//r, bias=False),
            nn.ReLU(),
            nn.Linear(C//r, C, bias=False),
            nn.Sigmoid())

    def forward(self, x):
        bs, c, _, _ = x.shape
        s = self.squeeze(x).view(bs, c)
        e = self.excitation(s).view(bs, c, 1, 1)
        return x * e.expand_as(x)

class SENet(nn.Module):
    def __init__(self, input_channel, n_classes):
        super().__init__()
        self.b1 = nn.Sequential(
            nn.Conv2d(input_channel, 64, kernel_size=7, stride=2, padding=3),
            nn.BatchNorm2d(64),nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
        self.b2 = nn.Sequential(*[SEBlock(C=64)])
        self.b3 = nn.Sequential(*residualBlock(64, 64, 2, first_block=True))
        self.b4 = nn.Sequential(*[SEBlock(C=64)])
        self.b5 = nn.Sequential(*residualBlock(64, 128, 2))
        self.b6 = nn.Sequential(*[SEBlock(C=128)])
        self.b7 = nn.Sequential(*residualBlock(128, 256, 2))
        self.b8 = nn.Sequential(*[SEBlock(C=256)])
        self.b9 = nn.Sequential(*residualBlock(256, 512, 2))
        self.b10 = nn.Sequential(*[SEBlock(C=512)])
        self.finalLayer = nn.Sequential(
            nn.AdaptiveAvgPool2d((1,1)),
            nn.Flatten(),
            nn.Linear(512, n_classes))

        self.b1.apply(self.init_weights)
        self.b2.apply(self.init_weights)
        self.b3.apply(self.init_weights)
        self.b4.apply(self.init_weights)
        self.b5.apply(self.init_weights)
        self.b6.apply(self.init_weights)
        self.b7.apply(self.init_weights)
        self.b8.apply(self.init_weights)
        self.b9.apply(self.init_weights)
        self.b10.apply(self.init_weights)
        self.finalLayer.apply(self.init_weights)

    def init_weights(self, layer):
        if type(layer) == nn.Conv2d:
            nn.init.kaiming_normal_(layer.weight, mode='fan_out')
        if type(layer) == nn.Linear:
            nn.init.normal_(layer.weight, std=1e-3)
        if type(layer) == nn.BatchNorm2d:
            nn.init.constant_(layer.weight, 1)
            nn.init.constant_(layer.bias, 0)
    

    def forward(self, X):
        out = self.b1(X)
        out = self.b2(out)
        out = self.b3(out)
        out = self.b4(out)
        out = self.b5(out)
        out = self.b6(out)
        out = self.b7(out)
        out = self.b8(out)
        out = self.b9(out)
        out = self.finalLayer(out)

        return out

训练过程

alt

精度与测试

「精度」

import torch
import torchvision.transforms as transforms
from torchvision import datasets
from models.SENet import SENet

# 定义测试集目录
test_dir = './datasets/test'

# 加载测试集数据
transform = transforms.Compose([
    transforms.Resize((256, 256)),  # 图像调整为模型输入大小
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

test_data = datasets.ImageFolder(root=test_dir, transform=transform)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=32, shuffle=False)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 加载模型
model = SENet(input_channel=3, n_classes=45).to(device)
model.load_state_dict(torch.load(f'SENet.pt', map_location='cuda:0'))

model.eval()

# 对测试集进行验证
correct = 0
total = 0
with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

accuracy = 100*correct / total
print(f"Accuracy on test set: {accuracy}")
alt

「测试」 这里我们从测试集中选取几张图片并在我们的GUI界面中进行测试看看

1

总结

感兴趣的可以按文末方式,免费获取数据集、完整代码与训练结果

获取方法

如有需要,请关注微信公众号「DataAssassin」后,后台回复「027」领取。

更多更多内容与代码请加入我们的星球! alt 加入前不要忘了领取优惠券哦! alt

本文由 mdnice 多平台发布

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1372330.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

CentOS本地部署SQL Server数据库无公网ip环境实现远程访问

文章目录 前言1.安装GeoServer2. windows 安装 cpolar3. 创建公网访问地址4. 公网访问Geo Servcer服务5. 固定公网HTTP地址 前言 GeoServer是OGC Web服务器规范的J2EE实现,利用GeoServer可以方便地发布地图数据,允许用户对要素数据进行更新、删除、插入…

23种设计模式精讲,配套23道编程题目 ,支持 C++、Java、Python、Go

关于设计模式的学习,大家应该还是看书或者看博客,但却没有一个边学边练的学习环境。 学完了一种设计模式 是不是应该去练一练? 所以卡码网 针对 23种设计,推出了 23道编程题目,来帮助大家练习设计模式,地…

oracle基本用户管理和权限分配

1.用户与模式的关系,一一对应的关系 2.创建与管理用户 2.1创建用户语法 CREATE user wdf IDENTIFIED by tiger--创建用户wdf,密码tiger DEFAULT tablespace users--用户的默认表空间 quota 10M on users;--在表空间的占用最大空间 注意:用户创建以后…

C练习——N个水手分椰子

题目: 五个水手在岛上发现一堆椰子,先由第1个水手把椰子分为等量的5堆,还剩下1个给了猴子,自己藏起1堆。然后,第2个水手把剩下的4堆混合后重新分为等量的5堆,还剩下1个给了猴子,自己藏起1堆。以…

2023-12-30 买卖股票的最佳时机 II和跳跃游戏以及跳跃游戏 II

122. 买卖股票的最佳时机 II 思路:关键点是每一次利用峰值来计算【画图好理解一点,就是计算陡坡的值】!每一次累加和的最大! 或者可以这样理解,把利润划分为每天的,如假如第 0 天买入,第 3 天卖出&#xf…

Mybatis之逆向工程

目录 一、逆向工程概述 二、逆向工程的实现 1、创建一个java工程 2、 第一步:mapper生成配置文件: 3、第二步:使用java类生成mapper文件: 4、第三步:拷贝生成的mapper文件到工程中指定的目录中 5、第四步Mapper接…

控制el-table的列显示隐藏

控制el-table的列显示隐藏,一般的话可以通过循环来实现,但是假如业务及页面比较复杂的话,list数组循环并不好用。 在我们的页面中el-table-column是固定的,因为现在是对现有的进行维护和迭代更新。 对需要控制列显示隐藏的页面进…

C语言可变参数输入

本博文源于笔者正在学习的可变参数输入&#xff0c;可变参数是c语言函数中的一部分&#xff0c;下面本文就以一个很小的demo演示可变参数的编写 问题来源 想要用可变参数进行多个整数相加 方法源码 #include<stdio.h> #include<stdlib.h> #include<stdarg.h…

Git 实战指南:常用指令精要手册(持续更新)

&#x1f451;专栏内容&#xff1a;Git⛪个人主页&#xff1a;子夜的星的主页&#x1f495;座右铭&#xff1a;前路未远&#xff0c;步履不停 目录 一、Git 安装过程1、Windows 下安装2、Cent os 下安装3、Ubuntu 下安装 二、配置本地仓库1、 初始化 Git 仓库2、配置 name 和 e…

Java微服务系列之 ShardingSphere - ShardingSphere-JDBC

&#x1f339;作者主页&#xff1a;青花锁 &#x1f339;简介&#xff1a;Java领域优质创作者&#x1f3c6;、Java微服务架构公号作者&#x1f604; &#x1f339;简历模板、学习资料、面试题库、技术互助 &#x1f339;文末获取联系方式 &#x1f4dd; 系列专栏目录 [Java项…

通过两台linux主机配置ssh实现互相免密登入

一 1.使用Xshell远程连接工工具生成公钥文件 2.生产密钥参数 3.生成公钥对 4.用户密钥信息 5.公钥注册 二 1.关闭服务端防火墙 ---systemctl stop firewalld 2.检查是否有/root/.ssh目录&#xff0c;没有则创建有则打开/root/.ssh/authorized_keys文件将密钥粘贴创建/ro…

麻省理工、Meta开源:无需人工标注,创新文生图模型

文生图领域一直面临着一个核心难题,就是有条件图像生成的效果&#xff0c;远超无条件的图像生成。有条件图像生成是指模型在生成图像的过程中,会额外使用类别、文本等辅助信息进行指导,这样可以更好的理解用户的文本意图&#xff0c;生成的图像质量也更高。 而无条件图像生成完…

【MYSQL】MYSQL 的学习教程(十一)之 MySQL 不同隔离级别,都使用了哪些锁

聊聊不同隔离级别下&#xff0c;都会使用哪些锁&#xff1f; 1. MySQL 锁机制 对于 MySQL 来说&#xff0c;如果只支持串行访问的话&#xff0c;那么其效率会非常低。因此&#xff0c;为了提高数据库的运行效率&#xff0c;MySQL 需要支持并发访问。而在并发访问的情况下&…

【LLM的概念理解能力】Concept Understanding In Large Language Models: An Empirical Study

大语言模型中的概念理解&#xff1a;一个实证研究 摘要 大语言模型&#xff08;LLMs&#xff09;已经在广泛的任务中证明了其卓越的理解能力和表达能力&#xff0c;并在现实世界的应用中显示出卓越的能力。因此&#xff0c;研究它们在学术界和工业界的值得信赖的性能的潜力和…

buuctf[极客大挑战 2019]BabySQL--联合注入、双写过滤

目录 1、测试万能密码&#xff1a; 2、判断字段个数 3、尝试联合注入 4、尝试双写过滤 5、继续尝试列数 6、查询数据库和版本信息 7、查询表名 8、没有找到和ctf相关的内容&#xff0c;查找其他的数据库 9、查看ctf数据库中的表 10、查询Flag表中的字段名 11、查询表…

C++学习笔记——对象的指针

目录 一、对象的指针 二、减少对象的复制开销 三、应用案例 游戏引擎 图像处理库 数据库管理系统 航空航天软件 金融交易系统 四、代码的案例应用 一、对象的指针 是一种常用的技术&#xff0c;用于处理对象的动态分配和管理。使用对象的指针可以实现以下几个方面的功…

Python GIL 一文全知道!

GIL 作为 Python 开发者心中永远的痛&#xff0c;在最近即将到来的更新中&#xff0c;终于要彻底解决了&#xff0c;整个 Python 社群都沸腾了 什么是GIL&#xff1f; GIL是英文学名global interpreter lock的缩写&#xff0c;中文翻译成全局解释器锁。GIL需要解决的是线程竞…

遥感影像-语义分割数据集:云数据集详细介绍及训练样本处理流程

原始数据集详情 简介&#xff1a;该云数据集包括150张RGB三通道的高分辨率图像&#xff0c;在全球不同区域的分辨率从0.5米到15米不等。这些图像采集自谷歌Earth的五种主要土地覆盖类型&#xff0c;即水、植被、湿地、城市、冰雪和贫瘠土地。 KeyValue卫星类型谷歌Earth覆盖区…

太惨了,又一个程序员被渣的开年大瓜

今天闲暇之余浏览了一下mm&#xff0c;忽然看见一条瓜&#xff1a;某东pdf瓜&#xff0c;一份19页的PDF文件&#xff0c;题为《婚房变赠予&#xff0c;京东渣女出轨连环套设计冤大头程序员》&#xff0c;点进去看了一下&#xff0c;简直炸裂了三观&#xff0c;男同志们一定要保…

EI级 | Matlab实现VMD-TCN-LSTM变分模态分解结合时间卷积长短期记忆神经网络多变量光伏功率时间序列预测

EI级 | Matlab实现VMD-TCN-LSTM变分模态分解结合时间卷积长短期记忆神经网络多变量光伏功率时间序列预测 目录 EI级 | Matlab实现VMD-TCN-LSTM变分模态分解结合时间卷积长短期记忆神经网络多变量光伏功率时间序列预测预测效果基本介绍程序设计参考资料 预测效果 基本介绍 1.【E…