U2Net使用方法和实现多类别语义分割模型改造

news2025/1/16 6:02:07

作者的碎碎念:U2Net是用来实现SOD的语义分割,本篇论文会介绍算法内容、主要代码、使用方法,以及如何将二分类语义分割修改为多类别语义模型。如果只想知道怎么训练自己的数据集,或者如何修改网络,可以通过目录进行跳转。
欢迎点赞、评论或收藏❤️


文章目录

  • (一)相关链接
  • (二)算法内容
    • 1. 摘要
    • 2. 介绍
    • 3. 网络架构
    • 4. loss函数
    • 5. 作者实验结果
  • (三)如何训练自己的数据
    • 1. 标注
    • 2. mask图像
    • 3. 训练数据集格式
    • 4. 配置文件修改
    • 5. 训练命令
    • 6. 测试命令
  • (四)多类别语义分割
    • 1. 实现思路
    • 2. 修改方法
    • 4. 测试
    • 5. 训练测试效果

(一)相关链接

  1. 论文名称
    《U2-Net: Going Deeper with Nested U-Structure for Salient Object Detection》
  2. github链接
    https://github.com/xuebinqin/U-2-Net
  3. paper
    https://arxiv.org/pdf/2005.09007.pdf

(二)算法内容

1. 摘要

  U²-Net是显著物体检测(salient object detection,简写SOD)的一个网络,并且现在已经是Python的抠图工具Rembg的基础算法

  • 什么是SOD?
      SOD是模拟人类视觉感知系统来定位场景中最吸引人的目标,例如人像
  • 算法优点总结
    (1)能获取到更多的上下文信息(RSU块,ReSidual U-blocks)
    (2)增加网络深度但没有增加计算量。并且可以从0开始训练,不用从分类预训练网络中再训练
  • 模型大小
      U2-Net (176.3 MB, 30 FPS on GTX 1080Ti GPU)
      U2-Net†(4.7 MB, 40 FPS)

2. 介绍

  • 现有的SOD网络存在什么问题?
    (1)现有的模式基本都是使用已有的backbone,例如AlexNet、VGG、ResNet。这些基础的网络都是为分类任务而设计的,提取的特征更多是语义特征,而不是定位特征和全局对比的信息。
    (2)耗用大量的资源
    (3)牺牲高分辨率的特征映射来实现更深层次的体系结构
  • U2Net的目标是网络更深、使用的资源和计算量更少、能够保持高分辨率的特征图。怎么做呢?
    (1)用两级的内嵌U型结构,不使用分类的backbone
    (2)新型的网络结构更深、能获取高分辨率图像、不增加内存和计算量

3. 网络架构

  • 卷积结构和RSU结构比对
    在这里插入图片描述

(1)( a ) Plain convolution blockPLN
     ( b ) Residual-like block RES
     ( c ) Dense-like block DSE
     ( d ) Inception-like block INC
     ( e ) Our residual U-blockRSU
(2)(a)到( c )是典型的卷积结构,用了1x1和3x3的卷积,感受野太小,只能用来获取local feature
(3)(d)用了空洞卷积增大了感受野,但是需要大的内存和计算资源
(4)RSU-L模块,(L代表层数),Cin:输入通道,Cout:输出通道,M:RSU内部通道

  • 开销比对
    在这里插入图片描述
    RSU的开销(overhead)不大,因为都是下采样,DSE和INC比较大
  • 残差结构比对
    在这里插入图片描述
    (1)残差块:H(x) = F2(F1(x))+x,H(x)是x的映射,F1和F2是卷积操作【对应两个weight layer】
    (2)RSU:HRSU (x) = U(F1(x))+F1(x),RSU和残差不同的地方,是将卷积替换成像Unet的U型结构U-block,原来的输入x替换成F1(x)【weight layer之后】
  • 网络架构
    在这里插入图片描述

  U-Net-like这种结构本来就有,只不过是级联起来,Uxn Net,而作者提出来的是 Un Net,用内嵌(nested)结构而不是级联结构
(1)结构特点:11个stage,每个stage都是RSU结构
   🔸 a six stages encoder
   🔸a five stages decoder
   🔸a saliency map fusion module attached with the decoder stages and the last encoder stage
