python 库笔记:pytorch-tcn

news2025/3/19 8:59:36

 


提供以下功能

  • TCN类
  • Conv1dConvTranspose1d的实现,并带有因果/非因果切换功能
  • 流式推理(Streaming Inference)选项,可用于实时应用
  • 兼容 ONNX(Open Neural Network Exchange)格式,可在非Python环境(如C++)中使用已训练的TCN模型

1 TCN类

 

from pytorch_tcn import TCN

model = TCN(
    num_inputs: int,
    num_channels: ArrayLike,
    kernel_size: int = 4,
    dilations: Optional[ ArrayLike ] = None,
    dilation_reset: Optional[ int ] = None,
    dropout: float = 0.1,
    causal: bool = True,
    use_norm: str = 'weight_norm',
    activation: str = 'relu',
    kernel_initializer: str = 'xavier_uniform',
    use_skip_connections: bool = False,
    input_shape: str = 'NCL',
    embedding_shapes: Optional[ ArrayLike ] = None,
    embedding_mode: str = 'add',
    use_gate: bool = False,
    lookahead: int = 0,
    output_projection: Optional[ int ] = None,
    output_activation: Optional[ str ] = None,
)

1.1 输入&输出形状

 (N, Cin, L)

  • N:批量大小
  • Cin:输入/出 通道数(特征维度)
  • L:序列长度

1.2 参数详解

num_inputs输入数据的特征维度
num_channels

一个list,指定每个残差块的通道数

(这个TCN多少层也是通过这个来确定的)

kernel_size卷积核大小
dilations膨胀率
  • 若为 None,则自动计算为 2^(0...n)(标准做法)。
  • 也可手动传入特定膨胀率列表,如 [1, 2, 4, 8]
dilation_reset膨胀率重置
  • 若不重置,膨胀率会指数级增长,导致内存溢出。
  • 如果设置了重置,那么超过之后会从最小的dilation开始再来一轮。
    • 例如,dilation_reset=16 使膨胀率超过16后重置,如 [1, 2, 4, 8, 16, 1, 2, 4, ...]
dropout
causal

是否使用因果卷积

  • True:忽略未来信息,适用于实时预测
  • False:考虑未来信息,可用于非实时预测
use_norm

归一化方式

  • 可选:weight_norm(默认)/batch_norm/layer_norm/None
  • weight_norm 在原论文中使用,其他方式需根据具体任务测试。
activation激活函数:默认 relu
kernel_initializer

权重初始化

  • 可选:uniform / normal / kaiming_uniform / kaiming_normal / xavier_uniform / xavier_normal
  • 默认 xavier_uniform,相比 normal 初始化更稳定。
use_skip_connections

跳跃连接

  • True:每个残差块的输出都会传递到最终输出(类似 WaveNet)。
  • False(默认):不使用跳跃连接(原始 TCN 结构)
input_shape输入形状格式)
  • NCL(默认):批量大小、通道数、序列长度(PyTorch 格式)。
  • NLC:批量大小、序列长度、通道数(时序数据常见格式)。
output_projection输出投影:
  • 若非 None,则通过 1x1 卷积将输出投影到指定维度。
  • 适用于输入和输出维度不同的情况
output_activation输出激活函数:
  • 可选 softmax 等,用于分类任务。
  • None(默认):不使用激活函数。

2因果卷积

pytorch-TCN 提供了一个 因果卷积层,它继承自 PyTorch 的 Conv1d,可以直接替换标准的 Conv1d

2.1 参数说明

2.1.1 TemporalConv1d(因果卷积层)

in_channels输入通道数
out_channels输出通道数
kernel_size卷积核大小
stride步长,默认 1
padding默认 0,自动计算,若手动设置可能会报错
dilation膨胀率,默认 1
groups组卷积数,默认 1
bias是否使用偏置,默认 True
padding_mode填充模式,默认 zeros
device运行设备,默认 None(自动选择)
dtype数据类型,默认 None
causal是否使用因果卷积,默认 True

 2.1.2 TemporalConvTranspose1d(转置卷积层)

in_channels输入通道数
out_channels输出通道数
kernel_size卷积核大小
stride步长,默认 1
padding默认 0,自动计算,若手动设置可能会报错
dilation膨胀率,默认 1
groups组卷积数,默认 1
bias是否使用偏置,默认 True
padding_mode填充模式,默认 zeros
device运行设备,默认 None(自动选择)
dtype数据类型,默认 None
causal是否使用因果卷积,默认 True

2.2 基本使用举例 

from pytorch_tcn import TemporalConv1d, TemporalConvTranspose1d
import torch


# 因果卷积层
conv = TemporalConv1d(
    in_channels=32,     # 输入通道数
    out_channels=32,    # 输出通道数
    kernel_size=3,      # 卷积核大小
    causal=True         # 是否使用因果卷积
)

