信号处理--基于混合CNN和transfomer自注意力的多通道脑电信号的情绪分类的简单应用

news2024/11/19 9:32:50

目录

关于

工具

数据集

数据集简述

方法实现

数据读取

​编辑数据预处理

传统机器学习模型(逻辑回归,支持向量机,随机森林)

多层感知机模型

CNN+transfomer模型

代码获取


关于

  • 本实验利用结合了卷积神经网络 (CNN) 和 Transformer 组件的混合架构,实现基于 EEG 的有效情绪分类。
  • 尝试各种机器学习模型,包括逻辑回归、支持向量机 (SVM)、随机森林分类器和多层感知器 (MLP) 神经网络,以比较不同模型的性能。 

 图片来自于: https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9991178

工具

数据集

数据集简述

脑电图数据是从两名受试者(1 名男性、1 名女性,年龄 20-22 岁)收集的,针对特定电影剪辑引发的六种情绪状态(积极、消极、中性)中的每一种状态。该数据集包括从脑电波中收集的 324,000 个数据点,这些数据点被重新采样到 150Hz。还收集了中性脑电波数据,作为代表受试者静息情绪状态的第三类数据。从四个电极(TP9、AF7、AF8、TP10)记录 EEG 数据,并进行处理以生成通过 1 秒滑动窗口提取的统计特征数据集。

图片来源:https://www.researchgate.net/figure/This-figure-shows-the-standard-locations-for-measuring-EEG-as-per-10-20-International_fig2_358644174 

方法实现

数据读取
raw_eeg_data = pd.read_csv('../data/features_raw.csv')
raw_eeg_data.head()

# plot the F8 column
plt.figure(figsize=(20, 5))
plt.plot(raw_eeg_data['F8'])
plt.title('F8 Electrode Data')
plt.ylabel('Voltage (uV)')
plt.xlabel('Time')
plt.show()


# plot the F7 column
plt.figure(figsize=(20, 5))
plt.plot(raw_eeg_data['F7'])
plt.title('F7 Electrode Data')
plt.ylabel('Voltage (uV)')
plt.xlabel('Time')
plt.show()
数据预处理
X = eeg_emotions_data.drop(['label'], axis=1)
y = eeg_emotions_data['label']

# Encoding categorical data
from sklearn.preprocessing import LabelEncoder, OneHotEncoder

labelencoder_emotions = LabelEncoder()
y = labelencoder_emotions.fit_transform(y)


# Standardizing the features in the dataset
from sklearn.preprocessing import StandardScaler
scaler = StandardScaler()

X = scaler.fit_transform(X)

传统机器学习模型(逻辑回归,支持向量机,随机森林)
from sklearn.linear_model import LogisticRegression
import pickle

# Create a logistic regression classifier
model = LogisticRegression(random_state=2003, multi_class='multinomial', max_iter=1000)

# Train the model
model.fit(X_train, y_train)

# Evaluate the model
evaluate_model(y_test, model.predict(X_test))

from sklearn.svm import SVC

# Create a model: a support vector classifier
model = SVC(kernel='rbf', gamma='auto', C=1.0, random_state=2003)

# Train the model
model.fit(X_train, y_train)

# Evaluate the model
evaluate_model(y_test, model.predict(X_test))

from sklearn.ensemble import RandomForestClassifier

# Create a random forest Classifier.
model = RandomForestClassifier(n_estimators=100, random_state=2003)

# Train the model
model.fit(X_train, y_train)

# Evaluate the model
evaluate_model(y_test, model.predict(X_test))

在传统机器模型,我们可以发现随机森林的性能表现最好; 

多层感知机模型
class EEGClassifier(nn.Module):
    def __init__(self, input_dim, num_classes, hidden_dim=256):
        super(EEGClassifier, self).__init__()

        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_dim, num_classes)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x


input_dim = 2548  # Number of features in EEG signal
num_classes = 3   # Number of classes for classification
model = EEGClassifier(input_dim, num_classes)
loss = nn.CrossEntropyLoss()

