基于SAM大模型的遥感影像分割工具,用于创建交互式标注、识别地物的能力,可利用Flask进行封装作为Web后台服务

news2024/9/22 17:14:04

如有帮助,支持一下(GitHub - Lvbta/ImageSegmentationTool-SAM: An interactive annotation case developed based on SAM for remote sensing image annotation, which can generate corresponding segmentation results based on point, multi-point, and rectangular box prompts, and convert the recognition results into vector data shp.)

本项目提供了一个图像分割工具,利用 Segment Anything Model (SAM) 对大规模的卫星或航拍图像进行分割。该工具支持通过单点、多点或边界框输入进行图像分割,并将分割结果保存为 shapefile,以便进一步进行地理空间分析。

功能特点

  • 单点分割:支持基于单个点的输入进行分割。
  • 多点分割:支持使用多个点进行分割。
  • 边界框分割:支持在指定的边界框内进行分割。
  • 地理空间集成:使用 GDAL 读取地理空间图像,并将分割的掩膜转换为多边形。
  • Shapefile 导出:将分割结果保存为 shapefile,方便与 GIS 工具集成。
  • 可视化:在原始图像上可视化分割结果,便于验证和分析。

安装

  1. 克隆仓库:

    git clone https://github.com/Lvbta/ImageSegmentationTool.git
    cd ImageSegmentationTool
  2. 下载SAM权重:

    • default or vit_h: ViT-H SAM model.
    • vit_l: ViT-L SAM model.
    • vit_b: ViT-B SAM model.
  3. 安装所需的依赖:

    pip install -r requirements.txt
  4. 设置环境变量:

    • 代码内已设置 KMP_DUPLICATE_LIB_OK 变量,以避免冲突。

使用方法

步骤 1:准备数据

  • 图像:确保您拥有地理参考的卫星或航拍图像,格式为 TIFF。
  • SAM 模型检查点:下载 SAM 模型检查点文件,并将其放置在项目目录中。

步骤 2:配置参数

在脚本中设置以下参数:

  • image_path: 您的地理参考图像文件的路径(例如 ./sentinel2.tif)。
  • sam_checkpoint: 您的 SAM 模型检查点文件的路径(例如 ./sam_vit_b_01ec64.pth)。
  • model_type: 用于分割的模型类型(vit_bvit_l 等)。
  • device: 用于运行模型的设备(cpu 或 cuda)。
  • output_shp: 保存输出 shapefile 的路径。

步骤 3:运行分割

选择分割模式并指定必要的输入点或边界框:

  • 单点模式

    seg_mode = 'single_point'
    input_points = [[1248, 1507]]
    single_label = [1]
  • 多点模式

    seg_mode = 'multi_point'
    input_points = [[389, 1041],[411, 1094]]
    single_label = [1, 1]
  • 边界框模式

    seg_mode = 'box'
    input_box = [[0, 951, 1909, 2383]]
    single_label = [1]

步骤 4:执行脚本

运行脚本以进行分割:

python main.py

步骤 5:可视化并保存结果

分割的掩膜将被可视化,多边形将作为 shapefile 保存到指定位置。

示例

使用边界框对图像进行分割,脚本配置如下:

# 边界框模式示例配置
seg_mode = 'box'
input_box = [[0, 951, 1909, 2383]]
single_label = [1]

segmenter = ImageSegmentation(image_path, sam_checkpoint, model_type, device)
masks, scores, x_off, y_off = segmenter.predict(mode=seg_mode, input_box=input_box, input_labels=single_label, multimask_output=True)
polygons = segmenter.masks_to_polygons(masks, x_off, y_off)
segmenter.save_polygons_gdal(polygons, output_shp)
segmenter.show_masks(seg_mode, masks, scores, x_off, y_off, input_box, single_label, image_chunk)
import numpy as np
import torch
import cv2
import sys
from osgeo import gdal, ogr, osr
from shapely.geometry import Polygon
from shapely.wkb import dumps
import matplotlib.pyplot as plt
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
plt.rcParams['font.sans-serif'] = 'SimHei'  # 设置中文显示
plt.rcParams['axes.unicode_minus'] = False
# plt.style.use('ggplot')


