CV(11)-图像分割

news2025/2/8 9:40:33

前言

仅记录学习过程,有问题欢迎讨论

图像分割

  • 语义分割不需要区分具体的个体,实例分割需要

反卷积/转置卷积:

  • 它并不是正向卷积的完全逆过程。反卷积是一种特殊的正向卷积,先按照一定的比例通过补0
    来扩大输入图像的尺寸,接着旋转卷积核,再进行正向卷积。只能还原原图的尺寸,还可提升图像精度。

  • 缺点:输出大量无用信息(添0);计算比较消耗资源

语义分割– FCN (生成像素级预测,用于实例分割)

  • FCN将传统卷积网络后面的全连接层换成了卷积层,这样网络输出不再是类别而是heatmap;
    同时为了解决因为卷积和池化对图像尺寸的影响,提出使用上采样的方式恢复尺寸

  • 对图像进行像素级的分类,在上采样的特征图上进行逐像素分类

  • 增大数据尺寸的反卷积(deconv)层。能够输出精细的结果(保持一定精度)

实例分割– Mask R-CNN

  • 需要同时检测出目标的位置并且对目标进行分割,目标检测+语义分割

MASK-RCNN

与Faster RCNN的区别:

1)使用ResNet网络作为backbone
2)将 Roi Pooling 层替换成了 RoiAlign;(pooling会有误差,反卷积后误差会很大,所以要替换)

  • RoiAlign使用线性插值代替取整操作,固定像素点,使得精度提升

3)添加并列的 Mask 层;

  • 添加掩膜,分类卷积,通过RoiAlign的结果获取分类结果

4)引入FPN 和 FCN

  • FPN:提取多尺度特征( 生成特征金字塔包含多个尺度的特征图),提升目标检测性能。
  • FCN:生成像素级预测,用于实例分割

实现Mask-RCNN网络结构


import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
import numpy as np
import cv2

# 定义骨干网络,这里使用 ResNet
class ResNetBackbone(nn.Module):
    def __init__(self):
        super(ResNetBackbone, self).__init__()
        resnet = torchvision.models.resnet50(pretrained=True)
        self.features = nn.Sequential(*list(resnet.children())[:-2])

    def forward(self, x):
        x = self.features(x)
        return x

# 区域生成网络 (RPN)
class RPN(nn.Module):
    def __init__(self, in_channels, num_anchors):
        super(RPN, self).__init__()
        self.conv = nn.Conv2d(in_channels, 512, kernel_size=3, stride=1, padding=1)
        self.cls_layer = nn.Conv2d(512, num_anchors * 2, kernel_size=1, stride=1)
        self.reg_layer = nn.Conv2d(512, num_anchors * 4, kernel_size=1, stride=1)

    def forward(self, x):
        x = F.relu(self.conv(x))
        cls_scores = self.cls_layer(x)
        bbox_preds = self.reg_layer(x)
        cls_scores = cls_scores.permute(0, 2, 3, 1).contiguous().view(x.size(0), -1, 2)
        bbox_preds = bbox_preds.permute(0, 2, 3, 1).contiguous().view(x.size(0), -1, 4)
        return cls_scores, bbox_preds

# RoI Align 层
class RoIAlign(nn.Module):
    def __init__(self, output_size):
        super(RoIAlign, self).__init__()
        self.output_size = output_size

    def forward(self, features, rois):
        roi_features = []
        for i in range(features.size(0)):
            roi = rois[i]
            roi_feature = torchvision.ops.roi_align(features[i].unsqueeze(0), [roi], self.output_size)
            roi_features.append(roi_feature)
        roi_features = torch.cat(roi_features, dim=0)
        return roi_features

