第12章 PyTorch图像分割代码框架-1

news2024/11/15 19:42:12

从本章开始,本书将会进行深度学习图像分割的实战阶段。PyTorch作为目前最为流行的一款深度学习计算框架,在计算机视觉和图像分割任务中已经广泛使用。本章将介绍基于PyTorch的深度学习图像分割代码框架,在总体框架的基础上,基于PASCAL VOC 2012数据集,分别介绍预处理模块、数据导入模块、模型模块、工具函数模块、配置模块、主函数模块、推理模块和部署模块等。每个模块都会在基本的代码结构基础上配以简单的代码示例,帮助读者快速掌握深度学习图像分割的代码范式。

图像分割代码总体框架

深度学习项目的代码框架和范式整体较为固定,一般按照机器学习的pipeline都会有:数据预处理、数据导入、模型搭建、训练、验证、测试和部署等流程。图像分割作为经典的深度学习应用方向之一,也不例外地遵循上述范式。根据深度学习图像分割项目的特点,一个相对完整的图像分割代码框架应包含:预处理模块、数据导入模块、模型模块、工具函数模块、配置模块、主函数模块、推理模块和部署模块。

PyTorch是目前最为流行的深度学习计算框架,无论在学术界还是工业界,都被广大用户作为首选用来进行项目搭建的框架。基于PyTorch的深度学习图像分割代码框架结构如图1所示。

181d4d6bd81ca2d3862db5926c3fbcbc.png

1是一个典型的深度学习图像分割代码框架,也是一个分割项目的工程示例。左侧为代码目录结构,与虚线对应的右侧部分则是具体的模块和功能文档名称。除了前述的8个核心模块之外,图中还补充了基于Jupyter Notebookexperiment.ipynb实验文档,方便开发者在将代码写成工程文件前进行一些尝试性的实验和探索。README.md文档则是关于该项目的说明文档,具体可包括项目的基本情况介绍,训练、验证和测试方法等。需要注意的是,图11-1仅是一个典型的参考代码框架,具体的目录结构可能会因开发者的习惯和项目的特殊需要而有所不同。但不论什么样的深度学习图像分割项目,数据、模型和包含训练验证的主函数这三个部分,一定是不可或缺的。

本章将基于经典的语义分割数据集PASCAL VOC 2012,针对上述8个代码模块,在给出基本代码模板的基础上,进而给出完整的代码实现方式。关于PASCAL VOC 2012数据集相关信息,读者可直接参考本书第10章内容。

预处理模块

数据预处理模块主要用于定义一些图像预处理的功能函数,包括像数据标准化方法、图像转换方法、图像重采样方法等。预处理模块不是必须的,具体结合每个分割项目的图像数据集实际情况来定,对于PASCAL VOC 2012这样成熟的公开数据集,可能一般不需要太多预处理的流程,但对于个人收集的原始数据、一些医学影像或者遥感影像等特定领域的图像数据,可能需要一定的预处理过程。预处理模块需要用户结合个人数据集实际情况酌情进行定义。在图像分割项目中,如果我们的图像数据集需要预处理,我们可以在preprocess.py文件中定义预处理过程。

数据导入模块

PyTorch提供了数据导入的标准化代码模板,可以直接借助于torch.utils.data模块下的数据对象类Dataset来完成数据的定义和读取,然后再借助于数据导入类DataLoaders将数据按批次导入到模型中。基于Dataset的数据读取框架如代码11-1所示。

# 导入数据对象类Dataset
from torch.utils.data import Dataset
# 基于Dataset的数据定义和读取框架
class CustomDataset(Dataset):
    def __init__(self, ...):
        # stuff
        
    def __getitem__(self, index):
        # stuff
        return (img, label)
        
    def __len__(self):
        # return examples size
        return len(self.images)

其中CustomDataset为自定义的数据对象类,该类继承于Dataset类,下面包括三个方法:__init__为类初始化方法,包含了数据的存放路径、输入图像、输出图像以及数据变换等属性的定义;__getitem__定义了图像读取和变换方法,一般可通过pillow库的Image.Open先完成读取,再用torchvision库下的transform完成数据变换;__len__方法则通过调用内置函数len返回数据的长度。按照该数据读取框架,我们可以定义PASCAL VOC 2012数据集的读取方法,定义voc.py文件如下。

# 导入相关模块
import numpy as np
from PIL import Image
from pathlib import Path
from torch.utils.data import Dataset


