TSception:从EEG中捕获时间动态和空间不对称性用于情绪识别

news2024/12/28 5:20:15

TSception:从EEG中捕获时间动态和空间不对称性用于情绪识别(论文复现)

  • 摘要
  • 模型结构
  • 代码实现
  • 写在最后

**这是一篇代码复现,原文通过Pytorch实现,本文中使用Keras对该结构进行复现。**该论文发表在IEEE Transactions on Affective Computing,第一作者Yi Ding

摘要

高时间分辨率和不对称的空间激活是脑内情绪过程的基本特征。为了学习EEG的时间动态性和空间不对称性,以实现准确和广义的情感识别,Yi Ding等人提出了一种多尺度卷积神经网络TSception,可以从EEG中对情感进行分类。Tsception由动态时间、非对称空间和高级融合层组成,同时学习时间和信道维度的区分表示。动态时域层由多尺度1D卷积核组成,其长度与EEG的采样率相关,其学习EEG的动态时间和频率表示。非对称空间层利用情绪的非对称EEG模式,学习有区别的全局和半球表示。原文代码可在以下网址获得:https://github.com/yi-ding-cs/TSception,本文复现完整代码可在下面获得:https://github.com/ruix6/tsception

模型结构

关于模型结构的相关公式推理可以参考原文,本文不详细展开,下图是模型的具体结构:
在这里插入图片描述
熟悉经典深度学习模型的同学应该能一眼看出来TSception的设计灵感来自Inception模型。结合脑电信号的特点,TSception分四步实现对脑电信号的计算。

  1. 多尺度时域卷积:第一步通过多个尺度的时域卷积核实现对EEG信号的分解与特征提取。多尺度的优势在于可以给模型提供多个不同的感受野,这对于脑电信号这种多源的复杂信号来说是十分合理的。作为对比,我们可以看一下EEGNet的结构,如下图,EEGNet的第一个部分也是一个一维时域卷积结构,但是由于单一的感受野,在很多任务中它很容易被低频部分的噪声所干扰,所以EEGNet在ERP或者SMR这种有效信息分布在低频段的任务比较友好,但是像情绪识别的话,该网络的原始结构似乎不能发挥其全部能量(改变其时域卷积核大小似乎能有效提升其能力)。在这里插入图片描述
  2. 不对称空间卷积层:模型的第二部分由两个尺度的卷积实现。大尺度的卷积层覆盖所有通道,小尺度的卷积层分可以分别卷积大脑左半球的通道和右半球的通道。前面提到过,不对称的空间激活(即受试者在不同的情绪状态下,大脑的左右半球的激活状态是不一样的,原文中分析了受试者的大脑激活状态,想进一步了解可以看看原文)是情绪的重要特征,所以通过小尺度的空间卷积可以进一步的抓住这些特征。
    在这里插入图片描述
  3. 高级融合层:该层为了进一步的融合输入的时空特征而设计,这个和EEGNet的可分离卷积层的效果是一样的,我更愿意称之为为了减小参数量而设计的🙂,这个卷积层的出现,其实使得模型的可解释性进一步降低。
  4. 分类:把高级融合层的输出做个全局平均池化然后全连接,最后输出。

代码实现

代码如下:

from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv2D, AveragePooling2D, Flatten, Dense, Dropout, BatchNormalization, concatenate, LeakyReLU


# 定义时域卷积块
def conv_block(input, out_chan, kernel, step, pool):
    x = Conv2D(out_chan, kernel, strides=step)(input)# padding='same', use_bias=False
    x = LeakyReLU()(x)
    x = AveragePooling2D(pool_size=(1, pool), strides=(1, pool))(x)

    return x


def Tsception(num_classes, Chans, Samples, sampling_rate, num_T, num_S, hidden, dropout_rate, pool=8):

    '''
    input_size: 输入数据的维度,(chans, samples, 1)
    '''
    inception_window = [0.5, 0.25, 0.125]
    # 定义输入层
    input = Input(shape=(Chans, Samples, 1))
    # 定义时域卷积层
    x1 = conv_block(input, num_T, (1, int(sampling_rate * inception_window[0])), 1, pool)
    x2 = conv_block(input, num_T, (1, int(sampling_rate * inception_window[1])), 1, pool)
    x3 = conv_block(input, num_T, (1, int(sampling_rate * inception_window[2])), 1, pool)
    # 在height维度上进行拼接
    x = concatenate([x1, x2, x3], axis=2)
    x = BatchNormalization()(x)
    # 定义空域卷积层
    y1 = conv_block(x, num_S, (Chans, 1), (Chans, 1), int(pool*0.25))
    y2 = conv_block(x, num_S, (int(Chans*0.5), 1), (int(Chans*0.5), 1), int(pool*0.25))
    # 在width维度上进行拼接
    y = concatenate([y1, y2], axis=1)
    y = BatchNormalization()(y)
    # 定义fusion_layer
    z = conv_block(y, num_S, (3, 1), (3, 1), 4)
    z = BatchNormalization()(z)
    # 定义全局平均池化层
    z = AveragePooling2D(pool_size=(1, z.shape[2]))(z)
    z = Flatten()(z)
    # 全连接层
    z = Dense(hidden, activation='relu')(z)# , use_bias=False
    z = Dropout(dropout_rate)(z)
    z = Dense(num_classes, activation='softmax')(z)# , use_bias=False

    return Model(inputs=input, outputs=z)

