PyTorch 加载 Mask R-CNN 预训练模型并 fine-tuning

news2024/11/25 22:25:50

目录

      • 1 Mask R-CNN 原理(简单版)
      • 2 ROI Align
      • 3 PyTorch 加载预训练模型

1 Mask R-CNN 原理(简单版)

  Mask R-CNN 是一个实例分割(Instance segmentation)算法,主要是在目标检测的基础上再进行分割。 Mask R-CNN 算法主要是 Faster R-CNN + FCN,更具体一点就是 ResNet/VGGNet + RPN + RoI Align + Fast R-CNN + FCN,如下图所示。

在这里插入图片描述

https://arxiv.org/pdf/1703.06870.pdf

Mask R-CNN 的主要创新点有两点:

  • Backbone: ResNet-101 + FPN,这是考虑到图片的多尺度特征,金字塔特征有利于小物体的检测;
  • RoI Align 替换 RoI Pooling.
  • 在 Faster R-CNN 的基础上增加了实例分割分支.

Mask R-CNN 的算法步骤如下:

  • 输入一张图片,进行数据预处理(尺寸,归一化等等);
  • 将处理好的图片传入预训练的神经网络中(ResNet 等,优秀的主干特征提取网络)获得相应的 feature map,这里的 feature map 是金字塔特征图;
  • 通过 feature map 中的每一点设定 RoI,获得多个 RoI 候选框,这里与 Faster R-CNN 相同;
  • 对这些多个 RoI 候选框送到 RPN 中进行二值分类(前景或后景)和预测框回归(Bounding-box regression),并过滤掉一部分候选的 RoI;
  • 对剩下的 RoI 进行 RoI Align 操作(即先将原图和 feature map 的 pixel 对应起来,然后将 feature map 和固定的 feature 对应起来);
  • 对这些 RoI 进行分类(N 类别分类),预测框回归和 Mask 生成。

Mask R-CNN 一大创新点是使用了 ROI Align,下面介绍其流程。

2 ROI Align

  首先假设 RoI 在原图上的大小为 x,原图到特征图缩小了 k 倍,则 RoI 对应到特征图大小就是 x / k,这里不对坐标进行取整,这意味着网格中没有确定的像素可以取,因为新坐标是浮点值。

在这里插入图片描述

https://arxiv.org/pdf/1703.06870.pdf

假设我们最终需要的特征图大小为 2 × 2,我们需要将 RoI 的特征图平均等分成 4 份,然后再将每个方格分成 4 小块,取每个小块的中心作为黑点,然后对这 4 个黑点的值选择最大值或者均值,作为这个方格最终的特征,如上图。RoI Align 使用了双线性插值的方法。小黑点周围会有特征图上的 4 个特征点,利用这 4 个特征点双线性插值出该黑点的值。 计算方式举例如下:

在这里插入图片描述
在这里插入图片描述

https://firiuza.medium.com/roi-pooling-vs-roi-align-65293ab741db

双线性差值公式如下:
在这里插入图片描述

3 PyTorch 加载预训练模型

代码结构如下:

../MaskRCNN
├── checkpoint
│ └── mask_rcnn_with_10epochs.pth
├── LabPicsMedical
│ ├── Categories
│ ├── EvaluationScripts
│ ├── ReaderScriptsPython
│ ├── Test
│ └── Train
├── result.jpg
├── dataset.py
├── test.py
└── train.py

本次任务我加载了 pytorch 的 mask r-cnn 的预训练模型,并在该模型基础上进行 fine-tuning 10 个 epochs,并保存 fine-tuning 后的模型,最终进行测试。 下面是代码中各个部分实现的具体情况和文件说明:

  • checkpoint:保存 fine-tuning 后的模型参数;
  • LabPicsMedical:数据集,本次实验使用的数据集是用于视觉理解医学的 LabPics 数据集,该数据集主要是对医学容器和容器内的物体进行实例分割,本次实验我只对图片的容器进行实例分割;下载地址请参考:https://zenodo.org/record/4736111
  • result.jpg:测试图片;
  • dataset.py:该模块用于模型数据的生成。我在其中编写了 MedicalDataset 类,继承了 Dataset 类,类中的 data_load 方法用于图片数据和注释数据的读取,注释数据包括 boxes、 labels 和 masks;
  • train.py:该模块用于模型训练;
  • test.py:该模块用于模型测试。

下面对每份代码进行详细注释,不再另外说明,张贴如下:
dataset.py

import torch
import numpy as np
import cv2
import argparse
from pathlib import Path
from torch.utils.data import DataLoader, Dataset

