基于Unet++在kaggle—2018dsb数据集上实现图像分割

news2024/12/23 22:07:16

目录

  • 1. 作者介绍
  • 2. 理论知识介绍
    • 2.1 Unet++模型介绍
  • 3. 实验过程
    • 3.1 数据集介绍
    • 3.2 代码实现
    • 3.3 结果
  • 4. 参考链接

1. 作者介绍

郭冠群,男,西安工程大学电子信息学院,2023级研究生
研究方向:机器视觉与人工智能
电子邮件:1347418097@qq.com

路治东,男,西安工程大学电子信息学院,2022级研究生,张宏伟人工智能课题组
研究方向:机器视觉与人工智能
电子邮件:2063079527@qq.com

2. 理论知识介绍

2.1 Unet++模型介绍

  • Unet
    语义分割是将图像划分为有意义的区域,并标注每个区域所属的类别。语义分割网络是实现这一任务的工具,其中Unet模型通过跨阶段融合不同尺寸的特征图来实现这一目标。
    在这里插入图片描述
  • 特征图融合
    特征图融合的目的是结合浅层和深层特征,提升分割效果。浅层特征能提取图像的简单特征如边界和颜色,而深层特征提取图像的深层次语义信息。多个特征图的融合能够弥补单一特征层次信息的不足。
  • Unet++
    Unet++通过嵌套的密集跳过路径连接编码器和解码器子网络,减少了特征映射之间的语义差距,从而提高了分割效果。在测试阶段,由于输入图像只进行前向传播,被剪掉的部分对前面输出没有影响,而在训练阶段,这些部分会帮助其他部分进行权重更新。
    在这里插入图片描述

3. 实验过程

3.1 数据集介绍

  • 数据集来源
    Kaggle—2018dsb数据集来自于2018年数据科学碗,其任务是从显微镜图像中分割细胞核。这对于推动医学发现具有重要意义,特别是在病理学、癌症研究和其他生命科学领域。
    在这里插入图片描述
  • 下载途径

百度网盘 链接:https://pan.baidu.com/s/1GXtZ0clE12oZKooF61siKQ
提取码:tsh7

  • 数据集内容
    数据集包含显微镜下细胞图像及其对应的分割掩码。训练集用于训练模型,测试集用于评估模型性能。
    在这里插入图片描述

3.2 代码实现

  1. train.py
import os
import argparse
from glob import glob
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from tqdm import tqdm
import albumentations as albu
from albumentations.core.composition import Compose, OneOf
from sklearn.model_selection import train_test_split
import archs
import losses
from dataset import CustomDataset
from metrics import iou_score
from utils import AverageMeter, str2bool

class Config:
    @staticmethod
    def from_cmdline():
        parser = argparse.ArgumentParser(description='Training configuration')
        parser.add_argument('--name', default=None, help='Model name: (default: arch+timestamp)')
        parser.add_argument('--epochs', type=int, default=100, help='Number of total epochs to run')
        parser.add_argument('-b', '--batch_size', type=int, default=8, help='Mini-batch size (default: 16)')
        parser.add_argument('--arch', default='NestedUNet', choices=archs.__all__, help='Model architecture')
        parser.add_argument('--deep_supervision', type=str2bool, default=False, help='Use deep supervision if True')
        parser.add_argument('--input_channels', type=int, default=3, help='Number of input channels')
        parser.add_argument('--num_classes', type=int, default=1, help='Number of classes')
        parser.add_argument('--input_w', type=int, default=96, help='Input image width')
        parser.add_argument('--input_h', type=int, default=96, help='Input image height')
        parser.add_argument('--loss', default='BCEDiceLoss', choices=losses.__all__, help='Loss function')
        parser.add_argument('--dataset', default='dsb2018_96', help='Dataset name')
        parser.add_argument('--img_ext', default='.png', help='Image file extension')
        parser.add_argument('--mask_ext', default='.png', help='Mask file extension')
        parser.add_argument('--optimizer', default='SGD', choices=['Adam', 'SGD'], help='Optimizer type')
        parser.add_argument('--lr', '--learning_rate', type=float, default=1e-3, help='Initial learning rate')
        parser.add_argument('--momentum', type=float, default=0.9, help='Optimizer momentum')
        parser.add_argument('--weight_decay', type=float, default=1e-4, help='Weight decay rate')
        parser.add_argument('--nesterov', type=str2bool, default=False, help='Nesterov momentum')
        parser.add_argument('--scheduler', default='CosineAnnealingLR',
                            choices=['CosineAnnealingLR', 'ReduceLROnPlateau', 'MultiStepLR', 'ConstantLR'],
                            help='Learning rate scheduler')
        parser.add_argument('--min_lr', type=float, default=1e-5, help='Minimum learning rate')
        parser.add_argument('--factor', type=float, default=0.1, help='Factor for ReduceLROnPlateau')
        parser.add_argument('--patience', type=int, default=2, help='Patience for ReduceLROnPlateau')
        parser.add_argument('--milestones', type=str, default='1,2', help='Milestones for MultiStepLR')
        parser.add_argument('--gamma', type=float, default=2 / 3, help='Gamma for MultiStepLR')
        parser.add_argument('--early_stopping', type=int, default=-1, help='Early stopping threshold')
        parser.add_argument('--num_workers', type=int, default=0, help='Number of data loading workers')

        args = parser.parse_args()
        return vars(args)


