【ICCV2023论文阅读】XNet(能跑通代码)

news2025/1/2 4:10:31

这里写目录标题

  • 论文阅读
    • 摘要
    • 介绍
    • 方法
      • overview
      • why use wavelet transform?
      • 融合方法
      • 用于全监督分割和半监督分割可行性分析
    • 效果
    • 局限性
    • 总结
  • 代码跑通
    • 去掉分布式训练
    • 生成低频和高频图片
    • 产生数据集
    • 改读取数据的位置
    • 损失函数
    • 添加自己数据集的信息
    • 结果

ps:我现在不知道自己研究方向是做什么的,就是分割也试试,医疗诊断也试试。然后之后更的尽量把代码跑通也写上。因为之前代码水平有限不能很好的跑通,然后我只是说我这个数据集怎么改,以及我这个硬件水平下,你们看着改就可以。
论文地址
代码地址

论文阅读

摘要

现状

  1. 把全监督分割和半监督当作两种训练方式,很少有把它们统一起来的。(本文就把这两个统一起来了,就是个创新点)
  2. 很少有完全监督的模型关注图像的固有的低频信息和高频信息去提升性能。
  3. 半监督学习的扰动是人为添加的,可能引入不利的学习偏见。
    方法
    提出了一种基于小波的LF和HF融合模型XNet,它支持全监督和半监督语义分割,并在这两个领域都优于最先进的模型。

介绍

启发:对于语义分割问题,HF信息通常表示图像细节,LF信息通常是抽象与一。提取和融合不同频率信息的策略可以帮助模型更好地关注LF予以和HF细节,以提高性能。模型使用小波变换生成LF和HF图像,用于基于一致性差分的半监督学习。这些一致性差异源于对LF和HF信息的不同关注,这缓解了人工设计造成的学习偏差。
contributions:

  1. 提出了低频和高频融合模型XNet,在有监督和半监督上实现了优异的性能。
  2. XNet使用小波变化生成LF和HF图片来进行一致性学习,可以减轻人为扰动引起的学习偏差。
  3. 在两个2D和两个3D公共生物医学数据集上进行的广泛基准测试证实了XNet的有效性。

方法

在这里插入图片描述

overview

获取相应的LF和HF图片。然后将它们输入到LF和HF编码器以分别生成LF(语义)和HF(边缘、纹理)特征。之后使用融合模块对他们的特征进行融合。然后把融合特征放到解码器中获得LF和HF分支的预测结果。全监督损失是监督损失(两个分支的预测和真实值之间的损失,记为 L s u p L_{sup} Lsup)和标记图像的一致性损失(记为 L u n s u p L_{unsup} Lunsup)。半监督训练,最大限度减少标记图像的监督损失和未标记图像的双重输出的一致性损失。都是dice loss。
L u n s u p L_{unsup} Lunsup是由交叉伪标签监督损失实现,用一个分支的预测作为伪标签去监督另一个分支。 L u n s u p = L u n s u p H ( p i L , p ^ i H ) + L u n s u p L ( p i H , p ^ i L ) L_{unsup}=L_{unsup}^{H}(p^L_{i},\hat{p}^H_{i})+L_{unsup}^{L}(p^H_{i},\hat{p}^L_{i}) Lunsup=LunsupH(piL,p^iH)+LunsupL(piH,p^iL)
我们选择在训练阶段表现更好的分支作为推理过程中的最终输出。

why use wavelet transform?

在这里插入图片描述
与其他方法(如傅立叶变换)相比,小波变换是生成L和H的有效方法。使用L作为输入,XNet可以更多地关注LF语义,因为L具有较少的噪声和细节。相比之下,H具有更多的噪声,但对象边界更清晰,这可以帮助模型更多地关注HF细节。此外,使用L和H进行半监督训练,一致性差异来自图像的固有LF和HF信息,这可以缓解人工扰动引起的学习偏差。

融合方法

在这里插入图片描述
LF和HF融合模块的架构。相同大小Conv表示输出和输入特征具有相同大小。下采样Conv将输出特征的大小减少一半。上采样Conv使输出特征的大小加倍。Transition Conv使用信道级联特征作为输入和输出融合特征。
就是LF Feature1是第4层的feature,它进行一次不改变大小的卷积得到第一个有花纹的蓝色块也就是特征,它进行一次下采样得到下面那个小的特征。LF Feature2是第5层的feature,它进行一次上采样得到横杠的看色特征,进行一次不改变大小的卷积得到方块特征。其他同理,结合之后进行卷积获得和原来LF Feature1相同大小的特征图,然后进行U-Net那个skip connect即可。