(2)编码器:
   🔹En_1、En_2、En_3、En_4(即前四个)用到的RSU层数是 RSU-7、 RSU-6、 RSU-5、 RSU-4,层数越多,尺度信息越丰富
   🔹En-5和En-6用了RSU-4F,用了空洞卷积,保证了输入输出是相同的分辨率
(3)解码器:
   De-5也是用了RSU-4F,和En-5、En-6类似
(4)融合模块(saliency map fusion module):
   编码器和解码器的输出,经过3x3卷积和sigmoid,upsample,输出了6个概率热力图:S_side(6)、S_side(5)、S_side(4)、S_side(3)、S_side(2)、S_side(1) ,用1x1卷积进行融合,产生了S_fuse

4. loss函数

在这里插入图片描述
✅总Loss等于所有loss之和,包括S_side(6)、S_side(5)、S_side(4)、S_side(3)、S_side(2)、S_side(1),和融合的S_fuse
在这里插入图片描述
✅每一层的S_side(x)的loss,使用了二分类交叉熵损失函数

5. 作者实验结果

在这里插入图片描述
Red, Green, and Blue indicate the best, second best and third best performance
在这里插入图片描述

(三)如何训练自己的数据

1. 标注

用labelme标注图片,生成json文件
在这里插入图片描述

2. mask图像

将json文件转换为mask图片,背景黑色,物体白色,下面是转换代码:

import cv2
import json
import numpy as np
import os
import sys


def func(file:str) -> np.ndarray:
    with open(file, mode='r', encoding="utf-8") as f:
        configs = json.load(f)
    shapes = configs["shapes"]

    png = np.zeros((configs["imageHeight"], configs["imageWidth"], 3), np.uint8)

    for shape in shapes:
        cv2.fillPoly(png, [np.array(shape["points"], np.int32)], (255,255,255))

    return png


if  __name__ == "__main__":

    if len(sys.argv) != 3:
        raise ValueError("json文件或目录 输出路径")

    if os.path.isdir(sys.argv[1]):
        for file in os.listdir(sys.argv[1]):
            cv2.imwrite(os.path.join(sys.argv[2], os.path.splitext(file)[0]+".png" ), func(os.path.join(sys.argv[1], file)))
    else:
        cv2.imwrite(os.path.join(sys.argv[2], os.path.splitext(os.path.basename(sys.argv[1]))[0]+".png"), func(sys.argv[1]))

在这里插入图片描述

转换的mask图像

3. 训练数据集格式

1️⃣在工程目录创建目录:train_data/DUTS/DUTS-TR/DUTS-TR/
2️⃣在第一步骤创建的目录上,创建目录im_aug,将原图放在这
3️⃣在第一步骤创建的目录上,创建目录gt_aug,将转换的mask图放在这

4. 配置文件修改

  打开u2net_train.py,一般可以设置这几项:
  model_name = ‘u2net’ # 用u2net或者u2netp模型进行训练
  epoch_num = 100000 # 训练轮次
  batch_size_train = 12 # batchsize
  save_frq = 2000 # 每2000个iter保存一个模型

5. 训练命令

python u2net_train.py

6. 测试命令

python u2net_test.py

(四)多类别语义分割

  作者提供的代码只实现了二分类的语义分割,U2Net是否可以用来做多类别的语义分割?答案是可以了,下面提供了将二分类语义分割转换为多类别语义分割的方法

1. 实现思路

🔺项目背景:图片有两个类别,分别是螺丝钉和位移线
🔺类别:两个类别+背景,num_class = 3,如果有更多类别,则是n+1类,1是背景
🔺mask图片:二分类时,填充的是0和255;多分类,不同类别可以填充为0(背景)、1(螺丝钉)、2(位移线),所以最多只能分出0~255个类别。查看3个类别的mask,因为像素值只有0、1、2,肉眼看基本是一张黑色图像
🔺模型输出:三个类别,输出三个通道,如[3, 320, 320],每一个通道代表一个类别

2. 修改方法

(1)获取多类别训练mask脚本

import cv2
import json
import numpy as np
import os
import sys


