基于Pytorch框架的深度学习DeepLabv3+网络头发语义分割系统源码

news2024/9/29 21:23:13

 第一步:准备数据

头发分割数据,总共有1050张图片,里面的像素值为0和1,所以看起来全部是黑的,不影响使用

第二步:搭建模型

DeepLabV3+的网络结构如下图所示,主要为Encoder-Decoder结构。其中,Encoder为改进的DeepLabV3,Decoder为3+版本新提出的。

1.1、Encoder
在Encoder部分,主要包括了backbone(即:图1中的DCNN)、ASPP两大部分。

其中backbone有两种网络结构:将layer4改为空洞卷积的Resnet系列、改进的Xception。从backbone出来的feature map分两部分:一部分是最后一层卷积输出的feature maps,另一部分是中间的低级特征的feature maps;backbone输出的第一部分送入ASPP模块,第二部分则送入Decoder模块。
ASPP模块接受backbone的第一部分输出作为输入,使用了四种不同膨胀率的空洞卷积块(包括卷积、BN、激活层)和一个全局平均池化块(包括池化、卷积、BN、激活层)得到一共五组feature maps,将其concat起来之后,经过一个1*1卷积块(包括卷积、BN、激活、dropout层),最后送入Decoder模块。
1.2、Decoder
在Decoder部分,接收来自backbone中间层的低级feature maps和来自ASPP模块的输出作为输入。

首先,对低级feature maps使用1*1卷积进行通道降维,从256降到48(之所以需要降采样到48,是因为太多的通道会掩盖ASPP输出的feature maps的重要性,且实验验证48最佳);
然后,对来自ASPP的feature maps进行插值上采样,得到与低级featuremaps尺寸相同的feature maps;
接着,将通道降维的低级feature maps和线性插值上采样得到的feature maps使用concat拼接起来,并送入一组3*3卷积块进行处理;
最后,再次进行线性插值上采样,得到与原图分辨率大小一样的预测图。

第三步:代码

1)损失函数为:交叉熵损失函数

2)网络代码:

import torch
import torch.nn as nn
import torch.nn.functional as F
from nets.xception import xception
from nets.mobilenetv2 import mobilenetv2

