【域适应】基于深度域适应MMD损失的典型四分类任务实现

news2024/11/27 5:28:32

关于

MMD (maximum mean discrepancy)是用来衡量两组数据分布之间相似度的度量。一般地,如果两组数据分布相似,那么MMD 损失就相对较小,说明两组数据/特征处于相似的特征空间中。基于这个想法,对于源域和目标域数据,在使用深度学习进行特征提取中,使用MMD损失,可以让模型提取两个域的共有特征/空间,从而实现源域到目标域的迁移。

参考论文:https://arxiv.org/abs/1409.6041

工具

Python

 

方法实现

定义mmd函数
#!/usr/bin/env python
# encoding: utf-8

import torch

# Consider linear time MMD with a linear kernel:
# K(f(x), f(y)) = f(x)^Tf(y)
# h(z_i, z_j) = k(x_i, x_j) + k(y_i, y_j) - k(x_i, y_j) - k(x_j, y_i)
#             = [f(x_i) - f(y_i)]^T[f(x_j) - f(y_j)]
#
# f_of_X: batch_size * k
# f_of_Y: batch_size * k
def mmd_linear(f_of_X, f_of_Y):
    delta = f_of_X - f_of_Y
    loss = torch.mean(torch.mm(delta, torch.transpose(delta, 0, 1)))
    return loss

def guassian_kernel(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):
    n_samples = int(source.size()[0])+int(target.size()[0])
    total = torch.cat([source, target], dim=0)
    total0 = total.unsqueeze(0).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))
    total1 = total.unsqueeze(1).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))
    L2_distance = ((total0-total1)**2).sum(2)
    if fix_sigma:
        bandwidth = fix_sigma
    else:
        bandwidth = torch.sum(L2_distance.data) / (n_samples**2-n_samples)
    bandwidth /= kernel_mul ** (kernel_num // 2)
    bandwidth_list = [bandwidth * (kernel_mul**i) for i in range(kernel_num)]
    kernel_val = [torch.exp(-L2_distance / bandwidth_temp) for bandwidth_temp in bandwidth_list]
    return sum(kernel_val)#/len(kernel_val)


def mmd_rbf_accelerate(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):
    batch_size = int(source.size()[0])
    kernels = guassian_kernel(source, target,
        kernel_mul=kernel_mul, kernel_num=kernel_num, fix_sigma=fix_sigma)
    loss = 0
    for i in range(batch_size):
        s1, s2 = i, (i+1)%batch_size
        t1, t2 = s1+batch_size, s2+batch_size
        loss += kernels[s1, s2] + kernels[t1, t2]
        loss -= kernels[s1, t2] + kernels[s2, t1]
    return loss / float(batch_size)

def mmd_rbf_noaccelerate(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):
    batch_size = int(source.size()[0])
    kernels = guassian_kernel(source, target,
                              kernel_mul=kernel_mul, kernel_num=kernel_num, fix_sigma=fix_sigma)
    XX = kernels[:batch_size, :batch_size]
    YY = kernels[batch_size:, batch_size:]
    XY = kernels[:batch_size, batch_size:]
    YX = kernels[batch_size:, :batch_size]
    loss = torch.mean(XX + YY - XY -YX)
    return loss
定义基于mmd特征对齐CNN模型
# encoding=utf-8

import torch.nn as nn
import torch.nn.functional as F


class Network(nn.Module):
    def __init__(self):
        super(Network, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=64, kernel_size=(1, 3)),
            nn.ReLU()
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=(1, 3)),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.MaxPool2d(kernel_size=(1, 2), stride=2)
        )
        self.fc1 = nn.Sequential(
            nn.Linear(in_features=64 * 98, out_features=100),
            nn.ReLU()
        )
        self.fc2 = nn.Sequential(
            nn.Linear(in_features=100, out_features=2)
        )

    def forward(self, src, tar):
        x_src = self.conv1(src)
        x_tar = self.conv1(tar)
        
        x_src = self.conv2(x_src)
        x_tar = self.conv2(x_tar)
        #print(x_src.shape)
        x_src = x_src.reshape(-1, 64 * 98)
        x_tar = x_tar.reshape(-1, 64 * 98)
        
        x_src_mmd = self.fc1(x_src)
        x_tar_mmd = self.fc1(x_tar)
        
        #x_src = self.fc1(x_src)
        #x_tar = self.fc1(x_tar)
        
        #x_src_mmd = self.fc2(x_src)
        #x_tar_mmd = self.fc2(x_tar)
        
        y_src = self.fc2(x_src_mmd)
        
        return y_src, x_src_mmd, x_tar_mmd
    