# 反卷积层(转置卷积)
conv_t = TemporalConvTranspose1d(
    in_channels=32,
    out_channels=32,
    kernel_size=4,
    stride=2
)

# 前向传播
x = torch.randn(10, 32, 100)  # (batch_size=10, in_channels=32, seq_len=100)

y = conv(x, inference=False, in_buffer=None)    
# 进行卷积运算
y_t = conv_t(x, inference=False, in_buffer=None)  
# 进行转置卷积


y.shape,y_t.shape
#(torch.Size([10, 32, 100]), torch.Size([10, 32, 200]))

3 流式推理

  • 模型能够逐块(blockwise)处理数据,而无需加载完整的序列。
  • 这一功能对于实时应用至关重要。

3.1  流式推理的挑战

  • 在 TCN 结构中,若 kernel_size > 1,为了保证输出的时间步与输入相同,TCN 始终会使用零填充(zero padding)
  • 然而,在块状处理(blockwise processing)时,零填充可能会导致感受野断裂,从而影响模型的推理效果
    • 假设输入序列为[ X1, X2, X3, X4 ]
    • 使用 kernel_size=3dilation=1 的因果网络时,第一层的填充长度为 2
    • 期望的输入格式:[ 0, 0, X1, X2, X3, X4 ]
    • 但若按块状输入 [X1, X2][X3, X4],则会导致:[ 0, 0, X1, X2 ] + [ 0, 0, X3, X4 ]
    • ——>不同块的填充断裂,使得推理结果与整体序列处理时不同,影响模型的表现。
  • ——>TCN 采用内部缓冲区(buffer),用于存储网络的输入历史,并在下一次推理时将其作为填充
    • 因此,无论数据是否是整块输入还是逐块输入,最终输出都是一致的。
    • 注意:流式推理时,批量大小 batch_size 必须为 1

3.2 如何使用流式推理

from pytorch_tcn import TCN

# 初始化 TCN 模型
tcn = TCN(
    num_inputs=10,      # 输入特征数
    num_channels=[32, 64, 128],
    causal=True,         # 流式推理只适用于因果网络
)

# 处理新序列前 **必须** 重置缓冲区
tcn.reset_buffers()
#作用是,清空存储的历史输入,确保每次推理不会受到上一次序列的影响。


# 流式推理的输入数据按块分割
for block in blocks:  # 每个 block 形状:(1, num_inputs, block_size)
    out = tcn(block, inference=True)  
    # 启用流式推理,使 TCN 在推理时利用缓冲区填充,确保逐块处理时的输出与完整序列处理时一致。
    

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2317692.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

新造车不再比拼排名,恰是曲终人散时,剩者为王

据称新能源汽车周销量不再发布,这可能也预示着新造车终于到了给出答案的时候了,新造车企业前三强已基本确立,其余那些落后的车企已很难有突围的机会,而特斯拉无疑是其中的最大赢家。 3月份第一周的数据显示,销量最高的…

博客迁移----宝塔面板一键迁移遇到问题

前景 阿里云轻量级服务器到期了,又免费领了个ESC, 安转了宝塔面板。现在需要迁移数据,使用宝塔面板一键迁移功能,完成了数据的迁移,改了域名的解析,现在进入博客是显示502 bad grateway 宝塔搬家参考链接…

大数据处理最容易的开源平台

