EIS-Net

news2024/11/24 5:05:54

我们提出了一种新的领域泛化框架(称为EISNet),该框架利用来自多源领域图像的外在关系监督和内在自我监督学习,学习如何同时在不同领域中进行泛化。

  • 具体而言,我们采用多任务学习范式,通过特征嵌入来构建我们的框架。
  • 除了进行常规的监督识别任务外,我们无缝地集成了动量度量学习任务和自监督辅助任务,以共同整合外在和内在监督。
  • 此外,我们还开发了一种有效的动量度量学习方案,采用K-hard负样本挖掘来提高网络的泛化能力。

在这里插入图片描述

  • 我们维护一个动量更新编码器(MuEncoder),生成存储在大型内存库中的动量更新嵌入。
  • 此外,我们设计了一个K-hard负选择器,以从内存库中定位信息量大的三元组来计算三元组损失。辅助的自监督任务预测图像内部的patch顺序。

One is an extrinsic supervision with momentum metric learning, and the other is an intrinsic supervision with a self-supervised auxiliary task.
The momentum metric learning is employed by a triplet loss with a K-hard negative selector on the momentum updated embeddings stored in a large memory bank. We implement a self-supervised auxiliary task by predicting the order of patches within an image. All these tasks adopt a shared encoder f and are seamlessly integrated into an end-to-end learning framework. Below, we introduce the extrinsic supervision and intrinsic self-supervision in detail.

Extrinsic Supervision with Momentum Metric Learning

对于领域泛化问题,需要确保具有相同标签的样本特征彼此接近,而不同类别样本的特征之间相距较远。否则,在未知的目标领域上进行预测时,可能会遭受模糊的决策边界和性能下降[8,20]。这与度量学习的理念是一致的。

因此,我们设计了一个动量度量学习方案,通过考虑跨域样本之间的相互关系,鼓励网络学习这种领域无关但类别特定的特征。具体而言,我们提出了一种新的K-hard负选择器用于三元组损失,通过在内存库中选择信息量大的三元组来提高训练效果,以及一个动量更新编码器,以保证内存库中的嵌入表示的一致性。

K-hard negative selector for triplet loss

The triplet loss is defined as:

L_triplet = max(0, d(fθ(xa), fθ(xp)) - d(fθ(xa), fθ(xn)) + margin)

where d(·, ·) is a distance function, and margin is a hyperparameter that controls
the margin between the positive and negative pairs.

However, randomly selecting negative samples may not be sufficient to provide
informative triplets for effective training. Therefore, we propose a K-hard negative
selector to select the K hardest negative samples for each anchor sample xa from
the memory bank
.

The memory bank stores the momentum updated embeddings
generated by the momentum updated Encoder MuEncoder.

The K-hard negative selector ensures that the selected negative samples have the highest distance to the anchor sample among all the negative samples in the memory bank.

  • This strategy improves the training effectiveness by selecting informative triplets with larger distances and more challenging sample pairs.

Intrinsic Supervision with Self-supervised Auxiliary Task

许多工作都专注于设计辅助自监督任务,例如旋转度预测和图像中两个补丁的相对位置预测[7,13,21]。在这里,我们采用最近提出的拼图任务[2,32]作为我们的辅助任务。然而,大多数关注高级语义特征学习的自监督任务都可以纳入我们的框架。

  • 具体而言,我们首先将图像分成九个(3×3)patch,并按照[2]的方法将这些patch随机组合成30种不同的组合方式

正如[2]所指出的,当类别数设置为30时,模型的性能最高,当任务难度增加时,随着组合方式的增加,预测性能下降。一个新的辅助任务分支ha跟随提取的特征表示fθ,预测补丁的顺序。交叉熵损失被应用于解决这个顺序分类任务:
在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

import argparse
import os
import torch
from torch import nn
from torch.nn import functional as F
from data import data_helper
from models import model_factory
from optimizer.optimizer_helper import *
from utils.Logger import Logger
from utils.losses import *
from utils.anchor_selector import *
from tqdm import tqdm
import torch.backends.cudnn as cudnn
cudnn.benchmark = True


