时间序列生成数据,TransformerGAN

news2024/11/15 8:27:44

        简介:这个代码可以用于时间序列修复和生成。使用transformer提取单变量或者多变时间窗口的趋势分布情况。然后使用GAN生成分布类似的时间序列。

        此外,还实现了基于prompt的数据生成,比如指定生成某个月份的数据、某半个月的数据、某一个星期的数据。

1、模型架构

        如下图所示,生成器和鉴别器都使用Transformer的编码器部分提取时间序列的特征,然后鉴别器使用这些进行二分类、生成器使用这些特征生成伪造的数据。

        重点:在下面的图的基础上,我还添加了基于提示的生成代码,类似于AI提示绘画一样,因此可以指定生成一月份、二月份等任意指定周期的数据。

2、训练GAN的代码

        下面是GAN的训练部分。

# 训练GAN
num_epochs = 100
for epoch in range(num_epochs):
    for real_x,x_g,zz in loader: # 分别是真实值real_x、提示词信息x_g、噪声zz
        real_data = real_x
        noisy_data = x_g
        # Train Discriminator
        optimizer_D.zero_grad()
        out = discriminator(real_data)
        real_loss = criterion(discriminator(real_data), torch.ones(real_data.size(0), 1))
        fake_data = generator(noisy_data,zz)
        fake_loss = criterion(discriminator(fake_data.detach()), torch.zeros(fake_data.size(0), 1))
        d_loss = real_loss + fake_loss
        d_loss.backward()
        optimizer_D.step()

        # Train Generator
        optimizer_G.zero_grad()
        g_loss = criterion(discriminator(fake_data), torch.ones(fake_data.size(0), 1))
        g_loss.backward()
        optimizer_G.step()

    print(f'Epoch [{epoch+1}/{num_epochs}], D Loss: {d_loss.item()}, G Loss: {g_loss.item()}')

3、生成器代码

class Generator(nn.Module):
    def __init__(self, seq_len=8, patch_size=2, channels=1, num_classes=9, latent_dim=100, embed_dim=10, depth=1,
                 num_heads=5, forward_drop_rate=0.5, attn_drop_rate=0.5):
        super(Generator, self).__init__()
        self.channels = channels
        self.latent_dim = latent_dim
        self.seq_len = seq_len
        self.embed_dim = embed_dim
        self.patch_size = patch_size
        self.depth = depth
        self.attn_drop_rate = attn_drop_rate
        self.forward_drop_rate = forward_drop_rate
        
        self.l1 = nn.Linear(self.latent_dim, self.seq_len * self.embed_dim)
        self.pos_embed = nn.Parameter(torch.zeros(1, self.seq_len, self.embed_dim))
        self.blocks = Gen_TransformerEncoder(
                         depth=self.depth,
                         emb_size = self.embed_dim,
                         drop_p = self.attn_drop_rate,
                        )

        self.deconv = nn.Sequential(
            nn.Conv2d(self.embed_dim, self.channels, 1, 1, 0)
        )

    def forward(self, z):
        x = self.l1(z).view(-1, self.seq_len, self.embed_dim)
        x = x + self.pos_embed
        H, W = 1, self.seq_len
        x = self.blocks(x)
        x = x.reshape(x.shape[0], 1, x.shape[1], x.shape[2])
        output = self.deconv(x.permute(0, 3, 1, 2))
        output = output.view(-1, self.channels, H, W)
        return output

4、生成数据和真实数据分布对比

        使用PCA和TSNE对生成的时间窗口数据进行降维,然后scatter这些二维点。如果生成的真实数据的互相混合在一起,说明模型学习到了真东西,也就是模型伪造的数据和真实数据分布是一样的,美滋滋。从下面的PCA可以看出,两者的分布还是近似的。

        进一步的,可以拟合两个二维正态分布,然后计算他们的KL散度作为一个评价指标。

5、生成数据展示

        上面是真实数据、下面是伪造的数据。由于只有几百个样本,以及参数都没有进行调整,但是效果还不错。

6、损失函数变化情况

        模型还是学习到了一点东西的。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1630970.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Git变更账户、查看账户

1、变更账户 (1)修改当前文件夹用户 git config user.name “新用户名” git config user.email “新邮箱” (2)修改全局git用户 git config - -global user.name “新用户名” git config - -global user.email “新邮箱”…

TensorRT模型压缩

参考链接: https://www.jianshu.com/p/805b76c9d229 https://blog.csdn.net/chenhaogu/article/details/132677778 基本理论 TensorRT是一个高性能的深度学习推理优化器,可以为深度学习应用提供低延迟、高吞吐率的部署推理。TensorRT可用于对超大规模数…

(六)几何平均数计算 补充案例 #统计学 #CDA学习打卡

一. 两个案例 1)几何平均数计算:基金年平均增长率计算 在财务、投资和银行业的问题中,几何平均数的应用尤为常见,当你任何时候想确定过去几个连续时期的平均变化率时,都能应用几何平均数。其他通常的应用包括物种总体…

《HCIP-openEuler实验指导手册》1.2Apache主页面配置

知识点 一、配置服务器监听IP及端口 注释主配置文件“监听IP及端口”部分 cd /etc/httpd/conf cp httpd.conf httpd.conf.bak vim httpd.conf可以在普通模式下搜索Listen关键字 :/Listen按n键继续向后搜索 在/etc/httpd/conf.d中新建子配置文件port.conf: tou…

FPGA秋招-笔记整理(3)无符号数、有符号数

参考:Verilog学习笔记——有符号数的乘法和加法 一、无符号数、有符号数 将输入输出全部定义为有符号数 (1)无符号数的读取按照原码进行,有符号数的读取应该按照补码读取,计算规则为去掉符号位后取反、加1在计算数值…