大数据处理最容易的开源平台可以从多个角度进行分析,包括易用性、灵活性、成本效益以及社区支持等方面。 Apache Spark Apache Spark 是一个广泛使用的开源大数据处理框架,以其快速、通用和易于使用的特点而著称。它支持多种编程语言(如 Scal…

Dify 使用 - 创建 翻译 工作流

文章目录 1、选择 模板2、设置 和 基本使用3、运行应用 1、选择 模板 2、设置 和 基本使用 翻译模板 自带了系统提示词,你也可以修改 3、运行应用 右上角 点击 发布 – 更新,运行应用,就可以在新的对话界面中使用此功能 2025-03-18&#x…

TreelabPLMSCM数字化供应链解决方案0608(61页PPT)(文末有下载方式)

详细资料请看本解读文章的最后内容。 资料解读:TreelabPLMSCM 数字化供应链解决方案 0608 在当今快速变化的市场环境中,企业面临着诸多挑战,Treelab 数智化 PLM_SCM 行业解决方案应运而生。该方案聚焦市场趋势与行业现状,致力于解…

LogicFlow介绍

LogicFlow介绍 LogicFlow是一款流程图编辑框架,提供了一系列流程图交互、编辑所必需的功能和灵活的节点自定义、插件等拓展机制。LogicFlow支持前端自定义开发各种逻辑编排场景,如流程图、ER图、BPMN流程等。在工作审批流配置、机器人逻辑编排、无代码平…

[蓝桥杯 2023 省 B] 飞机降落

[蓝桥杯 2023 省 B] 飞机降落 题目描述 N N N 架飞机准备降落到某个只有一条跑道的机场。其中第 i i i 架飞机在 T i T_{i} Ti​ 时刻到达机场上空,到达时它的剩余油料还可以继续盘旋 D i D_{i} Di​ 个单位时间,即它最早可以于 T i T_{i} Ti​ 时刻…

应用分层简介

一、什么是应用分层 应用分层是一种软件开发设计思想,它将应用程序分为多个层次,每个层次各司其职,多个层次之间协同提供完整的功能,根据项目的复杂程度,将项目分为三层或者更多层。 常见的MCV设计模式,就…

基于香橙派 KunpengPro学习CANN(3)——pytorch 模型迁移

通用模型迁移适配可以分为四个阶段:迁移分析、迁移适配、精度调试与性能调优。 迁移分析 迁移支持度分析: 准备NPU环境,获取模型的源码、权重和数据集等文件;使用迁移分析工具采集目标网络中的模型/算子清单,识别第三方…

电子硬件入门(三)——偏置电路

文章目录 一、先理解问题:为什么需要偏置电压?二.偏置电路生成的四大核心零件​三、工作流程图解​四、实物电路对照​五、常见问题答疑 一、先理解问题:为什么需要偏置电压? 想象一下,电机的电流像一条波浪线&#x…

使用C++写一个递推计算均方差和标准差的用例

文章目录 代码输出关键实现说明1. 类设计2. 算法核心3. 数值稳定性 扩展应用场景1. 实时传感器数据处理2. 大规模数据集分块处理 总结 以下是用 C 实现递推计算均值、方差和标准差的完整示例代码,基于 Welford 算法,适用于实时数据流或大数据场景&#x…

Vue:单文件组件

Vue:单文件组件 1、 什么是单文件组件? 在传统的Vue开发里,我们接触的是非单文件组件,它们通常被定义在同一个HTML文件中,随着项目规模的扩大,代码会变得杂乱无章,维护起来极为困难。而单文件…

JavaScript变量声明与DOM操作指南

变量声明 1.变量声明有三个 var let 和 const 我们应该用那个呢? 首先var 先排除,老派写法,问题很多,可以淘汰掉… 2.let or const ? 建议: const 优先,尽量使用const,原因是:…

[K!nd4SUS 2025] Crypto

最后一个把周末的补完。这个今天问了小鸡块神终于把一个补上,完成5/6,最后一个网站也上不去不弄了。 Matrices Matrices Matrices 这个是不是叫LWE呀,名词忘了,但意思还是知道。 b a*s e 这里的e是高斯分成,用1000…

工作记录 2017-02-04

工作记录 2017-02-04 序号 工作 相关人员 1 修改邮件上的问题。 更新RD服务器。 郝 更新的问题 1、DataExport的设置中去掉了ListPayors,见DataExport\bin\dataexport.xml 2、“IPA/Group Name” 改为 “Insurance Name”。 3、修改了Payment Posted的E…

Etcd 服务搭建

💢欢迎来到张胤尘的开源技术站 💥开源如江河,汇聚众志成。代码似星辰,照亮行征程。开源精神长,传承永不忘。携手共前行,未来更辉煌💥 文章目录 Etcd 服务搭建预编译的二进制文件安装下载 etcd 的…

【C++】stack和queue的使用及模拟实现(含deque的简单介绍)

文章目录 前言一、deque的简单介绍1.引入deque的初衷2.deque的结构3.为什么选择deque作为stack和queue的底层默认容器 二、stack1.stack的介绍2.stack的使用3.stack的模拟实现 三、queue1.queue的介绍2.queue的使用3.queue的模拟实现 前言 一、deque的简单介绍(引入…

MySQL原理:逻辑架构

目的:了解 SQL执行流程 以及 MySQL 内部架构,每个零件具体负责做什么 理解整体架构分别有什么模块每个模块具体做什么 目录 1 服务器处理客户端请求 1.1 MySQL 服务器端逻辑架构说明 2 Connectors 3 第一层:连接层 3.1 数据库连接池(Conn…

ora-600 ktugct: corruption detected---惜分飞

接手一个oracle 21c的库恢复请求,通过Oracle数据库异常恢复检查脚本(Oracle Database Recovery Check)脚本检测之后,发现undo文件offline之后,做了resetlogs操作,导致该文件目前处于WRONG RESETLOGS状态 尝试恢复数据库ORA-16433错误 SQL> recover datafile 1; ORA-00283:…

Houdini :《哪吒2》神话与科技碰撞的创新之旅

《哪吒2》(即《哪吒之魔童闹海》)截止至今日,荣登全球票房榜第五。根据猫眼专业版数据,截至2025年3月15日,《哪吒2》全球累计票房(含预售及海外)超过150.19亿元,超越《星球大战&…