多任务高斯过程数学原理和Pytorch实现示例

news2024/11/25 4:53:46

高斯过程其在回归任务中的应用我们都很熟悉了,但是我们一般介绍的都是针对单个任务的,也就是单个输出。本文我们将讨论扩展到多任务gp,强调它们的好处和实际实现。

本文将介绍如何通过共区域化的内在模型(ICM)和共区域化的线性模型(LMC),使用高斯过程对多个相关输出进行建模。

多任务高斯过程

高斯过程是回归和分类任务中的一个强大工具,提供了一种非参数方式来定义函数的分布。当处理多个相关输出时,多任务高斯过程可以模拟这些任务之间的依赖关系,从而带来更好的泛化和预测效果。

数学上,高斯过程被定义为一组随机变量,其中任何有限数量的变量都具有联合高斯分布。对于一组输入点 X,相应的输出值 f(X) 是联合高斯分布的:

其中 m(X) 是均值函数,k(X, X) 是协方差矩阵。

在多任务环境中,目标是建模函数 f: X → R^T,这样我们就有 T 个输出或任务,f_t(X) 对于 t = 1, …, T。这意味着均值函数是 m: X → R^T,核函数是 k: X × X → R^{T × T}。

我们如何模拟这些任务之间的相关性?

独立多任务高斯过程

一个简单的独立多输出高斯过程模型将每个任务独立建模,不考虑任务之间的任何相关性。在这种情况下,每个任务都有自己的高斯过程,具有自己的均值和协方差函数。数学上可以表达为:

这使得协方差矩阵 k(x, x) 是块对角形的,即 diag(k_1(x, x), …, k_T(x, x))。

这种方法没有利用任务之间的共享信息,可能导致性能不佳,尤其是当某些任务的数据有限时。

Intrinsic model of coregionalization(ICM)

ICM(共区域化的内在模型)方法通过引入核心区域化矩阵 (B) 来推广独立多输出高斯过程,该矩阵模型化任务之间的相关性。ICM方法中的协方差函数定义如下:

其中 k_input是在输入空间上定义的协方差函数(例如,平方指数核),而 B ∈ R^{T × T} 是捕捉任务特定协方差的核心区域化矩阵。矩阵 (B) 通常参数化为 (B = W W^T),其中W ∈ R^{T × r*}* ,且 ® 是核心区域化矩阵的秩。这确保了核函数是半正定的。

ICM方法可以学习任务之间的共享结构。任务之间的皮尔逊相关系数可以表示为:

Linear model of coregionalization (LMC)

另一种常见的方法是LMC(线性核心区域化模型)模型,它通过允许更多种类的输入核来扩展ICM。在LMC模型中,协方差函数定义为:

其中 (Q) 是基核的数量,k_input^q 是基核,而 (B_q) 是每个基核的核心区域化矩阵。通过结合多个基核,这个模型可以捕捉任务之间更复杂的相关性。

我们可以通过设置 (Q=1) 来恢复ICM模型。或者说ICM是Q=1的LMC模型

噪声建模

在多任务高斯过程中,我们需要考虑一个多输出似然函数,该函数为每个任务模型化噪声。

标准的似然函数通常是多维高斯似然,可以表示为:

其中 (y) 是观测输出,(f(x)) 是潜在函数,Sigma是噪声协方差矩阵。

这里的灵活性在于噪声协方差矩阵的选择,它可以是对角线Σ=diag(σ12,…,σT2)(每个任务独立噪声)或完整(任务间相关噪声)。

后者通常表示为 Σ=LLT,其中 L∈RT×r,且 r是噪声协方差矩阵的秩。这允许捕捉不同任务的噪声项之间的相关性。

最终包含噪声的协方巧矩阵则由以下给出:

其中 (K) 是没有噪声的协方差矩阵,(I) 是单位矩阵,⊗表示克罗内克乘积,以便将噪声项添加到协方差矩阵的对角块中。

PyTorch实现

我们上面介绍了多任务的高斯过程的数学原理,下面开使用’ GPyTorch '与ICM内核实现多任务GP的示例。

首先需要安装所需的软件包,包括’ Torch ', ’ GPyTorch ', ’ matplotlib ‘和’ seaborn ', ’ numpy '。

 %pip install torch gpytorch matplotlib seaborn numpy pandas
 
 import torch
 import gpytorch
 from matplotlib import pyplot as plt
 import numpy as np
 import seaborn as sns
 import pandas as pd