class ModelManager:
    def __init__(self, config):
        self.config = config
        self.model = self.create_model().cuda()
        self.criterion = self.create_criterion().cuda()
        self.optimizer = self.create_optimizer()
        self.scheduler = self.create_scheduler()

    def create_model(self):
        return archs.__dict__[self.config['arch']](
            self.config['num_classes'],
            self.config['input_channels'],
            self.config['deep_supervision']
        )

    def create_criterion(self):
        if self.config['loss'] == 'BCEWithLogitsLoss':
            return nn.BCEWithLogitsLoss()
        else:
            return losses.__dict__[self.config['loss']]()

    def create_optimizer(self):
        params = filter(lambda p: p.requires_grad, self.model.parameters())
        if self.config['optimizer'] == 'Adam':
            return optim.Adam(params, lr=self.config['lr'], weight_decay=self.config['weight_decay'])
        elif self.config['optimizer'] == 'SGD':
            return optim.SGD(params, lr=self.config['lr'], momentum=self.config['momentum'],
                             nesterov=self.config['nesterov'], weight_decay=self.config['weight_decay'])

    def create_scheduler(self):
        if self.config['scheduler'] == 'CosineAnnealingLR':
            return lr_scheduler.CosineAnnealingLR(
                self.optimizer, T_max=self.config['epochs'], eta_min=self.config['min_lr'])
        elif self.config['scheduler'] == 'ReduceLROnPlateau':
            return lr_scheduler.ReduceLROnPlateau(
                self.optimizer, factor=self.config['factor'], patience=self.config['patience'],
                min_lr=self.config['min_lr'])
        elif self.config['scheduler'] == 'MultiStepLR':
            milestones = list(map(int, self.config['milestones'].split(',')))
            return lr_scheduler.MultiStepLR(self.optimizer, milestones=milestones, gamma=self.config['gamma'])


