【变化检测】基于UNet建筑物变化检测

news2024/11/14 6:11:43

主要内容如下:

1、LEVIR-CD数据集介绍及下载
2、运行环境安装
3、基于likyoo变化检测代码模型训练与预测
4、Onnx运行及可视化

运行环境:Python=3.8,torch1.12.0+cu113
likyoo变化检测源码:https://github.com/likyoo/change_detection.pytorch
使用情况:环境配置简单、训练速度也快。

1 LEVIR-CD数据集介绍

1.1 简介

LEVIR-CD 由 637 个超高分辨率(VHR,0.5m/像素)谷歌地球(GE)图像块对组成,大小为 1024 × 1024 像素。这些时间跨度为 5 到 14 年的双时态图像具有显着的土地利用变化,尤其是建筑增长。LEVIR-CD涵盖别墅住宅、高层公寓、小型车库、大型仓库等各类建筑。在这里,我们关注与建筑相关的变化,包括建筑增长(从土壤/草地/硬化地面或在建建筑到新的建筑区域的变化)和建筑衰退。这些双时态图像由遥感图像解释专家使用二进制标签(1 表示变化,0 表示不变)进行注释。我们数据集中的每个样本都由一个注释者进行注释,然后由另一个进行双重检查以生成高质量的注释。
数据来源:https://justchenhao.github.io/LEVIR/
论文地址:https://www.mdpi.com/2072-4292/12/10/1662
快速下载链接:https://aistudio.baidu.com/datasetdetail/104390/1

1.2 示例

在这里插入图片描述

2 运行环境安装

2.1 基础环境安装

【超详细】跑通YOLOv8之深度学习环境配置1-Anaconda安装
【超详细】跑通YOLOv8之深度学习环境配置2-CUDA安装

创建Python环境及换源可借鉴如下:
【超详细】跑通YOLOv8之深度学习环境配置3-YOLOv8安装

2.2 likyoo变化检测代码环境安装

2.2.1 代码下载

在这里插入图片描述

2.2.2 环境安装
# 1 创建环境
conda create -n likyoo python=3.8
conda activate likyoo

# 2 安装torch
# 方式1:
conda install pytorch==1.12.0 torchvision==0.13.0 torchaudio==0.12.0 cudatoolkit=11.3 -c pytorch
# 方式2:
pip install torch==1.12.0+cu113 torchvision==0.13.0+cu113 torchaudio==0.12.0 --extra-index-url https://download.pytorch.org/whl/cu113

# 3 验证torch安装是否为gpu版
import torch
print(torch.__version__)  # 打印torch版本
print(torch.cuda.is_available())  # True即为成功
print(torch.version.cuda)
print(torch.backends.cudnn.version())

# 4 安装其他依赖库
cd ./change_detection.pytorch-main
pip install -r requirements.txt
pip install six  # 训练报错缺少该库

3 模型训练与预测

3.1 模型架构

在这里插入图片描述

3.1 模型训练

训练代码为local_test.py

  1. 选择不同的分割架构,详情见cdp的__init__.py
  2. 修改训练数据路径
  3. epoch可以修改72行MAX_EPOCH,默认为60,batchsize默认为8【本文改成32训练,测试效果比8好,显存占8G左右】;
    在这里插入图片描述

训练报错:RuntimeError: Attempted to set the storage of a tensor on device “cpu” to a storage on different device “cuda:0”. This is no longer allowed; the devices must match.

解决方法
在 hub.py里修改最后一行,删去, map_location=map_location
在这里插入图片描述

3.2 模型预测

新建predict.py脚本,复制如下内容【注意输入路径是否正确,否则报错】:

import cv2
import numpy as np 
import torch
from torch.utils.data import DataLoader, Dataset
import albumentations as A
import change_detection_pytorch as cdp
from change_detection_pytorch.datasets import LEVIR_CD_Dataset, SVCD_Dataset
from change_detection_pytorch.utils.lr_scheduler import GradualWarmupScheduler
 
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
 
