ConvLSTM时空预测实战代码详解

news2025/1/13 7:39:30

写在前面

时空预测是很多领域都存在的问题,不同于时间序列,时空预测不仅需要探究时间的变化,也需要关注空间的变化。许多预测问题都只片面的关注时间问题,如预测某人未来3年患某种病的概率,食堂就餐人数等,往往忽视了空间问题,如作为决策者,我不仅想知道明天患新冠的人数,而且想知道这些人会在哪些位置发病,以便精准管理。换句话说,决策者更多关注的是人群层面,而干预和施控是工作人员主要关注的问题。空间问题能够解答众多人对于预测问题最后的疑问,既:事件A可能会发生,那它会在哪里发生?
近些年关于时空预测的问题不断刺激着机器学习领域的科学家和学者,有别于传统统计学时空预测模型,机器学习或深度学习已经展现了其强大的优势:非线性拟合、高维数据的处理能力、较少担心变量共线性对模型的影响等。目前,通过长短期记忆网络LSTM和卷积神经网络CNN分别提取时间和空间特征来实现时空预测已经成为了可能,本文主要基于2015年发表在arxiv平台上的一篇building block文章《Convolutional LSTM Network: A Machine Learning Approach for Precipitation Nowcasting》来实现时空预测的经典问题——下一帧视频预测,当然,其他机器学习时空预测模型和时空预测问题不在本文讨论范围内,有兴趣的同学可以自行尝试。

一、时空预测

很多问题都可以尝试时空预测,如新冠疫情发病,交通流量、天气预报、慢性病区域发病风险预测等,空间可以被时间切分为无数个2维图片,这样用CNN来提取图片中的信息,用LSTM来处理时序问题,便能很好的解决时空预测问题,如下图所示。注意,与普通的LSTM不同,时空预测在其他维度上仍有位置变化,这是空间的体现。
在这里插入图片描述
关于论文中提到ConvLSTM的公式原理,在很多博客上都有详细论述,这里不再过多介绍,有兴趣的同学可以自己在网络上查看或者查看原文献的解读。

二、数据集的选取和下载

和论文中一样,我们也选取Moving-MNIST数据集,既移动MNIST数据集,该公开数据集可以在多伦多大学提供的网站上下载,Moving-MNIST数据集是时空预测常用的数据集之一,数据集下载代码如下:

import numpy as np
from tensorflow import keras
fpath = keras.utils.get_file(
    "moving_mnist.npy",
    "http://www.cs.toronto.edu/~nitish/unsupervised_video/mnist_test_seq.npy",
)
dataset = np.load(fpath)
print(dataset.shape)

下载好数据集后,我们输出数据集的shape,结果为(20, 10000, 64, 64),表示一个seq有二十个图片,前十帧为input,后十帧为target,一共有10000个sequence,每个图片的大小为64✖64,如下图所示:
在这里插入图片描述

三、数据集预处理与数据集划分

由于我们的ConvLSTM接受的输入是(sanmples,seq,wide,height,channel),因此我们要把数据集进行改造以符合模型的输入要求。

# 转换数据集的seq和samples维度,便于输入我们的模型
dataset = np.swapaxes(dataset, 0, 1)
# 10000个样本太多,我们只选取1000个
dataset = dataset[:1000, ...]
# 我们此时是二维灰度图片,因此要增加一维,代表单通道,如果是彩色,则为3
dataset = np.expand_dims(dataset, axis=-1)
print(dataset.shape)

转换后数据集的shape为(1000,20,64,64,1)已经符合模型的输入要求,接下来是划分数据集,这里要打乱索引,实现随机划分训练集和测试集。

indexes = np.arange(dataset.shape[0])
np.random.shuffle(indexes)  # 打乱索引顺序
# 训练集:测试集=9:1
train_index = indexes[: int(0.9 * dataset.shape[0])]
val_index = indexes[int(0.9 * dataset.shape[0]):]
train_dataset = dataset[train_index]
val_dataset = dataset[val_index]
print(train_dataset.shape)
print(val_dataset.shape)