class DataManager:
    def __init__(self, config):
        self.config = config
        self.train_loader, self.val_loader = self.setup_loaders()

    def setup_loaders(self):
        img_ids = glob(os.path.join('inputs', self.config['dataset'], 'images', '*' + self.config['img_ext']))
        img_ids = [os.path.splitext(os.path.basename(p))[0] for p in img_ids]
        train_img_ids, val_img_ids = train_test_split(img_ids, test_size=0.2, random_state=41)

        train_transform = Compose([
            albu.RandomRotate90(), albu.Flip(),
            OneOf([albu.HueSaturationValue(), albu.RandomBrightnessContrast()], p=1),
            albu.Resize(self.config['input_h'], self.config['input_w']), albu.Normalize()])

        val_transform = Compose([
            albu.Resize(self.config['input_h'], self.config['input_w']), albu.Normalize()])

        train_dataset = CustomDataset(img_ids=train_img_ids,
                                      img_dir=os.path.join('inputs', self.config['dataset'], 'images'),
                                      mask_dir=os.path.join('inputs', self.config['dataset'], 'masks'),
                                      img_ext=self.config['img_ext'],
                                      mask_ext=self.config['mask_ext'],
                                      num_classes=self.config['num_classes'],
                                      transform=train_transform)
        val_dataset = CustomDataset(img_ids=val_img_ids,
                                    img_dir=os.path.join('inputs', self.config['dataset'], 'images'),
                                    mask_dir=os.path.join('inputs', self.config['dataset'], 'masks'),
                                    img_ext=self.config['img_ext'],
                                    mask_ext=self.config['mask_ext'],
                                    num_classes=self.config['num_classes'],
                                    transform=val_transform)

        train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=self.config['batch_size'],
                                                   shuffle=True, num_workers=self.config['num_workers'], drop_last=True)
        val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=self.config['batch_size'],
                                                 shuffle=False, num_workers=self.config['num_workers'], drop_last=False)
        return train_loader, val_loader


def main():
    config = Config.from_cmdline()
    manager = ModelManager(config)
    data_manager = DataManager(config)

    for epoch in range(config['epochs']):
        train_loss, train_iou = train_epoch(data_manager.train_loader, manager.model, manager.criterion,
                                            manager.optimizer, config)
        val_loss, val_iou = validate_epoch(data_manager.val_loader, manager.model, manager.criterion, config)
        print(f'Epoch: {epoch}, Train Loss: {train_loss}, Train IOU: {train_iou}, Val Loss: {val_loss}, Val IOU: {val_iou}')
        # Update scheduler, save models, etc.
        if config['scheduler'] == 'ReduceLROnPlateau':
            manager.scheduler.step(val_loss)
        else:
            manager.scheduler.step()


def train_epoch(train_loader, model, criterion, optimizer, config):
        model.train()
        avg_meters = {'loss': AverageMeter(), 'iou': AverageMeter()}
        pbar = tqdm(total=len(train_loader), desc='Train')

        for data in train_loader:
            if len(data) == 2:
                input, target = data
            elif len(data) > 2:
                input, target, _ = data  # 根据实际返回的数据格式解包
            input = input.cuda()
            target = target.cuda()

            # Compute output
            if config['deep_supervision']:
                outputs = model(input)
                loss = 0
                for output in outputs:
                    loss += criterion(output, target)
                loss /= len(outputs)
                iou = iou_score(outputs[-1], target)
            else:
                output = model(input)
                loss = criterion(output, target)
                iou = iou_score(output, target)

            # Compute gradient and do optimizer step
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            avg_meters['loss'].update(loss.item(), input.size(0))
            avg_meters['iou'].update(iou, input.size(0))
            pbar.update(1)
            pbar.set_postfix({'Loss': avg_meters['loss'].avg, 'IoU': avg_meters['iou'].avg})

        pbar.close()
        return avg_meters['loss'].avg, avg_meters['iou'].avg

def validate_epoch(val_loader, model, criterion, config):
        model.eval()
        avg_meters = {'loss': AverageMeter(), 'iou': AverageMeter()}
        pbar = tqdm(total=len(val_loader), desc='Validate')

        with torch.no_grad():
            for data in val_loader:
                if len(data) == 2:
                    input, target = data
                elif len(data) > 2:
                    input, target, _ = data  # 根据实际返回的数据格式解包
                input = input.cuda()
                target = target.cuda()

                # Compute output
                if config['deep_supervision']:
                    outputs = model(input)
                    loss = 0
                    for output in outputs:
                        loss += criterion(output, target)
                    loss /= len(outputs)
                    iou = iou_score(outputs[-1], target)
                else:
                    output = model(input)
                    loss = criterion(output, target)
                    iou = iou_score(output, target)

                avg_meters['loss'].update(loss.item(), input.size(0))
                avg_meters['iou'].update(iou, input.size(0))
                pbar.update(1)
                pbar.set_postfix({'Loss': avg_meters['loss'].avg, 'IoU': avg_meters['iou'].avg})

        pbar.close()
        return avg_meters['loss'].avg, avg_meters['iou'].avg

    # Training and validation logic can go here using manager and data_manager


