UUNet训练自己写的网络

news2025/1/10 1:44:29

记录贴写的很乱仅供参考。
自己写的Unet网络不带深度监督,但是NNUNet默认的训练方法是深度监督训练的,对应的模型也是带有深度监督的。但是NNUNetV2也贴心的提供了非深度监督的训练方法在该目录下:
在这里插入图片描述
也或者说我们想要自己去定义一个nnUNWtTrainer 去扩展NNunet的话,就可以参考这里面的py文件去写自己的,但是都建议以nnUNetTrainer为基类去继承它。就如nnUNetTrainerNoDeepSupervision类的写法一样(这个类就是去实现无深度监督网络的训练的):
展示一下这个文件:以及要修改成自己网络的地方。
`import torch
from torch import autocast

from nnunetv2.training.loss.compound_losses import DC_and_BCE_loss, DC_and_CE_loss
from nnunetv2.training.loss.dice import get_tp_fp_fn_tn, MemoryEfficientSoftDiceLoss
from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer
from nnunetv2.utilities.helpers import dummy_context
from nnunetv2.utilities.label_handling.label_handling import determine_num_input_channels
from torch.nn.parallel import DistributedDataParallel as DDP
from nnunetv2.Network.UNet import UNet

class nnUNetTrainerNoDeepSupervision(nnUNetTrainer):
def _build_loss(self):
if self.label_manager.has_regions:
loss = DC_and_BCE_loss({},
{‘batch_dice’: self.configuration_manager.batch_dice,
‘do_bg’: True, ‘smooth’: 1e-5, ‘ddp’: self.is_ddp},
use_ignore_label=self.label_manager.ignore_label is not None,
dice_class=MemoryEfficientSoftDiceLoss)
else:
loss = DC_and_CE_loss({‘batch_dice’: self.configuration_manager.batch_dice,
‘smooth’: 1e-5, ‘do_bg’: False, ‘ddp’: self.is_ddp}, {}, weight_ce=1, weight_dice=1,
ignore_label=self.label_manager.ignore_label,
dice_class=MemoryEfficientSoftDiceLoss)
return loss

def _get_deep_supervision_scales(self):
    return None

def initialize(self):
    if not self.was_initialized:
        self.num_input_channels = determine_num_input_channels(self.plans_manager, self.configuration_manager,
                                                               self.dataset_json)

        # self.network = self.build_network_architecture(self.plans_manager, self.dataset_json,
        #                                                self.configuration_manager,
        #                                                self.num_input_channels,
        #                                                enable_deep_supervision=False).to(self.device)
        self.network = UNet(self.num_input_channels, 2, base_c=32).to(self.device)
        print("="*20)
        print("now use our unet")
        print("=" * 20)

        self.optimizer, self.lr_scheduler = self.configure_optimizers()
        # if ddp, wrap in DDP wrapper
        if self.is_ddp:
            self.network = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.network)
            self.network = DDP(self.network, device_ids=[self.local_rank])

        self.loss = self._build_loss()
        self.was_initialized = True
    else:
        raise RuntimeError("You have called self.initialize even though the trainer was already initialized. "
                           "That should not happen.")

def set_deep_supervision_enabled(self, enabled: bool):
    pass

def validation_step(self, batch: dict) -> dict:
    data = batch['data']
    target = batch['target']

    data = data.to(self.device, non_blocking=True)
    if isinstance(target, list):
        target = [i.to(self.device, non_blocking=True) for i in target]
    else:
        target = target.to(self.device, non_blocking=True)

    self.optimizer.zero_grad(set_to_none=True)

    # Autocast is a little bitch.
    # If the device_type is 'cpu' then it's slow as heck and needs to be disabled.
    # If the device_type is 'mps' then it will complain that mps is not implemented, even if enabled=False is set. Whyyyyyyy. (this is why we don't make use of enabled=False)
    # So autocast will only be active if we have a cuda device.
    with autocast(self.device.type, enabled=True) if self.device.type == 'cuda' else dummy_context():
        output = self.network(data)
        del data
        l = self.loss(output, target)

    # the following is needed for online evaluation. Fake dice (green line)
    axes = [0] + list(range(2, output.ndim))

    if self.label_manager.has_regions:
        predicted_segmentation_onehot = (torch.sigmoid(output) > 0.5).long()
    else:
        # no need for softmax
        output_seg = output.argmax(1)[:, None]
        predicted_segmentation_onehot = torch.zeros(output.shape, device=output.device, dtype=torch.float32)
        predicted_segmentation_onehot.scatter_(1, output_seg, 1)
        del output_seg

    if self.label_manager.has_ignore_label:
        if not self.label_manager.has_regions:
            mask = (target != self.label_manager.ignore_label).float()
            # CAREFUL that you don't rely on target after this line!
            target[target == self.label_manager.ignore_label] = 0
        else:
            mask = 1 - target[:, -1:]
            # CAREFUL that you don't rely on target after this line!
            target = target[:, :-1]
    else:
        mask = None

    tp, fp, fn, _ = get_tp_fp_fn_tn(predicted_segmentation_onehot, target, axes=axes, mask=mask)

    tp_hard = tp.detach().cpu().numpy()
    fp_hard = fp.detach().cpu().numpy()
    fn_hard = fn.detach().cpu().numpy()
    if not self.label_manager.has_regions:
        # if we train with regions all segmentation heads predict some kind of foreground. In conventional
        # (softmax training) there needs tobe one output for the background. We are not interested in the
        # background Dice
        # [1:] in order to remove background
        tp_hard = tp_hard[1:]
        fp_hard = fp_hard[1:]
        fn_hard = fn_hard[1:]

    return {'loss': l.detach().cpu().numpy(), 'tp_hard': tp_hard, 'fp_hard': fp_hard, 'fn_hard': fn_hard}`