划分好数据集后我们通过除以255实现归一化,归一化要在划分数据集之后完成,不然会导致数据泄露。

# 归一化,除255就是把3基色都调到0-1区间,得到绝对色彩信息
train_dataset = train_dataset / 255
val_dataset = val_dataset / 255

分离x和y,按照论文,我们是前20帧预测后20帧,类似于下图:
在这里插入图片描述
代码如下:

# 分离x和y,注意,此时的y是下一帧图像,既最后一个片子,我们用前20帧预测后20帧,既序号0-19
def create_shifted_frames(data):
    x = data[:, 0: data.shape[1] - 1, :, :]
    y = data[:, 1: data.shape[1], :, :]
    return x, y
x_train, y_train = create_shifted_frames(train_dataset)
x_val, y_val = create_shifted_frames(val_dataset)

四、模型构建

# 模型构建核心代码,这里我们修改超参数与keras官方超参数一致
model = Sequential([
    keras.layers.ConvLSTM2D(filters=64, kernel_size=(5, 5),
                   input_shape=(None, 64, 64, 1),
                   padding='same', return_sequences=True),
    keras.layers.BatchNormalization(),
    keras.layers.ConvLSTM2D(filters=64, kernel_size=(3, 3),
                   padding='same', return_sequences=True),
    keras.layers.BatchNormalization(),
    keras.layers.ConvLSTM2D(filters=64, kernel_size=(1, 1),
                   padding='same', return_sequences=True),
    keras.layers.Conv3D(filters=1, kernel_size=(3, 3, 3),
               activation='sigmoid',
               padding='same', data_format='channels_last')
])
model.compile(loss='binary_crossentropy', optimizer='adadelta')
model.summary()

模型结构如下:
在这里插入图片描述

五、模型训练

构建好模型后,我们开始训练,本人电脑显卡为GTX1660ti ,显存6G,显存有限,因此把batch size调小一些,并适当增大一些epoch,如果算力允许,可以增大epoch和batch size。这里定义了早期终止训练和调整学习率回调函数,因此无需担心无效训练时间增加问题。

# 定义回调
early_stopping = keras.callbacks.EarlyStopping(monitor="val_loss", patience=10)
reduce_lr = keras.callbacks.ReduceLROnPlateau(monitor="val_loss", patience=5)

# 设置训练参数
epochs = 50
batch_size = 2

# 拟合模型.
model.fit(
    x_train,
    y_train,
    batch_size=batch_size,
    epochs=epochs,
    validation_data=(x_val, y_val),
    callbacks=[early_stopping, reduce_lr],
)
model.save('model.h5')

六、结果查看

在这里插入图片描述

与自然语言处理不同,时空预测不能单纯对比accuracy、F1-score这些指标,要对比预测结果和实际的差别。从结果来看,我们预测结果比较模糊,因此,考虑改用adam优化和调整卷积核的个数至30个以及epoch改为15,结果如下:
在这里插入图片描述
从效果来看,略微有所提升,验证集损失在0.0244,继续增加epoch发现尽管验证集上的cost function在不断减小,但图像效果又变差了,和论文中图像一样,效果可以看到还是依然很差的,说实话这个想做好确实难度很高,因为数字在图中极度的非线性运动,很难全部学到运动的规则,就算是学到了,单个字的空间信息可能也会丢失,所以越往后效果越差。
在这里插入图片描述
查看了一些博主的解决方法,有提示改用SSIM(结构相似度)损失函数的,有建议减小学习率的,有建议用反卷积的等等,但深度学习本身就是一个“炼丹“的过程,无非就是增加一些卷积核,或者增加减少一些卷积层,ConvLSTM在本次任务中本质是图像生成,图像生成的超参数调整是比较敏感的,由于本人未安装tensorflow_contrib库,无法使用SSIM损失函数,后续如有新的进展,将第一时间发出来。

七、总结

