BraTS2021脑肿瘤分割实战

news2025/1/13 11:23:42

Brain Tumor Segmentation (BraTS) Challenge 2021 Homepage

github项目地址 brats-unet: UNet for brain tumor segmentation

BraTS是MICCAI所有比赛中历史最悠久的,到2021年已经连续举办了10年,参赛人数众多,是学习医学图像分割最前沿的平台之一。
在这里插入图片描述

1.数据准备

简介

比赛方提供多机构、多参数多模态核磁共振成像(mpMRI)数据集,包括训练集(1251例)和验证集(219例)以及测试集(530例),一共2000例患者的mpMRI扫描结果。其中训练集包含图像和分割标签,验证集和测试集没有分割标签,验证集被用于公共排行榜,测试集不公开,用作参赛者的最终排名评测。

四种模态数据:flair, t1ce, t1, t2,每个模态的数据大小都为 240 x 240 x 155,且共享分割标签。

分割标签:[0, 1, 2, 4]

  • label0:背景(background)
  • label1:坏疽(NT, necrotic tumor core)
  • label2:浮肿区域(ED,peritumoral edema)
  • label4:增强肿瘤区域(ET,enhancing tumor)

​ 本次比赛包括两个任务:

  • Task1:mpMRI扫描中分割内在异质性脑胶质母细胞瘤区域
  • Task2:预测术前基线扫描中的MGMT启动子甲基化状态

本文从数据处理、评价指标、损失函数、模型训练四个方面介绍Task1的整体实现过程

数据集下载地址

1.官网:BraTS 2021 Challenge 需要注册和申请(包括训练集和验证集)

2.Kaggle:BRaTS 2021 Task 1 Dataset 建议在kaggle上下载,数据集与官网一致(不包括验证集)

数据准备

下载数据集,解压后如下图所示:

在这里插入图片描述

每个病例包含四种模态的MRI图像和分割标签,结构如下:

BraTS2021_00000
├── BraTS2021_00000_flair.nii.gz
├── BraTS2021_00000_seg.nii.gz
├── BraTS2021_00000_t1ce.nii.gz
├── BraTS2021_00000_t1.nii.gz
└── BraTS2021_00000_t2.nii.gz

建议使用3D Slicer查看图像和标签,直观的了解一下自己要用的数据集。

2.数据预处理

每个病例的四种MRI图像大小为 240 x 240 x 155,且共享标签。

鉴于此,我将四种模态的图像合并为一个4D图像(C x H x W x D , C=4),并且和分割标签一起保存为一个.h5文件,方便后续处理。

import h5py
import os
import numpy as np
import SimpleITK as sitk
from tqdm import tqdm
# 四种模态的mri图像
modalities = ('flair', 't1ce', 't1', 't2')

# train
train_set = {
        'root': '/data/omnisky/postgraduate/Yb/data_set/BraTS2021/data',  # 四个模态数据所在地址
        'out': '/data/omnisky/postgraduate/Yb/data_set/BraTS2021/dataset/',  # 预处理输出地址
        'flist': 'train.txt',  # 训练集名单(有标签)
        }
  • 将图像保存为32位浮点数(np.float32),标签保存为整数(np.uint8),写入.h5文件
  • 对每张图像的灰度进行标准化,但保持背景区域为0

在这里插入图片描述

  • 上图是预处理后的图像,背景区域为0
def process_h5(path, out_path):
    """ Save the data with dtype=float32.
        z-score is used but keep the background with zero! """
    # SimpleITK读取图像默认是是 DxHxW,这里转为 HxWxD
    label = sitk.GetArrayFromImage(sitk.ReadImage(path + 'seg.nii.gz')).transpose(1,2,0)
    print(label.shape)
    # 堆叠四种模态的图像,4 x (H,W,D) -> (4,H,W,D) 
    images = np.stack([sitk.GetArrayFromImage(sitk.ReadImage(path + modal + '.nii.gz')).transpose(1,2,0) for modal in modalities], 0)  # [240,240,155]
    # 数据类型转换
    label = label.astype(np.uint8)
    images = images.astype(np.float32)
    case_name = path.split('/')[-1]
    # case_name = os.path.split(path)[-1]  # windows路径与linux不同
    
    path = os.path.join(out_path,case_name)
    output = path + 'mri_norm2.h5'
    # 对第一个通道求和,如果四个模态都为0,则标记为背景(False)
    mask = images.sum(0) > 0
    for k in range(4):

        x = images[k,...]  #
        y = x[mask]

        # 对背景外的区域进行归一化
        x[mask] -= y.mean()
        x[mask] /= y.std()

        images[k,...] = x
    print(case_name,images.shape,label.shape)
    f = h5py.File(output, 'w')
    f.create_dataset('image', data=images, compression="gzip")
    f.create_dataset('label', data=label, compression="gzip")
    f.close()


def doit(dset):
    root, out_path = dset['root'], dset['out']
    file_list = os.path.join(root, dset['flist'])
    subjects = open(file_list).read().splitlines()
    names = ['BraTS2021_' + sub for sub in subjects]
    paths = [os.path.join(root, name, name + '_') for name in names]

    for path in tqdm(paths):
        process_h5(path, out_path)
        # break
    print('Finished')