用于全监督分割和半监督分割可行性分析

对于生物医学图像,我们假设原始图像I由LF特征FL、HF特征FH、LF加性噪声NL和HF加性噪声NH组成。因此, I I I被定义为:
I = F L + F H + N L + N H I = F_L+F_H+N_L+N_H I=FL+FH+NL+NH
因为生物医学图像中的噪声通常是加性的。对于语义分割问题,准确的分割需要LF语义(如形状、颜色等)和HF细节(如边缘、纹理等)。
对于监督学习,对完整信息进行解码可以获得分割预测。对于半监督学习,由于每个解码分支对LF和HF信息的关注程度不同,因此双分支解码器的预测在LF语义和HF细节方面存在差异。这些差异可用于基于一致性规则的半监督训练。
总之,XNet既可以用于全监督学习,也可以用于半监督学习。图显示了XNet分割过程的拓扑流程图。
在这里插入图片描述

效果

在这里插入图片描述

局限性

由于XNet强调HF信息,当图像几乎没有HF信息时,XNet的性能会受到负面影响。

总结

我们提出了一种基于小波的低频和高频融合模型XNet,该模型在生物医学图像的全监督和半监督语义分割方面都取得了最先进的性能。在2D和3D数据集上进行的大量实验证明了我们提出的模型的有效性。然而,XNet的局限性在于,当高频信息不可用时,其性能可能会受到负面影响。我们认为,完全监督和半监督的语义分割模型可以而且应该是统一的。我们希望我们的研究能为它们的统一提供一些例证和思考。

代码跑通

完全可以按作者的那个readme对自己数据集进行修改跑通。
我不了解分布式训练,所以一直报错。下面展示不用分布式训练的代码。

去掉分布式训练

下面是我把有分布式训练的地方都删了。(应该问题不大吧。(lll¬ω¬))

from torchvision import transforms, datasets
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.autograd import Variable
from torch.utils.data import DataLoader
import argparse
import time
import os
import numpy as np
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
from torch.backends import cudnn
import random
from config.dataset_config.dataset_cfg import dataset_cfg

from config.train_test_config.train_test_config import print_train_loss_XNet, print_val_loss, print_train_eval_XNet, print_val_eval, save_val_best_2d, draw_pred_XNet, print_best
from config.visdom_config.visual_visdom import visdom_initialization_XNet, visualization_XNet, visual_image_XNet
from config.warmup_config.warmup import GradualWarmupScheduler
from config.augmentation.online_aug import data_transform_2d, data_normalize_2d
from loss.loss_function import segmentation_loss
from models.getnetwork import get_network
from dataload.dataset_2d import imagefloder_iitnn
from warnings import simplefilter

simplefilter(action='ignore', category=FutureWarning)




