python基于GCN(图卷积神经网络模型)和LSTM(长短期记忆神经网络模型)开发构建污染物时间序列预测模型

news2025/1/15 19:51:21

在以往的时间序列预测建模中广泛使用的是回归类算法模型和RNN类的算法模型,相对来说技术栈会更稳定一些,最近有一个实际业务场景的需求,在建模的过程中要综合考虑其余点位的影响依赖,这时候我想到了之前做过的交通流量和速度预测相关的项目,在那里采用的就是图相关的算法模型,所以这里也想对标来开发。

GCN(Graph Convolutional Network)是一种用于处理图结构数据的卷积神经网络模型。它的构建原理是基于图卷积操作,通过在图上进行局部的卷积运算来提取节点的特征表示。

具体来说,GCN通过邻居节点的信息聚合来更新每个节点的表示。GCN的每一层都可以表示为以下的公式:

H^{(l+1)} = σ(D^{-0.5}AD^{-0.5}H^{(l)}W^{(l)})

其中,H^{(l)}表示第l层的节点表示矩阵,A表示节点之间的邻接矩阵,D是度矩阵,W^{(l)}表示第l层的权重矩阵,σ表示激活函数。

GCN的优点主要体现在以下几个方面:

  1. 考虑了节点的邻居信息:GCN通过聚合节点的邻居信息来更新节点的表示,能够捕捉到节点的局部结构信息,适用于处理图结构数据。
  2. 具有共享权重的卷积:GCN通过共享权重矩阵来进行卷积操作,减少了需要学习的参数数量,降低了模型的复杂度。
  3. 良好的泛化能力:GCN能够将节点的特征向量传递给邻居节点,从而在整个图上进行信息传播,有助于节点的特征表示学习。

然而,GCN也存在一些缺点:

  1. 无法处理动态图:GCN对于静态图结构的处理效果较好,但对于动态图的处理存在一定的困难,因为动态图的结构会随着时间的推移而变化。
  2. 局部邻居信息的限制:GCN只考虑节点的一阶邻居信息,对于高阶邻居信息的利用能力有限。
  3. 对大规模图的计算开销较大:GCN的计算复杂度与图的规模相关,对于大规模图的处理需要较高的计算资源。

GCN作为一种处理图结构数据的卷积神经网络模型,具有考虑邻居信息、共享权重、泛化能力强等优点,但对动态图的处理存在限制,并且对大规模图的计算开销较大。

LSTM(Long Short-Term Memory)是一种用于处理序列数据的循环神经网络模型。它的构建原理是通过引入门控机制来解决传统RNN存在的梯度消失和梯度爆炸问题,从而能够更好地捕捉长期依赖关系。

具体来说,LSTM通过三个门控单元(输入门、遗忘门和输出门)来控制信息的流动和记忆的更新。它的计算过程可以表示为以下公式:

输入门:i_t = σ(W_{xi}x_t + W_{hi}h_{t-1} + b_i)
遗忘门:f_t = σ(W_{xf}x_t + W_{hf}h_{t-1} + b_f)
输出门:o_t = σ(W_{xo}x_t + W_{ho}h_{t-1} + b_o)
新的记忆:\tilde{c}t = tanh(W{xc}x_t + W_{hc}h_{t-1} + b_c)
细胞状态更新:c_t = f_t \odot c_{t-1} + i_t \odot \tilde{c}_t
隐藏状态更新:h_t = o_t \odot tanh(c_t)

其中,i_t、f_t和o_t分别表示输入门、遗忘门和输出门的输出,\tilde{c}_t表示新的记忆,c_t表示细胞状态,h_t表示隐藏状态,x_t表示当前时刻的输入。

