BiseNet实现遥感影像地物分类

news2024/9/28 13:20:57

遥感地物分类通过对遥感图像中的地物进行准确识别和分类,为资源管理、环境保护、城市规划、灾害监测等领域提供重要信息,有助于实现精细化管理和科学决策,提升社会治理和经济发展水平。深度学习遥感地物分类在提高分类精度、自动化程度、处理大规模数据、普适性以及推动遥感技术创新和发展等方面都具有重要的意义。本文将利用深度学习BiseNet实现遥感地物分类。

数据集

本文使用的数据集为WHDLD数据集[1](Wuhan dense labeling dataset)。WHDLD数据集包括4940张高分辨率遥感影像,包含6种土地覆盖类型,影像尺寸均被裁剪至256×256像素。下面是一些数据集示例。 alt

BiSeNet

BiseNet[2](Bilateral Segmentation Network)是一种用于图像分割的深度学习网络。它具有双边分割的特点,可以同时处理空间信息和上下文信息,从而实现高效、准确的图像分割。

具体来说,BiseNet由两个分支组成:空间路径(spatial path)和上下文路径(context path)。其中,空间路径具有较小的感受野,可以捕获丰富的空间信息并生成高分辨率的特征图;而上下文路径则具有较大的感受野,可以捕获更多的上下文信息并生成低分辨率的特征图。这两个路径通过一个特征融合模块进行融合,从而生成既包含丰富空间信息又包含上下文信息的分割结果。

在BiseNet中,还有一些关键的技术和设计,如轻量级模型设计、注意力机制、特征融合等,这些技术和设计可以进一步提升网络的性能和效率。 alt

网络复现

resnet18

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.model_zoo as modelzoo

resnet18_url = 'https://download.pytorch.org/models/resnet18-5c106cde.pth'


from torch.nn import BatchNorm2d


def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)


class BasicBlock(nn.Module):
    def __init__(self, in_chan, out_chan, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(in_chan, out_chan, stride)
        self.bn1 = BatchNorm2d(out_chan)
        self.conv2 = conv3x3(out_chan, out_chan)
        self.bn2 = BatchNorm2d(out_chan)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = None
        if in_chan != out_chan or stride != 1:
            self.downsample = nn.Sequential(
                nn.Conv2d(in_chan, out_chan,
                          kernel_size=1, stride=stride, bias=False),
                BatchNorm2d(out_chan),
                )

    def forward(self, x):
        residual = self.conv1(x)
        residual = self.bn1(residual)
        residual = self.relu(residual)
        residual = self.conv2(residual)
        residual = self.bn2(residual)

        shortcut = x
        if self.downsample is not None:
            shortcut = self.downsample(x)

        out = shortcut + residual
        out = self.relu(out)
        return out


def create_layer_basic(in_chan, out_chan, bnum, stride=1):
    layers = [BasicBlock(in_chan, out_chan, stride=stride)]
    for i in range(bnum-1):
        layers.append(BasicBlock(out_chan, out_chan, stride=1))
    return nn.Sequential(*layers)


class Resnet18(nn.Module):
    def __init__(self):
        super(Resnet18, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
                               bias=False)
        self.bn1 = BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = create_layer_basic(64, 64, bnum=2, stride=1)
        self.layer2 = create_layer_basic(64, 128, bnum=2, stride=2)
        self.layer3 = create_layer_basic(128, 256, bnum=2, stride=2)
        self.layer4 = create_layer_basic(256, 512, bnum=2, stride=2)
        self.init_weight()

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        feat8 = self.layer2(x) # 1/8
        feat16 = self.layer3(feat8) # 1/16
        feat32 = self.layer4(feat16) # 1/32
        return feat8, feat16, feat32

    def init_weight(self):
        state_dict = modelzoo.load_url(resnet18_url)
        self_state_dict = self.state_dict()
        for k, v in state_dict.items():
            if 'fc' in k: continue
            self_state_dict.update({k: v})
        self.load_state_dict(self_state_dict)

    def get_params(self):
        wd_params, nowd_params = [], []
        for name, module in self.named_modules():
            if isinstance(module, (nn.Linear, nn.Conv2d)):
                wd_params.append(module.weight)
                if not module.bias is None:
                    nowd_params.append(module.bias)
            elif isinstance(module, nn.modules.batchnorm._BatchNorm):
                nowd_params += list(module.parameters())
        return wd_params, nowd_params

BiSeNet


import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision

from .resnet import Resnet18

from torch.nn import BatchNorm2d


class ConvBNReLU(nn.Module):

    def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1, *args, **kwargs):
        super(ConvBNReLU, self).__init__()
        self.conv = nn.Conv2d(in_chan,
                out_chan,
                kernel_size = ks,
                stride = stride,
                padding = padding,
                bias = False)
        self.bn = BatchNorm2d(out_chan)
        self.relu = nn.ReLU(inplace=True)
        self.init_weight()

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x

    def init_weight(self):
        for ly in self.children():
            if isinstance(ly, nn.Conv2d):
                nn.init.kaiming_normal_(ly.weight, a=1)
                if not ly.bias is None: nn.init.constant_(ly.bias, 0)


