第12章 PyTorch图像分割代码框架-2

news2024/10/6 2:48:13

模型模块

本书的第5-9章重点介绍了各种2D3D的语义分割和实例分割网络模型,所以在模型模块中,我们需要做的事情就是将要实验的分割网络写在该目录下。有时候我们可能想尝试不同的分割网络结构,所以在该目录下可以存在多个想要实验的网络模型定义文件。对于PASCAL VOC这样的自然数据集,我们可能想实验Deeplab v3+PSPNetRefineNet等网络的训练效果。代码11-3给出了Deeplab v3+网络封装后的主体部分,完整网络搭建代码可参考本书配套代码对应章节。

代码11-3 Deeplab v3+网络的主体部分

# 定义Deeplab V3+类
class DeepLabHeadV3Plus(nn.Module):
    def __init__(self, in_channels, low_level_channels, num_classes, aspp_dilate=[12, 24, 36]):
        super(DeepLabHeadV3Plus, self).__init__()


        self.project = nn.Sequential(
            nn.Conv2d(low_level_channels, 48, 1, bias=False),
            nn.BatchNorm2d(48),
            nn.ReLU(inplace=True),
        )
    # ASPP
        self.aspp = ASPP(in_channels, aspp_dilate)
    # classifier head
        self.classifier = nn.Sequential(
            nn.Conv2d(304, 256, 3, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, num_classes, 1)
        )


        self._init_weight()
  # forward method
    def forward(self, feature):
        # print(feature['low_level'].shape)
        # print(feature['out'].shape)
        low_level_feature = self.project(feature['low_level'])
        output_feature = self.aspp(feature['out'])
        output_feature = F.interpolate(
            output_feature, size=low_level_feature.shape[2:], mode='bilinear', align_corners=False)
        return self.classifier(torch.cat([low_level_feature, output_feature], dim=1))
  # weight initilize
    def _init_weight(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

对于复杂网络搭建,一般都是采用自下而上的搭建方法,先搭建底层组件,再逐步向上封装,对于本例中的Deeplab v3+,可以先分别搭建backbone骨干网络、ASPP和编解码结构,最后再进行封装。

工具函数模块

工具函数是为项目完成各项功能所自定义的辅助函数,可以统一定义在utils文件夹下,根据实际项目的不同,工具函数也各不相同。常用的工具函数包括各种损失函数的定义loss.py、训练可视化函数的定义visualize.py、用于记录训练日志的log.py等。代码11-4给出了一个关于Focal loss损失函数的定义,该损失函数作为工具函数可放在loss.py文件中。

代码11-4 工具函数示例:定义一个Focal loss

# 导入相关库
import torch
import torch.nn as nn
import torch.nn.functional as F
# 定义一个Focal loss类
class FocalLoss(nn.Module):
    def __init__(self, alpha=1, gamma=2):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma


    def forward(self, inputs, targets):
        # Compute cross-entropy loss
        ce_loss = F.cross_entropy(inputs, targets, reduction='none')


        # Compute the focal loss
        pt = torch.exp(-ce_loss)  
        focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss
        return focal_loss.mean()

配置模块

配置模块是为项目模型训练传入各种参数而进行设置的模块,比如训练数据所在目录、训练所需要的各种参数、训练过程是否需要可视化等。一般来说,我们有两种方式来对项目执行参数进行配置管理,一种是直接在主函数main.py中使用argparse库对参数进行配置,然后再命令行中进行传入;另一种则是单独定义一个config.py或者config.yaml文件来对所有参数进行统一配置。基于argparse库的参数配置管理简单示例如代码11-5所示。

代码11-5 argparser参数配置管理

# 导入argparse库
import argparse
# 创建参数管理器
parser = argparse.ArgumentParser()
# 涉及数据相关的参数管理
parser.add_argument("--data_root", type=str, default='./dataset',
                     help="path to Dataset")
parser.add_argument("--save_root", type=str, default='./',
                     help="path to save result")
parser.add_argument("--dataset", type=str, default='voc',
                     choices=['voc', 'cityscapes', 'ade'], help='Name of dataset')
parser.add_argument("--num_classes", type=int, default=None,
                     help="num classes (default: None)")

在上述代码中,我们基于argparse给出了一小部分参数配置管理代码,涉及训练数据相关的部分参数,包括数据读取路径、存放路径、训练所用数据集、分割类别数量等。

主函数模块

主函数模块main.py是项目的启动模块,该模块将定义好的数据和模型模块进行组装,并结合损失函数、优化器、评估方法和可视化等组件,将config.py中配置好的项目参数传入,根据训练-验证的模式,执行图像分割项目模型训练和验证。代码11-6VOC数据集训练验证部分代码。

代码11-6 主函数模块中的训练迭代部分

# 初始化区间损失
interval_loss = 0
while True:  
  # 执行训练
  model.train()
  cur_epochs += 1
  for (images, labels) in train_loader:
    cur_itrs += 1
    images = images.to(device, dtype=torch.float32)
    labels = labels.to(device, dtype=torch.long)
    optimizer.zero_grad()
    outputs = model(images)
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()


    np_loss = loss.detach().cpu().numpy()
    interval_loss += np_loss


    if vis is not None:
      vis.vis_scalar('Loss', cur_itrs, np_loss)
    # 打印训练信息
    if (cur_itrs) % opts.print_interval == 0:
      pass
    # 保存模型
    if (cur_itrs) % opts.val_interval == 0:
      pass
      # 日志记录
      logger.info("Save the latest model to %s" % save_path_checkpoints)
      # 模型验证
      print("validation...")
      model.eval()
      val_score, ret_samples = validate(
        opts=opts, model=model, loader=val_loader, device=device, metrics=metrics,
        ret_samples_ids=vis_sample_id)
      logger.info("Validation performance: %s", val_score)
      
      # 保存最优模型
      if val_score['mean_dice'] > best_score:  
        best_score = val_score['mean_dice']
        save_ckpt(os.path.join(save_path_checkpoints, 'best_%s_%s_os%d.pth' %
                     (opts.model, opts.dataset, opts.output_stride)))
        logger.info("Save best-performance model so far to %s" % save_path_checkpoints)


      # 训练过程可视化
      if vis is not None:  
        vis.vis_scalar("[Val] Overall Acc", cur_itrs, val_score['Overall Acc'])
        vis.vis_scalar("[Val] Mean IoU", cur_itrs, val_score['Mean IoU'])
        vis.vis_table("[Val] Class IoU", val_score['Class IoU'])


        for k, (img, target, lbl) in enumerate(ret_samples):
          img = (denorm(img) * 255).astype(np.uint8)
          target = train_dst.decode_target(target).transpose(2, 0, 1).astype(np.uint8)
          lbl = train_dst.decode_target(lbl).transpose(2, 0, 1).astype(np.uint8)
          concat_img = np.concatenate((img, target, lbl), axis=2)  
          vis.vis_image('Sample %d' % k, concat_img)
          
    scheduler.step()

在代码11-6中,我们展示了一个图像分割项目主函数模块中最核心的训练和验证部分。在训练时,按照指定迭代次数保存模型和对训练过程进行可视化展示。图11-2为训练打印的部分信息。

29f5835d31863a3c0e12336eba35dd8d.png

11-2 VOC训练过程信息

11-3为基于visdom的训练过程可视化展示,包括当前训练配置参数信息,训练损失函数变化曲线、验证集全局准确率、mIoU和类别IoU等指标变化曲线图。

c829adc88b9de5b266170dd6b3b86385.png

11-3 Deeplab v3+训练过程可视化

11-4展示了两组训练过程中验证集的输入图像、标签图像和模型预测图像的对比图。可以看到,基于Deeplab v3+的分割模型在PASCAL VOC 2012上表现还不错。

ea3738bd115f8bb89add3ea4ab2e67b6.png

11-4 验证集模型效果图

后续全书内容和代码将在github上开源,请关注仓库:

https://github.com/luwill/Deep-Learning-Image-Segmentation

(未完待续)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1178778.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

JVM虚拟机:垃圾回收器之Parallel Scavenge

本文重点 在前面的课程中,我们学习了新生代的串行化垃圾回收器Serial,本文我们将学习新生代的另外一个垃圾回收器Parallel Scavenge(PS),PS是一个并行化的垃圾回收器,它使用复制算法来清理新生代的垃圾。 运行方式 如上所示,当进行垃圾回收的时候,它会暂停工作线程,而…

第二章: 创建第一个Spring Boot 应用

第二章: 创建第一个Spring Boot 应用 前言 本章重点知识:构建你的第一个Spring Boot应用:以一个简单的例子来引导你进入Spring Boot的开发,包括如何使用Spring Initializr来创建项目,以及如何使用Maven或Gradle构建和运行项目等 IntelliJ IDEA 开发工具中安装 Spring Init…

网络原理---网络初识

文章目录 网络发展史独立模式网络互连局域网LAN广域网WAN 网络通信基础IP地址端口号 认识协议什么是协议?协议分层为什么要分层?两种典型的分层方式:OSI七层TCP/IP五层 网络发展史 从我们出生以来,网络世界就已经纷繁错杂。我们虽…

简单CMake入门

CMake可以生成不同平台下的Makefile,有了CMake不用再写复杂的Makefile 视频教程:CMake 6分钟入门,不用再写复杂的Makefile 先前知识 Makefile简单入门 Cmake特性 CMake是一个用于管理C/C项目的跨平台构建工具。 跨平台:CMake是…

CSS示例001:鼠标放div上,实现旋转、放大、移动等效果

GPT能够很好的应用到我们的代码开发中,能够提高开发速度。你可以利用其代码,做出一定的更改,然后实现效能。 css实战中,经常会看到这样的场景,鼠标放到一个图片或者一个div块状时候,会出现旋转、放大、移动…

webgoat-Insecure Deserialization不安全的序列化

A(8)不安全的反序列化 反序列化是将已序列化的数据还原回对象的过程。然而,如果反序列化是不安全的,那么恶意攻击者可以在序列化的数据中夹带恶意代码,从而在反序列化时执行这些代码。这种攻击被称为反序列化。 什么…

2023年中国大学生程序设计竞赛女生专场题解, K. RSP

Dashboard - 2023年中国大学生程序设计竞赛女生专场 - Codeforces K. RSP time limit per test1 second memory limit per test512 megabytes input standard input output standard output 小 A 和小 B 在玩一种叫做石头剪刀布的游戏。 这个游戏的规则很复杂&#xff0c…

Java8强大的新特性 —— “Stream API”

一、什么是Stream API? Java Stream API是Java 8中引入的一个重要功能,它允许开发者以声明性方式处理数据集合,使代码更加简洁、可读性更好,同时还提供了并行操作的能力,从而能够更有效地利用多核处理器。 Stream AP…

1.RestCloud部署安装

一、背景 项目使用StarRocks数仓,在网上找了一遍ETL工具,本来想用DataX ,但考虑到DataX的学习成本就没使用,最后找到了RestCloud,RestCloud提供了社区开源版本,提供图形化的操作界面,相对于DataX来说更容易上手。 二、环境准备 RestCloud依赖的环境如下: 1.安装准备…

『亚马逊云科技产品测评』活动征文|EC2云服务器一键部署wordpress博客

『亚马逊云科技产品测评』活动征文|EC2云服务器一键部署wordpress博客 授权声明:本篇文章授权活动官方亚马逊云科技文章转发、改写权,包括不限于在 Developer Centre, 知乎,自媒体平台,第三方开发者媒体等亚马逊云科技…

嵌入式面试常见问题(三)

1.linux下的proc文件夹是干什么的? 进程信息:/proc文件夹包含有关系统上运行的每个进程的信息。您可以在/proc中找到以进程ID(PID)为名称的子文件夹,每个子文件夹包含有关特定进程的信息,如状态、命令行参数…

案例研究|腾讯音乐娱乐集团与JumpServer共探安全运维审计解决方案

近年来,得益于人民消费水平的提升以及版权意识的加强,用户付费意愿和在线用户数量持续增长,中国在线音乐市场呈现出稳定增长的发展态势。随着腾讯音乐于2018年12月上市,进一步推动了中国在线音乐市场的发展。 腾讯音乐娱乐集团&a…

了解计算机的大小端存储模式

我们在计算机中存储数据时,数据是如何组织和表示的是一个重要的问题。其中一个关键概念是 大小端存储模式(Endianness),它描述了多字节数据在内存中的存储方式。本文将介绍大小端存储模式的原理、应用和区别。 什么是大小端存储模…

国外住宅IP代理选择的8个方法,稳定的海外IP哪个靠谱?

一、国外住宅IP代理是什么? 代理服务器充当您和互联网之间的网关。它是一个中间服务器,将最终用户与他们浏览的网站分开。如果您使用国外代理IP,互联网流量将通过国外代理服务器流向您请求的地址。然后,请求通过同一个代理服务器…

【独立开发】跨境电商商城源码!源码全开源,无加密,软著加持,交付源码!

大家好,今天要给大家带来一个重磅好消息! 一直在寻找优质跨境电商源码?那么这个你一定不能错过! 1、独立开发:这款源码是由我们团队独立开发,从需求分析、设计、编码到测试,全部由我们亲自完成。这里没有中间商,也没有…

小红书运营篇1,新手如何快速分析拆解对标账号

hi,同学们,本期是第1期AI运营技巧篇 很多新手博主初期都非常迷茫,主要是因为他们没有找对标账号,也没有充分分析同行账号。 有些人可能会说,“我不想参考同行,我想要追求创新”。这种勇气是真的非常值得鼓…

【Unity实战】最全面的库存系统(一)(附源码)

文章目录 先来看看最终效果前言定义物品定义人物背包物品插槽数据拾取物品物品堆叠绘制UI移动拖拽物品选中物品跟随鼠标移动背包物品交换物品拆分物品物品堆叠源码完结先来看看最终效果 前言 它又来了,库存系统我前面其实一句做过很多次了,但是这次的与以往的不太一样,这个…

浅谈电力物联网时代物联网技术在电力系统中的应用

贾丽丽 安科瑞电气股份有限公司 上海嘉定201801 摘要:在电力系统建设中,物联网的应用不仅促进了我国电力工业的发展,而且对我国的物联网技术也起到了一定的促进作用。随着物联网技术应用于电力系统,推动了中国工业的快速发展。因…

Electron[3] 基础配置准备和Electron入门案例

1 背景 上一篇文章已经分享了,如何准备Electron的基础环境了。但是博客刚发才一天,就发现有人问问题了。经过实践发现,严格按照作者的博客教程走是不会有问题的,其中包括安装的环境版本等都要一致。因为昨天发的博客,…

【Java 进阶篇】JSP 指令详解

JavaServer Pages(JSP)是一种用于开发动态 Web 应用程序的强大技术。与传统的 Servlet 编程相比,JSP 更易于编写和维护。在 JSP 中,我们可以使用指令来定义页面的行为和属性。本博客将深入探讨 JSP 中的指令,从入门到精…