LSTM的优点主要体现在以下几个方面:

  1. 解决了梯度消失和梯度爆炸问题:通过门控机制,LSTM能够有效地捕捉长期依赖关系,解决了传统RNN在处理长序列时产生的梯度问题。
  2. 长期记忆能力强:LSTM能够在细胞状态中长期存储信息,有助于模型记住对当前任务有用的信息。
  3. 能够处理不定长的序列:LSTM可以处理不定长的序列数据,适用于多种应用场景,如自然语言处理、语音识别等。

然而,LSTM也存在一些缺点:

  1. 计算复杂度较高:LSTM中引入了多个门控单元和记忆单元,导致模型的计算复杂度较高,对计算资源要求较大。
  2. 可能存在过拟合:LSTM的参数数量较多,模型容易过拟合,需要采取一些正则化方法来缓解这个问题。
  3. 难以并行化:LSTM的计算过程涉及到门控单元的计算,导致模型难以进行有效的并行化,限制了其在大规模数据上的应用。

LSTM作为一种用于处理序列数据的循环神经网络模型,具有解决梯度问题、长期记忆能力强等优点,但计算复杂度较高、难以并行化等缺点。

这里本文的核心思想是想要基于GCN+LSTM来实现图网络和时序网络的融合,开发构建更加适合业务需求的时间序列预测模型。

首先是数据集加载,如下所示:

with open("distance.json") as f:
    route_distances = json.load(f)
with open("nodedata.json") as f:
    speeds_array = json.load(f)
route_distances = np.array(route_distances)
speeds_array = np.array(speeds_array)
print(f"route_distances shape={route_distances.shape}")
print(f"speeds_array shape={speeds_array.shape}")

结果输出如下所示:

route_distances shape=(23, 23)
speeds_array shape=(1056, 23)

共包含了23个节点。

之后是创建邻接矩阵,如下:

num_routes = route_distances.shape[0]
route_distances = route_distances / 10000
w2, w_mask = (
    route_distances * route_distances,
    np.ones([num_routes, num_routes]) - np.identity(num_routes),
)
adjacency_matrix = (np.exp(-w2 / sigma2) >= epsilon) * w_mask

结果输出如下所示:

一个简单的GCN+LSTM实现如下所示:

def gcn_layer(adj_matrix, input_features, output_dim):
    # GCN层的实现
    adj_normalized = normalize_adj(adj_matrix) # 对邻接矩阵进行归一化处理
    output = Dense(output_dim)(adj_normalized @ input_features) # GCN的公式实现
    output = Activation('relu')(output)
    return output

# 构建GCN+LSTM模型
def build_gcn_lstm_model(adj_matrix, input_dim, hidden_dim, output_dim):
    # 输入层
    inputs = Input(shape=(None, input_dim))
    
    # GCN层
    gcn_output = gcn_layer(adj_matrix, inputs, hidden_dim)
    
    # LSTM层
    lstm_output = LSTM(hidden_dim)(gcn_output)
    
    # 输出层
    outputs = Dense(output_dim, activation='softmax')(lstm_output)
    
    model = Model(inputs=inputs, outputs=outputs)
    model.compile(optimizer=Adam(lr=0.001), loss='categorical_crossentropy', metrics=['accuracy'])
    
    return model

也可以自己定义新的layer都是可以的,如下:

class GCNLSTM(layers.Layer):

    def __init__(**kwargs):
        super().__init__(**kwargs)
        self.graph_conv = gcnLayer(in_feat, out_feat, graph_info, **graph_conv_params)
        self.lstm = layers.LSTM(lstm_units, activation="relu")
        self.dense = layers.Dense(output_seq_len)
        self.input_seq_len, self.output_seq_len = input_seq_len, output_seq_len

    def get_config(self):
        config = super().get_config().copy()
        return config

    def call(self, inputs):
        inputs = tf.transpose(inputs, [2, 0, 1, 3])
        gcn_out = self.graph_conv(inputs) 
        shape = tf.shape(gcn_out)
        gcn_out = tf.reshape(gcn_out, (batch_size * num_nodes, input_seq_len, out_feat))
        lstm_out = self.lstm(gcn_out) 
        dense_output = self.dense(lstm_out) 
        output = tf.reshape(dense_output, (num_nodes, batch_size, self.output_seq_len))
        return tf.transpose(output, [1, 2, 0]) 

