时间卷积网络(TCN):序列建模的强大工具(附Pytorch网络模型代码)

news2024/9/21 2:45:09

1. 引言

引用自:Bai S, Kolter J Z, Koltun V. An empirical evaluation of generic convolutional and recurrent networks for sequence modeling. arXiv[J]. arXiv preprint arXiv:1803.01271, 2018, 10.

在这里插入图片描述

时间卷积网络(Temporal Convolutional Network,简称TCN)是一种专门用于处理序列数据的深度学习模型。它结合了卷积神经网络(CNN)的并行处理能力和循环神经网络(RNN)的长期依赖建模能力,成为序列建模任务中的强大工具。实验证明,对于某些任务下的长序LSTM和GRU等RNN架构,因此如果大家有多输入单输出(MISO)或多输入多输出(MIMO)序列建模任务,可以尝试使用TCN来作为创新点。
在这里插入图片描述

2. TCN的核心特性

在这里插入图片描述
图1所示。TCN中的架构元素。(a)一个扩张的因果卷积,其扩张因子d = 1,2,4,滤波器大小k = 3。接收野能够覆盖输入序列中的所有值。(b) TCN残余块。当剩余输入和输出具有不同的维数时,添加1x1卷积。© TCN中剩余连接的示例。蓝线是残差函数中的过滤器,绿线是恒等映射。

2.1 序列建模任务描述

在定义网络结构之前,我们先强调序列建模任务的核心特性。假设我们有输入序列 x 0 , … , x T x_0, \ldots, x_T x0,,xT,并希望在每个时间点预测对应的输出 y 0 , … , y T y_0, \ldots, y_T y0,,yT。关键约束在于,预测某个时间点 t t t 的输出 y t y_t yt 时,我们只能利用此前观察到的输入 x 0 , … , x t x_0, \ldots, x_t x0,,xt。形式上讲,序列建模网络是任何函数 f : X T + 1 → Y T + 1 f : X^{T+1} \rightarrow Y^{T+1} f:XT+1YT+1,它生成如下映射:

y ^ 0 , … , y ^ T = f ( x 0 , … , x T ) \hat{y}_0, \ldots, \hat{y}_T = f(x_0, \ldots, x_T) y^0,,y^T=f(x0,,xT)

若要满足因果性约束,即 y t y_t yt 只依赖于 x 0 , … , x t x_0, \ldots, x_t x0,,xt,而不依赖于任何“未来”的输入 x t + 1 , … , x T x_{t+1}, \ldots, x_T xt+1,,xT。在序列建模的学习目标中,是找到网络 f f f,使其最小化实际输出与预测值间的预期损失, L ( y 0 , … , y T , f ( x 0 , … , x T ) ) L(y_0, \ldots, y_T, f(x_0, \ldots, x_T)) L(y0,,yT,f(x0,,xT)),其中序列和输出根据某一概率分布抽取。

2.2 因果卷积

TCN使用因果卷积(Causal Convolution)来确保模型不会违反时间顺序。因果卷积即输出只依赖于当前时刻及其之前的输入,而不依赖于未来的输入(因为当前的你看不到未来的数据)。在标准的卷积操作中,每个输出值都基于其周围的输入值,包括未来的时间点。但在因果卷积中,权重仅应用于当前和过去的输入值,确保了信息流的方向性,避免了未来信息泄露到当前输出中。为了实现这一点,通常会在卷积核的右侧填充零(称为因果填充),这样只有当前和过去的信息被用于计算输出。

数学表示:

y ( t ) = ∑ i = 0 k − 1 f ( i ) ⋅ x ( t − i ) y(t) = \sum_{i=0}^{k-1} f(i) \cdot x(t-i) y(t)=i=0k1f(i)x(ti)

其中, f f f是卷积核, k k k是卷积核大小, x x x是输入序列。

2.3 扩张卷积

为了增加感受野而不增加参数数量,TCN采用扩张卷积(Dilated Convolution)。扩张卷积,也被称为空洞卷积,是一种在卷积核之间插入空隙(即跳过某些输入单元)的卷积形式。这种技术允许模型在不增加参数数量的情况下捕获更大的感受野,从而更好地理解输入数据中的上下文信息。扩张因子(dilation factor)决定了卷积核中元素之间的间距,例如,如果扩张因子为2,则卷积核中的元素会间隔一个输入单元。

扩张卷积的数学表示:

y ( t ) = ∑ i = 0 k − 1 f ( i ) ⋅ x ( t − d ⋅ i ) y(t) = \sum_{i=0}^{k-1} f(i) \cdot x(t-d \cdot i) y(t)=i=0k1f(i)x(tdi)

其中, d d d是扩张率。

一个扩张的因果卷积如下图所示:
在这里插入图片描述

2.4 残差连接

TCN使用残差连接来缓解梯度消失问题并促进更深层网络的训练。残差连接是残差网络(ResNets)的关键组成部分,由何凯明等人提出。它的主要目的是解决深层神经网络训练中的梯度消失/爆炸问题,以及提高网络的训练效率和性能。在残差连接中,网络的某一层的输出直接加到几层之后的另一层上,形成所谓的“跳跃连接”。具体来说,假设有一个输入 x x x,经过几层后得到 F ( x ) F(x) F(x),那么最终的输出不是 F ( x ) F(x) F(x)而是 x + F ( x ) x+F(x) x+F(x),也就是输入+输出。这种结构允许梯度在反向传播时可以直接流回更早的层,减少了梯度消失的问题,并且使得网络能够有效地训练更深的架构。残差块的输出可以表示为:

o u t p u t = a c t i v a t i o n ( i n p u t + F ( i n p u t ) ) output = activation(input + F(input)) output=activation(input+F(input))

其中, F F F是卷积层和激活函数的组合,残差连接如下图所示:
在这里插入图片描述

3. TCN的网络结构

TCN的基本结构包括多个残差块,每个残差块包含:

  1. 一维因果卷积层
  2. 层归一化
  3. ReLU激活函数
  4. Dropout层

TCN的整体结构可以表示为:
在这里插入图片描述

4. TCN vs RNN

相比于RNN,TCN有以下优势:

  1. 并行计算:卷积操作可以并行执行,提高计算效率。
  2. 固定感受野:可以精确控制输出对过去输入的依赖范围。
  3. 灵活的感受野大小:通过调整网络深度和扩张率,可以轻松处理不同长度的序列。
  4. 稳定梯度:避免了RNN中的梯度消失/爆炸问题。

5. TCN的应用

TCN在多个领域表现出色,包括:

  • 时间序列预测
  • 语音合成
  • 机器翻译
  • 动作识别
  • 音频生成

本篇文章不靠卖代码赚取收益,麻烦给个点赞和关注,后续还会有开源的免费优化算法及其代码,栓Q!同时如果大家有想要的算法可以在评论区打出,如果有空的话我可以帮忙复现

TCN的实现

以下是使用PyTorch实现TCN核心组件的示例代码(可以直接调用):

import torch
import torch.nn as nn
from torch.nn.utils import weight_norm


class Chomp1d(nn.Module):
    def __init__(self, chomp_size):
        super(Chomp1d, self).__init__()
        self.chomp_size = chomp_size

    def forward(self, x):
        return x[:, :, :-self.chomp_size].contiguous()


class TemporalBlock(nn.Module):
    def __init__(self, n_inputs, n_outputs, kernel_size, stride, dilation, padding, dropout=0.2):
        super(TemporalBlock, self).__init__()
        self.conv1 = weight_norm(nn.Conv1d(n_inputs, n_outputs, kernel_size,
                                           stride=stride, padding=padding, dilation=dilation))
        self.chomp1 = Chomp1d(padding)
        self.relu1 = nn.ReLU()
        self.dropout1 = nn.Dropout(dropout)

        self.conv2 = weight_norm(nn.Conv1d(n_outputs, n_outputs, kernel_size,
                                           stride=stride, padding=padding, dilation=dilation))
        self.chomp2 = Chomp1d(padding)
        self.relu2 = nn.ReLU()
        self.dropout2 = nn.Dropout(dropout)

        self.net = nn.Sequential(self.conv1, self.chomp1, self.relu1, self.dropout1,
                                 self.conv2, self.chomp2, self.relu2, self.dropout2)
        self.downsample = nn.Conv1d(n_inputs, n_outputs, 1) if n_inputs != n_outputs else None
        self.relu = nn.ReLU()
        self.init_weights()

    def init_weights(self):
        self.conv1.weight.data.normal_(0, 0.01)
        self.conv2.weight.data.normal_(0, 0.01)
        if self.downsample is not None:
            self.downsample.weight.data.normal_(0, 0.01)

    def forward(self, x):
        out = self.net(x)
        res = x if self.downsample is None else self.downsample(x)
        return self.relu(out + res)


