【PyTorch项目实战】图像分割 —— U-Net:Semantic segmentation with PyTorch

news2025/1/11 12:56:28

文章目录

  • 一、项目介绍
  • 二、项目实战
    • 2.1、搭建环境
      • 2.1.1、下载源码
      • 2.1.2、下载预训练模型
      • 2.1.3、下载训练集
    • 2.2、环境配置
    • 2.3、模型预测

U-Net是一种用于生物医学图像分割的卷积神经网络架构,最初由Olaf Ronneberger等人于2015年提出。

  • 论文: U-Net: Convolutional Networks for Biomedical Image Segmentation
  • 作者: Olaf Ronneberger, Philipp Fischer, Thomas Brox
  • 会议: MICCAI 2015
  • 数据集:使用了ISBI挑战赛中的神经元结构分割和细胞追踪数据集。主要用于评估生物医学图像分割算法的性能。
  • 开源代码:原论文并未提供官方的开源代码,但社区中有多个实现版本可供参考。
  • nnUNet:生物医学领域
  • Segment Anything:强泛化模型(建立了迄今为止最大的分割数据集,有超过1亿个mask。)

一、项目介绍

社区版本:milesial/Pytorch-UNet 是一个基于 PyTorch 的 U-Net 实现项目,专注于语义分割任务。
项目名称:U-Net: Semantic segmentation with PyTorch
主要目的:针对 Kaggle 的Carvana 图像蒙版挑战赛(来自高清图像)在 PyTorch 中定制实现的U-Net 。

