【人工智能概论】 Python标准库——dalib(领域自适应)

news2024/11/19 9:29:43

【人工智能概论】 Python标准库——dalib(领域自适应)

文章目录

  • 【人工智能概论】 Python标准库——dalib(领域自适应)
  • 一. 领域鉴别器(DomainDiscriminator)
  • 二. 领域对抗损失(DomainAdversarialLoss)
  • 三. 高斯核(GaussianKernel)
  • 四. 多核最大均值差异(MK-MMD)


一. 领域鉴别器(DomainDiscriminator)

dalib.modules.domain_discriminator.DomainDiscriminator(in_feature: int, hidden_size: int)

  • 功能: 区分输入的特征是来自源域还是目标域,源域标签为1,目标域标签为0。
  • 参数:
  • in_feature(int): 输入特征的维度;
  • hidden_size(int): 隐层特征的维度。
  • 形状:
  • inputs:(minibatch, in_feature);
  • outputs: (minibatch, 1)。
  • 举例:
  • 见 领域对抗损失(DomainAdversarialLoss) 的举例。

二. 领域对抗损失(DomainAdversarialLoss)

dalib.adaptation.dann.DomainAdversarialLoss(domain_discriminator: torch.nn.modules.module.Module, reduction: Optional[str]= 'mean')

  • 定义: L o s s ( D s , D t ) = E x i s ⌢ D s l o g [ D ( f i s ) ] + E x j t ⌢ D t l o g [ 1 − D ( f j t ) ] Loss(D_{s},D_{t})=E_{x_{i}^{s}\frown D_{s}}log[D(f_{i}^{s})]+E_{x_{j}^{t}\frown D_{t}}log[1-D(f_{j}^{t})] Loss(Ds,Dt)=ExisDslog[D(fis)]+ExjtDtlog[1D(fjt)]其中,D是领域鉴别器,f是领域的特征。
  • 参数:
  • domain_discriminator(nn.Module) : 域鉴别器对象,用于预测特征的域;
  • reduction(string,Optional) : 指定输出损失的方式,‘none’, ‘sum’,‘mean’,其中’none’指不使用任何降维直接输出,‘sum’、'mean’分别是对损失求和、求均值,默认为求均值。
  • 输入:
  • f_s (tensor):源域的特征 f s f^{s} fs
  • f_t (tensor) :目标域的特征 f t f^{t} ft
  • 形状:
  • f_s, f_t : (N, F)F是输入特征的维度;
  • outputs : 默认是标量,但如果reduction是’none’输出的形状是(N,)。
  • 举例:
from dalib.modules.domain_discriminator import DomainDiscriminator
from dalib.adaptation.dann import DomainAdversarialLoss

discriminator = DomainDiscriminator(in_feature= 1024, hidden_size= 2048)
loss = DomainAdversarialLoss(discriminator, reduction= 'mean')

f_s, f_t = torch.rand(20, 1024), torch.rand(20, 1024)
output = loss(f_s, f_t)

print(output)

在这里插入图片描述

三. 高斯核(GaussianKernel)

dalib.modules.kernels.GaussianKernel(sigma: Optional[float] = None, track_running_stats: Optional[bool] = True, alpha: Optional[float] = 1.0)

  • 定义:
  • 高斯核 k k k的定义: k ( x 1 , x 2 ) = e x p ( − ∥ x 1 − x 2 ∥ 2 2 σ 2 ) k(x_{1},x_{2})=exp(-\frac{\left \| x_{1}-x_{2} \right \|^{2} }{2\sigma ^{2}} ) k(x1,x2)=exp(2σ2x1x22) 其中 x 1 , x 2 ∈ R d x_{1},x_{2}\in R^{d} x1,x2Rd是一维张量。
  • 高斯核矩阵 K K K被定义在 X = ( x 1 , x 2 , . . . x m ) X=(x_{1},x_{2},...x_{m}) X=(x1,x2,...xm)上: K ( x ) i , j = k ( x i , x j ) K(x)_{i,j} = k(x_{i},x_{j}) K(x)i,j=k(xi,xj)
  • 在运算中 σ 2 \sigma ^{2} σ2有两种确认方法:
    第一种通过下式计算动态获得: σ 2 = α n 2 ∑ i , j ∥ x i − x j ∥ 2 \sigma ^{2} = \frac{\alpha }{n^{2}}\sum _{i,j}\left \| x_{i}-x_{j} \right \| ^{2} σ2=n2αi,jxixj2
    第二种是直接给定数值。
  • 参数:
  • sigma(float, optional): 即 σ \sigma σ,默认为None;
  • track_running_stats(bool, optional):如果是’True’则用前面的公式计算 σ 2 \sigma^{2} σ2,若为’False’则使用固定的 σ 2 \sigma^{2} σ2,默认为’True’;
  • alpha(float, optional):当track_running_stats为’True’时为计算 σ 2 \sigma^{2} σ2提供 α \alpha α
  • 输入:
  • X(tensor):输入组X。
  • 形状:
  • inputs:(minibatch, F) , F是输入特征的维数;
  • outputs:(minibatch, minibatch) 。

