U-Net网络模型改进(添加通道与空间注意力机制)---亲测有效,指标提升

news2025/1/17 2:49:23

U-Net网络模型(注意力改进版本)
这一段时间做项目用到了U-Net网络模型,但是原始的U-Net网络还有很大的改良空间,在卷积下采样的过程中加入了通道注意力和空间注意力 。

常规的U-net模型如下图:
在这里插入图片描述
红色箭头为可以添加的地方:即下采样之间。
在这里插入图片描述

通道空间注意力是一个即插即用的注意力模块(如下图):
在这里插入图片描述
代码加入之后对于分割效果是有提升的:(代码如下)

CBAM代码:

class ChannelAttentionModule(nn.Module):
    def __init__(self, channel, ratio=16):
        super(ChannelAttentionModule, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)

        self.shared_MLP = nn.Sequential(
            nn.Conv2d(channel, channel // ratio, 1, bias=False),
            nn.ReLU(),
            nn.Conv2d(channel // ratio, channel, 1, bias=False)
        )
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avgout = self.shared_MLP(self.avg_pool(x))
        maxout = self.shared_MLP(self.max_pool(x))
        return self.sigmoid(avgout + maxout)

class SpatialAttentionModule(nn.Module):
    def __init__(self):
        super(SpatialAttentionModule, self).__init__()
        self.conv2d = nn.Conv2d(in_channels=2, out_channels=1, kernel_size=7, stride=1, padding=3)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avgout = torch.mean(x, dim=1, keepdim=True)
        maxout, _ = torch.max(x, dim=1, keepdim=True)
        out = torch.cat([avgout, maxout], dim=1)
        out = self.sigmoid(self.conv2d(out))
        return out

class CBAM(nn.Module):
    def __init__(self, channel):
        super(CBAM, self).__init__()
        self.channel_attention = ChannelAttentionModule(channel)
        self.spatial_attention = SpatialAttentionModule()

    def forward(self, x):
        out = self.channel_attention(x) * x
        out = self.spatial_attention(out) * out
        return out

网络模型结合之后代码:

class conv_block(nn.Module):
    def __init__(self,ch_in,ch_out):
        super(conv_block,self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(ch_in, ch_out, kernel_size=3,stride=1,padding=1,bias=True),
            nn.BatchNorm2d(ch_out),
            nn.ReLU(inplace=True),
            nn.Conv2d(ch_out, ch_out, kernel_size=3,stride=1,padding=1,bias=True),
            nn.BatchNorm2d(ch_out),
            nn.ReLU(inplace=True)
        )

    def forward(self,x):
        x = self.conv(x)
        return x

class up_conv(nn.Module):
    def __init__(self,ch_in,ch_out):
        super(up_conv,self).__init__()
        self.up = nn.Sequential(
            nn.Upsample(scale_factor=2),
            nn.Conv2d(ch_in,ch_out,kernel_size=3,stride=1,padding=1,bias=True),
		    nn.BatchNorm2d(ch_out),
			nn.ReLU(inplace=True)
        )

    def forward(self,x):
        x = self.up(x)
        return x

class U_Net_v1(nn.Module):   #添加了空间注意力和通道注意力
    def __init__(self,img_ch=3,output_ch=2):
        super(U_Net_v1,self).__init__()
        
        self.Maxpool = nn.MaxPool2d(kernel_size=2,stride=2)

        self.Conv1 = conv_block(ch_in=img_ch,ch_out=64) #64
        self.Conv2 = conv_block(ch_in=64,ch_out=128)  #64 128
        self.Conv3 = conv_block(ch_in=128,ch_out=256) #128 256
        self.Conv4 = conv_block(ch_in=256,ch_out=512) #256 512
        self.Conv5 = conv_block(ch_in=512,ch_out=1024) #512 1024

        self.cbam1 = CBAM(channel=64)
        self.cbam2 = CBAM(channel=128)
        self.cbam3 = CBAM(channel=256)
        self.cbam4 = CBAM(channel=512)

        self.Up5 = up_conv(ch_in=1024,ch_out=512)  #1024 512
        self.Up_conv5 = conv_block(ch_in=1024, ch_out=512)  

        self.Up4 = up_conv(ch_in=512,ch_out=256)  #512 256
        self.Up_conv4 = conv_block(ch_in=512, ch_out=256)  
        
        self.Up3 = up_conv(ch_in=256,ch_out=128)  #256 128
        self.Up_conv3 = conv_block(ch_in=256, ch_out=128) 
        
        self.Up2 = up_conv(ch_in=128,ch_out=64) #128 64
        self.Up_conv2 = conv_block(ch_in=128, ch_out=64)  

        self.Conv_1x1 = nn.Conv2d(64,output_ch,kernel_size=1,stride=1,padding=0)  #64


    def forward(self,x):
        # encoding path
        x1 = self.Conv1(x)
        x1 = self.cbam1(x1) + x1

        x2 = self.Maxpool(x1)
        x2 = self.Conv2(x2)
        x2 = self.cbam2(x2) + x2
        
        x3 = self.Maxpool(x2)
        x3 = self.Conv3(x3)
        x3 = self.cbam3(x3) + x3

        x4 = self.Maxpool(x3)
        x4 = self.Conv4(x4)
        x4 = self.cbam4(x4) + x4

        x5 = self.Maxpool(x4)
        x5 = self.Conv5(x5)

        # decoding + concat path
        d5 = self.Up5(x5)
        d5 = torch.cat((x4,d5),dim=1)
        
        d5 = self.Up_conv5(d5)
        
        d4 = self.Up4(d5)
        d4 = torch.cat((x3,d4),dim=1)
        d4 = self.Up_conv4(d4)

        d3 = self.Up3(d4)
        d3 = torch.cat((x2,d3),dim=1)
        d3 = self.Up_conv3(d3)

        d2 = self.Up2(d3)
        d2 = torch.cat((x1,d2),dim=1)
        d2 = self.Up_conv2(d2)

        d1 = self.Conv_1x1(d2)

        return d1

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1287382.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

电表峰谷平是怎么分时间的?

电表的峰谷平时间是指电力公司根据电力需求的不同,将一天的时间划分为不同的时段,以此来确定不同时间段内的电费价格。这种不同时段对应不同电费价格的制度,旨在更好地平衡电力供需,促进能源的高效利用。 首先,我们来了…

记录一下Mac配置SpringBoot开发环境

由于很多项目喜欢使用传统的 Java 8 进行开发,而且 Java 8 的稳定性也是经过长久考验的,我们接下来就尝试一下,在一台新的 Mac 中配置 Java 环境,并且开始创建 SpringBoot 项目。 首先,去 Oracle 官网下载 java8 JDK …

springboot详解Mybatis-Plus中分页插件PaginationInterceptor标红

1.问题描述 在springboot项目中,类中引用PaginationInterceptor,标红,如下图所示: 2.问题分析 可能是因为pom.xml中的配置原因,导致不支持PaginationInterceptor 3.解决问题 更换版本后 更换后,记得Rel…

开发的客户收到样品表示质量不如原供应商如何应对

有小伙伴问,在开发客户的过程当中,给客户寄了样品,客户说他的样品没有原来供应商的好怎么办? 这个问题我们来想一下,客户既然愿意把地址给我们,愿意去接你的样品,说明什么?说明客户…

【剑指offer|图解|位运算】训练计划VI+撞色搭配

🌈个人主页:聆风吟 🔥系列专栏:数据结构、剑指offer每日一练 🔖少年有梦不应止于心动,更要付诸行动。 文章目录 一. ⛳️训练计划VI(题目难度:中等)1.1 题目1.2 示例1.3 …

2024年猫罐头排行榜前十名有哪些?分享2024年猫罐头排行榜前10名

很多人家里的哈基米是不是吃猫粮吃腻了,或者猫猫平时不喜欢喝水,又或者看猫猫太瘦了想入手几款猫罐头但是又愁于不会选择。而且现在猫罐头风这么大不知道选什么好~ 作为从业5年的宠物医生,给你说猫罐头行业内幕。我告诫大家,选择…

Python网络爬虫环境的安装指南

网络爬虫是一种自动化的网页数据抓取技术,广泛用于数据挖掘、信息搜集和互联网研究等领域。Python作为一种强大的编程语言,拥有丰富的库支持网络爬虫的开发。本文将为你详细介绍如何在你的计算机上安装Python网络爬虫环境。 一、安装python开发环境 进…

SIT75179B,5V 供电, 10Mbps,全双工差分芯片/全双工 RS485/RS422 收发器

SIT75179B 是一款 4.5V~5.5V 供电全双工差分芯片,可完全满足 TIA/EIA-485/422 标准要求 的收发器。 SIT75179B 包括一个驱动器和一个接收器,两者均可独立传输信号。 SIT75179B 具有 1/8 负 载,允许 256 个 SIT75179B 收发器并…

【AIGC】AI作图最全提示词prompt集合(收藏级)

目录 一、正向和负向提示词 二、作图参数 你好,我是giszz. AI做图真是太爽了,解放生产力,发展生产力。 但是,你是不是也总疑惑,为什么别人的图,表现力那么丰富呢,而且指哪打哪,要…

陀螺仪防抖术语

陀螺仪防抖术语 fov 视场角 drift 零偏   MotionFusion即运动传感器的融合补偿,对陀螺仪、加速度计等运动测量器件的数据 进行预处理,通过标定和补偿,为防抖提供校准后的陀螺仪数据 ratio 系数 gyro 陀螺仪 calibration 校准 标定 DIS&…

深度优先搜索LeetCode979. 在二叉树中分配硬币

给你一个有 n 个结点的二叉树的根结点 root ,其中树中每个结点 node 都对应有 node.val 枚硬币。整棵树上一共有 n 枚硬币。 在一次移动中,我们可以选择两个相邻的结点,然后将一枚硬币从其中一个结点移动到另一个结点。移动可以是从父结点到…

游戏:火星孤征 - deliver us mars - 美图秀秀~~

今天水一篇,借着免费周下载了deliver us mars,玩下来截了好多图,就放这里了。 游戏没有难度,剧情也不难理解,美图到处都是,建模细节也是满满,值得一玩。 游戏中的 A.S.E是守卫飞行机器人&…

甘草书店:#7 2023年11月19日 星期日 波澜不惊的日子里稳步前行

前进,可以伴着惊涛骇浪,也可以波澜不惊。 几番沟通,多方协商之后,甘草书店硬装方案基本确定,近期开始施工。 书目选择方面也在逐步推进。 就像之前设想的,划分成企业经管和个人成长两大类的前提下&#…

三、DVP摄像头调试笔记(图片成像质量微调整,非ISP)

说明:当前调试仅仅用来测试和熟悉部分摄像头寄存器模式 一、图片成像方向控制,基本每个摄像头都会有上下左右翻转寄存器 正向图片 反向图片 二、设置成像数据成各种颜色,(黑白/原彩/黄色等等) 在寄存器书册描述中…

【实用+干货】如何使用Clickhouse搭建百亿级用户画像平台看这一篇就够了

背景 如果你是用户,当你使用抖音、小红书的时候,假如平台能根据你的属性、偏好、行为推荐给你感兴趣的内容,那就能够为你节省大量获取内容的时间。 如果你是商家,当你要进行广告投放的时候,假如平台推送的用户都是你潜…

感冒 发烧 咳嗽记录

感冒 风寒: 清鼻涕 热感冒: 细菌记录, 脓鼻涕. 咳嗽 先是清痰咳嗽, 后是浓痰,细菌感染, 白细胞噬菌体, 所以要补充蛋白质,维生素. 胸骨上窝 , 天突穴 ,后面上支气管的位置, 往下会变成左右两支,连接到肺部 普通咳嗽: 用哈气拍打背部的方式. 把痰去除. 吃点 盐酸氨溴索片 增加支…

5.2k Star!一个可视化全球实时天气开源项目!

大家好,本文给大家推荐一款全球实时天气开源项目:Earth。 项目简介 Earth 是一个可视化全球天气实况的项目。该项目以可视化的方式展示了全球的天气情况,提供了风、温度、相对湿度等多种天气数据,以及风、洋流和波浪的动画效果…

你好!哈希表【JAVA】

1.初识🎶🎶🎶 它基本上是由一个数组和一个哈希函数组成的。哈希函数将每个键映射到数组的特定索引位置,这个位置被称为哈希码。当我们需要查找一个键时,哈希函数会计算其哈希码并立即返回结果,因此我们可以…

it统一运维平台怎么样?有可以推荐的品牌吗?

随着互联网化,随着信息化的不断发展,企业IT系统的规模和复杂性也在日益增加。在这个背景下,IT统一运维平台就应用而生了。它以一种全面、集成的方式管理企业IT资源,从而提高效率、降低成本、改善服务,为企业提供更快更…

Java中的Future源码讲解

JAVA Future源码解析 文章目录 JAVA Future源码解析前言一、传统异步实现的弊端二、what is Future ?2.1 Future的基本概念2.2Future 接口方法解析2.2.1 取消任务执行cancel2.2.2 检索任务是否被取消 isCancelled2.2.3 检索任务是否完成 isDone2.2.3 检索任务计算结果 get 三、…