if __name__ == '__main__':
    main()
  1. val.py
import argparse
import os
from glob import glob
import matplotlib.pyplot as plt
import numpy as np
import cv2
import torch
import torch.backends.cudnn as cudnn
import yaml
import albumentations as albu
from albumentations.core.composition import Compose
from sklearn.model_selection import train_test_split
from tqdm import tqdm

import archs
from dataset import CustomDataset
from metrics import iou_score
from utils import AverageMeter

"""
需要指定参数:--name dsb2018_96_NestedUNet_woDS
"""

def parse_args():
    parser = argparse.ArgumentParser()

    parser.add_argument('--name', default=None,
                        help='model name')

    args = parser.parse_args()

    return args


def main():
    args = parse_args()

    if args.name is None:
        print("Error: You must specify the model name using the --name argument.")
        return

    with open(f'models/{args.name}/config.yml', 'r') as f:
        config = yaml.load(f, Loader=yaml.FullLoader)

    print('-' * 20)
    for key in config.keys():
        print('%s: %s' % (key, str(config[key])))
    print('-' * 20)

    cudnn.benchmark = True

    # create model
    print("=> creating model %s" % config['arch'])
    model = archs.__dict__[config['arch']](config['num_classes'],
                                           config['input_channels'],
                                           config['deep_supervision'])

    model = model.cuda()

    # Data loading code
    img_ids = glob(os.path.join('inputs', config['dataset'], 'images', '*' + config['img_ext']))
    img_ids = [os.path.splitext(os.path.basename(p))[0] for p in img_ids]

    _, val_img_ids = train_test_split(img_ids, test_size=0.2, random_state=41)

    model.load_state_dict(torch.load(f'models/{config["name"]}/model.pth'))
    model.eval()

    val_transform = Compose([
        albu.Resize(config['input_h'], config['input_w']),
        albu.Normalize(),
    ])

    val_dataset = CustomDataset(
        img_ids=val_img_ids,
        img_dir=os.path.join('inputs', config['dataset'], 'images'),
        mask_dir=os.path.join('inputs', config['dataset'], 'masks'),
        img_ext=config['img_ext'],
        mask_ext=config['mask_ext'],
        num_classes=config['num_classes'],
        transform=val_transform
    )
    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=config['batch_size'],
        shuffle=False,
        num_workers=config['num_workers'],
        drop_last=False
    )

    avg_meter = AverageMeter()

    for c in range(config['num_classes']):
        os.makedirs(os.path.join('outputs', config['name'], str(c)), exist_ok=True)

    with torch.no_grad():
        for input, target, meta in tqdm(val_loader, total=len(val_loader)):
            input = input.cuda()
            target = target.cuda()

            # 将元组转换为字典
            meta_dict = {'img_id': meta}
            print(f"meta_dict: {meta_dict}")

            # compute output
            if config['deep_supervision']:
                output = model(input)[-1]
            else:
                output = model(input)

            iou = iou_score(output, target)
            avg_meter.update(iou, input.size(0))

            output = torch.sigmoid(output).cpu().numpy()

            for i in range(len(output)):
                for c in range(config['num_classes']):
                    cv2.imwrite(os.path.join('outputs', config['name'], str(c), str(meta_dict['img_id'][i]) + '.png'),
                                (output[i, c] * 255).astype('uint8'))

    print('IoU: %.4f' % avg_meter.avg)

    plot_examples(input, target, model, num_examples=3)

    torch.cuda.empty_cache()


def plot_examples(datax, datay, model, num_examples=6):
    fig, ax = plt.subplots(nrows=num_examples, ncols=3, figsize=(18, 4 * num_examples))
    m = datax.shape[0]
    for row_num in range(num_examples):
        image_indx = np.random.randint(m)
        image_arr = model(datax[image_indx:image_indx + 1]).squeeze(0).detach().cpu().numpy()
        ax[row_num][0].imshow(np.transpose(datax[image_indx].cpu().numpy(), (1, 2, 0))[:, :, 0])
        ax[row_num][0].set_title("Orignal Image")
        ax[row_num][1].imshow(np.squeeze((image_arr > 0.40)[0, :, :].astype(int)))
        ax[row_num][1].set_title("Segmented Image localization")
        ax[row_num][2].imshow(np.transpose(datay[image_indx].cpu().numpy(), (1, 2, 0))[:, :, 0])
        ax[row_num][2].set_title("Target image")
    plt.show()


