代码来自:GitHub - ChuHan89/WSSS-Tissue
借助了一些人工智能
2_generate_PM.py
功能总结
该代码用于 生成弱监督语义分割(WSSS)所需的伪掩码(Pseudo-Masks),是 Stage2 训练的前置步骤。其核心流程为:
-
加载 Stage1 训练好的分类模型(支持 CAM 生成)。
-
为不同层次的特征图生成伪掩码(如
b4_5
,b5_2
,bn7
对应的不同网络层)。 -
保存伪掩码图像,使用调色板将类别标签映射为彩色图像。
代码解析
1. 导入依赖库
import os import torch import argparse import importlib from torch.backends import cudnn cudnn.enabled = True # 启用CUDA加速 from tool.infer_fun import create_pseudo_mask # 自定义函数:生成伪掩码
-
关键依赖:
-
cudnn.enabled = True
:启用 cuDNN 加速,优化 GPU 计算性能。 -
create_pseudo_mask
:核心函数(用户需参考其实现),负责生成并保存伪掩码。
-
2. 主函数与参数解析
if __name__ == '__main__': # 定义命令行参数 parser = argparse.ArgumentParser() parser.add_argument("--weights", default='checkpoints/stage1_checkpoint_trained_on_bcss.pth', type=str) parser.add_argument("--network", default="network.resnet38_cls", type=str) parser.add_argument("--dataroot", default="datasets/BCSS-WSSS/", type=str) parser.add_argument("--dataset", default="bcss", type=str) parser.add_argument("--num_workers", default=8, type=int) parser.add_argument("--n_class", default=4, type=int) args = parser.parse_args() print(args) # 打印参数列表
-
参数说明:
-
--weights
:Stage1 训练好的模型权重文件路径(默认指向 BCSS 数据集)。 -
--network
:网络结构定义文件(如network.resnet38_cls
)。 -
--dataroot
:数据集根目录(包含训练/测试数据)。 -
--dataset
:数据集标识(bcss
或luad
)。 -
--n_class
:类别数量(BCSS 为 4 类,LUAD 可能不同)。
-
3. 定义调色板(颜色映射)
if args.dataset == 'luad': palette = [0]*15 # 初始化长度为15的列表(每类3个RGB通道) palette[0:3] = [205,51,51] # 类别1:红色 palette[3:6] = [0,255,0] # 类别2:绿色 palette[6:9] = [65,105,225] # 类别3:蓝色 palette[9:12] = [255,165,0] # 类别4:橙色 palette[12:15] = [255, 255, 255] # 背景或未标注区域:白色 elif args.dataset == 'bcss': palette = [0]*15 palette[0:3] = [255, 0, 0] # 类别1:红色 palette[3:6] = [0,255,0] # 类别2:绿色 palette[6:9] = [0,0,255] # 类别3:蓝色 palette[9:12] = [153, 0, 255] # 类别4:紫色 palette[12:15] = [255, 255, 255] # 背景:白色
-
作用:将类别标签映射为 RGB 颜色,用于伪掩码的可视化。
-
细节:
-
每个类别占 3 个连续位置(RGB 通道)。
-
palette[12:15]
可能表示背景或未标注区域。 -
不同数据集使用不同的颜色方案(如 BCSS 用紫色表示第4类)。
-
4. 创建伪掩码保存路径
PMpath = os.path.join(args.dataroot, 'train_PM') # 路径示例:datasets/BCSS-WSSS/train_PM if not os.path.exists(PMpath): os.mkdir(PMpath) # 若目录不存在则创建
-
目的:在数据集根目录下创建
train_PM
文件夹,用于保存生成的伪掩码。
5. 加载模型
model = getattr(importlib.import_module("network.resnet38_cls"), 'Net_CAM')(n_class=args.n_class) model.load_state_dict(torch.load(args.weights), strict=False) model.eval() # 设置为评估模式(禁用Dropout等随机操作) model.cuda() # 将模型移至GPU
-
关键步骤:
-
动态加载模型:从
network.resnet38_cls
模块加载Net_CAM
类(支持 CAM 生成的变体)。 -
加载权重:使用 Stage1 训练好的模型参数(
strict=False
允许部分参数不匹配)。 -
评估模式:关闭 BatchNorm 和 Dropout 的随机性,确保结果一致性。
-
6. 生成多级伪掩码
## fm = 'b4_5' # 特征模块名称(可能对应网络中的某个中间层) savepath = os.path.join(PMpath, 'PM_' + fm) # 保存路径:train_PM/PM_b4_5 if not os.path.exists(savepath): os.mkdir(savepath) create_pseudo_mask(model, args.dataroot, fm, savepath, args.n_class, palette, args.dataset) ## 重复相同流程生成其他层级的伪掩码 fm = 'b5_2' savepath = os.path.join(PMpath, 'PM_' + fm) if not os.path.exists(savepath): os.mkdir(savepath) create_pseudo_mask(model, args.dataroot, fm, savepath, args.n_class, palette, args.dataset) ## fm = 'bn7' savepath = os.path.join(PMpath, 'PM_' + fm) if not os.path.exists(savepath): os.mkdir(savepath) create_pseudo_mask(model, args.dataroot, fm, savepath, args.n_class, palette, args.dataset)
-
功能:针对不同特征模块(
fm
)生成伪掩码,保存到对应子目录。 -
关键参数:
-
fm
:特征模块标识,可能对应网络中的不同层(如 ResNet 的block4
、block5
或bottleneck
)。 -
create_pseudo_mask
:核心函数,推测其功能为:-
加载训练集图像。
-
使用模型提取指定层的特征图。
-
生成类别激活图(CAM)。
-
根据阈值将 CAM 转换为二值伪掩码。
-
应用调色板将掩码保存为彩色 PNG 图像。
-
-
代码执行示例
python generate_pseudo_masks.py \ --dataset bcss \ --dataroot datasets/BCSS-WSSS/ \ --weights checkpoints/stage1_checkpoint_trained_on_bcss.pth
-
输出:在
datasets/BCSS-WSSS/train_PM/
下生成三个子目录:-
PM_b4_5
:基于b4_5
层特征的伪掩码。 -
PM_b5_2
:基于b5_2
层特征的伪掩码。 -
PM_bn7
:基于bn7
层特征的伪掩码。
-
总结
该代码是弱监督语义分割流程中 生成多级伪掩码的关键步骤,利用 Stage1 训练的分类模型提取不同层级的特征,生成伪标签供 Stage2 的分割模型训练。通过多级伪掩码的融合,可以提升最终分割结果的精度和鲁棒性。
3_train_stage2.py
功能总结
该代码是弱监督语义分割(WSSS)的 Stage2 训练与测试脚本,核心功能为:
-
训练分割模型:基于 DeepLab v3+ 架构,使用 Stage1 生成的伪掩码(Pseudo-Masks)进行监督训练。
-
验证与测试:评估模型在验证集和测试集上的性能(如 mIoU、像素准确率等)。
-
门控机制(Gate Mechanism):在测试阶段结合 Stage1 的分类结果过滤分割预测,提升精度。
-
多任务损失:融合不同层次伪掩码的损失(主伪掩码 + 两种增强版本)。
代码结构
# 1. 依赖库导入 import argparse, os, numpy as np from tqdm import tqdm import torch from tool.GenDataset import make_data_loader from network.sync_batchnorm.replicate import patch_replication_callback from network.deeplab import * from tool.loss import SegmentationLosses from tool.lr_scheduler import LR_Scheduler from tool.saver import Saver from tool.summaries import TensorboardSummary from tool.metrics import Evaluator # 2. 定义训练器类 class Trainer(object): def __init__(self, args): ... # 初始化模型、数据、优化器等 def training(self, epoch): ... # 训练一个epoch def validation(self, epoch): ... # 验证集评估 def test(self, epoch, Is_GM): ... # 测试集评估(支持门控机制) def load_the_best_checkpoint(self): ... # 加载最佳模型 # 3. 主函数 def main(): ... # 解析参数、启动训练 if __name__ == "__main__": main()
关键代码解析
1. Trainer
类初始化
class Trainer(object): def __init__(self, args): self.args = args # 初始化日志记录与模型保存工具 self.saver = Saver(args) # 保存模型检查点 self.summary = TensorboardSummary('logs') # TensorBoard日志 self.writer = self.summary.create_summary() # 数据加载 kwargs = {'num_workers': args.workers, 'pin_memory': False} self.train_loader, self.val_loader, self.test_loader = make_data_loader(args, **kwargs) # 模型定义(DeepLab v3+) self.nclass = args.n_class model = DeepLab( num_classes=self.nclass, backbone=args.backbone, # 骨干网络(如ResNet) output_stride=args.out_stride, # 输出步长(控制特征图分辨率) sync_bn=args.sync_bn, # 多GPU同步BatchNorm freeze_bn=args.freeze_bn # 冻结BN层参数 ) # 优化器配置(分层学习率) train_params = [ {'params': model.get_1x_lr_params(), 'lr': args.lr}, # 骨干网络低学习率 {'params': model.get_10x_lr_params(), 'lr': args.lr * 10} # 分类头高学习率 ] optimizer = torch.optim.SGD( train_params, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=args.nesterov ) # 损失函数(交叉熵或Focal Loss) self.criterion = SegmentationLosses(weight=None, cuda=args.cuda).build_loss(mode=args.loss_type) self.model, self.optimizer = model, optimizer # 评估工具(计算mIoU等指标) self.evaluator = Evaluator(self.nclass) # 学习率调度(Poly策略) self.scheduler = LR_Scheduler( args.lr_scheduler, args.lr, args.epochs, len(self.train_loader) ) # 加载Stage1的分类模型(用于门控机制) model_stage1 = getattr(importlib.import_module('network.resnet38_cls'), 'Net_CAM')(n_class=4) resume_stage1 = 'checkpoints/stage1_checkpoint_trained_on_'+str(args.dataset)+'.pth' weights_dict = torch.load(resume_stage1) model_stage1.load_state_dict(weights_dict) self.model_stage1 = model_stage1.cuda() self.model_stage1.eval() # 固定Stage1模型参数 # GPU并行化 if args.cuda: self.model = torch.nn.DataParallel(self.model, device_ids=self.args.gpu_ids) patch_replication_callback(self.model) # 修复多GPU BatchNorm同步问题 self.model = self.model.cuda() # 加载预训练权重(如DeepLab预训练模型) if args.resume is not None: checkpoint = torch.load(args.resume) # 处理分类头权重(微调时保留,否则删除) if args.ft: self.model.load_state_dict(checkpoint['state_dict']) self.optimizer.load_state_dict(checkpoint['optimizer']) else: del checkpoint['state_dict']['decoder.last_conv.8.weight'] del checkpoint['state_dict']['decoder.last_conv.8.bias'] self.model.load_state_dict(checkpoint['state_dict'], strict=False) # 初始化最佳mIoU self.best_pred = 0.0
2. 训练阶段 training
def training(self, epoch): train_loss = 0.0 self.model.train() tbar = tqdm(self.train_loader) # 进度条 num_img_tr = len(self.train_loader) for i, sample in enumerate(tbar): # 加载数据(图像 + 三个伪掩码) image, target, target_a, target_b = sample['image'], sample['label'], sample['label_a'], sample['label_b'] if self.args.cuda: image, target, target_a, target_b = image.cuda(), target.cuda(), target_a.cuda(), target_b.cuda() # 调整学习率 self.scheduler(self.optimizer, i, epoch, self.best_pred) self.optimizer.zero_grad() # 前向传播 output = self.model(image) # 添加额外通道处理类别4(背景或忽略类) one = torch.ones((output.shape[0],1,224,224)).cuda() output = torch.cat([output, (100 * one * (target==4).unsqueeze(dim=1)], dim=1) # 计算多任务损失(主伪掩码 + 两种增强版本) loss_o = self.criterion(output, target) loss_a = self.criterion(output, target_a) loss_b = self.criterion(output, target_b) loss = 0.6*loss_o + 0.2*loss_a + 0.2*loss_b # 反向传播 loss.backward() self.optimizer.step() # 统计损失 train_loss += loss.item() tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1))) # 记录TensorBoard日志 self.writer.add_scalar('train/total_loss_iter', loss.item(), i + num_img_tr * epoch) # 输出epoch总结 self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch) print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0])) print('Loss: %.3f' % train_loss)
3. 验证阶段 validation
def validation(self, epoch): self.model.eval() self.evaluator.reset() tbar = tqdm(self.val_loader, desc='\r') test_loss = 0.0 for i, sample in enumerate(tbar): image, target = sample[0]['image'], sample[0]['label'] if self.args.cuda: image, target = image.cuda(), target.cuda() with torch.no_grad(): output = self.model(image) # 转换为CPU numpy数组 pred = output.data.cpu().numpy() target = target.cpu().numpy() pred = np.argmax(pred, axis=1) # 处理类别4(设为忽略类) pred[target==4] = 4 # 更新评估指标 self.evaluator.add_batch(target, pred) # 计算并记录指标 Acc = self.evaluator.Pixel_Accuracy() Acc_class = self.evaluator.Pixel_Accuracy_Class() mIoU = self.evaluator.Mean_Intersection_over_Union() ious = self.evaluator.Intersection_over_Union() FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union() # 输出结果 print('Validation:') print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(Acc, Acc_class, mIoU, FWIoU)) # 保存最佳模型 if mIoU > self.best_pred: self.best_pred = mIoU self.saver.save_checkpoint({ 'state_dict': self.model.module.state_dict(), 'optimizer': self.optimizer.state_dict() }, 'stage2_checkpoint_trained_on_'+self.args.dataset+'.pth')
4. 测试阶段 test
(含门控机制)
def test(self, epoch, Is_GM): self.load_the_best_checkpoint() # 加载最佳模型 self.model.eval() self.evaluator.reset() tbar = tqdm(self.test_loader, desc='\r') for i, sample in enumerate(tbar): image, target = sample[0]['image'], sample[0]['label'] if self.args.cuda: image, target = image.cuda(), target.cuda() with torch.no_grad(): output = self.model(image) # 门控机制:利用Stage1的分类结果过滤分割预测 if Is_GM: _, y_cls = self.model_stage1.forward_cam(image) # Stage1的分类输出 y_cls = y_cls.cpu().data pred_cls = (y_cls > 0.1) # 类别存在性判断(阈值0.1) # 应用门控机制 pred = output.data.cpu().numpy() if Is_GM: pred = pred * pred_cls.unsqueeze(dim=2).unsqueeze(dim=3).numpy() # 处理类别4 pred = np.argmax(pred, axis=1) pred[target==4] = 4 self.evaluator.add_batch(target, pred) # 计算并输出指标 Acc = self.evaluator.Pixel_Accuracy() Acc_class = self.evaluator.Pixel_Accuracy_Class() mIoU = self.evaluator.Mean_Intersection_over_Union() print('Test:') print("Acc:{}, Acc_class:{}, mIoU:{}".format(Acc, Acc_class, mIoU))
5. 主函数 main
def main(): # 解析命令行参数 parser = argparse.ArgumentParser(description="WSSS Stage2") # 模型结构参数 parser.add_argument('--backbone', default='resnet', choices=['resnet', 'xception', 'drn', 'mobilenet']) parser.add_argument('--out-stride', type=int, default=16) # 输出步长(控制特征图下采样率) parser.add_argument('--Is_GM', type=bool, default=True) # 是否启用门控机制 # 数据集参数 parser.add_argument('--dataroot', default='datasets/BCSS-WSSS/') parser.add_argument('--dataset', default='bcss') parser.add_argument('--n_class', type=int, default=4) # 训练超参数 parser.add_argument('--epochs', type=int, default=30) parser.add_argument('--batch-size', type=int, default=20) parser.add_argument('--lr', type=float, default=0.01) parser.add_argument('--lr-scheduler', default='poly', choices=['poly', 'step', 'cos']) # 其他配置 parser.add_argument('--gpu-ids', default='0') # 指定使用的GPU parser.add_argument('--resume', default='init_weights/deeplab-resnet.pth.tar') # 预训练权重 args = parser.parse_args() # 配置CUDA args.cuda = not args.no_cuda and torch.cuda.is_available() if args.cuda: args.gpu_ids = [int(s) for s in args.gpu_ids.split(',')] # 自动设置SyncBN if args.sync_bn is None: args.sync_bn = True if args.cuda and len(args.gpu_ids) > 1 else False # 初始化训练器并启动训练 trainer = Trainer(args) for epoch in range(trainer.args.epochs): trainer.training(epoch) if epoch % args.eval_interval == 0: trainer.validation(epoch) # 最终测试 trainer.test(epoch, args.Is_GM) trainer.writer.close()
关键设计解析
-
多任务损失:
-
目标:同时优化主伪掩码(
target
)及其两种增强版本(target_a
,target_b
),提升模型对不同噪声伪标签的鲁棒性。 -
权重分配:主损失占60%,增强损失各占20%(
0.6*loss_o + 0.2*loss_a + 0.2*loss_b
)。
-
-
门控机制(Gate Mechanism):
-
作用:在测试阶段,利用 Stage1 的分类结果过滤分割预测,仅保留分类模型认为存在的类别。
-
实现:若 Stage1 对某类别的预测概率 > 0.1,则保留该类的分割结果,否则置零。
-
-
类别4处理:
-
背景或忽略类:在标签中,类别4可能表示背景或未标注区域,预测时直接继承真实标签的值(
pred[target==4] = 4
),避免错误优化。
-
-
模型初始化:
-
预训练权重:加载 DeepLab 在 ImageNet 上的预训练权重(
init_weights/deeplab-resnet.pth.tar
),加速收敛。 -
分层学习率:骨干网络使用较低学习率(
args.lr
),分类头使用更高学习率(args.lr * 10
)。
-
运行示例
python train_stage2.py \ --dataset bcss \ --dataroot datasets/BCSS-WSSS/ \ --backbone resnet \ --Is_GM True \ --batch-size 20 \ --epochs 30
总结
该代码实现了弱监督语义分割的第二阶段训练,通过多任务损失融合多级伪标签,结合门控机制提升测试精度,最终生成高精度分割模型。训练过程支持多GPU加速、Poly学习率调度及多种评估指标监控,适用于医学图像(如BCSS)或自然场景图像的分割任务。