参数量仅有50KB的超轻量级unet变种网络egeunet【参数和计算量降低494和160倍】医疗图像分割实践

news2025/1/15 19:45:35

今天看到一篇挺有意思的文章,做的是跟医疗图像分割相关的工作,但是不像之前看到的一些工作一味地去追求高精度,因为医疗领域本身就是一个相对特殊的行业,对于模型产生的结果的精确性要求是很高的,带来的是参数量级的庞大,之所以觉得这篇论文挺有意思的就是因为这里的主要的点在于超轻量级但是并没有导致精度大幅下降。

官方论文地址在这里,如下所示:

 可见刚发表不久。

 EGE-UNet融合了两个主要模块:
Group multi-axis Hadamard Product Attention module (GHPA)
Group Aggregation Bridge module (GAB)
GHPA 利用哈达玛积注意力机制(HPA),通过将输入特征进行分组,对不同轴进行 HPA 操作,从多个视角提取病变信息。
GAB 通过分组聚合将不同规模的高级语义特征和低级细节特征以及解码器生成的掩码进行融合,从而有效提取多尺度信息,
通过融合上述两个模块提出了EGE-UNet模型实现了在参数和计算复杂度极低的情况下优秀的分割性能。
EGE-UNet的设计沿用了 U 形架构,包括对称的编码器-解码器部分。编码器由六个 stage 组成,各阶段的通道数量为{8, 16, 24, 32, 48, 64}。前三个阶段采用了普通卷积,而后三个阶段使用提出的 GHPA 来从多视角提取表征信息。
EGE-UNet 在编码器和解码器之间的每个阶段都集成了GAB。此外,模型还利用深监督生成不同规模的掩膜预测,这些预测用于损失函数并作为 GAB 的输入之一。通过这些高级模块的集成,EGE-UNet 在比先前的方法提升了分割性能的同时,显著减少了参数和计算负载。

 进一步详情可以自行研读发表的论文。

这里我也是初步了解了一下,主要是想要实际使用一下这个超轻量级的网络,因为我觉得这种类型的网络在现实工作里的意义更大,大参数量高精度模型固然很好,但是并未所有的工业或者是医疗场景里面的设备都具备那么高的算力能够支撑如此庞大的计算量的,如果能在高度轻量化的网络基础上保持不俗的精度性能的话着实还是很有实际意义的。

官方同时开源了项目,地址在这里,如下所示:

 感觉目前的star量很少,估计是了解到的人还不多吧,就让我来带一波热度吧。

从readme来看,作者给出来的实操训练手册可以说是简单到了极致了:

 数据集也一并准备好了,地址在这里,如下所示:

 自行下载下来即可,体积不大,下载起来应该还是很快的。

下载下载放到项目data目录下面解压缩即可,如下所示:

 可以看到:作者同时提供了两组数据集,项目源码默认使用的是isic2017的数据集的。

直接终端执行train.py模块即可,如下所示:

 默认300个epoch的迭代计算:

 训练完成截图如下所示:

 结果默认存储在results目录下。如下所示:

 checkpoints目录下存放的是训练得到的模型文件,如下所示:

 log目录下存放的是训练日志数据,如下所示:

 outputs目录下存放的是实际测试的实例图像可视化结果,如下所示:

 官方项目只提供了训练、评估使用的代码,没有提供离线推理可直接使用的代码,但是基于训练和评估部分的代码可以自行开发离线推理的代码,这里我为了能够更加简单的使用开发了专用的可视化系统界面,实例推理效果如下所示:

 到这里基本完整的实践就结束了,前面也说过了源码默认使用的是isic2017的数据集,所以后面我又考虑基于isic2018的数据集也开发训练一下模型,只需要修改configs目录下的参数即可,如下所示:

 修改后的config_setting模块如下所示:

from torchvision import transforms
from utils import *

from datetime import datetime

