区域识别——基于python语言

news2024/11/23 11:47:23

目录

目录

1.水域识别

2.模型介绍

3.文件框架

4.代码示例

4.1 data_preprocess.py

4.2 model1.0.py

4.3 train2.0.py

4.4 predict.py

4.5 运行结果

5.总结


1.水域识别

人眼看见河道可以直接分辨出这是河道,但是如何让计算机也能识别出这是河道呢?

这里就用到了深度学习之图像处理技术。

想要识别河道中的河水,对拍摄的河道图片进行编辑,可以创建两个图片文件的集合,一个文件中存放具体的河道图片,另一个存放带有标注的河道水域的分割掩码(标签),然后进行模型训练,训练一个可以识别河道水域的深度学习模型。

原始图像

掩码图像

2.模型介绍

使用深度学习框架TensorFlow和Keras进行图像分割任务。

模型的训练首先要进行数据处理,然后再进行模型训练。

3.文件框架

 data2.0中的三个数据文件夹分别是训练集,验证集和测试集的图片。

data_preprocess.py是图片预处理代码。

model1.0.py是定义了一个基于U-Net架构的卷积神经网络模型,用来做图像分割任务。

train2.0.py是训练程序,最重要的是要运行这个。

predict.py是调用训练后的模型的代码。

4.代码示例

4.1 data_preprocess.py

import os
import cv2
import numpy as np
from tensorflow.keras.preprocessing.image import img_to_array, load_img

# 图像大小和通道数
IMG_HEIGHT = 256
IMG_WIDTH = 256
IMG_CHANNELS = 3

# 路径设置
image_dir = './River Water Segmentation/data2.0/train/images'
mask_dir = './River Water Segmentation/data2.0/train/masks'

def preprocess_images(image_dir, target_size=(IMG_HEIGHT, IMG_WIDTH)):
    images = []
    for filename in os.listdir(image_dir):
        img_path = os.path.join(image_dir, filename)
        img = load_img(img_path, target_size=target_size)
        img_array = img_to_array(img)
        images.append(img_array)
    return np.array(images)

def preprocess_masks(mask_dir, target_size=(IMG_HEIGHT, IMG_WIDTH)):
    masks = []
    for filename in os.listdir(mask_dir):
        mask_path = os.path.join(mask_dir, filename)
        mask = load_img(mask_path, target_size=target_size, color_mode="grayscale")
        mask_array = img_to_array(mask) / 255.0  # 将掩码值标准化为0和1
        masks.append(mask_array)
    return np.array(masks)

if __name__ == "__main__":
    images = preprocess_images(image_dir)
    masks = preprocess_masks(mask_dir)
    print(f"Processed {len(images)} images and {len(masks)} masks.")

4.2 model1.0.py

import tensorflow as tf
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Conv2DTranspose, concatenate, Input, Dropout
from tensorflow.keras.models import Model