class UpSample(nn.Module):

    def __init__(self, n_chan, factor=2):
        super(UpSample, self).__init__()
        out_chan = n_chan * factor * factor
        self.proj = nn.Conv2d(n_chan, out_chan, 1, 1, 0)
        self.up = nn.PixelShuffle(factor)
        self.init_weight()

    def forward(self, x):
        feat = self.proj(x)
        feat = self.up(feat)
        return feat

    def init_weight(self):
        nn.init.xavier_normal_(self.proj.weight, gain=1.)


class BiSeNetOutput(nn.Module):

    def __init__(self, in_chan, mid_chan, n_classes, up_factor=32, *args, **kwargs):
        super(BiSeNetOutput, self).__init__()
        self.up_factor = up_factor
        out_chan = n_classes
        self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1)
        self.conv_out = nn.Conv2d(mid_chan, out_chan, kernel_size=1, bias=True)
        self.up = nn.Upsample(scale_factor=up_factor,
                mode='bilinear', align_corners=False)
        self.init_weight()

    def forward(self, x):
        x = self.conv(x)
        x = self.conv_out(x)
        x = self.up(x)
        return x

    def init_weight(self):
        for ly in self.children():
            if isinstance(ly, nn.Conv2d):
                nn.init.kaiming_normal_(ly.weight, a=1)
                if not ly.bias is None: nn.init.constant_(ly.bias, 0)

    def get_params(self):
        wd_params, nowd_params = [], []
        for name, module in self.named_modules():
            if isinstance(module, (nn.Linear, nn.Conv2d)):
                wd_params.append(module.weight)
                if not module.bias is None:
                    nowd_params.append(module.bias)
            elif isinstance(module, nn.modules.batchnorm._BatchNorm):
                nowd_params += list(module.parameters())
        return wd_params, nowd_params


class AttentionRefinementModule(nn.Module):
    def __init__(self, in_chan, out_chan, *args, **kwargs):
        super(AttentionRefinementModule, self).__init__()
        self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1)
        self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size= 1, bias=False)
        self.bn_atten = BatchNorm2d(out_chan)
        #  self.sigmoid_atten = nn.Sigmoid()
        self.init_weight()

    def forward(self, x):
        feat = self.conv(x)
        atten = torch.mean(feat, dim=(2, 3), keepdim=True)
        atten = self.conv_atten(atten)
        atten = self.bn_atten(atten)
        #  atten = self.sigmoid_atten(atten)
        atten = atten.sigmoid()
        out = torch.mul(feat, atten)
        return out

    def init_weight(self):
        for ly in self.children():
            if isinstance(ly, nn.Conv2d):
                nn.init.kaiming_normal_(ly.weight, a=1)
                if not ly.bias is None: nn.init.constant_(ly.bias, 0)


class ContextPath(nn.Module):
    def __init__(self, *args, **kwargs):
        super(ContextPath, self).__init__()
        self.resnet = Resnet18()
        self.arm16 = AttentionRefinementModule(256, 128)
        self.arm32 = AttentionRefinementModule(512, 128)
        self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
        self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
        self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0)
        self.up32 = nn.Upsample(scale_factor=2.)
        self.up16 = nn.Upsample(scale_factor=2.)

        self.init_weight()

    def forward(self, x):
        feat8, feat16, feat32 = self.resnet(x)

        avg = torch.mean(feat32, dim=(2, 3), keepdim=True)
        avg = self.conv_avg(avg)

        feat32_arm = self.arm32(feat32)
        feat32_sum = feat32_arm + avg
        feat32_up = self.up32(feat32_sum)
        feat32_up = self.conv_head32(feat32_up)

        feat16_arm = self.arm16(feat16)
        feat16_sum = feat16_arm + feat32_up
        feat16_up = self.up16(feat16_sum)
        feat16_up = self.conv_head16(feat16_up)

        return feat16_up, feat32_up # x8, x16

    def init_weight(self):
        for ly in self.children():
            if isinstance(ly, nn.Conv2d):
                nn.init.kaiming_normal_(ly.weight, a=1)
                if not ly.bias is None: nn.init.constant_(ly.bias, 0)

    def get_params(self):
        wd_params, nowd_params = [], []
        for name, module in self.named_modules():
            if isinstance(module, (nn.Linear, nn.Conv2d)):
                wd_params.append(module.weight)
                if not module.bias is None:
                    nowd_params.append(module.bias)
            elif isinstance(module, nn.modules.batchnorm._BatchNorm):
                nowd_params += list(module.parameters())
        return wd_params, nowd_params