class MedicalDataset(Dataset) :
    def __init__(self, args):
        super(MedicalDataset, self).__init__()
        # args 是传入的所有参数集合
        self.args = args
        # 加载图片和图片的注释数据,也即分割对象的 masks、labels、boxes
        self.images, self.targets = self.data_load()

    def data_load(self):
        imgs = []
        # 得到训练数据的路径
        for path in Path(self.args.train_dir).iterdir() :
            imgs.append(path)

        Imgs = []    # 图片数据
        Annos = []   # 注释数据

        for idx in range(len(imgs)) :
            # 图片读取
            img = cv2.imread(imgs[idx].joinpath("Image.jpg").__str__(), cv2.IMREAD_COLOR)
            # BGR 图片转成 RGB 图片
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            # 图片 resize 至固定大小
            img = cv2.resize(img, self.args.image_size)
            
            # 得到分割目标也就是图片中容器的路径,本次任务只分割容器
            maskdir = imgs[idx].joinpath('Vessels')
            
            masks = []
            # 判断该图片是否有容器
            if not maskdir.exists() :
                continue
            for maskname in maskdir.iterdir() :
                # 读取图片
                vesmask = cv2.imread(maskname.__str__(), cv2.IMREAD_GRAYSCALE)
                # 制作标签,令容器所在位置的像素等于 1,其余部分作为背景,赋 0
                vesmask = (vesmask > 0).astype(np.uint8)
                vesmask = cv2.resize(vesmask, self.args.image_size)
                masks.append(vesmask)
            
            # 查看有多少个容器,也即分割对象的数量
            num_objs = len(masks)

            if num_objs == 0 :
                continue
            
            # 制作真正预测框,维度是 [N, 4]
            boxxes = torch.zeros([num_objs, 4], dtype = torch.float32)
            for i in range(num_objs) :
                # 得到 mask 的最小外接矩阵
                x, y, w, h = cv2.boundingRect(masks[i])
                # 存储左上角和右下角坐标
                boxxes[i] = torch.tensor([x, y, x + w, y + h])

            masks = torch.as_tensor(masks, dtype = torch.uint8)
            img = torch.as_tensor(img, dtype = torch.float32)
            
            # 将 3 个注释信息存到字典里面
            annotations = {}
            annotations['boxes'] = boxxes
            # 这次像素分割只是二分类,也即容器和背景
            annotations['labels'] = torch.ones((num_objs, ), dtype = torch.int64)
            annotations['masks'] = masks

            Imgs.append(img)
            Annos.append(annotations)
        
        # 将图片信息拼接
        Imgs = torch.stack([torch.as_tensor(image) for image in Imgs], dim = 0)
        # 将 (H, W, C) 转成 (C, H, W)
        Imgs = Imgs.permute(0, 3, 1, 2)

        return Imgs, Annos

    def __getitem__(self, item):
        return self.images[item], self.targets[item]

    def __len__(self):
        return len(self.images)

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--image_size', type=tuple, default=(600, 600))
    parser.add_argument('--batch_size', type=int, default=16)
    parser.add_argument('--train_dir', type=str, default='LabPicsMedical/Train')
    parser.add_argument('--num_classes', type=int, default=2)
    parser.add_argument('--device', type=str, default="cuda:8" if torch.cuda.is_available() else "cpu")
    parser.add_argument('--max_epoch', type=int, default=200)

    args = parser.parse_args()
    data = MedicalDataset(args)

train.py

import torch
import torch.optim as optim
import argparse
from pathlib import Path
from torch.utils.data import DataLoader
from torchvision.models.detection import maskrcnn_resnet50_fpn
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from dataset import MedicalDataset

