PyTorch 图像分割模型教程

news2024/9/21 13:28:24

PyTorch 图像分割模型教程

在图像分割任务中,目标是将图像的每个像素归类为某一类,以分割出特定的物体。PyTorch 提供了非常灵活的工具,可以用于构建和训练图像分割模型。我们将使用 PyTorch 的经典网络架构,如 UNetDeepLabV3,并演示如何构建、训练和测试这些模型。

1. 图像分割概述

图像分割的目标是将图像的每个像素进行分类。常见的应用场景有医学图像分割(如肿瘤检测)、自动驾驶(道路、车辆、行人分割)等。

  • 语义分割:每个像素被分配给某个类别,例如道路、天空或车辆。
  • 实例分割:不仅对物体分类,还要区分物体实例,如区分不同的行人。

PyTorch 中有许多预训练的模型可以直接用于图像分割任务,常用的模型包括 UNetFCN (Fully Convolutional Network)DeepLabV3 等。

2. 官方文档链接
  • PyTorch 官方文档
  • Torchvision 模型

3. 准备工作

在开始训练之前,我们需要安装 torch, torchvisionPIL 等依赖项,并准备图像数据集。您可以使用自己的图像数据集,或者使用 COCO、VOC 等常用数据集。

pip install torch torchvision pillow

4. 使用预训练的 DeepLabV3 模型

DeepLabV3 是一个性能优异的语义分割模型,PyTorch 的 torchvision 提供了预训练的 DeepLabV3 模型。我们将使用 COCO 数据集中的预训练模型,并进行推理和测试。

import torch
from torchvision import models, transforms
from PIL import Image
import matplotlib.pyplot as plt

# 加载预训练的 DeepLabV3 模型
model = models.segmentation.deeplabv3_resnet50(pretrained=True)
model.eval()  # 切换到评估模式

# 定义预处理步骤
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# 加载图像
input_image = Image.open("test_image.jpg")
input_tensor = preprocess(input_image)
input_batch = input_tensor.unsqueeze(0)  # 创建 batch 维度

# 将输入移到 GPU(如果可用)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
input_batch = input_batch.to(device)

# 进行预测
with torch.no_grad():
    output = model(input_batch)['out'][0]  # DeepLabV3 的输出包含 'out' 字段

# 将输出转换为类别索引(每个像素对应一个类别)
output_predictions = output.argmax(0).cpu().numpy()

# 显示分割结果
plt.imshow(output_predictions)
plt.show()

说明

  • models.segmentation.deeplabv3_resnet50(pretrained=True):加载使用 ResNet-50 作为主干网络的 DeepLabV3 模型,预训练于 COCO 数据集。
  • preprocess:对输入图像进行预处理,包括调整大小、裁剪、归一化等。
  • output_predictions:模型的输出是每个像素的类别索引,经过 argmax 操作,获取每个像素的类别。

5. UNet 模型

UNet 是一个广泛用于医学图像分割的经典模型。我们将从头实现 UNet 模型,并在简单的合成数据上进行训练。

5.1 UNet 网络结构
import torch
import torch.nn as nn
import torch.nn.functional as F

class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()

        # 下采样(编码器部分)
        self.encoder1 = self.double_conv(1, 64)
        self.encoder2 = self.double_conv(64, 128)
        self.encoder3 = self.double_conv(128, 256)
        self.encoder4 = self.double_conv(256, 512)
        
        # 中间部分
        self.middle = self.double_conv(512, 1024)
        
        # 上采样(解码器部分)
        self.upconv4 = self.up_conv(1024, 512)
        self.decoder4 = self.double_conv(1024, 512)
        self.upconv3 = self.up_conv(512, 256)
        self.decoder3 = self.double_conv(512, 256)
        self.upconv2 = self.up_conv(256, 128)
        self.decoder2 = self.double_conv(256, 128)
        self.upconv1 = self.up_conv(128, 64)
        self.decoder1 = self.double_conv(128, 64)
        
        # 最后的分类层
        self.final = nn.Conv2d(64, 1, kernel_size=1)

    def double_conv(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
        )
    
    def up_conv(self, in_channels, out_channels):
        return nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)

    def forward(self, x):
        # 编码器部分
        e1 = self.encoder1(x)
        e2 = self.encoder2(F.max_pool2d(e1, 2))
        e3 = self.encoder3(F.max_pool2d(e2, 2))
        e4 = self.encoder4(F.max_pool2d(e3, 2))
        
        # 中间部分
        middle = self.middle(F.max_pool2d(e4, 2))
        
        # 解码器部分
        d4 = self.upconv4(middle)
        d4 = torch.cat((e4, d4), dim=1)
        d4 = self.decoder4(d4)

        d3 = self.upconv3(d4)
        d3 = torch.cat((e3, d3), dim=1)
        d3 = self.decoder3(d3)

        d2 = self.upconv2(d3)
        d2 = torch.cat((e2, d2), dim=1)
        d2 = self.decoder2(d2)

        d1 = self.upconv1(d2)
        d1 = torch.cat((e1, d1), dim=1)
        d1 = self.decoder1(d1)

        return self.final(d1)

