基于MNE的EEGNet 神经网络的脑电信号分类实战(附完整源码)

news2024/12/19 17:00:34

利用MNE中的EEG数据,进行EEGNet神经网络的脑电信号分类实现:

代码:

代码主要包括一下几个步骤:
1)从MNE中加载脑电信号,并进行相应的预处理操作,得到训练集、验证集以及测试集,每个集中都包括数据和标签;
2)基于tensorflow构建EEGNet网络模型;
3)编译模型,配置损失函数、优化器和评估指标等,并进行模型训练和预测;
4)绘制训练集和验证集的损失曲线以及训练集和验证集的准确度曲线。
代码如下:

import mne
import os
from pathlib import Path
import numpy as np
from keras.src.utils import np_utils

from mne import io
from mne.datasets import sample
import matplotlib.pyplot as plt
import pathlib

from keras.models import Model
from keras.layers import Dense, Activation, Permute, Dropout
from keras.layers import Conv2D, MaxPooling2D, AveragePooling2D
from keras.layers import SeparableConv2D, DepthwiseConv2D
from keras.layers import BatchNormalization
from keras.layers import SpatialDropout2D
from keras.regularizers import l1_l2
from keras.layers import Input, Flatten
from keras.constraints import max_norm
from keras import backend as K
from keras.src.callbacks import ModelCheckpoint

from sklearn.pipeline import make_pipeline
from sklearn.linear_model import LogisticRegression


def EEGNet(nb_classes, Chans=64, Samples=128,
           dropoutRate=0.5, kernelLength=64,
           F1=8, D=2, F2=16, norm_rate=0.25,
           dropout_type='Dropout'):
    """
    EEGNet模型的实现。

    参数:
    - nb_classes: int, 输出类别的数量。
    - Chans: int, 通道数,默认为64。
    - Samples: int, 每个通道的样本数,默认为128。
    - dropoutRate: float, Dropout率,默认为0.5。
    - kernelLength: int, 卷积核的长度,默认为64。
    - F1: int, 第一个卷积层的滤波器数量,默认为8。
    - D: int, 深度乘法器,默认为2。
    - F2: int, 第二个卷积层的滤波器数量,默认为16。
    - norm_rate: float, 权重范数约束,默认为0.25。
    - dropout_type: str, Dropout类型,默认为'Dropout'。

    返回:
    - Model: Keras模型对象。
    """

    # 根据dropout_type参数确定使用哪种Dropout方式
    if dropout_type == 'SpatialDropout2D':
        dropoutType = SpatialDropout2D
    elif dropout_type == 'Dropout':
        dropoutType = Dropout
    else:
        raise ValueError('dropout_type must be one of SpatialDropout2D '
                         'or Dropout, passed as a string.')

    # 定义模型的输入层
    input1 = Input(shape=(Chans, Samples, 1))

    # 第一个卷积块
    block1 = Conv2D(F1, (1, kernelLength), padding='same',
                       input_shape=(Chans, Samples, 1),
                       use_bias=False)(input1)
    block1 = BatchNormalization()(block1)
    block1 = DepthwiseConv2D((Chans, 1), use_bias=False,
                               depth_multiplier=D,
                               depthwise_constraint=max_norm(1.))(block1)
    block1 = BatchNormalization()(block1)
    block1 = Activation('elu')(block1)
    block1 = AveragePooling2D((1, 4))(block1)
    block1 = dropoutType(dropoutRate)(block1)

    # 第二个卷积块
    block2 = SeparableConv2D(F2, (1, 16),
                               use_bias=False, padding='same')(block1)
    block2 = BatchNormalization()(block2)
    block2 = Activation('elu')(block2)
    block2 = AveragePooling2D((1, 8))(block2)
    block2 = dropoutType(dropoutRate)(block2)

    # 将卷积块的输出展平以便输入到全连接层
    flatten = Flatten(name='flatten')(block2)

    # 定义全连接层
    dense = Dense(nb_classes, name='dense', kernel_constraint=max_norm(norm_rate))(flatten)
    softmax = Activation('softmax', name='softmax')(dense)

    # 创建并返回模型
    return Model(inputs=input1, outputs=softmax)


