基于Pytorch框架的深度学习MODNet网络精细人像分割系统源码

news2024/11/15 17:56:39

 第一步:准备数据

人像精细分割数据,可分割出头发丝,为PPM-100开源数据

第二步:搭建模型

MODNet网络结构如图所示,主要包含3个部分:semantic estimation(S分支)、detail prediction(D分支)、semantic-detail fusion(F分支)。

网络结构简单描述一下:

输入一幅图像I,送入三个模块:S、D、F;
S模块:在低分辨率分支进行语义估计,在backbone最后一层输出接上e-ASPP得到语义feature map Sp;
D模块:在高分辨率分支进行细节预测,通过融合来自低分辨率分支的信息得到细节feature map Dp;
F模块:融合来自低分辨率分支和高分辨率分支的信息,得到alpha matte ap;
对S、D、F模块,均使用来自GT的显式监督信息进行监督训练。

第三步:代码

1)损失函数为:L2损失

2)网络代码:

import torch
import torch.nn as nn
import torch.nn.functional as F

from .backbones import SUPPORTED_BACKBONES


#------------------------------------------------------------------------------
#  MODNet Basic Modules
#------------------------------------------------------------------------------

class IBNorm(nn.Module):
    """ Combine Instance Norm and Batch Norm into One Layer
    """

    def __init__(self, in_channels):
        super(IBNorm, self).__init__()
        in_channels = in_channels
        self.bnorm_channels = int(in_channels / 2)
        self.inorm_channels = in_channels - self.bnorm_channels

        self.bnorm = nn.BatchNorm2d(self.bnorm_channels, affine=True)
        self.inorm = nn.InstanceNorm2d(self.inorm_channels, affine=False)
        
    def forward(self, x):
        bn_x = self.bnorm(x[:, :self.bnorm_channels, ...].contiguous())
        in_x = self.inorm(x[:, self.bnorm_channels:, ...].contiguous())

        return torch.cat((bn_x, in_x), 1)


class Conv2dIBNormRelu(nn.Module):
    """ Convolution + IBNorm + ReLu
    """

    def __init__(self, in_channels, out_channels, kernel_size, 
                 stride=1, padding=0, dilation=1, groups=1, bias=True, 
                 with_ibn=True, with_relu=True):
        super(Conv2dIBNormRelu, self).__init__()

        layers = [
            nn.Conv2d(in_channels, out_channels, kernel_size, 
                      stride=stride, padding=padding, dilation=dilation, 
                      groups=groups, bias=bias)
        ]

        if with_ibn:       
            layers.append(IBNorm(out_channels))
        if with_relu:
            layers.append(nn.ReLU(inplace=True))

        self.layers = nn.Sequential(*layers)

    def forward(self, x):
        return self.layers(x) 


