pytorch U²-Net教程

news2024/9/23 6:35:43

U²-Net (U2-Net) 是一个用于图像分割的神经网络模型,特别擅长于边界复杂的物体分割任务,如前景背景分割和抠图。U²-Net 的独特之处在于其 U 形结构和嵌套 U 形块,能够有效捕捉不同尺度的特征,同时保持较小的模型大小。它非常适合在资源受限的环境下使用。

官方文档链接

U²-Net 本身并没有一个独立的 Python 库,但可以通过 官方 GitHub 仓库 获取源码和模型细节。


一、U²-Net 架构概述

U²-Net 是基于 U-Net 结构的改进模型,由多个嵌套的 U 形编码器-解码器模块组成。其创新点在于 U2 模块,它在不同尺度上提取特征,增强了对边界信息的捕捉能力。

U²-Net 结构包含:

  1. 编码器(Encoder):使用多尺度卷积核提取图像的特征,逐渐压缩特征图尺寸。
  2. 解码器(Decoder):通过逐步上采样,恢复原始分辨率,同时结合编码器的跳跃连接。
  3. U2 模块:嵌套的 U 形块,能够同时处理不同分辨率的特征,从而保留高分辨率的局部细节和低分辨率的全局语义信息。

二、基础功能

在 U²-Net 中,通常的工作流程是加载预训练模型并对输入图像进行分割。U²-Net 最常见的任务是图像前景提取,比如抠图。

1. 加载 U²-Net 模型

从官方 GitHub 下载预训练模型权重,并通过 PyTorch 加载。

import torch
import torchvision.transforms as transforms
from PIL import Image
import numpy as np

# 加载预训练的 U²-Net 模型
model = torch.load('u2net.pth')
model.eval()  # 设置为评估模式