目前来讲,空间统计学在医疗领域的运用是较为普遍的,特别是传染病和环境暴露因素的研究中,时常能看到论文作者使用时空统计模型,而ConLSTM的诞生和近些年来的改进能够推动基于机器学习的时空预测模型在传染病和环境暴露因素研究中的应用,除此之外,ConvLSTM甚至能够分析患者的视频数据,如步态或者某区域慢性病风险矩阵图等,这些领域依旧空白,我相信,AI的助力能够推动这些领域的蓬勃发展。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/112205.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

CSS--圆角边框

单独对四个角进行设置: boder-top-left-radius:30px; //左上角 boder-top-right-radius:30px; //右上角 boder-bottom-left-radius:30px; //右下角 boder-bottom-right-radius:30px&…

群晖 Sonology NAS DS920+ 拆机装机方法

文章内容:群晖 Sonology NAS DS920 拆机方法 关键词组:群晖,Sonology, nas, ds920, 拆机, 外壳 使用软件:无 虚拟环境:无 操作系统:无 目录一、事件起因三、拆装机方法一、事件起因 起初,由于机…

OpenCV环境下实现图像任意角度旋转的原理及代码

OpenCV环境下实现图像任意角度旋转的原理及代码 实现图像任意角度旋转的原理如下: Step01-把图像原点从左上角转换到旋转中心点。 Step02-利用极坐标系计算出旋转后各点的坐标。 Step03-确定旋转后图像的左边界、右边界、上边界、下边界,进而得出旋转后…

计数排序 [数据结构与算法][Java]

计数排序 计数排序和基数排序都是桶排序的一种应用 适用场景: 量大但是范围小 比如对10000个数进行排序, 但是这10000个数中只有10种数字(0 - 9)典型题目: 某大型企业数万名员工年龄排序如何快速得知高考名次(腾讯面试) 这里我们以某大型企业数万名员工年龄排序来进行一个…

RV1126笔记十四:吸烟行为检测及部署<二>

若该文为原创文章,转载请注明原文出处。 PC下yolov5环境搭建 我使用的训练环境是Windows10+MiniConda 接下来记录搭建全过程 备注:条件允许就使用ubuntu物理机,最好要有显卡,训练有显卡速度会快很多,没有显卡,训练300轮,亲测大概40小时,不值得。 一、miniconda 安装…

从零了解进程(操作系统定位,进程的概念,特征,虚拟地址)

目录 操作系统的定位 进程的概念 如何描述进程? 如何组织进程? 进程的特征 1.pid 2.内存指针 3.文件描述符 4.进程调度的相关属性 (1)进程的状态 (2)优先级 (3)上下文 (4)记账信息 进程是如何利用cpu资源的? 进程的虚拟地址 物理地址 内存随机访问的特性 为…

【ML实验7】人脸识别综合项目(PCA、多分类SVM)

实验代码获取 github repo 山东大学机器学习课程资源索引 实验目的 实验环境 实验内容 PCA 两种方法EVD-PCA和SVD-PCA的实现、效率对比见我之前的博客一个PCA加速技巧,这里补充SVD方法的数学推导: 首先,设方阵AAA的特征值分解为AUΣUTAU\S…

ZKP应用:石头剪刀布游戏

1. 引言 开源代码见: https://github.com/spalladino/zkp-tests 对比了分别使用: Iden3团队的circom语言(易于学习ZKP)ZCash团队的Halo2框架Aztec团队的Noir语言(最友好) 编写石头剪刀布游戏的ZKP证明…

IPv6(计算机网络-网络层)

目录 IPv6 的特点 IPv6 数据报的格式 IPv6 分组的格式 IPv6 的扩展首部 从计算机本身发展以及从互联网规模和网络传输速率来看,现在 IPv4已很不适用。 要解决 IP 地址耗尽的问题的措施: 采用无类别编址 CIDR,使 IP 地址的分配更加合理…

MySQL 锁机制