if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--path_trained_models', default='./checkpoints/sup_xnet')
    parser.add_argument('--path_seg_results', default='./seg_pred/sup_xnet')
    parser.add_argument('--path_dataset', default='自己数据集的根目录')
    parser.add_argument('--dataset_name', default='自己数据集的名称', help='CREMI, ISIC-2017, GlaS')
    parser.add_argument('--input1', default='L')
    parser.add_argument('--input2', default='H')
    parser.add_argument('--sup_mark', default='100')
    parser.add_argument('-b', '--batch_size', default=4, type=int)
    parser.add_argument('-e', '--num_epochs', default=200, type=int)
    parser.add_argument('-s', '--step_size', default=50, type=int)
    parser.add_argument('-l', '--lr', default=0.5, type=float)
    parser.add_argument('-g', '--gamma', default=0.5, type=float)
    parser.add_argument('-u', '--unsup_weight', default=5, type=float)
    parser.add_argument('--loss', default='dice', type=str)
    parser.add_argument('-w', '--warm_up_duration', default=20)
    parser.add_argument('--momentum', default=0.9, type=float)
    parser.add_argument('--wd', default=-5, type=float, help='weight decay pow')

    parser.add_argument('-i', '--display_iter', default=5, type=int)
    parser.add_argument('-n', '--network', default='xnet', type=str)
    parser.add_argument('--local_rank', default=-1, type=int)

    args = parser.parse_args()


    dataset_name = args.dataset_name
    cfg = dataset_cfg(dataset_name)

    print_num = 77 + (cfg['NUM_CLASSES'] - 3) * 14
    print_num_minus = print_num - 2
    print_num_half = int(print_num / 2 - 1)

    path_trained_models = args.path_trained_models + '/' + str(os.path.split(args.path_dataset)[1])


    path_seg_results = args.path_seg_results + '/' + str(os.path.split(args.path_dataset)[1])

    # Dataset
    if args.input1 == 'image':
        input1_mean = 'MEAN'
        input1_std = 'STD'
    else:
        input1_mean = 'MEAN_' + args.input1
        input1_std = 'STD_' + args.input1

    if args.input2 == 'image':
        input2_mean = 'MEAN'
        input2_std = 'STD'
    else:
        input2_mean = 'MEAN_' + args.input2
        input2_std = 'STD_' + args.input2

    data_transforms = data_transform_2d()
    data_normalize_1 = data_normalize_2d(cfg[input1_mean], cfg[input1_std])
    data_normalize_2 = data_normalize_2d(cfg[input2_mean], cfg[input2_std])

    dataset_train = imagefloder_iitnn(
        data_dir=args.path_dataset+'/train',
        input1=args.input1,
        input2=args.input2,
        data_transform_1=data_transforms['train'],
        data_normalize_1=data_normalize_1,
        data_normalize_2=data_normalize_2,
        sup=True,
        num_images=None,
    )
    dataset_val = imagefloder_iitnn(
        data_dir=args.path_dataset + '/val',
        input1=args.input1,
        input2=args.input2,
        data_transform_1=data_transforms['val'],
        data_normalize_1=data_normalize_1,
        data_normalize_2=data_normalize_2,
        sup=True,
        num_images=None,
    )

    dataloaders = dict()
    dataloaders['train'] = DataLoader(dataset_train, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=8)
    dataloaders['val'] = DataLoader(dataset_val, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=8)

    num_batches = {'train_sup': len(dataloaders['train']), 'val': len(dataloaders['val'])}

    # Model
    model = get_network(args.network, cfg['IN_CHANNELS'], cfg['NUM_CLASSES'])
    model = model.cuda()

    # Training Strategy
    criterion = segmentation_loss(args.loss, False).cuda()

    optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=5*10**args.wd)
    exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=args.gamma)
    scheduler_warmup = GradualWarmupScheduler(optimizer, multiplier=1.0, total_epoch=args.warm_up_duration, after_scheduler=exp_lr_scheduler)

    # Train & Val
    since = time.time()
    count_iter = 0

    best_model = model
    best_result = 'Result1'
    best_val_eval_list = [0 for i in range(4)]

    for epoch in range(args.num_epochs):

        count_iter += 1
        if (count_iter - 1) % args.display_iter == 0:
            begin_time = time.time()
        model.train()

        train_loss_sup_1 = 0.0
        train_loss_sup_2 = 0.0
        train_loss_unsup = 0.0
        train_loss = 0.0
        val_loss_sup_1 = 0.0
        val_loss_sup_2 = 0.0

        unsup_weight = args.unsup_weight * (epoch + 1) / args.num_epochs

        # dist.barrier()

        for i, data in enumerate(dataloaders['train']):

            inputs_train_1 = Variable(data['image'].cuda())
            inputs_train_2 = Variable(data['image_2'].cuda())
            mask_train = Variable(data['mask'].cuda())

            optimizer.zero_grad()
            outputs_train1, outputs_train2 = model(inputs_train_1, inputs_train_2)
            torch.cuda.empty_cache()

            if count_iter % args.display_iter == 0:
                if i == 0:
                    score_list_train1 = outputs_train1
                    score_list_train2 = outputs_train2
                    mask_list_train = mask_train
                # else:
                elif 0 < i <= num_batches['train_sup'] / 4:
                    score_list_train1 = torch.cat((score_list_train1, outputs_train1), dim=0)
                    score_list_train2 = torch.cat((score_list_train2, outputs_train2), dim=0)
                    mask_list_train = torch.cat((mask_list_train, mask_train), dim=0)

            max_train1 = torch.max(outputs_train1, dim=1)[1]
            max_train2 = torch.max(outputs_train2, dim=1)[1]
            max_train1 = max_train1.long()
            max_train2 = max_train2.long()

            loss_train_sup1 = criterion(outputs_train1, mask_train)
            loss_train_sup2 = criterion(outputs_train2, mask_train)
            loss_train_unsup = criterion(outputs_train1, max_train2) + criterion(outputs_train2, max_train1)
            loss_train_unsup = loss_train_unsup * unsup_weight
            loss_train = loss_train_sup1 + loss_train_sup2 + loss_train_unsup

            loss_train.backward()
            optimizer.step()

            train_loss_sup_1 += loss_train_sup1.item()
            train_loss_sup_2 += loss_train_sup2.item()
            train_loss_unsup += loss_train_unsup.item()
            train_loss += loss_train.item()

        scheduler_warmup.step()
        # torch.cuda.empty_cache()
        if count_iter % args.display_iter == 0:

            print('=' * print_num)
            print('| Epoch {}/{}'.format(epoch + 1, args.num_epochs).ljust(print_num_minus, ' '), '|')
            train_epoch_loss_sup1, train_epoch_loss_sup2, train_epoch_loss_cps, train_epoch_loss = print_train_loss_XNet(
                train_loss_sup_1, train_loss_sup_2, train_loss_unsup, train_loss, num_batches, print_num,
                print_num_half)
            # print(score_list_train1)
            # print(score_list_train2)
            train_eval_list1, train_eval_list2, train_m_jc1, train_m_jc2 = print_train_eval_XNet(cfg['NUM_CLASSES'], score_list_train1, score_list_train2, mask_list_train, print_num_half)
            torch.cuda.empty_cache()

            with torch.no_grad():
                model.eval()

                for i, data in enumerate(dataloaders['val']):

                    # if 0 <= i <= num_batches['val']:

                    inputs_val = Variable(data['image'].cuda())
                    inputs_val_wavelet = Variable(data['image_2'].cuda())
                    mask_val = Variable(data['mask'].cuda())
                    name_val = data['ID']

                    optimizer.zero_grad()
                    outputs_val1, outputs_val2 = model(inputs_val, inputs_val_wavelet)
                    torch.cuda.empty_cache()

                    if i == 0:
                        score_list_val1 = outputs_val1
                        score_list_val2 = outputs_val2
                        mask_list_val = mask_val
                        name_list_val = name_val
                    else:
                        score_list_val1 = torch.cat((score_list_val1, outputs_val1), dim=0)
                        score_list_val2 = torch.cat((score_list_val2, outputs_val2), dim=0)
                        mask_list_val = torch.cat((mask_list_val, mask_val), dim=0)
                        name_list_val = np.append(name_list_val, name_val, axis=0)

                    loss_val_sup1 = criterion(outputs_val1, mask_val)
                    loss_val_sup2 = criterion(outputs_val2, mask_val)

                    val_loss_sup_1 += loss_val_sup1.item()
                    val_loss_sup_2 += loss_val_sup2.item()
                torch.cuda.empty_cache()

                val_epoch_loss_sup1, val_epoch_loss_sup2 = print_val_loss(val_loss_sup_1, val_loss_sup_2,
                                                                          num_batches, print_num, print_num_half)
                val_eval_list1, val_eval_list2, val_m_jc1, val_m_jc2 = print_val_eval(cfg['NUM_CLASSES'],
                                                                                      score_list_val1,
                                                                                      score_list_val2,
                                                                                      mask_list_val, print_num_half)
                best_val_eval_list, best_model, best_result = save_val_best_2d(cfg['NUM_CLASSES'], best_model,
                                                                               best_val_eval_list, best_result,
                                                                               model, model, score_list_val1,
                                                                               score_list_val2, name_list_val,
                                                                               val_eval_list1, val_eval_list2,
                                                                               path_trained_models,
                                                                               path_seg_results, cfg['PALETTE'])

                torch.cuda.empty_cache()
        torch.cuda.empty_cache()