CNN+transfomer模型
class EEGConformer(nn.Module):
    def __init__(self, input_dim, num_classes):
        super(EEGConformer, self).__init__()

        # CNN
        self.conv1 = nn.Conv2d(1, 40, kernel_size=(1, 25), stride=(1, 1))
        self.conv2 = nn.Conv2d(40, 40, kernel_size=(1, input_dim), stride=(1, 1))
        self.batchnorm = nn.BatchNorm2d(40)

        # Transformer
        self.layernorm1 = nn.LayerNorm(40)
        self.multiheadattention = nn.MultiheadAttention(40, 1)
        self.layernorm2 = nn.LayerNorm(40)

        self.feedworward_block = nn.Sequential(
            nn.Linear(40, 32),
            nn.GELU(),
            nn.Dropout(p=0.1),
            nn.Linear(32, 40)
        )

        # MLP
        self.fc1 = nn.Linear(40, 32)
        self.fc2 = nn.Linear(32, 32)
        self.fc3 = nn.Linear(32, num_classes)

    def forward(self, x):
        # CNN
        x = x.unsqueeze(1).unsqueeze(1)
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.batchnorm(x)

        # Transformer
        x = x.squeeze()
        x = self.layernorm1(x)
        attn_out = self.multiheadattention(x, x, x)
        x = x + nn.Dropout(0.1)(attn_out[0])
        x = self.layernorm2(x)
        x = self.feedworward_block(x)
        x = nn.Dropout(p=0.1)(x)

        # MLP
        x = self.fc1(x)
        x = F.elu(x)
        x = nn.Dropout(p=0.5)(x)
        x = self.fc2(x)
        x = F.elu(x)
        x = nn.Dropout(p=0.3)(x)
        x = self.fc3(x)
        
        return x

input_dim = 2524  # Number of features in EEG signal
num_classes = 3   # Number of classes for classification
model = EEGConformer(input_dim, num_classes)
loss = nn.CrossEntropyLoss()

代码获取

后台私信,注明来意和文章名称;

其他问题,欢迎沟通交流。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1547818.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

在DasViewer里怎么查看三维模型的坐标系?

量测就可以查看坐标系了,或者查看xml文件中坐标系的代号。量测就可以查看坐标系了,或者查看xml文件中坐标系的代号。 DasViewer是由大势智慧自主研发的免费的实景三维模型浏览器,采用多细节层次模型逐步自适应加载技术,让用户在极低的电脑配置下,也能流畅…

Go语言学习Day3:数据类型、运算符与流程控制

名人说:莫愁千里路,自有到来风。 ——钱珝 创作者:Code_流苏(CSDN)(一个喜欢古诗词和编程的Coder😊) 目录 1、数据类型①布尔类型②整型③浮点型④string⑤类型转换 2、运算符①算术运算符②逻辑运算符③关…

STM32学习笔记(6_8)- TIM定时器的编码器接口代码

无人问津也好,技不如人也罢,都应静下心来,去做该做的事。 最近在学STM32,所以也开贴记录一下主要内容,省的过目即忘。视频教程为江科大(改名江协科技),网站jiangxiekeji.com 现在开…

원클릭으로 주류 전자상거래 플랫폼 상품 상세 데이터 수집 및 접속 시연 예제 (한국어판)

클릭 한 번으로 전자상거래 플랫폼 데이터를 캡처하는 것은 일반적으로 웹 페이지에서 정보를 자동으로 추출 할 수있는 네트워크 파충류 기술과 관련됩니다.그러나 모든 형태의 데이터 수집은 해당 웹 사이트의 사용 약관 및 개인 정보 보호 정책 및 현지 법률 및 규정을 준수…

百度蜘蛛池平台在线发外链-原理以及搭建教程

蜘蛛池平台是一款非常实用的SEO优化工具,它可以帮助网站管理员提高网站的排名和流量。百度蜘蛛池原理是基于百度搜索引擎的搜索算法,通过对网页的内容、结构、链接等方面进行分析和评估,从而判断网页的质量和重要性,从而对网页进行…

JAVA面试大全之基础篇

目录 1、语法基础 1.1、面向对象特性?​​​​​​​ 1.2、a a b 与 a b 的区别 1.3、3*0.1 0.3 将会返回什么? true 还是 false? 1.4、能在 Switch 中使用 String 吗? 1.5、对equals()和hashCode()的理解? 1.6、final、finalize 和 finally 的不同之…

【物联网开源平台】tingsboard二次开发

别看这篇了,这篇就当我的一个记录,我有空我再写过一篇,编译的时候出现了一个错误,然后我针对那一个错误执行了一个命令,出现了绿色的succes,我就以为整个tingsboard项目编译成功了,后面发现的时候&#xff…

怎么清理苹果电脑内存?CleanMyMac X4.15.2最新中文版使用教程

近日,我的苹果电脑似乎遭遇了一点小麻烦,每当深入工作或沉浸于娱乐之时,突如其来的一个警告弹窗就像一颗冰凉的霰弹,打断了我所有的思绪:内存不足!!!怎么清理苹果电脑内存&#xff0…

