【3D 图像分割】基于 Pytorch 的 VNet 3D 图像分割10(测试推理篇)

news2024/12/23 0:09:55

对于直接将裁剪的patch,一个个的放到训练好的模型中进行预测,这部分代码可以直接参考前面的训练部分就行了。其实说白了,就是验证部分。不使用dataloader的方法,也只需要修改少部分代码即可。

但是,这种方法是不end to end的。我们接下来要做的,就是将一个CT数组作为输入,产生patch,最后得到预测的完整结果。这样一个初衷,就需要下面几个步骤:

  1. 读取一个序列的CT数组;
  2. OverLap的遍历所有位置,裁剪出一个个patch
  3. 一个个patch送进模型,进行预测;
  4. 对预测结果,再一个个拼接起来,组成一个和CT数组一样大小的预测结果。

这里,就要不得不先补齐下对数组cropmerge操作的方法了,建议先去学习下本系列的文章,链接:【3D 图像分割】基于 Pytorch 的 VNet 3D 图像分割9(patch 的 crop 和 merge 操作)。对于本系列的其他文章,也可直达专栏:人工智能与医学影像专栏。后期可不知道什么时候就收费了,先Mark住。

一、导言

前面提到,其实推理过程,是可以参考训练过程的。那么,两者之间又存在着哪些的不同呢?

  1. 训练是有标签金标准的,推理没有;
  2. 训练是要梯度回归,更新模型的,推理没有;
  3. 训练要保存模型,推理没有;
  4. 训练要循环很多次epoch,推理1次就好。

除了上面的这些训练要做,而推理不需要的地方,推理最最重要的就是要把预测结果,给保存下来。可能是图像形式、类别形式、或者字典形式等等。

本文就将预测结果,保存成和输入数组一样大小的数组,便于后面查看和统计。

二、模型预测

说到将已经训练好的模型,给独立进行测试,需要经历哪些步骤呢?

  1. 模型和保存参数加载;
  2. 数据预处理;
  3. 前向推理,进行预测;
  4. 预测结果后处理;
  5. 结果保存。

相比于训练过程,预测过程就简单了很多,最最关键的也就在于数据的前处理,和预测结果的后处理上面。

2.1、数据前处理

由于我们在本篇的开始,就定义了目标。就是输入的一个CT的完整数组,输出是一个和输入一样大小的预测结果。

但是呢,我们的模型输入,仅仅是一个固定大小的patch,比如48x96x96的大小。所以,这就需要将shape320x265x252,或298x279x300大小的CT数组,裁剪成一个个小块,也就是一个个patch

这里其实在独立的一篇文章,进行了单独详细的介绍。对于本文调用的函数,也是直接从那里引用的。这里就不过多的介绍了,简单说下就是:

  1. 分别遍历z、y、x的长度;
  2. overlap size的方式,有重叠的进行裁剪;
  3. 一个个patch组成patches列表。

详细介绍的文章在这里,点击去看:【3D 图像分割】基于 Pytorch 的 VNet 3D 图像分割9(patch 的 crop 和 merge 操作)

此时单个patch的数组大小为48x96x96大小,但是输入模型的,需要是[b, 1, 48, 96, 96],其中b为batch size。就还需要进行一次预处理,将一个patch,转成对应大小的数组,然后转成Tensor

对于一些patch在边缘的,可能还存在裁剪来的数组大小不够的情况,这就需要进行pad操作。代码如下,对这部分进行了记录。

def data_preprocess(img_crop, crop_size):
    if img_crop.shape != crop_size:
        pad_width = [(0, crop_size[0] - img_crop.shape[0]),
                     (0, crop_size[1] - img_crop.shape[1]),
                     (0, crop_size[2] - img_crop.shape[2])]
        img_crop = np.pad(img_crop, pad_width, mode='constant', constant_values=0)

    img_crop = np.expand_dims(img_crop, 0)  # (1, 16, 96, 96)
    img_crop = torch.from_numpy(img_crop).float()
    return img_crop/255.0

2.2、结果后处理

后处理恰恰与前处理相反。他需要将预测数组shape[b, 2, 48, 96, 96]大小的数组,转为大小为[48, 96, 96]的数组。然后多个patch,按照逆过程,再merge在一起,组成一个和输入CT一样大小的数组。

