文章目录
- net.py
- 1.class Bottleneck:残差块
- 2.class ResNet:特征提取
- 3.class SRM:SR模块
- 4.class FAM:FIA模块
- 5.class CA:GCF模块
- 6.class SA:HA模块
- 7.class GCPANet:网络架构
- train.py
- test.py
论文:Global Context-Aware Progressive Aggregation Network for Salient Object Detection
论文链接:Global Context-Aware Progressive Aggregation Network for Salient Object Detection
代码链接:Github
net.py
1.class Bottleneck:残差块
class Bottleneck(nn.Module)
用于实现残差块。
class Bottleneck(nn.Module):
def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1):
#inplanes:输入通道数;planes:输出通道数;stride:步幅;downsample:下采样层;dilation:膨胀系数
super(Bottleneck, self).__init__()
#1×1卷积
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
#3×3卷积
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=(3*dilation-1)//2, bias=False, dilation=dilation)
self.bn2 = nn.BatchNorm2d(planes)
#1×1卷积
self.conv3 = nn.Conv2d(planes, planes*4, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(planes*4)
#下采样(若步幅不为1或输入通道数与目标通道数不匹配,则进行下采样)
self.downsample = downsample
def forward(self, x):
residual = x
#1×1卷积
out = F.relu(self.bn1(self.conv1(x)), inplace=True)
#3×3卷积
out = F.relu(self.bn2(self.conv2(out)), inplace=True)
#1×1卷积
out = self.bn3(self.conv3(out))
#若不能直接将x与特征残差连接,则需下采样
if self.downsample is not None:
residual = self.downsample(x)
#残差连接
return F.relu(out+residual, inplace=True)
2.class ResNet:特征提取
GCPANet模型使用
R
e
s
N
e
t
50
ResNet50
ResNet50作为特征提取器,
R
e
s
N
e
t
50
ResNet50
ResNet50共包含四个
B
l
o
c
k
Block
Block结构,每个
B
l
o
c
k
Block
Block中分别有3、4、6、3个
B
o
t
t
l
e
n
e
c
k
Bottleneck
Bottleneck。整体结构如下:
class ResNet(nn.Module):
def __init__(self):
super(ResNet, self).__init__()
#跟踪输入通道数
self.inplanes = 64
#conv1:7×7大小、输入通道3(RGB图像)、输出通道64、步长2、填充3
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = nn.BatchNorm2d(64)
#第一个残差层,对应conv_2
self.layer1 = self.make_layer( 64, 3, stride=1, dilation=1)
#第二个残差层,对应conv_3
self.layer2 = self.make_layer(128, 4, stride=2, dilation=1)
#第三个残差层,对应conv_4
self.layer3 = self.make_layer(256, 6, stride=2, dilation=1)
#第四个残差层,对应conv_5
self.layer4 = self.make_layer(512, 3, stride=2, dilation=1)
#权重初始化
self.initialize()
def make_layer(self, planes, blocks, stride, dilation):
downsample = None
#若步幅不为1或输入通道数与目标通道数不匹配,则进行下采样
if stride != 1 or self.inplanes != planes*4:
#使用1×1卷积和批量归一化进行下采样
downsample = nn.Sequential(nn.Conv2d(self.inplanes, planes*4, kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(planes*4))
#添加第一个残差块,使用Bottleneck结构(输入通道数、输出通道数、步长、下采样模块、膨胀系数)
layers = [Bottleneck(self.inplanes, planes, stride, downsample, dilation=dilation)]
#更新通道数,为原先四倍
self.inplanes = planes*4
#循环添加残差块
for _ in range(1, blocks):
layers.append(Bottleneck(self.inplanes, planes, dilation=dilation))
return nn.Sequential(*layers)
def forward(self, x):
#conv1,输出为112×112
out1 = F.relu(self.bn1(self.conv1(x)), inplace=True)
#conv2_x,输出为56×56
out1 = F.max_pool2d(out1, kernel_size=3, stride=2, padding=1)
out2 = self.layer1(out1)
#conv_3,输出为28×28
out3 = self.layer2(out2)
#conv_4,输出为14×14
out4 = self.layer3(out3)
#conv_5,输出为7×7
out5 = self.layer4(out4)
return out1, out2, out3, out4, out5
def initialize(self):
#加载预训练模型的权重,允许部分权重匹配(strict=False)
self.load_state_dict(torch.load('resnet50-19c8e357.pth'), strict=False)
3.class SRM:SR模块
class SRM(nn.Module)
实现自细化模块,用于将HA模块(一个)和FIA模块(三个)得到的特征图进一步细化和增强。
""" Self Refinement Module """
class SRM(nn.Module):
def __init__(self, in_channel):
super(SRM, self).__init__()
self.conv1 = nn.Conv2d(in_channel, 256, kernel_size=3, stride=1, padding=1)
self.bn1 = nn.BatchNorm2d(256)
self.conv2 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1)
def forward(self, x):
#先将输入特征压缩为256通道大小,再分别通过Batch Normalization、ReLU层
out1 = F.relu(self.bn1(self.conv1(x)), inplace=True)
#经过卷积运算转为512通道
out2 = self.conv2(out1)
#将前256通道作为权重,后256通道作为偏置0
w, b = out2[:, :256, :, :], out2[:, 256:, :, :]
#加权结合out1、w、b,并应用ReLU激活函数得到输出
return F.relu(w * out1 + b, inplace=True)
def initialize(self):
weight_init(self)
4.class FAM:FIA模块
class FAM(nn.Module)
定义特征交织聚合模块,用于融合低级特征、高级特征、上下文特征,从而产生具有全局感知的区分性和综合性特征。
""" Feature Interweaved Aggregation Module """
class FAM(nn.Module):
def __init__(self, in_channel_left, in_channel_down, in_channel_right):
#接受左、下、右三个方向的输入通道数(对应低级特征、高级特征、全局特征)
super(FAM, self).__init__()
#对低级特征f_l进行卷积、归一化
self.conv0 = nn.Conv2d(in_channel_left, 256, kernel_size=3, stride=1, padding=1)
self.bn0 = nn.BatchNorm2d(256)
#对高级特征f_h进行卷积、归一化
self.conv1 = nn.Conv2d(in_channel_down, 256, kernel_size=3, stride=1, padding=1)
self.bn1 = nn.BatchNorm2d(256)
#对全局特征f_g进行卷积、归一化
self.conv2 = nn.Conv2d(in_channel_right, 256, kernel_size=3, stride=1, padding=1)
self.bn2 = nn.BatchNorm2d(256)
self.conv_d1 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
self.conv_d2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
self.conv_l = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
self.conv3 = nn.Conv2d(256*3, 256, kernel_size=3, stride=1, padding=1)
self.bn3 = nn.BatchNorm2d(256)
def forward(self, left, down, right):
#依次将低级特征f_l、高级特征f_h、全局特征f_g卷积、归一化、ReLU激活,并压缩到256通道
left = F.relu(self.bn0(self.conv0(left)), inplace=True)
down = F.relu(self.bn1(self.conv1(down)), inplace=True)
right = F.relu(self.bn2(self.conv2(right)), inplace=True) #256
#上采样高级特征图
down_1 = self.conv_d1(down)
#对left特征图卷积,得到分割掩码w1
w1 = self.conv_l(left)
#检查高级特征图和低级特征图的空间维度,不匹配则使用线性插值调整高级特征图的大小.将分割掩码w1与高级特征图相乘并使用ReLU激活函数,得到f_{hl}
if down.size()[2:] != left.size()[2:]:
down_ = F.interpolate(down, size=left.size()[2:], mode='bilinear')
z1 = F.relu(w1 * down_, inplace=True)
else:
z1 = F.relu(w1 * down, inplace=True)
#将上采样后的高级特征图调整至与低级特征图相同的维度
if down_1.size()[2:] != left.size()[2:]:
down_1 = F.interpolate(down_1, size=left.size()[2:], mode='bilinear')
#将高级特征图与低级特征图相乘得到f_{lh}
z2 = F.relu(down_1 * left, inplace=True)
#上采样全局特征图
down_2 = self.conv_d2(right)
if down_2.size()[2:] != left.size()[2:]:
down_2 = F.interpolate(down_2, size=left.size()[2:], mode='bilinear')
#将全局特征图与低级特征图相乘得到f_{gl}
z3 = F.relu(down_2 * left, inplace=True)
#将三个结果cat
out = torch.cat((z1, z2, z3), dim=1)
#输入卷积层运算并返回
return F.relu(self.bn3(self.conv3(out)), inplace=True)
def initialize(self):
weight_init(self)
5.class CA:GCF模块
class CA(nn.Module)
对应模块
G
C
F
GCF
GCF,用于从
R
e
s
N
e
t
50
ResNet50
ResNet50提取的特征中捕获全局上下文信息,并输入到每个阶段的FIA模块。计算公式如下:
- f t o p f_{top} ftop:输入特征1。
- f g a p f_{gap} fgap:输入特征2。
class CA(nn.Module):
def __init__(self, in_channel_left, in_channel_down):
#in_channel_left:f_{top}通道数;in_channel_down:f_{gap}通道数
super(CA, self).__init__()
self.conv0 = nn.Conv2d(in_channel_left, 256, kernel_size=1, stride=1, padding=0)
self.bn0 = nn.BatchNorm2d(256)
self.conv1 = nn.Conv2d(in_channel_down, 256, kernel_size=1, stride=1, padding=0)
self.conv2 = nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0)
def forward(self, left, down):
#对f_{top}进行Conv+Batch Normlization+ReLU
left = F.relu(self.bn0(self.conv0(left)), inplace=True)
#平均池化,减少空间维度(H、W下降)
down = down.mean(dim=(2,3), keepdim=True)
#卷积+激活
down = F.relu(self.conv1(down), inplace=True)
#将输出值归一化到0-1之间
down = torch.sigmoid(self.conv2(down))
return left * down
def initialize(self):
weight_init(self)
6.class SA:HA模块
编码器顶层特征通常对于显著性目标检测是多余的,HA模块可利用空间和通道注意力机制来学习更多选择性和代表性的特征。计算公式:
代码中类SA仅获取
F
1
F1
F1,而
F
1
F1
F1与
f
f
f的计算由GCF模块(对应类CA)实现。
class SA(nn.Module):
def __init__(self, in_channel_left, in_channel_down):
super(SA, self).__init__()
self.conv0 = nn.Conv2d(in_channel_left, 256, kernel_size=3, stride=1, padding=1)
self.bn0 = nn.BatchNorm2d(256)
self.conv2 = nn.Conv2d(in_channel_down, 512, kernel_size=3, stride=1, padding=1)
def forward(self, left, down):
#left、down都是由ResNet提取的特征
#与SR模块相同操作
left = F.relu(self.bn0(self.conv0(left)), inplace=True) #256 channels
down_1 = self.conv2(down)
#检查down_1的空间尺寸是否与left相同.如果不同,则使用双线性插值调整down_1的尺寸.
if down_1.size()[2:] != left.size()[2:]:
down_1 = F.interpolate(down_1, size=left.size()[2:], mode='bilinear')
#与SR模块相同,分别获取权重w、b
w,b = down_1[:,:256,:,:], down_1[:,256:,:,:]
#得到F1
return F.relu(w*left+b, inplace=True)
def initialize(self):
weight_init(self)
7.class GCPANet:网络架构
class GCPANet(nn.Module)
定义了GCPANet的模型架构。
class GCPANet(nn.Module):
def __init__(self, cfg):
super(GCPANet, self).__init__()
self.cfg = cfg
#ResNet50:进行特征提取
self.bkbone = ResNet()
#GCF:初始化多个通道注意力模块(CA)、空间注意力模块(SA)用于特征加权
self.ca45 = CA(2048, 2048)
self.ca35 = CA(2048, 2048)
self.ca25 = CA(2048, 2048)
self.ca55 = CA(256, 2048)
self.sa55 = SA(2048, 2048)
#FIA:初始化特征交织聚合模块,用于处理不同层次的特征
self.fam45 = FAM(1024, 256, 256)
self.fam34 = FAM( 512, 256, 256)
self.fam23 = FAM( 256, 256, 256)
#SR:初始化自细化模块,用于对特征进行处理和提升
self.srm5 = SRM(256)
self.srm4 = SRM(256)
self.srm3 = SRM(256)
self.srm2 = SRM(256)
#四个卷积层,将特征图(256通道)映射为单通道输出
self.linear5 = nn.Conv2d(256, 1, kernel_size=3, stride=1, padding=1)
self.linear4 = nn.Conv2d(256, 1, kernel_size=3, stride=1, padding=1)
self.linear3 = nn.Conv2d(256, 1, kernel_size=3, stride=1, padding=1)
self.linear2 = nn.Conv2d(256, 1, kernel_size=3, stride=1, padding=1)
#初始化权重
self.initialize()
def forward(self, x):
#使用骨干网络ResNet提取多层次特征
out1, out2, out3, out4, out5_ = self.bkbone(x)
# GCF
out4_a = self.ca45(out5_, out5_)
out3_a = self.ca35(out5_, out5_)
out2_a = self.ca25(out5_, out5_)
# HA
out5_a = self.sa55(out5_, out5_)
out5 = self.ca55(out5_a, out5_)
#FIA+SR
out5 = self.srm5(out5)
out4 = self.srm4(self.fam45(out4, out5, out4_a))
out3 = self.srm3(self.fam34(out3, out4, out3_a))
out2 = self.srm2(self.fam23(out2, out3, out2_a))
#将四个阶段SR模块的输出线性插值,得到与原始图像有相同大小的特征图
out5 = F.interpolate(self.linear5(out5), size=x.size()[2:], mode='bilinear')
out4 = F.interpolate(self.linear4(out4), size=x.size()[2:], mode='bilinear')
out3 = F.interpolate(self.linear3(out3), size=x.size()[2:], mode='bilinear')
out2 = F.interpolate(self.linear2(out2), size=x.size()[2:], mode='bilinear')
#返回四张特征图
return out2, out3, out4, out5
def initialize(self):
if self.cfg.snapshot:
try:
self.load_state_dict(torch.load(self.cfg.snapshot))
except:
print("Warning: please check the snapshot file:", self.cfg.snapshot)
pass
else:
weight_init(self)
train.py
import sys
import datetime
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tensorboardX import SummaryWriter
from data import dataset
from net import GCPANet
import logging as logger
from lib.data_prefetcher import DataPrefetcher
from lib.lr_finder import LRFinder
import numpy as np
import matplotlib.pyplot as plt
#设置日志文件标签和保存路径
TAG = "ours"
SAVE_PATH = "ours"
#配置日志记录的格式和输出文件
logger.basicConfig(level=logger.INFO, format='%(levelname)s %(asctime)s %(filename)s: %(lineno)d] %(message)s', datefmt='%Y-%m-%d %H:%M:%S', \
filename="train_%s.log"%(TAG), filemode="w")
#学习率更新策略
def get_triangle_lr(base_lr, max_lr, total_steps, cur, ratio=1., \
annealing_decay=1e-2, momentums=[0.95, 0.85]):
first = int(total_steps*ratio)
last = total_steps - first
min_lr = base_lr * annealing_decay
cycle = np.floor(1 + cur/total_steps)
x = np.abs(cur*2.0/total_steps - 2.0*cycle + 1)
if cur < first:
lr = base_lr + (max_lr - base_lr) * np.maximum(0., 1.0 - x)
else:
lr = ((base_lr - min_lr)*cur + min_lr*first - base_lr*total_steps)/(first - total_steps)
if isinstance(momentums, int):
momentum = momentums
else:
if cur < first:
momentum = momentums[0] + (momentums[1] - momentums[0]) * np.maximum(0., 1.-x)
else:
momentum = momentums[0]
return lr, momentum
#设置基本学习率、最大学习率和是否进行学习率查找的标志
BASE_LR = 1e-3
MAX_LR = 0.1
FIND_LR = False
#训练函数,参数为数据集、网络模型
def train(Dataset, Network):
#配置数据集参数
cfg = Dataset.Config(datapath='./data/DUTS', savepath=SAVE_PATH, mode='train', batch=8, lr=0.05, momen=0.9, decay=5e-4, epoch=30)
#创建数据集实例和数据加载器
data = Dataset.Data(cfg)
loader = DataLoader(data, batch_size=cfg.batch, shuffle=True, num_workers=8)
#初始化数据预取器并提高数据加载效率
prefetcher = DataPrefetcher(loader)
#创建模型、设为训练模式、转移到GPU
net = Network(cfg)
net.train(True)
net.cuda()
#根据参数名称将参数分为基础参数和头部参数
base, head = [], []
for name, param in net.named_parameters():
if 'bkbone' in name:
base.append(param)
else:
head.append(param)
#为基础参数和头部参数定义优化器
optimizer = torch.optim.SGD([{'params':base}, {'params':head}], lr=cfg.lr, momentum=cfg.momen, weight_decay=cfg.decay, nesterov=True)
#记录训练过程中的指标
sw = SummaryWriter(cfg.savepath)
#全局步数计数器
global_step = 0
db_size = len(loader)
#若启用学习率查找,执行查找测试并绘制结果
if FIND_LR:
lr_finder = LRFinder(net, optimizer, criterion=None)
lr_finder.range_test(loader, end_lr=50, num_iter=100, step_mode="exp")
plt.ion()
lr_finder.plot()
import pdb; pdb.set_trace()
#进行训练
for epoch in range(cfg.epoch):
prefetcher = DataPrefetcher(loader)
batch_idx = -1
#获取图像及掩模
image, mask = prefetcher.next()
while image is not None:
niter = epoch * db_size + batch_idx
#获取当前迭代的学习率和动量
lr, momentum = get_triangle_lr(BASE_LR, MAX_LR, cfg.epoch*db_size, niter, ratio=1.)
optimizer.param_groups[0]['lr'] = 0.1 * lr #for backbone
optimizer.param_groups[1]['lr'] = lr
optimizer.momentum = momentum
batch_idx += 1
global_step += 1
#获取模型输出
out2, out3, out4, out5 = net(image)
#计算各个特征图对应的损失值
loss2 = F.binary_cross_entropy_with_logits(out2, mask)
loss3 = F.binary_cross_entropy_with_logits(out3, mask)
loss4 = F.binary_cross_entropy_with_logits(out4, mask)
loss5 = F.binary_cross_entropy_with_logits(out5, mask)
#根据权重计算综合损失
loss = loss2*1 + loss3*0.8 + loss4*0.6 + loss5*0.4
optimizer.zero_grad()
loss.backward()
optimizer.step()
#绘制曲线
sw.add_scalar('lr' , optimizer.param_groups[0]['lr'], global_step=global_step)
sw.add_scalars('loss', {'loss2':loss2.item(), 'loss3':loss3.item(), 'loss4':loss4.item(), 'loss5':loss5.item(), 'loss':loss.item()}, global_step=global_step)
#每10个批次打印一次训练信息
if batch_idx % 10 == 0:
msg = '%s | step:%d/%d/%d | lr=%.6f | loss=%.6f | loss2=%.6f | loss3=%.6f | loss4=%.6f | loss5=%.6f'%(datetime.datetime.now(), global_step, epoch+1, cfg.epoch, optimizer.param_groups[0]['lr'], loss.item(), loss2.item(), loss3.item(), loss4.item(), loss5.item())
print(msg)
#格式化并打印当前的训练状态
logger.info(msg)
#获取下一批数据
image, mask = prefetcher.next()
#每10个epoch 或最后一个epoch 保存模型权重
if (epoch+1)%10 == 0 or (epoch+1)==cfg.epoch:
torch.save(net.state_dict(), cfg.savepath+'/model-'+str(epoch+1))
if __name__=='__main__':
train(dataset, GCPANet)
test.py
class Test(object):
def __init__(self, Dataset, datapath, Network):
## dataset
self.datapath = datapath.split("/")[-1]
print("Testing on %s"%self.datapath)
self.cfg = Dataset.Config(datapath = datapath, snapshot=sys.argv[1], mode='test')
self.data = Dataset.Data(self.cfg)
self.loader = DataLoader(self.data, batch_size=1, shuffle=True, num_workers=8)
## network
self.net = Network(self.cfg)
self.net.train(False)
self.net.cuda()
self.net.eval()
#计算模型准确度
def accuracy(self):
with torch.no_grad():
#初始化指标
mae, fscore, cnt, number = 0, 0, 0, 256
mean_pr, mean_re, threshod = 0, 0, np.linspace(0, 1, number, endpoint=False)
cost_time = 0
for image, mask, (H, W), maskpath in self.loader:
image, mask = image.cuda().float(), mask.cuda().float()
#记录开始时间并前向传播
start_time = time.time()
out2, out3, out4, out5 = self.net(image)
pred = torch.sigmoid(out2)
torch.cuda.synchronize()
end_time = time.time()
#计算前向传播所需时间,并更新总时间
cost_time += end_time - start_time
#计算MAE
cnt += 1
mae += (pred-mask).abs().mean()
#计算精确率、召回率
precision = torch.zeros(number)
recall = torch.zeros(number)
for i in range(number):
temp = (pred >= threshod[i]).float()
precision[i] = (temp*mask).sum()/(temp.sum()+1e-12)
recall[i] = (temp*mask).sum()/(mask.sum()+1e-12)
mean_pr += precision
mean_re += recall
fscore = mean_pr*mean_re*(1+0.3)/(0.3*mean_pr+mean_re+1e-12)
#每20批次打印MAE、F-score和每秒帧数(fps)
if cnt % 20 == 0:
fps = image.shape[0] / (end_time - start_time)
print('MAE=%.6f, F-score=%.6f, fps=%.4f'%(mae/cnt, fscore.max()/cnt, fps))
#计算整体FPS并打印最终结果(数据集路径、MAE 和 F-score)
fps = len(self.loader.dataset) / cost_time
msg = '%s MAE=%.6f, F-score=%.6f, len(imgs)=%s, fps=%.4f'%(self.datapath, mae/cnt, fscore.max()/cnt, len(self.loader.dataset), fps)
print(msg)
logger.info(msg)
#将预测结果保存为图像
def save(self):
with torch.no_grad():
for image, mask, (H, W), name in self.loader:
out2, out3, out4, out5 = self.net(image.cuda().float())
out2 = F.interpolate(out2, size=(H,W), mode='bilinear')
pred = (torch.sigmoid(out2[0,0])*255).cpu().numpy()
head = './pred_maps/{}/'.format(TAG) + self.cfg.datapath.split('/')[-1]
if not os.path.exists(head):
os.makedirs(head)
cv2.imwrite(head+'/'+name[0],np.uint8(pred))
if __name__=='__main__':
for e in DATASETS:
t =Test(dataset, e, GCPANet)
t.accuracy()
t.save()