四. 多核最大均值差异(MK-MMD)

dalib.adaptation.dan.MultipleKernelMaximumMeanDiscrepancy(kernels: Sequence[torch.nn.modules.module.Module], Linear: Optional[bool]= False, quadratic_program: Optional[bool]= False)

  • MK-MMD:
  • 源域为: D s = { ( x i s , y i s ) } i = 1 n s D_{s}= \left \{ (x_{i}^{s},y_{i}^{s}) \right \}_{i=1}^{n_{s}} Ds={(xis,yis)}i=1ns
  • 目标域: D t = { x j t } j = 1 n t D_{t}= \left \{ x_{j}^{t} \right \}_{j=1}^{n_{t}} Dt={xjt}j=1nt
  • 它们各自的样本间都符合独立同分布;
  • 则MK-MMD的计算公式为: d M K − M M D ( D s , D t ) = ∥ E s [ g ( D s ) ] − E t [ g ( D t ) ] ∥ H k 2 d_{MK-MMD}(D_{s},D_{t})=\left \| E_{s}[g(D_{s})]-E_{t}[g(D_{t})] \right \| ^{2}_{H_{k}} dMKMMD(Ds,Dt)=Es[g(Ds)]Et[g(Dt)]Hk2
  • H k H_{k} Hk表示具有特定内核 k k k R K H S RKHS RKHS g ( ∗ ) g(*) g()是与核函数相关的连续映射, E [ ∗ ] E[*] E[]是给定分布的期望;
  • 应当注意的是,核函数 k k k是被定义为 r r r个不同的半正定核的凸组合,如下形式: k ( x s , x t ) = ∑ i = 1 r β i k i ( x s , x t ) k(x^{s},x^{t})= {\textstyle \sum_{i=1}^{r}}\beta _{i}k_{i}(x^{s},x^{t}) k(xs,xt)=i=1rβiki(xs,xt)
  • 其中: ∑ i r β i = 1 , β i ≥ 0 {\textstyle \sum_{i}^{r}}\beta _{i}=1,\beta _{i}\ge 0 irβi=1,βi0
  • 所谓半正定性是核函数的常见的性质(可以联系SVM中的相关概念学习),凸组合是一种线性组合,若满足 λ i ≥ 0 , ∑ i r λ i = 1 \lambda _{i}\ge 0,{\textstyle \sum_{i}^{r}}\lambda _{i}=1 λi0irλi=1 ∑ i r λ i x i {\textstyle \sum_{i}^{r}}\lambda _{i}x_{i} irλixi即为凸组合;
  • 使用内核技巧,MK-MMD可以简化计算为: D ^ k ( D s , D t ) = 1 n s 2 ∑ i = 1 n s ∑ j = 1 n s k ( D s i , D s j ) + 1 n t 2 ∑ i = 1 n t ∑ j = 1 n t k ( D t i , D t j ) − 2 n s n t ∑ i = 1 n s ∑ j = 1 n t k ( D s i , D t j ) \hat{D}_{k}(D_{s},D_{t})= \frac{1}{n_{s}^{2}} {\textstyle \sum_{i=1}^{n_{s}}} {\textstyle \sum_{j=1}^{n_{s}}} k(D_{s}^{i},D_{s}^{j}) +\frac{1}{n_{t}^{2}} {\textstyle \sum_{i=1}^{n_{t}}} {\textstyle \sum_{j=1}^{n_{t}}} k(D_{t}^{i},D_{t}^{j}) -\frac{2}{n_{s}n_{t}} {\textstyle \sum_{i=1}^{n_{s}}} {\textstyle \sum_{j=1}^{n_{t}}} k(D_{s}^{i},D_{t}^{j}) D^k(Ds,Dt)=ns21i=1nsj=1nsk(Dsi,Dsj)+nt21i=1ntj=1ntk(Dti,Dtj)nsnt2i=1nsj=1ntk(Dsi,Dtj)
  • 参数:
  • Kernel(tuple(nn.Module)): 核方程;
  • Linear(bool):是否使用DAN的线性版本,默认不用;
  • quadratic_program(bool): 是否使用二次规划求解 β \beta β,默认不用。
  • 输入:
  • d_s(tensor):源域通过映射所得的特征 D s D_{s} Ds
  • d_t(tensor): 目标域通过映射所得的特征 D t D_{t} Dt
  • 注意它俩必须相同的形状。
  • 形状:
  • inputs: (minibatch, *) *代表任意数,实际上就是传入的特征维度;
  • outputs: 标量。
  • 举例:
from dalib.modules.kernels import GaussianKernel
from dalib.adaptation.dan import MultipleKernelMaximumMeanDiscrepancy

feature_dim = 1024
batch_size = 10

kernels = (GaussianKernel(alpha=0.5), GaussianKernel(alpha=1.), GaussianKernel(alpha=2.))
loss = MultipleKernelMaximumMeanDiscrepancy(kernels)

# features from source domain and target domain
z_s, z_t = torch.randn(batch_size, feature_dim), torch.randn(batch_size,feature_dim)
output = loss(z_s, z_t)

print(output)

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/669904.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【MongoDB】五、MongoDB分片集群的部署

【MongoDB】五、MongoDB分片集群的部署 实验目的实验内容实验步骤环境准备部署 Config server配置Config Server副本集部署Shard部署mongos启动分片功能查看分片信息 实验小结 实验目的 能够通过部署MongoDB分片集群熟悉MongoDB分片集群架构和基本操作,从而解决大数…

调用有道API实现语音翻译(汉译英)

目录 1. 作者介绍2. 相关介绍2.1 API介绍2.2 网易API介绍 3. 实验过程3.1 调用过程3.2 代码获取3.3 完整代码 1. 作者介绍 南旭东,男,西安工程大学电子信息学院,2022级研究生 研究方向:机器视觉与人工智能 电子邮件:1…

win10 搭建vue环境并运行项目

win10 搭建vue环境并运行项目 1、参考链接2、遇到的问题及解决 1、参考链接 https://blog.csdn.net/qq_44959735/article/details/128886550 2、遇到的问题及解决 运行的时候不要再git bash里,要在自带的powershell里,以管理员权限运行。 问题&#xf…

未来5年,生产力的底层逻辑变了,影响所有企业

上周,K哥带领20多位企业家、技术高管参访了阿里钉钉,学习AI大模型如何应用到企业生产经营和组织管理当中,以及企业如何使用新生产力工具实现降本增效。 通过这次参访,我对AI大模型如何驱动企业管理变革有了新的认识,三…

js-排序数组中两个数字之和

给定一个已按照 升序排列 的整数数组 numbers &#xff0c;请你从数组中找出两个数满足相加之和等于目标数 target 。 函数应该以长度为 2 的整数数组的形式返回这两个数的下标值。numbers 的下标 从 0 开始计数 &#xff0c;所以答案数组应当满足 0 < answer[0] < answ…

【ArcGIS】栅格重分类(Reclass)

ArcGIS栅格重分类&#xff08;Reclass&#xff09; 1 重分类&#xff08;Relassify&#xff09;1.1 新值替代1.2 将值组合到一起1.3 按相同等级对一组栅格的值进行重分类1.4 将特定值设置为NoData或者为NoData像元设置某个值1.5 操作步骤 2 查找表&#xff08;Lookup&#xff0…

CS5366设计原理图|Type-C转HDMI2.0 4K60+USB3.0+PD3.1视频转换芯片应用电路图

CS5366Type-C转HDMI2.0的显示协议转换芯片, 内部集成了PD3.0及DSC decoder, 并能按客户需求配置成不同的功能组合&#xff0c; 是目前集成度与功耗更小的一颗芯片。 Type-C转HDMI2.0 4K60USB3.0PD3.1视频转换芯片应用电路图&#xff1a; CS5366系列符合USB电源传输规范3.0。CS…

C语言+单片机-内存分布详解,全网最全,值得收藏保存

目录 一、C语言内存分区 1. 代码区 2. 常量区 3. 全局(静态)区 4. 堆区(heap) 5. 栈区(stack) 二、STM32存储器分配 1. 随机存储器—RAM 2. 只读存储器—ROM 三、基于STM32代码验证 1. 详细代码如下 2. 运行结果如下 四、单片机中的内存分布 1.含义解释 2. 程序…

Makerbase CANable V2.0在Window系统使用

