57、通过EEG数据的SHAPE变化,揭开EEG-TCNet的黑匣子[看好了小子,我只教这一次]

news2024/11/24 20:04:59

之前在第18篇博客中对于EEG-TCNet这个处理EEG信号的sota模型进行了介绍,也给出了模型,目前也是全网对于EEG-TCNet浏览度最高的文章了,我觉得讲的已经很细致了,没想到还是有不少同学疑问,这也是全网缺少该模型pytorch代码的原因,因为pytorch中没有封装TCN模块,无法直接调用,而在Tensorflow中可直接调用,废话不多少,上菜:

EEG-TCNet模型图:

原论文EEG-TCNet结构图

模型结构分析:

1、BCI IV2a数据以4维数据输入,shape=(288,1,22,1000)

2、数据先经过一个完整的EEGNet结构(时间卷积+深度卷积+深度可分离),来处理这个4维数组

3、数据从EEGNet出来,进入到TCN块之前进行降维处理(TCN只能处理1维数组)

下面我们来看2a数据以(batch_size,1,22,1000)输入到EEG-TCNet中是如何改变shape的

我自己写的EEG-TCNet代码模型-结构图:(我自己画的,别盗图)

TCN块(膨胀因果卷积)分析:

代码编写:

Chomp1d(nn.Module):裁剪类

TemporalBlock(nn.Module):TCN主体类,调用Chomp1d(),在这个类使用的卷积是Conv1d

TemporalConvNet(nn.Module):调用TemporalBlock()TCN完全体类

TCNNet(nn.Module):调用TemporalConvNet(),降维,使得TCN完全体跑的通


讲上面这4个类,我要倒着讲,费点劲要:(为啥倒着讲?同学想想 0。0 )

input_data = batch_size,1,22,1000经过一个前3个block后,此时控制台输出shape = 32,8,1,31断点如下:

数据此时还是4维的,所以我们在这使用if来判断维度,给他降维度

1、Data = torch.rand(x.shape):生成一个空的和x的维度一致的张量数据,用来存储for循环TCN块裁剪的数据

2、空的张量数据也要送到GPU中,否则报错,因为此时X的数据都在GPU上

3、在x的第二维度channel = 1,进行for循环,通过self来调用类内的tcn_block对应的TCN方法,对x数据进行裁剪并提取数据,把这些数据(此时还是4维)送给张量data