在self.network处将网络替换为自己的非深度监督网络即可,比如我改成自己编写的UNet网络如下:

self.network = UNet(self.num_input_channels, 2, base_c=32).to(self.device)
###下列为提示语句,以便确认是在调用该训练器进行训练
print("="*20)
print("now use our unet")
print("=" * 20)

最后需要在训练时候的脚本上加上 -tr 自己写的类名,此处就是 -tr nnUNetTrainerNoDeepSupervision
也就是最后的训练脚本如下:

nnUNetv2_train 002 2d 0 -tr nnUNetTrainerNoDeepSupervision

PS:此处也可以通过直接在run_training.py 文件中修改在这里插入图片描述
这个命令行参数的默认值来实现。
好记录完毕,继续炼丹

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1170994.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

pyusb环境搭建和无法发包问题

pyusb环境搭建和无法发包问题 项目需要对usb设备进行开发调试,选择搭建pyusb环境进行调试测试,这里记录下完整流程和中间解决的一些问题。 我使用的环境是window10 64bit, vscode 1.84.0 , Python 3.11.6 1 安装流程 参考github上的 https://github.…

实用篇-MQ消息队列

一、初识MQ 通讯分为同步通讯和异步通讯,同步通讯就比如我们日常生活中的打电话,看直播,能够得到及时的反馈。而异步通讯则类似于聊天软件聊天,不需要建立实时的连接,并且可以进行建立多个业务一起异步执行 1. 同步通…

关于SNAP的Biophysical Processor模块的计算准确率以及大厂10月种植情况

关于SNAP的Biophysical Processor模块的计算准确率 在处理河北省2022年的10月6日影像,使用SNAP的Biophysical Processor计算LAI时 发现很多农田地块出现了缺失值,但其实就是0值 SNAP的这个模块基于PROSAIL物理模型反演。不得不说,还是挺准…

AI:56-基于深度学习的微表情识别

🚀 本文选自专栏:AI领域专栏 从基础到实践,深入了解算法、案例和最新趋势。无论你是初学者还是经验丰富的数据科学家,通过案例和项目实践,掌握核心概念和实用技能。每篇案例都包含代码实例,详细讲解供大家学习。 📌📌📌在这个漫长的过程,中途遇到了不少问题,但是…

聚观早报 |盒马参战双11;真我GT5 Pro将压轴登场

【聚观365】11月4日消息 盒马参战双11 真我GT5 Pro将压轴登场 奇瑞汽车10月销量创新高 iQOO 12系列将首发电竞芯片Q1 苹果CEO库克称正改善供需平衡 盒马参战双11 不少消费者反映,今年盒马的双11已悄然开始:10月20日起,盒马APP很多商品页…

COE文件之读写操作