然后,定义多任务GP模型。使用ICM内核(秩r=1)来捕获任务之间的相关性。

我们为两个任务(正弦和移位正弦)生成合成的噪声训练数据,以便有相关的输出。

噪声协方差矩阵为

其中σ²1 = σ²2 = 0.1²,ρ = 0.3。

最后,我们通过绘制每个任务的平均预测和置信区间来训练模型并评估其性能。

 # Define the kernel with coregionalization
 class MultitaskGPModel(gpytorch.models.ExactGP):
     def __init__(self, train_x, train_y, likelihood, num_tasks):
         super(MultitaskGPModel, self).__init__(train_x, train_y, likelihood)
         self.mean_module = gpytorch.means.MultitaskMean(
             gpytorch.means.ConstantMean(), num_tasks=num_tasks
         )
         self.covar_module = gpytorch.kernels.MultitaskKernel(
             gpytorch.kernels.RBFKernel(), num_tasks=num_tasks, rank=1
         )
 
     def forward(self, x):
         mean_x = self.mean_module(x)
         covar_x = self.covar_module(x)
         return gpytorch.distributions.MultitaskMultivariateNormal(mean_x, covar_x)
 
 # Training data
 f1 = lambda x:  torch.sin(x * (2 * torch.pi))
 f2 = lambda x: torch.sin((x - 0.1) * (2 * torch.pi))
 train_x = torch.linspace(0, 1, 10)
 train_y = torch.stack([
     f1(train_x),
     f2(train_x)
 ]).T
 # Define the noise covariance matrix with correlation = 0.3
 sigma2 = 0.1**2
 Sigma = torch.tensor([[sigma2, 0.3 * sigma2], [0.3 * sigma2, sigma2]])
 # Add noise to the training data
 train_y += torch.tensor(np.random.multivariate_normal(mean=[0,0], cov=Sigma, size=len(train_x)))
 
 # Model and likelihood
 num_tasks = 2
 likelihood = gpytorch.likelihoods.MultitaskGaussianLikelihood(num_tasks=num_tasks, rank=1)
 model = MultitaskGPModel(train_x, train_y, likelihood, num_tasks)
 
 # Training the model
 model.train()
 likelihood.train()
 
 optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
 
 mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)
 
 scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.5)
 
 num_iter = 500
 for i in range(num_iter):
     optimizer.zero_grad()
     output = model(train_x)
     loss = -mll(output, train_y)
     loss.backward()
     optimizer.step()
     scheduler.step()
 
 # Evaluation
 model.eval()
 likelihood.eval()
 
 test_x = torch.linspace(0, 1, 100)
 
 with torch.no_grad(), gpytorch.settings.fast_pred_var():
     pred_multi = likelihood(model(test_x))
 
 # Plot predictions
 fig, ax = plt.subplots()
 
 colors = ['blue', 'red']
 for i in range(num_tasks):
     ax.plot(test_x, pred_multi.mean[:, i], label=f'Mean prediction (Task {i+1})', color=colors[i])
     ax.plot(test_x, [f1(test_x), f2(test_x)][i], linestyle='--', label=f'True function (Task {i+1})')
     lower = pred_multi.confidence_region()[0][:, i].detach().numpy()
     upper = pred_multi.confidence_region()[1][:, i].detach().numpy()
     ax.fill_between(
         test_x,
         lower,
         upper,
         alpha=0.2,
         label=f'Confidence interval (Task {i+1})',
         color=colors[i]
     )
 
 ax.scatter(train_x, train_y[:, 0], color='black', label=f'Training data (Task 1)')
 ax.scatter(train_x, train_y[:, 1], color='gray', label=f'Training data (Task 2)')
 
 ax.set_title('Multitask GP with ICM')
 ax.legend(loc='lower center', bbox_to_anchor=(0.5, -0.2),
           ncol=3, fancybox=True)

在使用

GPyTorch

时,通过使用

MultitaskMean

MultitaskKernel

MultitaskGaussianLikelihood

类,ICM模型的实现非常简单。这些类可以处理多任务结构、噪声和核心区域化矩阵,允许我们专注于模型定义和训练。

训练的循环也与标准高斯过程类似,以负边际对数似然作为损失函数,并使用优化器来更新模型参数。

