Attention U-Net:Learning Where to Look for the Pancreas论文总结和代码实现

news2024/11/18 10:29:57

论文: https://arxiv.org/abs/1804.03999

中文版:https://blog.csdn.net/hhw999/article/details/110134398

源码: https://github.com/ozan-oktay/Attention-Gated-Networks

目录

一、论文背景和出发点

二、创新点

三、Attention U-Net的具体实现

四、实验

五、结论


一、论文背景和出发点

本文提出了一种用于医学成像的新型注意门(AG)模型,该模型可以自动学习聚焦于不同形状和大小的目标结构。

使用AG训练的模型可以不显著地学习输入图像中需要抑制的不相关区域,同时突出对特定任务有用的显著特征。这样就无需再使用级联卷积神经网络(CNNs)的显式外部组织/器官定位模块。

AGs可以很容易地集成到标准的CNN架构中,如U-Net模型,只需最少的计算开销,同时提高了模型的灵敏度和预测精度

二、创新点

1. 提出了基于网格的门控,使注意系数更聚焦于局部区域。该方法可以用于密集预测。

2. 提出的soft-attention技术是第一个用于医学成像任务的前馈CNN模型。提出的注意力门可以替代图像分类和图像分割框架中的外部器官定位模型中使用的hard-attention方法。

3. 提出了一种对标准U-Net模型的扩展,在不需要复杂启发式的情况下提高模型对前景像素的灵敏度。

AGs中的特征选择性是通过使用在较粗尺度上提取的上下文信息(门控)来实现的。 注意力门(AG)通过跳跃连接的方式过滤被传播的特征。

三、Attention U-Net的具体实现

1. AG

所提出的加性注意门(AG)示意图,如下:

目的:抑制不相关背景区域的特征响应,突出IOU区域。

方法:输入特征(x^l)使用在AG中计算的注意系数(α)进行缩放。

步骤:首先,对输入特征x^l1x1的卷积操作,同时也对与x^l同一层的下采样特征g也做1x1的卷积操作,然后,将卷积后的两个输出特征相加,将相加结果做relu激活,再然后,对激活结果做1x1x1的卷积操作,将特征图通道数变换为1(原文中说这里是线性变换??),其次,对线性变换结果做sigmoid激活,再进行resample(整形,源码中没有这一步),得到一个与原特征大小一致1维的权重矩阵\alpha,最后,权重矩阵\alpha与输入特征x^l相乘,返回一个新的特征图\hat{x}^l。对应算子公式如下:

                                                                         {\hat{x}^l_{ic}}=x^l_{ic} \cdot \alpha ^l_i

其中,g为与x^l同一层下采样得到的特征,W_{x} \in \mathbb{R}^{F_{t} \times F_{int}}W_{g} \in \mathbb{R}^{F_{g} \times F_{int}}是1x1的卷积操作,b_gb_\Psi是偏置项,\psi ^T是1x1x1的卷积操作,\sigma_2(xi,c)是softmax激活函数,\sigma_1是sigmoid激活函数,\alpha^l_i权重矩阵

详情可见作者给出的2D unet源码

import numpy as np
import torch
import torch.nn as nn
from torchsummary import summary


class Attention_block(nn.Module):
    def __init__(self, F_g, F_l, F_int):
        super(Attention_block, self).__init__()
        self.W_g = nn.Sequential(
            nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )

        self.W_x = nn.Sequential(
            nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )

        self.psi = nn.Sequential(
            nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )

        self.relu = nn.ReLU(inplace=True)

    def forward(self, g, x):
        g1 = self.W_g(g)
        x1 = self.W_x(x)
        psi = self.relu(g1 + x1)
        psi = self.psi(psi)

        return x * psi


class conv_block(nn.Module):
    def __init__(self, ch_in, ch_out):
        super(conv_block, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True),
            nn.BatchNorm2d(ch_out),
            nn.ReLU(inplace=True),
            nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1, bias=True),
            nn.BatchNorm2d(ch_out),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.conv(x)


class up_conv(nn.Module):
    def __init__(self, ch_in, ch_out):
        super(up_conv, self).__init__()
        self.up = nn.Sequential(
            nn.Upsample(scale_factor=2),
            nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True),
            nn.BatchNorm2d(ch_out),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.up(x)