代码如下:

                for d in range(0, patches.shape[0], 1):
                    img_crop = patches[d, :, :, :]
                    data = data_preprocess(img_crop, Config.Crop_Size)

                    data = data.unsqueeze(0).to(DEVICE)
                    output = model(data)  # (output.shape) torch.Size([b, class_num, 16, 96, 96])

                    data_cpu = data.clone().cpu()
                    output_cpu = output.clone().cpu()

                    i=0
                    img_tensor = data_cpu[i][0]  # 16 * 96 * 96
                    res_tensor = torch.gt(output_cpu[i][1], output_cpu[i][0])  # 16 * 96 * 96

                    patch_res = res_tensor.numpy()
                    patches_res_list.append(patch_res)

                mask_ai = res_merge(patches_res_list, volume_size, Config.OverLap_Size)
                nrrd.write(os.path.join(mask_ai_dir,  name + '.nrrd'), mask_ai)

在本系列中,增加了一个背景类,于是就需要对包含背景类的channel与目标比类的channel做比较,得到最终的,包含目标的层。

torch.gt(Tensor1,Tensor2)

其中Tensor1Tensor2为同维度的张量或者矩阵

含义:比较Tensor1Tensor2的每一个元素,并返回一个0-1值。若Tensor1中的元素大于Tensor2中的元素,则结果取1,否则取0。

经过这样一个步骤,将包含背景类别channel=2的,变成channel为1的状态。比背景大,记为1,为前景;反着,则记为0,为背景。

2.3、预测代码

在这里,就完整的进行测试,将前面提到的需要经历所有步骤,统一进行了汇总。其中一些patchcropmerge操作,你就参照上面提到的文章去看就可以了,调用的也是那个函数,比较好上手的。

看了那篇文章,即便没有本篇下面的代码,我相信你也能知道怎么搞了,没得担心的。

import os
import cv2
import numpy as np
import torch
import torch.utils.data
import matplotlib.pyplot as plt
import shutil
import nrrd
from tqdm import tqdm

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # 没gpu就用cpu
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'  # 屏蔽通知和警告信息
os.environ["CUDA_VISIBLE_DEVICES"] = "0"  # 使用gpu0

def test():
    Config = Configuration()
    Config.display()

    results_dir = './results'
    mask_ai_dir = os.path.join(results_dir, 'pred_nrrd')

    if os.path.exists(results_dir):
        shutil.rmtree(results_dir)
    os.mkdir(results_dir)
    os.makedirs(mask_ai_dir, exist_ok=True)

    from models.unet3d_bn_activate import UNet3D
    model = UNet3D(num_out_classes=2, input_channels=1, init_feat_channels=32, testing=True)

    model_ckpt = torch.load(Config.model_path + "/best_model.pth")
    model.load_state_dict(model_ckpt)
    model = model.to(DEVICE)  # 模型部署到gpu或cpu里
    model.eval()

    with torch.no_grad():
        for idx, file in enumerate(tqdm(os.listdir(Config.valid_path))):
            if '_clean.nrrd' in file:
                name = file.split('_clean.nrrd')[0]
                nrrdClean_path = os.path.join(Config.valid_path, file)
                imgs, volume_size = load_img(nrrdClean_path)
                print('volume_size:', volume_size)

                # crop
                patches = crop_volume(imgs, Config.Crop_Size, Config.OverLap_Size)
                print('patches shape:', patches.shape)

                print(patches.shape)
                patches_res_list = []
                for d in range(0, patches.shape[0], 1):
                    img_crop = patches[d, :, :, :]
                    data = data_preprocess(img_crop, Config.Crop_Size)

                    data = data.unsqueeze(0).to(DEVICE)
                    output = model(data)  # (output.shape) torch.Size([b, class_num, 16, 96, 96])

                    data_cpu = data.clone().cpu()
                    output_cpu = output.clone().cpu()

                    i=0
                    img_tensor = data_cpu[i][0]  # 16 * 96 * 96
                    res_tensor = torch.gt(output_cpu[i][1], output_cpu[i][0])  # 16 * 96 * 96

                    patch_res = res_tensor.numpy()
                    patches_res_list.append(patch_res)

                mask_ai = res_merge(patches_res_list, volume_size, Config.OverLap_Size)
                nrrd.write(os.path.join(mask_ai_dir,  name + '.nrrd'), mask_ai)

class Configuration(object):
    valid_path = r"./database/valid"
    model_path = r'./checkpoints'

    Crop_Size = (48, 96, 96)
    OverLap_Size = [4, 8, 8]
    Num_Workers = 0

    def display(self):
        """Display Configuration values."""
        print("\nConfigurations:")
        print("")
        for a in dir(self):
            if not a.startswith("__") and not callable(getattr(self, a)):
                print("{:30} {}".format(a, getattr(self, a)))
        print("\n")