# 创建模型实例
unet_model = UNet()
print(unet_model)

说明

  • UNet 是一种编码-解码结构,包含多个下采样(编码器)和上采样(解码器)层。每次下采样都会减少特征图的大小,并增加特征通道数,上采样则恢复原始图像的大小。
  • ConvTranspose2d 用于进行上采样操作。
5.2 训练 UNet 模型

为了训练 UNet 模型,我们需要构建一个数据加载器并定义损失函数和优化器。我们以一个简单的二分类分割任务为例。

from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

# 创建合成数据集
class SyntheticSegmentationDataset(Dataset):
    def __init__(self, num_samples, image_size):
        self.num_samples = num_samples
        self.image_size = image_size
        self.transform = transforms.Compose([
            transforms.ToTensor(),
        ])

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        image = torch.rand(1, self.image_size, self.image_size)
        mask = (image > 0.5).float()  # 简单的二分类掩码
        return image, mask

# 创建数据集
dataset = SyntheticSegmentationDataset(num_samples=1000, image_size=128)
dataloader = DataLoader(dataset, batch_size=16, shuffle=True)

# 定义损失函数和优化器
criterion = nn.BCEWithLogitsLoss()  # 二分类交叉熵损失
optimizer = torch.optim.Adam(unet_model.parameters(), lr=0.001)

# 训练循环
unet_model.train()
for epoch in range(5):  # 简单训练 5 个 epoch
    for images, masks in dataloader:
        # 前向传播
        outputs = unet_model(images)
        loss = criterion(outputs, masks)

        # 反向传播和优化
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f'Epoch [{epoch

+1}/5], Loss: {loss.item():.4f}')

说明

  • BCEWithLogitsLoss 是二分类任务的标准损失函数,适合输出为单通道(1 表示目标类,0 表示背景)的分割任务。
  • 我们创建了一个合成数据集,其中图像为随机值,掩码为图像中值大于 0.5 的部分。

6. 总结

  • DeepLabV3 是一种非常强大的图像分割模型,适用于各种复杂场景,PyTorch 提供了预训练模型,适合快速部署。
  • UNet 是经典的医学图像分割模型,适用于更细致的分割任务。

通过使用 PyTorch,您可以轻松实现并训练图像分割模型,利用 GPU 加速并扩展到大规模数据集。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2152447.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【Python】探索 Errbot:多功能聊天机器人框架

不是旅行治愈了你,是你在路上放过了自己。 在当今的数字化时代,聊天机器人已成为企业与客户互动、提升工作效率和增加乐趣的重要工具。Errbot是一个高度可扩展的聊天机器人框架,它允许开发者使用Python轻松创建和定制机器人。本文将介绍Errb…

【linux008】目录操作命令篇 - rmdir 命令

文章目录 1、基本用法2、常见选项3、举例4、注意事项 rmdir 是 Linux 系统中的一个命令,用于删除空目录。它只能删除 空目录,如果目录中存在文件或子目录,则无法删除。 1、基本用法 rmdir [选项] 目录名...2、常见选项 -p, --parents&…

1.4 MySql配置文件

既然我们开始学习数据库,就不能像大学里边讲数据库课程那样简单讲一下,增删改查,然后介绍一下怎么去创建索引,怎么提交和回滚事务。我们学习数据库要明白怎么用,怎么配置,学懂学透彻了。当然MySql的配置参数…

关于群里脱敏系统的讨论2024-09-20

群里大家讨论脱敏系统,傅同学:秦老师,银行数据脱敏怎么做的,怎么存储的? 采购了脱敏系统,一般是硬件(厂商直接卖的一体机)。这个系统很复杂,大概卖50-100万一台。 最核…

为什么消费还能返利?2024年全新返利模型!

在当今竞争激烈的电商市场中,一种名为“循环购”的创新商业模式正悄然兴起,以其独特的消费返利机制和积分体系,为消费者带来了前所未有的购物体验 一、循环购模式:消费即投资的智慧选择 循环购模式并非简单的消费行为&#xff0c…

MySQL | 知识 | 从底层看清 InnoDB 数据结构

文章目录 一、InnoDB 简介InnoDB 行格式COMPACT 行格式CHAR(M) 列的存储格式VARCHAR(M) 最多能存储的数据记录中的数据太多产生的溢出行溢出的临界点 二、表空间文件的结构三、InnoDB 数据页结构页页的概览Infimum 和 Supremum使用Page Directory页的真实面貌 四、B 树是如何进…

重生奇迹MU 强化玩法套路多 极品装备由你打造

欢迎来到重生奇迹MU的强化玩法指南!想要打造极品装备吗?不可错过这篇文章,我们将为您揭开最多套路的强化技巧和窍门,帮您节省时间和资源,并带来最高效的升级结果。无论您是新手还是老玩家,本文适合所有级别…

基于MySQL全量备份+GTID同步的主从架构恢复数据至指定时间点