if __name__ == '__main__':
    main()
  1. archs.py
import torch
from torch import nn

__all__ = ['UNet', 'NestedUNet']

# 基本计算单元
class VGGBlock(nn.Module):
    def __init__(self, in_channels, middle_channels, out_channels):
        super().__init__()
        self.relu = nn.ReLU(inplace=True)
        self.conv1 = nn.Conv2d(in_channels, middle_channels, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(middle_channels)
        self.conv2 = nn.Conv2d(middle_channels, out_channels, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        # VGGBlock实际上就是相当于做了两次卷积
        out = self.conv1(x)
        out = self.bn1(out)     # 归一化
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        return out


class UNet(nn.Module):
    def __init__(self, num_classes, input_channels=3, **kwargs):
        super().__init__()

        nb_filter = [32, 64, 128, 256, 512]

        self.pool = nn.MaxPool2d(2, 2)
        self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)#scale_factor:放大的倍数  插值

        self.conv0_0 = VGGBlock(input_channels, nb_filter[0], nb_filter[0])
        self.conv1_0 = VGGBlock(nb_filter[0], nb_filter[1], nb_filter[1])
        self.conv2_0 = VGGBlock(nb_filter[1], nb_filter[2], nb_filter[2])
        self.conv3_0 = VGGBlock(nb_filter[2], nb_filter[3], nb_filter[3])
        self.conv4_0 = VGGBlock(nb_filter[3], nb_filter[4], nb_filter[4])

        self.conv3_1 = VGGBlock(nb_filter[3]+nb_filter[4], nb_filter[3], nb_filter[3])
        self.conv2_2 = VGGBlock(nb_filter[2]+nb_filter[3], nb_filter[2], nb_filter[2])
        self.conv1_3 = VGGBlock(nb_filter[1]+nb_filter[2], nb_filter[1], nb_filter[1])
        self.conv0_4 = VGGBlock(nb_filter[0]+nb_filter[1], nb_filter[0], nb_filter[0])

        self.final = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)


    def forward(self, input):
        x0_0 = self.conv0_0(input)
        x1_0 = self.conv1_0(self.pool(x0_0))
        x2_0 = self.conv2_0(self.pool(x1_0))
        x3_0 = self.conv3_0(self.pool(x2_0))
        x4_0 = self.conv4_0(self.pool(x3_0))

        x3_1 = self.conv3_1(torch.cat([x3_0, self.up(x4_0)], 1))
        x2_2 = self.conv2_2(torch.cat([x2_0, self.up(x3_1)], 1))
        x1_3 = self.conv1_3(torch.cat([x1_0, self.up(x2_2)], 1))
        x0_4 = self.conv0_4(torch.cat([x0_0, self.up(x1_3)], 1))

        output = self.final(x0_4)
        return output