class MobileNetV2(nn.Module):
    def __init__(self, downsample_factor=8, pretrained=True):
        super(MobileNetV2, self).__init__()
        from functools import partial
        
        model           = mobilenetv2(pretrained)
        self.features   = model.features[:-1]

        self.total_idx  = len(self.features)
        self.down_idx   = [2, 4, 7, 14]

        if downsample_factor == 8:
            for i in range(self.down_idx[-2], self.down_idx[-1]):
                self.features[i].apply(
                    partial(self._nostride_dilate, dilate=2)
                )
            for i in range(self.down_idx[-1], self.total_idx):
                self.features[i].apply(
                    partial(self._nostride_dilate, dilate=4)
                )
        elif downsample_factor == 16:
            for i in range(self.down_idx[-1], self.total_idx):
                self.features[i].apply(
                    partial(self._nostride_dilate, dilate=2)
                )
        
    def _nostride_dilate(self, m, dilate):
        classname = m.__class__.__name__
        if classname.find('Conv') != -1:
            if m.stride == (2, 2):
                m.stride = (1, 1)
                if m.kernel_size == (3, 3):
                    m.dilation = (dilate//2, dilate//2)
                    m.padding = (dilate//2, dilate//2)
            else:
                if m.kernel_size == (3, 3):
                    m.dilation = (dilate, dilate)
                    m.padding = (dilate, dilate)

    def forward(self, x):
        low_level_features = self.features[:4](x)
        x = self.features[4:](low_level_features)
        return low_level_features, x 


#-----------------------------------------#
#   ASPP特征提取模块
#   利用不同膨胀率的膨胀卷积进行特征提取
#-----------------------------------------#
class ASPP(nn.Module):
	def __init__(self, dim_in, dim_out, rate=1, bn_mom=0.1):
		super(ASPP, self).__init__()
		self.branch1 = nn.Sequential(
				nn.Conv2d(dim_in, dim_out, 1, 1, padding=0, dilation=rate,bias=True),
				nn.BatchNorm2d(dim_out, momentum=bn_mom),
				nn.ReLU(inplace=True),
		)
		self.branch2 = nn.Sequential(
				nn.Conv2d(dim_in, dim_out, 3, 1, padding=6*rate, dilation=6*rate, bias=True),
				nn.BatchNorm2d(dim_out, momentum=bn_mom),
				nn.ReLU(inplace=True),	
		)
		self.branch3 = nn.Sequential(
				nn.Conv2d(dim_in, dim_out, 3, 1, padding=12*rate, dilation=12*rate, bias=True),
				nn.BatchNorm2d(dim_out, momentum=bn_mom),
				nn.ReLU(inplace=True),	
		)
		self.branch4 = nn.Sequential(
				nn.Conv2d(dim_in, dim_out, 3, 1, padding=18*rate, dilation=18*rate, bias=True),
				nn.BatchNorm2d(dim_out, momentum=bn_mom),
				nn.ReLU(inplace=True),	
		)
		self.branch5_conv = nn.Conv2d(dim_in, dim_out, 1, 1, 0,bias=True)
		self.branch5_bn = nn.BatchNorm2d(dim_out, momentum=bn_mom)
		self.branch5_relu = nn.ReLU(inplace=True)

		self.conv_cat = nn.Sequential(
				nn.Conv2d(dim_out*5, dim_out, 1, 1, padding=0,bias=True),
				nn.BatchNorm2d(dim_out, momentum=bn_mom),
				nn.ReLU(inplace=True),		
		)

	def forward(self, x):
		[b, c, row, col] = x.size()
        #-----------------------------------------#
        #   一共五个分支
        #-----------------------------------------#
		conv1x1 = self.branch1(x)
		conv3x3_1 = self.branch2(x)
		conv3x3_2 = self.branch3(x)
		conv3x3_3 = self.branch4(x)
        #-----------------------------------------#
        #   第五个分支,全局平均池化+卷积
        #-----------------------------------------#
		global_feature = torch.mean(x,2,True)
		global_feature = torch.mean(global_feature,3,True)
		global_feature = self.branch5_conv(global_feature)
		global_feature = self.branch5_bn(global_feature)
		global_feature = self.branch5_relu(global_feature)
		global_feature = F.interpolate(global_feature, (row, col), None, 'bilinear', True)
		
        #-----------------------------------------#
        #   将五个分支的内容堆叠起来
        #   然后1x1卷积整合特征。
        #-----------------------------------------#
		feature_cat = torch.cat([conv1x1, conv3x3_1, conv3x3_2, conv3x3_3, global_feature], dim=1)
		result = self.conv_cat(feature_cat)
		return result

class DeepLab(nn.Module):
    def __init__(self, num_classes, backbone="mobilenet", pretrained=True, downsample_factor=16):
        super(DeepLab, self).__init__()
        if backbone=="xception":
            #----------------------------------#
            #   获得两个特征层
            #   浅层特征    [128,128,256]
            #   主干部分    [30,30,2048]
            #----------------------------------#
            self.backbone = xception(downsample_factor=downsample_factor, pretrained=pretrained)
            in_channels = 2048
            low_level_channels = 256
        elif backbone=="mobilenet":
            #----------------------------------#
            #   获得两个特征层
            #   浅层特征    [128,128,24]
            #   主干部分    [30,30,320]
            #----------------------------------#
            self.backbone = MobileNetV2(downsample_factor=downsample_factor, pretrained=pretrained)
            in_channels = 320
            low_level_channels = 24
        else:
            raise ValueError('Unsupported backbone - `{}`, Use mobilenet, xception.'.format(backbone))

        #-----------------------------------------#
        #   ASPP特征提取模块
        #   利用不同膨胀率的膨胀卷积进行特征提取
        #-----------------------------------------#
        self.aspp = ASPP(dim_in=in_channels, dim_out=256, rate=16//downsample_factor)
        
        #----------------------------------#
        #   浅层特征边
        #----------------------------------#
        self.shortcut_conv = nn.Sequential(
            nn.Conv2d(low_level_channels, 48, 1),
            nn.BatchNorm2d(48),
            nn.ReLU(inplace=True)
        )		

        self.cat_conv = nn.Sequential(
            nn.Conv2d(48+256, 256, 3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),

            nn.Conv2d(256, 256, 3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),

            nn.Dropout(0.1),
        )
        self.cls_conv = nn.Conv2d(256, num_classes, 1, stride=1)

    def forward(self, x):
        H, W = x.size(2), x.size(3)
        #-----------------------------------------#
        #   获得两个特征层
        #   low_level_features: 浅层特征-进行卷积处理
        #   x : 主干部分-利用ASPP结构进行加强特征提取
        #-----------------------------------------#
        low_level_features, x = self.backbone(x)
        x = self.aspp(x)
        low_level_features = self.shortcut_conv(low_level_features)
        
        #-----------------------------------------#
        #   将加强特征边上采样
        #   与浅层特征堆叠后利用卷积进行特征提取
        #-----------------------------------------#
        x = F.interpolate(x, size=(low_level_features.size(2), low_level_features.size(3)), mode='bilinear', align_corners=True)
        x = self.cat_conv(torch.cat((x, low_level_features), dim=1))
        x = self.cls_conv(x)
        x = F.interpolate(x, size=(H, W), mode='bilinear', align_corners=True)
        return x

第四步:统计一些指标(训练过程中的loss和miou)

第五步:搭建GUI界面

第六步:整个工程的内容

有训练代码和训练好的模型以及训练过程,提供数据,提供GUI界面代码

代码见:基于Pytorch框架的深度学习DeepLabv3+网络头发语义分割系统源码

有问题可以私信或者留言,有问必答

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2087903.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

深度学习基础--模型拟合

模型拟合 损失与网络参数有关,本章着重于探讨如何确定能使损失最小化的参数值。这个过程称为网络参数的学习,或更通俗地说,是模型的训练或拟合。该过程首先是选取一组初始参数值,随后重复执行两个步骤: (i) 计算损失…

张驰咨询:新界泵业六西格玛设计DFSS项目出成果

近日,新界泵业六西格玛设计项目通过专家评审,新界泵业石总等领导、六西格玛设计项目组成员、张驰咨询首席顾问张驰、首席六西格设计顾问赵老师共同出席项目评审会。 (顾问老师致辞) 本期项目涉及多款新产品开发,本期…

Java新手零基础教程!Java 异常详解.^◡^.

Java 异常 Java教程 - Java异常 异常是在运行时在代码序列中出现的异常状况。例如,读取一个不存在的文件。 Java异常是描述异常条件的对象发生在一段代码中。 关键词 Java异常处理通过五个关键字管理: try,catch,throw,throws…

优思学院|质量工程师常用英语【客户投诉篇】

作为质量工程师,你是否曾因国外客户的投诉而不知如何用英语回应,感到困扰? 质量工程师常常面对各种挑战。即使你解决问题的能力很强,但由于不熟悉使用英语,可能会影响客户对你的印象和信任。 接下来,让我…

C#面试题系列--动态更新

C#面试题系列 排版排了半天,也是不好看,那就不排版了,尽量由易到难 高级一些 什么是MVC模式C#中特性是什么?如何使用?C#中什么是反射?C#中的委托是什么 事件是不是一种委托C# 不安全代码C# 隐式类型 varC# linqC# 匿名…

在centos中安装 --nmon性能系统监控工具

参考资料 CentOS安装nmon-CSDN博客 Jmeter(十九):nmon性能系统监控工具_jmeter nmon(1)_jmeter nmon性能系统监控工具详解-CSDN博客 Linux性能监控命令_nmon 安装与使用_nmon安装方法linux-CSDN博客 资源监控工具nmon安装及使用 – TestGo 下载启宏插件 https…

学习日志8.21--防火墙NAT

在学习过基于路由器的NAT网络地址转换,现在学习基于防火墙NAT的网络地址转换,防火墙的NAT配置和路由器的NAT配置还是有比较大的区别。 防火墙NAT是通过NAT策略实现的,在创建防火墙NAT之前需要先创建防火墙的安全策略。防火墙是不能直接在接口…

i2c-tool工具

i2c-tool工具的使用方法 包括i2cdetect、i2cget、i2cset、i2cdump、i2ctransfer i2cdetect命令 该命令用于扫描I2C总线上的设备。 语法:i2cdetect [-y] [-a] [-q|-r] i2cbus [first last]:参数说明:参数y:关闭交互模式&#xf…

GEE 教程:如何实现对指定矢量集合的归一化操作(以北京市各区县面积和边长为例)

简介 数据归一化处理是指将数据按照一定的规则进行变换,使数据落入一个特定的区间范围内。数据归一化处理的目的是消除数据之间的量纲差异,同时保留数据的分布特征,以便更好地进行数据分析和建模。 常见的数据归一化方法有如下几种&#xf…

快讯 | 谷歌AI引擎GameNGen颠覆游戏产业:0代码生成《毁灭战士》

硅纪元快讯栏目,每日追踪AI领域的最新动态,快速汇总最新科技新闻,助您时刻紧跟行业趋势。简明扼要的呈现资讯概要,让您快速了解前沿资讯。 1分钟速览新闻 ChatGPT用户翻倍突破2亿,AI工具融入日常生活 智谱AI发布尖端语…

电工手册 v77.9 — 专业电气知识与实用工具(Mod版)

电工手册是一款专门为电力领域从业者及爱好者设计的知识普及与技能提升应用。无论你是专业电工、DIY爱好者,还是对电力领域有兴趣的学生,这款应用都能为你提供大量实用的资源和工具。内容涵盖电气理论、接线图和计算器等多个方面,旨在帮助用户…

(echarts)散点图怎么给横坐标添加单位

(echarts)散点图怎么给横坐标添加单位 效果: 代码: 拓展-给值加

类在JVM中的工作原理

文章目录 引言I 类在JVM中的工作原理class文件的结构类的生命周期II JVM运行时数据区堆栈的意义栈帧内部结构堆III 在JIT中比较常见的优化手段引言 类是一种抽象概念,它是一种模板,用来定义一类事物的属性和行为。类是面向对象编程的基础,它是一种抽象的概念,代表一类事物…

Java 魔法类 Unsafe 源码解读(一)

Java 魔法类 Unsafe 源码解读(一) 前言 阅读过 JUC 源码的同学,一定会发现很多并发工具都调用了一个叫做 Unsafe 的类。 那这个类的作用是什么呢?有什么使用场景呢?底层源码是什么样呢?这篇文章笔者就带…

uni-app商城小程序+后台管理系统,手把手教你搭建

uni-app商城小程序是一种通过uni-app框架开发的,可以在微信、支付宝、字节跳动等多个平台上运行的轻量级电商应用。 一、特点 跨平台兼容:基于uni-app框架,一次开发,可同时适配微信小程序、支付宝小程序、H5、App等多个平台&…

3DMAX2025新款插件精选大全

关于3DMAX2025的新款插件,虽然无法提供一个详尽无遗的列表,本文根据公开发布的信息和插件的流行趋势,概述一些新款插件或插件更新。请注意,由于插件市场不断变化,以下信息可能随时间而有所更新。 以下插件按首字母排序…

ts转mp4怎么转?分享3个方法,快速搞定

在视频编辑和处理的世界里,格式转换是一个常见且必要的任务。特别是当你手头上有一些ts格式的视频文件,而你又需要将它们转换成更通用、更容易分享的mp4格式时,了解如何进行转换就显得尤为重要。 只有掌握了格式转换的技能,我们才…

yolov8训练野火烟雾识别检测模型

1.数据集下载 数据集下载链接:https://hyper.ai/datasets/33096 2. 数据集格式转换 需要将json中的标注信息转换为yolo格式的标注文件数据 import json import os import shutil import cv2 import matplotlib.pyplot as plttarget "./data/val" def…

如何在没有密码的情况下解锁 Oppo 手机?5 种简单的方法

保护智能手机隐私的一种很好的方法是设置复杂的锁屏密码或图案。一些 OPPO 手机的所有者在更改后一夜之间经历了图案或密码的内存丢失。事实上,OPPO 用户遇到的众多问题包括忘记密码或锁定屏幕。遗憾的是,没有多少人知道无需密码即可解锁 OPPO 手机的简单…

JAVA毕业设计166—基于Java+Springboot+vue3的流浪宠物救助管理小程序(源代码+数据库)

毕设所有选题: https://blog.csdn.net/2303_76227485/article/details/131104075 基于JavaSpringbootvue3的流浪宠物救助管理小程序(源代码数据库)166 一、系统介绍 本项目前后端分离带小程序(可以改为ssm版本),分为用户、救助站、管理员三种角色 1、…