4、x = data(乾坤大魔移!


tcn_block对应着咱们定义的TemporalConvNet() 完全体这个类,如下:

类里面调用了上面定义好的Chomp1d()这个裁剪的类

此时代码跑到了Chomp1d()里面,如下所示:

TCN之前的数据= 32,8,1,31

此时数据维度 = 32,8,40,这里代码自动的去掉了通道=1的维度,并+res这个对x下采样的数据

因为这里是for i in range(x.shape[2])的循环,此时i=0,x.shape[3] = 40,我们再进入下一个循环i=1看看

此时x.shape[3] = 49,所以就这样,在送到Chomp1d()进行裁剪时,x加上了res这个下采样特征数据,导致了x的数据量增加,我们规定了Chomp1d()中的chomp_size这个数值,只保留与原始数据总量相同的前chomp_size的这个数目,来最后送给Fc层做最后结果的输出

此时我们送给Fc的shape :

又变回原来的31个数据了,这事裁剪类的功劳!但此时前后的这个31数据是不同的,多了下采样的特征,所以TCNNet这个类实现了先降维再生维的神奇操作,使得代码流通,完事。

全部代码如下:

class Chomp1d(nn.Module):
    def __init__(self, chomp_size):
        super(Chomp1d, self).__init__()
        self.chomp_size = chomp_size

    def forward(self, x):
        return x[:, :, :-self.chomp_size].contiguous()

class TemporalBlock(nn.Module):
    def __init__(self, n_inputs, n_outputs, kernel_size, stride, dilation, padding, dropout=0.2):
        super(TemporalBlock, self).__init__()
        self.conv1 = weight_norm(nn.Conv1d(n_inputs, n_outputs, kernel_size,
                                           stride=stride, padding=padding, dilation=dilation))
        self.chomp1 = Chomp1d(padding)
       
        self.relu1 = nn.ReLU()
        self.dropout1 = nn.Dropout(dropout)

        self.conv2 = weight_norm(nn.Conv1d(n_outputs, n_outputs, kernel_size,
                                           stride=stride, padding=padding, dilation=dilation))
        self.chomp2 = Chomp1d(padding)
       
        self.relu2 = nn.ReLU()
        self.dropout2 = nn.Dropout(dropout)

        self.net = nn.Sequential(self.conv1, self.chomp1,self.relu1, self.dropout1,
                                 self.conv2, self.chomp2,self.relu2, self.dropout2)
        self.downsample = nn.Conv1d(n_inputs, n_outputs, 1) if n_inputs != n_outputs else None
        self.relu = nn.ReLU()
    #   self.init_weights()

    # def init_weights(self):
    #     self.conv1.weight.data.normal_(0, 0.01)
    #     self.conv2.weight.data.normal_(0, 0.01)
    #     if self.downsample is not None:
    #         self.downsample.weight.data.normal_(0, 0.01)

    def forward(self, x):
        out = self.net(x)
        res = x if self.downsample is None else self.downsample(x)
        return self.relu(out + res)

class TemporalConvNet(nn.Module):
    def __init__(self, num_inputs, num_channels, kernel_size=2, dropout=0.2):
        super(TemporalConvNet, self).__init__()
        layers = []
        num_levels = len(num_channels)
        for i in range(num_levels):
            dilation_size = 2 ** i
            in_channels = num_inputs if i == 0 else num_channels[i-1]
            out_channels = num_channels[i]
            layers += [TemporalBlock(in_channels, out_channels, kernel_size, stride=1, dilation=dilation_size, padding=(kernel_size-1) * dilation_size, 
                                     dropout=dropout)]

        self.network = nn.Sequential(*layers)

    def forward(self, x):
        return self.network(x)

import numpy as np

class TCNNet(nn.Module):
     def __init__ (self,*args) -> None:
        super(TCNNet,self).__init__()
        if len(args) < 2:
            print('error')
            exit()
        else:
            num_inputs = args[0]
            num_channels = args[1]
            kernel_size = int(args[2][0])
            
        self.tcn_block =  TemporalConvNet(num_inputs,num_channels,kernel_size) 
        #self.tcn_block =  TemporalConvNet(num_inputs=self.F2,num_channels=[tcn_filters,tcn_filters],kernel_size=tcn_kernelSize) 
     def forward(self,x) :
        if len(x.shape) == 4:
            data = torch.rand(x.shape)
            device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 
            data = data.to(device)
 
            for i in range(x.shape[2]):
              
                data[:,:,i,:] = self.tcn_block(x[:,:,i,:])
            x = data
        else:
            x = torch.squeeze(x,dim=2) 
            x = self.tcn_block(x)
        
        return x 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1604944.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

PLC通讯革新:EtherNetIP转PROFINET网关在工业现场的应用指南

通讯革新&#xff1a;通过Profinet和Ethernet/IP网关实现PLC与PLC之间进行通讯 在工业自动化领域&#xff0c;PLC扮演着至关重要的角色。随着技术的不断进步&#xff0c;PLC通讯协议的兼容性变得越来越重要。本文将详细介绍如何通过Profinet和Ethernet/IP网关&#xff0c;将罗克…

2024年Q1季度冰箱行业线上市场销售数据分析

Q1季度冰箱线上市场表现不如预期。 根据鲸参谋数据显示&#xff0c;2024年1月至3月线上电商平台&#xff08;京东天猫淘宝&#xff09;冰箱累计销量约410万件&#xff0c;环比下降11%&#xff0c;同比下降21%&#xff1b;累计销售额约98亿元&#xff0c;环比下降31%&#xff0…

外包干了16天,技术倒退明显

先说情况&#xff0c;大专毕业&#xff0c;18年通过校招进入湖南某软件公司&#xff0c;干了接近6年的功能测试&#xff0c;今年年初&#xff0c;感觉自己不能够在这样下去了&#xff0c;长时间呆在一个舒适的环境会让一个人堕落&#xff01; 而我已经在一个企业干了四年的功能…

十:深入理解 CyclicBarrier—— 栅栏锁

目录 1、CyclicBarrier 入门1.1、概念1.2、案例 2、CyclicBarrier 源码分析2.1、类结构2.2、await() 方法 —— CyclicBarrier2.2.1、dowait() 方法 —— CyclicBarrier2.2.1.1、breakBarrier() 方法 —— CyclicBarrier2.2.1.2、nextGeneration() 方法 —— CyclicBarrier 3、…

“400G网络:QSFP-DD的登场,谁主沉浮?”

&#x1f31f;QSFP-DD 作为400G 光模块的最小外形尺寸&#xff0c;提供业界最高的带宽密度&#xff0c;同时利用对低速 QSFP 可插拔模块和电缆的向后兼容性&#xff0c;使其在光纤制造商中很受欢迎。作为400G高速应用中最新的热门光收发器&#xff0c;QSFP-DD经常被拿来与QSFP5…

九州金榜|家庭教育中如何疏导孩子抑郁情绪?

在家庭教育的过程中&#xff0c;孩子抑郁情绪的疏导是一项至关重要的任务。抑郁情绪不仅会影响孩子的心理健康&#xff0c;还可能对其学习、生活和人际关系产生负面影响。因此&#xff0c;家长需要积极关注孩子的情绪变化&#xff0c;采取有效的措施来疏导孩子的抑郁情绪。下面…

【位运算】Leetcode 只出现一次的数字 ||

题目解析 137. 只出现一次的数字 II 算法讲解 nums中要么一个数字出现三次&#xff0c;一个数字出现一次&#xff0c;按照比特位来说只可能出现上面的四种情况&#xff1a; 3n个0 0 或者 3n个0 1 或者 3n个1 0 或者 3n个1 1&#xff0c;它们相加的结果依次是0&#xff0c;…

虚拟机数据恢复—KVM虚拟机磁盘文件数据恢复案例

虚拟化数据恢复环境&故障&#xff1a; KVM是Kernel-based Virtual Machine的简称&#xff0c;是一个开源的系统虚拟化模块&#xff0c;自Linux2.6.20版本之后集成在Linux的各个主要发行版本中。KVM使用Linux自身的调度器进行管理。 本案例中的服务器操作系统为Linux&#x…

LInux下C语言模拟实现 —— 极简版的命令行解释器

根据对进程的理解&#xff0c;我们知道然后去使用系统接口去调用程序和加载程序&#xff0c;因此我们可以利用接口去实现一个简易版的命令行解释器&#xff0c;核心思路就是获取用户输入的指令信息&#xff0c;然后利用指令信息去调用相关的接口&#xff0c;因此首先就是要如何…

Linux安装mysql 8.0

1.使用root登录服务器 2.创建安装包存放目录 # mkdir /software # cd /software3.下载并解压mysql安装包 # wget https://dev.mysql.com/get/Downloads/MySQL-8.0/mysql-8.0.21-linux-glibc2.12-x86_64.tar.xz # tar xvJf mysql-8.0.21-linux-glibc2.12-x86_64.tar.xz # mv m…

Redis的特性与安装

回顾 Redis是一个在内存中存储数据的中间件&#xff0c;可以用来当数据库用&#xff0c;也可以作为缓存用(这里的缓存往往是对数据库缓存)。 中间件&#xff1a;和业务无关的服务&#xff0c;功能更加通用&#xff0c;如&#xff1a;数据库&#xff0c;缓存&#xff0c;消息队…

基于springboot实现音乐网站管理系统项目【项目源码+论文说明】计算机毕业设计

基于SpringBoot实现音乐网站管理系统演示 摘要 随着信息技术在管理上越来越深入而广泛的应用&#xff0c;管理信息系统的实施在技术上已逐步成熟。本文介绍了音乐网站的开发全过程。通过分析音乐网站管理的不足&#xff0c;创建了一个计算机管理音乐网站的方案。文章介绍了音乐…

LeetCode-热题100:230. 二叉搜索树中第K小的元素

题目描述 给定一个二叉搜索树的根节点 root &#xff0c;和一个整数 k &#xff0c;请你设计一个算法查找其中第 k 个最小元素&#xff08;从 1 开始计数&#xff09;。 示例 1&#xff1a; 输入&#xff1a; root [3,1,4,null,2], k 1 输出&#xff1a; 1 示例 2&#…

算法课程笔记——List

缺点&#xff1a;不能用下标计算得到 只能 一步步来 这样才是赋值 只是得到拷贝的结果 很多容器都需要&#xff08;int&#xff09;强制转化 list可以用sort 但是 例如&#xff0c;sort(L2.begin(), L2.end());&#xff0c;这种是algorithm标准算法类提供&#xff0c;属于…

钡铼IOy系列模块在智能装备制造中发挥重要作用提升整体效能

随着科技的不断发展&#xff0c;智能装备制造已经成为推动工业进步的重要力量之一。在智能装备制造领域&#xff0c;钡铼IOy系列模块在智能装备制造中起关键作用&#xff0c;对生产效率、产品质量和工厂管理也有一定的影响。 首先&#xff0c;钡铼IOy系列模块在智能装备制造中…

阿里云服务器多少钱一年?2024年阿里云服务器租用费用一览

阿里云服务器租用价格表2024年最新&#xff0c;云服务器ECS经济型e实例2核2G、3M固定带宽99元一年&#xff0c;轻量应用服务器2核2G3M带宽轻量服务器一年61元&#xff0c;ECS u1服务器2核4G5M固定带宽199元一年&#xff0c;2核4G4M带宽轻量服务器一年165元12个月&#xff0c;2核…

po+selenium+unittest自动化测试项目实战

&#x1f525; 交流讨论&#xff1a;欢迎加入我们一起学习&#xff01; &#x1f525; 资源分享&#xff1a;耗时200小时精选的「软件测试」资料包 &#x1f525; 教程推荐&#xff1a;火遍全网的《软件测试》教程 &#x1f4e2;欢迎点赞 &#x1f44d; 收藏 ⭐留言 &#x1…

电脑缺失api-ms-win-core-path-l1-1-0.dll的5种解决方法

在计算机使用过程中&#xff0c;我们经常会遇到一些错误提示&#xff0c;其中之一就是"api-ms-win-core-path-l1-1-0.dll丢失"。这个问题可能会导致某些软件无法正常运行或系统功能受限。那么&#xff0c;如何解决这个问题呢&#xff1f;下面将详细介绍api-ms-win-co…

【Android Studio报错】:* What went wrong:Out of memory. Java heap space

项目场景&#xff1a; 今天&#xff0c;刚打开自己的安卓项目发现报错&#xff1a; 报错&#xff1a; * What went wrong: Out of memory. Java heap space Possible solution: - Check the JVM memory arguments defined for the gradle process in: gradle.properties in…

windows C++fmt库下载

下载地址 https://github.com/fmtlib/fmt vs2019 debug x64进行编译 安装包如下 https://download.csdn.net/download/qq_36314864/89163873