Pytorch 基于 deeplabv3_resnet50 迁移训练自己的图像语义分割模型

news2024/11/24 0:28:00

一、图像语义分割

图像语义分割是计算机视觉领域的一项重要任务,旨在将图像中的每个像素分配到其所属的语义类别,从而实现对图像内容的细粒度理解。与目标检测不同,图像语义分割要求对图像中的每个像素进行分类,而不仅仅是确定物体的边界框。deeplabv3_resnet50 就是一个常用的语义分割模型,它巧妙地将两个强大的神经网络架构融合在一起,为像素级别的图像理解提供了强大的解决方案。

首先,DeepLabV3是一种专门设计用于语义分割的架构。通过采用扩张卷积(也称为空洞卷积)能够在不损失空间分辨率的情况下捕捉多尺度信息。这使得模型能够对图像进行精细的分割,识别并分类每个像素的语义信息。

其次,ResNet50ResNet系列中的一员,拥有50层深度的残差网络结构。通过引入残差连接,ResNet50解决了深层神经网络中梯度消失的问题,使得网络更易于训练。作为骨干网络,ResNet50提供了强大的特征提取能力,有助于捕捉图像中的高级语义特征。

本文基于 Pytorch 使用 deeplabv3_resnet50 迁移训练自己的图像语义分割模型,数据使用的数据集,最后效果如下所示:

在这里插入图片描述

下面使用的 torch 版本如下:

torch                   1.13.1+cu116
torchaudio              0.13.1+cu116
torchvision             0.14.1+cu116

二、数据集准备

图像数据可以从网上找一些或者自己拍摄,我这里准备了一些 的图片:

在这里插入图片描述

这里构建 VOC 格式数据集,因此需要新建如下结构目录:

VOCdevkit
	VOC2012
		Annotations
		ImageSets
			Segmentation
		JPEGImages
		SegmentationClass

在这里插入图片描述

目录解释如下:

  • Annotations 存放标注后的 xml 文件
  • Segmentation 划分后的训练样本名称和验证集样本名称(只存放名称)
  • JPEGImages 存放收集的图像
  • SegmentationClass 存放语义分割的mask标签图像

将收集的图像放到 JPEGImages 目录下:

在这里插入图片描述

三、图像标注

标注工具使用 labelme ,如果没有安装,使用下面方式引入该依赖:

pip install labelme -i https://pypi.tuna.tsinghua.edu.cn/simple

然后控制台输入:labelme ,即可打开标注工具:

在这里插入图片描述

通过构建一个区域后,需要给该区域一个标签,这里给 cat

在这里插入图片描述

xml 文件保存在 Annotations 下:

在这里插入图片描述

四、生成 mask 标签图像及数据划分

标注完成后,需要将标注数据转为 mask 标签图像:

trans_mask.py

import json
import os
import os.path as osp
import copy
import numpy as np
import PIL.Image

from labelme import utils

NAME_LABEL_MAP = {
    '_background_': 0,
    "cat": 1,
}


def main():

    annotations = './voc/VOCdevkit/VOC2012/Annotations'
    segmentationClass = './voc/VOCdevkit/VOC2012/SegmentationClass'

    list = os.listdir(annotations)
    for i in range(0, len(list)):
        path = os.path.join(annotations, list[i])
        filename = list[i][:-5]
        if os.path.isfile(path):
            data = json.load(open(path,encoding="utf-8"))
            img = utils.image.img_b64_to_arr(data['imageData'])
            lbl, lbl_names = utils.shape.labelme_shapes_to_label(img.shape, data['shapes'])  # labelme_shapes_to_label
            # modify labels according to NAME_LABEL_MAP
            lbl_tmp = copy.copy(lbl)
            for key_name in lbl_names:
                old_lbl_val = lbl_names[key_name]
                new_lbl_val = NAME_LABEL_MAP[key_name]
                lbl_tmp[lbl == old_lbl_val] = new_lbl_val
            lbl_names_tmp = {}
            for key_name in lbl_names:
                lbl_names_tmp[key_name] = NAME_LABEL_MAP[key_name]
            # Assign the new label to lbl and lbl_names dict
            lbl = np.array(lbl_tmp, dtype=np.int8)
            label_path = osp.join(segmentationClass, '{}.png'.format(filename))
            PIL.Image.fromarray(lbl.astype(np.uint8)).save(label_path)
            print('Saved to: %s' % label_path)


if __name__ == '__main__':
    main()