def func(file):
    with open(file, mode='r', encoding="utf-8") as f:
        configs = json.load(f)
    shapes = configs["shapes"]

    png = np.zeros((configs["imageHeight"], configs["imageWidth"], 3), np.uint8)

    for shape in shapes:
        label = shape['label']
        if label == 'lm':
            cv2.fillPoly(png, [np.array(shape["points"], np.int32)], (1,1,1))
        else:
            cv2.fillPoly(png, [np.array(shape["points"], np.int32)], (2,2,2))

    return png


if  __name__ == "__main__":
    json_dir = "./train_data/labels_json"
    
    save_dir = './train_data/masks'


    for file in os.listdir(json_dir):
        print(file)
        png = func(os.path.join(json_dir, file))
        print(png.shape)
        save_path = save_dir+'/'+os.path.splitext(file)[0]+".png"
        cv2.imwrite(save_path, png)
        print(save_path)

(2)data_loader.py
   class ToTensor(object)和class ToTensorLab(object)这两个类中,有对label进行归一化操作,去除该操作,因为计算loss的时候,多类别换成交叉熵损失函数,它本身包含了softmax操作
在这里插入图片描述
(3)model/u2net.py
   修改模型输出,作者在class U2NETP(nn.Module)和class U2NET(nn.Module)这两个类用了sigmoid函数,需要修改为直接输出,原因同上

# return F.sigmoid(d0), F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6)
return d0, d1, d2, d3, d4, d5, d6

(4)u2net_train.py
   修改损失函数和模型输出通道,将损失函数由原来的BCELoss,修改为CrossEntropyLoss,并设置模型的输出通道和类别一致

# bce_loss = nn.BCELoss(size_average=True)  # 注释
ce_loss = nn.CrossEntropyLoss()  # 添加
# def muti_bce_loss_fusion(d0, d1, d2, d3, d4, d5, d6, labels_v):  # 注释
#     loss0 = bce_loss(d0, labels_v)
#     loss1 = bce_loss(d1, labels_v)
#     loss2 = bce_loss(d2, labels_v)
#     loss3 = bce_loss(d3, labels_v)
#     loss4 = bce_loss(d4, labels_v)
#     loss5 = bce_loss(d5, labels_v)
#     loss6 = bce_loss(d6, labels_v)

#     loss = loss0 + loss1 + loss2 + loss3 + loss4 + loss5 + loss6
#     print("l0: %3f, l1: %3f, l2: %3f, l3: %3f, l4: %3f, l5: %3f, l6: %3f\n" % (
#     loss0.data.item(), loss1.data.item(), loss2.data.item(), loss3.data.item(), loss4.data.item(), loss5.data.item(),
#     loss6.data.item()))

#     return loss0, loss

def muti_ce_loss_fusion(d0, d1, d2, d3, d4, d5, d6, labels_v):  # 添加
    loss0 = ce_loss(d0, labels_v)
    loss1 = ce_loss(d1, labels_v)
    loss2 = ce_loss(d2, labels_v)
    loss3 = ce_loss(d3, labels_v)
    loss4 = ce_loss(d4, labels_v)
    loss5 = ce_loss(d5, labels_v)
    loss6 = ce_loss(d6, labels_v)

    loss = loss0 + loss1 + loss2 + loss3 + loss4 + loss5 + loss6
    print("l0: %3f, l1: %3f, l2: %3f, l3: %3f, l4: %3f, l5: %3f, l6: %3f\n" % (
    loss0.data.item(), loss1.data.item(), loss2.data.item(), loss3.data.item(), loss4.data.item(), loss5.data.item(),
    loss6.data.item()))

    return loss0, loss
# ------- 3. define model --------
# define the net
n_class = 3
if (model_name == 'u2net'):
    net = U2NET(3, n_class)
elif (model_name == 'u2netp'):
    net = U2NETP(3, n_class)

4. 测试

   该例子中,存在三个类别,分别是背景、螺丝钉、位移线,对应模型三个通道的输出,但模型输出为概率值,如何获取到真实的类别,以及将类别用不同颜色表示出来?可以用下面这个脚本实现模型推理和输出结果图