class Solver :
    def __init__(self, args):
        self.args = args
        # 制作训练集,要注意输出的数据状态,分割任务中因为每张图片的目标数量可能不一样,因此无法进行张量拼接
        # 因此,对于图片或者标注信息我们要获取列表型数据,由参数 collate_fn 决定,可以查查该参数的用法
        self.train_loader = DataLoader(MedicalDataset(args), batch_size = args.batch_size, shuffle = True,
                                       collate_fn = lambda x: tuple(zip(*x)))

    def train(self):
        # 具有 ResNet-50-FPN 主干的 maskrcnn 的预训练模型
        model = maskrcnn_resnet50_fpn(pretrained = True)
        # 更换分类器
        in_features = model.roi_heads.box_predictor.cls_score.in_features
        model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes = self.args.num_classes)
        model = model.to(self.args.device)
        # 优化器
        optimizer = optim.AdamW(model.parameters(), lr = 1e-5)
        model.train()

        losses = []
        min_loss = 100000
        # 训练
        for epoch in range(self.args.max_epoch) :
            train_loss = 0.0
            for images, targets in self.train_loader :
                images = list(image.to(self.args.device) for image in images)
                targets = [{k : v.to(self.args.device) for k, v in t.items()} for t in targets]
                
                # 损失,如果输入了 target 则输出损失,否则输出的是预测分数、框、分割等等信息
                output = model(images, targets)
                loss = sum(loss for loss in output.values())

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                train_loss += loss.item()

            print(f'[{epoch + 1:02d}/{self.args.max_epoch:02d}], train_loss:{train_loss / len(self.train_loader):.5f}')
            losses.append(round(float(train_loss) / len(self.train_loader), 5))
            
            # 保存 fine-tuning 后的模型
            model_save_path = Path(self.args.model_save_path)
            if not model_save_path.exists() :
                model_save_path.mkdir(parents = True, exist_ok = True)

            if train_loss < min_loss :
                min_loss = train_loss
                torch.save(model.state_dict(), model_save_path.joinpath(f'mask_rcnn_with_{self.args.max_epoch}epochs.pth'))

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--image_size', type = tuple, default = (600, 600))
    parser.add_argument('--batch_size', type = int, default = 16)
    parser.add_argument('--train_dir', type = str, default = 'LabPicsMedical/Train')
    parser.add_argument('--num_classes', type = int, default = 2)
    parser.add_argument('--device', type = str, default = "cuda:5" if torch.cuda.is_available() else "cpu")
    parser.add_argument('--max_epoch', type = int, default = 10)
    parser.add_argument('--model_save_path', type = str, default = 'checkpoint')

    args = parser.parse_args()
    solver = Solver(args)
    solver.train()

test.py

import torch
import cv2
import argparse
import numpy as np
import random
from pathlib import Path
from torchvision.models.detection import maskrcnn_resnet50_fpn
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor

class Solver :
    def __init__(self, args):
        self.args = args

    def test(self):
        # 这部分与训练一样
        model = maskrcnn_resnet50_fpn(pretrained = True)
        in_features = model.roi_heads.box_predictor.cls_score.in_features
        model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes = self.args.num_classes)
        model = model.to(self.args.device)
        model.eval()
        # 加载已经保存的 fine-tuning 模型
        ckpt = torch.load(Path(self.args.model_save_path).joinpath(f'mask_rcnn_with_{self.args.max_epoch}epochs.pth').__str__())
        model.load_state_dict(ckpt)

        # 读入图片数据并进行维度转换等操作
        imgs = cv2.imread(self.args.test_img_path)
        imgs = cv2.resize(imgs, self.args.image_size)
        imgs = cv2.cvtColor(imgs, cv2.COLOR_BGR2RGB)
        imgs = torch.as_tensor(imgs, dtype = torch.float32).unsqueeze(0)
        imgs = imgs.permute(0, 3, 1, 2)
        imgs = [img.to(self.args.device) for img in imgs]

        with torch.no_grad() :
            # 预测
            pred = model(imgs)

        im = imgs[0].permute(1, 2, 0).detach().cpu().numpy().astype(np.uint8)
        im = cv2.cvtColor(im, cv2.COLOR_RGB2BGR)
        im2 = im.copy()
        showed = 0
        for i in range(len(pred[0]['masks'])):
            # mask,[i, 0] 之所以要有 0,是因为 mask 输出维度为 [N, 1, H, W]
            msk = pred[0]['masks'][i, 0].detach().cpu().numpy()
            # 置信度
            showed = pred[0]['scores'][i].detach().cpu().numpy()
            # 置信分数超过 0.5 则接受该 mask
            if showed > 0.5:
                # 得到最终的分割掩膜,对软掩膜进行阈值处理,一般取值为 0.5
                im2[:, :, 0][msk > 0.5] = random.randint(0, 255)
                im2[:, :, 1][msk > 0.5] = random.randint(0, 255)
                im2[:, :, 2][msk > 0.5] = random.randint(0, 255)

        cv2.imwrite(f'{str(showed)}.jpg', np.hstack([im, im2]))
        # cv2.imshow(str(showed), np.hstack([im, im2]))
        # cv2.waitKey()


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--image_size', type = tuple, default = (600, 600))
    parser.add_argument('--test_img_path', type = str, default = 'LabPicsMedical/Test/9Eval_IVbags/Image.jpg')
    parser.add_argument('--num_classes', type = int, default = 2)
    parser.add_argument('--device', type = str, default = "cuda:5" if torch.cuda.is_available() else "cpu")
    parser.add_argument('--max_epoch', type = int, default = 10)
    parser.add_argument('--model_save_path', type = str, default = 'checkpoint')

    args = parser.parse_args()
    solver = Solver(args)
    solver.test()