def build_unet(input_shape):
    inputs = Input(input_shape)
    
    # 编码器部分 (下采样)
    c1 = Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(inputs)
    c1 = Dropout(0.1)(c1)
    c1 = Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c1)
    p1 = MaxPooling2D((2, 2))(c1)

    c2 = Conv2D(128, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(p1)
    c2 = Dropout(0.1)(c2)
    c2 = Conv2D(128, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c2)
    p2 = MaxPooling2D((2, 2))(c2)

    c3 = Conv2D(256, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(p2)
    c3 = Dropout(0.2)(c3)
    c3 = Conv2D(256, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c3)
    p3 = MaxPooling2D((2, 2))(c3)

    c4 = Conv2D(512, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(p3)
    c4 = Dropout(0.3)(c4)
    c4 = Conv2D(512, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c4)
    p4 = MaxPooling2D((2, 2))(c4)

    # 底层
    c5 = Conv2D(1024, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(p4)
    c5 = Dropout(0.4)(c5)
    c5 = Conv2D(1024, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c5)

    # 解码器部分 (上采样)
    u6 = Conv2DTranspose(512, (2, 2), strides=(2, 2), padding='same')(c5)
    u6 = concatenate([u6, c4])
    c6 = Conv2D(512, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(u6)
    c6 = Dropout(0.3)(c6)
    c6 = Conv2D(512, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c6)

    u7 = Conv2DTranspose(256, (2, 2), strides=(2, 2), padding='same')(c6)
    u7 = concatenate([u7, c3])
    c7 = Conv2D(256, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(u7)
    c7 = Dropout(0.2)(c7)
    c7 = Conv2D(256, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c7)

    u8 = Conv2DTranspose(128, (2, 2), strides=(2, 2), padding='same')(c7)
    u8 = concatenate([u8, c2])
    c8 = Conv2D(128, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(u8)
    c8 = Dropout(0.1)(c8)
    c8 = Conv2D(128, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c8)

    u9 = Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same')(c8)
    u9 = concatenate([u9, c1], axis=3)
    c9 = Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(u9)
    c9 = Dropout(0.1)(c9)
    c9 = Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c9)

    outputs = Conv2D(1, (1, 1), activation='sigmoid')(c9)  # 单通道输出用于二分类分割

    model = Model(inputs=[inputs], outputs=[outputs])
    
    return model

if __name__ == "__main__":
    # 输入形状 (256, 256, 3) 用于 RGB 图像
    model = build_unet((256, 256, 3))
    model.summary()

4.3 train2.0.py

import numpy as np
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from model import build_unet
from data_preprocess import preprocess_images, preprocess_masks

# 设置路径
train_images_path = './River Water Segmentation/data2.0/train/images'
train_masks_path = './River Water Segmentation/data2.0/train/masks'
val_images_path = './River Water Segmentation/data2.0/train/images'
val_masks_path = './River Water Segmentation/data2.0/train/masks'

# 加载数据
train_images = preprocess_images(train_images_path)
train_masks = preprocess_masks(train_masks_path)
val_images = preprocess_images(val_images_path)
val_masks = preprocess_masks(val_masks_path)

# 数据增强
data_gen_args = dict(rotation_range=0.2,
                     width_shift_range=0.05,
                     height_shift_range=0.05,
                     shear_range=0.05,
                     zoom_range=0.05,
                     horizontal_flip=True,
                     fill_mode='nearest')

image_datagen = ImageDataGenerator(**data_gen_args)
mask_datagen = ImageDataGenerator(**data_gen_args)

image_datagen.fit(train_images)
mask_datagen.fit(train_masks)


# 编译模型
model.compile(optimizer=Adam(learning_rate=1e-4), loss='binary_crossentropy', metrics=['accuracy'])

# 设置回调函数
checkpoint = ModelCheckpoint('models/unet_model.h5', save_best_only=True, monitor='val_loss', mode='min')
early_stopping = EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True)
reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=5, min_lr=1e-6)

# 训练模型
history = model.fit(image_datagen.flow(train_images, train_masks, batch_size=16),
                    validation_data=(val_images, val_masks), 
                    epochs=50, 
                    callbacks=[checkpoint, early_stopping, reduce_lr])

# 保存最终模型
model.save('./River Water Segmentation/models/6.0model.h5')

4.4 predict.py

import numpy as np
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing import image
import matplotlib.pyplot as plt
import cv2

# 加载模型
model = load_model('./River Water Segmentation/models/unet_final_model.h5')

# 预测函数
def predict_image(image_path):
    img = image.load_img(image_path, target_size=(256, 256))
    img_array = image.img_to_array(img) / 255.0
    img_array = np.expand_dims(img_array, axis=0)
    
    # 预测掩码
    pred_mask = model.predict(img_array)
    pred_mask = (pred_mask > 0.5).astype(np.uint8)  # 二值化处理
    
    # 显示原图和预测掩码
    plt.figure(figsize=(10, 5))
    plt.subplot(1, 2, 1)
    plt.imshow(img)
    plt.title('Original')
    
    plt.subplot(1, 2, 2)
    plt.imshow(pred_mask[0, :, :, 0], cmap='gray')
    plt.title('Predicted')
    plt.show()

# 示例预测
predict_image('./River Water Segmentation/data/test/river/water_body_27.jpg')

4.5 运行结果

5.总结

模型主要用到了U-net深度学习模型,它通过编码器-解码器结构,能够有效地捕捉图像的上下文信息并进行精细的分割。

关于此例程可以考虑用在其他的图像分割对象上,比如森林区域识别,沙漠区域识别,改进代码可以进一步提高准确率,如果需要帮助或者合作,请私信我。

还请转载文章或者使用代码请备注来源呀!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2209266.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

无序中的秩序:Transformer模型的创新性结构解析

最近我在看论文时,发现很多都在用 Transformer 模型,我知道transformer很有名,但是我也只是听说过他的大名,不知道他具体是做什么怎么做的,因此我决定深入了解一下,并做个简单记录,方便以后参考…

TDC上YARN Web-UI 查看application日志方法

方法一 #通过浏览器访问tdc,访问的工作节点对于TDC都是外部节点。在提交给yarn任务后,YarnRM的Web UI 可以展示yarnnm上运行的application日志,但是由于跳转的svc地址,无法直接访问。 #在tdc界面上找到yarn实例,进入ya…

【Scala入门学习】集合常用方法和函数操作

1. foreach循环遍历 foreach 方法的原型: // f 返回的类型是Unit, foreach 返回的类型是Unit def foreach[U](f: Elem > U) 该方法接受一个函数 f 作为参数, 函数 f 的类型为Elem > U,即 f 接受一个参数,参数…

达梦数据库(DM)单机典型安装

达梦数据库(DM)单机典型安装 环境:centos7.6 1、创建用户 #增加用户和组,用于安装管理达梦数据库。 新建用户组:groupadd dinstall 新建用户:指定用户组,家目录,shell。useradd -g…

反转链表解题思路

题目描述 给定一个单链表的头结点pHead,长度为n,反转该链表后,返回新链表的表头。 示例:当输入链表{1,2,3}时,经反转后,原链表变为{3,2,1},所以对应的输出为{3,2,1}。 解题思路:迭…

股市大涨下的会展业创新者

近期,股市涨势强劲有力,各大指数普遍上扬,市场活力空前。与此同时,伴随全球经济逐步复苏及会展行业不断发展,上市展览公司机遇与挑战并存。国内外市场需求持续增长拓展了广阔发展空间,但同时行业竞争愈发激…

中国宏观经济与产业发展:挑战与机遇并存

#长沙屿# 在复杂多变的国内外经济形势之下,中国经济已然步入一个至关重要的发展阶段。今日,让我们深入剖析当前经济形势,对中国宏观经济的运行现状及产业发展的趋势展开深度探讨。 2024年,中国经济运行总体平稳、稳中有进&#x…

职场启悟:没有靠山的你,45岁前必知的5大潜规则

我是农村孩子,父母都是农民,毕业一切都是靠着自己找工作,在陌生的大城市除了认识老师就是同学。记得那是我初入职场的第三个月,每天我都沉浸在无尽的工作中,加班到深夜,周末也时常无休。我觉的农村孩子只能…

Chainbase :链原生的 Web3 AI 基建设施

“随着 Chainbase 在生态系统和市场方面的进一步拓展,其作为链原生 Web3 AI 基建设施的价值将愈发显著。” 算法、算力和数据是 AI 技术的三大核心要素。实际上,几乎所有的 AI 大模型都在不断革新算法,以确保模型能够跟上行业的发展趋势&…

机器学习中的模型设计与训练流程详解

目录 前言1. 模型设计1.1 数据特性分析1.2 计算资源限制1.3 应用场景需求 2. 模型训练2.1 训练集与验证集的划分2.2 损失函数的选择2.3 模型参数更新 3. 优化方法3.1 梯度下降法3.2 正则化方法 4. 模型测试4.1 性能评估指标4.2 模型的泛化能力 5. 模型选择5.1 数据规模与模型复…

怎么提取人声去掉背景音乐?人声提取秘籍:去掉背景音乐的技巧

在数字化时代,音频处理变得越来越普遍,我们经常需要从一段音频或视频中提取出纯净的人声,而去除掉背景音乐或其他杂音。这种需求在视频编辑、音乐制作、甚至在学习和娱乐中都十分常见。本文将介绍几种简单易行的方法,帮助你轻松提…

【Spring】获取 Cookie和Session

回顾 Cookie HTTP 协议自身是属于“无状态”协议 无状态:默认情况下,HTTP 协议的客户端和服务器之间的这次通信和下次通信之间没有直接的联系 但是在实际开发中,我们很多时候是需要知道请求之间的关联关系的 例如登录网站成功后&#xff…

抖音小游戏画图位置移动

文章目录 画图移动图形位置 画图 const canvas tt.createCanvas(); const context canvas.getContext(2d);context.width 500; context.height 500;let isPressing false; // 是否按下 let startX 0; let startY 0;context.fillStyle "#f00"; context.fillR…

@zabbix监控网站黑链接监控及数据推送

zabbix监控网站黑链接及数据推送 文章目录 zabbix监控网站黑链接及数据推送1.检测脚本1》编写python脚本2》脚本执行 2.数据推送1》方案一2》方案二 3.zabbix web 1.检测脚本 1》编写python脚本 创建脚本check_black_links.py,使用python脚本实现网站黑链接检测&a…

93. 复原 IP 地址【回溯】

文章目录 93. 复原 IP 地址解题思路Go代码 93. 复原 IP 地址 93. 复原 IP 地址 有效 IP 地址 正好由四个整数(每个整数位于 0 到 255 之间组成,且不能含有前导 0),整数之间用 . 分隔。 例如:"0.1.2.201" …

Datawhale组队学习|全球AI攻防挑战赛——赛道二:AI核身之金融场景凭证篡改检测

目录 前言Baseline代码解读 前言 Datawhale 2024.10 组队学习来了!这次选择的是动手实践专区——CV方向——“全球AI攻防挑战赛—赛道二:AI核身之金融场景凭证篡改检测”。 Baseline代码解读 1、读取数据集 !apt update > /dev/null; apt install…

美团测试面试真题学习

美团真题1–测试基础-业务场景说下你的测试用例设计 功能角度 方法论 边界值、等价类划分、错误推测法示例 输入已注册的用户名和正确的密码,验证是否登录成功;输入已注册的用户名和不正确的密码,验证是否登录失败输入未注册的用户名和任意密码&#xff…

Win10自带录屏神器?这4款工具让你秒变剪辑达人!

小伙伴们,随着电子设备使用率越来越高,日常工作中我们需要进行一些操作的演示,或者是游戏中精彩的瞬间都希望录下来,那就少不了好用的录屏工具了。这次我来跟大家聊聊Windows 10自带的那些让人惊艳的录屏工具。这不仅仅是我个人推…

Halcon 3D应用 - 胶路提取

1. 需求 本文基于某手环(拆机打磨处理)做的验证性工作,为了项目保密性,只截取部分数据进行测试。 这里使用的是海康3D线激光轮廓相机直线电机的方式进行的高度数据采集,我们拿到的是高度图亮度图数据。 提取手环上的胶…

IBM Flex System服务器硬件监控指标解读

随着企业IT架构的日益复杂,服务器的稳定运行对于保障业务连续性至关重要。IBM Flex System作为一款模块化、可扩展的服务器解决方案,广泛应用于各种企业级环境中。为了确保IBM Flex System服务器的稳定运行,监控易作为一款专业的IT基础设施监…