注意修改路径为你的地址,运行后可以在 SegmentationClass 目录下看到 mask 标签图像:

在这里插入图片描述

下面进行数据的划分,这里划分为90%训练集和10%验证集:

split_data.py

import os

if __name__ == '__main__':
    JPEGImages = "./voc/VOCdevkit/VOC2012/JPEGImages"
    Segmentation = "./voc/VOCdevkit/VOC2012/ImageSets/Segmentation"
    # 训练集比例 90%
    training_ratio = 0.9

    list = os.listdir(JPEGImages)
    all = len(list)
    print(all)
    train_count = int(all * training_ratio)
    train = list[0:train_count]
    val = list[train_count:]
    with open(os.path.join(Segmentation, "train.txt"), "w", encoding="utf-8") as f:
        for name in train:
            name = name.split(".")[0]
            f.write(name + "\n")
            f.flush()
    with open(os.path.join(Segmentation, "val.txt"), "w", encoding="utf-8") as f:
        for name in val:
            name = name.split(".")[0]
            f.write(name + "\n")
            f.flush()

运行后可以在 Segmentation 目录下看到两个文件:

在这里插入图片描述

到这里就已经准备好了 VOC 格式的数据集。

五、模型训练

deeplabv3_resnet50 的复现这里就不重复造轮子了,pytorch 官方的 vision 包已经做好了实现,拉取该工具包:

git clone https://github.com/pytorch/vision.git

可以在 references 下看到不同任务的实现:

在这里插入图片描述

这里我们主要关注 segmentation 中:

在这里插入图片描述

需要修改下 train.py 中的 voc 的分类数,由于我们只是分割出猫,加上背景就是 2 类:

在这里插入图片描述

控制台进入到该目录下,运行 train.py 文件开始训练:

python train.py --data-path ./voc --lr 0.02 --dataset voc --batch-size 2 --epochs 50 --model deeplabv3_resnet50 --device cuda:0 --output-dir model --aux-loss --weights-backbone ResNet50_Weights.IMAGENET1K_V1

如果缺失部分依赖直接 pip 安装即可。

其中参数的解释如下:

  • data-path:上面我们构建的 VOC 数据集的地址。
  • lr:初始学习率。
  • dataset:数据集的格式,这里我们是 voc 格式。
  • batch-size:一个批次的大小,这里我 GPU显存有限设的 2 ,如果显存大可以调大一些。
  • epochs:训练多少个周期。
  • model:训练使用的模型,可选:fcn_resnet50、fcn_resnet101、deeplabv3_resnet50、deeplabv3_resnet101、deeplabv3_mobilenet_v3_large、lraspp_mobilenet_v3_large
  • device:训练使用的设备。
  • output-dir:训练模型输出目录。
  • aux-loss:启用 aux-loss
  • weights-backbonebackbone模型。

更多参数可以打开 train.py 文件查看:

在这里插入图片描述

训练过程:

在这里插入图片描述

这里我训练完后 loss=0.3766, mean IoU= 85.4

在这里插入图片描述

五、模型预测

import os
import torch
import torch.utils.data
import torchvision
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from torchvision import transforms


# 转换输出,将每个标签换成对应的颜色
def decode_segmap(image, num_classes, label_colors):
    r = np.zeros_like(image).astype(np.uint8)
    g = np.zeros_like(image).astype(np.uint8)
    b = np.zeros_like(image).astype(np.uint8)

    for l in range(0, num_classes):
        idx = image == l
        r[idx] = label_colors[l, 0]
        g[idx] = label_colors[l, 1]
        b[idx] = label_colors[l, 2]

    rgb = np.stack([r, g, b], axis=2)
    return rgb