参照原文的各个超参数,引用方式为:

if __name__ == '__main__': 
    model = Tsception(num_classes=2, Chans=28, Samples=512, sampling_rate=128, num_T=15, num_S=15, hidden=32, dropout_rate=0.5)
    model.summary()

最后的输出为:

__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to
==================================================================================================
 input_1 (InputLayer)           [(None, 28, 512, 1)  0           []
                                ]

 conv2d (Conv2D)                (None, 28, 449, 15)  975         ['input_1[0][0]']

 conv2d_1 (Conv2D)              (None, 28, 481, 15)  495         ['input_1[0][0]']

 conv2d_2 (Conv2D)              (None, 28, 497, 15)  255         ['input_1[0][0]']

 leaky_re_lu (LeakyReLU)        (None, 28, 449, 15)  0           ['conv2d[0][0]']

 leaky_re_lu_1 (LeakyReLU)      (None, 28, 481, 15)  0           ['conv2d_1[0][0]']

 leaky_re_lu_2 (LeakyReLU)      (None, 28, 497, 15)  0           ['conv2d_2[0][0]']

 average_pooling2d (AveragePool  (None, 28, 56, 15)  0           ['leaky_re_lu[0][0]']
 ing2D)

 average_pooling2d_1 (AveragePo  (None, 28, 60, 15)  0           ['leaky_re_lu_1[0][0]']
 oling2D)

 average_pooling2d_2 (AveragePo  (None, 28, 62, 15)  0           ['leaky_re_lu_2[0][0]']
 oling2D)

 concatenate (Concatenate)      (None, 28, 178, 15)  0           ['average_pooling2d[0][0]',
                                                                  'average_pooling2d_1[0][0]',
                                                                  'average_pooling2d_2[0][0]']

 batch_normalization (BatchNorm  (None, 28, 178, 15)  60         ['concatenate[0][0]']
 alization)

 conv2d_3 (Conv2D)              (None, 1, 178, 15)   6315        ['batch_normalization[0][0]']

 conv2d_4 (Conv2D)              (None, 2, 178, 15)   3165        ['batch_normalization[0][0]']

 leaky_re_lu_3 (LeakyReLU)      (None, 1, 178, 15)   0           ['conv2d_3[0][0]']

 leaky_re_lu_4 (LeakyReLU)      (None, 2, 178, 15)   0           ['conv2d_4[0][0]']

 average_pooling2d_3 (AveragePo  (None, 1, 89, 15)   0           ['leaky_re_lu_3[0][0]']
 oling2D)

 average_pooling2d_4 (AveragePo  (None, 2, 89, 15)   0           ['leaky_re_lu_4[0][0]']
 oling2D)

 concatenate_1 (Concatenate)    (None, 3, 89, 15)    0           ['average_pooling2d_3[0][0]',
                                                                  'average_pooling2d_4[0][0]']

 batch_normalization_1 (BatchNo  (None, 3, 89, 15)   60          ['concatenate_1[0][0]']
 rmalization)

 conv2d_5 (Conv2D)              (None, 1, 89, 15)    690         ['batch_normalization_1[0][0]']

 leaky_re_lu_5 (LeakyReLU)      (None, 1, 89, 15)    0           ['conv2d_5[0][0]']

 average_pooling2d_5 (AveragePo  (None, 1, 22, 15)   0           ['leaky_re_lu_5[0][0]']
 oling2D)

 batch_normalization_2 (BatchNo  (None, 1, 22, 15)   60          ['average_pooling2d_5[0][0]']
 rmalization)

 average_pooling2d_6 (AveragePo  (None, 1, 1, 15)    0           ['batch_normalization_2[0][0]']
 oling2D)

 flatten (Flatten)              (None, 15)           0           ['average_pooling2d_6[0][0]']

 dense (Dense)                  (None, 32)           512         ['flatten[0][0]']

 dropout (Dropout)              (None, 32)           0           ['dense[0][0]']

 dense_1 (Dense)                (None, 2)            66          ['dropout[0][0]']