class TemporalConvNet(nn.Module):
    def __init__(self, num_inputs, num_channels, kernel_size=2, dropout=0.2):
        super(TemporalConvNet, self).__init__()
        layers = []
        num_levels = len(num_channels)
        for i in range(num_levels):
            dilation_size = 2 ** i
            in_channels = num_inputs if i == 0 else num_channels[i-1]
            out_channels = num_channels[i]
            layers += [TemporalBlock(in_channels, out_channels, kernel_size, stride=1, dilation=dilation_size,
                                     padding=(kernel_size-1) * dilation_size, dropout=dropout)]

        self.network = nn.Sequential(*layers)

    def forward(self, x):
        return self.network(x)
        

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1936312.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Six common classification algorithms in machine learning

分类算法是一种机器学习算法,其主要目的是从数据中发现规律并将数据分成不同的类别。分类算法通过对已知类别训练集的计算和分析,从中发现类别规则并预测新数据的类别。常见的分类算法包括决策树、朴素贝叶斯、逻辑回归、K-最近邻、支持向量机等。分类算…

减分兔搜题-12123学法减分20题目及答案 #媒体#职场发展

对于即将参加驾驶考试的朋友来说,掌握一些经典题目和答案至关重要。今天,我就为大家带来了这样一份干货——20道驾驶考试题目和答案,助你轻松应对考试!这些题目不仅包括了考试中常考的内容,还有针对难点和重点的详细解…

​数据结构之初始二叉树(3)

找往期文章包括但不限于本期文章中不懂的知识点: 个人主页:我要学编程(ಥ_ಥ)-CSDN博客 所属专栏:数据结构(Java版) 二叉树的基本操作 通过上篇文章的学习,我们简单的了解了二叉树的相关操作。接下来就是有…

前端组件化技术实践:Vue自定义顶部导航栏组件的探索

摘要 随着前端技术的飞速发展,组件化开发已成为提高开发效率、降低维护成本的关键手段。本文将以Vue自定义顶部导航栏组件为例,深入探讨前端组件化开发的实践过程、优势以及面临的挑战,旨在为广大前端开发者提供有价值的参考和启示。 一、引…

从微软发iPhone,聊聊企业设备管理

今天讲个上周的旧闻,微软给员工免费发iPhone。其实上周就有很多朋友私信问我,在知乎上邀请我回答相关话题,今天就抽点时间和大家一起聊聊这事。我不想讨论太多新闻本身,而是更想聊聊事件的主要原因——微软企业设备管理&#xff0…

深入浅出WebRTC—DelayBasedBwe

WebRTC 中的带宽估计是其拥塞控制机制的核心组成部分,基于延迟的带宽估计是其中的一种策略,它主要基于延迟变化推断出可用的网络带宽。 1. 总体架构 1.1. 静态结构 1)DelayBasedBwe 受 GoogCcNetworkController 控制,接收其输入…

C++STL初阶(7):list的运用与初步了解

在了解了vector之后,我们只需要简单学习List与vector不一样的接口即可 1.list的基本接口 1.1 iterator list中,与vector最大的区别就是迭代器由随机迭代器变成双向迭代器 string和vector中的迭代器都是随机迭代器,支持-等,而LIS…

MOGONET:患者分类与biomarker识别

为了充分利用组学技术的进步并更全面地了解人类疾病,需要新的计算方法来综合分析多种类型的组学数据。多组学图卷积网络 (MOGONET,Multi-Omics Graph cOnvolutional NETworks)是一种用于生物医学分类的新型多组学整合方法。MOGONET 包含特定组学的学习和…

Keil开发IDE

Keil开发IDE 简述Keil C51Keil ARMMDK DFP安装 简述 Keil公司是一家业界领先的微控制器(MCU)软件开发工具的独立供应商。Keil公司由两家私人公司联合运营,分别是德国慕尼黑的Keil Elektronik GmbH和美国德克萨斯的Keil Software Inc。Keil公…

三、初识C语言(3)

1.操作符 &#xff08;1&#xff09;算术操作符 - * / % 商 余&#xff08;取模&#xff09; 小算法&#xff1a; 若a<b&#xff0c;则a%b a 若a%b c&#xff0c;则0 < c < b-1 若两个int 类型数相除&#xff0c;结果有小数会被舍弃。 保留小数…