生成低频和高频图片

里面有一个wavelet2D.py(我是2D图片)。运行即可。

import numpy as np
from PIL import Image
import pywt
import argparse
import os

if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    parser.add_argument('--image_path', default='自己数据的位置')
   # parser.add_argument('--mask_path', default='')
    parser.add_argument('--L_path', default='自己保存低频图片的位置')
    parser.add_argument('--H_path', default='自己保存高频图片的位置')
    parser.add_argument('--wavelet_type', default='db2', help='haar, db2, bior1.5, bior2.4, coif1, dmey')
    parser.add_argument('--if_RGB', default=False)
    args = parser.parse_args()

    if not os.path.exists(args.L_path):
        os.mkdir(args.L_path)
    if not os.path.exists(args.H_path):
        os.mkdir(args.H_path)

    for i in os.listdir(args.image_path):
        image_path = os.path.join(args.image_path, i)
        L_path = os.path.join(args.L_path, i)
        H_path = os.path.join(args.H_path, i)

        if args.if_RGB:
            image = Image.open(image_path).convert('L')
        else:
            image = Image.open(image_path)
        image = np.array(image)

        LL, (LH, HL, HH) = pywt.dwt2(image, args.wavelet_type)

        LL = (LL - LL.min()) / (LL.max() - LL.min()) * 255

        LL = Image.fromarray(LL.astype(np.uint8))
        LL.save(L_path)

        LH = (LH - LH.min()) / (LH.max() - LH.min()) * 255
        HL = (HL - HL.min()) / (HL.max() - HL.min()) * 255
        HH = (HH - HH.min()) / (HH.max() - HH.min()) * 255

        merge1 = HH + HL + LH
        merge1 = (merge1-merge1.min()) / (merge1.max()-merge1.min()) * 255

        merge1 = Image.fromarray(merge1.astype(np.uint8))
        merge1.save(H_path)