import os
import cv2
from skimage import io, transform
import torch
import torchvision
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms#, utils
# import torch.optim as optim

import numpy as np
from PIL import Image
import glob

from data_loader import RescaleT
from data_loader import ToTensor
from data_loader import ToTensorLab
from data_loader import SalObjDataset

from model import U2NET # full size version 173.6 MB
from model import U2NETP # small version u2net 4.7 MB

# normalize the predicted SOD probability map
def normPRED(d):
    ma = torch.max(d)
    mi = torch.min(d)

    dn = (d-mi)/(ma-mi)

    return dn

def save_output(image_name,pred,d_dir):

    predict = pred
    predict = predict.squeeze()
    predict_np = predict.cpu().data.numpy()

    im = Image.fromarray(predict_np*255).convert('RGB')
    img_name = image_name.split(os.sep)[-1]
    image = io.imread(image_name)
    imo = im.resize((image.shape[1],image.shape[0]),resample=Image.BILINEAR)

    pb_np = np.array(imo)

    aaa = img_name.split(".")
    bbb = aaa[0:-1]
    imidx = bbb[0]
    for i in range(1,len(bbb)):
        imidx = imidx + "." + bbb[i]

    imo.save(d_dir+imidx+'.png')

def main():

    # --------- 1. get image path and name ---------
    model_name='u2net'#u2netp

    num_class = 3

    image_dir = os.path.join(os.getcwd(), 'test_data', 'ls_test_images')
    prediction_dir = os.path.join(os.getcwd(), 'test_data', model_name + '_results_ls' + os.sep)
    model_dir = os.path.join(os.getcwd(), 'saved_models', model_name, 'u2net_bce_itr_1000_train_1.046126_tar_0.124982.pth')

    img_name_list = glob.glob(image_dir + os.sep + '*')
    print(img_name_list)

    # --------- 2. dataloader ---------
    #1. dataloader
    test_salobj_dataset = SalObjDataset(img_name_list = img_name_list,
                                        lbl_name_list = [],
                                        transform=transforms.Compose([RescaleT(320),
                                                                      ToTensorLab(flag=0)])
                                        )
    test_salobj_dataloader = DataLoader(test_salobj_dataset,
                                        batch_size=1,
                                        shuffle=False,
                                        num_workers=1)

    # --------- 3. model define ---------
    if(model_name=='u2net'):
        print("...load U2NET---173.6 MB")
        net = U2NET(3,num_class)
    elif(model_name=='u2netp'):
        print("...load U2NEP---4.7 MB")
        net = U2NETP(3,num_class)

    if torch.cuda.is_available():
        net.load_state_dict(torch.load(model_dir))
        net.cuda()
    else:
        net.load_state_dict(torch.load(model_dir, map_location='cpu'))
    net.eval()

    # --------- 4. inference for each image ---------
    for i_test, data_test in enumerate(test_salobj_dataloader):

        print("inferencing:",img_name_list[i_test].split(os.sep)[-1])

        inputs_test = data_test['image']

        image = cv2.imread(img_name_list[i_test])
        image_name = os.path.basename(img_name_list[i_test])

        inputs_test = inputs_test.type(torch.FloatTensor)

        if torch.cuda.is_available():
            inputs_test = Variable(inputs_test.cuda())
        else:
            inputs_test = Variable(inputs_test)

        d1,d2,d3,d4,d5,d6,d7= net(inputs_test)
        d1 = d1.squeeze(dim=0)    # torch.Size([1, 3, 320, 320]) -> torch.Size([3, 320, 320])
        
        d1 = F.softmax(d1, dim=0)   # [3, 320, 320] 
        # print(d1[0, :, :])

        predict_np = torch.argmax(d1, dim=0, keepdim=True)
        # print(predict_np.shape)  # [1, 320, 320],3个类别,对应3个通道,获取概率值最高的下标

        predict_np = predict_np.cpu().detach().numpy().squeeze()   # 转到cpu设备

        predict_np = cv2.resize(predict_np, (image.shape[1], image.shape[0]), interpolation=cv2.INTER_NEAREST)  # resize和原图一样的大小
        
        r = predict_np.copy()
        b = predict_np.copy()
        g = predict_np.copy()

        cls = dict([(1, (0, 0, 255)),
                    (2, (255, 0, 255)),
                    (3, (0, 255, 0)),
                    (4, (255, 0, 0)),
                    (5, (255, 255, 0))])
        for c in cls:
            r[r == c] = cls[c][0]
            g[g == c] = cls[c][1]
            b[b == c] = cls[c][2]

        rgb = np.zeros((image.shape[0], image.shape[1], 3))
        # print('类别', np.unique(predict_np))
        rgb[:, :, 0] = r
        rgb[:, :, 1] = g
        rgb[:, :, 2] = b

        im = Image.fromarray(rgb.astype(np.uint8))
        im.save('./test_data/my_results_2/' + str(image_name)[:-4] + '.png')

        del d1,d2,d3,d4,d5,d6,d7