84.柱形图中最大的矩阵

二刷终于能过了. 思路解析: 不愧是hard,第一步就很难想, 对于每一个矩阵,我们要想清楚怎么拿到最大矩阵, 对于每个height[i],我们需要找到left和right,left是i左边第一个小于height[i]的,right是右边第一个小于height[i]的,那么他的最大矩阵就是height[i] * (right-left-…

暗区突围无限测试资格获取教程 暗区突围怎么参加测试教程推荐

《暗区突围:无限》是一款强调极致真实与沉浸感的军事策略游戏,它邀请玩家踏入危机四伏的战场,通过射击、智取与突袭,铺就一条满载财富的道路。在这片未知与挑战并存的土地上,游戏赋予玩家极高的行动自由,旨…

苍穹外卖绕过微信支付

经过以下改动可实现: 1、不用微信支付端口 2、弹出支付成功的界面 3、数据库修改支付成功后的数据 #在OrderServiceImpl.java里加入Autowiredprivate OrderService orderService; #在OrderServiceImpl.java里的payment函数做以下改动 #图片里有,红色为原…

如何从requirements.txt文件中安装pytorch

平时使用requirements.txt文件来安装python的依赖,如下所示: Flask3.0.0 Flask-Cors4.0.0 elastic-transport8.11.0 elasticsearch8.11.1但是如果我们的依赖中包含pytorch依赖,显然是不能简单的通过这个方式来进行的,例如&#x…

怎样将便签软件搬家?便签迁移攻略

便签软件已成为我们日常生活中不可或缺的记录工具。无论是重要的工作内容,还是琐碎的生活事务,我们都会在便签中一一记下。然而,当我们需要更换电脑或其他设备时,如何将这些珍贵的便签内容迁移到新设备上,成为了许多人…

Linxu系统服务管理,systemd知识/进程优先级/平均负载/php进程CPU100%怎么解决系列知识!

shell脚本(命令)放后台 sleep 300& 放到后台运行,脚本或命令要全路径 nohup:用户推出系统进程继续工作 【功能说明】 nohup 命令可以将程序以忽略挂起信号的方式运行起来,被运行程序的输出信息将不会显示到终端 如…

VS2019配合QT5.9开发IRayAT430相机SDK

环境配置 VS2019 QT5.9 编译器版本 MSVC2017_64添加系统环境变量(完毕后重启电脑) 从VS2019中下载Qt插件 从VS2019中添加单个编译组件 上述操作完成后用VS打开工程文件,工程文件地址 : C:\Users\86173\Desktop\IRCNETSDK_W…

337. 打家劫舍 III 树上最大独立集

这个题目其实就是 树形结构的 打家劫舍 (隔一家抢一家) 在打家劫舍中,我们是使用dp[i]记录当前的最大值,如果上一家刚抢过,就返回dp[i-2] 而对于这个题目,我们对每个节点记录两个状态,就是抢和没抢的最大值 具体转移方程如下: dfs(node,rob) 表示对于当前node节点,抢了true或…

最大连续1的个数 ||| ---- 滑动窗口

题目链接 题目: 分析: 题目中说可以将最多k个0翻转成1, 如果我们真的这样算就会十分麻烦, 所以我们可以换一种思路: 找到一个最长的子数组, 最多有k个0解法一: 暴力解法: 找到所有的最多有k个0的子字符串, 返回最长的解法二: 找到最长的子数组, 我们可以想到"滑动窗口算…

968.监控二叉树 树上最小支配集

法一: 动态规划 一个被支配的节点只会有三种状态 1.它本身有摄像头 2.他没有摄像头, 但是它的父节点有摄像头 3.他没有摄像头, 但是它的子节点有摄像头 我们 dfs(node,state) 记录在node节点时(以node为根的子树),状态为state下的所有最小摄像头 // 本身有摄像头就看左右孩子…

2024/4/28 C++day5

有以下类&#xff0c;完成特殊成员函数 class Person { string name; int *age; } class Stu:public Person { const double score; } #include <iostream> #include <string> using namespace std; class Person { string name; int *age ; publi…

Vision Pro“裸眼上车”,商汤绝影全新舱内3D交互亮相

2023年&#xff0c;Apple Vision Pro的横空出世让人们领略到了3D交互的魅力&#xff0c;商汤绝影通过深厚的技术研发实力和高效的创新迭代效率&#xff0c;带来两大全新座舱3D交互&#xff1a;3D Gaze高精视线交互和3D动态手势交互。 作为全球首创的能够通过视线定位与屏幕图标…

16.Blender 基础渲染工作流程及安装ACES

安装插件和菜单栏设置 在菜单栏的编辑里打开偏好设置 里面的插件界面 搜索node 给第三个打勾 点击安装&#xff0c;导入cat插件 安装完后&#xff0c;一定要打勾&#xff0c;选择上cat插件 这样N窗口才会显示MMD选项 导入场景 点击打开 把输出模式的帧率改为30fps 按…

CSS 之 transition过渡动画

一、简介 ​ CSS 制作 Web 动画有两种方式&#xff1a; 帧动画&#xff08;Keyframe Animation&#xff09;和过渡动画&#xff08;Transition Animation&#xff09;。针对不同的业务场景中&#xff0c;我们应该选择不同的动画方式&#xff0c;通常来说&#xff1a;对于交互元…

OpenHarmony实战开发-多层级手势事件

多层级手势事件指父子组件嵌套时&#xff0c;父子组件均绑定了手势或事件。在该场景下&#xff0c;手势或者事件的响应受到多个因素的影响&#xff0c;相互之间发生传递和竞争&#xff0c;容易出现预期外的响应。 本章主要介绍了多层级手势事件的默认响应顺序&#xff0c;以及…