if __name__ == '__main__':
    doit(train_set)

数据保存在 mri_norm2.h5 文件中,每个 mri_norm2.h5 相当于一个字典,字典的键为 image 和 label ,值为对应的数组。

在这里插入图片描述

处理后的数据,可以用下面的几行代码测试一下,记得修改为你自己的路径

import h5py
import numpy as np
p = '/***/data_set/BraTS2021/all/BraTS2021_00000_mri_norm2.h5'
h5f = h5py.File(p, 'r')
image = h5f['image'][:]
label = h5f['label'][:]
print('image shape:',image.shape,'\t','label shape',label.shape)
print('label set:',np.unique(label))

# image shape: (4, 240, 240, 155)          label shape (240, 240, 155)
# label set: [0 1 2 4]

将数据集按照 8:1:1随机划分为训练集、验证集和测试集,将划分后的数据名保存为.txt文件

import os
from sklearn.model_selection import train_test_split

# 预处理输出地址
data_path = "/***/data_set/BraTS2021/dataset"
train_and_test_ids = os.listdir(data_path)

train_ids, val_test_ids = train_test_split(train_and_test_ids, test_size=0.2,random_state=21)
val_ids, test_ids = train_test_split(val_test_ids, test_size=0.5,random_state=21)
print("Using {} images for training, {} images for validation, {} images for testing.".format(len(train_ids),len(val_ids),len(test_ids)))

with open('/***/data_set/BraTS2021/train.txt','w') as f:
    f.write('\n'.join(train_ids))

with open('/***/data_set/BraTS2021/valid.txt','w') as f:
    f.write('\n'.join(val_ids))

with open('/***/data_set/BraTS2021/test.txt','w') as f:
    f.write('\n'.join(test_ids))

划分结果:

Using 1000 images for training, 125 images for validation, 126 images for testing.
......
BraTS2021_00002_mri_norm2.h5
BraTS2021_00003_mri_norm2.h5
BraTS2021_00014_mri_norm2.h5
......

3.数据增强

下面是我写的Dataset类以及一些数据增强方法

整体架构

import os
import torch
from torch.utils.data import Dataset
import random
import numpy as np
from torchvision.transforms import transforms
import h5py


class BraTS(Dataset):
    def __init__(self,data_path, file_path,transform=None):
        with open(file_path, 'r') as f:
            self.paths = [os.path.join(data_path, x.strip()) for x in f.readlines()]
        self.transform = transform

    def __getitem__(self, item):
        h5f = h5py.File(self.paths[item], 'r')
        image = h5f['image'][:]
        label = h5f['label'][:]
        #[0,1,2,4] -> [0,1,2,3]
        label[label == 4] = 3
        # print(image.shape)
        sample = {'image': image, 'label': label}
        if self.transform:
            sample = self.transform(sample)
        return sample['image'], sample['label']

    def __len__(self):
        return len(self.paths)

    def collate(self, batch):
        return [torch.cat(v) for v in zip(*batch)]


if __name__ == '__main__':
    from torchvision import transforms
    data_path = "/***/data_set/BraTS2021/dataset"
    test_txt = "/***/data_set/BraTS2021/test.txt"
    test_set = BraTS(data_path,test_txt,transform=transforms.Compose([
        RandomRotFlip(),
        RandomCrop((160,160,128)),
        GaussianNoise(p=0.1),
        ToTensor()
    ]))
    d1 = test_set[0]
    image,label = d1
    print(image.shape)
    print(label.shape)
    print(np.unique(label))

具体的数据增强方法我列在了下面,包括裁剪、旋转、翻转、高斯噪声、对比度变换和亮度增强的源码,部分代码借鉴了nnUNet的数据增强方法。

随机裁剪

原始图像尺寸为 240 x 240 x 155,但图像周围是有很多黑边的,我将图像裁剪为 160 x 160 x 128

class RandomCrop(object):
    """
    Crop randomly the image in a sample
    Args:
    output_size (int): Desired output size
    """
    def __init__(self, output_size):
        self.output_size = output_size

    def __call__(self, sample):
        image, label = sample['image'], sample['label']

        (c, w, h, d) = image.shape
        w1 = np.random.randint(0, w - self.output_size[0])
        h1 = np.random.randint(0, h - self.output_size[1])
        d1 = np.random.randint(0, d - self.output_size[2])

        label = label[w1:w1 + self.output_size[0], h1:h1 + self.output_size[1], d1:d1 + self.output_size[2]]
        image = image[:,w1:w1 + self.output_size[0], h1:h1 + self.output_size[1], d1:d1 + self.output_size[2]]
        return {'image': image, 'label': label}

中心裁剪

class CenterCrop(object):
    def __init__(self, output_size):
        self.output_size = output_size

    def __call__(self, sample):
        image, label = sample['image'], sample['label']

        (c,w, h, d) = image.shape

        w1 = int(round((w - self.output_size[0]) / 2.))
        h1 = int(round((h - self.output_size[1]) / 2.))
        d1 = int(round((d - self.output_size[2]) / 2.))

        label = label[w1:w1 + self.output_size[0], h1:h1 + self.output_size[1], d1:d1 + self.output_size[2]]
        image = image[:,w1:w1 + self.output_size[0], h1:h1 + self.output_size[1], d1:d1 + self.output_size[2]]

        return {'image': image, 'label': label}