# 准备图像输入
def load_image(image_path):
    transform = transforms.Compose([
        transforms.Resize((320, 320)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    image = Image.open(image_path).convert('RGB')
    image = transform(image).unsqueeze(0)
    return image

# 加载图片并转换为张量
input_image = load_image("input_image.jpg")

# 前向传播,生成分割结果
with torch.no_grad():
    result = model(input_image)

2. 处理模型输出

U²-Net 的输出通常为前景掩码 (mask),可以通过阈值处理生成二值化图像。

def process_output(output):
    # 提取前景掩码
    mask = output[0][0].squeeze().cpu().numpy()
    
    # 归一化到0-1范围
    mask = (mask - np.min(mask)) / (np.max(mask) - np.min(mask))
    
    # 二值化处理
    mask = (mask > 0.5).astype(np.uint8)
    
    return mask

# 处理输出的前景掩码
foreground_mask = process_output(result)

三、进阶功能

1. 前景提取并保存透明 PNG

U²-Net 可以用于精细化的图像前景提取。通过将背景像素设置为透明,生成透明的 PNG 图片。

from PIL import Image

def save_foreground(image_path, mask, save_path):
    image = Image.open(image_path).convert('RGBA')
    width, height = image.size
    mask = Image.fromarray(mask * 255).resize((width, height), Image.BILINEAR)
    
    # 转换为 RGBA 格式,将背景设置为透明
    image_data = np.array(image)
    mask_data = np.array(mask)
    
    # 将背景区域的 alpha 通道设置为 0(完全透明)
    image_data[:, :, 3] = mask_data
    
    # 保存带有透明背景的 PNG 图片
    output_image = Image.fromarray(image_data)
    output_image.save(save_path)

# 使用掩码提取前景并保存
save_foreground("input_image.jpg", foreground_mask, "output_image.png")

2. 使用其他输入尺寸

虽然 U²-Net 默认是使用 320x320 的输入尺寸,但它对不同的输入尺寸有一定的适应性。我们可以根据需要调整输入图像的大小。

# 自定义输入尺寸
def load_image_custom_size(image_path, size=(320, 320)):
    transform = transforms.Compose([
        transforms.Resize(size),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    image = Image.open(image_path).convert('RGB')
    image = transform(image).unsqueeze(0)
    return image

# 调整输入图像尺寸
custom_size_image = load_image_custom_size("input_image.jpg", size=(512, 512))

四、高级教程

U²-Net 的高级用法可以结合其他深度学习框架或任务,例如对分割结果进行进一步的图像处理或增强。

1. 与 OpenCV 结合处理分割结果

可以利用 OpenCV 对分割后的图像进行一些后处理,例如边缘检测、轮廓提取等。

import cv2

def process_with_opencv(mask):
    # 使用 OpenCV 检测轮廓
    contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    
    # 绘制轮廓
    contour_image = np.zeros_like(mask)
    cv2.drawContours(contour_image, contours, -1, (255), 2)
    
    return contour_image

# 使用 OpenCV 处理分割结果
contour_image = process_with_opencv(foreground_mask)
cv2.imwrite("contour_image.png", contour_image)

2. 自定义损失函数与训练

如果需要训练自己的 U²-Net 模型,可以基于 Binary Cross Entropy 损失函数进行训练。以下是一个自定义损失函数的示例。

import torch.nn as nn

class U2NetLoss(nn.Module):
    def __init__(self):
        super(U2NetLoss, self).__init__()
        self.bce_loss = nn.BCELoss()

    def forward(self, d0, d1, d2, d3, d4, d5, d6, labels):
        # 对不同尺度的预测进行加权损失计算
        loss0 = self.bce_loss(d0, labels)
        loss1 = self.bce_loss(d1, labels)
        loss2 = self.bce_loss(d2, labels)
        loss3 = self.bce_loss(d3, labels)
        loss4 = self.bce_loss(d4, labels)
        loss5 = self.bce_loss(d5, labels)
        loss6 = self.bce_loss(d6, labels)
        return loss0 + loss1 + loss2 + loss3 + loss4 + loss5 + loss6

3. 模型优化与推理加速

U²-Net 的推理速度在某些情况下可能是瓶颈,尤其在移动端。可以通过模型量化、剪枝或者使用推理加速库(如 TensorRT)来提高效率。


五、总结

U²-Net 是一个轻量级、功能强大的模型,专注于高质量的前景分割任务。它具有以下特点:

  1. 多尺度特征捕捉:通过 U2 模块,U²-Net 能够捕捉到不同尺度的细节,适用于精细的边缘分割任务。
  2. 易于使用:通过 PyTorch 实现,能够轻松加载预训练模型并进行推理。
  3. 适应性强:U²-Net 适用于不同分辨率的输入图像,具有良好的推广性。

如果你有更多问题或需要代码测试,请随时告诉我!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2156883.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【测试】——JUnit

📖 前言:JUnit 是一个流行的 Java 测试框架,主要用于编写和运行单元测试,用来管理测试用例。本文采用JUnit 5 目录 🕒 1. 添加依赖🕒 2. 注解🕘 2.1 Test🕘 2.2 BeforeAll AfterAll&…

【Docker】基于docker compose部署artifactory-cpp-ce服务

基于docker compose部署artifactory-cpp-ce服务 1 环境准备2 必要文件创建与编写3 拉取镜像-创建容器并后台运行4 访问JFog Artifactory 服务 1 环境准备 docker 以及其插件docker compose ,我使用的版本如下图所示: postgresql 的jdbc驱动, 我使用的是…

【图像检索】基于纹理(LBP)和形状特征的图像检索,matlab实现

博主简介:matlab图像代码项目合作(扣扣:3249726188) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 本次案例是基于纹理(LBP)和形状特征(hu特征)的图像检索,用m…

力扣206.反转链表

力扣《反转链表》系列文章目录 刷题次序,由易到难,一次刷通!!! 题目题解206. 反转链表反转链表的全部 题解192. 反转链表 II反转链表的指定段 题解224. 两两交换链表中的节点两个一组反转链表 题解325. K 个一组翻转…

【二等奖论文】2024年华为杯研究生数学建模E题成品论文获取入口

您的点赞收藏是我继续更新的最大动力! 一定要点击如下的卡片链接,那是获取资料的入口! 点击链接获取【2024华为杯研赛资料汇总】: https://qm.qq.com/q/Wgk64ntZCihttps://qm.qq.com/q/Wgk64ntZCi 详细建模思路: 要解…

C++--模板(template)详解—— 函数模板与类模板

目录 1.泛型编程 2.函数模板 2.1 函数模板概念 2.2 函数模板格式 2.3 函数模板的原理 2.4 函数模板的实例化 2.5 模板参数的匹配原则 3.类模板 3.1 类模板的定义格式 3.2 类模板的实例化 1.泛型编程 在C中如果让你写一个交换函数,应该怎么做呢&#xff1f…

二叉树进阶【c++实现】【二叉搜索树的实现】

目录 二叉树进阶1.二叉搜索树1.1二叉搜索树的实现1.1.1二叉搜索树的查找1.1.2二叉搜索树的插入1.1.3中序遍历(排序)1.1.4二叉搜索树的删除(重点) 1.2二叉搜索树的应用1.2.1K模型1.2.2KV模型 1.3二叉搜索树的性能分析 二叉树进阶 前言: map和set特性需要先铺垫二叉搜…

Python3网络爬虫开发实战(16)分布式爬虫(第一版)

文章目录 一、分布式爬虫原理1.1 分布式爬虫架构1.2 维护爬取队列1.3 怎样来去重1.4 防止中断1.5 架构实现 二、Scrapy-Redis 源码解析2.1 获取源码2.2 爬取队列2.3 去重过滤2.4 调度器 三、Scrapy 分布式实现3.1 准备工作3.2 搭建 Redis 服务器3.3 部署代理池和 Cookies 池3.4…

超越sora,最新文生视频CogVideoX-5b模型分享

CogVideoX-5B是由智谱 AI 开源的一款先进的文本到视频生成模型,它是 CogVideoX 系列中的更大尺寸版本,旨在提供更高质量的视频生成效果。 CogVideoX-5B 采用了 3D 因果变分自编码器(3D causal VAE)技术,通过在空间和时…

【OpenAI o1背后技术】Sef-play RL:LLM通过博弈实现进化

【OpenAI o1背后技术】Sef-play RL:LLM通过博弈实现进化 OpenAI o1是经过强化学习训练来执行复杂推理任务的新型语言模型。特点就是,o1在回答之前会思考——它可以在响应用户之前产生一个很长的内部思维链。也就是该模型在作出反应之前,需要…

简单题104. 二叉树的最大深度 (python)20240922

问题描述: python: # Definition for a binary tree node. # class TreeNode(object): # def __init__(self, val0, leftNone, rightNone): # self.val val # self.left left # self.right right class Solution(object…

Python 入门(一、使用 VSCode 开发 Python 环境搭建)

Python 入门第一课 ,环境搭建...... by 矜辰所致前言 现在不会 Python ,好像不那么合适,咱先不求精通,但也不能不会,话不多说,开干! 这是 Python 入门第一课,当然是做好准备工作&a…

论前端框架的对比和选择 依据 前端框架的误区

前端框架的对比和选择依据 在前端开发中,有多种框架可供选择,以下是一些常见前端框架的对比和选择依据: 一、Vue.js 特点: 渐进式框架,灵活度高,可以逐步引入到项目中。学习曲线相对较平缓,容…

Java项目实战II基于Java+Spring Boot+MySQL的民宿在线预定平台(开发文档+源码+数据库)

目录 一、前言 二、技术介绍 三、系统实现 四、文档参考 五、核心代码 六、源码获取 全栈码农以及毕业设计实战开发,CSDN平台Java领域新星创作者,专注于大学生项目实战开发、讲解和毕业答疑辅导。获取源码联系方式请查看文末 一、前言 在旅游市场…

强大的重命名工具 | Bulk Rename Utility v4.0 便携版

软件简介 Bulk Rename Utility是一款功能强大且易于使用的文件批量重命名工具。它不仅体积小巧,而且完全免费,提供了友好的用户界面。该软件允许用户对文件或文件夹进行批量重命名,支持递归操作,即包含子文件夹的重命名。 软件特…

Apache Iceberg 概述

Apache Iceberg概述 一、what is Apache Iceberg? 为了解决数据存储和计算引擎之间的适配的问题,Netflix开发了Iceberg,2018年11月16日进入Apache孵化器,2020 年5月19日从孵化器毕业,成为Apache的顶级项目。 Apache…

SpringBoot实战(三十)发送HTTP/HTTPS请求的五种实现方式【下篇】(Okhttp3、RestTemplate、Hutool)

目录 一、五种实现方式对比结果二、Demo接口地址实现方式三、Okhttp3 库实现3.1 简介3.2 Maven依赖3.3 配置文件3.4 配置类3.5 工具类3.6 示例代码3.7 执行结果实现方式四、Spring 的 RestTemplate 实现4.1 简介4.2 Maven依赖4.3 配置文件4.4 配置类4.5 HttpClient 和 RestTemp…

华为HarmonyOS灵活高效的消息推送服务(Push Kit) - 5 发送通知消息

场景介绍 通知消息通过Push Kit通道直接下发,可在终端设备的通知中心、锁屏、横幅等展示,用户点击后拉起应用。您可以通过设置通知消息样式来吸引用户。 开通权益 Push Kit根据消息内容,将通知消息分类为服务与通讯、资讯营销两大类别&…

idea2021git从dev分支合并到主分支master

1、新建分支 新建一个名称为dev的分支,切换到该分支下面,输入新内容 提交代码到dev分支的仓库 2、切换分支 切换到主分支,因为刚刚提交的分支在dev环境,所以master是没有 3、合并分支 点击push,将dev里面的代码合并到…

Spring AI Alibaba,阿里的AI Java 开发框架

源码地址 https://github.com/alibaba/spring-ai-alibaba