if __name__=='__main__':
    test()

最终,我们输入的是一个CT.nrrd的数组文件,最终预测结果也存储在了.nrrd的文件内。这个文件可以和标注文件做比较,进而对预测结果进行评估,也可以将预测结果打印出来,更加直观的在二维slice层面上进行查看。

三、结果可视化

现在假定,你是有了这批数据的CT数据、标注数据,和在二章节里面的预测结果,你可以有下面两种方式进行查看。

  1. itk-snap软件直接查看nrrd文件,但是他一次只能查看CT数据和标注数据,或者CT数据、预测结果,同时打开两个窗口,实现联动也是可以的,如下面这样;

1

  1. 也可以将标注数据和预测结果都以slice层的形式,绘制到一起,存储到本地,这样一层一层的查看。如下面这样:

2

第一种方式就不赘述了,直接下载itk-snap打开即可,这部分的资料比较多。

对于第二种,将标注和预测结果,按照slice都绘制到图像上面,,就稍微展开介绍下,将代码给到大家,自己可以使用。

  1. 读取CT数组,标注数组和预测结果数组,都是nrrd文件;
  2. 获取单层的slice,包括了上面三种类型数据的;
  3. 分别将标注内容,预测内容,绘制到图像上;
  4. 最后就是以图片的形式,存储到本地。

完整的代码如下:

import numpy as np
import nrrd
import os, cv2

def load_img(path_to_img):
    if path_to_img.startswith('LKDS'):
        img = np.load(path_to_img)
    else:
        img, _ = nrrd.read(path_to_img)

    return img, img.shape
    
def load_mask(path_to_mask):
    mask, _ = nrrd.read(path_to_mask)
    mask[mask > 1] = 1
    return mask, mask.shape