苹果电脑pdf合并软件 苹果电脑合并pdf 苹果电脑pdf怎么合并

在数字化办公日益普及的今天&#xff0c;pdf文件因其跨平台兼容性强、格式稳定等特点&#xff0c;已经成为工作、学习和生活中不可或缺的文件格式。然而&#xff0c;我们常常面临一个问题&#xff1a;如何将多个pdf文件合并为一个&#xff1f;这不仅有助于文件的整理和管理&…

苏州金龙海格汽车入选2024中国汽车行业可持续发展实践案例

2024年7月11日-13日&#xff0c;由中国汽车工业协会主办的第14届中国汽车论坛在上海嘉定举办。本届论坛隆重发布了“2024中国汽车行业可持续发展实践案例”&#xff0c;苏州金龙因在坚持绿色可持续发展方面做出的努力和贡献获评2024中国汽车行业可持续发展实践案例“绿色发展”…

Ideal窗口中左右侧栏消失了

不知道大家在工作过程中有没有遇到过此类问题&#xff0c;不论是Maven项目还是Gradle项目&#xff0c;突然发现Ideal窗口右侧图标丢失了&#xff0c;同事今天突然说大象图标不见了&#xff0c;不知道怎样刷新gradle。 不要慌张&#xff0c;下面提供一些解决思路&#xff1a; 1…

HarmonyOS ArkUi 唤起系统APP:指定设置界面、浏览器、相机、拨号界面、选择通讯录联系人

效果&#xff1a; 完整工具类&#xff1a; import { common, Want } from kit.AbilityKit; import { BusinessError } from kit.BasicServicesKit; import { call } from kit.TelephonyKit; import { promptAction } from kit.ArkUI; import { contact } from kit.Contacts…

PHP宠物店萌宠小程序系统源码

&#x1f43e;萌宠生活新方式&#x1f43e; &#x1f3e1;【一键直达萌宠世界】 你是否也梦想着拥有一家随时能“云撸猫”、“云吸狗”的神奇小店&#xff1f;现在&#xff0c;“宠物店萌宠小程序”就是你的秘密花园&#xff01;&#x1f31f;只需轻轻一点&#xff0c;就能瞬…

使用Velero备份与恢复K8s集群及应用

作者&#xff1a;红米 环境 3台虚拟机组成一主两从的测试集群&#xff0c;使用NFS作为动态存储。 主机IP系统k8s-master192.168.1.10centos7.9k8s-node1192.168.1.11centos7.9k8s-node2192.168.1.12centos7.9 1、介绍 1.1 简介 备份容灾 一键恢复 集群迁移 支持备份pv&…

CH04_依赖项属性

第4章&#xff1a;依赖项属性 本章目标 理解依赖项属性理解属性验证 依赖项属性 ​ 属性与事件是.NET抽象模型的核心部分。WPF使用了更高级的依赖项属性&#xff08;Dependency Property&#xff09;功能来替换原来.NET的属性&#xff0c;实现了更高效率的保存机制&#xf…

卷积神经网络【CNN】--池化层的原理详细解读

池化层&#xff08;Pooling Layer&#xff09;是卷积神经网络&#xff08;CNN&#xff09;中的一个关键组件&#xff0c;主要用于减少特征图&#xff08;feature maps&#xff09;的维度&#xff0c;同时保留重要的特征信息。 一、池化层的含义 池化层在卷积神经网络中扮演着降…

python调用chrome浏览器自动化如何选择元素

功能描述&#xff1a;在对话框输入文字&#xff0c;并发送。 注意&#xff1a; # 定位到多行文本输入框并输入内容。在selenium 4版本中&#xff0c;元素定位需要填写父元素和子元素名。 textarea driver.find_element(By.CSS_SELECTOR,textarea.el-textarea__inner) from …

ACM中国图灵大会专题 | 图灵奖得主Manuel Blum教授与仓颉团队交流 | 华为论坛:面向全场景应用编程语言精彩回顾

ACM 中国图灵大会&#xff08;ACM Turing Award Celebration Conference TURC 2024&#xff09;于2024年7月5日至7日在长沙举行。本届大会由ACM主办&#xff0c;in cooperation with CCF&#xff0c;互联网之父Vinton Cerf、中国计算机学会前理事长梅宏院士和廖湘科院士担任学术…