产生数据集

我是这样的所以我读取数据集的时候还得改。大家也可以按照这个项目的readme中的那个文件对我下面这个产生数据集代码中的路径进行修改。
dataset
├── train
├── L
├── 1.png
├── 2.png
└── …
├── H
├── 1.png
├── 2.png
└── …
└── mask
├── 1.png
├── 2.png
└── …
└── val
├── L
├── H
└── mask

import os
import argparse
import random
import shutil
from shutil import copyfile



def rm_mkdir(dir_path):
    if os.path.exists(dir_path):
        shutil.rmtree(dir_path)
        print('Remove path - %s' % dir_path)
    os.makedirs(dir_path)
    print('Create path - %s' % dir_path)


def main(config):
    rm_mkdir(os.path.join(config.train_path,'H'))
    rm_mkdir(os.path.join(config.train_path,'L'))
    rm_mkdir(os.path.join(config.train_path, 'mask'))
    rm_mkdir(os.path.join(config.valid_path, 'H'))
    rm_mkdir(os.path.join(config.valid_path, 'L'))
    rm_mkdir(os.path.join(config.valid_path, 'mask'))
    H_path = os.path.join(config.origin_data_path, 'H')

    H_filenames = os.listdir(H_path)
    data_list = []

    for filename in H_filenames:
        ext = os.path.splitext(filename)[-1]
        if ext == '.png':
            filename = os.path.basename(filename)
            data_list.append(filename)

    num_total = len(data_list)
    num_train = int((config.train_ratio / (config.train_ratio + config.valid_ratio )) * num_total)
    num_valid = int((config.valid_ratio / (config.train_ratio + config.valid_ratio )) * num_total)

    print('\nNum of train set : ', num_train)
    print('\nNum of valid set : ', num_valid)

    Arange = list(range(num_total))
    random.shuffle(Arange)

    for i in range(num_train):
        idx = Arange.pop()

        src = os.path.join(config.origin_data_path,'H', data_list[idx])
        dst = os.path.join(config.train_path,'H', data_list[idx])
        copyfile(src, dst)

        src = os.path.join(config.origin_data_path, 'L', data_list[idx])
        dst = os.path.join(config.train_path, 'L', data_list[idx])
        copyfile(src, dst)

        src = os.path.join(config.origin_data_path, 'mask_', data_list[idx])
        dst = os.path.join(config.train_path, 'mask', data_list[idx])
        copyfile(src, dst)


    for i in range(num_valid):
        idx = Arange.pop()

        src = os.path.join(config.origin_data_path, 'H', data_list[idx])
        dst = os.path.join(config.valid_path, 'H', data_list[idx])
        copyfile(src, dst)

        src = os.path.join(config.origin_data_path, 'L', data_list[idx])
        dst = os.path.join(config.valid_path, 'L', data_list[idx])
        copyfile(src, dst)

        src = os.path.join(config.origin_data_path, 'mask', data_list[idx])
        dst = os.path.join(config.valid_path, 'mask', data_list[idx])
        copyfile(src, dst)




if __name__ == '__main__':
    parser = argparse.ArgumentParser()

    # model hyper-parameters
    parser.add_argument('--train_ratio', type=float, default=0.8)#训练集和测试集的比例
    parser.add_argument('--valid_ratio', type=float, default=0.2)

    # data path
    parser.add_argument('--origin_data_path', type=str, default='自己数据的位置')
    parser.add_argument('--train_path', type=str, default='./train/')#自己要保存的训练集和测试集的位置←↓
    parser.add_argument('--valid_path', type=str, default='./val/')

    config = parser.parse_args()
    print(config)
    main(config)

改读取数据的位置

main.py中他原来是train_sup100,我用的是train文件夹。所以dataloader的参数要改。