class setting_config:
    """
    the config of training setting.
    """

    network = 'egeunet'
    model_config = {
        'num_classes': 1, 
        'input_channels': 3, 
        'c_list': [8,16,24,32,48,64], 
        'bridge': True,
        'gt_ds': True,
    }

    datasets = 'isic18' 
    if datasets == 'isic18':
        data_path = './data/isic2018/'
    elif datasets == 'isic17':
        data_path = './data/isic2017/'
    else:
        raise Exception('datasets in not right!')

    criterion = GT_BceDiceLoss(wb=1, wd=1)

    pretrained_path = './pre_trained/'
    num_classes = 1
    input_size_h = 256
    input_size_w = 256
    input_channels = 3
    distributed = False
    local_rank = -1
    num_workers = 0
    seed = 42
    world_size = None
    rank = None
    amp = False
    gpu_id = '0'
    batch_size = 8
    epochs = 300

    work_dir = 'results/' + network + '_' + datasets + '_' + datetime.now().strftime('%A_%d_%B_%Y_%Hh_%Mm_%Ss') + '/'

    print_interval = 20
    val_interval = 30
    save_interval = 100
    threshold = 0.5

    train_transformer = transforms.Compose([
        myNormalize(datasets, train=True),
        myToTensor(),
        myRandomHorizontalFlip(p=0.5),
        myRandomVerticalFlip(p=0.5),
        myRandomRotation(p=0.5, degree=[0, 360]),
        myResize(input_size_h, input_size_w)
    ])
    test_transformer = transforms.Compose([
        myNormalize(datasets, train=False),
        myToTensor(),
        myResize(input_size_h, input_size_w)
    ])

    opt = 'AdamW'
    assert opt in ['Adadelta', 'Adagrad', 'Adam', 'AdamW', 'Adamax', 'ASGD', 'RMSprop', 'Rprop', 'SGD'], 'Unsupported optimizer!'
    if opt == 'Adadelta':
        lr = 0.01 # default: 1.0 – coefficient that scale delta before it is applied to the parameters
        rho = 0.9 # default: 0.9 – coefficient used for computing a running average of squared gradients
        eps = 1e-6 # default: 1e-6 – term added to the denominator to improve numerical stability 
        weight_decay = 0.05 # default: 0 – weight decay (L2 penalty) 
    elif opt == 'Adagrad':
        lr = 0.01 # default: 0.01 – learning rate
        lr_decay = 0 # default: 0 – learning rate decay
        eps = 1e-10 # default: 1e-10 – term added to the denominator to improve numerical stability
        weight_decay = 0.05 # default: 0 – weight decay (L2 penalty)
    elif opt == 'Adam':
        lr = 0.001 # default: 1e-3 – learning rate
        betas = (0.9, 0.999) # default: (0.9, 0.999) – coefficients used for computing running averages of gradient and its square
        eps = 1e-8 # default: 1e-8 – term added to the denominator to improve numerical stability 
        weight_decay = 0.0001 # default: 0 – weight decay (L2 penalty) 
        amsgrad = False # default: False – whether to use the AMSGrad variant of this algorithm from the paper On the Convergence of Adam and Beyond
    elif opt == 'AdamW':
        lr = 0.001 # default: 1e-3 – learning rate
        betas = (0.9, 0.999) # default: (0.9, 0.999) – coefficients used for computing running averages of gradient and its square
        eps = 1e-8 # default: 1e-8 – term added to the denominator to improve numerical stability
        weight_decay = 1e-2 # default: 1e-2 – weight decay coefficient
        amsgrad = False # default: False – whether to use the AMSGrad variant of this algorithm from the paper On the Convergence of Adam and Beyond 
    elif opt == 'Adamax':
        lr = 2e-3 # default: 2e-3 – learning rate
        betas = (0.9, 0.999) # default: (0.9, 0.999) – coefficients used for computing running averages of gradient and its square
        eps = 1e-8 # default: 1e-8 – term added to the denominator to improve numerical stability
        weight_decay = 0 # default: 0 – weight decay (L2 penalty) 
    elif opt == 'ASGD':
        lr = 0.01 # default: 1e-2 – learning rate 
        lambd = 1e-4 # default: 1e-4 – decay term
        alpha = 0.75 # default: 0.75 – power for eta update
        t0 = 1e6 # default: 1e6 – point at which to start averaging
        weight_decay = 0 # default: 0 – weight decay
    elif opt == 'RMSprop':
        lr = 1e-2 # default: 1e-2 – learning rate
        momentum = 0 # default: 0 – momentum factor
        alpha = 0.99 # default: 0.99 – smoothing constant
        eps = 1e-8 # default: 1e-8 – term added to the denominator to improve numerical stability
        centered = False # default: False – if True, compute the centered RMSProp, the gradient is normalized by an estimation of its variance
        weight_decay = 0 # default: 0 – weight decay (L2 penalty)
    elif opt == 'Rprop':
        lr = 1e-2 # default: 1e-2 – learning rate
        etas = (0.5, 1.2) # default: (0.5, 1.2) – pair of (etaminus, etaplis), that are multiplicative increase and decrease factors
        step_sizes = (1e-6, 50) # default: (1e-6, 50) – a pair of minimal and maximal allowed step sizes 
    elif opt == 'SGD':
        lr = 0.01 # – learning rate
        momentum = 0.9 # default: 0 – momentum factor 
        weight_decay = 0.05 # default: 0 – weight decay (L2 penalty) 
        dampening = 0 # default: 0 – dampening for momentum
        nesterov = False # default: False – enables Nesterov momentum 
    
    sch = 'CosineAnnealingLR'
    if sch == 'StepLR':
        step_size = epochs // 5 # – Period of learning rate decay.
        gamma = 0.5 # – Multiplicative factor of learning rate decay. Default: 0.1
        last_epoch = -1 # – The index of last epoch. Default: -1.
    elif sch == 'MultiStepLR':
        milestones = [60, 120, 150] # – List of epoch indices. Must be increasing.
        gamma = 0.1 # – Multiplicative factor of learning rate decay. Default: 0.1.
        last_epoch = -1 # – The index of last epoch. Default: -1.
    elif sch == 'ExponentialLR':
        gamma = 0.99 #  – Multiplicative factor of learning rate decay.
        last_epoch = -1 # – The index of last epoch. Default: -1.
    elif sch == 'CosineAnnealingLR':
        T_max = 50 # – Maximum number of iterations. Cosine function period.
        eta_min = 0.00001 # – Minimum learning rate. Default: 0.
        last_epoch = -1 # – The index of last epoch. Default: -1.  
    elif sch == 'ReduceLROnPlateau':
        mode = 'min' # – One of min, max. In min mode, lr will be reduced when the quantity monitored has stopped decreasing; in max mode it will be reduced when the quantity monitored has stopped increasing. Default: ‘min’.
        factor = 0.1 # – Factor by which the learning rate will be reduced. new_lr = lr * factor. Default: 0.1.
        patience = 10 # – Number of epochs with no improvement after which learning rate will be reduced. For example, if patience = 2, then we will ignore the first 2 epochs with no improvement, and will only decrease the LR after the 3rd epoch if the loss still hasn’t improved then. Default: 10.
        threshold = 0.0001 # – Threshold for measuring the new optimum, to only focus on significant changes. Default: 1e-4.
        threshold_mode = 'rel' # – One of rel, abs. In rel mode, dynamic_threshold = best * ( 1 + threshold ) in ‘max’ mode or best * ( 1 - threshold ) in min mode. In abs mode, dynamic_threshold = best + threshold in max mode or best - threshold in min mode. Default: ‘rel’.
        cooldown = 0 # – Number of epochs to wait before resuming normal operation after lr has been reduced. Default: 0.
        min_lr = 0 # – A scalar or a list of scalars. A lower bound on the learning rate of all param groups or each group respectively. Default: 0.
        eps = 1e-08 # – Minimal decay applied to lr. If the difference between new and old lr is smaller than eps, the update is ignored. Default: 1e-8.
    elif sch == 'CosineAnnealingWarmRestarts':
        T_0 = 50 # – Number of iterations for the first restart.
        T_mult = 2 # – A factor increases T_{i} after a restart. Default: 1.
        eta_min = 1e-6 # – Minimum learning rate. Default: 0.
        last_epoch = -1 # – The index of last epoch. Default: -1. 
    elif sch == 'WP_MultiStepLR':
        warm_up_epochs = 10
        gamma = 0.1
        milestones = [125, 225]
    elif sch == 'WP_CosineLR':
        warm_up_epochs = 20

