keras实现TCN网络层

news2025/1/4 9:56:05

keras实现TCN网络层,keras3.0可用。

from keras.layers import  Lambda,Dense,Layer,Conv1D
import tensorflow as tf

class TCNCell(Layer):
    """
    sumary_line:
    Chinese:让输入的时间序列[bs,seql,dim]提升kernel_size倍的感受野
    English: Double the receptive field of the input time series [bs, seql, dim]
    """
    def __init__(self, filters=32,ks=3,activation=None,name=None):
        self.filters = filters
        self.ks = ks
        self.activation = activation
        super(TCNCell, self).__init__(name=name)


    def build(self, input_shape):
        assert len(input_shape) == 3, f"Input shape should be [batch, timesteps, features], but got {input_shape}"
        self.input_shape = input_shape
        bs,seq_l,dim = input_shape
        if input_shape[1]==1:
            self.out = Dense(self.filters,activation='relu')
        else:
            if not seq_l%self.ks == 0:
                self.maxlen = seq_l+self.ks-seq_l%self.ks
                self.pad_layer = Lambda(lambda x: tf.pad(tensor=x, paddings=[[0,0],[self.maxlen-seq_l, 0], [0, 0]], constant_values=0),output_shape=(self.maxlen,dim))
                assert self.maxlen%self.ks == 0, 'kernel size should be divisible by input length'
            self.tcn_cell = Conv1D(filters=self.filters, kernel_size=self.ks, strides=self.ks,activation=self.activation,padding='valid')
        super(TCNCell, self).build(input_shape)
    
    def call(self,x):
        if x.shape[1]==1 and hasattr(self,'out'):
            return self.out(x)
        else:
            if hasattr(self, 'pad_layer') and hasattr(self,'maxlen'):
                x = self.pad_layer(x)
                x = self.tcn_cell(x)
                return x
            else:
                return self.tcn_cell(x)
    

    
class TCN(Layer):

    """
    input: (batch_size,seq_len,feature_dim)
    output: (batch_size,output_len,feature_dim)
    """

    def __init__(self,filters_list=[32,64,128],kernel_size_list=[3,3,3],seq_len=32,name='TCN'):
        assert len(filters_list) == len(kernel_size_list), "filters_list and kernel_size_list must have the same length"
        self.l = len(filters_list)
        assert seq_len is not None and seq_len > 2**self.l, f"seql is None or receptive field must be smaller than squence length, please check"
        self.filters_list = filters_list
        self.kernel_size_list = kernel_size_list
        self.seql = seq_len
        self.print_receptive_field()
        super(TCN,self).__init__(name=name)

    def cala_receptive_field(self):
        ce_list = []
        for idx,ks in enumerate(self.kernel_size_list):
            if idx == 0:
                ce_list.append(ks)
            else:
                ce_list.append(ce_list[-1]*ks)
        return ce_list[-1]



    def print_receptive_field(self):
        ce = self.cala_receptive_field()
        print(f'当前的参数将会使感受野提升{ce}倍,即输出时间维度一个时刻能够反应其之前{ce}个时刻的特征')
        print(f'The current parameter will increase the receptive field by {ce} times,' + ' '+
              f'which means that the output time dimension can reflect the features of {ce} times before it at one moment')


    def build(self, input_shape):
        bs,seql,dim  = input_shape
        assert seql==self.seql, f'输入序列长度{seql}与设定的序列长度{self.seql}不一致' + ' ' + f'The input sequence length {seql} does not match the set sequence length {self.seql}'
        self.tcn_cell_layers = []
        for i in range(self.l):
            self.tcn_cell_layers.append(
                TCNCell(filters=self.filters_list[i],ks=self.kernel_size_list[i])
            )
        super(TCN, self).build(input_shape)
    
    def call(self,x):
        for i in range(self.l):
            x = self.tcn_cell_layers[i](x)
        return x
    

    
if __name__ == '__main__':
    import numpy as np
    tcnlayer = TCN()
    out = tcnlayer(np.zeros((1,32,768)))
    print(out.shape)
if __name__ == '__main__':
    import numpy as np
    tcnlayer = TCN()
    out = tcnlayer(np.zeros((1,32,768)))
    print(out.shape)