系列文章目录 基于GTID同步搭建主从复制 MySQL全量备份 文章目录 系列文章目录前言一、环境准备二、构建测试数据1.安装sysbench2.构建测试数据3.准备全量备份4.将全量备份和binlog拷贝到临时数据库服务器5.模拟误删除表操作 三、恢复数据到指定时间点1.临时数据库恢复数据2.找…

【Delphi】中的数据绑定(LiveBindings)

LiveBindings 是 RAD Studio 中 VCL 和 FireMonkey 框架都支持的数据绑定功能。 LiveBindings 是一个基于表达式的框架,这意味着它使用绑定表达式将对象绑定到其他对象或数据集字段。 LiveBindings 概述 LiveBindings 基于关系表达式,即绑定表达式&am…

react 甘特图之旅

react-gantt GitHub 仓库: https://github.com/clayrisser/react-gantt react-gantt-chart GitHub 仓库: https://github.com/MaTeMaTuK/gantt-task-react easy-gant-beta GitHub 仓库: https://github.com/web-widgets/react-gantt-demos 上面的版本不兼容 dhtmlx-gant…

一周热门|比GPT-4强100倍,OpenAI有望年底发布GPT-Next;1个GPU,1分钟,16K图像

大模型周报将从【企业动态】【技术前瞻】【政策法规】【专家观点】四部分,带你快速跟进大模型行业热门动态。 01 企业动态 Ilya 新公司 SSI 官宣融资 10 亿美元 据路透社报道,由 OpenAI 联合创始人、前首席科学家 Ilya Sutskever 在 2 个多月前共同创…

抖音如何改ip地址到另外城市

在数字化时代,抖音作为广受欢迎的社交媒体平台,不仅连接了亿万用户,也成为了展示个人生活、分享创意内容的重要舞台。然而,有时候出于隐私保护等需求,用户可能希望更改抖音账号显示的IP地址,使其看起来像是…

奇安信渗透2面经验分享

《网安面试指南》http://mp.weixin.qq.com/s?__bizMzkwNjY1Mzc0Nw&mid2247484339&idx1&sn356300f169de74e7a778b04bfbbbd0ab&chksmc0e47aeff793f3f9a5f7abcfa57695e8944e52bca2de2c7a3eb1aecb3c1e6b9cb6abe509d51f&scene21#wechat_redirect 《Java代码审…

泛微E9开发 创建自定义浏览框,关联物品管理表【1】

创建自定义浏览框,关联物品管理表【1】 1、自定义浏览框1.1 概念1.2 前端样式 2、创建物品管理表2.1 新建建模表单操作方法2.2 物品管理表 3、创建浏览按钮 1、自定义浏览框 1.1 概念 自定义浏览框可以理解为是建模引擎中的表与表关联的一个桥梁。比如利用建模引擎…

【学习笔记】数据结构(六 ①)

树和二叉树 (一) 文章目录 树和二叉树 (一)6.1 树(Tree)的定义和基本术语6.2 二叉树6.2.1 二叉树的定义1、斜树2、满二叉树3、完全二叉树4、二叉排序树5、平衡二叉树(AVL树)6、红黑树 6.2.2 二叉树的性质6.…

2024“智衡屋” 智能感知挑战赛决赛即将来袭

2024“智衡屋” 智能感知挑战赛决赛将于 2024 年 9 月 24 日在安徽省合肥市举行,决赛将作为 2024 年中国计量测试学会首届人工智能计量学术大会的重要环节率先举行。 2024“智衡屋” 智能感知挑战赛自启动以来,吸引了700余支高校学生、科研机构研究人员以…

Spring Boot框架在心理教育辅导系统中的应用

3 系统分析 3.1可行性分析 在进行可行性分析时,我们通常根据软件工程里方法,通过四个方面来进行分析,分别是技术、经济、操作和法律可行性。因此,在基于对目标系统的基本调查和研究后,对提出的基本方案进行可行性分析。…

weblogic CVE-2018-2894 靶场攻略

漏洞描述 Weblogic Web Service Test Page中⼀处任意⽂件上传漏洞,Web Service Test Page 在 "⽣产模式"下默认不开启,所以该漏洞有⼀定限制。 漏洞版本 weblogic 10.3.6.0 weblogic 12.1.3.0 weblogic 12.2.1.2 28 weblogic 12.2.1.3 …

ChromaDB教程_2024最新版(下)

前言 Embeddings(嵌入)是表示任何类型数据的AI原生方式,它非常适用于各种AI驱动的工具和算法中。它们可以表示文本、图像,很快还可以表示音频和视频。有许多创建嵌入的选项,无论是在本地使用已安装的库,还是…

LabVIEW 可以同时支持脚本编程和图形编程

LabVIEW 可以同时支持脚本编程和图形编程,但主要依赖其独特的 图形编程 环境(G语言),其中程序通过连线与节点来表示数据流和功能模块。不过,LabVIEW 也支持通过以下方式实现脚本编程的能力: 1. 调用外部脚本…