时间卷积网络(TCN)原理+代码详解

news2024/12/22 19:08:29

目录

  • 一、TCN原理
    • 1.1 因果卷积(Causal Convolution)
    • 1.2 扩张卷积(Dilated Convolution)
  • 二、代码实现
    • 2.1 Chomp1d 模块
    • 2.2 TemporalBlock 模块
    • 2.3 TemporalConvNet 模块
    • 2.4 完整代码示例
  • 参考文献

  在理解 TCN 的原理之前,我们可以先对传统的循环神经网络(RNN)进行简要回顾。RNN 是处理序列数据的常用方法,其核心思想是通过将前一个时间步的隐藏状态传递到下一个时间步,实现对序列依赖关系的建模。然而,RNN 在处理长序列时存在以下几个缺点:

  • 无法并行计算:RNN 的计算依赖于时间步的顺序,导致无法高效利用 GPU 并行计算。

  • 梯度消失/爆炸:在长时间依赖中,梯度在反向传播时会逐渐消失或变得不稳定。

  • 短期记忆限制:由于计算依赖于序列的逐步传递,RNN 难以捕获远距离的时间依赖。

  TCN 正是在这样的背景下提出的。它通过因果卷积和扩张卷积,突破了 RNN 的这些瓶颈,特别适用于长时间序列数据。接下来,我们将详细解析 TCN 的原理。

一、TCN原理

1.1 因果卷积(Causal Convolution)

  在卷积操作中,卷积核在输入上滑动时会同时处理前后时间步的数据,导致当前时间步的输出可能依赖于未来的输入。然而,对于时间序列任务,我们通常希望模型只依赖于过去的输入,不“窥探”未来,这样的结构称为“因果性”。

  TCN 使用因果卷积来确保这一点。因果卷积是指每个时间步的输出仅依赖于它之前的时间步,而不依赖于未来。简单来说,当前时间步的输出只会考虑卷积核覆盖的前几个时间步的输入。

  TCN 通过适当的填充(padding)来实现这一点,使得每一层的卷积不会跨越未来时间步。因果卷积的示意图如下:

在这里插入图片描述