核心思路:使用valid卷积,卷积核大小和stride大小取相同的值,Conv1d只会沿着一个方向(序列正方向)进行移动,因此卷积核计算的特征具有因果特性(与pading=='causal'效果一样)。每经过一层卷积,得到的每个时刻就代表一个kernel_size个感受野。通过堆叠层数,实现感受野的增加。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2269255.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【D3.js in Action 3 精译_047】5.2:图形的堆叠(一)—— 图解 D3 中的堆叠布局生成器

当前内容所在位置: 第五章 饼图布局与堆叠布局 ✔️ 5.1 饼图和环形图的创建 5.1.1 准备阶段(一)5.1.2 饼图布局生成器(二)5.1.3 圆弧的绘制(三)5.1.4 数据标签的添加(四&#xff09…

自建私有云相册:Docker一键部署Immich,照片视频备份利器

自建私有云相册:Docker一键部署Immich,照片视频备份利器 前言 随着人们手机、PC、平板等电子产品多样,我们拍摄和保存的照片和视频数量也在不断增加。如何高效地管理和备份这些珍贵的记忆成为了一个重要的问题。 传统的云备份虽然方便&…

[微服务] - MQ高级

在昨天的练习作业中,我们改造了余额支付功能,在支付成功后利用RabbitMQ通知交易服务,更新业务订单状态为已支付。 但是大家思考一下,如果这里MQ通知失败,支付服务中支付流水显示支付成功,而交易服务中的订单…

【Unity3D】A*寻路(2D究极简单版)

运行后点击透明格子empty即执行从(0,0)起点到点击为止终点(测试是(5,5))如下图 UICamera深度要比MainCamera大,Clear Flags:Depth only,正交视野 MainCamera保持原样;注意Line绘线物体的位置大小旋转信息,不…

xadmin后台首页增加一个导入数据按钮

xadmin后台首页增加一个导入数据按钮 效果 流程 1、在添加小组件中添加一个html页面 2、写入html代码 3、在urls.py添加导入数据路由 4、在views.py中添加响应函数html代码 <!DOCTYPE html> <html lang

压敏电阻MOV选型【EMC】

左侧的压敏电阻用来防护差模干扰&#xff1b;右侧并联在L N 两端的压敏电阻是用来防护共模干扰&#xff1a; 选择压敏电阻时&#xff0c;通常需要考虑以下几个关键因素&#xff0c;以确保它能够有效保护电路免受浪涌电流或过电压的损害&#xff0c;同时满足 EMC 要求&#xff1…

pycharm pytorch tensor张量可视化,view as array

Evaluate Expression 调试过程中&#xff0c;需要查看比如attn_weight 张量tensor的值。 方法一&#xff1a;attn_weight.detach().numpy(),view as array 方法二&#xff1a;attn_weight.cpu().numpy(),view as array

log4j2的Strategy、log4j2的DefaultRolloverStrategy、删除过期文件

文章目录 一、DefaultRolloverStrategy1.1、DefaultRolloverStrategy节点1.1.1、filePattern属性1.1.2、DefaultRolloverStrategy删除原理 1.2、Delete节点1.2.1、maxDepth属性 二、知识扩展2.1、DefaultRolloverStrategy与Delete会冲突吗&#xff1f;2.1.1、场景一&#xff1a…

设计模式之访问者模式:一楼千面 各有玄机

~犬&#x1f4f0;余~ “我欲贱而贵&#xff0c;愚而智&#xff0c;贫而富&#xff0c;可乎&#xff1f; 曰&#xff1a;其唯学乎” 一、访问者模式概述 \quad 江湖中有一个传说&#xff1a;在遥远的东方&#xff0c;有一座神秘的玉楼。每当武林中人来访&#xff0c;楼中的各个房…

结合实例来聊聊UDS诊断中的0x2F服务

1、什么是UDS中的0x2F服务 0x2F简单来说&#xff0c;就是输入输出控制服务。先看官方的简绍 翻译如下&#xff1a; InputOutputControlByldentifier服务来替换输入信号、内部服务器函数和/或强制控制为电子系统的输出&#xff08;执行器&#xff09;的值。通常&#xff0c;此…

1月第二讲:WxPython跨平台开发框架之图标选择界面

1、图标分类介绍 这里图标我们分为两类&#xff0c;一类是wxPython内置的图标资源&#xff0c;以wx.Art_开始。wx.ART_ 是 wxPython 提供的艺术资源&#xff08;Art Resource&#xff09;常量&#xff0c;用于在界面中快速访问通用的图标或位图资源。这些资源可以通过 wx.ArtP…

【弱监督视频异常检测】2024-TCSVT-基于片段间特征相似度的多尺度时间 MLP 弱监督视频异常检测

2024-TCSVT-Inter-clip Feature Similarity based Weakly Supervised Video Anomaly Detection via Multi-scale Temporal MLP 基于片段间特征相似度的多尺度时间 MLP 弱监督视频异常检测摘要1. 引言2. 相关工作A. 分布外检测B. 弱监督视频异常检测C. 多层感知器 3. 方法A. 概述…

C# OpenCV机器视觉:凸包检测

在一个看似平常却又暗藏玄机的午后&#xff0c;阿强正悠闲地坐在实验室里&#xff0c;翘着二郎腿&#xff0c;哼着小曲儿&#xff0c;美滋滋地品尝着手中那杯热气腾腾的咖啡&#xff0c;仿佛整个世界都与他无关。突然&#xff0c;实验室的门 “砰” 的一声被撞开&#xff0c;小…

【英特尔IA-32架构软件开发者开发手册第3卷:系统编程指南】2001年版翻译,2-44

文件下载与邀请翻译者 学习英特尔开发手册&#xff0c;最好手里这个手册文件。原版是PDF文件。点击下方链接了解下载方法。 讲解下载英特尔开发手册的文章 翻译英特尔开发手册&#xff0c;会是一件耗时费力的工作。如果有愿意和我一起来做这件事的&#xff0c;那么&#xff…

8.若依系统监控与定时任务

帮助开发者和运维快速了解应用程序的性能状态。 数据监控 定时任务 实现动态管理任务。 需求&#xff1a;每间隔5s&#xff0c;控制台输出系统时间。 新建的任务类必须在指定目录ruoyi-quartz模块下的task包下。 状态设置为启动 执行策略 场景&#xff1a;比如一个任务每个…

【JAVA高级篇教学】第六篇:Springboot实现WebSocket

在 Spring Boot 中对接 WebSocket 是一个常见的场景&#xff0c;通常用于实现实时通信。以下是一个完整的 WebSocket 集成步骤&#xff0c;包括服务端和客户端的实现。本期做个简单的测试用例。 目录 一、WebSocket 简介 1. 什么是 WebSocket&#xff1f; 2. WebSocket 的特…

【YOLO 项目实战】(12)红外/可见光多模态目标检测

欢迎关注『youcans动手学模型』系列 本专栏内容和资源同步到 GitHub/youcans 【YOLO 项目实战】&#xff08;10&#xff09;YOLO8 环境配置与推理检测 【YOLO 项目实战】&#xff08;11&#xff09;YOLO8 数据集与模型训练 【YOLO 项目实战】&#xff08;12&#xff09;红外/可…

Ubuntu开机The root filesystem on /dev/sdbx requires a manual fsck 问题

出现“Manual fsck”错误可能由以下几种原因引起&#xff1a; 不正常关机&#xff1a;如果系统意外断电或被强制重启&#xff0c;文件系统可能未能正确卸载&#xff0c;导致文件系统损坏。磁盘故障&#xff1a;硬盘的物理损坏可能会引发文件系统错误。文件系统配置问题&#x…

RFSOC 47dr Dp口测试(ARM裸机)

47DR 内核还是一个4核A53的MPSOC&#xff0c;测试方式和MPSOC一样 首先设置好BD文件 编译好BIT设置VITIS工程 examle工程测试即可 但是本人硬件会跑飞不知道为何&#xff0c;通过注释掉下图的子函数得以解决 值得注意的是&#xff0c;最好用HP的线&#xff0c;不要用DP转…

protobuf: 通讯录2.1

先引入需要知道的proto3语法&#xff1a; 1.proto3 1.hexdump 作用&#xff1a; hexdump&#xff1a;是Linux下的⼀个⼆进制⽂件查看⼯具&#xff0c;它可以将⼆进制⽂件转换为ASCII、⼋进制、 ⼗进制、⼗六进制格式进⾏查看。 -C: 表⽰每个字节显⽰为16进制和相应的ASCI…