重新训练启动日志输出如下所示:

 整体的资源占用可以看到还是很低的,如下所示:

 等到模型训练完成后再来看下实际效果,感兴趣的话都可以自己尝试实践一下。后面可以考虑将本文中的超轻量级的模型应用到实际项目开发过程中。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/814819.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Pytorch深度学习-----神经网络之卷积层用法详解

系列文章目录 PyTorch深度学习——Anaconda和PyTorch安装 Pytorch深度学习-----数据模块Dataset类 Pytorch深度学习------TensorBoard的使用 Pytorch深度学习------Torchvision中Transforms的使用(ToTensor,Normalize,Resize ,Co…

交通运输安全大数据分析解决方案

当前运输市场竞争激烈,道路运输企业受传统经营观念影响,企业管理者安全意识淡薄,从业人员规范化、流程化的管理水平较低,导致制度规范在落实过程中未能有效监督与管理,执行过程中出现较严重的偏差,其营运车…

【C++入门到精通】C++入门 —— 类和对象(初始化列表、Static成员、友元、内部类、匿名对象)

目录 一、初始化列表 ⭕初始化列表概念 ⭕初始化列表的优点 ⭕使用场景 ⭕explicit关键字 二、Static成员 ⭕Static成员概念 🔴静态数据成员: 🔴静态函数成员: ⭕使用静态成员的优点 ⭕使用静态成员的注意事项 三、友…

国际化警告Fall back to translate ‘creator‘ key with ‘zn‘ locale.

发现是自己粗心写错了一个单词 这个需要改成zh

OC对象内存布局与isa指针

文章目录 一、Objective-C的本质二、一个objc对象如何进行内存布局?考虑父类的情况三、一个objc对象的isa指针指向什么?有什么作用四、objc对象的类方法和实例方法有什么本质区别和联系? 一、Objective-C的本质 Objc的底层实现是C\C代码&…

微信小程序tab加列表demo

一、效果 代码复制即可使用,记得把图标替换成个人工程项目图片。 微信小程序开发经常会遇到各种各样的页面组合,本demo为list列表与tab组合,代码如下: 二、json代码 {"usingComponents": {},"navigationStyle&q…

goctl template一定制化服务配置生成

官网介绍: 模板(Template)是数据驱动生成的基础,所有的代码(rest api、rpc、model、docker、kube)生成都会依赖模板, 默认情况下,模板生成器会选择内存中的模板进行生成&#xff0c…

easyui实用点

easyui实用点 1.下拉框(input框只能选不能手动输入编辑) data-options"editable:false"//不可编辑2.日期框,下拉框,文本框等class class"easyui-datebox"//不带时分秒 class"easyui-datetimebox"…

口碑+丨香港邮政联合极智嘉建立全港首个机器人邮包分拣系统

近日,香港邮政与全球仓储机器人引领者极智嘉(Geek)在其中央邮件中心联手建立全港首个机器人包裹分拣系统。该全新系统采用极智嘉分运结合解决方案,每小时可处理达1,000个邮包,助力香港邮政利用创新科技简化邮包分拣流程、提升工作效率&#x…

allwinner 全志RS485调试,GPIO状态与万用表测量不同

全志RS485调试 思路:UART驱动中已经将485流控功能加进去了,所以我们只需要根据硬件原理图配置一下485脚的GPIO就行了。 硬件原理图: 将UART3 UART4的RTS脚配置为485流控脚就行, RX和TX不需要配置,在pinctrol已经配置好…

热风梳C22.2 NO.3亚马逊加拿大审核标准

加拿大是目前亚马逊所有站点中,商业规模大、发展势头迅猛的站点之一。亚马逊加拿大站每月吸引近1600万访客。其优势在于在加拿大,目前平台的竞争较小,商家容易出单。既然加拿大站有这么多优势,那产品上架需要有哪些检测认证合规方…

[MAUI 项目实战] 手势控制音乐播放器: 手势交互

原理 定义一个拖拽物,和它拖拽的目标,拖拽物可以理解为一个平底锅(pan),拖拽目标是一个坑(pit),当拖拽物进入坑时,拖拽物就会被吸附在坑里。可以脑补一下下图&#xff1…

腾讯地图点标记加调用

先看效果 PHP代码 <?phpnamespace kds_addons\edata\controller;use think\addons\Controller; use think\Db;class Maps extends Controller {// 经纬度计算面积function calculate_area($points){$totalArea 0;$numPoints count($points);if ($numPoints > 2) {…

Qt、Qt Creator下载、安装

一、Qt、Qtcreator简介 Qt是一个跨平台应用开发框架。 Qt Creator是一个跨平台的集成开发环境&#xff08;IDE&#xff09;&#xff0c;集成了Qt所提供的功能&#xff0c;可以单独下载使用&#xff0c;也可以结合Qt组合使用。 二、下载 下载地址&#xff1a;https://downloa…

2023 ChinaJoy | 移远通信携手高通,共创数字娱乐新体验

当前&#xff0c; 5G、AI、大数据等智能创新技术正以惊人的速度蔓延至越来越多的领域&#xff0c;从智能家居、智能交通、智能医疗到智能制造&#xff0c;改变了我们的工作和生活方式。 而在数字娱乐领域&#xff0c;智能创新技术也展现出了巨大的潜力。作为全球领先的物联网整…

13.5.5 【Linux】其他相关文件

除了前一小节谈到的 /etc/securetty 会影响到 root 可登陆的安全终端机&#xff0c; /etc/nologin 会影响到一般使用者是否能够登陆的功能之外&#xff0c;我们也知道 PAM 相关的配置文件在 /etc/pam.d &#xff0c;说明文档在 /usr/share/doc/pam-&#xff08;版本&#xff09…

牛客网Verilog刷题——VL47

牛客网Verilog刷题——VL47 题目答案 题目 实现4bit位宽的格雷码计数器。 电路的接口如下图所示&#xff1a; 输入输出描述&#xff1a; 信号类型输入/输出位宽描述clkwireIntput1时钟信号rst_nwireIntput1异步复位信号&#xff0c;低电平有效gray_outregOutput4输出格雷码计数…

管理ceph集群

文章目录 ceph的常用命令查看集群状态查看pg的状态查看mon节点状态查看osd的通用命令查看osd的容量查看osd池写入文件测试查看池的属性查看文件映射过程 添加磁盘删除磁盘 ceph的常用命令 查看集群状态 ceph osd pool application enable pool-name rbd #将池启用rbd功能 ceph…

Java集合框架-List、Set、Map

一、Java集合框架概述&#xff1a; 1.1 Collection接口继承树 JDK提供的集合API位于java.util包内。 Map接口继承树 1.2 Collection接口方法 Collection 接口 Collection 接口是 List、Set 和 Queue 接口的父接口&#xff0c;该接口里定义的方法既可用于操作 Set 集合&#…

Matlab进阶绘图第24期—悬浮柱状图

悬浮柱状图是一种特殊的柱状图。 与常规柱状图相比&#xff0c;悬浮柱状图可以通过悬浮的矩形展示最小值到最大值的范围&#xff08;或其他范围表达&#xff09;&#xff0c;因此在多个领域得到应用。 本文使用自己制作的Floatingbar小工具进行悬浮柱状图的绘制&#xff0c;先…