class SpatialPath(nn.Module):
    def __init__(self, *args, **kwargs):
        super(SpatialPath, self).__init__()
        self.conv1 = ConvBNReLU(3, 64, ks=7, stride=2, padding=3)
        self.conv2 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1)
        self.conv3 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1)
        self.conv_out = ConvBNReLU(64, 128, ks=1, stride=1, padding=0)
        self.init_weight()

    def forward(self, x):
        feat = self.conv1(x)
        feat = self.conv2(feat)
        feat = self.conv3(feat)
        feat = self.conv_out(feat)
        return feat

    def init_weight(self):
        for ly in self.children():
            if isinstance(ly, nn.Conv2d):
                nn.init.kaiming_normal_(ly.weight, a=1)
                if not ly.bias is None: nn.init.constant_(ly.bias, 0)

    def get_params(self):
        wd_params, nowd_params = [], []
        for name, module in self.named_modules():
            if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
                wd_params.append(module.weight)
                if not module.bias is None:
                    nowd_params.append(module.bias)
            elif isinstance(module, nn.modules.batchnorm._BatchNorm):
                nowd_params += list(module.parameters())
        return wd_params, nowd_params


class FeatureFusionModule(nn.Module):
    def __init__(self, in_chan, out_chan, *args, **kwargs):
        super(FeatureFusionModule, self).__init__()
        self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0)
        ## use conv-bn instead of 2 layer mlp, so that tensorrt 7.2.3.4 can work for fp16
        self.conv = nn.Conv2d(out_chan,
                out_chan,
                kernel_size = 1,
                stride = 1,
                padding = 0,
                bias = False)
        self.bn = nn.BatchNorm2d(out_chan)
        #  self.conv1 = nn.Conv2d(out_chan,
        #          out_chan//4,
        #          kernel_size = 1,
        #          stride = 1,
        #          padding = 0,
        #          bias = False)
        #  self.conv2 = nn.Conv2d(out_chan//4,
        #          out_chan,
        #          kernel_size = 1,
        #          stride = 1,
        #          padding = 0,
        #          bias = False)
        #  self.relu = nn.ReLU(inplace=True)
        self.init_weight()

    def forward(self, fsp, fcp):
        fcat = torch.cat([fsp, fcp], dim=1)
        feat = self.convblk(fcat)
        atten = torch.mean(feat, dim=(2, 3), keepdim=True)
        atten = self.conv(atten)
        atten = self.bn(atten)
        #  atten = self.conv1(atten)
        #  atten = self.relu(atten)
        #  atten = self.conv2(atten)
        atten = atten.sigmoid()
        feat_atten = torch.mul(feat, atten)
        feat_out = feat_atten + feat
        return feat_out

    def init_weight(self):
        for ly in self.children():
            if isinstance(ly, nn.Conv2d):
                nn.init.kaiming_normal_(ly.weight, a=1)
                if not ly.bias is None: nn.init.constant_(ly.bias, 0)

    def get_params(self):
        wd_params, nowd_params = [], []
        for name, module in self.named_modules():
            if isinstance(module, (nn.Linear, nn.Conv2d)):
                wd_params.append(module.weight)
                if not module.bias is None:
                    nowd_params.append(module.bias)
            elif isinstance(module, nn.modules.batchnorm._BatchNorm):
                nowd_params += list(module.parameters())
        return wd_params, nowd_params