代码获取

后台私信;

其他相关域适应问题和代码开发,欢迎沟通和交流。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1587458.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

实况窗助力美团打造鸿蒙原生外卖新体验,用户可实时掌握外卖进展

自2023年华为宣布全新HarmonyOS NEXT蓄势待发,鸿蒙原生应用全面启动以来,已有金融、旅行、社交等多个领域的企业和开发者陆续宣布加入鸿蒙生态。其中,美团作为国内头部的科技零售企业,是首批加入鸿蒙生态的伙伴,其下的…

华为ensp中PPPOE (点对点协议)原理和配置命令

作者主页:点击! ENSP专栏:点击! 创作时间:2024年4月12日6点30分 PPPoE(PPP over Ethernet)是一种将PPP协议封装到以太网帧中的链路层协议。它可以使以太网网络中的多台主机连接到远端的宽带接…

【应用】SpringBoot-自动配置原理

前言 本文简要介绍SpringBoot的自动配置原理。 本文讲述的SpringBoot版本为:3.1.2。 前置知识 在看原理介绍之前,需要知道Import注解的作用: 可以导入Configuration注解的配置类、声明Bean注解的bean方法;可以导入ImportSele…

AtCoder Beginner Contest 141 F. Xor Sum 3(异或性质+异或线性基求最大异或值)

题目 n(2<n<1e5)个数&#xff0c;第i个数ai(0<ai<2^60) 将n个数分成两堆&#xff0c;对每一堆求异或和&#xff0c;再将得到的两个数求和&#xff0c; 现在希望这个和最大&#xff0c;输出这个最大的值 思路来源 ABC141F - 洛谷专栏 题解 感觉思路来源说的很…

再说vue响应式数据

请说一下你对响应式数据的理解 如何实现响应式数据据 对象 vue2 响应式核心代码 数组 vue2 处理缺陷Vue3则采用 proxy - vue3 响应式核心代码 请说一下你对响应式数据的理解 如何实现响应式数据据 数组和对象类型当值变化时如何劫持到。 对象 对象内部通过defineReactive方…

mysql重启失败

服务器重启了一下&#xff0c;然后启动后发现mysql自动启动没有生效&#xff0c;于是手动通过systemctl启动mysqld&#xff0c;然后就报错:Starting MySQL...........The server quit without updating P[FAILED](/data/mysql/iz2zebvmy1qv3fao9c5riuz.pid). 根据配置my.cnf文…

Hello 算法10:搜索

https://www.hello-algo.com/chapter_searching/binary_search/ 二分查找法 给定一个长度为 n的数组 nums &#xff0c;元素按从小到大的顺序排列&#xff0c;数组不包含重复元素。请查找并返回元素 target 在该数组中的索引。若数组不包含该元素&#xff0c;则返回 -1 。 # 首…

Ubuntu下配置Android NDK环境

Android-NDK的下载 下载Android-NDK wget -c http://dl.google.com/android/ndk/android-ndk-r10e-linux-x86_64.bin 执行bin文件&#xff08;即解压&#xff09; ./android-ndk-r10c-linux-x86_64.bin Android-NDK的配置 要想使用Android-NDK&#xff0c;还需要进行环境变量…

程序猿之路

我接触计算机算对自己来说是比较晚的了&#xff0c;上初中的时候就有微机课&#xff0c;但是在那个小县城&#xff0c;上课也只是3个人共用一个电脑&#xff0c;我初中整个过程只会开关机&#xff0c;哈哈&#xff0c;虽然学过word&#xff0c;但是无奈&#xff0c;我插不上手呀…

Qlik Sense :use Peek function to Group by and Get Rowno

Question Row number based on groups of data Calculate row number for groups 有时候我们需要基于分组来对数据进行内部排序&#xff0c;例如一个iddate&#xff0c;把不同的属性的记录标记为123&#xff0c;又或者把重复记录标记出来 Solved: Calculate row number for…