==================================================================================================
Total params: 12,653
Trainable params: 12,563
Non-trainable params: 90
__________________________________________________________________________________________________

原文结果,测试数据集是DEAP情绪数据集:
在这里插入图片描述

写在最后

原文的作者并没有对模型进行更加具体的调参,事实上,输出的全连接层的神经元设置为32应该是意义不大的,效果可能不如直接链接到输出层,并且如果要考虑进一步缩小参数量的话,各个卷积层的偏置权重其实可以去除。在原文中,模型的效果表现与其它的卷积神经网络并没有统计学意义上的差别,但是,如果能够进一步调参的话,效果实际上要好很多。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/661980.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【Python 随练】古典问题:兔子繁殖问题

题目: 古典问题:有一对兔子,从出生后第 3 个月起每个月都生一对兔子,小兔子长到第三个月,后每个月又生一对兔子,假如兔子都不死,问每个月的兔子总数为多少? 简介: 在本…

三、DSMP/OLS等夜间灯光数据贫困地区识别——MPI和灯光指数拟合、误差分析

一、前言 当我们准备好MPI和灯光指数(包括总灯光指数和平均灯光指数)之后,接下来主要的过程就是通过将MPI和灯光指数拟合,构建多维度指数估算模型,这里我解释一下前文中的MPI计算过程,其实利用熵值法确定指标权重,并通过各 指 标 归 一 化 数 值 乘 以 对 应 的 权 重 …

