[迁移学习]域自适应代码解析

news2024/11/19 21:20:54

一、概述

        代码来自:https://github.com/jindongwang/transferlearning,可以前往github下载代码,本文涉及的代码的位置为:Code->DeepDA。理论基础可以参见:[迁移学习]域自适应

        整体网络结构如下:可以视为一个分类网络(如Resnet50)+一个fc_adapt模块组成 。其损失函数为L=L_c(x_i,y_i)+\lambda Distance(D_s,D_t),即原来的交叉熵损失函数后面添加一个衡量源域和目标域之间距离的损失函数,一般为MMD,超参数\lambda用来控制此损失函数的权重。

 二、代码分析

        1.dataloader

        转到main函数,可以看到与dataloader相关的代码为:

source_loader, target_train_loader, target_test_loader, n_class = load_data(args)

        跳转后可以看到,在load_data(.)函数中数据集被分成下面三个部分,分别对应源域、目标域训练集、目标域测试集。

source_loader,target_train_loader,target_test_loader

        而实际起作用的是data_loader中的load_data,该函数重写了DataSet类并调用了dataloader

data = datasets.ImageFolder(root=data_folder, transform=transform['train' if train else 'test'])
    data_loader = get_data_loader(data, batch_size=batch_size, 
                                shuffle=True if train else False, 
                                num_workers=num_workers, **kwargs, drop_last=True if train else False)
    n_class = len(data.classes)

        2.model

        main函数中,与model有关的代码为:

model = get_model(args)
def get_model(args):
    model = models.TransferNet(
        args.n_class, transfer_loss=args.transfer_loss, base_net=args.backbone, max_iter=args.max_iter, use_bottleneck=args.use_bottleneck).to(args.device)
    return model

        该函数实际是从models文件中的TransferNet类中提取网络模型,继续转到TransferNet,该类有以下几个参数:类别个数,骨干网络类型,transfer_loss类型,以及一些调整骨干网络的参数。

class TransferNet(nn.Module):
    def __init__(self, num_class, base_net='resnet50', 
        transfer_loss='mmd', use_bottleneck=True, 
        bottleneck_width=256, max_iter=1000, **kwargs):

        随后看该网络的前向传递函数,基本可以归类为以下几部:

                ①骨干网络提取

source = self.base_network(source)
target = self.base_network(target)

                ②源域分类

source_clf = self.classifier_layer(source)

                代码中的分类器是一个全连接层,对应的参数是隐藏层通道数和输出类别数

self.classifier_layer = nn.Linear(feature_dim, num_class)

                ③源域分类损失函数计算

clf_loss = self.criterion(source_clf, source_label)

                代码中采用了一个交叉熵损失函数

self.criterion = torch.nn.CrossEntropyLoss()

                ④迁移学习

                这一小节是域自适应迁移学习和传统分类网络的最大区别,除了传统的cls_loss之外,该网络还计算了transfer_loss。

                代码中提供了lmmd,daan,bnm三种方式,其代码基本大同小异,这里选取lmmd进行解析,lmmd代码如下:

if self.transfer_loss == "lmmd":
    kwargs['source_label'] = source_label
    target_clf = self.classifier_layer(target)
    kwargs['target_logits'] = torch.nn.functional.softmax(target_clf, dim=1)

                该段代码的主要功能是从参数列表中获取source_label,同时使用上面同样的分类器对由骨干网络提取到的目标域特征进行分类(使用softmax进行分类,结果记录为target_logits),随后源域特征和目标域特征以及参数会被送入adapt_loss(.)模块用以计算transfer_loss

                同时从后续代码得知,如果使用最简单的mmd是不需要经过这一步处理的

transfer_loss = self.adapt_loss(source, target, **kwargs)
self.adapt_loss = TransferLoss(**transfer_loss_args)

                TransferLoss类为一个transfer_loss的提取模块,提供了6种不同的损失函数,这里以mmd为例。

if loss_type == "mmd":
    self.loss_func = MMDLoss(**kwargs)

                MMDLoss的完整代码如下:

import torch
import torch.nn as nn