在训练过程中添加了一个调度器来降低学习率,这可以帮助稳定优化过程。

 W = model.covar_module.task_covar_module.covar_factor
 B = W @ W.T
 
 fig, ax = plt.subplots()
 sns.heatmap(B.detach().numpy(), annot=True, ax=ax, cbar=False, square=True)
 ax.set_xticklabels(['Task 1', 'Task 2'])
 ax.set_yticklabels(['Task 1', 'Task 2'])
 ax.set_title('Coregionalization matrix B')
 fig.show()
 
 
 L = model.likelihood.task_noise_covar_factor.detach().numpy()
 Sigma = L @ L.T
 
 fig, ax = plt.subplots()
 sns.heatmap(Sigma, annot=True, ax=ax, cbar=False, square=True)
 ax.set_xticklabels(['Task 1', 'Task 2'])
 ax.set_yticklabels(['Task 1', 'Task 2'])
 ax.set_title('Noise covariance matrix')
 fig.show()

下面的图展示了模型学习的核心区域化矩阵 B 以及噪声协方差矩阵 Σ。

这张图捕捉了任务之间的相关性。如我们所见,B 的非对角线元素是正的。

这个图代表每个任务的噪声水平。注意到模型已经正确学习了噪声相关性。

比较

为了突出使用ICM方法对相关输出建模的优势,我们可以将其与独立处理每个任务、忽略任务之间任何潜在相关性的模型进行比较。

为每个任务定义一个单独的GP,训练它们,并在测试数据上评估它们的性能。

 class IndependentGPModel(gpytorch.models.ExactGP):
     def __init__(self, train_x, train_y, likelihood):
         super(IndependentGPModel, self).__init__(train_x, train_y, likelihood)
         self.mean_module = gpytorch.means.ConstantMean()
         self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())
 
     def forward(self, x):
         mean_x = self.mean_module(x)
         covar_x = self.covar_module(x)
         return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)
 
 # Create models and likelihoods for each task
 likelihoods = [gpytorch.likelihoods.GaussianLikelihood() for _ in range(num_tasks)]
 models = [IndependentGPModel(train_x, train_y[:, i], likelihoods[i]) for i in range(num_tasks)]
 
 # Training the independent models
 for i, (model, likelihood) in enumerate(zip(models, likelihoods)):
     model.train()
     likelihood.train()
     optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
     mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)
     scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.5)
     
     for _ in range(num_iter):
         optimizer.zero_grad()
         output = model(train_x)
         loss = -mll(output, train_y[:, i])
         loss.backward()
         optimizer.step()
         scheduler.step()
 
 # Evaluation
 for model, likelihood in zip(models, likelihoods):
     model.eval()
     likelihood.eval()
 
 with torch.no_grad(), gpytorch.settings.fast_pred_var():
     pred_inde = [likelihood(model(test_x)) for model, likelihood in zip(models, likelihoods)]
 
 # Plot predictions
 fig, ax = plt.subplots()
 
 for i in range(num_tasks):
     ax.plot(test_x, pred_inde[i].mean, label=f'Mean prediction (Task {i+1})', color=colors[i])
     ax.plot(test_x, [f1(test_x), f2(test_x)][i], linestyle='--', label=f'True function (Task {i+1})')
     lower = pred_inde[i].confidence_region()[0]
     upper = pred_inde[i].confidence_region()[1]
     ax.fill_between(
         test_x,
         lower,
         upper,
         alpha=0.2,
         label=f'Confidence interval (Task {i+1})',
         color=colors[i]
     )
 
 ax.scatter(train_x, train_y[:, 0], color='black', label='Training data (Task 1)')
 ax.scatter(train_x, train_y[:, 1], color='gray', label='Training data (Task 2)')
 
 ax.set_title('Independent GPs')
 ax.legend(loc='lower center', bbox_to_anchor=(0.5, -0.2),
           ncol=3, fancybox=True)

在性能方面,比较多任务GP与ICM和独立GP在测试数据上预测的均方误差(MSE)。

 mean_multi = pred_multi.mean.numpy()
 mean_inde = np.stack([pred.mean.numpy() for pred in pred_inde]).T
 
 test_y = torch.stack([f1(test_x), f2(test_x)]).T.numpy()
 MSE_multi = np.mean((mean_multi - test_y) ** 2)
 MSE_inde = np.mean((mean_inde - test_y) ** 2)
 
 df = pd.DataFrame({
     'Model': ['ICM', 'Independent'],
     'MSE': [MSE_multi, MSE_inde]
   })
 df