文章目录MySQL 锁机制表锁读锁场景一场景二场景三总结写锁场景一场景二场景三总结行锁场景一场景二总结间隙锁缺点如何锁定一行MySQL 锁机制 表锁 读锁 查看哪些表被加锁了 语法:show open tables 添加读锁 read 读锁关键字 | write 写锁关键字 语法:l…

qt下采用libcurl实现ftp与tftp功能,提供源代码程序

一、FTP简介 FTP(文件传输协议),工作在应用层,是用于在网络上进行文件传输的一套标准协议。它使用 TCP 传输,客户在和服务器建立连接前要经过一个“三次握手”的过程, 保证客户与服务器之间的连接是可靠的&…

基于xml的自动装配之byName

基于xml的自动装配之byName 自动装配方式:byName byName:将自动装配的属性的属性名,作为bean的id在IOC容器中匹配相对应的bean进行赋值总结:当类型匹配的 bean 有多个时,此时可以使用 byName 实现自动装配 配置bean &…

02:损失函数总结

目录 nn.L1Loss: nn.NLLLoss: nn.MSELoss: nn.CrossEntropyLoss: 损失函数是用来估量模型的预测值与真实值的不一致程度,它是一个非负实值函数。我们训练模型的过程,就是通过不断的迭代计算,使用梯度下降的优化算法,使得损失函…

多叉树 [数据结构与算法][Java]

多叉树 在二叉树中每个结点只能有一个数据项, 并且最多有两个子节点, 如果允许每个结点可以有更多的数据项和更多的子节点, 那么就是多叉树 多叉树: multiway tree 那么我们为什么要提出多叉树? 因为二叉树有一定的问题: 即使二叉树的操作效率高, 但是也存在问题: 二叉树需…

django之前后端不分离的操作(增删改查)

背景: demo采用的是前后端不分离的操作,用Bootstrap作为前段页面框架使用 1.先说模板之间的继承(针对HTML来讲) 父模板中编写好公共代码块,例如一个网站的导航和footer部分 在内容部分注明 {%block content%} {%e…

53 记一次自定义 classloader 导致的 metadataspace OOM

前言 这是最近 flink 集群上面暴露出现的一个问题 具体的细节原因 就是 flink 上面提交任务的时候, 自定义的 classloader 加载 driver.jar 然后导致 metaspace OOM 由于这边的 TaskManager metadataspace 配置相对较小(MaxMetaspaceSize配置为96M), 然后导致 出现了 metada…

BUUCTF Misc [SUCTF2018]single dog 我吃三明治 sqltest [SWPU2019]你有没有好好看网课?

[SUCTF2018]single dog 下载文件 使用kali的binwalk工具分析 进行文件分离 解压其中的压缩包,得到1.txt 看内容应该是js加密后的结果,复制到AAEncode在线解密网站 得到flag flag{happy double eleven} 我吃三明治 下载文件 使用010 eito…

AICodeHelper - AI编程助手

AICodeHelper是一款AI编程助手,旨在帮助程序员提高他们的编码技能。 简单的像尝试的代码直接问就行,但是一些复杂的,就得需要写技巧; 下面是几个使用的小技巧:链接是:AICodeHelper 1.可以使用中文提问&a…

深度学习Faster-RCNN网络

目录1 网络工作流程1.1 数据加载1.2 模型加载1.3 模型预测过程1.3.1 RPN获取候选区域1.3.2 FastRCNN进行目标检测2 模型结构详解2.1 backbone2.2 RPN网络2.2.1 anchors2.2.2 RPN分类2.2.3 RPN回归2.2.4 Proposal层2.4 ROIPooling2.5 目标分类与回归3 FasterRCNN的训练3.1 RPN网…

Yolov5移植树莓派实现目标检测

Hallo,大家好啊!之前写了几篇Yolov5相关项目的博客,然后学习了树莓派之后,更新了几篇树莓派的博客,我的最终目的是将Yolov5移植到树莓派,通过树莓派上面的摄像头实现目标检测。你想啊,在工厂里面…