# 定义VOC数据集类
class VOCSegmentation(Dataset):
    """
    A PyTorch Dataset class for the VOC Segmentation dataset.
    """
    def __init__(self, root, image_set='train', transform=None):
        """
        Initialize the dataset.
        Args:
            root (str): Path to the dataset root directory.
            image_set (str): The image set to use ('train' or 'val').
            transform (callable, optional): A function/transform to apply to the images and masks.
        """
        self.root = Path(root)
        self.transform = transform
        self.image_set = image_set
        base_dir = 'VOCdevkit/VOC2012'
        voc_root = self.root / base_dir
        image_dir = voc_root / 'JPEGImages'
        if not voc_root.is_dir():
            raise RuntimeError('Dataset not found.')


        mask_dir = voc_root / 'SegmentationClass'
        splits_dir = voc_root / 'ImageSets/Segmentation'
        split_f = splits_dir / f"{image_set.rstrip()}.txt"


        with open(split_f, "r") as f:
            file_names = [x.strip() for x in f.readlines()]


        self.images = [image_dir / f"{x}.jpg" for x in file_names]
        self.masks = [mask_dir / f"{x}.png" for x in file_names]
        assert (len(self.images) == len(self.masks))


    def __getitem__(self, index):
        """
        Get an item from the dataset.
        Args:
            index (int): Index of the item to get.   
        Returns:
            tuple: (image, target) where target is the image segmentation.
        """
        img = Image.open(self.images[index]).convert('RGB')
        target = Image.open(self.masks[index])
        if self.transform is not None:
            img, target = self.transform(img, target)
        return img, target


    def __len__(self):
        return len(self.images)

如代码11-2所示,我们通过pathlib库来定义数据路径,通过pillowImage.open函数来读取图像并进行同步转换。除此之外,VOC数据集掩码还需要单独进行颜色编码的接码,所以实际操作时要单独定义voc_map函数以及在VOCSegmentation类中补充一个掩码图像的解码方法decode_target。完整代码可参考本书配套代码对应章节。

需要特别说明的是,torchvision库中提供的transform模块提供了各种图像变换方法,也就是我们通常所说的在线数据增强(Online Data Augmentation)。在线数据增强是指在训练过程中,每次读取一个样本时都会进行数据增强操作。也就是说,数据增强是在每个小批量(batch)的数据上实时进行的。在线数据增强可以通过数据转换的方式,在每个训练迭代中生成多个不同的数据样本,以增加训练集的多样性,但不实际增加训练数据的数量。常见的在线数据增强操作包括随机裁剪(random cropping)、翻转(flipping)、旋转(rotation)、缩放(scaling)等。与在线数据增强对应的是离线数据增强(Offline Data Augmentation)。离线数据增强是指在训练开始之前,将原始数据集进行增强,并将增强后的数据保存为新的训练集。然后,在训练过程中,使用增强后的训练集进行模型训练。离线数据增强的好处是可以节省训练时间,因为数据增强只需在训练开始之前完成一次。常见的离线数据增强操作包括扩充数据集,例如通过旋转、平移、缩放等方式生成新的图像。常用的离线数据增强库包括imgaugalbumentationsAugmentor等。在代码11-2中,我们在初始化方法里面提供数据transform方式,同步接收imgtarget作为输入进行在线数据增强,以增强训练样本的多样性。当然了,这需要我们对torchvision中的transform方法稍微进行改动。

后续全书内容和代码将在github上开源,请关注仓库:

https://github.com/luwill/Deep-Learning-Image-Segmentation

(未完待续)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1134391.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

kuaishou web端did注册激活 学习记录

快手web端 did 注册激活的流程大概如下: 1.访问web端的接口,主动触发滑块,拿到滑块信息 2.然后滑块验证did 获取captchaToken 3.携带captchaToken访问接口 4.最后校验web端的did 是否激活 最后激活以后的效果如下: 经过测试&…

微服务-服务拆分

文章目录 服务拆分及注意事项服务拆分案例案例代码分析 服务拆分及注意事项 每个服务有独立的数据库,订单模块需要查询用户信息时,通过调用用户模块的接口,自身的数据库并没有用户信息。 服务拆分案例 案例结构 案例有2个微服务,…

【跟小嘉学 Rust 编程】三十三、Rust的Web开发框架之一: Actix-Web的基础

系列文章目录 【跟小嘉学 Rust 编程】一、Rust 编程基础 【跟小嘉学 Rust 编程】二、Rust 包管理工具使用 【跟小嘉学 Rust 编程】三、Rust 的基本程序概念 【跟小嘉学 Rust 编程】四、理解 Rust 的所有权概念 【跟小嘉学 Rust 编程】五、使用结构体关联结构化数据 【跟小嘉学…

基于Kubesphere容器云平台物联网云平台Devops实践

基于Kubesphere容器云平台物联网云平台Devops实践 项目背景 ​ 公司是做工业物联网相关业务的,现业务是云平台,技术栈 后端为 Springboot2.7JDK11 ,前端为 Vue3Ts,需要搭建自动化运维平台以实现业务代码自动部署上线,…

【C++笔记】如何用检查TCP或UDP端口是否被占用

一、检查步骤 使用socket函数创建socket_fd套接字。使用sockaddr_in结构体配置协议和端口号。使用bind函数尝试与端口进行绑定,成功返回0表示未被占用,失败返回-1表示已被占用。 二、步骤详解 2.1 socket函数 socket 函数是用于创建套接字的函数&…

【MySql】9- 实践篇(七)

文章目录 1. 一主多从的主备切换1.1 基于位点的主备切换1.2 GTID1.3 基于 GTID 的主备切换1.4 GTID 和在线 DDL 2. 读写分离问题2.1 强制走主库方案2.2 Sleep 方案2.3 判断主备无延迟方案2.4 配合 semi-sync方案2.5 等主库位点方案2.6 GTID 方案 3. 如何判断数据库是否出问题了…