class BiSeNetV1(nn.Module):

    def __init__(self, n_classes, aux_mode='train', *args, **kwargs):
        super(BiSeNetV1, self).__init__()
        self.cp = ContextPath()
        self.sp = SpatialPath()
        self.ffm = FeatureFusionModule(256, 256)
        self.conv_out = BiSeNetOutput(256, 256, n_classes, up_factor=8)
        self.aux_mode = aux_mode
        if self.aux_mode == 'train':
            self.conv_out16 = BiSeNetOutput(128, 64, n_classes, up_factor=8)
            self.conv_out32 = BiSeNetOutput(128, 64, n_classes, up_factor=16)
        self.init_weight()

    def forward(self, x):
        H, W = x.size()[2:]
        feat_cp8, feat_cp16 = self.cp(x)
        feat_sp = self.sp(x)
        feat_fuse = self.ffm(feat_sp, feat_cp8)

        feat_out = self.conv_out(feat_fuse)
        if self.aux_mode == 'train':
            feat_out16 = self.conv_out16(feat_cp8)
            feat_out32 = self.conv_out32(feat_cp16)
            return feat_out, feat_out16, feat_out32
        elif self.aux_mode == 'eval':
            return feat_out,
        elif self.aux_mode == 'pred':
            feat_out = feat_out.argmax(dim=1)
            return feat_out
        else:
            raise NotImplementedError

    def init_weight(self):
        for ly in self.children():
            if isinstance(ly, nn.Conv2d):
                nn.init.kaiming_normal_(ly.weight, a=1)
                if not ly.bias is None: nn.init.constant_(ly.bias, 0)

    def get_params(self):
        wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params = [], [], [], []
        for name, child in self.named_children():
            child_wd_params, child_nowd_params = child.get_params()
            if isinstance(child, (FeatureFusionModule, BiSeNetOutput)):
                lr_mul_wd_params += child_wd_params
                lr_mul_nowd_params += child_nowd_params
            else:
                wd_params += child_wd_params
                nowd_params += child_nowd_params
        return wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params

训练过程精度变化

alt

测试精度

alt

结果展示

alt

总结

今天的分享到此结束,感兴趣的点点关注,后续将分享更多案例。

参考资料

[1]

WHDLD: https://sites.google.com/view/zhouwx/dataset#h.p_hQS2jYeaFpV0

[2]

BiSeNet: https://arxiv.org/abs/1808.00897

本文由 mdnice 多平台发布

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1284119.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

java后端技术演变杂谈(未完结)

1.0版本javaWeb:原始servletjspjsbc 早期的jsp:htmljava,页面先在后端被解析,里面的java代码动态渲染完成后,成为纯html,再通过服务器发送给浏览器显示。 缺点: 服务器压力很大,因为…

深入微服务架构 | 微服务与k8s架构解读

微服务项目架构解读 ① 什么是微服务? 微服务是指开发一个单个小型的但有业务功能的服务,每个服务都有自己的处理和轻量通讯机制,可以部署在单个或多个服务器上。 微服务也指一种种松耦合的、有一定的有界上下文的面向服务架构。也就是说&…

C++数据结构:B树

目录 一. 常见的搜索结构 二. B树的概念 三. B树节点的插入和遍历 3.1 插入B树节点 3.2 B树遍历 四. B树和B*树 4.1 B树 4.2 B*树 五. B树索引原理 5.1 索引概述 5.2 MyISAM 5.3 InnoDB 六. 总结 一. 常见的搜索结构 表示1为在实际软件开发项目中,常用…

链表【2】

文章目录 🥝24. 两两交换链表中的节点🥑题目🌽算法原理🥬代码实现 🍎143. 重排链表🍒题目🍅算法原理🍓代码实现 🥝24. 两两交换链表中的节点 🥑题目 题目链接…

【超详细】vue项目:Tinymce富文本使用教程以及踩坑总结+功能扩展

【【超详细】vue项目:Tinymce富文本使用教程以及踩坑总结功能扩展 引言:一、 开始二、快速开始1、安装Tinymce 三、封装成Vue组件1、文件结构2、index.vue3、dynamicLoadScript.js4、plugin.js5、toolbar.js 四、使用Tinymce组件五、业务逻辑实现1、添加…

vue中的this.$nextTick().then()