dataset_train = imagefloder_iitnn(
        data_dir=args.path_dataset+'/train',
        input1=args.input1,
        input2=args.input2,
        data_transform_1=data_transforms['train'],
        data_normalize_1=data_normalize_1,
        data_normalize_2=data_normalize_2,
        sup=True,
        num_images=None,
    )
    dataset_val = imagefloder_iitnn(
        data_dir=args.path_dataset + '/val',
        input1=args.input1,
        input2=args.input2,
        data_transform_1=data_transforms['val'],
        data_normalize_1=data_normalize_1,
        data_normalize_2=data_normalize_2,
        sup=True,
        num_images=None,
    )

还有dataset_2d.py中的,也有train_sup100好像也改了。具体的忘了。

class dataset_iitnn(Dataset):
    
    def __init__(self, data_dir, input1, input2, augmentation1, normalize_1, normalize_2, sup=True,
                 num_images=None, **kwargs):
        super(dataset_iitnn, self).__init__()

        img_paths_1 = []
        img_paths_2 = []
        mask_paths = []

        image_dir_1 = data_dir + '/' + input1
        image_dir_2 = data_dir + '/' + input2
        if sup:
            mask_dir = data_dir + '/mask'

损失函数

我数据集是只有一个类别。

class DiceLoss(nn.Module):
    """Dice loss, need one hot encode input"""

    def __init__(self, weight=None, aux=False, aux_weight=0.4, ignore_index=-1, **kwargs):
        super(DiceLoss, self).__init__()
        self.kwargs = kwargs
        self.weight = weight
        self.ignore_index = ignore_index
        self.aux = aux
        self.aux_weight = aux_weight

    def _base_forward(self, predict, target, valid_mask):

        dice = BinaryDiceLoss(**self.kwargs)
        total_loss = 0
        predict = F.softmax(predict, dim=1)


        for i in range(target.shape[-1]):
            if i != self.ignore_index:
                dice_loss = dice(predict, target, valid_mask)#这里只有一个类别的把[i,:]删了,不然会报错因为超出范围
                if self.weight is not None:
                    assert self.weight.shape[0] == target.shape[1], \
                        'Expect weight shape [{}], get[{}]'.format(target.shape[1], self.weight.shape[0])
                    dice_loss *= self.weights[i]
                total_loss += dice_loss

        return total_loss / target.shape[-1]

    def _aux_forward(self, output, target, **kwargs):
        # *preds, target = tuple(inputs)
        valid_mask = (target != self.ignore_index).long()
        target_one_hot = F.one_hot(torch.clamp_min(target, 0))
        loss = self._base_forward(output[0], target_one_hot, valid_mask)
        for i in range(1, len(output)):
            aux_loss = self._base_forward(output[i], target_one_hot, valid_mask)
            loss += self.aux_weight * aux_loss
        return loss

    def forward(self, output, target):
        # preds, target = tuple(inputs)
        # inputs = tuple(list(preds) + [target])
        if self.aux:
            return self._aux_forward(output, target)
        else:
            valid_mask = (target != self.ignore_index).long()
            # target_one_hot = F.one_hot(torch.clamp_min(target, 0))
            # target_one_hot = F.one_hot(torch.clamp_min(target, 0))#这个注释掉
            return self._base_forward(output, target, valid_mask)#把target_one_hot改成target

添加自己数据集的信息

在/config/dataset_config/dataset_cfg.py中。
添加自己数据集的信息。我的理解。下面是求相关数据的程序。

'Data_one':
            {
                'IN_CHANNELS': 1,#单通道的
                'NUM_CLASSES': 1,
                'MEAN': [0.1612872],
                'STD': [0.1612872],
                'MEAN_H': [0.44275072],
                'STD_H': [0.44275072],
                'MEAN_L': [0.21374299],
                'STD_L': [0.22170983],
                'PALETTE': list(np.array([
                    [255, 255, 255],
                ]).flatten())
            },
import cv2
import numpy as np
import os

def compute_mean_std(dataset_path):
    # 初始化累积器
    mean_accumulator = np.zeros(3)
    std_accumulator = np.zeros(3)
    total_samples = 0

    # 遍历数据集
    for image_file in os.listdir(dataset_path):
        if image_file.endswith(".jpg") or image_file.endswith(".png"):
            image_path = os.path.join(dataset_path, image_file)

            # 读取图像
            image = cv2.imread(image_path)
            image = image / 255.0  # 将像素值缩放到 [0, 1]

            # 计算均值和标准差
            mean_accumulator += np.mean(image, axis=(0, 1))
            std_accumulator += np.std(image, axis=(0, 1))
            total_samples += 1

    # 计算平均值
    mean_values = mean_accumulator / total_samples

    # 计算标准差
    std_values = std_accumulator / total_samples

    return mean_values, std_values