model = cdp.Unet(
    encoder_name="resnet34",  # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
    encoder_weights="imagenet",  # use `imagenet` pre-trained weights for encoder initialization
    in_channels=3,  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
    classes=2,  # model output channels (number of classes in your datasets)
    siam_encoder=True,  # whether to use a siamese encoder
    fusion_form='concat',  # the form of fusing features from two branches. e.g. concat, sum, diff, or abs_diff.
)
 
model_path = 'best_model.pth'  # 修改1
model.to(DEVICE)
# model.load_state_dict(torch.load(model_path))
model = torch.load(model_path)
model.eval()
test_transform = A.Compose([
            A.Normalize()])

path1 = 'E:/datasets/LEVIR-CD/test/A/test_7.png'  # 修改2
img1 = cv2.imread(path1)
img1 = test_transform(image = img1)
img1 = img1['image']
img1 = img1.transpose(2, 0, 1)
img1 = np.expand_dims(img1,0)
img1 = torch.Tensor(img1)
img1 = img1.cuda()
 
path2 = 'E:/datasets/LEVIR-CD/test/B/test_7.png'  # 修改3
img2 = cv2.imread(path2)
img2 = test_transform(image = img2)
img2 = img2['image']
img2 = img2.transpose(2, 0, 1)
img2 = np.expand_dims(img2,0)
img2 = torch.Tensor(img2)
img2 = img2.cuda()
 
pre = model(img1,img2)
pre = torch.argmax(pre, dim=1).cpu().data.numpy() * 255
cv2.imwrite('./result/test_7_pre.png', pre[0])  # 修改4

3.3 结果显示

其中epoch=60结果:
在这里插入图片描述

4 Onnx运行及可视化

4.1 Onnx导出静态和动态文件

  1. 新建export.py脚本,导出onnx,复制如下内容:
import torch
import change_detection_pytorch as cdp

model = cdp.Unet(
    encoder_name="resnet34",  # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
    encoder_weights="imagenet",  # use `imagenet` pre-trained weights for encoder initialization
    in_channels=3,  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
    classes=2,  # model output channels (number of classes in your datasets)
    siam_encoder=True,  # whether to use a siamese encoder
    fusion_form='concat',  # the form of fusing features from two branches. e.g. concat, sum, diff, or abs_diff.
)

model_path = "best_model.pth"  # 修改1
model = torch.load(model_path)

input1 = torch.randn(1, 3, 1024, 1024, device='cuda:0')
input2 = torch.randn(1, 3, 1024, 1024, device='cuda:0')

# 保存静态onnx
torch.onnx.export(model, (input1, input2),
                   "cd.onnx",
                   input_names=["images1", "images2"],
                   output_names=["output"],
                   verbose=True, opset_version=12)

# 保存动态onnx
torch.onnx.export(model, (input1, input2),
                   "cd_dy.onnx",
                   input_names=["images1", "images2"],
                   output_names=["output"],
                   verbose=True, opset_version=12,
                   dynamic_axes={
                       "images1": {0 :"batch_size", 2: "input_height", 3: "input_width"},
                       "images2": {0 :"batch_size", 2: "input_height", 3: "input_width"},
                       "output": {0 :"batch_size", 2: "output_height", 3: "output_width"}
                   })
  1. 查看模型结构
    https://netron.app/
    静态onnx:
    在这里插入图片描述

    动态onnx:
    在这里插入图片描述

4.2 Onnx运行及可视化

4.2.1 Onnx推理运行
import os
import cv2
import time
import argparse
import numpy as np
import onnxruntime as ort  # 使用onnxruntime推理用上,pip install onnxruntime-gpu==1.12.0 -i https://pypi.tuna.tsinghua.edu.cn/simple
 

