基于monai库。其实我不是很喜欢这种,可扩展性太差了,除非说你想快速在自己的数据集上出结果。但是它的transform可以对3d医学图像增强操作,比torch的transform强一点,因为它的数据增强输入是(x,y,z)h,w,d格式的,我还没有试过单独用它的transform来结合torch训练。
前提
pip install monai
目录结构
train.py
from nets.swin_model import GetSwinUnetr
import torch
from utils.dataloaderd import GetDataLoader
from monai.losses import DiceCELoss
from monai.metrics import DiceMetric
from monai.inferers import sliding_window_inference
from tqdm import tqdm
import os
import matplotlib.pyplot as plt
from torch import nn
os.environ["CUDA_VISIBLE_DEVICES"] = "2,3"
from monai.data import (
ThreadDataLoader,
CacheDataset,
load_decathlon_datalist,
decollate_batch,
set_track_meta,
)
from monai.transforms import (AsDiscrete,)
def validation(epoch_iterator_val):
model.eval()
with torch.no_grad():
for batch in epoch_iterator_val:
val_inputs, val_labels = (batch["image"].cuda(), batch["label"].cuda())
with torch.cuda.amp.autocast():
val_outputs = sliding_window_inference(val_inputs, (cutSize[0], cutSize[1], cutSize[2]), 1, model)
val_labels_list = decollate_batch(val_labels)
val_labels_convert = [post_label(val_label_tensor) for val_label_tensor in val_labels_list]
val_outputs_list = decollate_batch(val_outputs)
val_output_convert = [post_pred(val_pred_tensor) for val_pred_tensor in val_outputs_list]
dice_metric(y_pred=val_output_convert, y=val_labels_convert)
epoch_iterator_val.set_description("Validate (%d / %d Steps)" % (global_step, 10.0))
mean_dice_val = dice_metric.aggregate().item()
dice_metric.reset()
return mean_dice_val
def train(global_step, train_loader, dice_val_best, global_step_best):
model.train()
epoch_loss = 0
step = 0
epoch_iterator = tqdm(train_loader, desc="Training (X / X Steps) (loss=X.X)", dynamic_ncols=True)
for step, batch in enumerate(epoch_iterator):
step += 1
x, y = (batch["image"].cuda(), batch["label"].cuda())
# with torch.cuda.amp.autocast():
logit_map = model(x)
loss = loss_function(logit_map, y)
# scaler.scale(loss).backward()
epoch_loss += loss.item()
# scaler.unscale_(optimizer)
# scaler.step(optimizer)
# scaler.update()
optimizer.zero_grad()
loss.backward()
optimizer.step()
epoch_iterator.set_description(f"Training ({global_step} / {max_iterations} Steps) (loss={loss:2.5f})")
if (global_step % eval_num == 0 and global_step != 0) or global_step == max_iterations:
epoch_iterator_val = tqdm(val_loader, desc="Validate (X / X Steps) (dice=X.X)", dynamic_ncols=True)
dice_val = validation(epoch_iterator_val)
epoch_loss /= step
epoch_loss_values.append(epoch_loss)
metric_values.append(dice_val)
if dice_val > dice_val_best:
dice_val_best = dice_val
global_step_best = global_step
torch.save(model.state_dict(), os.path.join(root_dir, "best_metric_model.pth"))
print(
"Model Was Saved ! Current Best Avg. Dice: {} Current Avg. Dice: {}".format(dice_val_best, dice_val)
)
else:
print(
"Model Was Not Saved ! Current Best Avg. Dice: {} Current Avg. Dice: {}".format(
dice_val_best, dice_val
)
)
global_step += 1
return global_step, dice_val_best, global_step_best
if __name__ == '__main__':
data_dir = "/home/data/hablee_data_dir/aorta2023/"
root_dir = "/home/hbli/pythonFiles3/c_swinUnetr_torch/"
split_json = "dataset_0.json"
cutSize = [64,64,32]
numClasses = 3
batchSize = 2
train_loader,val_loader,val_ds = GetDataLoader(
data_dir=data_dir,split_json=split_json,
cut_size=cutSize,batch_size=batchSize)
model = GetSwinUnetr(cut_size=cutSize,num_classes=numClasses)
weight = torch.load("./best_metric_model.pth")
model.load_state_dict(weight)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
gpu_counts = torch.cuda.device_count()
if gpu_counts>1:
batch_size = batchSize * gpu_counts
model = nn.DataParallel(model)
model.to(device=device)
# torch.backends.cudnn.benchmark = True
loss_function = DiceCELoss(to_onehot_y=True, softmax=True)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)
# scaler = torch.cuda.amp.GradScaler()
max_iterations = 30000
eval_num = 500
post_label = AsDiscrete(to_onehot=numClasses)
post_pred = AsDiscrete(argmax=True, to_onehot=numClasses)
dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False)
global_step = 0
dice_val_best = 0.0
global_step_best = 0
epoch_loss_values = []
metric_values = []
while global_step < max_iterations:
global_step, dice_val_best, global_step_best = train(global_step, train_loader, dice_val_best, global_step_best)
print(f"train completed, best_metric: {dice_val_best:.4f} " f"at iteration: {global_step_best}")
plt.figure("train", (12, 6))
plt.subplot(1, 2, 1)
plt.title("Iteration Average Loss")
x = [eval_num * (i + 1) for i in range(len(epoch_loss_values))]
y = epoch_loss_values
plt.xlabel("Iteration")
plt.plot(x, y)
plt.subplot(1, 2, 2)
plt.title("Val Mean Dice")
x = [eval_num * (i + 1) for i in range(len(metric_values))]
y = metric_values
plt.xlabel("Iteration")
plt.plot(x, y)
plt.show()
plt.savefig('loss.jpg')
case_num = 4
slice_map = {
"72.nii.gz": 50,
"97.nii.gz": 50,
"82.nii.gz": 50,
"153.nii.gz": 50,
"54.nii.gz": 50,
"104.nii.gz": 50
}
model.load_state_dict(torch.load(os.path.join(root_dir, "best_metric_model.pth")))
model.eval()
with torch.no_grad():
img_name = os.path.split(val_ds[case_num]["image"].meta["filename_or_obj"])[1]
img = val_ds[case_num]["image"]
label = val_ds[case_num]["label"]
val_inputs = torch.unsqueeze(img, 1).cuda()
val_labels = torch.unsqueeze(label, 1).cuda()
val_outputs = sliding_window_inference(val_inputs, (cutSize[0], cutSize[1], cutSize[2]), 1, model, overlap=0.8)
plt.figure("check", (18, 6))
plt.subplot(1, 3, 1)
plt.title("image")
plt.imshow(val_inputs.cpu().numpy()[0, 0, :, :, slice_map[img_name]], cmap="gray")
plt.subplot(1, 3, 2)
plt.title("label")
plt.imshow(val_labels.cpu().numpy()[0, 0, :, :, slice_map[img_name]])
plt.subplot(1, 3, 3)
plt.title("output")
plt.imshow(torch.argmax(val_outputs, dim=1).detach().cpu()[0, :, :, slice_map[img_name]])
plt.show()
plt.savefig("result.jpg")
nets/swin_model.py
# import os
# os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
from monai.networks.nets.swin_unetr import SwinUNETR
# import torch
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def GetSwinUnetr(cut_size,num_classes):
model = SwinUNETR(
img_size=(cut_size[0], cut_size[1], cut_size[2]),
in_channels=1,
out_channels=num_classes,
feature_size=48,
use_checkpoint=True,)
return model
utils/dataloaderd.py
import os
import shutil
import tempfile
import matplotlib.pyplot as plt
from tqdm import tqdm
import torch
from monai.losses import DiceCELoss
from monai.inferers import sliding_window_inference
from monai.transforms import (
AsDiscrete,
Compose,
CropForegroundd,
LoadImaged,
Orientationd,
RandFlipd,
RandCropByPosNegLabeld,
RandShiftIntensityd,
ScaleIntensityRanged,
Spacingd,
RandRotate90d,
EnsureTyped,
)
from monai.metrics import DiceMetric
from monai.networks.nets import SwinUNETR
from monai.config import print_config
from monai.data import (
DataLoader,
CacheDataset,
load_decathlon_datalist,
decollate_batch,
set_track_meta,
)
num_samples = 4
os.environ["CUDA_VISIBLE_DEVICES"] = "2,3"
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
val_transforms = Compose(
[
LoadImaged(keys=["image", "label"], ensure_channel_first=True),
ScaleIntensityRanged(keys=["image"], a_min=-175, a_max=250, b_min=0.0, b_max=1.0, clip=True),
CropForegroundd(keys=["image", "label"], source_key="image"),
Orientationd(keys=["image", "label"], axcodes="RAS"),
Spacingd(
keys=["image", "label"],
pixdim=(1.5, 1.5, 2.0),
mode=("bilinear", "nearest"),
),
EnsureTyped(keys=["image", "label"], device=device, track_meta=True),
]
)
def GetDataLoader(data_dir,split_json,cut_size,batch_size):
train_transforms = Compose(
[
LoadImaged(keys=["image", "label"], ensure_channel_first=True),
ScaleIntensityRanged(
keys=["image"],
a_min=-175,
a_max=250,
b_min=0.0,
b_max=1.0,
clip=True,
),
CropForegroundd(keys=["image", "label"], source_key="image"),
Orientationd(keys=["image", "label"], axcodes="RAS"),
Spacingd(
keys=["image", "label"],
pixdim=(1.5, 1.5, 2.0),
mode=("bilinear", "nearest"),
),
EnsureTyped(keys=["image", "label"], device=device, track_meta=False),
RandCropByPosNegLabeld(
keys=["image", "label"],
label_key="label",
spatial_size=(cut_size[0], cut_size[1], cut_size[2]),
pos=1,
neg=1,
num_samples=num_samples,
image_key="image",
image_threshold=0,
),
RandFlipd(
keys=["image", "label"],
spatial_axis=[0],
prob=0.10,
),
RandFlipd(
keys=["image", "label"],
spatial_axis=[1],
prob=0.10,
),
RandFlipd(
keys=["image", "label"],
spatial_axis=[2],
prob=0.10,
),
RandRotate90d(
keys=["image", "label"],
prob=0.10,
max_k=3,
),
RandShiftIntensityd(
keys=["image"],
offsets=0.10,
prob=0.50,
),
]
)
datasets = data_dir + split_json
datalist = load_decathlon_datalist(datasets, True, "training")
val_files = load_decathlon_datalist(datasets, True, "validation")
train_ds = CacheDataset(
data=datalist,
transform=train_transforms,
cache_num=24,
cache_rate=1.0,
num_workers=0,
)
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=False)
val_ds = CacheDataset(data=val_files, transform=val_transforms, cache_num=6, cache_rate=1.0, num_workers=0)
val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=False)
return train_loader,val_loader,val_ds
# if __name__ == '__main__':
# data_dir = "/home/data/hablee_data_dir/aorta2023/"
# split_json = "dataset_0.json"
# datasets = data_dir + split_json
# datalist = load_decathlon_datalist(datasets, True, "training")
# val_files = load_decathlon_datalist(datasets, True, "validation")
# train_ds = CacheDataset(
# data=datalist,
# transform=train_transforms,
# cache_num=24,
# cache_rate=1.0,
# num_workers=4,
# )
# train_loader = DataLoader(train_ds, batch_size=1, shuffle=True, num_workers=4, pin_memory=True)
# val_ds = CacheDataset(data=val_files, transform=val_transforms, cache_num=6, cache_rate=1.0, num_workers=4)
# val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=4, pin_memory=True)
# slice_map = {
# "72.nii.gz": 50,
# "97.nii.gz": 50,
# "82.nii.gz": 50,
# "153.nii.gz": 50,
# "54.nii.gz": 50,
# "104.nii.gz": 50
# }
# case_num = 2
# img_name = os.path.split(val_ds[case_num]["image"].meta["filename_or_obj"])[1]
# img = val_ds[case_num]["image"]
# label = val_ds[case_num]["label"]
# img_shape = img.shape
# label_shape = label.shape
# print(f"image shape: {img_shape}, label shape: {label_shape}")
# plt.figure("image", (18, 6))
# plt.subplot(1, 2, 1)
# plt.title("image")
# plt.imshow(img[0, :, :, slice_map[img_name]].detach().cpu(), cmap="gray")
# plt.subplot(1, 2, 2)
# plt.title("label")
# plt.imshow(label[0, :, :, slice_map[img_name]].detach().cpu())
# plt.show()
# plt.savefig("a.jpg")
就这几个文件就可以训练了。
数据
数据都是.nii.gz
格式的
dataset_0.json
{
"description": "zdm fjm",
"labels": {
"0": "background",
"1": "zdm",
"2": "fjm"
},
"licence": "hablee",
"modality": {
"0": "CT"
},
"name": "zdm fjm",
"numTest": 5,
"numTraining": 100,
"reference": "None",
"release": "2023/02/14",
"tensorImageSize": "3D",
"test":[
"images/100.nii.gz",
"images/102.nii.gz",
"images/103.nii.gz",
"images/104.nii.gz",
"images/105.nii.gz"
],
"training":[
{
"image": "images/72.nii.gz",
"label": "labels/72.nii.gz"
},
{
"image": "images/97.nii.gz",
"label": "labels/97.nii.gz"
},
{
"image": "images/82.nii.gz",
"label": "labels/82.nii.gz"
},
{
"image": "images/153.nii.gz",
"label": "labels/153.nii.gz"
},
{
"image": "images/54.nii.gz",
"label": "labels/54.nii.gz"
},
{
"image": "images/104.nii.gz",
"label": "labels/104.nii.gz"
},
{
"image": "images/110.nii.gz",
"label": "labels/110.nii.gz"
},
{
"image": "images/89.nii.gz",
"label": "labels/89.nii.gz"
},
{
"image": "images/95.nii.gz",
"label": "labels/95.nii.gz"
},
{
"image": "images/129.nii.gz",
"label": "labels/129.nii.gz"
},
{
"image": "images/118.nii.gz",
"label": "labels/118.nii.gz"
},
{
"image": "images/141.nii.gz",
"label": "labels/141.nii.gz"
},
{
"image": "images/134.nii.gz",
"label": "labels/134.nii.gz"
},
{
"image": "images/146.nii.gz",
"label": "labels/146.nii.gz"
},
{
"image": "images/100.nii.gz",
"label": "labels/100.nii.gz"
},
{
"image": "images/142.nii.gz",
"label": "labels/142.nii.gz"
},
{
"image": "images/85.nii.gz",
"label": "labels/85.nii.gz"
},
{
"image": "images/154.nii.gz",
"label": "labels/154.nii.gz"
},
{
"image": "images/103.nii.gz",
"label": "labels/103.nii.gz"
},
{
"image": "images/59.nii.gz",
"label": "labels/59.nii.gz"
},
{
"image": "images/84.nii.gz",
"label": "labels/84.nii.gz"
},
{
"image": "images/124.nii.gz",
"label": "labels/124.nii.gz"
},
{
"image": "images/125.nii.gz",
"label": "labels/125.nii.gz"
},
{
"image": "images/58.nii.gz",
"label": "labels/58.nii.gz"
},
{
"image": "images/68.nii.gz",
"label": "labels/68.nii.gz"
},
{
"image": "images/81.nii.gz",
"label": "labels/81.nii.gz"
},
{
"image": "images/115.nii.gz",
"label": "labels/115.nii.gz"
},
{
"image": "images/77.nii.gz",
"label": "labels/77.nii.gz"
},
{
"image": "images/127.nii.gz",
"label": "labels/127.nii.gz"
},
{
"image": "images/131.nii.gz",
"label": "labels/131.nii.gz"
},
{
"image": "images/147.nii.gz",
"label": "labels/147.nii.gz"
},
{
"image": "images/73.nii.gz",
"label": "labels/73.nii.gz"
},
{
"image": "images/102.nii.gz",
"label": "labels/102.nii.gz"
},
{
"image": "images/66.nii.gz",
"label": "labels/66.nii.gz"
},
{
"image": "images/67.nii.gz",
"label": "labels/67.nii.gz"
},
{
"image": "images/135.nii.gz",
"label": "labels/135.nii.gz"
},
{
"image": "images/149.nii.gz",
"label": "labels/149.nii.gz"
},
{
"image": "images/48.nii.gz",
"label": "labels/48.nii.gz"
},
{
"image": "images/83.nii.gz",
"label": "labels/83.nii.gz"
},
{
"image": "images/145.nii.gz",
"label": "labels/145.nii.gz"
},
{
"image": "images/45.nii.gz",
"label": "labels/45.nii.gz"
},
{
"image": "images/61.nii.gz",
"label": "labels/61.nii.gz"
},
{
"image": "images/122.nii.gz",
"label": "labels/122.nii.gz"
},
{
"image": "images/96.nii.gz",
"label": "labels/96.nii.gz"
},
{
"image": "images/60.nii.gz",
"label": "labels/60.nii.gz"
},
{
"image": "images/144.nii.gz",
"label": "labels/144.nii.gz"
},
{
"image": "images/91.nii.gz",
"label": "labels/91.nii.gz"
},
{
"image": "images/111.nii.gz",
"label": "labels/111.nii.gz"
},
{
"image": "images/114.nii.gz",
"label": "labels/114.nii.gz"
},
{
"image": "images/90.nii.gz",
"label": "labels/90.nii.gz"
},
{
"image": "images/52.nii.gz",
"label": "labels/52.nii.gz"
},
{
"image": "images/132.nii.gz",
"label": "labels/132.nii.gz"
},
{
"image": "images/107.nii.gz",
"label": "labels/107.nii.gz"
},
{
"image": "images/109.nii.gz",
"label": "labels/109.nii.gz"
},
{
"image": "images/139.nii.gz",
"label": "labels/139.nii.gz"
},
{
"image": "images/143.nii.gz",
"label": "labels/143.nii.gz"
},
{
"image": "images/119.nii.gz",
"label": "labels/119.nii.gz"
},
{
"image": "images/55.nii.gz",
"label": "labels/55.nii.gz"
},
{
"image": "images/80.nii.gz",
"label": "labels/80.nii.gz"
},
{
"image": "images/53.nii.gz",
"label": "labels/53.nii.gz"
},
{
"image": "images/120.nii.gz",
"label": "labels/120.nii.gz"
},
{
"image": "images/65.nii.gz",
"label": "labels/65.nii.gz"
},
{
"image": "images/88.nii.gz",
"label": "labels/88.nii.gz"
},
{
"image": "images/47.nii.gz",
"label": "labels/47.nii.gz"
},
{
"image": "images/57.nii.gz",
"label": "labels/57.nii.gz"
},
{
"image": "images/130.nii.gz",
"label": "labels/130.nii.gz"
},
{
"image": "images/108.nii.gz",
"label": "labels/108.nii.gz"
},
{
"image": "images/151.nii.gz",
"label": "labels/151.nii.gz"
},
{
"image": "images/113.nii.gz",
"label": "labels/113.nii.gz"
},
{
"image": "images/71.nii.gz",
"label": "labels/71.nii.gz"
},
{
"image": "images/46.nii.gz",
"label": "labels/46.nii.gz"
},
{
"image": "images/105.nii.gz",
"label": "labels/105.nii.gz"
},
{
"image": "images/148.nii.gz",
"label": "labels/148.nii.gz"
},
{
"image": "images/112.nii.gz",
"label": "labels/112.nii.gz"
},
{
"image": "images/106.nii.gz",
"label": "labels/106.nii.gz"
},
{
"image": "images/49.nii.gz",
"label": "labels/49.nii.gz"
},
{
"image": "images/140.nii.gz",
"label": "labels/140.nii.gz"
},
{
"image": "images/92.nii.gz",
"label": "labels/92.nii.gz"
},
{
"image": "images/137.nii.gz",
"label": "labels/137.nii.gz"
},
{
"image": "images/74.nii.gz",
"label": "labels/74.nii.gz"
},
{
"image": "images/62.nii.gz",
"label": "labels/62.nii.gz"
},
{
"image": "images/99.nii.gz",
"label": "labels/99.nii.gz"
},
{
"image": "images/150.nii.gz",
"label": "labels/150.nii.gz"
},
{
"image": "images/75.nii.gz",
"label": "labels/75.nii.gz"
},
{
"image": "images/98.nii.gz",
"label": "labels/98.nii.gz"
},
{
"image": "images/86.nii.gz",
"label": "labels/86.nii.gz"
},
{
"image": "images/50.nii.gz",
"label": "labels/50.nii.gz"
},
{
"image": "images/93.nii.gz",
"label": "labels/93.nii.gz"
},
{
"image": "images/138.nii.gz",
"label": "labels/138.nii.gz"
},
{
"image": "images/126.nii.gz",
"label": "labels/126.nii.gz"
},
{
"image": "images/69.nii.gz",
"label": "labels/69.nii.gz"
},
{
"image": "images/64.nii.gz",
"label": "labels/64.nii.gz"
},
{
"image": "images/136.nii.gz",
"label": "labels/136.nii.gz"
},
{
"image": "images/51.nii.gz",
"label": "labels/51.nii.gz"
},
{
"image": "images/70.nii.gz",
"label": "labels/70.nii.gz"
},
{
"image": "images/56.nii.gz",
"label": "labels/56.nii.gz"
},
{
"image": "images/128.nii.gz",
"label": "labels/128.nii.gz"
},
{
"image": "images/76.nii.gz",
"label": "labels/76.nii.gz"
},
{
"image": "images/123.nii.gz",
"label": "labels/123.nii.gz"
},
{
"image": "images/152.nii.gz",
"label": "labels/152.nii.gz"
}
],
"validation":[
{
"image": "images/72.nii.gz",
"label": "labels/72.nii.gz"
},
{
"image": "images/97.nii.gz",
"label": "labels/97.nii.gz"
},
{
"image": "images/82.nii.gz",
"label": "labels/82.nii.gz"
},
{
"image": "images/153.nii.gz",
"label": "labels/153.nii.gz"
},
{
"image": "images/54.nii.gz",
"label": "labels/54.nii.gz"
},
{
"image": "images/104.nii.gz",
"label": "labels/104.nii.gz"
}
]
}