分割结果:
在这里插入图片描述
如果单看对针管的分割,效果还是很不错的。但是模型只分割出了针管,右下方的医疗袋没分割出来,需要调一下参数。

参数可在此处下载 https://drive.google.com/file/d/1oERLJzAel8CB2iBWfDJNqy5KwXpJhVvU/view?usp=sharing

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/4439.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

算法练习题(涉外黄成老师)

1.带锁的门在走廊上有n个带锁的门&#xff0c;从1到n依次编号。最初所有的门都是关着的。我们从门前经过n次&#xff0c;每一次都从1号门开始。在第i次经过时(i1,2,…,n)我们改变i的整数倍号锁的状态:如果门是关的&#xff0c;就打开它;如果门是打开的&#xff0c;就关上它。在…

CEC2015:(二)动态多目标野狗优化算法DMODOA求解DIMP2、dMOP2、dMOP2iso、dMOP2dec(提供Matlab代码)

一、cec2015中测试函数DIMP2、dMOP2、dMOP2iso、dMOP2dec详细信息 CEC2015&#xff1a;动态多目标测试函数之DIMP2、dMOP2、dMOP2iso、dMOP2dec详细信息 二、动态多目标野狗优化算法 多目标野狗优化算法&#xff08;Multi-Objective Dingo Optimization Algorithm&#xff0…

#入坑keychron#你还没一起入坑吗?

经济和科技飞速发展的今天&#xff0c;我们早已不在像从前那样有电脑玩就行&#xff0c;现在的我们追求的是更高的配置、更好的体验&#xff0c;就像从前一碗泡面就是最高的理想&#xff0c;而现在最少都得有根泡面搭档才能勉强接受&#xff0c;连泡面都有搭档&#xff0c;电脑…

web前端期末大作业:旅游网页设计与实现——个人旅游博客(4页)HTML+CSS

&#x1f468;‍&#x1f393;学生HTML静态网页基础水平制作&#x1f469;‍&#x1f393;&#xff0c;页面排版干净简洁。使用HTMLCSS页面布局设计,web大学生网页设计作业源码&#xff0c;这是一个不错的旅游网页制作&#xff0c;画面精明&#xff0c;排版整洁&#xff0c;内容…

【后端】初识HTTP_2

我们学习的HTTP协议&#xff0c;是应用层里面最广泛使用的协议~ 我们主要是学习HTTP的请求响应的报文格式 我们可以借助抓包工具来学习&#xff0c;抓包抓到的是文本格式~~ 根据上节内容 我们大概了解了请求和响应的格式 请求有4部分&#xff1a; &#xff08;1&#xff…

leetcode 51. N皇后 回溯法求解(c++版本)

题目描述 简单来说就给一个N*N的棋盘 棋盘上的每一列每一行以及每一个对角不能出现两个皇后 因此明确以下几点 要找出所有可能的解法也是采用回溯法进行求解&#xff08;具体在下面进行详解&#xff09; 用下面一张示例图来说明回溯法的思路 说白了就是进行搜索&#xff0c;…

java项目-第102期基于ssm的校园二手交易平台-java毕业设计

java项目-第102期基于ssm的校园二手交易平台 【源码请到资源专栏下载】 1、项目简述 Hi&#xff0c;大家好&#xff0c;今天分享的源码是基于ssm的校园二手交易平台。 该交易平台分为两部分&#xff0c;前台和后台。用户在前台进行商品选购以及交易&#xff1b;管理员登录后台可…

python-(6-3-3)爬虫---requests入门(对参数封装)

文章目录一 需求二 分析三 代码四 补充说明一 需求 爬取豆瓣电影的“纪录片”的电影信息数据 二 分析 老规矩&#xff0c;先在网页的“检查”中提取我们需要的信息 如下图所示。在“纪录片”那一页面&#xff0c;选择"network"----“XHR”----“preview”。 我们…

【附源码】Python计算机毕业设计面向社区的购物平台系统

项目运行 环境配置&#xff1a; Pychram社区版 python3.7.7 Mysql5.7 HBuilderXlist pipNavicat11Djangonodejs。 项目技术&#xff1a; django python Vue 等等组成&#xff0c;B/S模式 pychram管理等等。 环境需要 1.运行环境&#xff1a;最好是python3.7.7&#xff0c;…

java计算机毕业设计ssm+vue网络考试信息网站