机器学习——神经网络简单了解

一、神经网络基本概念 神经网络可以分为生物神经网络和人工神经网络 (1)生物神经网络,指的是生物脑内的神经元、突触等构成的神经网络,可以使生物体产生意识,并协助生物体思考、行动和管理各机体活动。 (2)人工神经网络,是目前热门的深度学习的研究…

蔓灵花组织wmRAT分析

wmRAT分析 MD5:35639088a2406aa9e22fa8c03e989983 样本分析 多次调用sleep函数绕过沙箱检测。 创建线程获取username computername 磁盘驱动器个数 通过域名microsoft.com获取ip地址 通过c2服务器域名maxdimservice.com获取ip地址85.239.53.31 408460函数获取…

44 el-dialog 的 appendToBody 属性, 导致 vue 响应式失效

前言 我们经常会碰到 一些 模型和视图 不同步的问题 通常意义上 主要的问题为 列表的某响应式数据更新着更新着 后面就变成非响应式对象了, 然后 就造成了 数据一直在更新, 但是 视图的渲染后面就未渲染了, 这是一个由于 模型上的问题 导致的数据的不在响应式更新 又或者 是…

借力AI+视频号电商,腾讯广告业务这驾马车能跑多远?

腾讯的“功劳簿”又添上了几笔。 日前,腾讯披露了2023年四季度及全年财报。报告显示,2023年,腾讯营收6090.15亿元,同比增长10%;调整后净利润(Non-IFRS)1576.88亿元,同比增长36%。 …

在stable diffusion中手指纠错的指令和关键词是什么?

在Stable Diffusion模型中,如果您想对生成的图像中的手指进行纠错,您可以在描述中使用特定的指令和关键词来引导模型关注于手指区域并作出调整。 "Perfect hand" (完美的手) "Five fingers" (五个…

中国科学院半导体研究所汪林望:在曙光超级计算机上对第一性原理计算软件LS3DF进行1000万个硅原子模拟

编者荐语: 面对纳米材料等大体系时,电荷补丁法可以计算几千甚至上万原子, 但是电荷补丁法作为非自洽计算,不能给出原子受力,也不能用来弛豫原子坐标。面对摩尔条纹或线性位错等问题,我们需要弛豫原子的坐标…

javaSwing模拟写字板

一、摘要 目前,很多新的技术领域都涉及到了Java语言,Java语言是面向对象编程,并且涉及到网络、多线程等重要的基础知识,因此Java语言也是学习面向对象编程和网络编程的首选语言。此简易JAVA写字板程序,使用Java程序编…

Object Detection--Loss Function:从IoU到CIoU

本篇总结Loss Function中的IoU系列代码。 1. IoU 交并集,两个框交集面积除以并集面积。(论写写画画的重要性)(找原文看看) """ box1[x1, y1, x2, y2] box2[x1, y1, x2, y2] return iou ""…

Qt 作业 24/3/26

1、实现闹钟 #ifndef WIDGET_H #define WIDGET_H#include <QWidget> #include <QTime> #include <QLineEdit>QT_BEGIN_NAMESPACE namespace Ui { class Widget; } QT_END_NAMESPACEclass Widget : public QWidget {Q_OBJECTpublic:Widget(QWidget *parent …

python(django)之单一接口管理功能后台开发

1、创建数据模型 在apitest/models.py下加入以下代码 class Apis(models.Model):Product models.ForeignKey(product.Product, on_deletemodels.CASCADE, nullTrue)# 关联产品IDapiname models.CharField(接口名称, max_length100)apiurl models.CharField(接口地址, max_…

关键技术解析:CH-99除硼树脂在超纯水制备中对硼高效去除的应用实践与性能优势

超纯水(UPW)是科技和研究领域的关键资源&#xff0c;其中硼元素的去除对于保证其品质至关重要。本文将介绍一种高效的除硼技术——Tulsimer CH-99树脂&#xff0c;并阐述其在超纯水制备中的应用及优势。 首先&#xff0c;让我们了解超纯水的制备过程。超纯水是通过一系列精密的…

JAVA面试大全之集合IO篇

目录 1、集合 1.1、Collection 1.1.1、集合有哪些类&#xff1f; 1.1.2、ArrayList的底层&#xff1f; 1.1.3、ArrayList自动扩容&#xff1f; 1.1.4、ArrayList的Fail-Fast机制&#xff1f; 1.2、MAP 1.2.1、Map有哪些类&#xff1f; 1.2.2、JDK7 HashMap如何实现…