之后就可以进行模型的训练了如下:

history=model.fit(X_train,y_train,validation_data=(X_test,y_test),epochs=200)
model.save("model.h5")
plotLoss(history.history["loss"],history.history["val_loss"],"loss.png")

loss可视化如下所示:

这里我们随机选取了几个node对其数据进行可视化,如下所示:

node之间的热力图如下所示:

对测试集进行测试评估,对比曲线可视化结果如下所示:

这里对不同node未来时刻进行持续预测,分别展示,如下所示:

这是一种比较新的建模方式,在实际应用中可以更多捕捉到不同节点的依赖性,这个是比较有用的一点,后续会在实际应用过程中继续挖掘。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1240008.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【软件测试】接口——基本测试流程

1、接口测试需求获取 (1)获取接口文档 (2)通过接口文档获取接口信息 (3)确认接口测试需求 2、 接口测试计划编写 (1)目标:确认需求、资源、方法、进度方案 &#x…

【VScode】安装配置、插件及远程SSH连接

一、VSCode安装 二、配置安装插件 三、配置远程连接SSH 四、MinGW 一、VSCode安装 VS官网 Visual Studio Code - Code Editing. Redefined下载安装包: 二、配置安装插件 安装中文插件 配置字体为20 配置文件–>首选项->设置->Font Size为20 设置 VSC…

想分析全国用电及煤气、液化石油气供应利用情况,这部分数据对你有帮助!

随着经济的发展和人民生活水平的提高,能源的需求量越来越大。其中,电力和煤气、液化石油气等能源的供应利用情况与我们的日常生活息息相关。 今天我们根据《中国城市统计年鉴》统计的中国地级及以上城市的煤气及液化石油气供应及利用情况的指标&#xff…

芯片被磨掉型号分辨不出怎么解决?

很多电子工程师,在芯片或单片机解密时,经常会遇到芯片被加密的情况,更糟糕的是这个芯片的型号被打磨了,无法确定其准确型号,那么如何解决?今天凡小亿开课好好讲讲,希望对小伙伴们有所帮助。 1、…

11.6AOP

一.AOP是什么 是面向切面编程,是对某一类事情的集中处理. 二.解决的问题 三.AOP的组成 四.实现步骤 1.添加依赖(版本要对应): maven仓库链接 2.添加两个注解 3.定义切点 4.通知 5.环绕通知 五.execution表达式 六.AOP原理 1.建立在动态代理的基础上,对方法级别的拦截. 2.…

Arcgis中通过函数实现字符串截取

效果 从字符串中提取最右侧的符号,如“/”后面的字符串 步骤 1、VB dim bbindexinstrrev( [WGCJ] ,"/")bbright( [WGCJ] ,len( [WGCJ] )- index )2、python def bb(aa):index(aa.rfind("/"))bbaa[index1:]return bb

基于docker实现JMeter分布式压测

为什么需要分布式? 在工作中经常需要对一些关键接口做高QPS的压测,JMeter是由Java 语言开发,没创建一个线程(虚拟用户),JVM默认会为每个线程分配1M的堆栈内存空间。受限于单台试压机的配置很难实现太高的并…

在windows笔记本中安装tensorflow1.13.2版本的gpu环境2

tensorflow1.13.2版本的gpu环境 看python-anacona的安装只需要看1.1部分即可 目录 1.1 Anaconda安装 1.2 tensorflow-gpu安装 1.3 python编译器-pycharm安装 1.1 Anaconda安装 从镜像源处下载anaconda,地址:Index of /anaconda/archive/ | 北京…

【无标题】dp80采集机和机器人通信相关框架总结