随机翻转

旋转可能会导致图像重采样,因为数据集比较充分,我只在{90,180,270}度做一个简单旋转,不涉及重采样。

class RandomRotFlip(object):
    """
    Crop randomly flip the dataset in a sample
    Args:
    output_size (int): Desired output size
    """

    def __call__(self, sample):
        image, label = sample['image'], sample['label']
        k = np.random.randint(0, 4)
        image = np.stack([np.rot90(x,k) for x in image],axis=0)
        label = np.rot90(label, k)
        axis = np.random.randint(1, 4)
        image = np.flip(image, axis=axis).copy()
        label = np.flip(label, axis=axis-1).copy()

        return {'image': image, 'label': label}

高斯噪声

def augment_gaussian_noise(data_sample, noise_variance=(0, 0.1)):
    if noise_variance[0] == noise_variance[1]:
        variance = noise_variance[0]
    else:
        variance = random.uniform(noise_variance[0], noise_variance[1])
    data_sample = data_sample + np.random.normal(0.0, variance, size=data_sample.shape)
    return data_sample


class GaussianNoise(object):
    def __init__(self, noise_variance=(0, 0.1), p=0.5):
        self.prob = p
        self.noise_variance = noise_variance

    def __call__(self, sample):
        image = sample['image']
        label = sample['label']
        if np.random.uniform() < self.prob:
            image = augment_gaussian_noise(image, self.noise_variance)
        return {'image': image, 'label': label}

对比度变换

  • contrast_range:对比度增强的范围
  • preserve_range:是否保留数据的取值范围
  • per_channel:是否对每个通道的图像分别进行对比度增强
def augment_contrast(data_sample, contrast_range=(0.75, 1.25), preserve_range=True, per_channel=True):
    if not per_channel:
        mn = data_sample.mean()
        if preserve_range:
            minm = data_sample.min()
            maxm = data_sample.max()
        if np.random.random() < 0.5 and contrast_range[0] < 1:
            factor = np.random.uniform(contrast_range[0], 1)
        else:
            factor = np.random.uniform(max(contrast_range[0], 1), contrast_range[1])
        data_sample = (data_sample - mn) * factor + mn
        if preserve_range:
            data_sample[data_sample < minm] = minm
            data_sample[data_sample > maxm] = maxm
    else:
        for c in range(data_sample.shape[0]):
            mn = data_sample[c].mean()
            if preserve_range:
                minm = data_sample[c].min()
                maxm = data_sample[c].max()
            if np.random.random() < 0.5 and contrast_range[0] < 1:
                factor = np.random.uniform(contrast_range[0], 1)
            else:
                factor = np.random.uniform(max(contrast_range[0], 1), contrast_range[1])
            data_sample[c] = (data_sample[c] - mn) * factor + mn
            if preserve_range:
                data_sample[c][data_sample[c] < minm] = minm
                data_sample[c][data_sample[c] > maxm] = maxm
    return data_sample


class ContrastAugmentationTransform(object):
    def __init__(self, contrast_range=(0.75, 1.25), preserve_range=True, per_channel=True,p_per_sample=1.):
        self.p_per_sample = p_per_sample
        self.contrast_range = contrast_range
        self.preserve_range = preserve_range
        self.per_channel = per_channel

    def __call__(self, sample):
        image = sample['image']
        label = sample['label']
        for b in range(len(image)):
            if np.random.uniform() < self.p_per_sample:
                image[b] = augment_contrast(image[b], contrast_range=self.contrast_range,
                                            preserve_range=self.preserve_range, per_channel=self.per_channel)
        return {'image': image, 'label': label}

亮度变换

附加亮度从具有μ和σ的高斯分布中采样

def augment_brightness_additive(data_sample, mu:float, sigma:float , per_channel:bool=True, p_per_channel:float=1.):
    if not per_channel:
        rnd_nb = np.random.normal(mu, sigma)
        for c in range(data_sample.shape[0]):
            if np.random.uniform() <= p_per_channel:
                data_sample[c] += rnd_nb
    else:
        for c in range(data_sample.shape[0]):
            if np.random.uniform() <= p_per_channel:
                rnd_nb = np.random.normal(mu, sigma)
                data_sample[c] += rnd_nb
    return data_sample


class BrightnessTransform(object):
    def __init__(self, mu, sigma, per_channel=True, p_per_sample=1., p_per_channel=1.):
        self.p_per_sample = p_per_sample
        self.mu = mu
        self.sigma = sigma
        self.per_channel = per_channel
        self.p_per_channel = p_per_channel

    def __call__(self, sample):
        data, label = sample['image'], sample['label']

        for b in range(data.shape[0]):
            if np.random.uniform() < self.p_per_sample:
                data[b] = augment_brightness_additive(data[b], self.mu, self.sigma, self.per_channel,
                                                      p_per_channel=self.p_per_channel)

        return {'image': data, 'label': label}

数据类型转换