class AttU_Net(nn.Module):
    def __init__(self, n_channels=3, n_classes=1, scale_factor=1):
        super(AttU_Net, self).__init__()
        filters = np.array([64, 128, 256, 512, 1024])
        filters = filters // scale_factor
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.scale_factor = scale_factor
        self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=2)

        self.Conv1 = conv_block(ch_in=n_channels, ch_out=filters[0])
        self.Conv2 = conv_block(ch_in=filters[0], ch_out=filters[1])
        self.Conv3 = conv_block(ch_in=filters[1], ch_out=filters[2])
        self.Conv4 = conv_block(ch_in=filters[2], ch_out=filters[3])
        self.Conv5 = conv_block(ch_in=filters[3], ch_out=filters[4])

        self.Up5 = up_conv(ch_in=filters[4], ch_out=filters[3])
        self.Att5 = Attention_block(F_g=filters[3], F_l=filters[3], F_int=filters[2])
        self.Up_conv5 = conv_block(ch_in=filters[4], ch_out=filters[3])

        self.Up4 = up_conv(ch_in=filters[3], ch_out=filters[2])
        self.Att4 = Attention_block(F_g=filters[2], F_l=filters[2], F_int=filters[1])
        self.Up_conv4 = conv_block(ch_in=filters[3], ch_out=filters[2])

        self.Up3 = up_conv(ch_in=filters[2], ch_out=filters[1])
        self.Att3 = Attention_block(F_g=filters[1], F_l=filters[1], F_int=filters[0])
        self.Up_conv3 = conv_block(ch_in=filters[2], ch_out=filters[1])

        self.Up2 = up_conv(ch_in=filters[1], ch_out=filters[0])
        self.Att2 = Attention_block(F_g=filters[0], F_l=filters[0], F_int=filters[0] // 2)
        self.Up_conv2 = conv_block(ch_in=filters[1], ch_out=filters[0])

        self.Conv_1x1 = nn.Conv2d(filters[0], n_classes, kernel_size=1, stride=1, padding=0)

    def forward(self, x):
        # encoding path
        x1 = self.Conv1(x)

        x2 = self.Maxpool(x1)
        x2 = self.Conv2(x2)

        x3 = self.Maxpool(x2)
        x3 = self.Conv3(x3)

        x4 = self.Maxpool(x3)
        x4 = self.Conv4(x4)

        x5 = self.Maxpool(x4)
        x5 = self.Conv5(x5)

        # decoding + concat path
        d5 = self.Up5(x5)
        x4 = self.Att5(g=d5, x=x4)
        d5 = torch.cat((x4, d5), dim=1)
        d5 = self.Up_conv5(d5)

        d4 = self.Up4(d5)
        x3 = self.Att4(g=d4, x=x3)
        d4 = torch.cat((x3, d4), dim=1)
        d4 = self.Up_conv4(d4)

        d3 = self.Up3(d4)
        x2 = self.Att3(g=d3, x=x2)
        d3 = torch.cat((x2, d3), dim=1)
        d3 = self.Up_conv3(d3)

        d2 = self.Up2(d3)
        x1 = self.Att2(g=d2, x=x1)
        d2 = torch.cat((x1, d2), dim=1)
        d2 = self.Up_conv2(d2)

        d1 = self.Conv_1x1(d2)

        return d1

if __name__ == '__main__':
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    net = AttU_Net().to(device)
    # 打印网络结构和参数
    summary(net, (3, 224, 224))

这里增加了主函数,方便观察模型结构。

2. 将AG整合到UNet

目的:消除跳跃连接产生的不相关和嘈杂的响应(加强特征和去噪)。

方法:在每次上采样拼接操作之前,添加AG模块。AGs在前向传播和反向传播时通过对神经元的激活操作,减少背景区域产生的权重,加强对IOU区域的提取。

作者对l-1层卷积参数的更新方式的过程进行了推导,对应推导公式如下:

四、实验

数据集:已标注的150例胃癌患者腹部3D CT扫描(CT-150)图像、NIH-TCIA(胰脏数据集)。

训练分配比例:进行训练(120)和测试(30);训练(30)和测试(120)。

评估指标:dice score、surface to surface distance(s2s)。

实验1:

由上图可见,胰腺预测的结果表明,注意力门(AGs)通过提高模型的表达能力(AGs提高了前景区域的提取率)来提高recall值。

实验2:

 各种当前较为先进的CT胰腺分割模型的结果,与att u-net相比,att u-net有显著提升。

预测效果:

五、结论

提出了一种新的用于医学图像分割的注意力门控模型。该方法消除了使用额外目标定位模型的必要性。所提出的方法是通用的和模块化的,因此它可以很容易地应用于图像分类和回归问题,如在自然图像分析和机器翻译的例子。实验结果表明,所提出的AGs对组织/器官的识别和定位非常有利。对于可变的小尺寸器官,如胰腺,这一点尤其正确,而对于全局的分类任务,预期也会有类似的行为。
 

参考博文:图像分割UNet系列------Attention Unet详解

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/660750.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【5G MAC】5G中传输块(TBS)大小的计算方式

博主未授权任何人或组织机构转载博主任何原创文章,感谢各位对原创的支持! 博主链接 本人就职于国际知名终端厂商,负责modem芯片研发。 在5G早期负责终端数据业务层、核心网相关的开发工作,目前牵头6G算力网络技术标准研究。 博客…

【学习日记2023.6.18】之 分布式缓存redis持久化_redis主从_reids哨兵_redis分片集群

文章目录 分布式缓存1. Redis持久化1.1 RDB持久化1.1.1 执行时机1.1.2 RDB原理1.1.3 小结 1.2 AOF持久化1.2.1 AOF原理1.2.2 AOF配置1.2.3 AOF文件重写 1.3 RDB与AOF对比 2 Redis主从2.1 搭建主从架构2.1.1 准备实例和配置2.1.2 启动2.1.3 开启主从关系2.1.4 测试 2.2 主从数据…

计算服务资源调度管理

文章目录 前言总体架构“ULT”和“KLT”抽象“内核”“容器”“虚容器” 内存抽象虚拟存储(容器调用) 多机器调度 前言 今天复习了一下操作系统,系统过了一下,感觉还有点时间,那么顺便来讨论一下,关于我的…

.maloxx勒索病毒数据怎么处理|数据解密恢复,malox/mallox

导语: 随着科技的快速发展,数据成为了企业和个人不可或缺的财富。然而,网络安全威胁也日益增多,其中Mallox勒索病毒家族的最新变种.maloxx勒索病毒的出现给我们带来了巨大的困扰。但不要担心!91数据恢复研究院将为您揭…

一、Docker介绍

学习参考:尚硅谷Docker实战教程、Docker官网、其他优秀博客(参考过的在文章最后列出) 目录 前言一、Docker是什么?二、Docker能干撒?三、容器虚拟化技术 和 虚拟机有啥区别?1.虚拟机2.容器虚拟化技术3.对比 四、Docker组成4.1 镜像…

python自动化办公——定制化将电子签名批量签写到PDF文件

python自动化办公——定制化将电子签名批量签写到PDF文件 文章目录 python自动化办公——定制化将电子签名批量签写到PDF文件1、安装依赖2、需求分析3、代码 1、安装依赖 首先需要下载所需要的库 pip install pdf2image pip install img2pdf pip install opencv-python此外还…

【工作记录】基于可视化爬虫spiderflow实战天气数据爬取@20230618

前言 之前写过一篇关于可视化爬虫spiderflow的文章,介绍了基本语法并实战了某校园新闻数据的爬取。 还有一篇文章介绍了基于docker-compose快速部署spiderflow的过程,需要部署的话可参考该文章。 文章链接如下: 可视化爬虫框架spiderflow入门及实战【…

基于SpringBoot+Vue+微信小程序的电影平台

✌全网粉丝20W,csdn特邀作者、博客专家、CSDN新星计划导师、java领域优质创作者,博客之星、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ 🍅文末获取项目下载方式🍅 一、项目背景介绍: 研究背景:…

【通过Data Studio连接openGauss---快速入门】

【通过Data Studio连接openGauss---快速入门】 🔻 一、访问openGauss🔰 1.1 确认连接信息(单节点)🔰 1.2 使用gsql访问openGauss(本地连接数据库)🔰 1.3 使用gsql访问openGauss&…

多道程序设计(操作系统)

目录 1 单道程序设计的缺点 2 多道程序设计的提出 3 多道程序设计的问题 多道程序设计目标: 多道程序设计是操作系统所采用最基本、最重要的技术,其根本目标是提高整个计算机系统的效率。衡量系统效率有一个尺度,那就是吞吐量。 提高系统…

clDice-一种新的分割标准-能够促进管状结构分割的连接性

clDice-a Novel Topology-Preserving Loss Function for Tubular StructureSegmentation论文总结 论文:clDice-A Novel Topology-Preserving Loss Function for Tubular Structure 源码:GitHub - jocpae/clDice 目录 一、论文背景和出发点 二、创新点 …

动态规划III (买股票-121、122、123、188、309)

CP121 买股票的最佳时机 题目描述: 给定一个数组 prices ,它的第 i 个元素 prices[i] 表示一支给定股票第 i 天的价格。你只能选择 某一天 买入这只股票,并选择在 未来的某一个不同的日子 卖出该股票。设计一个算法来计算你所能获取的最大利…

Advanced-C.04.函数

函数 函数的定义 包括两个部分,“函数头"和"函数体” 返回值类型 函数名(形参1,形参2,...)//函数头{}//函数体 函数类型决定返回值类型,执行函数需要调用 函数的返回值和参数可以是任何类型,包括空类型!!函…

Android adb shell命令捕获systemtrace

Android adb shell命令捕获systemtrace (1)抓取trace文件: adb shell perfetto -o /data/misc/perfetto-traces/trace_file.perfetto-trace -t 20s sched freq idle am wm gfx view binder_driver hal dalvik camera input res memory -t 时长,20s&a…

Java学习笔记23——集合进阶

集合进阶 集合进阶CollectionCollection集合常用方法Collection集合的遍历Iterator中的常用方法集合的使用步骤 List集合概述和特点List集合的特点List集合的特有方法并发修改异常ListIterator 列表迭代器常用方法增强for循环 数据结构栈队列数组链表 Set集合Set特点实现类Hash…

Presto(Trino)的逻辑执行计划和Fragment生成过程

文章目录 1. 前言2. 从SQL提交到Fragment计划生成全过程2.1 Statement生成2.2 对结构化的Statement进行分析2.3 生成未优化的逻辑执行计划2.4 基于Visitor模型对逻辑执行计划进行优化2.4.1 Visitor模型介绍2.4.2 Presto中常见的逻辑执行计划优化器常规OptimizerIterativeOptimi…

阿里月薪23k软件测试工程师:必备的6大技能(建议收藏)

随着软件开发行业的日益发展,岗位需求量和行业薪资都不断增长,想要入行的人也是越来越多,但不知道从哪里下手,今天,就给大家分享一下,软件测试行业都有哪些必会的方法和技术知识点,作为小白该从…

EmGU(4.7) 和C#中特征检测算法详解集合

C#联合Emgu实现计算机视觉任务(特征提取篇) 文章目录 C#联合Emgu实现计算机视觉任务(特征提取篇)前言一、Emgu库中特征提取有哪些类函数?二、特征提取函数1.AgastFeatureDetector类2.AKAZE 类3.FastFeatureDetector类4…

Docker部署(2)——实现两个容器互相访问并运行项目

一、拉取MySQL镜像,并启动镜像对应的容器 由于上一篇文章实现了拉取jdk8的环境,同时将jar包打成了一个镜像。但是要想真正的把项目运行起来(此处仅以单体项目为例)还需要MySQL的容器提供数据支持(当然这里面方法有多种…

深蓝学院C++基础与深度解析笔记 第 4 章 表达式

第 4 章 表达式 一、表达式基础 A、表达式: 由一到多个操作数组成&#xff0c;可以求值并 ( 通常会 ) 返回求值结果: #include <iostream> int main(){int x;x 3; }最基本的表达式&#xff1a;变量、字面值通常来说&#xff0c;表达式会包含操作符&#xff08;运算符…