采血机器人通信解析相关框架总结: 类似于dp80,将整个过程进行了分解如下: 类似于dp80,将整个过程进行了分解如下: 上位机界面在进行点击操作的时候,先是通信协议的解析,解析后改变采血的控制状态如下: Dp80主要框架解析࿱

Godot

前言 为什么要研究开源引擎 主要原因有: 可以享受“信创”政策的红利,非常有利于承接政府项目。中美脱钩背景下,国家提出了“信创”政策。这个政策的核心就是,核心技术上自主可控。涉及的产业包括:芯片、操作系统、数据…

Vue 3 渲染机制解密:从模板到页面的魔法

Vue 3 渲染机制解密 前言Vue 3的响应性系统1. **Reactivity API:**2. **Proxy 对象:**3. **Getter 和 Setter:**4. **依赖追踪:**5. **批量更新:**6. **异步更新:**7. **递归追踪:**8. **删除属性:** 虚拟DOM的角色1. **减少直接操作真实 DOM:**2. **高效的批量更新:**3. **跨平…

sonar对webgoat进行静态扫描

安装sonar并配置 docker安装sonarqube,sonarQube静态代码扫描 - Joson6350 - 博客园 (cnblogs.com) 对webgoat进行sonar扫描 扫描结果 bugs Change this condition so that it does not always evaluate to "false" 意思是这里的else if语句不会执行…

git merge 和 git rebase

一、是什么 在使用 git 进行版本管理的项目中,当完成一个特性的开发并将其合并到 master 分支时,会有两种方式: git merge git rebasegit rebase 与 git merge都有相同的作用,都是将一个分支的提交合并到另一分支上,…

Redis的持久化(新)

Redis中数据都保存在内存,但是内存中的数据变换很快,也很容易丢失,比如连接断开、宕机停机等等。而Redis提供的数据持久化机制有RDB(Redis DataBase)和AOF(Append Only File)。 1.RDB RDB是指在指定的时间间隔内将内存中的数据集快照写入到磁…

static和extern

1.extern extern 是⽤来声明外部符号的,如果⼀个全局的符号在A⽂件中定义的,在B⽂件中想使⽤,就可以使⽤ extern 进⾏声明,然后使⽤。 即在一个源文件中想要使用另一个源文件,即可通过这个extern来声明使用。 2.st…

NC65 修改元数据字段长度

NC65 修改元数据字段长度,执行下面sql,执行完后需要重启NC服务才生效。 --属性 update md_property set attrlength 200 where name fphm and classidece96dd8-bdf8-4db3-a112-9d2f636d388f ;--列 update md_column set columnlength 200 where tab…

Java的判空

校验List 在进行参数非空校验时,我们总是容易直接就对任何的东西都来一个! null 其实这样并不完全正确,对list类型来说 不等于null和list.size ! 0是不一样的。 不等于null是指已经声明了list的存在,并且在堆内存中有了引用。 …

ROS设置DHCP option121

配置时,了解格式很关键,16进制填写格式如下: 将要访问的IPV4地址:192.168.100.0/24 192.168.30.254 转换为:掩码 目标网段 网关 0x18c0a864c0a81efe,0不用填写 ROS配置如下图: 抓…

【Qt开发流程】之富文本处理

描述 Scribe框架提供了一组类,用于读取和操作结构化的富文本文档。与Qt中以前的富文本支持不同,新的类集中在QTextDocument类上,而不是原始文本信息。这使开发者能够创建和修改结构化的富文本文档,而不必准备中间标记格式的内容。…

修改 OkHttp3 的超时时间

修改 OkHttp3 的超时时间 一. 前言二. 导入mavengradle 三. 设置超时时间 一. 前言 OkHttp是一个处理网络请求的开源项目,是安卓端最火热的轻量级框架,由移动支付Square公司开发。OkHttp3是Java和Android都能用,Android还有一个著名网络库叫Volley,那个…