UG NX二次开发(C#)-用UFun函数导出图像(Image)

文章目录 1、前言2、在UG NX中交互导出图像的操作2.1 打开一个三维模型2.2 打开导出图像的界面3、采用UFun函数来实现3.1 搜索image的方法3.2 帮助说明3.3 应用环境3.4 方法应用4、后记1、前言 在UG NX二次开发过程中,三维CAD模型有时需要导出为图像,如.png、.jpg、.bmp、.t…

类与封装的概念

类通常分为以下两个部分 类的实现细节 类的使用方式 当使用类时,不需要关心其实现细节 当创建类时,才需要考虑其内部实现细节 封装的基本概念 根据经验:并不是类的每个属性都是对外公开的 如:女孩子不希望外人知道自己的体重…

【系统开发】尚硅谷 - 谷粒商城项目笔记(二):搭建分布式系统基本环境

文章目录 搭建分布式系统基本环境引入spring-cloud-alibaba依赖Nacos作为注册中心Feign 远程调用Nacos作为配置中心Nacos配置中心进阶Nacos加载多配置集GateWay网关网关路由分发解释 搭建分布式系统基本环境 引入spring-cloud-alibaba依赖 在common的pom.xml中加入 &#xff…

Socket网络通信过程 与 IO多路复用原理

0、引言 本文主要讲述Socket网络编程的基本知识、IO多路复用的select、poll、epoll实现原理以及比较,并解答了一些socket建立连接、阻塞的常见问题。 1、什么是Socket、网络通信的过程 Socket 的中文名叫作插口,事实上,双方要进行网络通信前…

HTML(一)

一.HTML的标准结构 <!doctype html> 声明文档类型<html> HTML根标签<head> 头标签<title></title> 标题标签</head><body> 主题标签...</body></html> 二.标签介绍 2.1 段落标签 1.注释标签 <!--我是一个注释--…

送外卖适合什么蓝牙耳机,推荐几款适合户外佩戴的骨传导耳机

骨传导耳机&#xff0c;是通过震动的方式将声音转化为不同频率的机械振动&#xff0c;由于不需要通过耳膜就可以听到声音&#xff0c;骨传导耳机在保留传统耳机的优点的基础上&#xff0c;解决了传统耳机不能在开放环境中使用的问题。那么在骨传导耳机中&#xff0c;究竟有哪些…

MobaXterm 常用设置

MobaXterm 是用于远程计算的工具箱&#xff0c;作为一个 Windows 应用程序&#xff0c;它为程序员、网站管理员、IT管理员和几乎所有需要以更简单的方式处理远程工作的用户量身定制了大量功能。MobaXterm 提供了所有重要的远程网络工具(SSH, X11, RDP, VNC, FTP, MOSH&#xff…

消息队列常见问题整理

前言 消息队列&#xff08;Message Queue&#xff09;&#xff0c;从广义上讲是一种消息队列服务中间件&#xff0c;提供一套完整的信息生产、传递、消费的软件系统。 消息队列所涵盖的功能远不止于队列&#xff08;Queue&#xff09;&#xff0c;其本质是两个进程传递信息的…

Java Web程序设计的学习

属于B/S结构、服务器软件&#xff1a;Apache Tomcat、 Web 项目 目录结构&#xff1a; 1.src目录&#xff1a;存放Java源文件 2.WebRoot目录&#xff1a; 存在两个子目录&#xff1a; META-INF目录 WEB-INF目录&#xff1a;&#xff08;lib目录&#xff1a;存放驱动…

Notepad++安装json插件

Notepad是Windows操作系统下的一套文本编辑器(软件版权许可证:GPL)&#xff0c;有完整的中文化接口及支持多国语言编写的功能(UTF8技术)。 Notepad功能比 Windows 中的Notepad(记事本)强大&#xff0c;除了可以用来制作一般的纯文字说明文件&#xff0c;也十分适合编写计算机程…

MySQL数据表:对数据的基础操作(增、删、查、改)以及运算符的讲解

目录 前言 一.增加数据 二.查询数据 2.1查询数据表中所有信息 2.2查询表中指定的列信息 2.3查询通过计算的列 2.4使用别名代替列名 2.5查询不带有重复值的列 2.6将查询的结果进行排序 2.7条件查询 2.7.1条件查询的种类 2.7.2使用运算符查询的讲解 2.8分页查询 …

2015年全国硕士研究生入学统一考试管理类专业学位联考写作试题

2015年1月真题&#xff1a; 四、写作&#xff1a;第56~57小题&#xff0c;共65 分。其中论证有效性分析30 分&#xff0c;论说文35 分。 56、论证有效性分析&#xff1a; 分析下述论证存在的缺陷和漏洞&#xff0c;选择若干要点&#xff0c;写一篇600 字的文章&#xff0c;对…

MyCat2介绍以及部署和读写分离/分库分表(MyCat2.0)

一&#xff0c;MyCat入门 1.什么是mycat 官网&#xff1a;http://www.mycat.org.cn/​ mycat是数据库中间件 它可以干什么&#xff1f; 读写分离数据分片&#xff1a;垂直拆分&#xff0c;水平拆分多数据源整合 2.数据库中间件 ​ 中间件&#xff1a;是一类连接软件组件和…

KSM01.2B-061C-35N-M1-HP0-SE-NN伺服电机力士乐

​ KSM01.2B-061C-35N-M1-HP0-SE-NN伺服电机力士乐 KSM01.2B-061C-35N-M1-HP0-SE-NN伺服电机力士乐 从应用对象的规模上来说&#xff1a; PLC一般应用在小型自控场所&#xff0c;比如设备的控制或少量的模拟量的控制及联锁&#xff0c;而大型的应用一般都是DCS。当然&#x…

STM32开发——DMA(数据搬运)

目录 1.DMA简介 2.从内存到内存搬运 2.1CubeMX设置 2.2函数代码 3.内存到外设 3.1CubeMX配置 3.2 函数代码 4.外设到内存 4.1CubeMX配置 4.1函数代码 1.DMA简介 DMA(Direct Memory Access&#xff0c;直接存储器访问) 提供在外设与内存、存储器和存储器、外设 与外设…

APM二次开发(二):添加一个任务

固件版本 APM copter 4.3.1 参考&#xff1a;https://ardupilot.org/dev/docs/code-overview-scheduling-your-new-code-to-run-intermittently.html APM添加任务比PX4要简单很多&#xff0c;直接在调度器里添加函数即可。 先定义一个要调度的函数my_test() 然后加到调度器中…

C++ [STL容器反向迭代器]

本文已收录至《C语言和高级数据结构》专栏&#xff01; 作者&#xff1a;ARMCSKGT STL容器反向迭代器 前言正文适配器反向迭代器反向迭代器框架默认成员函数反向迭代器的遍历反向迭代器的比较反向迭代器数据访问反向迭代器代码测试反向迭代器 最后 前言 我们知道STL大部分容器…

(2023最新版)互联网大厂1120道Java面试真题附答案详解

很多 Java 工程师的技术不错&#xff0c;但是一面试就头疼&#xff0c;10 次面试 9 次都是被刷&#xff0c;过的那次还是去了家不知名的小公司。 问题就在于&#xff1a;面试有技巧&#xff0c;而你不会把自己的能力表达给面试官。 应届生&#xff1a;你该如何准备简历&#…