# Mask 分支
class MaskBranch(nn.Module):
    def __init__(self, in_channels, num_classes):
        super(MaskBranch, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, 256, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
        self.conv4 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
        self.deconv = nn.ConvTranspose2d(256, 256, kernel_size=2, stride=2)
        self.mask_layer = nn.Conv2d(256, num_classes, kernel_size=1, stride=1)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = F.relu(self.conv4(x))
        x = F.relu(self.deconv(x))
        mask_preds = self.mask_layer(x)
        return mask_preds

# Mask R-CNN 模型
class MaskRCNN(nn.Module):
    def __init__(self, num_classes):
        super(MaskRCNN, self).__init__()
        self.backbone = ResNetBackbone()
        self.rpn = RPN(2048, 9)  # 假设使用 9 个锚点
        self.roi_align = RoIAlign((14, 14))  # RoI Align 到 14x14
        self.fc1 = nn.Linear(2048 * 14 * 14, 1024)
        self.fc2 = nn.Linear(1024, 1024)
        self.cls_layer = nn.Linear(1024, num_classes)
        self.reg_layer = nn.Linear(1024, num_classes * 4)
        self.mask_branch = MaskBranch(2048, num_classes)

    def forward(self, x, rois=None):
        features = self.backbone(x)
        cls_scores, bbox_preds = self.rpn(features)
        if rois is not None:
            roi_features = self.roi_align(features, rois)
            roi_features_fc = roi_features.view(roi_features.size(0), -1)
            fc1 = F.relu(self.fc1(roi_features_fc))
            fc2 = F.relu(self.fc2(fc1))
            cls_preds = self.cls_layer(fc2)
            reg_preds = self.reg_layer(fc2)
            mask_preds = self.mask_branch(roi_features)
            return cls_preds, reg_preds, mask_preds, cls_scores, bbox_preds
        else:
            return cls_scores, bbox_preds

# 自定义数据集类
class CustomDataset(Dataset):
    def __init__(self, image_paths, target_paths, transform=None):
        self.image_paths = image_paths
        self.target_paths = target_paths
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = cv2.imread(self.image_paths[idx])
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        target = np.load(self.target_paths[idx], allow_pickle=True)
        if self.transform:
            image = self.transform(image)
        return image, target

# 数据预处理
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# 训练函数
def train(model, dataloader, optimizer, criterion_cls, criterion_reg, criterion_mask):
    model.train()
    total_loss = 0
    for images, targets in dataloader:
        images = images.to(device)
        targets = [t.to(device) for t in targets]
        optimizer.zero_grad()
        cls_preds, reg_preds, mask_preds, cls_scores, bbox_preds = model(images, targets)
        # 计算分类、回归和掩码损失
        cls_loss = criterion_cls(cls_preds, targets)
        reg_loss = criterion_reg(reg_preds, targets)
        mask_loss = criterion_mask(mask_preds, targets)
        loss = cls_loss + reg_loss + mask_loss
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(dataloader)

# 评估函数
def evaluate(model, dataloader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, targets in dataloader:
            images = images.to(device)
            targets = [t.to(device) for t in targets]
            cls_preds, reg_preds, mask_preds, _, _ = model(images)
            # 计算评估指标,这里可根据具体需求实现
            # 例如计算 mAP 等
    return correct / total

if __name__ == "__main__":
    # 假设的图像和标注文件路径
    image_paths = ['img/street.jpg', 'img/street.jpg']
    target_paths = ['target1.npy', 'target2.npy']
    dataset = CustomDataset(image_paths, target_paths, transform)
    dataloader = DataLoader(dataset, batch_size=2, shuffle=True)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    num_classes = 2  # 包括背景类
    model = MaskRCNN(num_classes).to(device)
    optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
    criterion_cls = nn.CrossEntropyLoss()
    criterion_reg = nn.SmoothL1Loss()
    criterion_mask = nn.BCEWithLogitsLoss()  # 用于掩码的损失函数
    num_epochs = 10
    for epoch in range(num_epochs):
        loss = train(model, dataloader, optimizer, criterion_cls, criterion_reg, criterion_mask)
        print(f'Epoch {epoch + 1}/{num_epochs}, Loss: {loss}')
    # 评估
    accuracy = evaluate(model, dataloader)
    print(f'Accuracy: {accuracy}')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2294717.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【STM32系列】利用MATLAB配合ARM-DSP库设计FIR数字滤波器(保姆级教程)

ps.源码放在最后面 设计IIR数字滤波器可以看这里:利用MATLAB配合ARM-DSP库设计IIR数字滤波器(保姆级教程) 前言 本篇文章将介绍如何利用MATLAB与STM32的ARM-DSP库相结合,简明易懂地实现FIR低通滤波器的设计与应用。文章重点不在…

STM32上部署AI的两个实用软件——Nanoedge AI Studio和STM32Cube AI

1 引言 STM32 微控制器在嵌入式领域应用广泛,因为它性能不错、功耗低,还有丰富的外设,像工业控制、智能家居、物联网这些场景都能看到它的身影。与此同时,人工智能技术发展迅速,也逐渐融入各个行业。 把 AI 部署到 STM…

qt+gstreamer快速创建一个流媒体播放器

目录 1 前言 2 playbin3 3 videooverlay 4 关键代码 5 运行示例 1 前言 最近因为工作需求,要实现一个桌面流媒体播放器来支持常见的流媒体协议,经过调研发现使用gstreamer配合一些桌面级的gui应用开发工具如qt可以进行快速实现,在此进…

DeepSeek V2报告阅读

概况 MoE架构,236B参数,每个token激活参数21B,支持128K上下文。采用了包括多头潜在注意力(MLA)和DeepSeekMoE在内的创新架构。MLA通过将KV缓存显著压缩成潜在向量来保证高效的推理,而DeepSeekMoE通过稀疏计…

零基础Vue入门6——Vue router

本节重点: 路由定义路由跳转 前面几节学习的都是单页面的功能(都在专栏里面https://blog.csdn.net/zhanggongzichu/category_12883540.html),涉及到项目研发都是有很多页面的,这里就需要用到路由(vue route…

关于JS继承的七种方式和理解

1.原型链继承 function Fun1() {this.name parentthis.play [1, 2, 3] } function Fun2() {this.type child }Fun2.prototype new Fun1()let s1 new Fun2() let s2 new Fun2() s1.play.push(4) console.log(s1.play, s2.play) // [1, 2, 3, 4] [1, 2, 3, 4]可以看到两个…

【Vue】在Vue3中使用Echarts的示例 两种方法

文章目录 方法一template渲染部分js部分方法一实现效果 方法二template部分js or ts部分方法二实现效果 贴个地址~ Apache ECharts官网地址 Apache ECharts示例地址 官网有的时候示例显示不出来,属于正常现象,多进几次就行 开始使用前,记得先…

每日Attention学习18——Grouped Attention Gate

模块出处 [ICLR 25 Submission] [link] UltraLightUNet: Rethinking U-shaped Network with Multi-kernel Lightweight Convolutions for Medical Image Segmentation 模块名称 Grouped Attention Gate (GAG) 模块作用 轻量特征融合 模块结构 模块特点 特征融合前使用Group…

124,【8】buuctf web [极客大挑战 2019] Http

进入靶场 查看源码 点击 与url有关,抓包 over

源路由 | 源路由网桥 / 生成树网桥

注:本文为 “源路由” 相关文章合辑。 未整理去重。 什么是源路由(source routing)? yzx99 于 2021-02-23 09:45:51 发布 考虑到一个网络节点 A 从路由器 R1 出发,可以经过两台路由器 R2、R3,到达相同的…

FPGA的IP核接口引脚含义-快解

疑问 手册繁琐,怎样快速了解IP核各输入输出接口引脚的含义。 答疑 不慌不慌,手册确实比较详细但繁琐,如何快速知晓该部分信息,涛tao道长给你们说,简单得很,一般新入门的道友有所不知,往往后面…

Qwen2-VL-2B-Instruct 模型 RK3576 板端部署过程

需要先在电脑上运行 RKLLM-Toolkit 工具,将训练好的模型转换为 RKLLM 格式的模型,然后使用 RKLLM C API 在开发板上进行推理。 在安装前先查看板端的内存容量,和自己模型占用大小比较一下,别安装编译好了不能用。 这里我就是先尝试…

如何设计光耦电路

光耦长这样,相信小伙伴们都见过,下图是最为常用的型号PC817 怎么用?我们先看图,如下图1: Vin为输入信号,一般接MCU的GPIO口,由于这里的VCC1为3.3V,故MCU这边的供电电源不能超过3.3V…

ADC模数转换器概念函数及应用

ADC模数转换器概念函数及应用 文章目录 ADC模数转换器概念函数及应用1.ADC简介2.逐次逼近型ADC2.1逐次逼近型ADC2.2stm32逐次逼近型2.3ADC基本结构2.4十六个通道 3.规则组的4种转换模式3.1单次转换,非扫描模式3.2连续转换,非扫描模式3.3单次转换&#xf…

DFX(Design for eXcellence)架构设计全解析:理论、实战、案例与面试指南*

一、什么是 DFX ?为什么重要? DFX(Design for eXcellence,卓越设计)是一种面向产品全生命周期的设计理念,旨在确保产品在设计阶段就具备**良好的制造性(DFM)、可测试性(…

【LeetCode】152、乘积最大子数组

【LeetCode】152、乘积最大子数组 文章目录 一、dp1.1 dp1.2 简化代码 二、多语言解法 一、dp 1.1 dp 从前向后遍历, 当遍历到 nums[i] 时, 有如下三种情况 能得到最大值: 只使用 nums[i], 例如 [0.1, 0.3, 0.2, 100] 则 [100] 是最大值使用 max(nums[0…i-1]) * nums[i], 例…

《云夹:让书签管理变得轻松又高效》

在当今数字化的生活与工作场景中,我们畅游于网络的浩瀚海洋,每天都会邂逅各式各样有价值的网页内容。而如何妥善管理这些如繁星般的书签,使其能在我们需要时迅速被找到,已然成为众多网络使用者关注的焦点。云夹,作为一…

Microsoft Fabric - 尝试一下在pipeline中发送请求给web api(获取数据和更新数据)

1.简单介绍 Microsoft Fabric中的Pipeline支持很多种activity,分成数据转换和控制流两种类型的activitly。 这边将尝试一下发送web请求的activity,要做成的pipeline大概如下图所示, 上图中有4个Activity,作用如下 Web - 从一个…

数据完整性与约束的分类

一、引言 为什么需要约束?为了保证数据的完整性。 (1)数据完整性 数据完整性指的是数据的精确性和可靠性。 为了保证数据的完整性,SQL对表数据进行额外的条件限制,从以下四方面考虑: ①实体完整性&…

docker安装nacos2.x

本文为单机模式,非集群教程,埋坑 nacos2.x官方强制条件 64 bit OS,支持 Linux/Unix/Mac/Windows,推荐选用 Linux/Unix/Mac。 64 bit JDK 1.8 Maven 3.2.x 环境介绍 centos 7 maven 3.9.9 jdk 17 nacos 2.3.1 1. 拉取docker镜像 d…