def main():
    # 基础模型
    base_model = "deeplabv3_resnet50"
    # 训练后的权重
    model_weights = "./model/model_49.pth"
    # 使用设备
    device = "cuda:0"
    # 预测图像目录地址
    prediction_path = "./voc/VOCdevkit/VOC2012/JPEGImages"
    # 分类数
    num_classes = 2
    # 标签对应的颜色,0: 背景,1:cat
    label_colors = np.array([(0, 0, 0), (255, 255, 255)])

    device = torch.device(device)
    print("using {} device.".format(device))
    # 加载模型
    model = torchvision.models.get_model(
        base_model,
        num_classes=2,
    )
    assert os.path.exists(model_weights), "{} file dose not exist.".format(model_weights)
    model.load_state_dict(torch.load(model_weights, map_location=device)["model"], strict=False)
    print(model)
    model.to(device)
    model.eval()
    files = os.listdir(prediction_path)
    for file in files:
        filename = os.path.join(prediction_path, file)
        input_image = Image.open(filename).convert('RGB')
        preprocess = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])

        input_tensor = preprocess(input_image)
        input_batch = input_tensor.unsqueeze(0).to(device)

        with torch.no_grad():
            output = model(input_batch)['out'][0]
        output_predictions = output.argmax(0)

        out = output_predictions.detach().cpu().numpy()
        rgb = decode_segmap(out, num_classes, label_colors)

        plt.figure()
        plt.subplot(1, 2, 1)
        plt.imshow(input_image)

        plt.subplot(1, 2, 2)
        plt.imshow(rgb)
        plt.show()


if __name__ == '__main__':
    main()

输出结果:

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1258850.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

华夏ERP信息泄露漏漏洞复现 [附POC]

文章目录 华夏ERP信息泄露漏漏洞复现 [附POC]0x01 前言0x02 漏洞描述0x03 影响版本0x04 漏洞环境0x05 漏洞复现1.访问漏洞环境2.构造POC3.复现 0x06 修复建议 华夏ERP信息泄露漏漏洞复现 [附POC] 0x01 前言 免责声明:请勿利用文章内的相关技术从事非法测试&#x…

深思:C与C++相互调用问题

背景 上周,偶然看到同事愁眉苦脸的样子,便善意咨询了下发生了什么。简单沟通下,才知道他遇到了一个工程编译的问题,一直无法编译通过,困扰了他快一天时间。出于个人的求知欲和知识的渴望,我便主动与他一同分…

动态规划学习——等差子序列问题

目录 一,最长等差子序列 1.题目 2.题目接口 3.解题思路及其代码 二,等差序列的划分——子序列 1.题目 2.题目接口 3.解题思路及其代码 一,最长等差子序列 1.题目 给你一个整数数组 nums,返回 nums 中最长等差子序列的长度…

M3u8视频文件怎么转换成MP4?一分钟解决!

大部分网课平台或者视频平台,都是基于m3u8格式的,这是因为m3u8格式本身的特点,既支持直播又支持点播。但是往往在其他平台或者设备上不兼容,就需要转成MP4格式,那么就像大家介绍3种好用的方法~ 方法一:使用…

详解HTTP协议(介绍--版本--工作过程--Fiddler 抓包显示--请求响应讲解)

目录 一.HTTP协议的介绍 1.1HTTP是什么? 1.2HTTP版本的演变 二.HTTP的工作过程 三.使用Fiddler抓包工具 3.1简单讲解Fiddler 3.2Fiddler工作的原理 3.3抓包结果分析 四.HTTP请求 4.1认识URL 4.2关于URL encode 4.3认识方法 4.3.1认识get和post 4.3.…

Fiddler弱网测试究竟该怎么做?

前言 使用Fiddler对手机App应用进行抓包,可以对App接口进行测试,也可以了解App传输中流量使用及请求响应情况,从而测试数据传输过程中流量使用的是否合理。 抓包过程: 1、Fiddler设置 1)启动Fiddler->Tools->…

易点天下携AIGC创新成果KreadoAI亮相数贸会,解锁电商文化出海新可能

11月27日,第二届全球数字贸易博览会(以下简称“数贸会”)在浙江杭州完美落幕。作为出海营销领域最早一批布局AIGC战略的营销科技公司,易点天下受邀与来自全球800余家境内外数字贸易企业同台参展,并分享了旗下AIGC数字营…

拦截器使用详解

什么是拦截器? 拦截器是 Spring 框架提供的核⼼功能之⼀,主要⽤来拦截⽤户的请求, 在指定⽅法前后,根据业务需要执行预先设定的代码. 也就是说,允许开发⼈员提前预定义⼀些逻辑,在请求访问接口前/后执行.也可以在⽤户请求前阻止其进入接口执行 在拦截器当中,开发⼈…

广西铁塔发布ZETag定位服务,快递物流可视化将成趋势

在万亿的快递、物流红海市场中,可视化或将成为凸显差异化优势的一枚棋子。 11月17日,中国铁塔股份有限公司广西壮族自治区分公司(以下简称“广西铁塔”)在广西南宁召开“万物智联,贴‘芯’服务”物联网产品发布会&…

UE5富文本框学习(用途:A(名字)用刀(图片)击杀B(名字))