MENU 示例一示例二sortsplicepushrandomfloorMathwhile演示 示例一 let reorganize function (arr){let rest [];while (arr.length > 0) {let random Math.floor(Math.random() * arr.length);// 把获取到的值放到新定义的数组中rest.push(arr[random]);// 这句代码的作…

Leetcode每日一题学习训练——Python3版(从二叉搜索树到更大和树)

版本说明 当前版本号[20231204]。 版本修改说明20231204初版 目录 文章目录 版本说明目录从二叉搜索树到更大和树理解题目代码思路参考代码 原题可以点击此 1038. 从二叉搜索树到更大和树 前去练习。 从二叉搜索树到更大和树 给定一个二叉搜索树 root (BST),请…

网络安全卫士:上海迅软DSE的员工上网管理策略大揭秘!

在日常办公中,企业员工可能会在互联网上有意或无意的将一些包含内部重要信息的内容发布出去,从而造成不必要的违规及泄密风险,因此对终端用户进行规范的上网行为管理,既能有效预防重要数据泄密,同时也能提高员工办公效…

Java数据结构之《直接插入排序》(难度系数75)

一、前言: 这是怀化学院的:Java数据结构中的一道难度中等的一道编程题(此方法为博主自己研究,问题基本解决,若有bug欢迎下方评论提出意见,我会第一时间改进代码,谢谢!) 后面其他编程题只要我写完…

新书推荐——《Copilot和ChatGPT编程体验:挑战24个正则表达式难题》

《Copilot和ChatGPT编程体验:挑战24个正则表达式难题》呈现了两方竞争的格局。一方是专业程序员David Q. Mertz,是网络上最受欢迎的正则表达式教程的作者。另一方则是强大的AI编程工具OpenAI ChatGPT和GitHub Copilot。 比赛规则如下:David编…

OpenResty(nginx+lua+resty-http)实现访问鉴权

OpenResty(nginxluaresty-http)实现访问鉴权 最近用BI框架解决了一些报表需求并生成了公开链接,现在CMS开发人员打算将其嵌入到业务系统中,结果发现公开链接一旦泄露任何人都可以访问,需要实现BI系统报表与业务系统同步的权限控制。但是目前…

7、Qt延时的使用

一、说明 平时用到两种延时方式QThread::sleep()和QTimer::singleShot() 1、QThread::sleep() QThread类中如下三个静态函数: QThread::sleep(n); //延迟n秒 QThread::msleep(n); //延迟n毫秒 QThread::usleep(n); //延迟n微妙 这种方式使用简单,但是会阻…

X540t2关于手动安装intel驱动

首先去intel驱动官网下载,win10和win11驱动一样 https://www.intel.cn/content/www/cn/zh/download/18293/intel-network-adapter-driver-for-windows-10.html 然后下载下来解压 将Wired_driver_28.2_x64.exe修改成Wired_driver_28.2_x64.zip文件再解压 打开设备管…

[Mac软件]HitPaw Video Converter 功能强大的视频格式转换编辑软件激活版

软件介绍: 以令人难以置信的速度将无损视频和音乐转换为1000多种格式:MP4、MOV、AVI、VOB、MKV等。不仅适用于普通编解码器,也适用于高级VP9、ProRes和Opus编码器。这解决了您不支持格式的所有问题,并允许您在任何平台和设备上播…

语义分割 DeepLab V1网络学习笔记 (附代码)

论文地址:https://arxiv.org/abs/1412.7062 代码地址:GitHub - TheLegendAli/DeepLab-Context 1.是什么? DeepLab V1是一种基于VGG模型的语义分割模型,它使用了空洞卷积和全连接条件随机(CRF)来提高分割…

SQL手工注入漏洞测试(PostgreSQL数据库)-墨者

———靶场专栏——— 声明:文章由作者weoptions学习或练习过程中的步骤及思路,非正式答案,仅供学习和参考。 靶场背景: 来源: 墨者学院 简介: 安全工程师"墨者"最近在练习SQL手工注入漏洞&#…

10、pytest通过assert进行断言

官方实例 # content of test_assert1.pydef f():return 3def test_function():assert f() 4def test_assert_desc():a f()# assert a % 2 0assert a % 2 0, "value was odd, should be even"解读与实操 pytest允许你使用标准python断言来验证测试中的期望和值&…

【工具使用-Audition】如何使用Audition频谱分析

一,简介 本文以Audition 2020为例,介绍如何生成频谱分析的图像。 二,操作步骤 使用快捷键“shift D” 三,总结 本文主要介绍如何查看频谱分析,供参考。

不会代码(零基础)学语音开发(语音控制双色LED)

语音开发板到手后,跟着说明做一遍。例程:语音控制双色LED! 首先,进行固件烧录。 资料中已经给了固件,按照手册中的说明进行烧录即可。整体感受还挺简单,认真读手册即可实现。 需要注意的事项&#xff1a…

C语言——交换两个int变量的值,不能使用第三个变量。

交换两个int变量的值&#xff0c;不能使用第三个变量。即 a3,b5,交换之后a5,b3; #include<stdio.h> int main() {int a3;int b5;printf("a%d b%d\n",a,b);aa^b;ba^b;aa^b;printf("a%d b%d\n",a,b); } “^”——按位异或操作符&#xff0c;这里的按…