如何实现word一键注音?给一篇word文章快速注音的方法

在日常生活和工作中&#xff0c;我们经常需要处理各种文档&#xff0c;其中不乏包含大量生僻字或需要标注拼音的文本。手动为每一个字添加拼音不仅效率低下&#xff0c;而且容易出错。那么&#xff0c;有没有一种方法可以实现Word文档的一键注音呢&#xff1f;本文将为大家详细…

基于SpringBoot和Vue的企业客户管理系统

今天要和大家聊的是基于SpringBoot和Vue的企业客户管理系统 &#xff01;&#xff01;&#xff01; 有需要的小伙伴可以通过文章末尾名片咨询我哦&#xff01;&#xff01;&#xff01; &#x1f495;&#x1f495;作者&#xff1a;李同学 &#x1f495;&#x1f495;个人简介…

IntelliJ IDEA(WebStorm、PyCharm、DataGrip等)设置中英文等宽字体,英文为中文的一半(包括标点符号)

1.设置前&#xff08;idea默认字体为 JetBrains Mono&#xff09; 2.设置后&#xff08;楷体&#xff09;

Oracle 19c补丁升级(Windows)

文章目录 一、打补丁前备份检查1、补丁包获取2、备份数据包以及数据库软件3、检查OPatch版本 二、补丁升级1、更新OPatch2、关闭监听以及服务3、补丁升级过程4、启动监听以及服务 三、数据库补丁应用 一、打补丁前备份检查 1、补丁包获取 补丁包&#xff1a; 百度网盘链接&am…

贪心算法:排列算式

题目描述 给出n数字&#xff0c;对于这些数字是否存在一种计算顺序&#xff0c;使得计算过程中数字不会超过3也不会小于0&#xff1f; 输入描述: 首行给出一个正整数t,(1≤t≤1000)代表测试数据组数每组测试数据第一行一个正整数n,(1≤n≤500)第二行包含n个以空格分隔的数字…

CLIPSeg如果报“目标计算机积极拒绝,无法连接。”怎么办?

CLIPSeg这个插件在使用的时候&#xff0c;偶尔会遇到以下报错&#xff1a; Error occurred when executing CLIPSeg: (MaxRetryError("HTTPSConnectionPool(hosthuggingface.co, port443): Max retries exceeded with url: /CIDAS/clipseg-rd64-refined/resolve/main/toke…

练习题(2024/4/11)

1每日温度 给定一个整数数组 temperatures &#xff0c;表示每天的温度&#xff0c;返回一个数组 answer &#xff0c;其中 answer[i] 是指对于第 i 天&#xff0c;下一个更高温度出现在几天后。如果气温在这之后都不会升高&#xff0c;请在该位置用 0 来代替。 示例 1: 输入…

使用 vue3-sfc-loader 加载远程Vue文件, 在运行时动态加载 .vue 文件。无需 Node.js 环境,无需 (webpack) 构建步骤

加载远程Vue文件 vue3-sfc-loader vue3-sfc-loader &#xff0c;它是Vue3/Vue2 单文件组件加载器。 在运行时从 html/js 动态加载 .vue 文件。无需 Node.js 环境&#xff0c;无需 (webpack) 构建步骤。 主要特征 支持 Vue 3 和 Vue 2&#xff08;参见dist/&#xff09;仅需…

订单中台架构:打造高效订单管理系统的关键

在现代商业环境下&#xff0c;订单管理对于企业来说是至关重要的一环。然而&#xff0c;随着业务规模的扩大和多渠道销售的普及&#xff0c;传统的订单管理方式往往面临着诸多挑战&#xff0c;如订单流程复杂、信息孤岛、数据不一致等问题。为了应对这些挑战并抓住订单管理的机…

Redis 的数据结构和内部编码

Redis的 5 种数据类型 Redis 底层在实现上述数据结构的时候&#xff0c;会在源码层面&#xff0c;针对上述实现进行 特定的优化 &#xff0c;来达到节省时间/节省空间效果 特定的优化&#xff1a;内部的具体实现的数据结构&#xff0c;在特定场景下&#xff0c;不是其对应的标准…