UE5-UMG教程-通用控件:多格式文本块(RichTextBlock)_哔哩哔哩_bilibilihttps://www.bilibili.com/video/BV1Pu4y1k7Z2/?p54&spm_id_frompageDriver 结果示例: 1.添加富文本框 2.添加文字样式库 点添加,更改每行行…

又又又重新刷题的第一天第一天第一天,这次目标是top100一定要刷完整至少一次两次吧:1/150:两数之和 2/150两数相加 3/150无重复字符的最长字串

题目1/150:两数之和 给定一个整数数组 nums 和一个整数目标值 target,请你在该数组中找出 和为目标值 target 的那 两个 整数,并返回它们的数组下标。 你可以假设每种输入只会对应一个答案。但是,数组中同一个元素在答案里不能重…

智慧工厂人员定位系统源码,融合位置物联网、GIS可视化等技术,实现对人员、物资精确定位管理

智慧工厂人员定位系统源码,UWB高精度定位系统源码 随着中国经济发展进入新常态,在资源和环境约束不断强化的背景下,创新驱动传统制造向智能制造转型升级,越发成为企业生存发展的关键。智能工厂作为实现智能制造的重要载体&#xf…

电商平台为什么要使用CDN加速?

随着电商零售市场的成熟,消费者越来越关注购物体验。电商零售平台的响应速度、稳定性和安全性,均可能直接影响用户购买欲和用户转化率。如何进一步提升用户体验成为电商零售企业在市场决胜的关键。 阿里云全站加速DCDN全球覆盖3200节点,在提…

【知网稳定检索】2024年应用经济学,管理科学与社会发展国际学术会议(AEMSS 2024)

2024年应用经济学,管理科学与社会发展国际学术会议(AEMSS 2024) 2024 International Conference on Applied Economics, Management Science and Social Development 2024年应用经济学,管理科学与社会发展国际学术会议&#xff…

无公网IP下,如何实现公网远程访问MongoDB文件数据库

文章目录 前言1. 安装数据库2. 内网穿透2.1 安装cpolar内网穿透2.2 创建隧道映射2.3 测试随机公网地址远程连接 3. 配置固定TCP端口地址3.1 保留一个固定的公网TCP端口地址3.2 配置固定公网TCP端口地址3.3 测试固定地址公网远程访问 前言 MongoDB是一个基于分布式文件存储的数…

即时电商需求快速爆发,商城系统平台的安全性如何保障?

双11的狂欢刚刚落下帷幕,留下的不仅是消费者的购物满足和品牌商家的销售增长,更让我们看到了一个行业变革的微妙暗示。 回溯这场全民购物的盛大节日,我们不难发现,线下零售品牌在电商巨浪的冲击下,非但没有萎靡&#x…

2023年最值得推荐的数据分析平台,可能是它!

在知乎上,商业数据分析和可视化是热门话题,其中Tableau和PowerBI是讨论最多的两个工具。随着数据分析行业的快速发展,利用这两个工具生成可视化dashboard并进行数据探索分析确实高效便捷。 作为一名数据分析爱好者,我也经常尝试各…

点成案例 | 使用自动细胞计数仪进行酵母细胞计数

一、概述 酵母可用于基础研究、酿造和蒸馏以及食品生产等多种应用,这些应用的全过程都离不开准确的细胞计数和活力测定。事实证明,较小的尺寸和形态对于自动细胞计数仪来说是相当具有挑战性的,利用活体染色剂手动计数酵母的方法繁琐且容易出…

竞赛选题 题目:基于机器视觉opencv的手势检测 手势识别 算法 - 深度学习 卷积神经网络 opencv python

文章目录 1 简介2 传统机器视觉的手势检测2.1 轮廓检测法2.2 算法结果2.3 整体代码实现2.3.1 算法流程 3 深度学习方法做手势识别3.1 经典的卷积神经网络3.2 YOLO系列3.3 SSD3.4 实现步骤3.4.1 数据集3.4.2 图像预处理3.4.3 构建卷积神经网络结构3.4.4 实验训练过程及结果 3.5 …

微软Azure AI新增Phi、Jais等,40种新大模型

微软在官方宣布在Azure AI云开发平台中,新增了Falcon、Phi、Jais、Code Llama、CLIP、Whisper V3、Stable Diffusion等40个新模型,涵盖文本、图像、代码、语音等内容生成。 开发人员只需要通过API或SDK就能快速将模型集成在应用程序中,同时支…