可以看到由于共区域化矩阵学习到的共享结构,ICM在MSE方面略优于独立gp。当处理更复杂的任务或有限的数据时,这种改进可能更为显著。

在独立GP的场景中,每个GP从一个较小的10个点的数据集中学习,这可能会导致过拟合或次优泛化。具有ICM的多任务GP使用所有20个点来学习指数核参数的平方。这种共享信息有助于改进对这两个任务的预测。

https://avoid.overfit.cn/post/f804e93bd5dd4c4ab9ede5bf1bc9b6c8

作者:Andrea Ruglioni

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1939075.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

驱动LSM6DS3TR-C实现高效运动检测与数据采集(10)----融合磁力计进行姿态解算

驱动LSM6DS3TR-C实现高效运动检测与数据采集.10--融合磁力计进行姿态解算 概述视频教学样品申请源码下载硬件准备DataLogFusion磁力计校准过程初始化磁力计MFX_Arithmetic_Init卡尔曼滤波算法演示 概述 MotionFX库包含用于校准陀螺仪、加速度计和磁力计传感器的例程。 将磁力计…

【网络】windows和linux互通收发

windows和linux互通收发 一、windows的udp客户端代码1、代码剖析2、总体代码 二、linux服务器代码三、成果展示 一、windows的udp客户端代码 1、代码剖析 首先我们需要包含头文件以及lib的一个库&#xff1a; #include <iostream> #include <WinSock2.h> #inclu…

swiftui中onChange函数的使用,监听变量的变化

在 SwiftUI 中&#xff0c;onChange 修饰符用于在指定值发生变化时执行某些操作。它允许你监听一个状态或绑定值的变化&#xff0c;并在变化发生时运行一些代码。这个功能非常适合需要对状态变化做出响应的场景。 使用示例&#xff1a; struct AppStorageTest: View {State p…

友力科技数据中心搬迁方案

将当前运行机房中的所有设备、应用系统安全搬迁至新数据中心机房&#xff0c;实现平滑切换、平稳过渡&#xff0c;最大限度地降低搬迁工作对业务的影响。 为了确保企事业单位能够顺利完成数据中心机房搬迁工作&#xff0c;我们根据实际经验提供了4个基本原则&#xff0c;希望能…

【Linux】编辑器vscode与linux的联动

1.vscode简单学习 vscode是编辑器&#xff0c;可以写各种语言的程序 下载链接&#xff1a;Download Visual Studio Code - Mac, Linux, Windows 来用一下vscode 我们保存了就能在我们的那个文件夹里面看到这个 这个就是编辑器&#xff0c;跟我们的文本文件好像差不多&#…

RPM、YUM 安装 xtrabackup 8 (mysql 热备系列一)包含rpm安装 mysql 8 配置主从

RPM安装 percona-xtrabackup-80-8.0.35-30.1.el7.x86_64.rpm 官网&#xff1a; https://www.percona.com/ 下载地址&#xff1a; https://www.percona.com/downloads wget https://downloads.percona.com/downloads/percona-distribution-mysql-ps/percona-distribution-mysq…

51单片机14(独立按键实验)

一、按键介绍 1、按键是一种电子开关&#xff0c;使用的时候&#xff0c;只要轻轻的按下我们的这个按钮&#xff0c;按钮就可以使这个开关导通。 2、当松开这个手的时候&#xff0c;我们的这个开关&#xff0c;就断开开发板上使用的这个按键&#xff0c;它的内部结构&#xff…

从千台到十万台,浪潮信息InManage V7解锁智能运维密码

随着大模型技术的深度渗透&#xff0c;金融行业正经历着前所未有的智能化变革。从“投顾助手”精准导航投资蓝海&#xff0c;到“智能客服”秒速响应客户需求&#xff0c;大模型以其对海量金融数据的深度挖掘与高效利用&#xff0c;正显著提升金融服务的智能化水准&#xff0c;…

Java:拦截器简介和应用示例(多个拦截器+校验token是否为空)

JAVA 拦截器 简介 拦截器和过滤器均可以拦截http请求&#xff0c;过滤器偏向于基础设施工作&#xff0c;拦截器偏向于业务&#xff0c;拦截器允许在执行Controller之前做验证预处理&#xff0c;在Controller执行之后对返回对象做加工处理。可以用于&#xff1a;权限检查、日志…