class CD(object):
    def __init__(self, onnx_model, in_shape=1024):
        self.in_shape = in_shape  # 图像输入尺度
        self.mean = [0.485, 0.456, 0.406]  # 定义均值和标准差(确保它们与图像数据的范围相匹配)  
        self.std = [0.229, 0.224, 0.225]  # 基于0-1范围的

        # 构建onnxruntime推理引擎
        self.ort_session = ort.InferenceSession(onnx_model,
                                providers=['CUDAExecutionProvider', 'CPUExecutionProvider']
                                if ort.get_device() == 'GPU' else ['CPUExecutionProvider'])

    # 归一化 
    def normalize(self, image, mean, std):  
        # 如果均值和标准差是基于0-255范围的图像计算的,那么需要先将图像转换为0-1范围  
        image = image / 255.0  
        image = image.astype(np.float32)  
        image_normalized = np.zeros_like(image)  

        for i in range(3):  # 对于 RGB 的每个通道  
            image_normalized[:, :, i] = (image[:, :, i] - mean[i]) / std[i]  
        return image_normalized
    

    def preprocess(self, img_a, img_b):
        # resize为1024大小
        if img_a.shape[0] != self.in_shape and img_a.shape[1] != self.in_shape:
            img_a = cv2.resize(img_a, (self.in_shape, self.in_shape), interpolation=cv2.INTER_LINEAR)
        if img_b.shape[0] != self.in_shape and img_b.shape[1] != self.in_shape:
            img_b = cv2.resize(img_b, (self.in_shape, self.in_shape), interpolation=cv2.INTER_LINEAR)

        # 应用归一化  
        img_a = self.normalize(img_a, self.mean, self.std)
        img_b = self.normalize(img_b, self.mean, self.std)
        img_a = np.ascontiguousarray(np.einsum('HWC->CHW', img_a)[::-1], dtype=np.single)  # (1024, 1024, 3)-->(3, 1024, 1024), BGR-->RGB
        img_b = np.ascontiguousarray(np.einsum('HWC->CHW', img_b)[::-1], dtype=np.single)  # np.single 和 np.float32 是等价的
        img_a = img_a[None] if len(img_a.shape) == 3 else img_a  # (1, 3, 256, 256)
        img_b = img_b[None] if len(img_b.shape) == 3 else img_b

        return img_a, img_b
    
    # 推理
    def infer(self, img_a, img_b):
        im1, im2 = self.preprocess(img_a, img_b)  # --> (1, 3, 256, 256)

        preds = self.ort_session.run(None, {self.ort_session.get_inputs()[0].name: im1, self.ort_session.get_inputs()[1].name: im2})[0]  
        out_img = (np.argmax(preds, axis=1)[0] * 255).astype("uint8")  
       
        return out_img

if __name__ == '__main__':
    # Create an argument parser to handle command-line arguments
    parser = argparse.ArgumentParser()
    parser.add_argument('--model', type=str, default='cd.onnx', help='Path to ONNX model')
    parser.add_argument('--source_A', type=str, default=str('E:/datasets/LEVIR-CD/test/A/test_7.png'), help='A期图像')
    parser.add_argument('--source_B', type=str, default=str('E:/datasets/LEVIR-CD/test/B/test_7.png'), help='B期图像')
    parser.add_argument('--in_shape', type=int, default=1024, help='输入模型图像尺度')
    args = parser.parse_args()

    # 实例化变化检测模型
    cd= CD(args.model, args.in_shape)

    t1 = time.time()
    # Read image by OpenCV
    img_a = cv2.imread(args.source_A)
    img_b = cv2.imread(args.source_B)

    # 推理+输出
    out = cd.infer(img_a, img_b)
    
    # 保存结果
    cv2.imwrite('./result/test_7_res.png', out)
    print('总耗时:{}'.format(time.time() - t1))
4.2.2 结果可视化