# 示例用法
dataset_path = "自己要求平均值、方差的数据的位置"#L,H,
#mask_path = ""
mean_values, std_values = compute_mean_std(dataset_path)

print("MEAN:", mean_values)
print("STD:", std_values)

结果

暂时跑出来是这样的。要是有问题之后会更新。大家也可以调调错误。感谢。
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1275329.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

<JavaEE> 什么是线程安全?产生线程不安全的原因和处理方式

目录 一、线程安全的概念 二、线程不安全经典示例 三、线程不安全的原因和处理方式 3.1 线程的随机调度和抢占式执行 3.2 修改共享数据 3.3 关键代码或指令不是“原子”的 3.4 内存可见性和指令重排序 四、Java标准库自带的线程安全类 一、线程安全的概念 线程安全是指…

WebGL开发交互式艺术品技术方案

开发交互式艺术品需要使用 WebGL 技术&#xff0c;并结合其他前端技术以实现丰富的用户体验。以下是一个可能的技术方案&#xff0c;希望对大家有所帮助。北京木奇移动技术有限公司&#xff0c;专业的软件外包开发公司&#xff0c;欢迎交流合作。 1.WebGL 框架&#xff1a; 选…

业余爱好-社会工程管理记账报税

税务问题笔记 印花税税费申报及缴纳财务和行为税合并纳税申报增值税及附加税费申报企业所得税季度A类申报残疾人就业保障金申报财务报表个税申报 印花税 印花税是对在经济活动和经济交往中书立、领受具有法律效力的凭证的行为征收的一种税。 税费申报及缴纳 财务和行为税合并…

String类 ---java

目录 一. 常用的字符串的构造 二. 字符串的源代码 三. 字符串比较 1. 是不能比较字符串的值的 ​编辑 2.比较两个字符串 --- compareTo() 3. 忽略大小写比较 ---compareToIgnoreCase() 四. 字符串转化 1. 数字转字符串 valueOf() 2. 字符串转数字 3. 小写转大写 to…

【C指针】深入理解指针(最终篇)数组指针指针运算题解析(一)

&#x1f308;write in front :&#x1f50d;个人主页 &#xff1a; 啊森要自信的主页 ✏️真正相信奇迹的家伙&#xff0c;本身和奇迹一样了不起啊&#xff01; 欢迎大家关注&#x1f50d;点赞&#x1f44d;收藏⭐️留言&#x1f4dd;>希望看完我的文章对你有小小的帮助&am…

如何解决SSL证书部署后未生效或网站显示不安全

本文介绍SSL证书部署后未生效或网站显示不安全的排查方法。 浏览器提示“您与此网站建立的连接不安全” 浏览器提示“无法访问此页面” 浏览器提示“这可能是因为站点使用过期或者不全的TLS安全设置” 浏览器提示“此页面上部分内容不安全&#xff08;例如图像&#xff09;”…

LeetCode刷题---汉诺塔问题

个人主页&#xff1a;元清加油_【C】,【C语言】,【数据结构与算法】-CSDN博客 前言&#xff1a;这个专栏主要讲述递归递归、搜索与回溯算法&#xff0c;所以下面题目主要也是这些算法做的 我讲述题目会把讲解部分分为3个部分&#xff1a; 1、题目解析 2、算法原理思路讲解 …

c++ pcl出现LNK2019 宏定义 PCL_NO_PRECOMPILE

问题&#xff1a;c pcl使用拟合圆柱时出现LNK2019问题&#xff1b; 说明&#xff1a;lib等配置没有问题&#xff1b; 解决方案 在上述代码中添加如下代码即可 #define PCL_NO_PRECOMPILE 是 C 中的预处理器指令&#xff0c;用于在代码中定义一个宏。而 #undef PCL_NO_PRECOM…

【数电笔记】基本和复合逻辑运算

说明&#xff1a; 笔记配套视频来源&#xff1a;B站 基本逻辑运算 1. 与运算 &#xff08;and gate&#xff09; 2. 或运算 &#xff08;or gate&#xff09; 3. 非运算 &#xff08;not gate &#xff09; 复合逻辑运算 1. 与非运算&#xff08;nand&#xff09; 2. 或非运…

LeetCode Hot100 287.寻找重复数