class NestedUNet(nn.Module):
    def __init__(self, num_classes, input_channels=3, deep_supervision=False, **kwargs):
        super().__init__()
        # 定义了一个列表,包含NestedUNet中不同层的通道数
        nb_filter = [32, 64, 128, 256, 512]
        # 深度监督:是否需要都计算损失函数
        self.deep_supervision = deep_supervision

        self.pool = nn.MaxPool2d(2, 2)  # 最大池化,池化核大小为2x2,步幅为2
        # 创建一个上采样层实例,尺度因子为2,采用双线性插值的方式进行上采样,边缘对齐方式为True
        self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

        self.conv0_0 = VGGBlock(input_channels, nb_filter[0], nb_filter[0])
        self.conv1_0 = VGGBlock(nb_filter[0], nb_filter[1], nb_filter[1])
        self.conv2_0 = VGGBlock(nb_filter[1], nb_filter[2], nb_filter[2])
        self.conv3_0 = VGGBlock(nb_filter[2], nb_filter[3], nb_filter[3])
        self.conv4_0 = VGGBlock(nb_filter[3], nb_filter[4], nb_filter[4])

        self.conv0_1 = VGGBlock(nb_filter[0]+nb_filter[1], nb_filter[0], nb_filter[0])
        self.conv1_1 = VGGBlock(nb_filter[1]+nb_filter[2], nb_filter[1], nb_filter[1])
        self.conv2_1 = VGGBlock(nb_filter[2]+nb_filter[3], nb_filter[2], nb_filter[2])
        self.conv3_1 = VGGBlock(nb_filter[3]+nb_filter[4], nb_filter[3], nb_filter[3])

        self.conv0_2 = VGGBlock(nb_filter[0]*2+nb_filter[1], nb_filter[0], nb_filter[0])
        self.conv1_2 = VGGBlock(nb_filter[1]*2+nb_filter[2], nb_filter[1], nb_filter[1])
        self.conv2_2 = VGGBlock(nb_filter[2]*2+nb_filter[3], nb_filter[2], nb_filter[2])

        self.conv0_3 = VGGBlock(nb_filter[0]*3+nb_filter[1], nb_filter[0], nb_filter[0])
        self.conv1_3 = VGGBlock(nb_filter[1]*3+nb_filter[2], nb_filter[1], nb_filter[1])

        self.conv0_4 = VGGBlock(nb_filter[0]*4+nb_filter[1], nb_filter[0], nb_filter[0])

        if self.deep_supervision:
            self.final1 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
            self.final2 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
            self.final3 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
            self.final4 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
        else:
            self.final = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)


    def forward(self, input):
        # 入口函数打个断点,看数据的维度很重要
        print('input:', input.shape)
        x0_0 = self.conv0_0(input)  # 第一次卷积
        print('x0_0:',x0_0.shape)   # 升维 input: torch.Size([8, 32, 96, 96])
        x1_0 = self.conv1_0(self.pool(x0_0))
        print('x1_0:', x1_0.shape)  # 升维,降数据量,x1_0: torch.Size([8, 32, 96, 96])
        x0_1 = self.conv0_1(torch.cat([x0_0, self.up(x1_0)], 1))
        # cat 拼接,再经历一次卷积,input是96=32+64,output=32
        print('x0_1:', x0_1.shape)   # x0_1: torch.Size([8, 32, 96, 96])
        # 梳理清楚一个关键点即可,后面依次类推,可以打印结果自己手动推一下
        x2_0 = self.conv2_0(self.pool(x1_0))
        print('x2_0:', x2_0.shape)
        x1_1 = self.conv1_1(torch.cat([x1_0, self.up(x2_0)], 1))
        print('x1_1:',x1_1.shape)
        x0_2 = self.conv0_2(torch.cat([x0_0, x0_1, self.up(x1_1)], 1))
        print('x0_2:',x0_2.shape)

        x3_0 = self.conv3_0(self.pool(x2_0))
        print('x3_0:',x3_0.shape)
        x2_1 = self.conv2_1(torch.cat([x2_0, self.up(x3_0)], 1))
        print('x2_1:',x2_1.shape)
        x1_2 = self.conv1_2(torch.cat([x1_0, x1_1, self.up(x2_1)], 1))
        print('x1_2:',x1_2.shape)
        x0_3 = self.conv0_3(torch.cat([x0_0, x0_1, x0_2, self.up(x1_2)], 1))
        print('x0_3:',x0_3.shape)
        x4_0 = self.conv4_0(self.pool(x3_0))
        print('x4_0:',x4_0.shape)
        x3_1 = self.conv3_1(torch.cat([x3_0, self.up(x4_0)], 1))
        print('x3_1:',x3_1.shape)
        x2_2 = self.conv2_2(torch.cat([x2_0, x2_1, self.up(x3_1)], 1))
        print('x2_2:',x2_2.shape)
        x1_3 = self.conv1_3(torch.cat([x1_0, x1_1, x1_2, self.up(x2_2)], 1))
        print('x1_3:',x1_3.shape)
        x0_4 = self.conv0_4(torch.cat([x0_0, x0_1, x0_2, x0_3, self.up(x1_3)], 1))
        print('x0_4:',x0_4.shape)

        if self.deep_supervision:
            output1 = self.final1(x0_1)
            output2 = self.final2(x0_2)
            output3 = self.final3(x0_3)
            output4 = self.final4(x0_4)
            return [output1, output2, output3, output4]

        else:
            # 输出一个结果,结果是0~1之间
            output = self.final(x0_4)
            return output

