多变量时间序列生成模型GAN介绍与实现

news2024/9/21 4:32:45

目录

    • 1. 模型介绍
    • 2. 问题提出
    • 3. 模型具体实现
      • 3.1 数据预处理
      • 3.2 生成对抗网络(GAN)结构
      • 3.3 模式崩溃解决
      • 3.4 合成数据验证
    • 4. 代码实现
    • 参考文献

在这里插入图片描述

1. 模型介绍

在大数据时代,生成逼真的时间序列数据对于负载平衡、负载预测和智能资源配置等方面至关重要。多变量时间序列数据生成模型基于生成对抗网络(GAN)的能力,能够在不泄露真实数据隐私的前提下生成相似的合成数据。本文介绍了一种用于生成多变量时间序列数据的GAN模型。

2. 问题提出

云和边缘计算领域的监控数据通常是商业机密或受到数据法规(如GDPR)的保护,获取真实数据用于研究和开发变得非常困难。为了应对这一挑战,研究人员使用合成数据来填补数据空缺。在这种背景下,多变量时间序列生成模型通过GAN的使用,为生成任意数量的时间序列工作负载数据提供了一种新方法。其目标是学习真实生产工作负载的概率分布,并生成统计上相似的时间序列数据。

3. 模型具体实现

3.1 数据预处理

  • 数据格式化:将原始数据转换为所需格式。
  • 样本过滤:过滤掉不完整的样本。
  • 特征缩放:将特征缩放到定义范围内,通常为[0, 1]。
  • 数据标准化:将样本适应到[0, 1]的范围内,加速梯度下降过程。

3.2 生成对抗网络(GAN)结构

GAN由两个人工神经网络组成:判别器(Discriminator)和生成器(Generator),通过最小最大博弈进行训练。

  • 判别器(Discriminator, D)
    • 两层LSTM单元构成,最后一层为单个LSTM单元输出层,用于最终分类。
    • 输入形状为 n × m n \times m n×m n n n为时间步数, m m m为特征数)。

y = D ( h output ) y = D(h_{\text{output}}) y=D(houtput)
在这里插入图片描述

  • 生成器(Generator, G)
    • 两层递归层,用高斯噪声初始化,最后连接一个全连接输出层。
    • 每个时间步对应一个输出单元。

h ^ S = g S ( z S ) , h ^ t = g X ( h ^ S , h ^ t − 1 , z t ) \hat{h}_S = g_S(z_S), \quad \hat{h}_t = g_X(\hat{h}_S, \hat{h}_{t-1}, z_t) h^S=gS(zS),h^t=gX(h^S,h^t1,zt)
在这里插入图片描述

  • GAN目标
    • 判别器目标:区分真实和合成数据。
    • 生成器目标:生成判别器无法区分的合成数据。

3.3 模式崩溃解决

  • 为不同分布的序列训练独立的GANs,避免模式崩溃。
  • 使用自动化容器化工作流提高可重复性和可扩展性。

3.4 合成数据验证

  • 描述性统计:均值和标准差用于评估数据的分布和趋势。
  • 时间序列分析:计算时间序列的相关性和协整性。

在这里插入图片描述

在这里插入图片描述

4. 代码实现

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

# 定义生成器网络
class Generator(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(Generator, self).__init__()
        self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers=2, batch_first=True)
        self.fc = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        # 确保 x 是三维的
        if x.dim() == 2:
            x = x.unsqueeze(1)  # 添加时间维度
        batch_size = x.size(0)
        h_0 = torch.zeros(2, batch_size, hidden_dim).to(x.device)
        c_0 = torch.zeros(2, batch_size, hidden_dim).to(x.device)
        out, _ = self.lstm(x, (h_0, c_0))
        out = self.fc(out[:, -1, :])  # 使用最后一个时间步的输出
        return out