题目&#xff1a; 给定一个包含 n 1 个整数的数组 nums &#xff0c;其数字都在 [1, n] 范围内&#xff08;包括 1 和 n&#xff09;&#xff0c;可知至少存在一个重复的整数。 假设 nums 只有 一个重复的整数 &#xff0c;返回 这个重复的数 。 你设计的解决方案必须 不修…

基于YOLOv8深度学习的钢材表面缺陷检测系统【python源码+Pyqt5界面+数据集+训练代码】目标检测、深度学习实战

《博主简介》 小伙伴们好&#xff0c;我是阿旭。专注于人工智能、AIGC、python、计算机视觉相关分享研究。 ✌更多学习资源&#xff0c;可关注公-仲-hao:【阿旭算法与机器学习】&#xff0c;共同学习交流~ &#x1f44d;感谢小伙伴们点赞、关注&#xff01; 《------往期经典推…

Python中进行特征重要性分析的8个常用方法

更多资料获取 &#x1f4da; 个人网站&#xff1a;ipengtao.com 在机器学习和数据科学领域&#xff0c;理解特征在模型中的重要性对于构建准确且可靠的预测模型至关重要。Python提供了多种强大的工具和技术&#xff0c;能够探索特征重要性的各个方面。 本文将详细介绍8种常用…

长沙上市公司董秘联谊会首次活动,到底讲了什么?

“首场活动就这么提振士气&#xff0c;一场干货满满的头脑风暴&#xff0c;真的太荣幸加入这个集体了。”这是可孚医疗科技股份有限公司副总裁、董秘薛小桥参加长沙上市公司董秘联谊会首次活动后的感慨。 11月29日&#xff0c;长沙上市公司董秘联谊会首场活动在爱尔眼科全球总…

【Excel】WPS快速按某列查重数据

查重值 excel列几条数据肉眼可见&#xff0c;如何千万级别数据查验呢&#xff1f;平时很少用&#xff0c;记录一下: 先框选列要验证的数据&#xff0c;然后&#xff1a;开始->条件格式->突出显示单元格规则->重复值 效果

Service的双向跨进程通信

一、客户端向服务端通信。 1、创建AIDL文件&#xff0c;用于生成跨进程通信代码。 // ITestService.aidl package com.example.servicetest;interface ITestService {void sayHello(); } 2、创建服务端Service&#xff0c;添加如下代码。 public class TestService extends…

基于springboot+vue的点餐系统(前后端分离)

博主主页&#xff1a;猫头鹰源码 博主简介&#xff1a;Java领域优质创作者、CSDN博客专家、公司架构师、全网粉丝5万、专注Java技术领域和毕业设计项目实战 主要内容&#xff1a;毕业设计(Javaweb项目|小程序等)、简历模板、学习资料、面试题库、技术咨询 文末联系获取 项目介绍…

java+springboot学生宿舍公寓管理系统xueshenggongy

经过查阅资料和调查统计发现&#xff0c;高校学生宿舍管理工作变得越来越繁重和琐碎&#xff0c;如在学生住宿安排&#xff08;特别是新生住宿安排&#xff09;、宿舍大幅调换、公共设施统计维护、宿舍杂费统计收取、宿舍卫生管理统计、出入登记记录等各个方法存在着大量问题和…

【Java8系列06】Java8数据计算

&#x1f49d;&#x1f49d;&#x1f49d;欢迎来到我的博客&#xff0c;很高兴能够在这里和您见面&#xff01;希望您在这里可以感受到一份轻松愉快的氛围&#xff0c;不仅可以获得有趣的内容和知识&#xff0c;也可以畅所欲言、分享您的想法和见解。 推荐:kwan 的首页,持续学…

设计模式之美学习笔记-单例模式-为什么说支持懒加载的双重检测不比饿汉式更优?

单例设计模式&#xff1a;一个类只允许创建一个对象&#xff08;或者实例&#xff09;&#xff0c;那这个类就是一个单例类&#xff0c;这种设计模式就叫作单例设计模式&#xff0c;简称单例模式。 实战案例一&#xff1a;处理资源访问冲突 我们先来看第一个例子。在这个例子…

近期知识点随笔

菜单查询&#xff08;编写权限时的细节&#xff09; 菜单查询list为了侧边框展示更完整&#xff08;不报空指针&#xff09; 登录时&#xff08;用户名&#xff09;查询出多个结果&#xff08;保证用户名唯一&#xff09; 文件上传 前端 对权限与菜单绑定的修改&#xff08;实…