项目介绍 对网络考试系统进行了介绍&#xff0c;包括研究的现状&#xff0c;还有涉及的开发背景&#xff0c;然后还对系统的设计目标进行了论述&#xff0c;还有系统的需求&#xff0c;以及整个的设计方案&#xff0c;对系统的设计以及实现&#xff0c;也都论述的比较细致&…

五大模型看深度学习用于时序预测的最新进展

引言 在以往的时序预测中&#xff0c;大部分使用的是基于统计和机器学习的一些方法。然而&#xff0c;由于深度学习在时间序列的预测中表现并不是很好&#xff0c;且部分论文表述&#xff0c;在训练时间方面&#xff0c;用 Transformer、Informer 、Logtrace 等模型来做时间序…

ESP32 入门笔记06: FreeRTOS+《两只老虎》 (ESP32 for Arduino IDE)

ESP32FreeRTOS Esp32 模块中已经提供了 FreeRTOS&#xff08;实时操作系统&#xff09;固件。 FreeRTOS有助于提高系统性能和管理模块的资源。FreeRTOS允许用户处理多项任务&#xff0c;如测量传感器读数&#xff0c;发出网络请求&#xff0c;控制电机速度等&#xff0c;所有…

旅游推荐系统

摘要 随着社会的发展&#xff0c;人们生活水平的提高&#xff0c;旅游逐渐成为人们生活中的重要活动&#xff0c;2019年国内旅游人数超过60亿人次。并且&#xff0c;旅游业已经成为了我国经济发展的一个重要支柱&#xff0c;近年来我国旅游业对GDP贡献值呈上升趋势。2019年&am…

【附源码】计算机毕业设计java装修服务分析系统设计与实现

项目运行 环境配置&#xff1a; Jdk1.8 Tomcat7.0 Mysql HBuilderX&#xff08;Webstorm也行&#xff09; Eclispe&#xff08;IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持&#xff09;。 项目技术&#xff1a; SSM mybatis Maven Vue 等等组成&#xff0c;B/S模式 M…

【数据结构】链表OJ特别篇 —— 面试情景带你深度剖析 环形链表系列问题 复制带随机指针的链表

文章目录0. 前言1. 环形链表2. 环形链表延伸问题3. 环形链表 II4. 复制带随机指针的链表5. 结语0. 前言 Hello&#xff0c;大家好&#xff0c;我是anduin。这期的内容依然是链表OJ。但与前两期不同的是&#xff0c;这次我会模拟面试的情景&#xff0c;由浅入深&#xff0c;以生…

冶金行业数字化供应链管理系统:赋能企业端到端供应链优化,推动行业数字化变革

作为国之重器&#xff0c;冶金行业在国家大基建战略的支持下&#xff0c;迎来了新一轮的发展良机&#xff0c;然而冶金行业整体产能过剩、对于产业链下游依赖严重等情况依然突出&#xff0c;机遇与挑战并存&#xff0c;在此环境下&#xff0c;如何在激烈的竞争中提升企业自身竞…

多对一和一对多的处理P21,P22

多对一&#xff1a; 按照查询嵌套处理&#xff1a; StudentMapper&#xff1a;&#xff08;接口很简单需要返回一个List&#xff09; package com.Li.dao;import com.Li.pojo.Student; import org.apache.ibatis.annotations.Select;import java.util.List;public interface …

无人机和热成像

当它变得可见时&#xff0c;看不见的世界是一个令人惊叹的景象。 热成像就是这样&#xff0c;透过使用专门的摄影机和技术&#xff0c;使看不见的红外光谱可见。 我们周围那个无形的世界隐藏着什么&#xff0c;这是一个完整的隐藏世界。 图片来源&#xff1a;Keith Davis/热成像…

腾讯面试真题 | 没在我八股文列表里。。。

&#x1f388; 作者&#xff1a;Linux猿 &#x1f388; 简介&#xff1a;CSDN博客专家&#x1f3c6;&#xff0c;华为云享专家&#x1f3c6;&#xff0c;Linux、C/C、云计算、物联网、面试、刷题、算法尽管咨询我&#xff0c;关注我&#xff0c;有问题私聊&#xff01; &…

零基础自学javase黑马课程第十天

零基础自学javase黑马课程第十天 ✨欢迎关注&#x1f5b1;点赞&#x1f380;收藏⭐留言✒ &#x1f52e;本文由京与旧铺原创&#xff0c;csdn首发&#xff01; &#x1f618;系列专栏&#xff1a;java学习 &#x1f4bb;首发时间&#xff1a;&#x1f39e;2022年11月1日&#x…