# 定义判别器网络
class Discriminator(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(Discriminator, self).__init__()
        self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers=2, batch_first=True)
        self.fc = nn.Linear(hidden_dim, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        # 确保 x 是三维的
        if x.dim() == 2:
            x = x.unsqueeze(1)  # 添加时间维度
        batch_size = x.size(0)
        h_0 = torch.zeros(2, batch_size, hidden_dim).to(x.device)
        c_0 = torch.zeros(2, batch_size, hidden_dim).to(x.device)
        out, _ = self.lstm(x, (h_0, c_0))
        out = self.fc(out[:, -1, :])  # 使用最后一个时间步的输出
        out = self.sigmoid(out)
        return out

# 超参数设置
input_dim = 100  # 输入维度
hidden_dim = 64  # 隐藏层维度
output_dim = 100  # 输出维度
batch_size = 32
num_epochs = 1000
learning_rate = 0.0002

# 初始化生成器和判别器
generator = Generator(input_dim, hidden_dim, output_dim)
discriminator = Discriminator(output_dim, hidden_dim)

# 优化器
g_optimizer = optim.Adam(generator.parameters(), lr=learning_rate)
d_optimizer = optim.Adam(discriminator.parameters(), lr=learning_rate)

# 损失函数
criterion = nn.BCELoss()

# 加载数据(使用随机数据进行示例)
real_data = torch.randn(1000, input_dim).unsqueeze(1)  # 确保输入是三维的
dataloader = DataLoader(real_data, batch_size=batch_size, shuffle=True)

# 训练GAN模型
for epoch in range(num_epochs):
    for real_samples in dataloader:
        # 训练判别器
        real_samples = real_samples.float()
        batch_size = real_samples.size(0)
        real_labels = torch.ones(batch_size, 1)
        fake_labels = torch.zeros(batch_size, 1)

        # 生成假的样本
        noise = torch.randn(batch_size, input_dim).unsqueeze(1)  # 确保噪声是三维的
        fake_samples = generator(noise)

        # 计算判别器损失
        d_real_loss = criterion(discriminator(real_samples), real_labels)
        d_fake_loss = criterion(discriminator(fake_samples.detach()), fake_labels)
        d_loss = d_real_loss + d_fake_loss

        d_optimizer.zero_grad()
        d_loss.backward()
        d_optimizer.step()

        # 训练生成器
        noise = torch.randn(batch_size, input_dim).unsqueeze(1)  # 确保噪声是三维的
        fake_samples = generator(noise)
        g_loss = criterion(discriminator(fake_samples), real_labels)

        g_optimizer.zero_grad()
        g_loss.backward()
        g_optimizer.step()

    # 打印损失
    if (epoch + 1) % 100 == 0:
        print(f'Epoch [{epoch + 1}/{num_epochs}], d_loss: {d_loss.item():.4f}, g_loss: {g_loss.item():.4f}')

参考文献

[1] Leznik, M., Michalsky, P., Willis, P., Schanzel, B., Östberg, P., & Domaschka, J. (2021). Multivariate Time Series Synthesis Using Generative Adversarial Networks. In Proceedings of the 2021 ACM/SPEC International Conference on Performance Engineering (ICPE ’21), April 19–23, 2021, Virtual Event, France. ACM, New York, NY, USA, 8 pages. https://doi.org/10.1145/3427921.3450257

[2] Goodfellow, I., Pouget-Abadie, J., Mirza, M., Xu, B., Warde-Farley, D., Ozair, S., ... & Bengio, Y. (2014). Generative adversarial nets. Advances in neural information processing systems, 27.

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1990690.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

openwrt 性能工具perf和cpu占用查看工具sysstat编译及使用

代码使用的lean源码,只需要用make menuconfig打开perf对应的编译选项即可 1.第一步选择Global build settings 2.第二步选择Kernel build options 3.第三步选择Enable kernel cgroups 4.第四步选择Enable perf_event per-cpu per-container group (cgroup) monitor…

计算机网络-CSP初赛知识点整理

历年真题 [2016-NOIP-普及-第3题] 以下不属于无线通信技术的是( ) A. 蓝牙 B. Wifi C. GPRS D. 以太网 [2015-NOIP-普及-第10题] FTP 可以用于( )。 A. 远程传输文件 B. 发送电子邮件 C. 浏览网页 D. 网上聊天 [2019-CSP-J-第1题] 中国的国家顶级域名是( ). A. .cn B. .ch C.…

国内自闭症学校指南:了解孩子的康复需求和解决方案

在国内,自闭症儿童的数量逐年增加,为他们提供专业的教育和康复支持变得至关重要。对于家长来说,选择一所合适的自闭症学校是帮助孩子走向康复的关键一步。在众多的选择中,星贝育园以其独特的优势和全面的服务脱颖而出。 当孩子被诊…

android系统中data下的xml乱码无法查看问题剖析及解决方法

背景: Android12高版本以后系统生成的很多data路径下的xml都变成了二进制类型,根本没办法看xml的内容具体如下: 比如想要看当前系统的widget的相关数据 ./system/users/0/appwidgets.xml 以前老版本都是可以直接看的,这些syste…

Cxx primer-chap13-Copy Control

copy控制涉及类的五个成员函数:,这五个成员函数被显式或隐式的被调用,各司其职:我们必须根据类的实际情况来确定是否需要显式定义这些成员函数:什么是拷贝构造函数呢?简单讲就是该函数的第一个形参是refere…

malloc函数与free函数

目录 开头1.怎样把数组初始化时的项数变成变量?malloc函数free函数 2.malloc函数与free函数的实际运用CC6 牛牛的排序随机乱码打印随机数组打印 结尾 开头 大家好,我叫这是我58。今天,我们来学一下如何把数组初始化时的项数变成变量的一些知识。 1.怎…

【书生大模型实战营第三期 | 入门岛第3关-Git 基础知识】

学习Git版本控制系统心得体会 摘要 通过参与InternLM Git教程,我对Git这一开源的分布式版本控制系统有了更深入的理解和实践。Git以其高效的团队协作能力、详尽的代码历史记录以及灵活的分支管理功能,成为软件开发中不可或缺的工具。 文章大纲 Git简介…

自查出癌症后 凯特王妃的生活观发生了变化 王室的粉丝们也应该会很少见到她

凯特米德尔顿今年的健康问题令人意外,这也改变了王室的面貌。这位威尔士王妃每次露面都引来巨大关注,因此王室不得不发挥创意,将更多精力放在威廉王子、索菲、爱丁堡公爵夫人,甚至查理三世国王的社交日程上。王室粉丝们可能期待着凯特恢复健康,恢复正常日程,但内部人士称…

unity 粒子系统学习

差不多了解了基本的ui面板,学一下粒子系统 取消轮廓线 这样粒子biubiu的时候就没有橙黄色的轮廓线了

lvs的dr模式实现

目录 一、实验环境准备 1、五台红帽9系统的主机 2、关闭所有的防火墙以及关闭selinux 二、在lvs中配置 1、在lvs中安装lvs软件并设置开机启动 2、在lvs中打开内核路由功能,并把它写入/etc/sysctl.conf文件中 3、webserver1和webserver2下载httpd 4、在lvs主机…

【Redis进阶】Redis单线程模型和多线程模型

目录 单线程 为什么Redis是单线程 处文件事件理器的结构 文件处理器的工作流程 总结 文件事件处理器 连接应答处理器 命令请求处理器 命令回复处理器 多线程 为什么引入多线程 多线程架构 多线程执行流程 关于Redis的问题 Redis为什么采用单线程模型 Redis为什…

【STM32】USART串口和I2C通信

个人主页~ USART串口和I2C通信 USART串口一、串口1、简介2、电路要求3、参数及时序 二、USART外设1、USART结构2、波特率发生器 三、数据包1、HEX数据包HEX数据包接收 2、文本数据包文本数据包接收 I2C通信一、简介二、通信协议1、硬件电路2、I2C时序基本单元 三、I2C外设1、简…

Chapter 29 类型注解

欢迎大家订阅【Python从入门到精通】专栏,一起探索Python的无限可能! 文章目录 前言一、变量的类型注解二、函数的类型注解三、Union类型注解 前言 类型注解为我们提供了一种清晰的方式来描述变量和函数的预期类型,使得代码的意图更加明确。…

GMMREG:基于高斯混合模型的鲁棒点集配准

其关键思想都是用连续密度函数表示离散点集,即高斯混合模型。不同点在于本节算法采用L2距离来衡量两个点云之间的相似性,而5.7节中的NDT算法采用的是作者定义的匹配势来衡量,实际上是所有线段对之间的差异。并且本节算法中加入了薄板样条插值…

打造分布式缓存组件【场景】

本文将采用AOP 反射 Redis自定义缓存标签,重构缓存代码,打造基础架构分布式缓存组件 配置 需要在Redis配置类中开启AOP自动代理,即通过EnableAspectJAutoProxy 注解实现该功能 import com.fasterxml.jackson.annotation.JsonAutoDetect; …

「链表」链表原地算法合集:原地翻转|原地删除|原地取中|原地查重 / LeetCode 206|237|2095|287(C++)

概述 对于一张单向链表,我们总是使用双指针实现一些算法逻辑,这旨在用常量级别空间复杂度和线性时间复杂度来解决一些问题。 所谓原地算法,是指不使用额外空间的算法。 现在,我们利用双指针实现以下四种行为。 //Definition fo…

Linux驱动.之I2C,iic驱动层(二)

一、 Linux下IIC驱动架构 本篇只分析,一个整体框架。 1、首先说说,单片机,的i2c硬件接口图,一个i2c接口,通过sda和scl总线,外接了多个设备device,通过单片机,来控制i2c的信号发生&…

VUE和Element Plus

1.VUE 1.下载和配置环境 使用vue编程,我们需要使用到的编程软件是vs code,还需要使用node.js,这个的作用就类似于JDK,当我们都下载好之后,winR键打开命令提示符,我们在这里可以查看版本, npm…

《计算机网络 - 自顶向下方法》阅读笔记

《计算机网络 - 自顶向下方法》阅读笔记 应用层、运输层、网络层、数据链路层 计算机网络和因特网: 因特网: ​ 是一个世界范围的计算机网络,互联了全世界的计算机设备 计算机设备:手机,电脑,游戏机&#…

MATLAB数据可视化:在地图上画京沪线的城市连线

matlab自带的geoplot(lat,lon) 可以在地理坐标中绘制线条。使用 lat和lon分别指定以度为单位的经度和纬度坐标。 绘制京沪线所经城市线条: citys [116.350009,39.853928; 116.683546,39.538304; 117.201509,39.085318; 116.838715,38.304676;...116.359244,37.436…