def getContours(output):
    img_seged = output.copy()
    img_seged = img_seged * 255

    # ---- Predict bounding box results with txt ----
    kernel = np.ones((5, 5), np.uint8)
    img_seged = cv2.dilate(img_seged, kernel=kernel)
    _, img_seged_p = cv2.threshold(img_seged, 127, 255, cv2.THRESH_BINARY)
    try:
        _, contours, _ = cv2.findContours(np.uint8(img_seged_p), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    except:
        contours, _ = cv2.findContours(np.uint8(img_seged_p), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

    return contours

def drawImg(img, mask_ai, isAI=True):
    pred_oneImg = np.expand_dims(mask_ai, axis=2)
    contours = getContours(pred_oneImg)
    print(contours)
    if isAI:
        color = (0, 0, 255)
    else:
        color = (0, 255, 0)
    if len(contours) != 0:
        for contour in contours:
            x, y, w, h = cv2.boundingRect(contour)
            xmin, ymin, xmax, ymax = x, y, x + w, y + h
            print('contouts:', xmin, ymin, xmax, ymax)
            # if isAI:
            #     cv2.drawContours(img, contour, -1, color, 2)
            # else:
            cv2.rectangle(img, (int(xmin), int(ymin)), (int(xmax), int(ymax)), color, thickness=2)

    return img

if __name__ == '__main__':
  ai_dir = r'./results/pred_nrrd'
  gt_dir = r'./database_nodule/valid'
  save_dir = r'./results/img_ai_gt'
  os.makedirs(save_dir, exist_ok=True)

  for filename in os.listdir(ai_dir):
      name = os.path.splitext(filename)[0]
      print(name, filename)
      ai_path = os.path.join(ai_dir, filename)
      gt_path = os.path.join(gt_dir, name + '_mask.nrrd')
      clean_path = os.path.join(gt_dir, name + '_clean.nrrd')

      imgs, volume_size = load_img(clean_path)

      masks_ai, masks_ai_shape = load_mask(ai_path)
      masks_gt, masks_gt_shape = load_mask(gt_path)

      assert volume_size==masks_ai_shape==masks_gt_shape

      for i in range(volume_size[0]):
          img = imgs[i, :, :]  # 获得第i张的单一数组
          mask_ai = masks_ai[i, :, :]
          mask_gt = masks_gt[i, :, :]

          print(np.max(img), np.min(img))

          img = np.expand_dims(img, axis=2)
          img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)

          img = drawImg(img, mask_ai, isAI=True)
          img = drawImg(img, mask_gt, isAI=False)

          save_path = os.path.join(save_dir, name)
          os.makedirs(save_path, exist_ok=True)
          cv2.imwrite(os.path.join(save_path, '{}_img.png'.format(i)), img)

四、总结

本文是继训练之后,对训练的模型进行独立的推理,实现对单个CT图像,经过patch操作,进行预测后,恢复成与原始CT一样大小的预测结果。并对预测结果和标注结果进行可视化对比,可以直观的看到对于单个检查,哪些结节是很容易的被识别到,而哪些比较的困难,哪些又是假阳性。

后面,就是对多个检查进行预测结果的评估,包括了结节级别的敏感度、特异度。在这样一个评估下,我们可以知道这个训练好的模型,究竟整体的性能怎么样,需不需要进一步的提高,从哪些角度进行提高?敬请期待。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1185416.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Busco-真核生物为主基因组质量评估

文章目录 简介Install必须参数谱系数据集输出结果自动谱系选择结果解读完整片段化缺失 自动选择:多domain和污染匹配注意BUSCO报告常用脚本真核Ref 简介 Busco评估基因组质量的核心原理在于通过计算基因组的通用单拷贝标记基因的比例来估计基因组的完整性。其中两个…

Javascript知识点详解:对象的继承、原型对象、原型链

目录 对象的继承 原型对象概述 构造函数的缺点 prototype 属性的作用 原型链 constructor 属性 instanceof 运算符 构造函数的继承 多重继承 对象的继承 面向对象编程很重要的一个方面,就是对象的继承。A 对象通过继承 B 对象,就能直接拥有 B …

C++之函数中实现类、调用总结(二百五十四)

简介: CSDN博客专家,专注Android/Linux系统,分享多mic语音方案、音视频、编解码等技术,与大家一起成长! 优质专栏:Audio工程师进阶系列【原创干货持续更新中……】🚀 人生格言: 人生…

MySQL之表的增删改查

目录 表的增删改查1.Create1.1 单行数据 全列插入1.2 多行数据 指定列插入1.3 插入否则更新1.4 替换 2.Retrieve1.1 SELECT 列全列查询指定列查询查询字段为表达式为查询结果指定别名结果去重 1.2 WHERE 条件普遍使用NULL 的查询结果排序筛选分页结果 3.Update对查询到的结果…

vue3+setup 解决:this.$refs引用子组件报错 is not a function

一、如果在父组件中以下四步都没问题的话&#xff0c;再看下面步骤 二、如果父组件引用的是index页面 请在 头部加上以下代码 &#xff08;如果是form页面请忽略这一步&#xff09; <template> <a-modalv-model:visible"visible"title"头部名称&…

MySQL库的库操作指南

1.创建数据库 一般格式&#xff1a;create database (if not exists) database1_name,database2_name...... 特殊形式&#xff1a; create database charset harset_name collate collate_name 解释&#xff1a; 红色字是用户自己设置的名称charset&#xff1a;指定数据…

照片处理软件 DxO FilmPack 7 mac中文版软件介绍

DxO FilmPack 7 mac是一款照片处理软件&#xff0c;专为摄影后期制作而设计。该软件来自法国的DXO公司&#xff0c;它可以在数码影像上模拟胶卷的颜色、对比度、颗粒感等。DxO FilmPack 7提供了多种胶卷颜色效果&#xff0c;包括7种正片胶卷颜色、9种单色照片胶卷颜色、5种负片…

71 内网安全-域横向网络传输应用层隧道技术

目录 必备知识点&#xff1a;1.代理和隧道技术区别?2.隧道技术为了解决什么?3.隧道技术前期的必备条件? 演示案例:网络传输应用层检测连通性-检测网络层ICMP隧道Ptunnel使用-检测利用传输层转发隧道Portmap使用-检测,利用传输层转发隧道Netcat使用-检测,利用,功能应用层DNS隧…

浮点数保留指定位数的小数,小数位自动去掉多余的0

通过DecimalFormat.format可以按照指定的格式格式化数据。 public class test { public static void main(String[] args) { DecimalFormat dnew DecimalFormat(“#.#”);//在有小数的情况下留一位小数&#xff0c;默认是四舍五入 BigDecimal decimalnew BigDecimal(“3.14159…

SQL入门语句

MySQL和SQL的区别是什么&#xff1f;之间是什么关系&#xff1f; SQL&#xff08;Structured Query Language&#xff09;是用于管理和操作关系型数据库&#xff08;RDBMS&#xff09;的标准语言。SQL还可以用于这些RDBMS&#xff1a;MySQL、Oracle、Microsoft SQL Server、Pos…

React进阶之路(二)-- 组件通信、组件进阶

文章目录 组件通信组件通信的意义父传子实现props说明子传父实现兄弟组件通信跨组件通信Context通信案例 React组件进阶children属性props校验组件生命周期 组件通信 组件通信的意义 组件是独立且封闭的单元&#xff0c;默认情况下组件只能使用自己的数据&#xff08;state&a…

0成本LLM微调上手项目,⚡️一步一步使用colab训练法律LLM,基于microsoft/phi-1_5,包含lora微调,全参微调

项目地址 &#xff1a;https://github.com/billvsme/train_law_llm ✏️LLM微调上手项目 一步一步使用Colab训练法律LLM&#xff0c;基于microsoft/phi-1_5 。通过本项目你可以0成本手动了解微调LLM。 nameColabDatasets自我认知lora-SFT微调train_self_cognition.ipynbsel…

P1131 [ZJOI2007] 时态同步

Portal. 先找出树上以 S S S 为起点最长的一条链&#xff0c;然后让其他链的长度都和该链对齐即可。 维护每个结点 x x x 的子树最长链 d max ⁡ ( x ) d_{\max}(x) dmax​(x)&#xff0c;则每次 DFS 求出最长链之后调整对齐的代价为 d max ⁡ ( x ) − ( d max ⁡ ( s o …

Java算法(二):数组元素求和(元素个位和十位不能是 7 ,且只能是偶数)

java算法&#xff08;二&#xff09; 需求&#xff1a; ​ 有这样一个数组&#xff1a; 元素是&#xff1a;{68, 27, 95, 88, 171, 996, 51, 210} ​ 求出该数组中满足要求的元素和 ​ 要求是&#xff1a; 求和的元素各位和十位都不能是 7 &#xff0c;并且只能是偶数 packa…

机器学习中的假设检验

正态性检验相关分析回归分析 所谓假设检验&#xff0c;其实就是根据原假设来构造一种已知分布的统计量来计算概率&#xff0c;根据概率值大小来判断能否拒绝原假设&#xff0c;从而得到一种结论。假设检验的过程就是&#xff0c;构造一个原假设成立条件下的事件A&#xff0c;计…

如何实现单病种上报的多院区/集团化/平台联动管理

背 景 米软售前人员在了解客户单病种上报的相关需求中发现&#xff0c;部分医院分为本部、分部或总院、分院等多个院区&#xff0c;各院区需共用一套系统&#xff1b;部分医院与其他兄弟医院隶属于同一集团医院&#xff0c;全集团需统一部署&#xff1b;部分市/区卫健委要求全…

【Node.js入门】1.3 开始开发Node.js应用程序

1.3 开始开发Node.js应用程序 学习目标 &#xff08;1&#xff09;熟悉开发工具Visual Studio Code的基本使用&#xff1b; &#xff08;2&#xff09;掌握Node.js应用程序的编写、运行和调试的基本方法。 构建第一个 Node.js应用程序 代码 const http require("htt…

RabbitMQ 消息中间件 消息队列

RabbitMQ1、RabbitMQ简介2、RabbitMQ 特点3、什么是消息队列4、RabbiMQ模式5、集群中的基本概念 单实例安装RabbitMQ安装依赖安装erlang安装rabbitmq开启rabbitmq的web访问界面添加用户修改配置文件重启服务浏览器访问Rabbit-test rabbitMQ集群准备工作&#xff08;三台&#x…

AM@向量代数@向量基本概念和向量线性运算

文章目录 abstract向量的基本概念向量向量的坐标分解式和坐标&#x1f47a;向量的模向量的长度(大小)&#x1f47a;零向量单位向量&#x1f47a;方向向量非零向量的单位向量正规化向量夹角&#x1f47a; 向量方向角和向量间夹角投影几何描述向量的线性运算向量的加减运算向量的…

【STM32 开发】| INA219采集电压、电流值

目录 前言1 原理图2 IIC地址说明3 寄存器地址说明4 开始工作前配置5 程序代码1&#xff09;驱动程序2&#xff09;头文件3) 测试代码 前言 INA219 是一款具备 I2C 或 SMBUS 兼容接口的分流器和功率监测计。该器件监测分流器电压降和总线电源电压&#xff0c;转换次数和滤波选项…