Carvana 图像蒙版挑战赛:自动识别图像中的汽车边界

  • 概述:由美国二手车零售平台 Carvana 于 2017 年在 Kaggle 上举办的竞赛,最初为提升车辆展示效果而开发,旨在通过前景分割技术从高分辨率图像中提取汽车主体,去除背景。这一挑战推动了高分辨率图像分割技术在自动驾驶和车辆识别等领域的发展。
  • 数据集:包含 5088 张汽车照片及其对应的掩码(mask),用于训练和评估图像分割模型,特别是汽车前景与背景的分离。
  • 代码(多人提交了Notebook格式的开源代码,部分提供了预训练模型
  • 模型(未开源
  • 排行榜(根据Dice得分,最高Dice=0.99733

在这里插入图片描述

二、项目实战

2.1、搭建环境

2.1.1、下载源码

官方下载地址:milesial/Pytorch-UNet

2.1.2、下载预训练模型

官方提供了两个预训练模型:Pretrained model

  1. unet_carvana_scale0.5_epoch2.pth
    • 模型说明: 这是在 Carvana 数据集上训练的 U-Net 模型,缩放因子为 0.5。这意味着输入图像的尺寸在训练时被缩小了一半,有助于降低计算复杂性和内存使用。
    • 应用场景: 适合于需要快速推理或资源受限的环境,例如移动设备或边缘计算设备。
    • 训练细节: 训练通常包括数据增强、交叉熵损失计算和优化,旨在提高模型的分割精度。
  2. unet_carvana_scale1.0_epoch2.pth
    • 模型说明: 这是相同模型在 Carvana 数据集上的训练,但缩放因子为 1.0,表示输入图像的尺寸与原始图像一致。
    • 应用场景: 适合于对图像分割精度要求较高的任务,因为使用原始尺寸可以保留更多的细节信息。
    • 训练细节: 该模型可能会有更多的计算需求和内存消耗,但在准确性上通常优于缩放因子为 0.5 的模型。

2.1.3、下载训练集

如果需要自训练模型,可以下载官方数据集:carvana-image-masking-challenge:dataset

2.2、环境配置

Note : Use Python 3.6 or newer

conda install python=3.6
pip install -r requirements.txt

2.3、模型预测

基于预训练模型的Unet【Pytorch版】

该项目具有一定的影响力,由于项目需要,尝试调用其预训练模型。

  • 问题:在项目复现过程中,发现 predict.py 无法运行且有部分BUG。
  • 解决:在不改动大框架的前提下,优化了部分内容,最终可以正常执行。

优化内容如下:
(1)get_args():指定路径(预训练模型、输入图像、输出图像)
(2)get_output_filenames()
(3)img = Image.open(filename)替换为img = Image.open(filename).convert('RGB')

备注:由于项目太过简单,优化内容少,建议自己搭建(没有备份优化后项目)。

只需要优化以下两个内容,即可完成项目复现:

  • (1)在原项目的基础上,添加蓝色标记内容,用于指定路径。
  • (2)使用下述代码替换原文中的 predict.py 文件。

在这里插入图片描述

  • 测试结果:使用官方提供的预训练模型,测试效果极差(没有过度探讨内部细节,但核查代码后确定定义的 UNet 模型没有问题)
  • 原因分析:提供的预训练模型中有 epoch2 字样,若为真,则模型确实不可能收敛(感兴趣可以尝试自训练,并增加epoch训练周期)

在这里插入图片描述

import argparse
import logging
import os

import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
from torchvision import transforms

from utils.data_loading import BasicDataset
from unet import UNet
from utils.utils import plot_img_and_mask

def predict_img(net,
                full_img,
                device,
                scale_factor=1,
                out_threshold=0.5):
    net.eval()
    img = torch.from_numpy(BasicDataset.preprocess(None, full_img, scale_factor, is_mask=False))
    img = img.unsqueeze(0)
    img = img.to(device=device, dtype=torch.float32)

    with torch.no_grad():
        output = net(img).cpu()
        output = F.interpolate(output, (full_img.size[1], full_img.size[0]), mode='bilinear')
        if net.n_classes > 1:
            mask = output.argmax(dim=1)
        else:
            mask = torch.sigmoid(output) > out_threshold

    return mask[0].long().squeeze().numpy()


def get_args():
    parser = argparse.ArgumentParser(description='Predict masks from input images')
    parser.add_argument('--model', '-m', type=str, default='./data/checkpoints/unet_carvana_scale1.0_epoch2.pth', help='Specify the file in which the model is stored')
    parser.add_argument('--input', '-i', type=str, default='./data/predict_data/input/t1.png', help='Filenames of input images')
    parser.add_argument('--output', '-o', type=str, default='./data/predict_data/output/t1.png', help='Filenames of output images')
    parser.add_argument('--viz', '-v', action='store_true', help='Visualize the images as they are processed')
    parser.add_argument('--no-save', '-n', action='store_true', help='Do not save the output masks')
    parser.add_argument('--mask-threshold', '-t', type=float, default=0.5, help='Minimum probability value to consider a mask pixel white')
    parser.add_argument('--scale', '-s', type=float, default=0.5, help='Scale factor for the input images')
    parser.add_argument('--bilinear', action='store_true', default=False, help='Use bilinear upsampling')
    parser.add_argument('--classes', '-c', type=int, default=2, help='Number of classes')
    
    return parser.parse_args()


def get_output_filenames(args):
    def _generate_name(fn):
        return f'{os.path.splitext(fn)[0]}_OUT.png'

    # return args.output or list(map(_generate_name, args.input))
    return [args.output] if args.output else list(map(_generate_name, args.input))

def mask_to_image(mask: np.ndarray, mask_values):
    if isinstance(mask_values[0], list):
        out = np.zeros((mask.shape[-2], mask.shape[-1], len(mask_values[0])), dtype=np.uint8)
    elif mask_values == [0, 1]:
        out = np.zeros((mask.shape[-2], mask.shape[-1]), dtype=bool)
    else:
        out = np.zeros((mask.shape[-2], mask.shape[-1]), dtype=np.uint8)

    if mask.ndim == 3:
        mask = np.argmax(mask, axis=0)

    for i, v in enumerate(mask_values):
        out[mask == i] = v

    return Image.fromarray(out)


if __name__ == '__main__':
    args = get_args()
    logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')

    in_files = [args.input] if isinstance(args.input, str) else args.input
    out_files = get_output_filenames(args)

    net = UNet(n_channels=3, n_classes=args.classes, bilinear=args.bilinear)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    logging.info(f'Loading model {args.model}')
    logging.info(f'Using device {device}')

    net.to(device=device)
    state_dict = torch.load(args.model, map_location=device)
    mask_values = state_dict.pop('mask_values', [0, 1])
    net.load_state_dict(state_dict)

    logging.info('Model loaded!')

    for i, filename in enumerate(in_files):
        logging.info(f'Predicting image {filename} ...')
        # img = Image.open(filename)
        img = Image.open(filename).convert('RGB')

        mask = predict_img(net=net,
                           full_img=img,
                           scale_factor=args.scale,
                           out_threshold=args.mask_threshold,
                           device=device)

        if not args.no_save:
            out_filename = out_files[i]
            result = mask_to_image(mask, mask_values)
            result.save(out_filename)
            logging.info(f'Mask saved to {out_filename}')

        if args.viz:
            logging.info(f'Visualizing results for image {filename}, close to continue...')
            plot_img_and_mask(img, mask)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2237653.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

开源竞争-大数据项目期末考核

开源竞争: 自己没有办法完全掌握技术的时候就开源这个技术,培养出更多的技术依赖,让更多人完善你的技术,那么这不就是在砸罐子吗?一个行业里面总会有人砸罐子的,你不如先砸还能听个想。 客观现实&#xf…

11月7日星期四今日早报简报微语报早读

11月7日星期四,农历十月初七,早报#微语早读。 1、河南:旅行社组织1000人次境外游客在豫住宿2夜以上,可申请激励奖补; 2、主播宣称下播后商品恢复原价构成欺诈,广州市监:罚款5万元;…

HTMLCSS:3D 旋转卡片的炫酷动画

效果演示 这段代码是一个HTML和CSS的组合&#xff0c;用于创建一个具有3D效果的动画卡片。 HTML <div class"obj"><div class"objchild"><span class"inn6"><h3 class"text">我是谁&#xff1f;我在那<…

词嵌入方法(Word Embedding)

词嵌入方法&#xff08;Word Embedding&#xff09; Word Embedding是NLP中的一种技术&#xff0c;通过将单词映射到一个空间向量来表示每个单词 ✨️常见的词嵌入方法&#xff1a; &#x1f31f;Word2Vec&#xff1a;由谷歌提出的方法&#xff0c;分为CBOW&#xff08;conti…

2024下半年系统架构师考试【回忆版】

2024年11月10日&#xff0c;系统架构师考试如期举行&#xff0c;屡战屡败的参试倒是把北京的学校转了好几所。 本次考试时间 考试科目考试时间综合知识、案例分析8:30 - 12:30论文14:30 - 16:30 案例分析 1、RESTful 对于前后端的优势&#xff1b; 2、心跳相对于ping/echo的…

最简单解决NET程序员在centos系统安装c#网站

目前随着技术栈转移&#xff0c;c#程序员如何在linux服务器中部署net程序呢&#xff1f; 我做了一次实验&#xff1a;一般来说runtime和sdk都要装。 1.centos系统内命令行输入命令 sudo yum install dotnet-sdk-6.0 安装6.0版 2.检测下是否成功&#xff1a;dotnet --versio…

参数估计理论

估计理论的主要任务是在某种信号假设下&#xff0c;估算该信号中某个参数&#xff08;比如幅度、相位、达到时间&#xff09;的具体取值。 参数估计&#xff1a;先假定研究的问题具有某种数学模型&#xff0c; 如正态分布&#xff0c;二项分布&#xff0c;再用已知类别的学习样…

java多线程stop() 和 suspend() 方法为何不推荐使用?

大家好&#xff0c;我是锋哥。今天分享关于【java多线程stop() 和 suspend() 方法为何不推荐使用&#xff1f;】面试题。希望对大家有帮助&#xff1b; java多线程stop() 和 suspend() 方法为何不推荐使用&#xff1f; 1000道 互联网大厂Java工程师 精选面试题-Java资源分享网…

嵌入式硬件电子电路设计(三)电源电路之负电源

引言&#xff1a;在对信号线性度放大要求非常高的应用需要使用双电源运放&#xff0c;比如高精度测量仪器、仪表等;那么就需要给双电源运放提供正负电源。 目录 负电源电路原理 负电源的作用 如何产生负电源 负电源能作功吗&#xff1f; 地的理解 负电压产生电路 BUCK电…

C++高级编程(8)

八、标准IO库 1.输入输出流类 1)非格式化输入输出 2)put #include <iostream> #include <string> ​ using namespace std; int main() {string str "123456789";for (int i str.length() - 1; i > 0; i--) {cout.put(str[i]); //从最后一个字符开…

Python 分子图分类,GNN Model for HIV Molecules Classification,HIV 分子图分类模型;整图分类问题,代码实战

一、分子图 分子图&#xff08;molecular graph&#xff09;是一种用来表示分子结构的图形方式&#xff0c;其中原子被表示为节点&#xff08;vertices&#xff09;&#xff0c;化学键被表示为边&#xff08;edges&#xff09;。对于HIV&#xff08;人类免疫缺陷病毒&#xff…

如何调整pdf的页面尺寸

用福昕阅读器打开pdf&#xff0c;进入打印页面&#xff0c;选择“属性”&#xff0c;在弹出的页面选择“高级” 选择你想调成的纸张尺寸&#xff0c;然后打印&#xff0c;打印出来的pdf就是调整尺寸后的pdf

查缺补漏----用户上网过程(HTTP,DNS与ARP)

&#xff08;1&#xff09;HTTP 来自湖科大计算机网络微课堂&#xff1a; ① HTTP/1.0采用非持续连接方式。在该方式下&#xff0c;每次浏览器要请求一个文件都要与服务器建立TCP连接当收到响应后就立即关闭连接。 每请求一个文档就要有两倍的RTT的开销。若一个网页上有很多引…

koa、vue安装与使用

koa官网&#xff1a;https://koajs.com/ 首选创建一个文件夹&#xff1a;mkdir koaDemo (cmd即可) 文件夹初始化&#xff1a;npm init (cmd即可) 初始化完成后就会产生一个package.json的文件。 安装&#xff1a; npm install koa --save (vscode的控制台中安装&a…

Linux:版本控制器git的简单使用+gdb/cgdb调试器的使用

一&#xff0c;版本控制器git 1.1概念 为了能够更方便我们管理不同版本的文件&#xff0c;便有了版本控制器。所谓的版本控制器&#xff0c;就是能让你 了解到⼀个文件的历史&#xff0c;以及它的发展过程的系统。通俗的讲就是⼀个可以记录工程的每⼀次改动和版本迭代的⼀个…

ML 系列:第 21 节 — 离散概率分布(二项分布)

一、说明 二项分布描述了在固定数量的独立伯努利试验中一定数量的成功的概率&#xff0c;其中每个试验只有两种可能的结果&#xff08;通常标记为成功和失败&#xff09;。 二、探讨伯努利模型 例如&#xff0c;假设您正在抛一枚公平的硬币 &#xff08;其中正面成功&#xff…

【优选算法篇】微位至简,数之恢宏——解构 C++ 位运算中的理与美

文章目录 C 位运算详解&#xff1a;基础题解与思维分析前言第一章&#xff1a;位运算基础应用1.1 判断字符是否唯一&#xff08;easy&#xff09;解法&#xff08;位图的思想&#xff09;C 代码实现易错点提示时间复杂度和空间复杂度 1.2 丢失的数字&#xff08;easy&#xff0…

存算分离与计算向数据移动:深度解析与Java实现

背景 随着大数据时代的到来&#xff0c;数据量的激增给传统的数据处理架构带来了巨大的挑战。传统的“存算一体”架构&#xff0c;即计算资源与存储资源紧密耦合&#xff0c;在处理海量数据时逐渐显露出其局限性。为了应对这些挑战&#xff0c;存算分离&#xff08;Disaggrega…

WPS单元格重复值提示设置

选中要检查的所有的单元格 设置提示效果 当出现单元格值重复时&#xff0c;重复的单元格就会自动变化 要修改或删除&#xff0c;点击

Linux笔记之pandoc实现各种文档格式间的相互转换

Linux笔记之pandoc实现各种文档格式间的相互转换 code review! 文章目录 Linux笔记之pandoc实现各种文档格式间的相互转换1.安装 Pandoc2.Word转Markdown3.markdown转html4.Pandoc 支持的一些常见格式4.1.输入格式4.2.输出格式 1.安装 Pandoc sudo apt-get install pandoc # …