archs解读
在这里插入图片描述
在这里插入图片描述
conv00代表着图中的X00,conv20代表图中的X20,以此类推。
每一个vggblock
在这里插入图片描述

3.3 结果

在这里插入图片描述

4. 参考链接

深度学习分割任务——Unet++分割网络代码详细解读

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1818765.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【电机】开环控制系统和闭环控制系统

1 什么是控制系统 控制系统是指由控制主体、控制客体和控制媒体组成的具有自身目标和功能的管理系统。也可以理解为:为了使控制对象达到预期的稳定状态。例如一个水箱的温度控制,可以通过控制加热设备输出的功率进而来改变水温达到目标温度,…

Linux发邮件的工具推荐有哪些?如何配置?

Linux发邮件的功能怎么样?Linux系统如何设置服务器? 在Linux操作系统中,有多种工具可供选择用来发送电子邮件,每种工具都有其独特的特点和适用场景。AokSend将介绍几种常用的Linux发邮件工具,并分析它们的优缺点和适用…

接口自动化测试的全面解析与实战指南!

🚀 【引言】🚀 接口自动化测试,作为现代软件开发生命周期中的关键一环,扮演着“质量守门员”的角色。它不仅关乎提升开发速度,更在于确保每一次更新都能可靠地满足用户期待。接下来,我们将踏上一场深入浅出…

Redis分布式锁的实现、优化与Redlock算法探讨

Redis分布式锁最简单的实现 要实现分布式锁,首先需要Redis具备“互斥”能力,这可以通过SETNX命令实现。SETNX表示SET if Not Exists,即如果key不存在,才会设置它的值,否则什么也不做。利用这一点,不同客户端就能实现互斥,从而实现一个分布式锁。 举例: 客户端1申请加…

RH850---注意问题积累--1

硬件规格(引脚分配,内存映射,外设功能规格、电气特性、时序图)和操作说明 注意:有关使用的详细信息,请参阅应用说明 ---------外围函数。。。 1:存储指令完成与后续同步指令的一代 当控制寄存器被存储指令更新时,从存储的执行开始…

在网站建设时,如何选择适合自己的网站模版

可以根据以下几个地方选择适合的网站模板 1.公司的核心业务 根据公司的业务内容来确定网站展示的内容之一,不同的业务内容可以有不同的展示方式,以此来确定网站的展示风格之一,公司肯定是要有明确的业务内容,并且能够在网站…