class ImageSegmentation:
    def __init__(self, image_path, sam_checkpoint, model_type='vit_b', device='cpu'):
        self.image_path = image_path
        self.sam_checkpoint = sam_checkpoint
        self.model_type = model_type
        self.device = device
        self.geo_transform, self.proj = self.get_geoinfo()
        self.sam = self.load_sam_model()
        self.predictor = self.init_predictor()

    def get_geoinfo(self):
        dataset = gdal.Open(self.image_path)
        geo_transform = dataset.GetGeoTransform()
        proj = dataset.GetProjection()
        dataset = None  # 关闭
        return geo_transform, proj

    def read_image_chunk(self, x_off, y_off, x_size, y_size):
        dataset = gdal.Open(self.image_path)
        image = dataset.ReadAsArray(x_off, y_off, x_size, y_size)
        dataset = None  # 关闭
        if len(image.shape) == 3:
            image = np.transpose(image, (1, 2, 0))  # GDAL reads in (bands, height, width) format
        else:
            image = np.stack([image] * 3, axis=-1)  # If it's a single-band image, stack to (height, width, 3)
        return image

    def load_sam_model(self):
        sys.path.append("..")
        from segment_anything import sam_model_registry

        sam = sam_model_registry[self.model_type](checkpoint=self.sam_checkpoint)
        sam.to(device=self.device)
        return sam

    def init_predictor(self):
        from segment_anything import SamPredictor

        predictor = SamPredictor(self.sam)
        return predictor

    def predict(self, mode='single_point', input_points=None, input_labels=None, input_box=None, multimask_output=None):
        if mode == 'single_point':
            assert input_points is not None and input_labels is not None, "Points and labels are required for single point mode."
            x, y = input_points[0]
            chunk_size = 512  # or any appropriate size
            x_off = max(x - chunk_size // 2, 0)
            y_off = max(y - chunk_size // 2, 0)
            x_size = y_size = chunk_size

            image_chunk = self.read_image_chunk(x_off, y_off, x_size, y_size)
            self.predictor.set_image(image_chunk)

            adjusted_points = [(x - x_off, y - y_off)]
            masks, scores, logits = self.predictor.predict(
                point_coords=np.array(adjusted_points),
                point_labels=np.array(input_labels),
                multimask_output=multimask_output,
            )
        elif mode == 'multi_point':
            assert input_points is not None and input_labels is not None, "Points and labels are required for multi point mode."
            # Determine bounding box of all points
            x_min = min(p[0] for p in input_points)
            y_min = min(p[1] for p in input_points)
            x_max = max(p[0] for p in input_points)
            y_max = max(p[1] for p in input_points)
            margin = 256  # or any appropriate margin
            x_off = max(x_min - margin, 0)
            y_off = max(y_min - margin, 0)
            x_size = min(x_max - x_min + 2 * margin, 2048)
            y_size = min(y_max - y_min + 2 * margin, 2048)

            image_chunk = self.read_image_chunk(x_off, y_off, x_size, y_size)
            self.predictor.set_image(image_chunk)

            adjusted_points = [(x - x_off, y - y_off) for x, y in input_points]
            masks, scores, logits = self.predictor.predict(
                point_coords=np.array(adjusted_points),
                point_labels=np.array(input_labels),
                multimask_output=multimask_output,
            )
        elif mode == 'box':
            assert input_box is not None, "Box coordinates are required for box mode."
            x_min, y_min, x_max, y_max = input_box[0]
            margin = 256  # or any appropriate margin
            x_off = max(x_min - margin, 0)
            y_off = max(y_min - margin, 0)
            x_size = min(x_max - x_min + 2 * margin, 2048)
            y_size = min(y_max - y_min + 2 * margin, 2048)

            image_chunk = self.read_image_chunk(x_off, y_off, x_size, y_size)
            self.predictor.set_image(image_chunk)

            adjusted_box = [(x_min - x_off, y_min - y_off, x_max - x_off, y_max - y_off)]
            masks, scores, logits = self.predictor.predict(
                box=np.array(adjusted_box).reshape(1, -1),
                multimask_output=multimask_output,
            )
        else:
            raise ValueError("Mode must be 'single_point', 'multi_point', or 'box'.")

        return masks, scores, x_off, y_off

    def masks_to_polygons(self, masks, x_off, y_off):
        polygons = []
        for mask in masks:
            contours, _ = cv2.findContours((mask > 0).astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
            for contour in contours:
                contour = contour.squeeze()
                if len(contour.shape) == 2 and len(contour) >= 3:  # valid polygon
                    geo_contour = [self.pixel_to_geo(x + x_off, y + y_off) for x, y in contour]
                    polygon = Polygon(geo_contour)
                    if polygon.is_valid:
                        polygons.append(polygon)
        return polygons

    def pixel_to_geo(self, x, y):
        geox = self.geo_transform[0] + x * self.geo_transform[1] + y * self.geo_transform[2]
        geoy = self.geo_transform[3] + x * self.geo_transform[4] + y * self.geo_transform[5]
        return geox, geoy

    def save_polygons_gdal(self, polygons, output_shp):
        driver = ogr.GetDriverByName("ESRI Shapefile")
        data_source = driver.CreateDataSource(output_shp)

        spatial_ref = osr.SpatialReference()
        spatial_ref.ImportFromWkt(self.proj)  # 使用图像的投影信息

        layer = data_source.CreateLayer("segmentation", spatial_ref, ogr.wkbPolygon)
        layer_defn = layer.GetLayerDefn()

        for i, polygon in enumerate(polygons):
            feature = ogr.Feature(layer_defn)
            geom_wkb = dumps(polygon)  # 将Shapely几何对象转换为WKB
            ogr_geom = ogr.CreateGeometryFromWkb(geom_wkb)  # 从WKB创建OGR几何对象
            feature.SetGeometry(ogr_geom)
            feature.SetField("id", i + 1)
            layer.CreateFeature(feature)
            feature = None

        data_source = None

    def show_masks(self, mode, masks, scores,x_off, y_off, input_point, input_label, image):
        for i, (mask, score) in enumerate(zip(masks, scores)):
            plt.figure(figsize=(10, 10))
            plt.imshow(image)
            self.show_mask(mask, plt.gca())
            if mode == 'box':
                self.show_box(np.array(input_point[0]), plt.gca(), x_off, y_off)
            else:
                self.show_points(np.array(input_point), np.array(input_label), plt.gca(), x_off, y_off)
            plt.title(f"{mode}模式 {i + 1}, Score: {score:.3f}", fontsize=18)
            plt.axis('on')
            plt.show()

    def show_mask(self, mask, ax, x_off=0, y_off=0):
        mask_resized = np.zeros((mask.shape[0] + y_off, mask.shape[1] + x_off), dtype=np.uint8)
        mask_resized[y_off:y_off + mask.shape[0], x_off:x_off + mask.shape[1]] = mask.astype(np.uint8)
        contours, _ = cv2.findContours(mask_resized, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        for contour in contours:
            contour[:, :, 0] += x_off
            contour[:, :, 1] += y_off
            ax.plot(contour[:, 0, 0], contour[:, 0, 1], color='lime', linewidth=2)

    def show_points(self, points, labels, ax, x_off, y_off):
        for point, label in zip(points, labels):
            x, y = point
            x -= x_off  
            y -= y_off  
            ax.scatter(x, y, c='red', marker='o', label=f'Label: {label}')

    @staticmethod
    def show_box(box, ax, x_off, y_off):
        x0, y0 = box[0]-x_off, box[1]-y_off
        w, h = box[2] - box[0], box[3] - box[1]
        ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='red', facecolor=(0, 0, 0, 0), lw=2))


if __name__ == '__main__':
    # Usage
    image_path = r'./data/sentinel2.tif'
    sam_checkpoint = "./model/sam_vit_b_01ec64.pth"
    model_type = "vit_b"
    device = "cpu"
    output_shp = r'./result/segmentation_results.shp'
    # # 预测模式
    # seg_mode = 'single_point'
    # # # 模型参数
    # input_points = [[1248, 1507]]
    # single_label = [1]

    # # 预测模式
    # seg_mode = 'multi_point'
    # # 模型参数
    # input_points = [[389, 1041],[411, 1094]]
    # single_label = [1, 1]

    # 预测模式
    seg_mode = 'box'
    # 模型参数
    input_box = [[0, 951, 1909, 2383]]
    single_label = [1]


    # 实例化类
    segmenter = ImageSegmentation(image_path, sam_checkpoint, model_type, device)

    # # 调用segAnything模型
    # masks, scores, x_off, y_off = segmenter.predict(mode=seg_mode, input_points=input_points,
    #                                                     input_labels=single_label, multimask_output=False)
    # box
    masks, scores, x_off, y_off = segmenter.predict(mode=seg_mode, input_box=input_box,
                                                    input_labels=single_label, multimask_output=True)

    # 模型预测结果转矢量多边形
    polygons = segmenter.masks_to_polygons(masks, x_off, y_off)

    # 保存为shp
    segmenter.save_polygons_gdal(polygons, output_shp)

    # 可视化
    image_chunk = segmenter.read_image_chunk(x_off, y_off, 512, 512)
    # segmenter.show_masks(seg_mode, masks, scores, x_off, y_off, input_points, single_label, image_chunk)
    # box
    segmenter.show_masks(seg_mode, masks, scores, x_off, y_off, input_box, single_label, image_chunk)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2155537.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Selenium4.0实现自动搜索功能

01.Selenium4.0实现搜索功能 1.安装Selenium及查看Selenium版本 pip install selenium pip show seleniumfrom selenium import webdriver from chromedriver_py import binary_path import time from selenium.webdriver.common.by import By from selenium.webdriver.commo…

postman控制变量和常用方法

1、添加环境: 2、环境添加变量: 3、配置不同的环境:local、dev、sit、uat、pro 4、 接口调用 5、清除cookie方法: 6、下载文件方法:

SSM+vue音乐播放器管理系统

音乐播放器管理系统 随着社会的发展,计算机的优势和普及使得音乐播放器管理系统的开发成为必需。音乐播放器管理系统主要是借助计算机,通过对首页、音乐推荐、付费音乐、论坛信息、个人中心、后台管理等信息进行管理。减少管理员的工作,同时…

c# 线程等待变量的值符合条件

在C#中,如果你想让一个线程等待直到某个变量的值满足特定条件,你可以使用ManualResetEvent或者AutoResetEvent来实现线程间的同步。以下是使用AutoResetEvent实现的一个简单例子: 在这个例子中,同时实现了如何让static函数访问非…

高等数学 3.6 函数图像的描绘

利用导数描绘函数图形的一般步骤如下: (1)确定函数 y f ( x ) y f(x) yf(x) 的定义域及函数所具有的某些特性(如奇偶性、周期性等),并求出函数的一阶导数 f ′ ( x ) f^{}(x) f′(x) 和二阶导数 f ′ …

【YOLO目标检测手势识别数据集】共55952张、已标注txt格式、有训练好的yolov5的模型

目录 说明图片示例 说明 数据集格式:YOLO格式 图片数量:55952 标注数量(txt文件个数):55952 标注类别数:7 标注类别名称: one two three four five good ok 数据集下载:手势识别数据集 图片示例 数…

【CSS】伪类选择器 :root 声明全局CSS变量

语法 <style>:root {/* ... */} </style>这个 CSS 伪类 :root 匹配文档树的根元素&#xff0c;表示选中 <html> 元素&#xff0c;除了优先级更高之外&#xff0c;与 html 选择器相同 <style>/* 选中文档的根元素&#xff08;HTML 中的 <html>&a…

制造业的智能化革命:工业物联网(IIoT)的优势、层级应用及挑战解析

在全球制造业的蓬勃发展中&#xff0c;工业物联网&#xff08;IIoT&#xff09;作为一股颠覆性力量&#xff0c;正逐步重塑传统制造业的面貌。IIoT技术通过无缝连接设备、系统与人员&#xff0c;促进了数据的即时流通与处理&#xff0c;不仅极大地提升了制造效率&#xff0c;还…

pikachu XXE(XML外部实体注入)通关

靶场&#xff1a;pikachu 环境: 系统&#xff1a;Windows10 服务器&#xff1a;PHPstudy2018 靶场&#xff1a;pikachu 关卡提示说&#xff1a;这是一个接收xml数据的api 常用的Payload 回显 <?xml version"1.0"?> <!DOCTYPE foo [ <!ENTITY …

【Godot4.3】基于状态切换的游戏元素概论

提示 本文的设想性质比较大,只是探讨一种设计思路。完全理论阶段&#xff0c;不可行就当是闹了个笑话O(∩_∩)O哈哈~但很符合我瞎搞的气质。 概述 一些游戏元素&#xff0c;其实是拥有多个状态的。比如一个宝箱&#xff0c;有打开和关闭两个状态。那么只需要设定两个状态的图…

演示:基于WPF的DrawingVisual开发的Chart图表和表格绘制

一、目的&#xff1a;基于WPF的DrawingVisual开发的Chart图表和表格绘制 二、预览 钻井井轨迹表格数据演示示例&#xff08;应用Table布局&#xff0c;模拟井轨迹深度的绘制&#xff09; 饼图表格数据演示示例&#xff08;应用Table布局&#xff0c;模拟多个饼状图组合显示&am…

OpenCV_图像膨胀腐蚀与形态学操作及具体应用详解

在本教程中&#xff0c;您将学习如何&#xff1a; 应用两个非常常见的形态运算符&#xff1a;腐蚀和膨胀&#xff1a; cv::erodecv::dilate 使用OpenCV函数cv :: morphologyEx应用形态转换&#xff0c;如&#xff1a; 开运算闭运算形态学梯度顶帽运算黑帽运算 形态作业 简…

算法-环形链表(141)

这道题其实是一个非常经典的快慢指针的问题 &#xff0c;也成为Floyd的乌龟和兔子算法。 设置两个指针&#xff0c;一个快指针&#xff0c;一个满指针&#xff0c;都从头节点开始遍历&#xff0c;如果链表中存在环&#xff0c;那么快指针最终会在环内某个节点相遇&#xff0c;…

Linux网络工具:用于查询DNS(域名系统)域名解析信息的命令nslookup详解

目录 一、概述 二、基本功能 1、查询域名对应的IP地址 2、查询IP地址对应的主机名 3、查询特定类型的DNS记录 三、用法 1、命令格式 2、常用选项 五、nslookup的安装 1. 打开终端 2. 更新的系统包列表 3. 安装 bind-utils 软件包 &#xff08;1&#xff09;对于Ce…

树和二叉树的概念以及结构

一起加油学数据结构 目录 树的概念以及结构 树的概念 树的相关概念 树的表示 二叉树的概念以及结构 二叉树的概念 特殊的二叉树 二叉树的性质 二叉树的存储结构 树的概念以及结构 树的概念 树是一种非线性的数据结构&#xff0c;它是由n&#xff08;n>0&#xff09…

【Elasticsearch系列十九】评分机制详解

&#x1f49d;&#x1f49d;&#x1f49d;欢迎来到我的博客&#xff0c;很高兴能够在这里和您见面&#xff01;希望您在这里可以感受到一份轻松愉快的氛围&#xff0c;不仅可以获得有趣的内容和知识&#xff0c;也可以畅所欲言、分享您的想法和见解。 推荐:kwan 的首页,持续学…

Mapper核心配置文件

文章目录 environment 数据库环境typeAlias 起别名 environment 数据库环境 typeAlias 起别名

【QGIS入门实战精品教程】6.2:QGIS选择要素的多种方法

本文讲解QGIS中选择要素的多种方法。 文章目录 一、选择要素二、多边形选择三、自由手绘四、按半径选择五、按值选择要素六、按表达式选择在QGIS中,选择要素有多种方法,如下所示: 下面举例说明。 一、选择要素 可以直接点选、框选实现单个或者多个点线面要素的选择(按住C…

【计算机网络 - 基础问题】每日 3 题(十八)

✍个人博客&#xff1a;Pandaconda-CSDN博客 &#x1f4e3;专栏地址&#xff1a;http://t.csdnimg.cn/fYaBd &#x1f4da;专栏简介&#xff1a;在这个专栏中&#xff0c;我将会分享 C 面试中常见的面试题给大家~ ❤️如果有收获的话&#xff0c;欢迎点赞&#x1f44d;收藏&…