def get_data4EEGNet(kernels, chans, samples):
    """
    为EEGNet模型准备数据。

    该函数从指定的文件路径中读取原始EEG数据和事件数据,进行预处理,
    包括滤波、选择通道、分割数据集,并将数据集按给定的通道、核数和样本数进行重塑。

    参数:
    kernels - 数据集中的核数量。
    chans - 数据集中的通道数量。
    samples - 数据集中的样本数量。

    返回:
    X_train, X_validate, X_test, y_train, y_validate, y_test - 分别是训练、验证和测试数据集,
    以及相应的标签。
    """
    # 设置图像数据格式,确保数据维度顺序正确
    K.set_image_data_format('channels_last')

    # 定义数据路径
    data_path = Path("C:\\Users\\72671\\mne_data\\MNE-sample-data")

    # 定义原始数据和事件数据的文件路径
    raw_fname = os.path.join(data_path, "MEG", "sample", "sample_audvis_filt-0-40_raw.fif")
    event_fname = os.path.join(data_path, "MEG", "sample", "sample_audvis_filt-0-40_raw-eve.fif")

    # 定义时间范围和事件ID
    tmin, tmax = -0., 1
    event_id = dict(aud_l=1, aud_r=2, vis_l=3, vis_r=4)

    # 读取并预处理原始数据
    raw = io.Raw(raw_fname, preload=True, verbose=False)
    raw.filter(2, None, method='iir')
    events = mne.read_events(event_fname)

    # 设置无效通道并选择所需通道类型
    raw.info['bads'] = ['MEG 2443']
    picks = mne.pick_types(raw.info, meg=False, eeg=True, stim=False, eog=False,
                           exclude='bads')

    # 创建epochs数据集
    epochs = mne.Epochs(raw, events, event_id, tmin, tmax, proj=False, picks=picks, baseline=None,
                        preload=True, verbose=False)
    labels = epochs.events[:, -1]

    # 获取数据并进行缩放
    X = epochs.get_data(copy=False) * 1e6
    y = labels

    # 分割数据集为训练、验证和测试集
    X_train = X[0:144, ]
    y_train = y[0:144]
    X_validate = X[144:216, ]
    y_validate = y[144:216]
    X_test = X[216:, ]
    y_test = y[216:]

    # 将训练、验证和测试数据集中的标签转换为one-hot编码
    # 减1是因为标签通常从1开始计数,而one-hot编码需要从0开始
    y_train = np_utils.to_categorical(y_train-1)
    y_validate = np_utils.to_categorical(y_validate-1)
    y_test = np_utils.to_categorical(y_test-1)


    # 重塑数据集以匹配EEGNet模型的输入要求
    X_train = X_train.reshape(X_train.shape[0], chans, samples, kernels)
    X_validate = X_validate.reshape(X_validate.shape[0], chans, samples, kernels)
    X_test = X_test.reshape(X_test.shape[0], chans, samples, kernels)

    # 返回准备好的数据集
    return X_train, X_validate, X_test, y_train, y_validate, y_test


#########################################################################
# 定义模型参数
kernels, chans, samples = 1, 60, 151
# 获取预处理后的EEG数据集
X_train, X_validate, X_test, y_train, y_validate, y_test = get_data4EEGNet(kernels, chans, samples)

# 初始化EEGNet模型
model = EEGNet(nb_classes=4, Chans=chans, Samples=samples, dropoutRate=0.5,
               kernelLength=32, F1=8, D=2, F2=16, dropout_type='Dropout')

# 编译模型,配置损失函数、优化器和评估指标
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

# 设置模型检查点以保存最佳模型
checkpointer = ModelCheckpoint(filepath='./models/EEGNet_best_model.h5', verbose=1, save_best_only=True)
# 定义类别权重
class_weights = {0: 1, 1: 1, 2: 1, 3: 1}

# 训练模型
fittedModel = model.fit(X_train, y_train, batch_size=32, epochs=500, verbose=2,
                        validation_data=(X_validate, y_validate),
                        callbacks=[checkpointer], class_weight=class_weights)

# 加载最佳模型权重
model.load_weights('./models/EEGNet_best_model.h5')