2014年全国大学生数学建模竞赛C题生猪养殖管理(含word论文和源代码资源)

文章目录 一、部分题目二、部分论文三、部分源代码四、完整word版论文和源代码 一、部分题目 2014高教社杯全国大学生数学建模竞赛题目 C题 生猪养殖场的经营管理 某养猪场最多能养10000头猪&#xff0c;该养猪场利用自己的种猪进行繁育。养猪的一般过程是&#xff1a;母猪配…

第3关 -- Git 基础知识

任务1: 破冰活动&#xff1a;自我介绍 任务2: 实践项目&#xff1a;构建个人项目 MeiHuaYiShu

【BUG】已解决:ModuleNotFoundError: No module named ‘_ctypes‘

已解决&#xff1a;ModuleNotFoundError: No module named ‘_ctypes‘ 欢迎来到英杰社区https://bbs.csdn.net/topics/617804998 欢迎来到我的主页&#xff0c;我是博主英杰&#xff0c;211科班出身&#xff0c;就职于医疗科技公司&#xff0c;热衷分享知识&#xff0c;武汉城…

【Langchain大语言模型开发教程】记忆

&#x1f517; LangChain for LLM Application Development - DeepLearning.AI 学习目标 1、Langchain的历史记忆 ConversationBufferMemory 2、基于窗口限制的临时记忆 ConversationBufferWindowMemory 3、基于Token数量的临时记忆 ConversationTokenBufferMemory 4、基于历史…

双笼转子感应电机建模仿真(2):任意速旋转坐标系下xy/xy数学模型及仿真模型

1.概述 2. 双笼转子三相感应电机数学模型 2.1. 定子基准下ABC/qd数学模型 2.2. 任意速旋转坐标系下xy/xy数学模型 2.3. 空间矢量数学模型 3. 双笼转子三相感应电动机仿真模型 3.1 基于任意速xy/xy坐标系数学模型(1)~(5)的仿真模型 3.2. 基于任意速xy/xy坐标系中瞬态等效电…

MATLAB图像处理分析基础(一)

一、引言 MATLAB软件得到许多数字图像处理学生、老师和科研工作者的喜爱&#xff0c;成为数字图像处理领域不可或缺的工具之一&#xff0c;其与其他软件相比有以下诸多显著优点。首先&#xff0c;MATLAB 拥有强大的内置函数库&#xff0c;涵盖了图像读取、显示、处理及分析的全…

OpenCV 遍历Mat,像素操作,使用TrackBar 调整图像的亮度和对比度 C++实现

文章目录 1.使用C遍历Mat,完成颜色反转1.1 常规遍历方式1.2 迭代器遍历方式1.3指针访问方式遍历&#xff08;最快&#xff09;1.4不同遍历方式的时间对比 2.图像像素操作&#xff0c;提高图像的亮度3.TrackBar 进度条操作3.1使用TrackBar 调整图像的亮度3.2使用TrackBar 调整图…

【JavaEE进阶】——Spring事务和事务传播机制

目录 &#x1f6a9;事务 &#x1f388;为什么需要事务? &#x1f388;事务的操作 &#x1f6a9;Spring 中事务的实现 &#x1f388;数据准备 &#x1f388;Spring 编程式事务(了解) &#x1f388;Spring 声明式事务 Transactional &#x1f36d;Transactional 详解 &…

2013年全国大学生数学建模竞赛B题碎纸片复原(含word论文和源代码资源)

文章目录 一、部分题目二、部分论文三、部分源代码四、完整word版论文和源代码&#xff08;两种获取方式&#xff09; 一、部分题目 2013高教社杯全国大学生数学建模竞赛题目 B题 碎纸片的拼接复原 破碎文件的拼接在司法物证复原、历史文献修复以及军事情报获取等领域都有着重…

基于术语词典干预的机器翻译挑战赛笔记Task2 #Datawhale AI 夏令营

上回&#xff1a; 基于术语词典干预的机器翻译挑战赛笔记Task1 跑通baseline Datawhale AI 夏令营-CSDN博客文章浏览阅读718次&#xff0c;点赞11次&#xff0c;收藏8次。基于术语词典干预的机器翻译挑战赛笔记Task1 跑通baselinehttps://blog.csdn.net/qq_23311271/article/d…

统计一个页面用到的html,css,js

<!DOCTYPE html> <html lang"zh-CN"><head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"><title>统计html</title><style>* {margin: …