在Xilinx的FIR、Block Memory等 IP核的配置中,需要使用COE(Coefficient)文件来进行初始化。 写COE文件 一般是通过Matlab设计好参数后,再生成COE文件。具体代码如下。 x 1:512; fid fopen(test.coe,w); fprintf(fid, memory_…

MachineLearning 14. 机器学习之集成分类器(AdaBoost)

这期介绍一下NB的最佳集成分类方法之一 AdaBoost,并实现在具体数据集上的应用,尤其是临床数据。 简 介 Adaboost是Adaptive Boosting的缩写,使用一组简单的弱分类器,通过强调被弱分类器错误分类的样本来实现改进的分类器。AdaBoo…

Java金字塔、空心金字塔、空心菱形

Java金字塔 public class TestDemo01 {public static void main(String[] args){//第一个for用于每行输出 从i1开始到i<5,总共5行for(int i1;i<5;i){//每行前缀空格&#xff0c;这个for用于表示每行输出*前面的空格//从上面规律可得,每行输出的空格数为总层数&#xff0c…

【计算机网络】金管局计算机岗位——计算机网络(⭐⭐⭐⭐)

计算机网络知识点 计算机网络基础知识计算机网络的定义与组成、分类网络的发展、常识&#xff08;⭐⭐⭐⭐&#xff09;计算机网络的定义计算机网络的功能计算机网络的组成计算机网络的分类计算机网络的性能指标主要包括&#xff08;⭐⭐⭐⭐&#xff09; 网络体系结构OSI模型定…

英伟达发布 Windows 版 TensorRT-LLM 库

导读英伟达发布了 Windows 版本的 TensorRT-LLM 库&#xff0c;称其将大模型在 RTX 上的运行速度提升 4 倍。 GeForce RTX 和 NVIDIA RTX GPU 配备了名为 Tensor Core 的专用 AI 处理器&#xff0c;正在为超过 1 亿台 Windows PC 和工作站带来原生生成式 AI 的强大功能。 Tens…

Python笔记——linux/ubuntu下安装mamba,安装bob.learn库

Python笔记——linux/ubuntu下安装mamba&#xff0c;安装bob.learn库 一、安装/卸载anaconda二、安装mamba1. 命令行安装&#xff08;大坑&#xff0c;不推荐&#xff09;2. 命令行下载guihub上的安装包并安装&#xff08;推荐&#xff09;3. 网站下载安装包并安装&#xff08;…

R语言中的自带的调色板--五种--全平台可用

R语言中的自带的调色板–五种–全平台可用

YOLOv5论文作图教程(2)— 软件界面布局和基础功能介绍

前言:Hello大家好,我是小哥谈。通过上一节课的学习,相信大家都已成功安装好软件了,本节课就给大家详细介绍一下Axure RP9软件的界面布局及相关基础功能,希望大家学习之后能够有所收获!🌈 前期回顾: YOLOv5论文作图教程(1)— 软件介绍及下载安装(包括软件包+下载安…

Java字符串常用函数 详解5000字 (刷题向 / 应用向)

1.直接定义字符串 直接定义字符串是指使用双引号表示字符串中的内容&#xff0c;例如"Hello Java"、"Java 编程"等。具体方法是用字符串常量直接初始化一个 String 对象&#xff0c;示例如下&#xff1a; 1. String str"Hello Java"; 或者 …

生成m3u8视频:批量剪辑与分割的完美结合

在视频处理领域&#xff0c;m3u8视频格式的出现为高效处理和优化视频内容提供了新的可能。尤其在批量剪辑和分割视频的过程中&#xff0c;掌握m3u8视频的生成技巧&#xff0c;意味着更高效的工作流程和更出色的创作效果。现在一起来看看云炫AI智剪如何生成m3u8视频的操作吧。 步…

代码生成器

Easycode Entity ##导入宏定义 $!{define.vm}##保存文件&#xff08;宏定义&#xff09; #save("/entity", ".java")##包路径&#xff08;宏定义&#xff09; #setPackageSuffix("entity")##自动导入包&#xff08;全局变量&#xff09; $!{au…

港科夜闻|香港科大戴希教授被选为腾讯公司新基石研究员

关注并星标 每周阅读港科夜闻 建立新视野 开启新思维 1、香港科大戴希教授被选为腾讯公司“新基石研究员”。10月30日&#xff0c;作为目前国内社会力量资助基础研究力度最大的公益项目之一&#xff0c;“新基石研究员项目”揭晓了第二期获资助名单&#xff0c;来自13个城市28家…

【Orangepi Zero2 全志H616】驱动超声波测距、gettimeofday时间函数API

一、HC-SR04超声波模块 超声波测距原理超声波的时序图 二、时间函数 API测试代码代码实现和验证 一、HC-SR04超声波模块 型号&#xff1a;HC-SR04 接线参考&#xff1a;模块除了两个电源引脚外&#xff0c;还有TRIG、ECHO引脚 / P0、P1 超声波测距原理 让它发送波&#…

竞赛选题 深度学习手势检测与识别算法 - opencv python

文章目录 0 前言1 实现效果2 技术原理2.1 手部检测2.1.1 基于肤色空间的手势检测方法2.1.2 基于运动的手势检测方法2.1.3 基于边缘的手势检测方法2.1.4 基于模板的手势检测方法2.1.5 基于机器学习的手势检测方法 3 手部识别3.1 SSD网络3.2 数据集3.3 最终改进的网络结构 4 最后…

Oracle安全基线检查

一、账户安全 1、禁止SYSDBA用户远程连接 用户具备数据库超级管理员(SYSDBA)权限的用户远程管理登录SYSDBA用户只能本地登录,不能远程。REMOTE_LOGIN_PASSWORDFILE函数的Value值为NONE。这意味着禁止共享口令文件,只能通过操作系统认证登录Oracle数据库。 1)检查REMOTE…