# 对测试集进行预测
probs = model.predict(X_test)
# 获取预测标签
preds = probs.argmax(axis=-1)
# 计算分类准确率
acc = np.mean(preds == y_test.argmax(axis=-1))

# 输出分类准确率
print("Classification accuracy: %f " % (acc))


# 获取训练历史
history = fittedModel.history

# 绘制训练集和验证集的损失曲线
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(history['loss'], label='Training Loss')
plt.plot(history['val_loss'], label='Validation Loss')
plt.title('Loss Curves')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

# 绘制训练集和验证集的准确度曲线
plt.subplot(1, 2, 2)
plt.plot(history['accuracy'], label='Training Accuracy')
plt.plot(history['val_accuracy'], label='Validation Accuracy')
plt.title('Accuracy Curves')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()

plt.tight_layout()
plt.show()

效果如下:

在这里插入图片描述

参考资料:

论文链接: EEGNet: a compact convolutional neural network for EEG-based brain–computer interfaces(Journal of Neural Engineering,SCI JCR2,Impact Factor:4.141)
Github链接: the Army Research Laboratory (ARL) EEGModels project

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2262268.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Element@2.15.14-tree checkStrictly 状态实现父项联动子项,实现节点自定义编辑、新增、删除功能

背景:现在有一个新需求,需要借助树结构来实现词库的分类管理,树的节点是不同的分类,不同的分类可以有自己的词库,所以父子节点是互不影响的;同样为了选择的方便性,提出了新需求,选择…

SAP-ABAP开发学习-面向对象开发ooalv(2)

SAP-ABAP开发学习-面向对象OOALV(1)-CSDN博客 本文目录 一、类的继承 多态性类继承的实现 二、抽象类 三、最终类 四、接口 五、定义全局对象 一、类的继承 继承的本质是代码重用。当我们要构造一个新类时,无需从零开始,可…

典型案例 | 旧PC新蜕变!东北师范大学依托麒麟信安云“旧物焕新生”

东北师范大学始建于1946年,坐落于吉林省长春市,是中国共产党在东北地区创建的第一所综合性大学。作为国家“双一流”建设高校,学校高度重视教学改革和科技创新,校园信息化建设工作始终走在前列。基于麒麟信安云,东北师…

Linux脚本语言学习--上

1.shell概述 1.1 shell是什么? Shell是一个命令行解释器,他为用户提供了一个向Linux内核发送请求以便运行程序的界面系统级程序,用户可以使用Shell来启动,挂起,停止甚至是编写一些程序。 Shell还是一个功能相当强大…

2024年底-Sre面试问题总结-持续更新