将Numpy数组转为Tensor

class ToTensor(object):
    """Convert ndarrays in sample to Tensors."""
    def __call__(self, sample):
        image = sample['image']
        label = sample['label']

        image = torch.from_numpy(image).float()
        label = torch.from_numpy(label).long()

        return {'image': image, 'label': label}

相比其他医学影像数据集,BraTS2021是非常高质量的,对数据增强方法并不是很敏感。

4.评价损失

损失函数:

combination of dice and crossentropy loss

在这里插入图片描述

dice loss

在这里插入图片描述

  • μ是网络的softmax输出
  • v是分割标签的one-hot编码

其实就是将计算dice时的torch.argmax替换为了torch.softmax

import torch.nn.functional as F
import torch.nn as nn
import torch
from einops import rearrange


class Loss(nn.Module):
    def __init__(self, n_classes, weight=None, alpha=0.5):
        "dice_loss_plus_cetr_weighted"
        super(Loss, self).__init__()
        self.n_classes = n_classes
        self.weight = weight.cuda()
        # self.weight = weight
        self.alpha = alpha

    def forward(self, input, target):
        smooth = 0.01  # 防止分母为0
        input1 = F.softmax(input, dim=1)
        target1 = F.one_hot(target,self.n_classes)
        input1 = rearrange(input1,'b n h w s -> b n (h w s)')
        target1 = rearrange(target1,'b h w s n -> b n (h w s)')

        input1 = input1[:, 1:, :]
        target1 = target1[:, 1:, :].float()

        # 以batch为单位计算loss和dice_loss,据说训练更稳定,和上面的公式有出入
        # 注意,这里的dice不是真正的dice,叫做soft_dice更贴切
        inter = torch.sum(input1 * target1)
        union = torch.sum(input1) + torch.sum(target1) + smooth
        dice = 2.0 * inter / union

        loss = F.cross_entropy(input,target, weight=self.weight)

        total_loss = (1 - self.alpha) * loss + (1 - dice) * self.alpha

        return total_loss


if __name__ == '__main__':
    torch.manual_seed(3)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    losser = Loss(n_classes=4, weight=torch.tensor([0.2, 0.3, 0.25, 0.25])).to(device)
    x = torch.randn((2, 4, 16, 16, 16)).to(device)
    y = torch.randint(0, 4, (2, 16, 16, 16)).to(device)
    print(losser(x, y))

评价指标:

在这里插入图片描述

dice计算方法:
2 ( A ∩ B ) A + B 2{(A \cap B)}\over{A + B} A+B2(AB)

def Dice(output, target, eps=1e-3):
    inter = torch.sum(output * target,dim=(1,2,3)) + eps
    union = torch.sum(output,dim=(1,2,3)) + torch.sum(target,dim=(1,2,3)) + eps * 2
    x = 2 * inter / union
    dice = torch.mean(x)
    return dice
  • output: (b, num_class, d, h, w) target: (b, d, h, w)
  • dice1(ET):label4
  • dice2(TC):label1 + label4
  • dice3(WT): label1 + label2 + label4
  • 注意,这里的label4已经被替换为3
def cal_dice(output, target):
    output = torch.argmax(output,dim=1)
    dice1 = Dice((output == 3).float(), (target == 3).float())
    dice2 = Dice(((output == 1) | (output == 3)).float(), ((target == 1) | (target == 3)).float())
    dice3 = Dice((output != 0).float(), (target != 0).float())

    return dice1, dice2, dice3

5.模型训练

在这里插入图片描述

UNet为例,我把完整代码放在了下面

module:

import torch
import torch.nn as nn


class InConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(InConv, self).__init__()
        self.conv = DoubleConv(in_ch, out_ch)

    def forward(self, x):
        x = self.conv(x)
        return x

