记录贴写的很乱仅供参考。
自己写的Unet网络不带深度监督,但是NNUNet默认的训练方法是深度监督训练的,对应的模型也是带有深度监督的。但是NNUNetV2也贴心的提供了非深度监督的训练方法在该目录下:
也或者说我们想要自己去定义一个nnUNWtTrainer 去扩展NNunet的话,就可以参考这里面的py文件去写自己的,但是都建议以nnUNetTrainer为基类去继承它。就如nnUNetTrainerNoDeepSupervision类的写法一样(这个类就是去实现无深度监督网络的训练的):
展示一下这个文件:以及要修改成自己网络的地方。
`import torch
from torch import autocast
from nnunetv2.training.loss.compound_losses import DC_and_BCE_loss, DC_and_CE_loss
from nnunetv2.training.loss.dice import get_tp_fp_fn_tn, MemoryEfficientSoftDiceLoss
from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer
from nnunetv2.utilities.helpers import dummy_context
from nnunetv2.utilities.label_handling.label_handling import determine_num_input_channels
from torch.nn.parallel import DistributedDataParallel as DDP
from nnunetv2.Network.UNet import UNet
class nnUNetTrainerNoDeepSupervision(nnUNetTrainer):
def _build_loss(self):
if self.label_manager.has_regions:
loss = DC_and_BCE_loss({},
{‘batch_dice’: self.configuration_manager.batch_dice,
‘do_bg’: True, ‘smooth’: 1e-5, ‘ddp’: self.is_ddp},
use_ignore_label=self.label_manager.ignore_label is not None,
dice_class=MemoryEfficientSoftDiceLoss)
else:
loss = DC_and_CE_loss({‘batch_dice’: self.configuration_manager.batch_dice,
‘smooth’: 1e-5, ‘do_bg’: False, ‘ddp’: self.is_ddp}, {}, weight_ce=1, weight_dice=1,
ignore_label=self.label_manager.ignore_label,
dice_class=MemoryEfficientSoftDiceLoss)
return loss
def _get_deep_supervision_scales(self):
return None
def initialize(self):
if not self.was_initialized:
self.num_input_channels = determine_num_input_channels(self.plans_manager, self.configuration_manager,
self.dataset_json)
# self.network = self.build_network_architecture(self.plans_manager, self.dataset_json,
# self.configuration_manager,
# self.num_input_channels,
# enable_deep_supervision=False).to(self.device)
self.network = UNet(self.num_input_channels, 2, base_c=32).to(self.device)
print("="*20)
print("now use our unet")
print("=" * 20)
self.optimizer, self.lr_scheduler = self.configure_optimizers()
# if ddp, wrap in DDP wrapper
if self.is_ddp:
self.network = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.network)
self.network = DDP(self.network, device_ids=[self.local_rank])
self.loss = self._build_loss()
self.was_initialized = True
else:
raise RuntimeError("You have called self.initialize even though the trainer was already initialized. "
"That should not happen.")
def set_deep_supervision_enabled(self, enabled: bool):
pass
def validation_step(self, batch: dict) -> dict:
data = batch['data']
target = batch['target']
data = data.to(self.device, non_blocking=True)
if isinstance(target, list):
target = [i.to(self.device, non_blocking=True) for i in target]
else:
target = target.to(self.device, non_blocking=True)
self.optimizer.zero_grad(set_to_none=True)
# Autocast is a little bitch.
# If the device_type is 'cpu' then it's slow as heck and needs to be disabled.
# If the device_type is 'mps' then it will complain that mps is not implemented, even if enabled=False is set. Whyyyyyyy. (this is why we don't make use of enabled=False)
# So autocast will only be active if we have a cuda device.
with autocast(self.device.type, enabled=True) if self.device.type == 'cuda' else dummy_context():
output = self.network(data)
del data
l = self.loss(output, target)
# the following is needed for online evaluation. Fake dice (green line)
axes = [0] + list(range(2, output.ndim))
if self.label_manager.has_regions:
predicted_segmentation_onehot = (torch.sigmoid(output) > 0.5).long()
else:
# no need for softmax
output_seg = output.argmax(1)[:, None]
predicted_segmentation_onehot = torch.zeros(output.shape, device=output.device, dtype=torch.float32)
predicted_segmentation_onehot.scatter_(1, output_seg, 1)
del output_seg
if self.label_manager.has_ignore_label:
if not self.label_manager.has_regions:
mask = (target != self.label_manager.ignore_label).float()
# CAREFUL that you don't rely on target after this line!
target[target == self.label_manager.ignore_label] = 0
else:
mask = 1 - target[:, -1:]
# CAREFUL that you don't rely on target after this line!
target = target[:, :-1]
else:
mask = None
tp, fp, fn, _ = get_tp_fp_fn_tn(predicted_segmentation_onehot, target, axes=axes, mask=mask)
tp_hard = tp.detach().cpu().numpy()
fp_hard = fp.detach().cpu().numpy()
fn_hard = fn.detach().cpu().numpy()
if not self.label_manager.has_regions:
# if we train with regions all segmentation heads predict some kind of foreground. In conventional
# (softmax training) there needs tobe one output for the background. We are not interested in the
# background Dice
# [1:] in order to remove background
tp_hard = tp_hard[1:]
fp_hard = fp_hard[1:]
fn_hard = fn_hard[1:]
return {'loss': l.detach().cpu().numpy(), 'tp_hard': tp_hard, 'fp_hard': fp_hard, 'fn_hard': fn_hard}`
在self.network处将网络替换为自己的非深度监督网络即可,比如我改成自己编写的UNet网络如下:
self.network = UNet(self.num_input_channels, 2, base_c=32).to(self.device)
###下列为提示语句,以便确认是在调用该训练器进行训练
print("="*20)
print("now use our unet")
print("=" * 20)
最后需要在训练时候的脚本上加上 -tr 自己写的类名,此处就是 -tr nnUNetTrainerNoDeepSupervision
也就是最后的训练脚本如下:
nnUNetv2_train 002 2d 0 -tr nnUNetTrainerNoDeepSupervision
PS:此处也可以通过直接在run_training.py 文件中修改
这个命令行参数的默认值来实现。
好记录完毕,继续炼丹