class MMDLoss(nn.Module):
    def __init__(self, kernel_type='rbf', kernel_mul=2.0, kernel_num=5, fix_sigma=None, **kwargs):
        super(MMDLoss, self).__init__()
        self.kernel_num = kernel_num
        self.kernel_mul = kernel_mul
        self.fix_sigma = None
        self.kernel_type = kernel_type

    def guassian_kernel(self, source, target, kernel_mul, kernel_num, fix_sigma):
        n_samples = int(source.size()[0]) + int(target.size()[0])
        total = torch.cat([source, target], dim=0)
        total0 = total.unsqueeze(0).expand(
            int(total.size(0)), int(total.size(0)), int(total.size(1)))
        total1 = total.unsqueeze(1).expand(
            int(total.size(0)), int(total.size(0)), int(total.size(1)))
        L2_distance = ((total0-total1)**2).sum(2)
        if fix_sigma:
            bandwidth = fix_sigma
        else:
            bandwidth = torch.sum(L2_distance.data) / (n_samples**2-n_samples)
        bandwidth /= kernel_mul ** (kernel_num // 2)
        bandwidth_list = [bandwidth * (kernel_mul**i)
                          for i in range(kernel_num)]
        kernel_val = [torch.exp(-L2_distance / bandwidth_temp)
                      for bandwidth_temp in bandwidth_list]
        return sum(kernel_val)

    def linear_mmd2(self, f_of_X, f_of_Y):
        loss = 0.0
        delta = f_of_X.float().mean(0) - f_of_Y.float().mean(0)
        loss = delta.dot(delta.T)
        return loss

    def forward(self, source, target):
        if self.kernel_type == 'linear':
            return self.linear_mmd2(source, target)
        elif self.kernel_type == 'rbf':
            batch_size = int(source.size()[0])
            kernels = self.guassian_kernel(
                source, target, kernel_mul=self.kernel_mul, kernel_num=self.kernel_num, fix_sigma=self.fix_sigma)
            XX = torch.mean(kernels[:batch_size, :batch_size])
            YY = torch.mean(kernels[batch_size:, batch_size:])
            XY = torch.mean(kernels[:batch_size, batch_size:])
            YX = torch.mean(kernels[batch_size:, :batch_size])
            loss = torch.mean(XX + YY - XY - YX)
            return loss

                里面有大量的数学变换在这里就不进行深究了。其前向函数的作用是将源域和目标域之间的差距计算出来并返回loss

                ⑤返回函数

                该模型的前向传递函数最后会返回两个参数:clf_loss, transfer_loss,这两个参数将会被用于后面的反向传递。

        3.训练

        训练过程中,会从source_loader中提取源域图像和标签,从target_train_loader中提取目标域图像(不需要目标域的标签)。

        然后将数据这三个数据送入模型,得到clf_loss和transfer_loss。最后将transfer_loss×权重系数lambda后与clf_loss相加后可以得到最终的损失函数loss。

clf_loss, transfer_loss = model(data_source, data_target, label_source)
loss = clf_loss + args.transfer_loss_weight * transfer_loss

        再然后就是自动的反向传递:

optimizer.zero_grad()
loss.backward()
optimizer.step()

        4.测试

        测试过程与上面的训练大同小异,主要是不再需要前向传递。使用的是model中的predict函数而不是默认的前向传递函数

s_output = model.predict(data)

        该函数与之前的前向传递函数相比没有adapt_loss模块:

def predict(self, x):
    features = self.base_network(x)
    x = self.bottleneck_layer(features)
    clf = self.classifier_layer(x)
    return clf

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/641171.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Win7下静态变量析构导致进程卡死无法退出问题解决

项目中在用户机器Win7系统上好几次出现进程卡死,无法退出,在用户机器上抓取了dump,发现是在DllMain函数中执行了静态变量的析构,这个静态变量析构的时候会使用std::condition_variable 类型的成员变量通知其他线程退出。同时本地在…

PDF怎样转换成长图?这个方法,超级简单!

在当今社会,PDF文档广泛应用于各个领域。然而,在某些情况下,我们可能需要将多个PDF页面合并成一个单独的长图,以便更方便地浏览、共享或嵌入到其他文件中。为了满足这一需求,记灵在线工具应运而生,它为我们…

一种全新的图像变换理论的实验(六)——研究目的替代DCT和小波

一、变换算法在图像视频中的核心作用 我们国产的变换算法是比较少的,基本上都是在小波、DCT和FFT上发展优化升级的应用。我之前的文章给出了一种基于加权概率模型的变换算法,该算法在一定的程度上能有效的保存低频数据。而且我基于该算法给出了一些新的…

微信小程序快速开发— TDesign模版初始化

最近有个商城类的小程序业务需要快速上线,看了一下微信官方的模版库,相中了TDesign,调研了半天,决定就从这个开始干。 调研的两个重点: 1、网络请求,即数据获取 2、模板本身存在些bug,如&…

从Kotlin中return@forEach了个寂寞

点击上方蓝字关注我,知识会给你力量 今天在Review(copy)同事代码的时候,发现了一个问题,想到很久之前,自己也遇到过这个问题,那么就来看下吧。首先,我们抽取最小复现代码。 (1..7).f…

Python 基于人脸识别的实验室智能门禁系统的设计与实现,附源码

1 简介 本基于人脸识别的实验室智能门禁系统通过大数据和信息化的技术实现了门禁管理流程的信息化的管理操作。平台的前台页面通过简洁的平台页面设计和功能结构的分区更好的提高用户的使用体验,没有过多的多余的功能,把所有的功能操作都整合在功能操作…

聚观早报|微软Xbox2023发布会汇总;苹果VisionPro头显低配版曝光

今日要闻:微软Xbox 2023发布会汇总;苹果Vision Pro头显低配版曝光;台积电在熊本县建设半导体工厂;苹果今年或能出货2.4亿台;中国含氯废塑料高效无害升级回收 微软Xbox 2023发布会汇总 6 月 12 日凌晨,微软…

Java 实战介绍 Cookie 和 Session 的区别

HTTP 是一种不保存状态的协议,即无状态协议,HTTP 协议不会保存请求和响应之间的通信状态,协议对于发送过的请求和响应都不会做持久化处理。 无状态协议减少了对服务压力,如果一个服务器需要处理百万级用户的请求状态,对…

Linux教程——Linux绝对路径和相对路径详解

在 Linux 中,简单的理解一个文件的路径,指的就是该文件存放的位置,只要我们告诉 Linux 系统某个文件存放的准确位置,那么它就可以找到这个文件。 指明一个文件存放的位置,有 2 种方法,分别是使用绝对路径和…

深度解读 KaiwuDB 的排序操作

一、单节点执行 在单节点环境执行一条简单的 SQL 语句 SELECT * FROM NATION ORDER BY N_NAME。NATION 是一张小表,只有 25 条记录;对第 2 列 N_NAME 进行升序排列。 1. 抽象语法树 上述示例中的 SQL 语句经过分析器解析后得到 AST,如下图…

(文章复现)面向配电网韧性提升的移动储能预布局与动态调度策略(2)-灾后调度matlab代码

参考文献: [1]王月汉,刘文霞,姚齐,万海洋,何剑,熊雪君.面向配电网韧性提升的移动储能预布局与动态调度策略[J].电力系统自动化,2022,46(15):37-45. 1.基本原理 1. 1 目标函数 在灾害发生后,配电网失去主网供电,设故障的持续时间可根据灾害…

基于SpringBoot+Vue的酒店管理系统设计与实现

博主介绍: 大家好,我是一名在Java圈混迹十余年的程序员,精通Java编程语言,同时也熟练掌握微信小程序、Python和Android等技术,能够为大家提供全方位的技术支持和交流。 我擅长在JavaWeb、SSH、SSM、SpringBoot等框架下…

代码随想录 二叉树 Java(二)

文章目录 (*中等)222. 完全二叉树的节点个数(*简单)110. 平衡二叉树(*简单)257. 二叉树的所有路径(简单)404. 左叶子之和(简单)513. 找树左下角的值&#xff…

设计模式的原则(一)

相信自己,无论自己到了什么局面,请一定要继续相信自己。 新的世界开始了,接下来,老蝴蝶带领大家学习一下设计模式。 我们先了解一下 设计原则 一.设计模式 一.一 设计原则 设计模式常用的七大原则: 单一职责原则接口隔离原则…

【项目】接入飞书平台

前言 项目有和飞书打通的需求,因为是第一次打通,摸索过程还是花了些时间的,现在相关笔记分享给大家。 步骤 1、熟悉开发文档 熟悉飞书的开发文档:开发文档 ,找到你需要的接口,拿我为例,我需…

长生的秘密:肠道菌群代谢组学

欲遂长生志,但求千金方。长生不老是人类文明历程中苦苦追寻的目标之一,影响人类寿命的因素也复杂多样,包括但不限于遗传因素如性别、线粒体状态、染色体稳定性、端粒长短、疾病、干细胞活性;环境因素如肠道微生物、饮食、运动、空…

如何解决“RuntimeError: CUDA Out of memory”问题

当遇到这个问题时,你可以尝试一下这些建议,按代码更改的顺序递增: 减少“batch_size” 降低精度 按照错误说的做 清除缓存 修改模型/训练 在这些选项中,如果你使用的是预训练模型,则最容易和最有可能解决问题的选项是第一个。 修改batchsize 如果你是在运行现成的代码或…

页面置换算法的模拟与比较

前言 在计算机操作系统中,页面置换算法是虚拟存储管理中的重要环节。通过对页面置换算法的模拟实验,我们可以更深入地理解虚拟存储技术,并比较不同算法在请求页式虚拟存储管理中的优劣。 随着计算机系统和应用程序的日益复杂,内存…

技术管理方法论

今天来跟大家分享一下我对于技术管理的理解。先介绍一下对于管理最普遍的认识,我们每一个人在公司里面都有两种类型的角色,一种是通过个人的能力和产出来实现组织利益的最大化,另外一类人就是通过管理使得一群人产出结果最大化。 也就是我们…

阿里P8传授的80K+星的MySQL笔记助我修行,一周快速进阶

MySQL 是最流行的关系型数据库之一,广泛的应用在各个领域。下面这些问题对于程序员的你来说应该很常见,来看看你面对这些问题是否会胆怯? MySQL数据库作发布系统的存储,一天五万条以上的增量,预计运维三年,怎么优化? …