目录
1. 前言
2. 更改epochs
3. 推理
3.1 nnUNet_predict
3.2 切成小的nii gz文件推理
切片代码
融合代码
3.3 可视化展示
3.4 评估指标
参考
1. 前言
训练了一天半,终于跑完了。。。。
训练的模型在这可以免费下载:
基于nnUnet3d-fullres训练的spineCT训练结果资源-CSDN文库
关于nnUnet的环境搭建、数据集制作、训练网络参考:
第四章:nnUnet大模型之环境配置、数据集制作_nnunet代码详解pytorch-CSDN博客
nnUnet 大模型学习笔记(续):训练网络(3d_fullres)以及数据集标签的处理-CSDN博客
训练过程如下:
生成的结果在 nnUnet_trained_models 目录下:
训练过程的指标可以看曲线图或者训练日志(validation_raw/summary.json):
这里validation_raw/summary.json没有生成,不知道因为什么原因,程序被kill了。。。。
2. 更改epochs
1000个epoch太多了,可以更改官方的参数,如果按照本文的环境搭建,参数在这里
*/environments/nnunet/lib/python3.8/site-packages/nnunet/training/network_training/
3. 推理
nnUnet 是没有推理和测试放在一起的,它会对指定的数据进行推理
如果你有推理的labels的话,那么就可以进行指标计算,这样就可以测试
如果没有labels,那么只有推理
3.1 nnUNet_predict
在最初的数据集里,新建inferTs,用于推理nnUnet推理的结果,把想要推理的数据放在imagesTs下就行了
好像这里的测试数据必须得是0000.nii.gz结尾的,是因为多模态?
运行下面的命令:
nnUNet_predict -i DATASET/nnUNet_raw/nnUNet_raw_data/Task01_Spine/imagesTs/ -o DATASET/nnUNet_raw/nnUNet_raw_data/Task01_Spine/inferTs/ -t 1 -m 3d_fullres -f 0
- -i 是想要预测的数据目录 ,一般为imagesTs
- -o 是保存推理后的数据目录,一般为inferTs
- -t 是任务的训练标号
- -m 是nnUnet训练好的模型
- -f 是训练m模型的几折交叉验证
由于是直接使用模型进行推理, inference done. 后出现 WARNING! Cannot run postprocessing because the postprocessing file is missing. 也即表示已经完成推理。
若CT层数太多或层间距小,可能会卡在 inference done. 阶段,此时需要将CT切分成几部分分别进行 推理。
这里有时候nnUnet推理不出来,可能是因为输入层数太多?参考3.2 处理
3.2 切成小的nii gz文件推理
如果3.1可以成功推理,可以不参考这步!!!
切片代码
注意,这里只需要把想要推理的数据进行切片,然后推理完再拼接即可!!
代码如下:
import SimpleITK as sitk
import numpy as np
import os
import cv2
# 切片函数
def sliceMain(rt):
img = sitk.ReadImage(rt)
img_array = sitk.GetArrayFromImage(img) # nii-->array
print('input size:', img_array.shape)
channel = img_array.shape[0]
y, z = channel // 100, channel % 100 # 2 91
if z == 0:
n = y
print('切出的nii.gz文件个数:', n)
else:
n = y + 1
print('切出的nii.gz文件个数:', n)
for i in range(n):
star, end = i * 100, i * 100 + 100
if i == n - 1: # 最后一个切片
img_select = img_array[star:, :, :]
shape = img_select.shape
img_select = sitk.GetImageFromArray(img_select)
img_save_name = 'data_' + str(i) + '_0000.nii.gz'
print(img_save_name, 'channel:', shape)
sitk.WriteImage(img_select, img_save_name)
else:
img_select = img_array[star:end, :, :]
shape = img_select.shape
img_select = sitk.GetImageFromArray(img_select)
img_save_name = 'data_' + str(i) + '_0000.nii.gz'
print(img_save_name, 'channel:', shape)
sitk.WriteImage(img_select, img_save_name)
if __name__ == '__main__':
root = 'spine_001.nii.gz'
# 切片函数
sliceMain(rt=root)
切片结果:
效果如下:
然后推理就行了!
nnUNet_predict -i DATASET/nnUNet_raw/nnUNet_raw_data/Task01_Spine/imagesTs/ -o DATASET/nnUNet_raw/nnUNet_raw_data/Task01_Spine/inferTs/ -t 1 -m 3d_fullres -f 0
融合代码
推理完成的nii数据下放在data目录下,然后运行下面代码会自动拼接:
import SimpleITK as sitk
import numpy as np
import os
import cv2
# 切片函数
def sliceMain():
data = [os.path.join('data',u) for u in os.listdir('data')]
ret_nii = None
for index,i in enumerate(data):
img = sitk.ReadImage(i)
img_array = sitk.GetArrayFromImage(img) # nii-->array
print(i,':',img_array.shape)
if index ==0:
ret_nii = img_array
else:
ret_nii = np.concatenate((ret_nii,img_array),axis=0)
print('返回的数组size:',ret_nii.shape)
sitk.WriteImage(sitk.GetImageFromArray(ret_nii),'ret.nii.gz')
if __name__ == '__main__':
# 切片函数
sliceMain()
效果如下:
左上角是拼接后的,其余三个是nnUNet推理生成的
3.3 可视化展示
下面是原图加真实gt
下面是原图加nnUNet推理结果:
3.4 评估指标
在labelsTs下放入对应的标签即可
如果评估的话,需要真实的gt图!!!
import numpy as np
import SimpleITK as sitk
from tqdm import tqdm
def main(pred, gt,n):
gt = sitk.GetArrayFromImage(sitk.ReadImage(gt)) # [ 0 2 3 4 5 6 7 8 9 10 11 12 13 14 15]
pred = sitk.GetArrayFromImage(sitk.ReadImage(pred))
dice = []
for h in tqdm(range(n)):
if h == 0:
continue
g = np.zeros(gt.shape,dtype=np.uint8) # 单独提取某个灰度级
g[gt == h] = 255
g[g<255] = 0
g[g==255] = 1
p = np.zeros(pred.shape,dtype=np.uint8) # 单独提取某个灰度级
p[pred == h] = 255
p[p<255] = 0
p[p==255] = 1
if len(np.unique(p)) == 1 or len(np.unique(g)) == 1:
dice.append('None')
else:
dice_score = (2*(p*g).sum() / ((p+g).sum()+1e-8))
dice.append(round(dice_score,4))
print(dice)
for i in range(len(dice) - 1,-1,-1):
if dice[i] == 'None':
dice.remove('None')
print('mean dice',np.array(dice).mean())
if __name__ == "__main__":
gt_path = 'labels.nii.gz'
pred_path = 'infer.nii.gz'
classes = 19
main(pred=pred_path,gt=gt_path,n=classes)
指标如下:
['None', 0.4523, 0.9311, 0.967, 0.9732, 0.9756, 0.9687, 0.9665, 0.9674, 0.9787, 0.9743, 0.9802, 0.9811, 0.9826, 0.974, 'None', 'None', 'None']
mean dice 0.9337642857142857
代码主要实现思路:
因为推理的时候,不是所有的数据同时包含所有标签,所以这里为了方便评估,将所有的类别全部显示。如果没有某个标签就设定为None,然后计算平均dice的时候,就会去掉相应的空标签。
这是nnUnet某个epoch计算的平均dice指标
参考
参考博文如下:nnUNet使用指南(一):Ubuntu系统下使用nnUNet对自己的多模态MR数据集训练 - 梅雨明夏 - 博客园
nnUNet训练并推理自己的数据集_nnunet训练自己数据集-CSDN博客