应用软件与固件 应用软件CANable V2.0CANable V1.0cangaroocandleLight/slcan(支持CAN FD)candleLight/slcan/cantactBUSMASTER V3.2.2candleLightcandleLight/pcan/cantactTSMastercandleLightcandleLight/pcan/cantactPCAN-Explorer 5、pcan view不支持pcancantactslcan(不支…

北理工软件工程基础考试要点

文章目录 前言题型分析概念部分大题部分数据流图和数据字典数据流图数据字典 前言 这篇文章就是针对北理工计科同学写的&#xff0c;这是精心筛选&#xff0c;针对老师最后一节课的重点以及题型写的笔记&#xff0c;保你一天速通软件工程基础这门课。 题型分析 闭卷 单选题…

SAP-ABAP-SM30自建表维护如何如何自动带出描述

文章目录 1 Requirement2 Process2.1 When you finish table maintainer , and create event .2.2 Create sub routine .2.3 Write code 3. Result4 Reference Document 1 Requirement The requirement is that we input MATNR and HKONT and get the description automatic .…

DDoS攻击原理是什么?

&#x1f482; 个人网站:【海拥】【游戏大全】【神级源码资源网】&#x1f91f; 前端学习课程&#xff1a;&#x1f449;【28个案例趣学前端】【400个JS面试题】&#x1f485; 寻找学习交流、摸鱼划水的小伙伴&#xff0c;请点击【摸鱼学习交流群】 目录 前言DDoS攻击的原理DDo…

考研算法第28天:冒泡排序和简单选择排序 【排序】

算法介绍 冒泡排序就不需要多说了&#xff0c;大一就会的东西&#xff0c;所以这里就不多言了。记录一下y总对他的分析就是了 简单选择排序 每次循环遍历后面的元素&#xff0c;然后将最小的放到最前面&#xff1a;举个例子 第一次 如上图 第一次发现最小的元素是2就将位于第…

EPM创建报表时,子节点有数据但是父节点无数据的解决方案

目录 一、环境二、问题描述与分析三、解决方案1、确认HANA MDX是否启用2、确认BPC前端的TIME维是否正常维护3. 在SPRO中设置模型参数4、使用SE38执行程序UJHANA_REFRESH_VIR_CUBE刷新模型 一、环境 产品版本BWSAP BW/4HANA 2021BPCSAP BPC 2021 Version for SAP BW/4 HANAEPM1…

电力数据安全治理实践思路探讨

01电力数据安全实践背景 数字经济的快速发展根本上源自数据的高质量治理和高价值转化&#xff0c;近年来&#xff0c;国家层面相继推出促进数据高质量治理的政策法规&#xff0c;围绕加强数据安全保障、提高数据质量等方面&#xff0c;明确了相关规定和要求。作为重要数据持有者…

TC8:TCP_UNACCEPTABLE_05-09

TCP_UNACCEPTABLE_05: [listen] unacceptable ACK -> RST [listen] 目的 在LISTEN状态下的TCP接收到携带一个不可接受的ACK号的段,发送RST并且保持在相同的状态 在LISTEN状态下的TCP,只能接收到SYN消息,不能有ACK标志位和ACK号,如果有,就是不可接受的 测试步骤 Teste…

C++IO流和类型处理(13)

IO流 IO流包括 标准IO流&#xff0c;字符串流&#xff0c;文件流 标准IO流 基础使用 #include <iostream> //包括istream和ostream cin >> ----- 标准输入 cout<< ----- 标准输出 clog<< ----- 带缓冲区的标准错误 cerr<< ----- 不带缓冲…

【C6】11111

文章目录 10.动静态库&#xff1a;.a&#xff0c;指定.so&#xff0c;LD_10.1 静态库&#xff1a;链接库的文件名是libpublic.a&#xff0c;链接库名是public&#xff0c;缺点使用的静态库发生更新改变&#xff0c;程序必须重新编译10.2 动态库&#xff1a;动态库发生改变&…

Selenium java自动化

文章目录 1. Selenium的安装2. 了解自动化和selenium2.1 什么是自动化以及为什么要做2.2为什么选择selenium作为我们的文本自动化工具2.3 环境部署2.4什么驱动&#xff0c;驱动的工作原理。2.5一个简单的自动化演示 3. 掌握selenium的基础语法3.1)元素的定位3.2) 元素的操作3.3…

Java解析String类的使用及String a = b + “c“面试题

1.概述 String:字符串&#xff0c;使用一对""引起来表示。 1.String声明为final的&#xff0c;不可被继承 2.String实现了Serializable接口&#xff1a;表示字符串是支持序列化的。 实现了Comparable接口&#xff1a;表示String可以比较大小 3.String内部定义了fina…