这几个缩写 贴一下是因为真的会有人问:( SRE “Site Reliability Engineer” 站点可靠性工程师 SLA “Service Level Agreement” 服务可用性协议 CICD “Continuos Integration Continous Deployment” 持续集成 持续部署 3个高频问题 K8s生产环境中处理过哪些复杂 or 印象…

【硬件接口】I2C总线接口

本文章是笔者整理的备忘笔记。希望在帮助自己温习避免遗忘的同时,也能帮助其他需要参考的朋友。如有谬误,欢迎大家进行指正。 一、概述 I2C总线是一种非常常用的总线,其多用于一个主机(或多个)与单个或多个从设备通讯…

OkHttp源码分析:分发器任务调配,拦截器责任链设计,连接池socket复用

目录 一,分发器和拦截器 二,分发器处理异步请求 1.分发器处理入口 2.分发器工作流程 3.分发器中的线程池设计 三,分发器处理同步请求 四,拦截器处理请求 1.责任链设计模式 2.拦截器工作原理 3.OkHttp五大拦截器 一&#…

SAP:如何修改已释放的请求

SAP:如何修改已释放的请求 QQ出了一个新功能,把10年前的旧日志推给自己。这个10年前的日志,是用户反映在SE10中把请求释放后发现漏了内容,想修改已释放的请求。经调查写了一个小程序,实现用户的需求。 *&-------------------…

python怎么循环嵌套

嵌套循环: 概念:循环中再定义循环,称为嵌套循环; 【注意】嵌套循环可能有多层,但是一般我们实际开发最多两层就可以搞定了(99%的情况) 格式: 1、while中套while常用 2、while中套for in 3、for in中套…

前端优雅(装逼)写法(updating····)

1.>>右位移运算符取整数 它将一个数字的二进制位向右移动指定的位数,并在左侧填充符号位(即负数用1填充,正数用0填充)。 比如 2.99934 >> 0:取整结果是2,此处取整并非四舍五入 2.99934 会先…

MySQL -- 库的相关操作

目录 查看数据库 创建数据库 直接创建: 加约束条件 if not exists 字符集和校对规则 什么是字符集 什么是校对规则 校对规则的主要功能 校对规则的特性 查看指定的数据库使用的字符集和校对规则: 比较是否区分大小写字母差异 显示创建语句 …

Moretl开箱即用日志采集

永久免费: 至Gitee下载 使用教程: Moretl使用说明 使用咨询: 用途 定时全量或增量采集工控机,电脑文件或日志. 优势 开箱即用: 解压直接运行.不需额外下载.管理设备: 后台统一管理客户端.无人值守: 客户端自启动,自更新.稳定安全: 架构简单,兼容性好,通过授权控制访问. 架…

分享一次接口性能摸底测试过程

接口性能测试是用于验证应用程序中的接口是否可以满足系统的性能要求的一种测试方法。确定应用程序在各种负载条件下的性能指标,例如响应时间、吞吐量、并发性能等,以便提高系统的性能和可靠性。本文主要讲述接口性能测试从前期准备、方案设计到环境搭建…

【机器学习】机器学习的基本分类-无监督学习-t-SNE(t-分布随机邻域嵌入)

t-SNE(t-分布随机邻域嵌入) t-SNE(t-distributed Stochastic Neighbor Embedding)是一种用于降维的非线性技术,常用于高维数据的可视化。它特别适合展示高维数据在二维或三维空间中的分布结构,同时能够很好…

【教学类-83-03】20241218立体书盘旋蛇3.0——圆点蛇1(蚊香形)

背景需求: 制作儿童简易立体书贺卡 【教学类-83-01】20241215立体书三角嘴1.0——小鸡(正菱形嘴)-CSDN博客文章浏览阅读1k次,点赞24次,收藏18次。【教学类-83-01】20241215立体书三角嘴1.0——小鸡(正菱形…

监控视频汇聚融合云平台一站式解决视频资源管理痛点

随着5G技术的广泛应用,各领域都在通信技术加持下通过海量终端设备收集了大量视频、图像等物联网数据,并通过人工智能、大数据、视频监控等技术方式来让我们的世界更安全、更高效。然而,随着数字化建设和生产经营管理活动的长期开展&#xff0…

JAVA 零拷贝技术和主流中间件零拷贝技术应用

目录 介绍Java代码里面有哪些零拷贝技术java 中文件读写方式主要分为什么是FileChannelmmap实现sendfile实现 文件IO实战需求代码编写实战IOTest.java 文件上传阿里云,测试运行代码看耗时为啥带buffer的IO比普通IO性能高?BufferedInputStream为啥性能高点…

云灾备技术

目录 云灾备分类与定义 云容灾定义与主要应用场景 云容灾定义 应用场景 云备份定义与主要应用场景 云备份定义 应用场景 云容灾参考模型与关键技术 云备份参考模型与关键技术 云灾备分类与定义 云容灾技术是指保护云数据中心业务持续性的灾备技术,它是云灾…

进程通信方式---共享映射区(无血缘关系用的)

5.共享映射区(无血缘关系用的) 文章目录 5.共享映射区(无血缘关系用的)1.概述2.mmap&&munmap函数3.mmap注意事项4.mmap实现进程通信父子进程练习 无血缘关系 5.mmap匿名映射区 1.概述 原理:共享映射区是将文件…

leetcode 面试经典 150 题:长度最小的子数组

链接长度最小的子数组题序号209题型数组解题方法滑动窗口难度中等 题目 给定一个含有 n 个正整数的数组和一个正整数 target 。找出该数组中满足其总和大于等于 target 的长度最小的 子数组 [numsl, numsl1, …, numsr-1, numsr] ,并返回其长度。如果不存在符合条件…