if __name__ == "__main__":
    main()

5. 训练测试效果

   经过少量数据的训练测试,证明U2Net可以用来做多类别语义分割
输入图片

输入测试图片

在这里插入图片描述

模型测试效果

撒花完结🌟🌟🌟

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/980702.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【C语言】错题本(2)

题目: 将题目代码粘贴在下面便于分析: #define MAX_SIZE AB struct _Record_Struct {unsigned char Env_Alarm_ID : 4;unsigned char Para1 : 2;unsigned char state;unsigned char avail : 1;}*Env_Alarm_Record;struct _Record_Struct *pointer (struct _Record_Struct*)m…

【PyCharm Community Edition】:PCAN-USB上位机开发

PCAN-USB上位机开发 一级目录二次开发问题记录继承父类的函数platform.system()判断不准确打开PCANBasic.dll出错延伸阅读一级目录 二次开发 下载链接 问题记录 继承父类的函数 python中super().init() platform.system()判断不准确 根据不同系统,打开DLL的方法不同,本来…

IDEA批量解决Lombok警告,开发者必备技巧!

问题背景 用Lombok的Data注解的时候,因为封装了一个公共的Base基础类,总是会提示以下警告提示: Generating equals/hashCode implementation but without a call to superclass, even though this class does not extend java.lang.Object. …

软件测试的基础(1)

程序员(开发) :编写程序代码(实现产品需求) 产品:收集并设计需求-需求文档(根据用户需求进行产品设计) UI设计师:设计界面,向外展示的形态 前端:用代码实现页面的显示 DBA:数据库设计(系统数据之间的关联) 运维:版本控制和发布、升级迭代,环境搭建和维护 客服:客户支持,…

jsvmp逆向(补环境篇)

书接上回 上篇文章写到tx的jsxmp的算法逆向,文章链接在这里。初试jsvmp加密 。但是可能有伙伴觉得不够详细。 这里放一个大佬的文章链接。 https://www.52pojie.cn/thread-1521480-1-1.html 。其实就是一个变形的xtea加密。 大佬的文章已经讲了很清楚了&#…

小程序测试应该进行哪些测试?起到什么作用?

在如今小程序蓬勃发展的时代,越来越多的企业选择开发小程序来扩大业务。在推出小程序之前,进行全面的测试是至关重要的。 一、小程序测试的注意事项 1、功能测试:测试小程序的各项功能是否正常,包括页面跳转、数据加载、提交操作…

依赖导入失败场景和解决方案

在使用 Maven 构建项目时,可能会发生依赖项下载错误的情况,主要原因有以下几种: 下载依赖时出现网络故障或仓库服务器宕机等原因,导致无法连接至 Maven 仓库,从而无法下载依赖。 依赖项的版本号或配置文件中的版本号错…

恒运资本:股市板块轮动顺口溜?

股市是一个变化多端的场所,不同的板块会因为不同的方针、商场影响、经济形势等多种原因而有不同的体现。因而,不同时期不同板块的轮动也成为了研究的热门。下面咱们就通过一个顺口溜,来深化了解股市板块轮动: “钢铁、水泥、煤炭…

Linux如何安装MySQL

