Domain Adaptation(李宏毅)机器学习 2023 Spring HW11 (Boss Baseline)

news2025/1/12 9:57:25

1. 领域适配简介

领域适配是一种迁移学习方法,适用于源领域和目标领域数据分布不同但学习任务相同的情况。具体而言,我们在源领域(通常有大量标注数据)训练一个模型,并希望将其应用于目标领域(通常只有少量或没有标注数据)。然而,由于这两个领域的数据分布不同,模型在目标领域上的性能可能会显著下降。领域适配技术的目标是通过对模型进行适配,缩小源领域与目标领域之间的差距,从而提升模型在目标领域的表现。

Domain Shift (source: https://speech.ee.ntu.edu.tw/~hylee/ml/ml2021-course-data/da_v6.pdf)

以数字识别为例,如果我们的源数据是灰度图像,并且在这些数据上训练模型,我们可以预期模型会取得相当不错的效果。然而,如果我们将这个在灰度图像上训练的模型用于分类彩色图像,模型的表现可能会较差。这是因为这两个数据集之间存在领域转移。

领域适配方法可以根据目标领域中标签的可用性进行分类:

  1. 有监督领域适配:源领域和目标领域都有标注数据。这种情况较为少见,因为领域适配的主要动机是目标领域标签的稀缺性。

  2. 无监督领域适配:源领域有标注数据,而目标领域没有标注数据。这是最常见且最具挑战性的情况。

  3. 半监督领域适配:源领域有标注数据,目标领域则只有少量标注数据。

Different Domain Adaptation Scenarios (source: https://speech.ee.ntu.edu.tw/~hylee/ml/ml2021-course-data/da_v6.pdf)

我们的博客和作业主要关注目标领域缺乏标注数据的场景。

解决这个问题的基本概念如下:我们旨在找到一个特征提取器,它能够接收输入数据并输出特征空间。这个特征提取器应该能够滤除领域特定的变化,同时保留不同领域之间共享的特征。例如,在以下的示例中,特征提取器应该能够忽略图像的颜色,对于相同的数字,不论其颜色如何,都能生成具有相同分布的特征。

Basic Idea of Domain Adaptation (source: https://speech.ee.ntu.edu.tw/~hylee/ml/ml2021-course-data/da_v6.pdf)

研究人员提出了许多方法,其中对抗学习方法是最常见且最有效的技术之一。

Domain Adversarial Training - 1 (source: https://speech.ee.ntu.edu.tw/~hylee/ml/ml2021-course-data/da_v6.pdf)

我们将一个标准网络分为两部分:特征提取器和标签预测器。在训练过程中,我们以标准的有监督方式在源领域数据上训练整个网络。对于目标领域数据,我们只使用特征提取器提取特征,并采用技术手段将目标领域的特征与源领域的特征对齐。

Domain Adversarial Training - 2 (source: https://speech.ee.ntu.edu.tw/~hylee/ml/ml2021-course-data/da_v6.pdf)

具体来说,我们设计了一个新的领域分类器,它是一个二分类器,输入特征向量并判断输入数据是来自源领域还是目标领域。另一方面,特征生成器的设计目的是“欺骗”领域分类器,使其无法正确区分来源领域。

Domain Adversarial Training - 3 (source: https://speech.ee.ntu.edu.tw/~hylee/ml/ml2021-course-data/da_v6.pdf)

如果我们仔细思考上述方法,我们可以直观地理解,尽管对抗训练可以使源领域和目标领域的整体分布更加相似,如下图左侧所示,但这种分布可能并不适合或不适用于机器学习任务。理想情况下,我们期望获得右侧图像所示的分布。

Limitation of Domain Adversarial Training (source: https://speech.ee.ntu.edu.tw/~hylee/ml/ml2021-course-data/da_v6.pdf)

当然,已有大量论文提出了针对这一问题的解决方法。为了在这次作业中通过strong 和 boss baseline,我们需要深入相关文献,并采用合适的方法。在作业中,我将介绍更多相关的论文和技术。

2. Homework Results and Analysis

作业 11 聚焦于领域适配。给定真实图像(带标签)和涂鸦(无标签),任务是利用领域适配技术训练一个网络,能够准确预测绘制图像的标签。

task description (source: https://speech.ee.ntu.edu.tw/~hylee/ml/ml2023-course-data/HW11.pdf)

数据集设置:

  • 标签:10个类别(编号从0到9),如以下图片所示。

  • 训练集:5000张 (32, 32) RGB 真实图像(带标签)。

  • 测试集:100000张 (28, 28) 灰度绘制图像。

source and target data (source: https://speech.ee.ntu.edu.tw/~hylee/ml/ml2023-course-data/HW11.pdf)

baseline 的门槛 在 Kaggle 上的数值为:

Baseline

Public

Private

Simple

Score >= 0.44280

Score >= 0.44012

Medium

Score >= 0.65994

Score >= 0.65928

Strong

Score >= 0.75342

Score >= 0.75518

Boss

Score >= 0.81072

Score >= 0.80794

像往常一样,助教会提供关于如何超越各种基准模型的指导。

Hints for Simple, Medium and Strong Baseline (source: https://speech.ee.ntu.edu.tw/~hylee/ml/ml2023-course-data/HW11.pdf)

Hints for Boss Baseline (source: https://speech.ee.ntu.edu.tw/~hylee/ml/ml2023-course-data/HW11.pdf)

2.1 Simple Baseline

使用助教提供的默认代码足以通过 simple baseline。

2.2 Medium Baseline

通过增加训练轮数并调整超参数 lambda,可以通过 medium baseline。

num_epochs = 800
# train 800 epochs

with Progress(TextColumn("[progress.description]{task.description}"),
              BarColumn(),
              TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
              TimeRemainingColumn(),
              TimeElapsedColumn()) as progress:
    epoch_tqdm = progress.add_task(description="epoch progress", total=num_epochs)
    for epoch in range(num_epochs):
        train_D_loss, train_F_loss, train_acc = train_epoch(source_dataloader, target_dataloader, progress, lamb=0.6)

        progress.advance(epoch_tqdm, advance=1)
        if epoch == 10:
          torch.save(feature_extractor.state_dict(), f'extractor_model_early.bin')
          torch.save(label_predictor.state_dict(), f'predictor_model_early.bin')
        elif epoch == 100:
          torch.save(feature_extractor.state_dict(), f'extractor_model_mid.bin')
          torch.save(label_predictor.state_dict(), f'predictor_model_mid.bin')

        torch.save(feature_extractor.state_dict(), f'extractor_model.bin')
        torch.save(label_predictor.state_dict(), f'predictor_model.bin')
        print('epoch {:>3d}: train D loss: {:6.4f}, train F loss: {:6.4f}, acc {:6.4f}'.format(epoch, train_D_loss, train_F_loss, train_acc))

2.3 Strong Baseline

助教建议了几篇论文来提升性能并通过strong baseline。其中,我发现以下这篇论文特别有趣:《Minimum Class Confusion for Versatile Domain Adaptation》(Jin, Ying, et al.)(链接)。

他们“提出了一种新颖的损失函数:Minimum Class Confusion(MCC)。它可以被描述为一种新颖且多功能的领域适配方法,无需显式进行领域对齐,且具有较快的收敛速度。此外,它还可以作为一种通用正则化器,与现有的领域适配方法正交且互补,从而进一步加速和改善这些已有的竞争性方法。”(Jin, Ying, et al.,p. 3)

The schematic of the Minimum Class Confusion (MCC) loss function (source: https://arxiv.org/abs/1912.03699)

MCC 的计算过程如下:

给定以下变量:

  • \mathbf{f}_t:网络输出的目标领域数据的logits(即网络分类器的输出)。

  • T :一个温度参数,用于缩放logits,使其更加平滑并增大类别分布之间的差异。

  • \mathbf{p}_t:目标领域经温度平滑后的预测结果,表示通过softmax得到的概率分布。

  • H(\cdot):熵函数,用于衡量每个样本的预测不确定性。

MCC步骤1:目标领域logits的温度缩放:

目标领域的logits ​ \mathbf{f}_t 通过温度进行缩放,以平滑分类概率:

\\ \mathbf{f}_t' = \frac{\mathbf{f}_t}{T} \\

其中, T > 1 用于拉伸预测的概率分布,防止模型过于自信。

MCC步骤2:计算Softmax输出:

将经过温度缩放的logits通过softmax函数得到目标领域预测的概率分布 \mathbf{p}_t ​:

\mathbf{p}_t = \text{Softmax}(\mathbf{f}_t') \\

此处, \mathbf{p}_t ​是一个 N \times C 的矩阵,其中 N 是目标领域样本的数量,C 是分类的类别数。

MCC步骤3:计算样本熵权重:

每个样本的熵 H(\mathbf{p}_t) 使用以下公式计算:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2275443.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

netty请求行超出长度

说明:记录一次使用Netty异常,如下: 错误信息:An HTTP line is larger than 4096 bytes. 场景 项目是微服务架构,在使用Netty转发请求到其他服务的时候报了这个错误。因为该请求是GET方式,其中有个参数值是…

CES Asia 2025科技盛宴,AI智能体成焦点

2025第七届亚洲消费电子技术展(CES Asia赛逸展)将在北京拉开帷幕,AI智能体有望成为展会的核心亮点。 深圳市人工智能行业协会发文表示全力支持CES Asia 2025(赛逸展),称其为人工智能领域的创新发展提供了强…

linux:文件的创建/删除/复制/移动/查看/查找/权限/类型/压缩/打包

关于文件的关键词 创建 touch 删除 rm 复制 cp 权限 chmod 移动 mv 查看内容 cat(全部); head(前10行); tail(末尾10行); more,less 查找 find 压缩 gzip ; bzip 打包 tar 编辑 sed 创建文件 格式: touch 文件名 删除文件 复制文件 移动文件 查看文…

【计算机网络】lab3 802.11 (无线网络帧)

🌈 个人主页:十二月的猫-CSDN博客 🔥 系列专栏: 🏀计算机网络_十二月的猫的博客-CSDN博客 💪🏻 十二月的寒冬阻挡不了春天的脚步,十二点的黑夜遮蔽不住黎明的曙光 目录 1. 前言 2.…

机器人碳钢去毛刺,用大扭去毛刺主轴可轻松去除

在碳钢精密加工的最后阶段,去除毛刺是确保产品质量的关键步骤。面对碳钢这种硬度较高的材料,采用大扭矩的SycoTec去毛刺主轴,成为了行业内的高效解决方案。SycoTec作为精密加工领域的领军品牌,其生产的高速电主轴以其卓越的性能&a…

【漫话机器学习系列】042.提前停止训练的优势(Early Stopping Advantages)

提前停止训练(Early Stopping)的优势 提前停止是一种有效的正则化技术,在训练模型时通过监控验证集的性能来决定训练的结束点,从而避免过拟合。以下是提前停止的主要优势: 1. 防止过拟合 提前停止通过在验证集性能开…

ROS2快速入门0--节点

0:安装 wget http://fishros.com/install -O fishros && . fishros1:运行第一个机器人 ros2 run turtlesim turtlesim_node使用方向健进行控制(在另一个终端) ros2 run turtlesim turtle_teleop_key 2原理解析 打开另一个终端-->输入rqt-->Plugins-->Intr…

10.STM32F407ZGT6-内部温度传感器

参考: 1.正点原子 前言: 本笔记的主要目的和意义就是,再次练习ADC的使用。 32.1 内部温度传感器简介 STM32F407 有一个内部的温度传感器,可以用来测量 CPU 及周围的温度(TA)。对于STM32F407 系列来说,该温度传感器在…

新车月交付突破2万辆!小鹏汽车“激活”智驾之困待解

首次突破月交付2万辆规模的小鹏汽车,稳吗? 本周,高工智能汽车研究院发布的最新监测数据显示,2024年11月,小鹏汽车在国内市场(不含出口)交付量(上险口径,下同&#xff09…

【2024年华为OD机试】 (A卷,100分)- 租车骑绿岛(Java JS PythonC/C++)

一、问题描述 题目描述 部门组织绿岛骑行团建活动。租用公共双人自行车,每辆自行车最多坐两人,最大载重 M。 给出部门每个人的体重,请问最多需要租用多少双人自行车。 输入描述 第一行两个数字 m、n,分别代表自行车限重&#…

AI在零售行业中的应用:提升顾客体验与运营效率

你知道吗?零售行业正悄悄发生着一场革命!AI正在改变我们的购物方式,提升体验的同时,还让商家们的运营更高效! 1、个性化推荐 AI通过分析你的购物历史和兴趣,精准推荐你喜欢的商品,再也不怕刷到…

人才选拔中,如何优化面试流程

在与某大型央企的深入交流中,随着该企业的不断壮大与业务扩张,对技术人才的需求急剧上升,尽管企业加大了招聘力度并投入了大量资源,但招聘成效却不尽如人意。经过项目组细致调研与访谈,问题的根源逐渐浮出水面&#xf…

Deepin20.9 搭建 JDK 8 开发环境(VS Code)

一、安装指令 sudo apt-get install openjdk-8-jdk 二、切换 java 版本(可选) sudo update-alternatives --config java sudo update-alternatives --config javac sudo update-alternatives --config javadoc三、查看 java 与 javac 的版本 jav…

可靠的人形探测,未完待续(III)

一不小心,此去经年啊。问大家新年快乐! 那,最近在研究毫米波雷达模块嘛,期望用在后续的产品中,正好看到瑞萨的活动送板子,手一下没忍住。 拿了板子就得干活咯,我一路火花带闪电,开整…

论文笔记:FDTI: Fine-grained Deep Traffic Inference with Roadnet-enriched Graph

2023 PKDD 1 intro 一些交通预测下游任务对预测结果的粒度非常敏感,例如交通信号控制、拥堵发现和路径规划 然而,现有的深度学习方法主要关注粗粒度的交通数据,而在细粒度设置下利用深度学习方法解决交通预测任务的研究仍未被探索在细粒度设…

如何BugReport和PowerMonitor图形结合分析功耗问题

一、什么是BugReport和PowerMonitor图形结合呢? Battery Historian是支持PowerMonitor电流图显示的,具体显示效果如下:我们移动鼠标到PowerMonitor的电流波形时就会显示这个时刻的电流情况。 BugReport和PowerMonitor图形结合好处&#xff…

外部获取nVisual所在层级方法

Iframe嵌入nVisual,在iframe渲染完成之后,以后通过增加window.addEventListener()方法监听message事件,来获取nvisual当前的所在层级以及所选中的节点列表以及线缆列表。 nVisualPatrolDiagramIdList 变量是获取nVisual当前所在的层级的ID值…

UI自动化测试框架playwright--初级入门

一、背景:UI自动化的痛点: 1、设计脚本耗时: 需要思考要如何模拟用户的操作,如何触发页面的事件,还要思考如何设计脚本,定位和操作要交互的元素、路径、位置,再编写代码逻辑,往复循…

开放词汇检测新晋SOTA:地瓜机器人开源DOSOD实时检测算法

在计算机视觉领域,目标检测是一项关键技术,旨在识别图像或视频中感兴趣物体的位置与类别。传统的闭集检测长期占据主导地位,但近年来,开放词汇检测(Open-Vocabulary Object Detection-OVOD 或者 Open-Set Object Detec…

Jupyter Notebook 安装PyTorch

1、判断当前环境 通过如下命令可以看出是Anaconda 环境 2、Anaconda 环境安装 PyTorch 2.1 要执行的命令 如果你使用的是 Anaconda 环境,可以使用以下命令来安装 PyTorch: conda install pytorch -c pytorch 2.2 执行遇到的问题:没有权…