class SEBlock(nn.Module):
    """ SE Block Proposed in https://arxiv.org/pdf/1709.01507.pdf 
    """

    def __init__(self, in_channels, out_channels, reduction=1):
        super(SEBlock, self).__init__()
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(in_channels, int(in_channels // reduction), bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(int(in_channels // reduction), out_channels, bias=False),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        b, c, _, _ = x.size()
        w = self.pool(x).view(b, c)
        w = self.fc(w).view(b, c, 1, 1)

        return x * w.expand_as(x)


#------------------------------------------------------------------------------
#  MODNet Branches
#------------------------------------------------------------------------------

class LRBranch(nn.Module):
    """ Low Resolution Branch of MODNet
    """

    def __init__(self, backbone):
        super(LRBranch, self).__init__()

        enc_channels = backbone.enc_channels
        
        self.backbone = backbone
        self.se_block = SEBlock(enc_channels[4], enc_channels[4], reduction=4)
        self.conv_lr16x = Conv2dIBNormRelu(enc_channels[4], enc_channels[3], 5, stride=1, padding=2)
        self.conv_lr8x = Conv2dIBNormRelu(enc_channels[3], enc_channels[2], 5, stride=1, padding=2)
        self.conv_lr = Conv2dIBNormRelu(enc_channels[2], 1, kernel_size=3, stride=2, padding=1, with_ibn=False, with_relu=False)

    def forward(self, img, inference):
        enc_features = self.backbone.forward(img)
        enc2x, enc4x, enc32x = enc_features[0], enc_features[1], enc_features[4]

        enc32x = self.se_block(enc32x)
        lr16x = F.interpolate(enc32x, scale_factor=2, mode='bilinear', align_corners=False)
        lr16x = self.conv_lr16x(lr16x)
        lr8x = F.interpolate(lr16x, scale_factor=2, mode='bilinear', align_corners=False)
        lr8x = self.conv_lr8x(lr8x)

        pred_semantic = None
        if not inference:
            lr = self.conv_lr(lr8x)
            pred_semantic = torch.sigmoid(lr)

        return pred_semantic, lr8x, [enc2x, enc4x] 


class HRBranch(nn.Module):
    """ High Resolution Branch of MODNet
    """

    def __init__(self, hr_channels, enc_channels):
        super(HRBranch, self).__init__()

        self.tohr_enc2x = Conv2dIBNormRelu(enc_channels[0], hr_channels, 1, stride=1, padding=0)
        self.conv_enc2x = Conv2dIBNormRelu(hr_channels + 3, hr_channels, 3, stride=2, padding=1)

        self.tohr_enc4x = Conv2dIBNormRelu(enc_channels[1], hr_channels, 1, stride=1, padding=0)
        self.conv_enc4x = Conv2dIBNormRelu(2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1)

        self.conv_hr4x = nn.Sequential(
            Conv2dIBNormRelu(3 * hr_channels + 3, 2 * hr_channels, 3, stride=1, padding=1),
            Conv2dIBNormRelu(2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1),
            Conv2dIBNormRelu(2 * hr_channels, hr_channels, 3, stride=1, padding=1),
        )

        self.conv_hr2x = nn.Sequential(
            Conv2dIBNormRelu(2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1),
            Conv2dIBNormRelu(2 * hr_channels, hr_channels, 3, stride=1, padding=1),
            Conv2dIBNormRelu(hr_channels, hr_channels, 3, stride=1, padding=1),
            Conv2dIBNormRelu(hr_channels, hr_channels, 3, stride=1, padding=1),
        )

        self.conv_hr = nn.Sequential(
            Conv2dIBNormRelu(hr_channels + 3, hr_channels, 3, stride=1, padding=1),
            Conv2dIBNormRelu(hr_channels, 1, kernel_size=1, stride=1, padding=0, with_ibn=False, with_relu=False),
        )

    def forward(self, img, enc2x, enc4x, lr8x, inference):
        img2x = F.interpolate(img, scale_factor=1/2, mode='bilinear', align_corners=False)
        img4x = F.interpolate(img, scale_factor=1/4, mode='bilinear', align_corners=False)

        enc2x = self.tohr_enc2x(enc2x)
        hr4x = self.conv_enc2x(torch.cat((img2x, enc2x), dim=1))

        enc4x = self.tohr_enc4x(enc4x)
        hr4x = self.conv_enc4x(torch.cat((hr4x, enc4x), dim=1))

        lr4x = F.interpolate(lr8x, scale_factor=2, mode='bilinear', align_corners=False)
        hr4x = self.conv_hr4x(torch.cat((hr4x, lr4x, img4x), dim=1))

        hr2x = F.interpolate(hr4x, scale_factor=2, mode='bilinear', align_corners=False)
        hr2x = self.conv_hr2x(torch.cat((hr2x, enc2x), dim=1))

        pred_detail = None
        if not inference:
            hr = F.interpolate(hr2x, scale_factor=2, mode='bilinear', align_corners=False)
            hr = self.conv_hr(torch.cat((hr, img), dim=1))
            pred_detail = torch.sigmoid(hr)

        return pred_detail, hr2x


class FusionBranch(nn.Module):
    """ Fusion Branch of MODNet
    """

    def __init__(self, hr_channels, enc_channels):
        super(FusionBranch, self).__init__()
        self.conv_lr4x = Conv2dIBNormRelu(enc_channels[2], hr_channels, 5, stride=1, padding=2)
        
        self.conv_f2x = Conv2dIBNormRelu(2 * hr_channels, hr_channels, 3, stride=1, padding=1)
        self.conv_f = nn.Sequential(
            Conv2dIBNormRelu(hr_channels + 3, int(hr_channels / 2), 3, stride=1, padding=1),
            Conv2dIBNormRelu(int(hr_channels / 2), 1, 1, stride=1, padding=0, with_ibn=False, with_relu=False),
        )

    def forward(self, img, lr8x, hr2x):
        lr4x = F.interpolate(lr8x, scale_factor=2, mode='bilinear', align_corners=False)
        lr4x = self.conv_lr4x(lr4x)
        lr2x = F.interpolate(lr4x, scale_factor=2, mode='bilinear', align_corners=False)

        f2x = self.conv_f2x(torch.cat((lr2x, hr2x), dim=1))
        f = F.interpolate(f2x, scale_factor=2, mode='bilinear', align_corners=False)
        f = self.conv_f(torch.cat((f, img), dim=1))
        pred_matte = torch.sigmoid(f)

        return pred_matte


#------------------------------------------------------------------------------
#  MODNet
#------------------------------------------------------------------------------

class MODNet(nn.Module):
    """ Architecture of MODNet
    """

    def __init__(self, in_channels=3, hr_channels=32, backbone_arch='mobilenetv2', backbone_pretrained=True):
        super(MODNet, self).__init__()

        self.in_channels = in_channels
        self.hr_channels = hr_channels
        self.backbone_arch = backbone_arch
        self.backbone_pretrained = backbone_pretrained

        self.backbone = SUPPORTED_BACKBONES[self.backbone_arch](self.in_channels)

        self.lr_branch = LRBranch(self.backbone)
        self.hr_branch = HRBranch(self.hr_channels, self.backbone.enc_channels)
        self.f_branch = FusionBranch(self.hr_channels, self.backbone.enc_channels)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                self._init_conv(m)
            elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.InstanceNorm2d):
                self._init_norm(m)

        if self.backbone_pretrained:
            self.backbone.load_pretrained_ckpt()                

    def forward(self, img, inference):
        pred_semantic, lr8x, [enc2x, enc4x] = self.lr_branch(img, inference)
        pred_detail, hr2x = self.hr_branch(img, enc2x, enc4x, lr8x, inference)
        pred_matte = self.f_branch(img, lr8x, hr2x)

        return pred_semantic, pred_detail, pred_matte
    
    def freeze_norm(self):
        norm_types = [nn.BatchNorm2d, nn.InstanceNorm2d]
        for m in self.modules():
            for n in norm_types:
                if isinstance(m, n):
                    m.eval()
                    continue

    def _init_conv(self, conv):
        nn.init.kaiming_uniform_(
            conv.weight, a=0, mode='fan_in', nonlinearity='relu')
        if conv.bias is not None:
            nn.init.constant_(conv.bias, 0)

    def _init_norm(self, norm):
        if norm.weight is not None:
            nn.init.constant_(norm.weight, 1)
            nn.init.constant_(norm.bias, 0)

第四步:搭建GUI界面

第五步:整个工程的内容

有训练代码和训练好的模型以及训练过程,提供数据,提供GUI界面代码

代码见:基于Pytorch框架的深度学习MODNet网络精细人像分割系统源码

有问题可以私信或者留言,有问必答

e0420b8919fe4c1cb3d1e3dd52176a8a.png

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2152707.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

pyqt瀑布流布局

最近研究瀑布流布局,发现都是收费的,所以只能自己写算法写布局。 所以啥都不说直接上代码 ImageLabel 参考 pyqt5 QLabel显示网络图片或qfluentwidgets官网 代码 import math import sys from pathlib import Pathfrom PyQt5.Qt import * from qflue…

传统美业通过小魔推短视频矩阵系统,实现逆势增长?

许多美甲店在经营过程中常常陷入一个误区:他们认为自己缺少的是客户,但实际上,他们真正缺少的是有效的营销策略,美甲店经营者普遍面临的两大难题包括: 1. 高客户流失率: 据研究显示,约70%的顾…

如何成立一家自己的等级保护测评机构?需要哪些条件?有哪些要求?

给大家的福利,点击下方蓝色字 即可免费领取↓↓↓ 🤟 基于入门网络安全/黑客打造的:👉黑客&网络安全入门&进阶学习资源包 前言 各省、自治区、直辖市公安厅、局网络安全保卫总队,新疆生产建设兵团公安局网络安…

SpringBoot学习指南

文章目录 一、为什么要学习SpringBoot二、SpringBoot介绍2.1 约定优于配置2.2 SpringBoot中的约定三、SpringBoot快速入门3.1 快速构建SpringBoot3.1.1 选择构建项目的类型3.1.2 项目的描述3.1.3 指定SpringBoot版本和需要的依赖3.1.4 导入依赖3.1.5 编写了Controller3.1.6 测试…

机器翻译之Bahdanau注意力机制在Seq2Seq中的应用

目录 1.创建 添加了Bahdanau的decoder 2. 训练 3.定义评估函数BLEU 4.预测 5.知识点个人理解 1.创建 添加了Bahdanau的decoder import torch from torch import nn import dltools#定义注意力解码器基类 class AttentionDecoder(dltools.Decoder): #继承dltools.Decoder写…

为什么年轻人都热衷找搭子,而不是找对象?

在繁华的都市中,有一个名叫晓悦的年轻人。晓悦每天穿梭于忙碌的工作和快节奏的生活之间,渐渐地,她发现身边的朋友们都开始找起了 “搭子”。 有一天,晓悦结束了一天疲惫的工作,坐在咖啡店里,看着窗外匆匆而…

在SpringBoot项目中利用Redission实现布隆过滤器(布隆过滤器的应用场景、布隆过滤器误判的情况、与位图相关的操作)

文章目录 1. 布隆过滤器的应用场景2. 在SpringBoot项目利用Redission实现布隆过滤器3. 布隆过滤器误判的情况4. 与位图相关的操作5. 可能遇到的问题(Redission是如何记录布隆过滤器的配置参数的)5.1 问题产生的原因5.2 解决方案5.2.1 方案一:…

DBeaver启动报错 Faild to load the JNI shared library

DBeaver启动报错 Faild to load the JNI shared library 问题现象 安装完成后,启动dbeaver报错 查看版本为64位版本,JAVA也为64为版本 dbeaver版本 java版本 解决 在dberver.ini添加指定配置,即可启动成功

msvcp100.dll是什么意思?msvcp100.dll丢失有什么可靠的解决方法

当我们在电脑中试图启动某些程序或游戏时,可能会遇到一个错误消息:"程序无法启动,因为计算机缺少msvcp100.dll"。其实遇到这种情况是非常的常见的,只要你是经常使用电脑的人,我们要解决它也非常的简单&#…

工作中遇到的问题总结(1)

文章目录 第一题问题描述解决思路 第二题问题描述解决思路核心大表如何优化数据迁移过程是怎么样的如何将流量从旧系统迁移到新系统上 第三题问题描述解决思路 第四题问题描述解决思路方案一:双写机制方案二:基于时间戳的分流机制方案三:灰度…

数据结构之线性表——LeetCode:707. 设计链表,206. 反转链表,92. 反转链表 II

707. 设计链表 题目描述 707. 设计链表 你可以选择使用单链表或者双链表,设计并实现自己的链表。 单链表中的节点应该具备两个属性:val 和 next 。val 是当前节点的值,next 是指向下一个节点的指针/引用。 如果是双向链表,则…

【滑动窗口】算法总结

文章目录 滑动窗口算法总结1.暴力求解vs滑动窗口2.需要注意的细节问题 2.滑动窗口的基本模板1.非固定窗口大小的滑动窗口2.固定窗口大小的滑动窗口细节 滑动窗口算法总结 1.暴力求解vs滑动窗口 遇到那些可以转化成一个子数组的长度的问题时,往往需要用到双指针。 …

二,MyBatis -Plus 关于映射 Java Bean 对象的注意事项和细节(详细说明)

二,MyBatis -Plus 关于映射 Java Bean 对象的注意事项和细节(详细说明) 文章目录 二,MyBatis -Plus 关于映射 Java Bean 对象的注意事项和细节(详细说明)1. 映射2. 表的映射3. 字段映射4. 字段失效5. 视图属性6. 总结:7. 最后: 1.…

【C/C++】速通涉及string类的经典编程题

【C/C】速通涉及string类的经典编程题 一.字符串最后一个单词的长度代码实现:(含注释) 二.验证回文串解法一:代码实现:(含注释) 解法二:(推荐)1. 函数isalnum…

单卡3090 选用lora微调ChatGLM3-6B

环境配置 Python 3.10.12 transformers 4.36.2 torch 2.0.1 下载demo代码 在官方网址https://github.com/THUDM/ChatGLM3/blob/main/finetune_demo 下载demo代码cd 进入文件夹 pip install -r requirements.txt 安装一些包 基本知识 SFT 全量微调: 4张显卡平均分配&#…

13年计算机考研408-数据结构

解析: 这个降序链表不影响时间复杂度,因为是链表,所以你想要升序就使用头插法,你想要降序就使用尾插法。 然后我们来分析一下最坏的情况是什么样的。 因为m和n都是两个有序的升序序列。 如果刚好m的最大值小于n的最小值&#xff0…

AI宠物拟人化新玩法,教你如何用0成本打造爆款创意内容!

近年来,随着AI技术的快速发展,各种创新玩法不断涌现,尤其是在内容创作领域,AI带来的变革尤为显著。 **其中,宠物拟人化逐渐成为社交媒体上的一大热门话题。**通过AI生成工具,我们不仅可以将宠物拟人化&…

Snapchat API 访问:Objective-C 实现示例

Snapchat 是一个流行的社交媒体平台,它允许用户发送和接收短暂存在的图片和视频。对于开发者来说,访问 Snapchat API 可以为应用程序添加独特的社交功能。本文将介绍如何在 Objective-C 中实现对 Snapchat API 的访问,并提供一个详细的代码示…

GD32F103单片机-EXTI外部中断

GD32F103单片机-EXTI外部中断 一、EXTI及NVIC介绍二、编程实验2.1 相关库函数2.2 实验代码 一、EXTI及NVIC介绍 GD32和STM32的EXTI基本相似,具体见STM32F1单片机-外部中断GD32的EXTI包括20个相互独立的边沿检测电路请求产生中断或事件,4位优先级配置寄存…

热像仪是如何工作的?

红外热像仪是一种非接触式设备,能够检测红外能量(热量)并将其转变成可见光图像。让我们深入了解红外热像仪的科学原理,以及借助红外热像仪我们能够看到的隐形世界。 捕捉红外波,而不是可见光 首先必须清楚的是&#…