[C#]winform使用onnxruntime部署LYT-Net轻量级低光图像增强算法

【训练源码】 https://github.com/albrateanu/LYT-Net 【参考源码】 https://github.com/hpc203/Low-Light-Image-Enhancement-onnxrun 【算法介绍】 一、研究动机 1.研究目标 研究的目标是提出一种轻量级的基于YUV Transformer 的网络(LYT-Net)&…

neo4j-官网学习

1、cypher 代码学习文档 https://neo4j.com/docs/cypher-cheat-sheet/5/auradb-enterprise 2、APOC函数包安装(desktop) 直接点击就可以安装,安装完之后重启一下,Cypher查询中使用CALL apoc.help(‘apoc’)来检查APOC插件是否已…

Java技术驱动的工程项目管理系统源码:工程管理的数字化解决方案

工程项目管理系统是一款基于Java技术的专业工程管理软件,它采用了Spring Cloud、Spring Boot、Mybatis、Vue和ElementUI等前沿技术,通过前后端分离架构构建了一个功能全面的工程项目管理系统。 随着公司的发展,工程管理的需求日益增长&#x…

图像处理与视觉感知复习--彩色图像处理

文章目录 三原色原理及其两种应用常用彩色模型及其应用领域各种颜色模型的转换彩色图像处理 三原色原理及其两种应用 三基色原理 自然界中绝大多数的颜色都可看作是由红、绿、蓝三种颜色组合而成;自然界中的绝大多数的颜色都可以分解成红、绿、蓝这三种颜色。这即…

渗透测试模拟实战-tomexam网络考试系统

渗透测试,也称为“pentest”或“道德黑客”,是一种模拟攻击的网络安全评估方法,旨在识别和利用系统中的安全漏洞。这种测试通常由专业的安全专家执行,他们使用各种技术和工具来尝试突破系统的防御,如网络、应用程序、主…

【PyQt5】简要介绍

文章目录 一、PyQt5的简介、安装、配置1.1 简介1.2 安装与配置1.3 QtDesigner1.3.1 基础操作 二、PyQt5的基本控件(Widget Box)2.1 基类(QWidget)2.1.1 QWidget 2.2 Button类(属于QtWidgets:QPushButton&am…

轮到国产游戏统治Steam榜单

6月10日晚8点,《黑神话:悟空》实体版正式开启全款预售,预售开启不到5分钟,所有产品即宣告售罄。 Steam上,《黑神话:悟空》持续占据着热销榜榜首的位置。 但在《黑神话:悟空》傲人的光环下,还有一款国产游戏取得出色的成绩。 6月10日&#…

vue3+ Element-Plus 点击勾选框往input中动态添加多个tag

实现效果&#xff1a; template&#xff1a; <!--产品白名单--><div class"con-item" v-if"current 0"><el-form-item label"平台名称"><div class"contaion" click"onclick"><!-- 生成的标签 …

六西格玛培训都培训哪些内容 ?

天行健六西格玛培训的内容通常涵盖多个方面&#xff0c;旨在帮助学员全面理解和应用六西格玛管理方法。以下是详细的培训内容概述&#xff1a; 一、六西格玛基础知识 引入六西格玛的概念、原理和历史&#xff0c;包括DMAIC&#xff08;定义、测量、分析、改进、控制&#xff0…

轻松搭建,一键开发,MemFire Cloud:懒人开发者的创意神器

在如今快节奏的时代&#xff0c;对于开发者来说&#xff0c;时间就是金钱。但是&#xff0c;要想快速开发一个应用&#xff0c;却需要花费大量时间搭建服务、开发接口、集成认证等等&#xff0c;这无疑增加了开发者的工作负担。但现在&#xff0c;有了MemFire Cloud&#xff0c…

Petalinux由于网络原因产生的编译错误(2)--Fetcher failure:Unable to find file

1 Fetcher failure:Unable to find file 错误 如果编译工程遇到如下图所示的“Fetcher failure for URL”或相似错误 出现这种错误的原因是 Petalinux 在配置和编译的时候&#xff0c;需要联网下载一些文件&#xff0c;由于网 络原因这些文件不能正常下载&#xff0c;导致编译…

分享一些经典的国外二维码活动案例,含二维码制作技巧

二维码具有信息容量大、编码范围广、自由度高、容错能力强、保密性、防伪性好、译码可靠性高等多项优势&#xff0c;所以二维码应用极其广泛&#xff0c;它作为一种及时、准确、可靠、经济的数据输入手段&#xff0c;已在工业、商业、交通、金融、医疗卫生、办公自动化等许多领…

Python发送Outlook邮件的步骤流程有哪些?

Python发送Outlook邮件的技巧&#xff1f;如何使用Python发信&#xff1f; 在Python中使用SMTP协议发送邮件到Outlook邮箱是一项常见的任务。AokSend将介绍如何通过Python编程语言实现这一过程&#xff0c;从准备工作到实际发送邮件的具体步骤。 Python发送Outlook邮件&#…

Beyond Compare 提示“缺少评估信息或损坏”,无法打开只要操作一行命令就可以了

在CMD 或者powershell下执行如下命令重新打开即可。 reg delete "HKEY_CURRENT_USER\Software\Scooter Software\Beyond Compare 4" /v CacheID /f重新打开&#xff0c;就ok 了