Linux安装MySQL5.7 1、下载 官网下载地址:http://dev.mysql.com/downloads/mysql/ 2、复制下面几个文件 3、检查当前系统是否安装过mysql、检查当前mysql依赖环境、检查/tmp文件夹权限 1)检查当前系统是否安装过mysql,执行安装命令前&am…

CAS策略

CAS CAS(Compare And Swap)比较并交换 CAS是多线程环境下对共享变量进行修改时的一种策略,主要存在三个参数:当前值、预估值、结果 CAS采用的策略是当一个线程要对共享变量进行修改时,需要获取内存中共享变量的值作…

自动驾驶中间件

自动驾驶中间件 1. 什么是中间件2. 中间件的分类3. 自动驾驶为什么需要中间件4. 通信中间件 Reference: 自动驾驶中间件:量产落地的关键技术通俗易懂的告诉你什么是中间件 对于初入自动驾驶行业的人来说,各色各样的新型传感器、线控系统、芯…

基于Java+SpringBoot+Vue前后端分离家政服务管理系统的设计与实现【Java毕业设计·文档报告·代码讲解·安装调试】

🍊作者:计算机编程-吉哥 🍊简介:专业从事JavaWeb程序开发,微信小程序开发,定制化项目、 源码、代码讲解、文档撰写、ppt制作。做自己喜欢的事,生活就是快乐的。 🍊心愿:点…

【考研数学】高等数学第五模块 —— 级数(2,幂级数)

文章目录 引言二、幂级数2.1 基本概念2.2 幂级数的收敛半径与收敛域2.3 幂级数的性质2.4 将函数展开为幂级数2.4.1 直接法2.4.2 间接法 2.5 求幂级数的和函数 写在最后 引言 承接前文的常数项级数,我们来继续看看关于幂级数的内容。 二、幂级数 2.1 基本概念 函数…

SpringCloud(二)

1.Nacos配置管理 Nacos除了可以做注册中心,同样可以做配置管理来使用。 1.1.统一配置管理 当微服务部署的实例越来越多,达到数十、数百时,逐个修改微服务配置就会让人抓狂,而且很容易出错。我们需要一种统一配置管理方案&#…

php使用jwt作登录验证

1 在项目根目录下,安装jwt composer require firebase/php-jwt 2 在登录控制器中加入生成token的代码 use Firebase\JWT\JWT; use Firebase\JWT\Key; class Login extends Cross {/*** 显示资源列表** return \think\Response*/public function index(Request $r…

在贸易发展新时代,我为什么推荐你使用全渠道支持平台?

在当今世界,客户希望在与企业互动时获得无缝体验,无论他们使用什么渠道进行联系。这就是全渠道支持的用武之地。通过提供全渠道客户支持,企业可以满足客户期望并提高客户满意度。在本文中,我们将探讨它的好处以及认识全渠道客户沟…

前人栽树,后人才能乘凉!聊聊低代码对开发者的意义

一、低代码很火 LCDP(低代码开发平台)市场规模大,增长迅速。Gartner机构的预测,到2025年,企业70%的新应用将会通过低代码或者无代码技术开发,这将加快低代码市场的全面爆发。而另外一家研究机构海比研究院数…

YOLOv5:对yolov5n模型进一步剪枝压缩

YOLOv5:对yolov5n模型进一步剪枝压缩 前言前提条件相关介绍具体步骤修改yolov5n.yaml配置文件单通道数据(黑白图片)修改models/yolo.py文件修改train.py文件 剪枝后模型大小 参考 前言 由于本人水平有限,难免出现错漏,…

2023高教社杯数学建模国赛题目这样选择

2023高教社杯数学建模国赛题目如何选择 一年一度的数学建模国赛要来啦!!!小编仔细阅读了比赛官方网站上的规则和要求,以及比赛的题型和时间安排,现总结分享给大家。小编将会在开赛后第一时间发布选题建议、所有题目的…

Leetcode - 112双周赛

一,2839. 判断通过操作能否让字符串相等 I ​ 该题的题意就是看 单数下标 和 偶数下标的 s1 和 s2 中的字母及其数量是否相等。 代码如下(也可以使用哈希表来做): class Solution {public boolean canBeEqual(String s1, String s2) {int[] a new in…