def get_args():
    parser = argparse.ArgumentParser(description="Script to launch jigsaw training", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('-g', '--gpu', type=int, default=0, help='gpu id')
    parser.add_argument("--source", choices=data_helper.available_datasets, help="Source", nargs='+')
    parser.add_argument("--target", choices=data_helper.available_datasets, help="Target")
    parser.add_argument("--batch_size", "-b", type=int, default=64, help="Batch size")
    parser.add_argument("--image_size", type=int, default=225, help="Image size")
    # data aug stuff
    parser.add_argument("--min_scale", default=0.8, type=float, help="Minimum scale percent")
    parser.add_argument("--max_scale", default=1.0, type=float, help="Maximum scale percent")
    parser.add_argument("--random_horiz_flip", default=0.0, type=float, help="Chance of random horizontal flip")
    parser.add_argument("--jitter", default=0.4, type=float, help="Color jitter amount")
    parser.add_argument("--tile_random_grayscale", default=0.1, type=float, help="Chance of randomly greyscaling a tile")

    parser.add_argument("--limit_source", default=None, type=int, help="If set, it will limit the number of training samples")
    parser.add_argument("--limit_target", default=None, type=int, help="If set, it will limit the number of testing samples")

    parser.add_argument("--learning_rate", "-l", type=float, default=.01, help="Learning rate")
    parser.add_argument("--learning_rate_moco", "-lmoco", type=float, default=.003, help="Learning rate")
    parser.add_argument("--epochs", "-e", type=int, default=100, help="Number of epochs")
    parser.add_argument("--n_classes", "-c", type=int, default=31, help="Number of classes")
    parser.add_argument("--jigsaw_n_classes", "-jc", type=int, default=30, help="Number of classes for the jigsaw task")
    parser.add_argument("--network", choices=model_factory.nets_map.keys(), help="Which network to use", default="caffenet")
    parser.add_argument("--jig_weight", type=float, default=0.7, help="Weight for the jigsaw puzzle")
    parser.add_argument("--tri_weight", type=float, default=0.001, help="Weight for the triplet loss")
    parser.add_argument("--moco_weight", type=float, default=0.1, help="Weight for the moco loss")
    parser.add_argument("--ooo_weight", type=float, default=0, help="Weight for odd one out task")
    parser.add_argument("--tf_logger", type=bool, default=True, help="If true will save tensorboard compatible logs")
    parser.add_argument("--val_size", type=float, default="0.1", help="Validation size (between 0 and 1)")
    parser.add_argument("--folder_name", default=None, help="Used by the logger to save logs")
    parser.add_argument("--bias_whole_image", default=0.9, type=float, help="If set, will bias the training procedure to show more often the whole image")

    parser.add_argument("--train_all", default=True, type=bool, help="If true, all network weights will be trained")
    parser.add_argument("--suffix", default="", help="Suffix for the logger")
    parser.add_argument("--nesterov", default=False, type=bool, help="Use nesterov")
    parser.add_argument("--margin", default=0.2, type=float, help="Margin in triplet loss")

    # loss function
    parser.add_argument('--softmax', action='store_true', help='using softmax contrastive loss rather than NCE')
    parser.add_argument('--nce_k', type=int, default=4096)
    parser.add_argument('--nce_t', type=float, default=0.07)
    parser.add_argument('--nce_m', type=float, default=0.5)

    parser.add_argument('--k_triplet', type=int, default=5)

    # memory setting
    parser.add_argument('--moco', default=True, action='store_true', help='using MoCo (otherwise Instance Discrimination)')
    parser.add_argument('--alpha', type=float, default=0.999, help='exponential moving average weight')
    
    return parser.parse_args()


def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']


def moment_update(model, model_ema, m):
    """ model_ema = m * model_ema + (1 - m) model """
    for p1, p2 in zip(model.parameters(), model_ema.parameters()):
        p2.data.mul_(m).add_(1-m, p1.detach().data)


class Trainer:
    def __init__(self, args, device):
        self.args = args
        self.device = device
        model = model_factory.get_network(args.network)(classes=args.n_classes)
        model_ema = model_factory.get_network(args.network)(classes=args.n_classes)
        self.model = model.to(device)
        self.model_ema = model_ema.to(device)
        if args.moco:
            moment_update(self.model, self.model_ema, 0)
        self.tri_weight = self.args.tri_weight
        print(self.model)
        self.source_loader, self.val_loader = data_helper.get_train_dataloader(args, False)

        self.target_loader = data_helper.get_val_dataloader(args)
        self.test_loaders = {"val": self.val_loader, "test": self.target_loader}
        self.len_dataloader = len(self.source_loader)
        print("Dataset size: train %d, val %d, test %d" % (len(self.source_loader.dataset), len(self.val_loader.dataset), len(self.target_loader.dataset)))
        self.optimizer, self.scheduler = get_optim_and_scheduler(self.model, args.epochs, args.learning_rate)

        self.jig_weight = args.jig_weight
        self.n_classes = args.n_classes
        if args.target in args.source:
            self.target_id = args.source.index(args.target)
            print("Target in source: %d" % self.target_id)
            print(args.source)
        else:
            self.target_id = None

        self.moco_weight_init = self.args.moco_weight
        self.moco_weight = self.moco_weight_init
        self.k_triplet = args.k_triplet
        self.initialize_queue()

    def update_moco_weight(self):
        self.moco_weight = self.moco_weight_init + (1-self.moco_weight_init)/self.args.epochs * self.current_epoch

    def queue_data(self, data, k, label_list, label):
        self.queue = torch.cat([data, k], dim=0)
        label = label.float().unsqueeze(1)
        self.label_list = torch.cat([label_list, label], dim=0)

    def dequeue_data(self, K=4096):
        if len(self.queue) > K:
            self.queue = self.queue[-K:]
            self.label_list = self.label_list[-K:]

    def initialize_queue(self):
        queue = torch.zeros((0, 128), dtype=torch.float).cuda()
        label_list = torch.zeros((0, 1), dtype=torch.float).cuda()

        for batch_idx, ((data, jig, class_l), d_idx) in enumerate(self.source_loader):
            data1, data2 = torch.split(data, [3, 3], dim=1)
            x_k = data1
            x_k = x_k.cuda()
            with torch.no_grad():
                _, _, k, _ = self.model_ema(x_k)
            k = k.detach()
            self.queue_data(queue, k, label_list, class_l.cuda())
            self.dequeue_data(K=self.args.nce_k)
            break

    def momentum_update(self, model_q, model_k, beta=0.999):
        param_k = model_k.state_dict()
        param_q = model_q.named_parameters()
        for n, q in param_q:
            if n in param_k:
                param_k[n].data.copy_(beta * param_k[n].data + (1 - beta) * q.data)
        model_k.load_state_dict(param_k)

    def _do_epoch(self):
        margin = self.args.margin
        criterion_class = nn.CrossEntropyLoss()
        triplet_loss = OnlineKTripletLoss(margin, KSemihardNegativeTripletSelectorFromMomentum(margin, k=self.k_triplet))
        jigen_loss = nn.CrossEntropyLoss()

        self.model.train()
        self.model_ema.eval()
        moco_loss = 0

        for it, ((data, order, class_l), d_idx) in enumerate(self.source_loader):
            data, class_l, d_idx, order = data.to(self.device), class_l.to(self.device), d_idx.to(self.device), order.to(self.device)

            self.optimizer.zero_grad()
            data1, data2 = torch.split(data, [3, 3], dim=1)  # normal  Jigen

            class_logit, jig, q, qc = self.model(data2)  # , lambda_val=lambda_val)
            with torch.no_grad():
                _, _, k, kc = self.model_ema(data1)

            k = k.detach()
            self.queue_data(self.queue, k, self.label_list, class_l)
            self.dequeue_data(K=self.args.nce_k)

            moco_loss, _ = triplet_loss(q, k, self.queue, class_l, self.label_list)
            moco_loss = moco_loss * self.moco_weight

            jig_loss = jigen_loss(jig, order)
            jig_loss = jig_loss * self.jig_weight
            class_loss = criterion_class(class_logit, class_l)
            _, cls_pred = class_logit.max(dim=1)

            loss = class_loss + jig_loss + moco_loss

            loss.backward()
            self.optimizer.step()

            self.momentum_update(self.model, self.model_ema)

            if (it) % 30 == 0:

                print("{}/{} iter/epoch, [losses] class: {}, jig: {}, moco: {}, total: {}. ".format(it, self.current_epoch,
                                                                                                            class_loss.item(),
                                                                                                            jig_loss.item(),
                                                                                                            moco_loss.item(),
                                                                                                            loss.item(),
                                                                                                            ))
            self.logger.writer.add_scalar('training loss/class', class_loss.item(), self.current_epoch*len(self.source_loader)+it)
            self.logger.writer.add_scalar('training loss/jig', jig_loss.item(),
                                          self.current_epoch * len(self.source_loader) + it)
            self.logger.writer.add_scalar('training loss/moco', moco_loss.item(),
                                          self.current_epoch * len(self.source_loader) + it)
            self.logger.writer.add_scalar('training loss/total', loss.item(),
                                          self.current_epoch * len(self.source_loader) + it)

        self.model.eval()
        with torch.no_grad():
            for phase, loader in self.test_loaders.items():
                total = len(loader.dataset)
                if loader.dataset.isMulti():
                    class_correct, single_acc = self.do_test_multi(loader)
                    print("Single vs multi: %g %g" % (float(single_acc) / total, float(class_correct) / total))
                else:
                    class_correct = self.do_test(loader)
                class_acc = float(class_correct) / total
                self.logger.writer.add_scalar('acc/'+phase, class_acc,
                                              self.current_epoch * len(self.source_loader) + it, )
                print("[{}] acc: {}".format(phase, class_acc))
                self.results[phase][self.current_epoch] = class_acc

    def do_test(self, loader):
        class_correct = 0
        for it, ((data, class_l), _) in enumerate(loader):
            data, class_l = data.to(self.device), class_l.to(self.device)
            class_logit, _, _, _ = self.model(data)
            _, cls_pred = class_logit.max(dim=1)
            class_correct += torch.sum(cls_pred == class_l.data)
        return class_correct

    def do_test_multi(self, loader):
        class_correct = 0
        single_correct = 0
        for it, ((data, class_l), d_idx) in enumerate(loader):
            data, class_l = data.to(self.device), class_l.to(self.device)
            n_permutations = data.shape[1]
            class_logits = torch.zeros(n_permutations, data.shape[0], self.n_classes).to(self.device)
            for k in range(n_permutations):
                class_logits[k] = F.softmax(self.model(data[:, k])[1], dim=1)
            class_logits[0] *= 4 * n_permutations  # bias more the original image
            class_logit = class_logits.mean(0)
            _, cls_pred = class_logit.max(dim=1)
            single_logit, _ = self.model(data[:, 0])
            _, single_logit = single_logit.max(dim=1)
            single_correct += torch.sum(single_logit == class_l.data)
            class_correct += torch.sum(cls_pred == class_l.data)
        return class_correct, single_correct

    def save_tsne(self):
        self.model.eval()
        embedding = torch.zeros((0, 2048), dtype=torch.float)
        embedding_label = torch.zeros((0, 1), dtype=torch.long)
        with torch.no_grad():
            for it, ((data, class_l), _) in enumerate(self.target_loader):
                data, class_l = data.to(self.device), class_l.to(self.device)

                class_logit, jig, q, qc = self.model(data)

                embedding = torch.cat([embedding, qc.cpu()])
                embedding_label = torch.cat([embedding_label, class_l.unsqueeze(1).cpu()])
                torch.cuda.empty_cache()
        self.logger.writer.add_embedding(embedding.data,
                                         metadata=embedding_label.data,
                                         tag="TestImg")

    def do_training(self):
        self.logger = Logger(self.args, update_frequency=30)  # , "domain", "lambda"
        self.results = {"val": torch.zeros(self.args.epochs), "test": torch.zeros(self.args.epochs)}
        for self.current_epoch in tqdm(range(self.args.epochs)):
            self._do_epoch()
            torch.cuda.empty_cache()
            self.scheduler.step()
            self.logger.writer.add_scalar('lr/basic', get_lr(self.optimizer), self.current_epoch)

        self.save_tsne()
        val_res = self.results["val"]
        test_res = self.results["test"]
        idx_best = val_res.argmax()

        print('Val best test: ', test_res[idx_best], 'Test best test: ', test_res.max())
        return self.logger, self.model


def main():
    args = get_args()
    os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    trainer = Trainer(args, device)
    trainer.do_training()


if __name__ == "__main__":
    torch.backends.cudnn.benchmark = True
    main()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/559063.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

AI智慧安监视频平台EasyCVR视频出现不能播放的情况排查与解决

EasyCVR基于云边端协同,可支持海量视频的轻量化接入与汇聚管理。平台兼容性强、拓展度高,可提供视频监控直播、视频轮播、视频录像、云存储、回放与检索、智能告警、服务器集群、语音对讲、云台控制、电子地图、H.265自动转码、平台级联等功能。 有用户反…

如何动态生成列表视图?

UE5 插件开发指南 前言0 什么是列表视图?1 如何动态生成?1.0 指定ListView生成的条目前言 这里将其拆分成两个问题来分析: (1)什么是列表视图? (2)如何动态生成? 0 什么是列表视图? 列表视图就是用来展示一系列对象的UI列表,在UE编辑器的UserWidget设计窗口中可以找到…

linux 安装 ffmpeg

linux 安装 ffmpeg windows上安装,直接下载压缩包解压。linux安装,找了半天各种技术文章,说最好编译安装,按照步骤安装编译环境编译成功了,但是使用的时候总要安装各种外部库,转码转不了等等问题...... 最…

城市生命线监测系统包括哪些内容?

城市排水、供水、燃气、供热、桥梁、隧道、综合管廊等基础设施是城市正常运转的基石,被称为“城市生命线”。城市生命线一旦出现故障或事故,将会给城市和居民带来巨大的经济和生活损失。通过对城市生命线的实时监测和预警,可以及时发现潜在的…

第十五届“中国电机工程学会杯”数学建模竞赛

第十五届电工杯5月26号就要开始啦,今天给大家回顾第十四届全国大学生电工数学建模竞赛A题,主要从赛题重述和问题分析与代码实战展开。第十五届全国大学生电工数学建模竞赛已经开始报名了哦,后续我也会分享对应的建模思路哦,大家记…

Leetcode452. 用最少数量的箭引爆气球

Every day a Leetcode 题目来源:452. 用最少数量的箭引爆气球 解法1:排序 贪心 题解:用最少数量的箭引爆气球 我们首先随机地射出一支箭,再看一看是否能够调整这支箭地射出位置,使得我们可以引爆更多数目的气球。…

CVPR论文解读 | 点云匹配的旋转不变变压器

原创 | 文 BFT机器人 传统的手工特征描述符通常具有内在的旋转不变性,但是最近的深度匹配器通常通过数据增强来获得旋转不变性。 然而,由于增强旋转数量有限,无法覆盖连续SO(3)空间中所有可能的旋转,因此这…

VC6.0的工程设置解读Project--Settings

做开发差不多一年多了,突然感觉对VC的工程设置都不是很清楚,天天要和VC见面,虽然通常情况下一般都不会修改工程设置,但是还是有必要对它的一些设置项的来龙去脉有一定的了解,所以狂查资料,稍作整理&#xf…

(仿真)创建 URDF 机器人模(1)

继上一篇基础篇的结束,不用看以前的也可以,这里是不受前面的影响的。 如果你没有这个目录,就创建一个catkin_ws文件夹 然后里面再一个src文件夹就ok了,我在基础篇第一篇的时候就有这个文件夹了,所有我现在是直接进入 …

【 计算机组成原理 】第七章 外围设备

系列文章目录 第一章 计算系统概论 第二章 运算方法和运算器 第三章 多层次的存储器 第四章 指令系统 第五章 中央处理器 第六章 总线系统 第七章 外围设备 第八章 输入输出系统 文章目录 系列文章目录前言第七章 外围设备7.1 外围设备概述7.1.1 外围设备的一般功能7.1.2 外围…

zabbix安装部署、三分钟分钟部署zabbix监控(超详细)

zabbix安装部署 1,快速安装部署zabbix2,一键脚本安装zabbix 1,快速安装部署zabbix 1,关闭防火墙,selinux systemctl stop firewalld systemctl disable firewalld setenforce 0 #临时 sed -i s/SELINUXenforcing/SE…

运维宝典大全

运维宝典大全 网络拓展Linux 概述什么是LinuxUnix和Linux有什么区别?什么是 Linux 内核?Linux的基本组件是什么?Linux 的体系结构BASH和DOS之间的基本区别是什么?Linux 开机启动过程?Linux系统缺省的运行级别&#xff…

Jmeter性能测试 -3 Jmeter使用中的一些问题

请求内容出现乱码的处理方法 1 内容编码:utf-8 2 请求头添加编码 Content-Type: application/json;charsetutf-8 3 请求体为参数类型时,勾选参数“编码”,编码为urlencoded编码。当参数值为非字符(汉字、特殊符号)时…

全面了解Java连接MySQL的基础知识,快速实现数据交互

全面了解Java连接MySQL的基础知识,快速实现数据交互 1. 数据库的重要性2. MySQL数据库简介2.1 MySQL数据库的基本概念2.2 MySQL的基本组成部分包括服务器、客户端和存储引擎。2.3 安装MySQL数据库2.3.1安装MySQL数据库2.3.2 下载MySQL安装程序2.3.3 运行MySQL安装程…

API接口|了解API接口测试|API接口测试指南

part1.什么是API接口 API接口是指应用程序接口(Application Programming Interface),它是一组定义、控制和描述软件程序中不同组件之间交互的方式和规则。 API接口允许不同的软件系统之间进行信息共享和相互访问,而无需了解在其…

QT6之QTimeZone

一、简介 QTimeZone 标识时间表示与 UTC 的关系,也可以表示 UTC、本地时间和与 UTC 的固定偏移量。 QTimeZone(自 Qt 6.5 起)统一了它们与一般时间系统的表示,大多数操作系统普遍支持的一个时区被指定为本地时间。 总结&#x…

bim精装修常用软件【建模助手】有什么功能?

大家好,这里是BIM建模助手。 今天有个重磅消息要告诉大家,那就是BIM建模助手的【精装模块】上线啦! 为了辅助BIMer快速设计出精装修的房屋效果,我们开发了【精装模块】,无论是装饰面层、铺排瓷砖、布置吊顶、统计出量…

chatgpt赋能Python-python_kazoo

Python Kazoo: 优质的分布式应用程序开发工具 什么是Python Kazoo? Python Kazoo是一个Python库,它提供了高级别的API,使得分布式应用程序的开发更加容易。Kazoo是基于Zookeeper实现的,它是一个分布式系统协调器,为分…

力扣sql中等篇练习(二十六)

力扣sql中等篇练习(二十六) 1 世界排名的变化 1.1 题目内容 1.1.1 基本题目信息1 1.1.2 基本题目信息2 1.1.3 示例输入输出 a 示例输入 b 示例输出 1.2 示例sql语句 # 分别求出变化前后的排名 然后再进行内连接即可 # row_number()里面也可以用多个字段加减的表达式去进行…

【C++】类和对象(3)

文章目录 一、初始化列表二、explicit关键字三、static成员四、友元4.1 友元函数4.2 友元类 五、内部类六、匿名对象七、编译器的优化 一、初始化列表 首先我们先回顾一下构造函数,对象的初始化由构造函数来完成,我们可以在构造函数的函数体内对对象的成…