注意:Onnx前处理与前面的pth不一样,结果略有不同。
显存:占用2G左右,速度:100ms左右(RTX4080)
在这里插入图片描述

4.2.3 后处理:

(1)可加入腐蚀膨胀处理,消除一些小白点等区域;
(2)将变化区域绘制在第二期图上,便于观察;

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2071003.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

数据仓库中的表设计模式:全量表、增量表与拉链表

在现代数据仓库中,管理和分析海量数据需要高效且灵活的数据存储策略。全量表、增量表和拉链表是三种常见的数据存储模式,各自针对不同的数据管理需求提供了解决方案。全量表通过保存完整的数据快照确保数据的一致性,增量表则通过记录数据的变…

如何在 Ubuntu 系统中安装PyCharm集成开发环境?

在上一篇文章中,我们探讨了Jupyter notebook,今天再来看看另一款常用的Python 工具,Pycharm。 PyCharm也是我们日常开发和学习常用的Python 集成开发环境 (IDE),由 JetBrains 开发。 PyCharm 带有一整套可以帮助用户在使用Pytho…

大数据-91 Spark 集群 RDD 编程-高阶 RDD广播变量 RDD累加器 Spark程序优化

点一下关注吧!!!非常感谢!!持续更新!!! 目前已经更新到了: Hadoop(已更完)HDFS(已更完)MapReduce(已更完&am…

【代码随想录训练营第42期 Day39打卡 - 打家劫舍问题 - LeetCode 198.打家劫舍 213.打家劫舍II 337.打家劫舍III

目录 一、做题心得 二、题目与题解 题目一:198.打家劫舍 题目链接 题解:动态规划 题目二:213.打家劫舍II 题目链接 题解:动态规划 题目三:337.打家劫舍III 题目链接 题解:动态规划 三、小结 一、…

卸载nomachine

网上的方法:提示找不到命令 我的方法: step1. 终端输入 sudo find / -name nxserver 2>/dev/null确认 NoMachine 的实际安装路径。你可以使用 find 命令在系统中查找 nxserver 脚本的位置。 找到路径后,你可以使用该路径来卸载 NoMachine。 如下图,紫色框中是我的路径…

【ACM出版】第三届公共管理、数字经济与互联网技术国际学术会议(ICPDI 2024,9月06-08)

第三届公共管理、数字经济与互联网技术国际学术会议(ICPDI 2024)定于2024年9月06-08日在中国-济南举行。 会议主要围绕公共管理、数字经济,互联网技术等研究领域展开讨论。会议旨在为从事公共管理、经济、大数据、互联网研究的专家学者提供一…

解决LabVIEW配置文件中文乱码问题

LabVIEW配置文件中的中文字符在程序调用时出现乱码,通常是由于字符编码不匹配引起的。LabVIEW默认使用ANSI编码格式,而配置文件可能使用了不同的编码格式(如UTF-8),导致中文字符在读取时无法正确解析。 解决方法 统一编…

导数的基本法则与常用导数公式的推导

目录 n 次幂函数导数公式的推导导数和的运算法则的证明正弦、余弦函数导数公式的推导代数证明两个重要极限(引理)及证明具体推导 几何直观 导数积的运算法则的证明导数商的法则的证明链式法则的证明有理幂函数求导法则的证明反函数求导法则的证明反正切函…

SSH 远程登录报错:kex_exchange_identification: Connection closed.....

一 问题起因 在公司,使用ssh登录远程服务器。有一天,mac终端提示:`kex_exchange_identification: Connection closed by remote host Connection closed by UNKNOWN port 65535`。 不知道为啥会出现这样的情形,最近这段时间登录都是正常的,不知道哪里抽风了,就提示这个。…

C++初学(15)

前面学习了循环的工作原理,接下来来看看循环完成的一项最常见的任务:逐字符地读取来自文本或键盘的文本。 15.1、使用cin进行输入 如果需要程序使用循环来读取来自键盘的文本输入,则必须有办法直到何时停止读取。一种方式是选择某个特殊字符…

发布分班查询,老师都在用哪个小程序?

新学期伊始,校园里又迎来了一批朝气蓬勃的新生。老师们的日程表上,除了日常的教学准备,还多了一项重要的任务——分班。这项工作不仅需要老师们精心策划,以确保每个班级的平衡,还要在分班完成后,及时将结果…

系统架构不是设计出来的

今天给大家分享一个 X/ Twitter 早期系统架构演变的故事,内容来自《数据密集型应用系统设计》这本书,具体数据来自 X/ Twitter 在 2012 年 11 月发布的分享。 推特的两个主要业务是: 发布推文(Tweets)。用户可以向其粉…

零基础入门~汇编语言(第四版王爽)~第3章寄存器(内存访问)

文章目录 前言3.1 内存中字的存储3.2 DS 和[address]3.3 字的传送3.4 mov、add、sub指令3.5 数据段检测点3.13.6 栈3.7 CPU提供的栈机制3.8 栈顶超界的问题3.9 push、pop指令3.10 栈 段检测点3.2实验2 用机器指令和汇编指令编程 前言 第2章中,我们主要从CPU 如何执…

2月公开赛Web-ssrfme

考点&#xff1a; redis未授权访问 源码&#xff1a; <?php highlight_file(__file__); function curl($url){ $ch curl_init();curl_setopt($ch, CURLOPT_URL, $url);curl_setopt($ch, CURLOPT_HEADER, 0);echo curl_exec($ch);curl_close($ch); }if(isset($_GET[url…

回归预测 | Matlab实现WOA-ESN鲸鱼算法优化回声状态网络多输入单输出回归预测

回归预测 | Matlab实现WOA-ESN鲸鱼算法优化回声状态网络多输入单输出回归预测 目录 回归预测 | Matlab实现WOA-ESN鲸鱼算法优化回声状态网络多输入单输出回归预测效果一览基本介绍程序设计参考资料 效果一览 基本介绍 1.Matlab实现WOA-ESN鲸鱼算法优化回声状态网络多输入单输出…

vue3 生命周期钩子

在 Vue 3 中&#xff0c;可以在组件不同阶段的生命周期执行特定逻辑。 生命周期整体分为四个阶段&#xff1a;创建、挂载、更新、卸载。 创建阶段 组合式APIsetup() 这是组合式 API 的入口点&#xff0c;在组件实例创建之前被调用。在此阶段&#xff0c;可以初始化响应式数据…

一键批量查询邮政快递,物流状态尽在掌握

邮政快递批量查询&#xff0c;轻松掌握物流动态 在电商行业蓬勃发展的今天&#xff0c;邮政快递作为连接商家与消费者的桥梁&#xff0c;其物流信息的及时性和准确性对于提升客户体验至关重要。然而&#xff0c;面对海量的快递单号&#xff0c;如何高效地进行批量查询&#xf…

【最长上升子序列】

题目 代码 #include <bits/stdc.h> using namespace std; const int N 1010; int a[N], f[N]; int main() {int n;cin >> n;for(int i 1; i < n; i) cin >> a[i];int res 0;for(int i 1; i < n; i){f[i] 1;for(int j 1; j < i; j){if(a[j] &…

(贪心) LeetCode 135. 分发糖果

原题链接 一. 题目描述 n 个孩子站成一排。给你一个整数数组 ratings 表示每个孩子的评分。 你需要按照以下要求&#xff0c;给这些孩子分发糖果&#xff1a; 每个孩子至少分配到 1 个糖果。 相邻两个孩子评分更高的孩子会获得更多的糖果。 请你给每个孩子分发糖果&#xf…

UE5 蓝图 计算当前时间段

思路&#xff1a; 那当前hour与阈值hour对比 。小于返回&#xff0c;大于就继续循环对比。 临时变量 折叠图表↓