1.2 扩张卷积(Dilated Convolution)

  为了捕捉长时间依赖关系,TCN 通过 扩张卷积(Dilated Convolution 来扩展卷积核的感受野。扩张卷积通过在卷积核的元素之间插入“间隔”,从而在保持卷积核大小不变的情况下,扩大卷积的感受野。

  例如,假设卷积核大小为 3,当扩张率 dilation=2 时,卷积核的元素之间插入 1 个间隔,感受野可以从 3 扩展到 5。通过这种扩张卷积,TCN 在每一层可以通过指数扩展的方式增大感受野,使得模型能够捕捉到远距离的依赖关系。例如,TCN 中第 i i i 层的感受野大小为 2 i 2^{i} 2i,这样层数越深,感受野就越大。如下图所示:

在这里插入图片描述

二、代码实现

2.1 Chomp1d 模块

  TCN 使用填充操作来保证卷积后的时间步不丢失,但填充会导致额外的时间步,因此需要 Chomp1d 来修剪掉多余部分,保证输入输出的时间维度一致。

class Chomp1d(nn.Module):
    def __init__(self, chomp_size):
        super(Chomp1d, self).__init__()
        self.chomp_size = chomp_size

    def forward(self, x):
        return x[:, :, :-self.chomp_size].contiguous()

  Chomp1d 的作用是对卷积结果的最后几个时间步进行修剪,这确保了卷积核在时间序列两端不会额外输出冗余的步长。

2.2 TemporalBlock 模块

  TemporalBlock 是 TCN 的基本构建单元,包含两层扩张卷积,每层后接激活函数和 Chomp1d 操作。

class TemporalBlock(nn.Module):
    def __init__(self, n_inputs, n_outputs, kernel_size, stride, dilation, padding, dropout):
        super(TemporalBlock, self).__init__()
        # 第一层卷积
        self.ll_conv1 = nn.Conv1d(n_inputs, n_outputs, kernel_size, stride=stride, padding=padding, dilation=dilation)
        self.chomp1 = Chomp1d(padding)
        self.relu1 = nn.LeakyReLU()

        # 第二层卷积
        self.ll_conv2 = nn.Conv1d(n_outputs, n_outputs, kernel_size, stride=stride, padding=padding, dilation=dilation)
        self.chomp2 = Chomp1d(padding)
        self.relu2 = nn.LeakyReLU()

        # Dropout 作为正则化,防止过拟合
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # 第一个卷积、修剪、激活和 Dropout
        out = self.ll_conv1(x)
        out = self.chomp1(out)
        out = self.relu1(out)
        out = self.dropout(out)

        # 第二个卷积、修剪、激活和 Dropout
        out = self.ll_conv2(out)
        out = self.chomp2(out)
        out = self.relu2(out)
        out = self.dropout(out)

        return out
  • ll_conv1 和 ll_conv2 是两层扩张卷积层,dilation 参数决定了每层的感受野大小。

  • Chomp1d 保证卷积结果不会产生额外的时间步。

  • LeakyReLU 是非线性激活函数,为模型引入非线性。

  • Dropout 用于防止过拟合,通过随机丢弃一部分神经元。

2.3 TemporalConvNet 模块

  TemporalConvNet 是由多个 TemporalBlock 级联组成的模型,每一层的卷积感受野逐层递增。

class TemporalConvNet(nn.Module):
    def __init__(self, num_inputs, num_channels, kernel_size=2, dropout=0.0):
        super(TemporalConvNet, self).__init__()
        layers = []
        self.num_levels = len(num_channels)

        for i in range(self.num_levels):
            dilation_size = 2 ** i  # 每层的扩张率递增
            in_channels = num_inputs if i == 0 else num_channels[i - 1]
            out_channels = num_channels[i]
            layers.append(
                TemporalBlock(
                    in_channels, out_channels, kernel_size, stride=1, dilation=dilation_size,
                    padding=(kernel_size - 1) * dilation_size, dropout=dropout
                )
            )

        self.network = nn.Sequential(*layers)

    def forward(self, x):
        return self.network(x)
  • TemporalConvNet 通过循环构建多层 TemporalBlock,每层的扩张率 dilation 是前一层的两倍,使得感受野指数级增长。

  • 使用 nn.Sequential 将所有层级联在一起,模型最终输出序列数据经过所有层的处理结果。

2.4 完整代码示例

  在这个例子中,输入数据有 8 个样本,每个样本有 3 个特征,序列长度为 10。经过 TCN 网络的三层处理,输出的特征维度从 3 增加到 64,但时间维度(10)保持不变。

import torch.nn as nn
import torch.nn.functional as F
import torch

class Chomp1d(nn.Module):
    def __init__(self, chomp_size):
        super(Chomp1d, self).__init__()
        self.chomp_size = chomp_size

    def forward(self, x):
        return x[:, :, : -self.chomp_size].contiguous()


class TemporalBlock(nn.Module):
    def __init__(
        self, n_inputs, n_outputs, kernel_size, stride, dilation, padding, dropout
    ):
        super(TemporalBlock, self).__init__()
        self.n_inputs = n_inputs
        self.n_outputs = n_outputs
        self.kernel_size = kernel_size
        self.stride = stride
        self.dilation = dilation
        self.padding = padding
        self.dropout = dropout
        self.ll_conv1 = nn.Conv1d(
            n_inputs,
            n_outputs,
            kernel_size,
            stride=stride,
            padding=padding,
            dilation=dilation,
        )
        self.chomp1 = Chomp1d(padding)

        self.ll_conv2 = nn.Conv1d(
            n_outputs,
            n_outputs,
            kernel_size,
            stride=stride,
            padding=padding,
            dilation=dilation,
        )
        self.chomp2 = Chomp1d(padding)
        self.sigmoid = nn.Sigmoid()

    def net(self, x, block_num, params=None):
        layer_name = "ll_tc.ll_temporal_block" + str(block_num)
        if params is None:
            x = self.ll_conv1(x)
        else:
            x = F.conv1d(
                x,
                weight=params[layer_name + ".ll_conv1.weight"],
                bias=params[layer_name + ".ll_conv1.bias"],
                stride=self.stride,
                padding=self.padding,
                dilation=self.dilation,
            )

        x = self.chomp1(x)
        x = F.leaky_relu(x)

        return x

    def init_weights(self):
        self.ll_conv1.weight.data.normal_(0, 0.01)
        self.ll_conv2.weight.data.normal_(0, 0.01)

    def forward(self, x, block_num, params=None):
        out = self.net(x, block_num, params)
        return out


class TemporalConvNet(nn.Module):
    def __init__(self, num_inputs, num_channels, kernel_size=2, dropout=0.0):
        super(TemporalConvNet, self).__init__()
        layers = []
        self.num_levels = len(num_channels)

        for i in range(self.num_levels):
            dilation_size = 2 ** i
            in_channels = num_inputs if i == 0 else num_channels[i - 1]
            out_channels = num_channels[i]
            setattr(
                self,
                "ll_temporal_block{}".format(i),
                TemporalBlock(
                    in_channels,
                    out_channels,
                    kernel_size,
                    stride=1,
                    dilation=dilation_size,
                    padding=(kernel_size - 1) * dilation_size,
                    dropout=dropout,
                ),
            )

    def forward(self, x, params=None):

        for i in range(self.num_levels):
            temporal_block = getattr(self, "ll_temporal_block{}".format(i))
            x = temporal_block(x, i, params=params)
        return x


# 定义一个 TCN 模型,输入通道数为 3,输出通道分别为 16, 32, 64,核大小为 2
tcn = TemporalConvNet(num_inputs=3, num_channels=[16, 32, 64], kernel_size=2, dropout=0.2)

# 假设输入的张量形状为 (batch_size, num_inputs, sequence_length)
x = torch.randn(8, 3, 10)  # 8 个样本,3 个输入特征,序列长度为 10

# 通过 TCN 进行前向传播
output = tcn(x)

print(output.shape)  # 输出的形状为 (batch_size, 64, sequence_length),即 (8, 64, 10)

参考文献

[1] https://github.com/locuslab/TCN

[2] 如何理解扩张卷积(dilated convolution)

[3] 【机器学习】详解 扩张/膨胀/空洞卷积 (Dilated / Atrous Convolution)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2198296.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

GIS后端工程师岗位职责、技术要求和常见面试题

GIS 后端工程师负责设计、开发与维护地理信息系统的后端服务,包括数据存储、处理、分析以及与前端的交互接口等,以实现高效的地理数据管理和功能支持。 GIS 后端工程师岗位职责 一、系统设计与开发 参与地理信息系统(GIS)项目的…

安装 Petalinux

资料准备 ubuntu 22.04: 运行内存8G 存储空间500G Petalinux:2024.1 安装流程 安装依赖 sudo apt-get update sudo apt-get upgrade sudo apt-get install iproute2 sudo apt-get install gawk sudo apt-get install build-essential sudo apt-ge…

7.3 物联网平台-Thingsboard使用教程

物联网平台-Thingsboard使用教程 目录概述需求: 设计思路实现思路分析 免费下载参考资料和推荐阅读 Survive by day and develop by night. talk for import biz , show your perfect code,full busy,skip hardness,make a better result,wait for chang…

如何使用ssm实现基于web技术的税务门户网站的实现+vue

TOC ssm820基于web技术的税务门户网站的实现vue 绪论 1.1 研究背景 当前社会各行业领域竞争压力非常大,随着当前时代的信息化,科学化发展,让社会各行业领域都争相使用新的信息技术,对行业内的各种相关数据进行科学化&#xff…

基于matlab的语音信号处理

摘要 利用所学习的数字信号处理知识,设计了一个有趣的音效处理系统,首先设计了几种不同的滤波器对声音进行滤波处理,分析了时域和频域的变化,比较了经过滤波处理后的声音与原来的声音有何变化。同时设计实现了语音的倒放&#xf…

从0开始linux(9)——进程(1)进程管理

欢迎来到博主的专栏:从0开始linux 博主ID:代码小豪 文章目录 查看进程进程管理PID与PPIDfork函数 在上一篇中我们了解到:当运行程序时,操作系统会将磁盘中的二进制文件读取到内存当中,程序运行到结束的过程称为进程&am…

【C++ 11】auto 自动类型推导

文章目录 【 1. 基本用法 】【 2. auto 的 应用 】2.0 auto 的限制2.1 简单实例2.2 auto 与指针、引用、const2.4 auto 定义迭代器2.5 auto 用于泛型编程 问题背景 在 C11 之前的版本(C98 和 C 03)中,定义变量或者声明变量之前都必须指明它的…

目标检测YOLO实战应用案例100讲-【目标检测】YOLOV11

目录 前言 算法原理 YOLO发展历程 什么是 YOLO11 YOLOv11 的主要特点 YOLO各版本概览 核心优势: YOLOv11改进方向 YOLOv11功能介绍 YOLOv11关键创新 YOLOv11 指标展示 YOLOV11实验 环境设置 准备数据集 训练模型 验证模型 应用领域 一、智慧交通与自动驾…

【Linux实践】实验八:Shell程序的创建及变量

文章目录 实验八:Shell程序的创建及变量实验目的:实验内容:操作步骤:1. 查看环境变量2. 定义变量AK3. 定义变量AM并比较4. 创建Shell程序 实验八:Shell程序的创建及变量 实验目的: 掌握Shell程序的创建过…

【C++】AVL树的底层以及实现

个人主页 文章目录 ⭐一、AVL树的概念🎉二、AVL树的性质🏝️三、AVL树的实现1. 树的基本结构2. 树的插入3. 树的旋转• 左单旋• 右单旋• 左右双旋• 右左双旋 🎡四、AVL树的其它功能1. 树的查找2. 树的遍历3. 树的高度4. 树的大小 &#x…

RK3568平台开发系列讲解(I2C篇)i2c 总线驱动介绍

🚀返回专栏总目录 文章目录 一、i2c 总线定义二、i2c 总线注册三、i2c 设备和 i2c 驱动匹配规则沉淀、分享、成长,让自己和他人都能有所收获!😄 i2c 总线驱动由芯片厂商提供,如果我们使用 ST 官方提供的 Linux 内核, i2c 总线驱动已经保存在内核中,并且默认情况下已经…

vulnhub-matrix-breakout-2-morpheus靶机的测试报告

目录 一、测试环境 1、系统环境 2、使用工具/软件 二、测试目的 三、操作过程 1、信息搜集 2、Getshell ①nc反弹shell连接 ②Webshell上传 3、提权 ①使用kali自带的poc ②使用msf进行渗透 四、结论 一、测试环境 1、系统环境 渗透机:kali2021.1(19…

项目构建工具

一般面试中被问到的项目构建工具,常常会回答的是Maven 今天大概了解了一下目前项目构建构建有Maven,Ant,Gradle Gradle 是一个构建工具,它是用来帮助我们构建app的,构建包括编译,打包等过程。我们可以为Gradle指定构建规则&…

matlab 相关

1、xcorr 本质上是两个函数做内积运算 相关算法有两种: 在Matlab上既可以 1.用自带的xcorr函数计算互相关,2.通过在频域上乘以共轭复频谱来计算互相关; 网友验证程序 clc;clear;close all; % s1,s2为样例数据 s1 [-0.00430297851562500;-…

攻防世界----->Replace

前言:做题笔记。 下载 查壳。 upx32脱壳。 32ida打开。 先运行看看: 没有任何反应? 猜测又是 地址随机化(ASLR)---遇见过。 操作参考: 攻防世界---->Windows_Reverse1_dsvduyierqxvyjrthdfrtfregreg-CSDN博客 然后…

Spring系列 Bean创建过程

文章目录 初始化时机单例初始化流程getBeandoGetBeangetSingleton(String) 获取单例getSingleton(String, ObjectFactory) 创建单例beforeSingletonCreationcreateBeanafterSingletonCreation 创建 Bean 过程doCreateBeanaddSingletonFactory createBeanInstance 创建 Bean 对象…

医院管理智能化:Spring Boot技术革新

3系统分析 3.1可行性分析 通过对本医院管理系统实行的目的初步调查和分析,提出可行性方案并对其一一进行论证。我们在这里主要从技术可行性、经济可行性、操作可行性等方面进行分析。 3.1.1技术可行性 本医院管理系统采用JAVA作为开发语言,Spring Boot框…

ctf.bugku - game1

题目来源: game1 - Bugku CTF 访问页面,让玩游戏 得到100分,没拿到flag 查看页面源码, GET请求带有 score、IP、sign 三个参数,最后的flag 应该跟分数有关; 给了score一个99999分数, sign 为 …

STM32编码器接口

一、概述 1、Encoder Interface 编码器接口概念 编码器接口可接收增量(正交)编码器的信号,根据编码器旋转产生的正交信号脉冲,自动控制CNT自增或自减,从而指示编码器的位置、旋转方向和旋转速度每个高级定时器和通用…

如何录制微课教程?K12教育相关课程录制录屏软件推荐

在当今数字化教育的时代,微课作为一种重要的教学资源,受到了越来越多教师和学生的关注。制作一节优质的微课,录制是关键的环节之一。下面我们将结合相关知识,详细介绍如何录制微课教程。 一、微课录制前的准备 1.教学设计文档编写…