class Down(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(Down, self).__init__()
        self.mpconv = nn.Sequential(
            nn.MaxPool3d(2, 2),
            DoubleConv(in_ch, out_ch)
        )

    def forward(self, x):
        x = self.mpconv(x)
        return x

class OutConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(OutConv, self).__init__()
        self.conv = nn.Conv3d(in_ch, out_ch, 1)
        # self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.conv(x)
        # x = self.sigmoid(x)
        return x

class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv3d(in_ch, out_ch, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm3d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv3d(out_ch, out_ch, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm3d(out_ch),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        x = self.conv(x)
        return x

class Up(nn.Module):
    def __init__(self, in_ch, skip_ch,out_ch):
        super(Up, self).__init__()
        self.up = nn.ConvTranspose3d(in_ch, in_ch, kernel_size=2, stride=2)
        self.conv = DoubleConv(in_ch+skip_ch, out_ch)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        x = torch.cat([x2, x1], dim=1)
        x = self.conv(x)
        return x

model:

class UNet(nn.Module):
    def __init__(self, in_channels, num_classes):
        super(UNet, self).__init__()
        features = [32,64,128,256]

        self.inc = InConv(in_channels, features[0])
        self.down1 = Down(features[0], features[1])
        self.down2 = Down(features[1], features[2])
        self.down3 = Down(features[2], features[3])
        self.down4 = Down(features[3], features[3])

        self.up1 = Up(features[3], features[3], features[2])
        self.up2 = Up(features[2], features[2], features[1])
        self.up3 = Up(features[1], features[1], features[0])
        self.up4 = Up(features[0], features[0], features[0])
        self.outc = OutConv(features[0], num_classes)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)

        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        x = self.outc(x)
        return x


if __name__ == '__main__':
    x = torch.randn(1, 4, 160, 160, 128)
    net = UNet(in_channels=4, num_classes=4)
    y = net(x)
    print("params: ", sum(p.numel() for p in net.parameters()))
    print(y.shape)

Train:

下面是我写的训练函数,具体细节见代码注释

  • 优化器:optim.SGD(model.parameters(),momentum=0.9, lr=0, weight_decay=5e-4)
  • 学习率余弦衰减:最大学习率0.004,最小学习率0.002,预热10个epoch
  • 优化策略可参考我的另一篇博客nnUnet代码解读–优化策略
import os
import argparse

from torch.utils.data import DataLoader
import torch
import torch.optim as optim
from tqdm import tqdm
from BraTS import *
from networks.Unet import UNet
from utils import Loss,cal_dice,cosine_scheduler


def train_loop(model,optimizer,scheduler,criterion,train_loader,device,epoch):
    model.train()
    running_loss = 0
    dice1_train = 0
    dice2_train = 0
    dice3_train = 0
    pbar = tqdm(train_loader)
    for it,(images,masks) in enumerate(pbar):
        # update learning rate according to the schedule
        it = len(train_loader) * epoch + it
        param_group = optimizer.param_groups[0]
        param_group['lr'] = scheduler[it]
        # print(scheduler[it])

        # [b,4,128,128,128] , [b,128,128,128]
        images, masks = images.to(device),masks.to(device)
        # [b,4,128,128,128], 4分割
        outputs = model(images)
        # outputs = torch.softmax(outputs,dim=1)
        loss = criterion(outputs, masks)
        dice1, dice2, dice3 = cal_dice(outputs,masks)
        pbar.desc = "loss: {:.3f} ".format(loss.item())

        running_loss += loss.item()
        dice1_train += dice1.item()
        dice2_train += dice2.item()
        dice3_train += dice3.item()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    loss = running_loss / len(train_loader)
    dice1 = dice1_train / len(train_loader)
    dice2 = dice2_train / len(train_loader)
    dice3 = dice3_train / len(train_loader)
    return {'loss':loss,'dice1':dice1,'dice2':dice2,'dice3':dice3}


def val_loop(model,criterion,val_loader,device):
    model.eval()
    running_loss = 0
    dice1_val = 0
    dice2_val = 0
    dice3_val = 0
    pbar = tqdm(val_loader)
    with torch.no_grad():
        for images, masks in pbar:
            images, masks = images.to(device), masks.to(device)
            outputs = model(images)
            # outputs = torch.softmax(outputs,dim=1)

            loss = criterion(outputs, masks)
            dice1, dice2, dice3 = cal_dice(outputs, masks)

            running_loss += loss.item()
            dice1_val += dice1.item()
            dice2_val += dice2.item()
            dice3_val += dice3.item()
            # pbar.desc = "loss:{:.3f} dice1:{:.3f} dice2:{:.3f} dice3:{:.3f} ".format(loss,dice1,dice2,dice3)

    loss = running_loss / len(val_loader)
    dice1 = dice1_val / len(val_loader)
    dice2 = dice2_val / len(val_loader)
    dice3 = dice3_val / len(val_loader)
    return {'loss':loss,'dice1':dice1,'dice2':dice2,'dice3':dice3}


def train(model,optimizer,scheduler,criterion,train_loader,
          val_loader,epochs,device,train_log,valid_loss_min=999.0):
    for e in range(epochs):
        # train for epoch
        train_metrics = train_loop(model,optimizer,scheduler,criterion,train_loader,device,e)
        # eval for epoch
        val_metrics = val_loop(model,criterion,val_loader,device)
        info1 = "Epoch:[{}/{}] train_loss: {:.3f} valid_loss: {:.3f} ".format(e+1,epochs,train_metrics["loss"],val_metrics["loss"])
        info2 = "Train--ET: {:.3f} TC: {:.3f} WT: {:.3f} ".format(train_metrics['dice1'],train_metrics['dice2'],train_metrics['dice3'])
        info3 = "Valid--ET: {:.3f} TC: {:.3f} WT: {:.3f} ".format(val_metrics['dice1'],val_metrics['dice2'],val_metrics['dice3'])
        print(info1)
        print(info2)
        print(info3)
        with open(train_log,'a') as f:
            f.write(info1 + '\n' + info2 + ' ' + info3 + '\n')

        if not os.path.exists(args.save_path):
            os.makedirs(args.save_path)
        save_file = {"model": model.state_dict(),
                     "optimizer": optimizer.state_dict()}
        if val_metrics['loss'] < valid_loss_min:
            valid_loss_min = val_metrics['loss']
            torch.save(save_file, 'results/UNet.pth')
        else:
            torch.save(save_file,os.path.join(args.save_path,'checkpoint{}.pth'.format(e+1)))
    print("Finished Training!")


def main(args):
    torch.manual_seed(args.seed)  # 为CPU设置种子用于生成随机数,以使得结果是确定的
    torch.cuda.manual_seed_all(args.seed)  # 为所有的GPU设置种子,以使得结果是确定的

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
    os.environ['CUDA_VISIBLE_DEVICES'] = '0'
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # data info
    patch_size = (160,160,128)
    train_dataset = BraTS(args.data_path,args.train_txt,transform=transforms.Compose([
        RandomRotFlip(),
        RandomCrop(patch_size),
        GaussianNoise(p=0.1),
        ToTensor()
    ]))
    val_dataset = BraTS(args.data_path,args.valid_txt,transform=transforms.Compose([
        CenterCrop(patch_size),
        ToTensor()
    ]))
    test_dataset = BraTS(args.data_path,args.test_txt,transform=transforms.Compose([
        CenterCrop(patch_size),
        ToTensor()
    ]))

    train_loader = DataLoader(dataset=train_dataset, batch_size=args.batch_size, num_workers=12,   # num_worker=4
                              shuffle=True, pin_memory=True)
    val_loader = DataLoader(dataset=val_dataset, batch_size=args.batch_size, num_workers=12, shuffle=False,
                            pin_memory=True)
    test_loader = DataLoader(dataset=test_dataset, batch_size=args.batch_size, num_workers=12, shuffle=False,
                             pin_memory=True)

    print("using {} device.".format(device))
    print("using {} images for training, {} images for validation.".format(len(train_dataset), len(val_dataset)))
    # img,label = train_dataset[0]

    # 1-坏疽(NT,necrotic tumor core),2-浮肿区域(ED,peritumoral edema),4-增强肿瘤区域(ET,enhancing tumor)
    # 评价指标:ET(label4),TC(label1+label4),WT(label1+label2+label4)
    model = UNet(in_channels=4,num_classes=4).to(device)
    criterion = Loss(n_classes=4, weight=torch.tensor([0.2, 0.3, 0.25, 0.25])).to(device)
    optimizer = optim.SGD(model.parameters(),momentum=0.9, lr=0, weight_decay=5e-4)
    scheduler = cosine_scheduler(base_value=args.lr,final_value=args.min_lr,epochs=args.epochs,
                                 niter_per_ep=len(train_loader),warmup_epochs=args.warmup_epochs,start_warmup_value=5e-4)

    # 加载训练模型
    if os.path.exists(args.weights):
        weight_dict = torch.load(args.weights, map_location=device)
        model.load_state_dict(weight_dict['model'])
        optimizer.load_state_dict(weight_dict['optimizer'])
        print('Successfully loading checkpoint.')

    train(model,optimizer,scheduler,criterion,train_loader,val_loader,args.epochs,device,train_log=args.train_log)

    # metrics1 = val_loop(model, criterion, train_loader, device)
    metrics2 = val_loop(model, criterion, val_loader, device)
    metrics3 = val_loop(model, criterion, test_loader, device)

    # 最后再评价一遍所有数据,注意,这里使用的是训练结束的模型参数
    # print("Train -- loss: {:.3f} ET: {:.3f} TC: {:.3f} WT: {:.3f}".format(metrics1['loss'], metrics1['dice1'],metrics1['dice2'], metrics1['dice3']))
    print("Valid -- loss: {:.3f} ET: {:.3f} TC: {:.3f} WT: {:.3f}".format(metrics2['loss'], metrics2['dice1'], metrics2['dice2'], metrics2['dice3']))
    print("Test  -- loss: {:.3f} ET: {:.3f} TC: {:.3f} WT: {:.3f}".format(metrics3['loss'], metrics3['dice1'], metrics3['dice2'], metrics3['dice3']))


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--num_classes', type=int, default=4)
    parser.add_argument('--seed', type=int, default=21)
    parser.add_argument('--epochs', type=int, default=60)
    parser.add_argument('--warmup_epochs', type=int, default=10)
    parser.add_argument('--batch_size', type=int, default=1)
    parser.add_argument('--lr', type=float, default=0.004)
    parser.add_argument('--min_lr', type=float, default=0.002)
    parser.add_argument('--data_path', type=str, default='/***/data_set/BraTS2021/dataset')
    parser.add_argument('--train_txt', type=str, default='/***/data_set/BraTS2021/train.txt')
    parser.add_argument('--valid_txt', type=str, default='/***/data_set/BraTS2021/valid.txt')
    parser.add_argument('--test_txt', type=str, default='/***/data_set/BraTS2021/test.txt')
    parser.add_argument('--train_log', type=str, default='results/UNet.txt')
    parser.add_argument('--weights', type=str, default='results/UNet.pth')
    parser.add_argument('--save_path', type=str, default='checkpoint/UNet')

    args = parser.parse_args()

    main(args)

训练集1000张,验证集125张,测试集126张。保存在验证集上损失最小的模型。

6.实验结果

在这里插入图片描述

训练30轮的loss曲线如上图所示,下面是我用不同的模型训练60轮,在测试集上的评价指标:

3D MRI Brain Tumor Segmentation(BraTS2021)
网络模型三维数据大小ETTCWT均值
UNet160×160×1280.8390.8770.9070.874
Attention UNet160×160×1280.8500.8770.9150.881
  • Attention UNetUNet的基础上,在上采样模块引入像素注意力。

7.滑动推理

加载训练好的权重,采用滑动窗口法进行推理,代码见inference.py

def test_single_case(net, image, stride_xy, stride_z, patch_size, num_classes=1):
    # print(image.shape)
    c, ww, hh, dd = image.shape

    sx = math.ceil((ww - patch_size[0]) / stride_xy) + 1
    sy = math.ceil((hh - patch_size[1]) / stride_xy) + 1
    sz = math.ceil((dd - patch_size[2]) / stride_z) + 1
    # print("{}, {}, {}".format(sx, sy, sz))
    score_map = np.zeros((num_classes, ) + image.shape[1:]).astype(np.float32)
    cnt = np.zeros(image.shape[1:]).astype(np.float32)

    for x in range(0, sx):
        xs = min(stride_xy*x, ww-patch_size[0])
        for y in range(0, sy):
            ys = min(stride_xy * y,hh-patch_size[1])
            for z in range(0, sz):
                zs = min(stride_z * z, dd-patch_size[2])
                test_patch = image[:,xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]]
                test_patch = np.expand_dims(test_patch,axis=0).astype(np.float32)
                test_patch = torch.from_numpy(test_patch).cuda()
                y1 = net(test_patch)
                y = F.softmax(y1, dim=1)
                y = y.cpu().data.numpy()
                y = y[0,:,:,:,:]
                score_map[:, xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] \
                  = score_map[:, xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] + y
                cnt[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] \
                  = cnt[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] + 1
    score_map = score_map/np.expand_dims(cnt,axis=0)
    label_map = np.argmax(score_map, axis = 0)
    return label_map, score_map

在这里插入图片描述

以标签1(NT, necrotic tumor core)为例,上图中红色的是金标签,蓝色的是UNet预测结果


确实,脑肿瘤分割相比其他三维分割任务,结果要好太多了,是一个非常适合练手的项目。感兴趣的同学可以按照我的步骤复现一下,效果也不会差。

代码我都放在上面了,码字不易,有用的话还请点个赞,后续也会更新图像分割和深度学习方面的内容,欢迎交流讨论。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1147858.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【ChatGPT 01】ChatGPT基础科普

1. 从图灵测试到ChatGPT 1950年&#xff0c;艾伦•图灵(Alan Turing)发表论文**《计算机器与智能》&#xff08; Computing Machinery and Intelligence&#xff09;&#xff0c;提出并尝试回答“机器能否思考”这一关键问题。在论文中&#xff0c;图灵提出了“模仿游戏”&…

2.2 消元法的概念

一、消元法介绍 消元法&#xff08;elimination&#xff09;是一个求解线性方程组的系统性方法。下面是使用消元法求解一个 2 2 2\times2 22 线性方程组的例子。消元之前&#xff0c;两个方程都有 x x x 和 y y y&#xff0c;消元后&#xff0c;第一个未知数 x x x 将从第…

Websocket传递JWT令牌

在访问带有[Authorize]的方法的时候&#xff0c;需要前端通过自定义报文头的形式将JWT令牌传递给后端进行验证&#xff0c;否则是不能访问带有[Authorize]的方法。 [Authorize]是用于限制对web应用程序中某些操作或控制器的访问。当[授权]属性应用于操作或控制器时&#xff0c;…

【Linux】多路IO复用技术①——select详解如何使用select在本地主机实现简易的一对多服务器(附图解与代码实现)

这一篇的篇幅可能有点长&#xff0c;但真心希望大家能够静下心来看完&#xff0c;相信一定会有不小的收获。那么话不多说&#xff0c;我们这就开始啦&#xff01;&#xff01;&#xff01; 目录 一对一服务器中的BUG 如何实现简易的一对多服务器 实现简易一对多服务器的大体…

Python超入门(7)__迅速上手操作掌握Python

# 31.类 class Point:# 构造函数def __init__(self, x, y, z):self.x xself.y yself.z z# 自定义函数def move(self):print("move")def draw(self):print("draw")# 定义一个Point类的实例point1 # 注意&#xff1a;新建实例的参数要与构造函数一致 poi…

一天下来一个微信号能添加多少个微信好友?

在即时通讯领域&#xff0c;微信的用户量处于领先的地位。据了解微信及WeChat合并的月活跃账户数已超13亿。远远超越QQ的移动端5.71亿的月活跃用户数量。 那么&#xff0c;微信的用户数量这么多&#xff0c;一天可以加多少好友呢&#xff1f; 新号和不活跃的号 01 微信新号是…

【计算机网络】分层模型和应用协议

网络分层模型和应用协议 1. 分层模型 1.1 五层网络模型 网络要解决的问题是&#xff1a;两个程序之间如何交换数据。 四层&#xff1f;五层&#xff1f;七层&#xff1f; 2. 应用层协议 2.1 URL URL&#xff08;uniform resource locator&#xff0c;统一资源定位符&#…

ZYNQ连载06-EasyLogger日志组件

ZYNQ连载06-EasyLogger日志组件 1. EasyLogger介绍 Easylogger仓库 2. EasyLogger移植 EasyLogger移植比较简单&#xff0c;在Vitis中移植时主要注意路径问题&#xff0c;然后适配下接口即可&#xff1a; void elog_port_output(const char *log, size_t size) {printf(&…

密码学基础

密码学总览 信息安全面临的危险与应对这些威胁的密码技术&#xff1a; 关于上图中的威胁&#xff0c;这里在简单的说明&#xff1a; 窃听&#xff1a;指的是需要保密的消息被第三方获取。篡改&#xff1a;指的是消息的内容被第三方修改&#xff0c;达到欺骗的效果。伪装&…

k8s命令式对象管理、命令式对象配置、声明式对象配置管理资源介绍

目录 一.kubernetes资源管理简介 二.三种资源管理方式优缺点比较 三.命令式对象管理介绍 1.kubectl命令语法格式 2.资源类型 &#xff08;1&#xff09;通过“kubectl api-resources”来查看所有的资源 &#xff08;2&#xff09;每列含义 &#xff08;3&#xff09;常…

RabbitMQ学习05

文章目录 交换机1.Exchanges1.1 概念1.2 类型1.3 无名exchange 2. 临时队列3. 绑定&#xff08;bings&#xff09;4. Fanout4.1 介绍 5.Direct exchange5.1 介绍5.2 多重绑定5.3 实战: 6. Topics6.1 规则6.2 实战 交换机 1.Exchanges 1.1 概念 RabbitMQ 消息传递模型的核心思…

C++初阶2

目录 一&#xff0c;auto关键字 1-1&#xff0c;auto的使用 1-2&#xff0c;基于范围auto的for循环 二&#xff0c;nullptr的运用 三&#xff0c;C类的初步学习 3-1&#xff0c;类的引用 3-2&#xff0c;类的访问权限 3-3&#xff0c;类的使用 1&#xff0c;类中函数的…

SA实战 ·《SpringCloud Alibaba实战》第12章-服务网关:网关概述与核心架构

作者:冰河 星球:http://m6z.cn/6aeFbs 博客:https://binghe.gitcode.host 文章汇总:https://binghe.gitcode.host/md/all/all.html 大家好,我是冰河~~ 一不小心《SpringCloud Alibaba实战》专栏都更新到第12章了,再不上车就跟不上了,小伙伴们快跟上啊! 在《SpringClou…

基于AliO Things和阿里云的智能环境监控系统。

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 文章目录 前言一、实习内容二、实习方法2.1搭建开发环境并完成编译2.1.1 正常完成编译的标志2.1.2 编写实例烧录程序&#xff0c;并完成烧录 2.2按键实现流水灯2.2.1 HaaS ED…

General Expression In Oral English

1. With all Due Respect 恕我直言&#xff0c; 那些和你混在一起的混蛋们只关心政治 With all due respect , these sons of bitches that youre mingling with(与.. 交际) only care about politics

A. Doremy‘s Paint 3(规律)

Problem - A - Codeforces 解析&#xff1a; 首先最多只能存在两个值&#xff0c;因为间隔必须相同。并且两个值的数量相差小于等于1 #include<bits/stdc.h> using namespace std; #define int long long const int N2e55; int t,n,a[N]; map<int,int>mp; signed…

【SpringSecurity】快速入门—通俗易懂

目录 1.导入依赖 2.继承WebSecurityConfigurerAdapter 3.实现UserDetailsService 4.记住我 5.用户注销 6.CSRF理解 7.注解功能 7.1Secured 7.2PreAuthorized 7.3PostAuthorized 7.4PostFilter 7.5ZPreFilter 8.原理解析 1.导入依赖 首先&#xff0c;在pom.xml文…

第五章 I/O管理 二、I/O控制器

目录 一、电子部件 1、I/O控制器 1.功能&#xff1a; &#xff08;1&#xff09;接受和识别CPU发出的命令&#xff1a; &#xff08;2&#xff09;向CPU报告设备的状态 &#xff08;3&#xff09;数据交换 &#xff08;4&#xff09;地址识别 2.组成 二、内存映像和寄…

磁盘调度算法之先来先服务(FCFS),最短寻找时间优先(SSTF),扫描算法(SCAN,电梯算法),LOOK调度算法

目录 1.一次磁盘读/写操作需要的时间1.寻找时间2.延迟时间3.传输时间4.影响读写操作的因素 2.磁盘调度算法1.先来先服务(FCFS)1.例题2.优缺点 2.最短寻找时间优先(SSTF)1.例题2.优缺点3.饥饿的原因 3.扫描算法(SCAN)1.例题2.优缺点 4.LOOK调度算法1.例题2.优点 5.循环扫描算法(…

81 分割回文串

分割回文串 题解1 回溯题解2 回溯dp利用dp相当于先判断哪段是回文(省掉了每次都需要调用的isValid)【预处理】 给你一个字符串 s&#xff0c;请你将 s 分割成一些子串&#xff0c;使 每个子串都是 回文串 。返回 s 所有可能的分割方案。 回文串 是正着读和反着读都一样的字…