Django 实战开发(一)项目搭建

1.项目搭建 用pycharm 编辑器可以直接 New 一个 Django 项目 2.新建应用 python manage.py startapp demo项目结构如下: 3.编写第一个Django 视图函数 /demo/views: from django.http import HttpResponse def welcome(request):return HttpResponse("welcome to dja…

品牌媒介工作流程是什么,媒体投放目标怎么做?

品牌媒介其实说简单也很简单,说难也很难,简单在于其实事情流程简洁,难呢,在于很多东西如果不亲身体验是无法领悟到精髓的。今天为大家分享下品牌媒介工作流程是什么,媒体投放目标怎么做? 我们怎么才能在媒体…

JWT的封装、[Authorize]的使用

JWT的封装 需要安装两个包。 包1:System.IdentityModel.Tokens.Jwt Install-Package System.IdentityModel.Tokens.Jwt 包2:Microsoft.AspNetCore.Authentication.JwtBearer Install-Package Microsoft.AspNetCore.Authentication.JwtBearer 我们创建一…

【Unity】3D跑酷游戏

展示 finish_all * 方块跑酷 1.教程链接 翻墙:https://www.youtube.com/watch?v9ZEu_I-ido4&listPLPV2KyIb3jR53Jce9hP7G5xC4O9AgnOuL&index3 2.基础制作 最终成果 2.1 基本场景 1.创建Cube作为跑道 1)记得把位置Reset; 2&#…

C#使用mysql-connector-net驱动连接mariadb报错

给树莓派用最新的官方OS重刷了一下,并且用apt install mariadb-server装上“mysql”作为我的测试服务器。然后神奇的事情发生了,之前用得好好的程序突然就报错了,经过排查,发现在连接数据库的Open阶段就报错了。写了个最单纯的Con…

CSDN学院 < 华为战略方法论进阶课 > 正式上线!

目录 你将收获 适用人群 课程内容 内容目录 CSDN学院 作者简介 你将收获 提升职场技能提升战略规划的能力实现多元化发展综合能力进阶 适用人群 主要适合公司中高层、创业者、产品经理、咨询顾问,以及致力于改变现状的学员。 课程内容 本期课程主要介绍华为…

【发展史】鼠标的发展史

最早可以追溯到1952年,皇家加拿大海军将5针保龄球放在能够侦测球面转动的硬件上,这个硬件再将信息转化成光标在屏幕上移动,用作军事计算机输入。这是我们能够追溯到的最早的依靠手部运动进行光标移动的输入设备。但当时这个东西不叫鼠标&…

Ps:套索工具

Ps 的套索工具有三种,主要通过手动绘制的方式创建选区。 套索工具 Lasso Tool 又称“自由套索工具”,可绘制任意形状的选区,灵活快速但不够精确,是仅需粗略选区时(比如,生成式填充等)最常用的工…

XTU-OJ 1178-Rectangle

题目描述 给你两个平行于坐标轴的矩形,请判断两者是不是相交(面积有重合的部分)? 输入 第一行是一个整数K,表示样例数。 每个样例占两行,每行是4个整数,表示一个矩形的对角线点的坐标&#xff0…

【API篇】十一、Flink水位线传递与迟到数据处理

文章目录 1、水位线传递2、水位线设置空闲等待3、迟到数据处理:窗口允许迟到4、迟到数据处理:侧流输出5、问 1、水位线传递 上游task处理完水位线,时钟改变后,要把数据和当前水位线继续往下游算子的task发送。当一个任务接收到多…

对mysql的联合索引的深刻理解

背景 对mysql的联合索引的考察是Java程序员面试高频考点!必须深刻理解掌握否则容易丢分非常可惜。 技术难点 考察对最左侧匹配原理理解。 原理 暂且不表。网上讲这非常多。我理解就是,B树每个非叶子节点的值都是有序存放索引的值。 比如对A、B、C …

unity 基于UGUI的无限动态滚动列表

基于UGUI的动态滚动列表,主要支持以下功能: 继承自UGUI的SrollRect,支持ScrollRect的所有功能; 使用对象池来管理列表元素,以实现列表元素的复用; 支持一行多个元素或一列多个元素; 可使用不…

漏洞复现--用友 畅捷通T+ .net反序列化RCE

免责声明: 文章中涉及的漏洞均已修复,敏感信息均已做打码处理,文章仅做经验分享用途,切勿当真,未授权的攻击属于非法行为!文章中敏感信息均已做多层打马处理。传播、利用本文章所提供的信息而造成的任何直…

互联网Java工程师面试题·Spring篇·第五弹

目录 1、什么是 spring? 2、使用 Spring 框架的好处是什么? 3、Spring 由哪些模块组成? 4、核心容器(应用上下文) 模块。 5、BeanFactory – BeanFactory 实现举例。 6、XMLBeanFactory 7、解释 AOP 模块 8、解释 JDBC 抽象和 DAO 模块。 9、…