Braindecode系列 (1):在BCIC IV 2a数据集上进行试验

news2024/10/7 10:13:55

Braindecode系列:在BCIC IV 2a数据集上进行试验

  • 0. 引言
  • 1. 环境介绍
    • 1.1 环境配置
    • 1.2 运行环境
  • 2. Python实现
    • 2.1 加载和预处理数据集
    • 2.2 创建模型
    • 2.3 模型训练
    • 2.4 结果输出图像
  • 3. 结果展示
  • 4. 总结

0. 引言

最近在看运动想象相关的论文时,找到了一个很好的关于脑电信号处理的深度学习库,名为:Braindecode。在该库包中,集成了众多模型,包括:EEGNetShallow_fbcspAtcnetTcn等。这里就如何使用Braindecode 进行简单的介绍,本节内容主要介绍一个小项目:在BCI IV 2a数据集上进行试验。该项目主要分为四个系列,分别为:

  1. BCIC IV 2a数据集上进行试验(trialwise decoding
  2. BCIC IV 2a 数据集上的裁剪解码 (cropped decoding
  3. BCIC IV 2a数据集的数据增强
  4. 使用自定义数据集

Braindecode 项目地址:Braindecode:一个解决脑电信号处理的深度学习模型的开源Python库

1. 环境介绍

1.1 环境配置

首先,介绍下如何配置项目所需的环境。项目所需环境配置主要分为三个步骤:

  1. Pytorch的配置。项目里面所用的诸多深度学习模型均基于Pytorch框架,因此需要安装Pytorch框架,具体安装需根据自己电脑显卡cuda版本来进行;
  2. moabb的配置。直接调用命令即可:pip install moabb
  3. braindecode的配置。直接调用命令即可:pip install braindecode

到这里,项目的环境就配置完成了!!!

1.2 运行环境

虽然你已经根据上述内容配置好了自己的环境,但是如果环境版本不合适往往会出现这样或者那样的问题。这里给出了我配置过程中的一个问题,感兴趣的读者可以观看:Bug小能手系列(python)10: 使用Braindecode库报错 ‘EEGClassifier‘ object has no attribute ‘classes_inferred‘。另外,下面我给出了自己可以运行代码的对应环境,如果你出现环境问题可以按照下面的版本进行配置环境。

Package                  Version
------------------------ ------------
aiofiles                 22.1.0
aiosqlite                0.18.0
albumentations           1.2.1
anyio                    3.5.0
appdirs                  1.4.4
argon2-cffi              21.3.0
argon2-cffi-bindings     21.2.0
asttokens                2.0.5
attrs                    22.1.0
Babel                    2.11.0
backcall                 0.2.0
beautifulsoup4           4.12.2
bleach                   4.1.0
Braindecode              0.7
brotlipy                 0.7.0
certifi                  2023.5.7
cffi                     1.15.1
charset-normalizer       3.2.0
chinese-calendar         1.8.0
colorama                 0.4.6
comm                     0.1.2
contourpy                1.1.0
coverage                 7.2.7
cryptography             39.0.1
cycler                   0.11.0
debugpy                  1.5.1
decorator                5.1.1
defusedxml               0.7.1
entrypoints              0.4
executing                0.8.3
fastjsonschema           2.16.2
fonttools                4.40.0
h5py                     3.9.0
idna                     3.4
importlib-metadata       6.0.0
importlib-resources      5.12.0
ipykernel                6.19.2
ipython                  8.12.0
ipython-genutils         0.2.0
ipywidgets               8.0.4
jedi                     0.18.1
Jinja2                   3.1.2
joblib                   1.2.0
json5                    0.9.6
jsonschema               4.17.3
jupyter                  1.0.0
jupyter_client           8.1.0
jupyter-console          6.6.3
jupyter_core             5.3.0
jupyter-events           0.6.3
jupyter_server           2.5.0
jupyter_server_fileid    0.9.0
jupyter_server_terminals 0.4.4
jupyter_server_ydoc      0.8.0
jupyter-ydoc             0.2.4
jupyterlab               3.6.3
jupyterlab-pygments      0.1.2
jupyterlab_server        2.22.0
jupyterlab-widgets       3.0.5
kiwisolver               1.4.4
lxml                     4.9.2
MarkupSafe               2.1.1
matplotlib               3.7.1
matplotlib-inline        0.1.6
memory-profiler          0.61.0
mistune                  0.8.4
mkl-fft                  1.3.6
mkl-random               1.2.2
mkl-service              2.4.0
mne                      1.4.2
moabb                    0.5.0
nbclassic                0.5.5
nbclient                 0.5.13
nbconvert                6.5.4
nbformat                 5.7.0
nest-asyncio             1.5.6
notebook                 6.5.4
notebook_shim            0.2.2
numpy                    1.25.0
packaging                23.0
pandas                   1.5.3
pandocfilters            1.5.0
parso                    0.8.3
pickleshare              0.7.5
Pillow                   10.0.0
pip                      23.1.2
platformdirs             2.5.2
ply                      3.11
pooch                    1.7.0
prometheus-client        0.14.1
prompt-toolkit           3.0.36
psutil                   5.9.0
pure-eval                0.2.2
pycparser                2.21
Pygments                 2.15.1
pyOpenSSL                23.0.0
pyparsing                3.1.0
pypiwin32                223
PyQt5                    5.15.7
PyQt5-sip                12.11.0
pyriemann                0.3
pyrsistent               0.18.0
PySocks                  1.7.1
python-dateutil          2.8.2
python-json-logger       2.0.7
pyttsx3                  2.90
pytz                     2022.7
PyWavelets               1.4.1
pywin32                  305.1
pywinpty                 2.0.10
PyYAML                   6.0
pyzmq                    25.1.0
qtconsole                5.4.2
QtPy                     2.2.0
qudida                   0.0.4
requests                 2.29.0
rfc3339-validator        0.1.4
rfc3986-validator        0.1.1
scikit-learn             1.2.2
scipy                    1.10.1
seaborn                  0.12.2
Send2Trash               1.8.0
setuptools               67.8.0
sip                      6.6.2
six                      1.16.0
skorch                   0.14.0
sniffio                  1.2.0
soupsieve                2.4
stack-data               0.2.0
tabulate                 0.9.0
terminado                0.17.1
threadpoolctl            2.2.0
tinycss2                 1.2.1
toml                     0.10.2
tomli                    2.0.1
torch                    1.12.1+cu116
torchaudio               0.12.1+cu116
torchvision              0.13.1+cu116
tornado                  6.2
tqdm                     4.65.0
traitlets                5.7.1
typing_extensions        4.6.3
tzdata                   2023.3
urllib3                  1.26.16
wcwidth                  0.2.5
webencodings             0.5.1
websocket-client         0.58.0
wheel                    0.38.4
widgetsnbextension       4.0.5
win-inet-pton            1.1.0
y-py                     0.5.9
ypy-websocket            0.8.2
zipp                     3.11.0

2. Python实现

主要目的:介绍如何在经典的EEG设置中使用Braindecode训练和测试深度学习模型:您有带标签的数据试验(例如,右手、左手等)。

2.1 加载和预处理数据集

首先,我们加载数据。在本教程中,我们使用 braindecodes通过MOABB加载数据集以加载BCIC IV 2a数据。具体代码如下:

from braindecode.datasets import MOABBDataset

subject_id = 3
# BNCI2014001 表示 BCIC IV 2a 数据集   subject_ids表示试验者编号
dataset = MOABBDataset(dataset_name="BNCI2014001", subject_ids=[subject_id])

然后,我们对数据进行预处理操作。我们将带通滤波等预处理应用于数据集。您可以应用mne.Rawmne.Epochs提供的函数,也可以将自己的函数应用于mne对象或底层numpy数组。具体代码如下:

from braindecode.preprocessing import (
    exponential_moving_standardize, preprocess, Preprocessor)
from numpy import multiply

low_cut_hz = 4.  # low cut frequency for filtering
high_cut_hz = 38.  # high cut frequency for filtering
# Parameters for exponential moving standardization
factor_new = 1e-3
init_block_size = 1000
# Factor to convert from V to uV
factor = 1e6

preprocessors = [
    Preprocessor('pick_types', eeg=True, meg=False, stim=False),  # Keep EEG sensors
    Preprocessor(lambda data: multiply(data, factor)),  # Convert from V to uV
    Preprocessor('filter', l_freq=low_cut_hz, h_freq=high_cut_hz),  # Bandpass filter
    Preprocessor(exponential_moving_standardize,  # Exponential moving standardization
                 factor_new=factor_new, init_block_size=init_block_size)
]

# Transform the data
preprocess(dataset, preprocessors)

剪切计算窗口也是数据预处理的重要一步计算窗口,即训练期间深度网络的输入。在trialwise decoding的情况下,我们只需要决定是否要在试验之前和/或之后剪切一些部分。对于这个数据集,在我们的工作中,在试验前截取500毫秒通常是有益的。具体代码如下:

from braindecode.preprocessing import create_windows_from_events

trial_start_offset_seconds = -0.5
# Extract sampling frequency, check that they are same in all datasets
sfreq = dataset.datasets[0].raw.info['sfreq']
assert all([ds.raw.info['sfreq'] == sfreq for ds in dataset.datasets])
# Calculate the trial start offset in samples.
trial_start_offset_samples = int(trial_start_offset_seconds * sfreq)

# Create windows using braindecode function for this. It needs parameters to define how
# trials should be used.
windows_dataset = create_windows_from_events(
    dataset,
    trial_start_offset_samples=trial_start_offset_samples,
    trial_stop_offset_samples=0,
    preload=True,
)

最后,同传统网络结构一样,对数据集进行切分,将数据集分为训练集和验证集。由于在数据集存储过程中已经完成了切分,这里直接使用存储在描述属性中的附加信息轻松地拆分数据集。具体代码如下:

splitted = windows_dataset.split('session')
train_set = splitted['session_T']
valid_set = splitted['session_E']

2.2 创建模型

现在我们创建深度学习模型!Braindecode为原始时域EEG提供了一些预定义的卷积神经网络架构。在这里,我们使用深度学习的浅层ConvNet模型和卷积神经网络进行EEG解码和可视化。这些模型都是纯PyTorch深度学习模型,因此要使用自己的模型,它只需要是一个普通的PyTorch nn.Module。具体代码如下:

import torch
from braindecode.util import set_random_seeds
from braindecode.models import ShallowFBCSPNet

cuda = torch.cuda.is_available()  # check if GPU is available, if True chooses to use it
device = 'cuda' if cuda else 'cpu'
if cuda:
    torch.backends.cudnn.benchmark = True
# Set random seed to be able to roughly reproduce results
# Note that with cudnn benchmark set to True, GPU indeterminism
# may still make results substantially different between runs.
# To obtain more consistent results at the cost of increased computation time,
# you can set `cudnn_benchmark=False` in `set_random_seeds`
# or remove `torch.backends.cudnn.benchmark = True`
seed = 20200220
set_random_seeds(seed=seed, cuda=cuda)

n_classes = 4
# Extract number of chans and time steps from dataset
n_chans = train_set[0][0].shape[0]
input_window_samples = train_set[0][0].shape[1]

model = ShallowFBCSPNet(
    n_chans,
    n_classes,
    input_window_samples=input_window_samples,
    final_conv_length='auto',
)

# Send model to GPU
if cuda:
    model.cuda()

2.3 模型训练

现在我们训练网络模型EEGClassifier 是一个Braindecode对象,负责管理神经网络的训练。它继承自skorch.NeuralNetClassifier,因此训练逻辑与Skorch中的相同。具体代码如下:
注意:在本教程中,我们使用了一些默认参数,这些参数在运动解码中效果良好,但我们强烈建议您对训练数据使用交叉验证来执行自己的超参数优化。

from skorch.callbacks import LRScheduler
from skorch.helper import predefined_split

from braindecode import EEGClassifier
# These values we found good for shallow network:
lr = 0.0625 * 0.01
weight_decay = 0

# For deep4 they should be:
# lr = 1 * 0.01
# weight_decay = 0.5 * 0.001

batch_size = 64
n_epochs = 4

clf = EEGClassifier(
    model,
    criterion=torch.nn.NLLLoss,
    optimizer=torch.optim.AdamW,
    train_split=predefined_split(valid_set),  # using valid_set for validation
    optimizer__lr=lr,
    optimizer__weight_decay=weight_decay,
    batch_size=batch_size,
    callbacks=[
        "accuracy", ("lr_scheduler", LRScheduler('CosineAnnealingLR', T_max=n_epochs - 1)),
    ],
    device=device,
)
# Model training for a specified number of epochs. `y` is None as it is already supplied
# in the dataset.
clf.fit(train_set, y=None, epochs=n_epochs)

2.4 结果输出图像

最后,我们使用Skorch在整个训练过程中存储的历史来绘制精度和损失曲线。具体代码如下:

import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
import pandas as pd

# Extract loss and accuracy values for plotting from history object
results_columns = ['train_loss', 'valid_loss', 'train_accuracy', 'valid_accuracy']
df = pd.DataFrame(clf.history[:, results_columns], columns=results_columns,
                  index=clf.history[:, 'epoch'])

# get percent of misclass for better visual comparison to loss
df = df.assign(train_misclass=100 - 100 * df.train_accuracy,
               valid_misclass=100 - 100 * df.valid_accuracy)

plt.style.use('seaborn')
fig, ax1 = plt.subplots(figsize=(8, 3))
df.loc[:, ['train_loss', 'valid_loss']].plot(
    ax=ax1, style=['-', ':'], marker='o', color='tab:blue', legend=False, fontsize=14)

ax1.tick_params(axis='y', labelcolor='tab:blue', labelsize=14)
ax1.set_ylabel("Loss", color='tab:blue', fontsize=14)

ax2 = ax1.twinx()  # instantiate a second axes that shares the same x-axis

df.loc[:, ['train_misclass', 'valid_misclass']].plot(
    ax=ax2, style=['-', ':'], marker='o', color='tab:red', legend=False)
ax2.tick_params(axis='y', labelcolor='tab:red', labelsize=14)
ax2.set_ylabel("Misclassification Rate [%]", color='tab:red', fontsize=14)
ax2.set_ylim(ax2.get_ylim()[0], 85)  # make some room for legend
ax1.set_xlabel("Epoch", fontsize=14)

# where some data has already been plotted to ax
handles = []
handles.append(Line2D([0], [0], color='black', linewidth=1, linestyle='-', label='Train'))
handles.append(Line2D([0], [0], color='black', linewidth=1, linestyle=':', label='Valid'))
plt.legend(handles, [h.get_label() for h in handles], fontsize=14)
plt.tight_layout()

为了更好地展现结果,这里给出了绘制混淆矩阵的代码:

from sklearn.metrics import confusion_matrix
from braindecode.visualization import plot_confusion_matrix

# generate confusion matrices
# get the targets
y_true = valid_set.get_metadata().target
y_pred = clf.predict(valid_set)

# generating confusion matrix
confusion_mat = confusion_matrix(y_true, y_pred)

# add class labels
# label_dict is class_name : str -> i_class : int
label_dict = valid_set.datasets[0].windows.event_id.items()
# sort the labels by values (values are integer class labels)
labels = list(dict(sorted(list(label_dict), key=lambda kv: kv[1])).keys())

# plot the basic conf. matrix
plot_confusion_matrix(confusion_mat, class_names=labels)

3. 结果展示

这里展示了模型精度和损失变化曲线以及最后的混淆矩阵图像!!
在这里插入图片描述
在这里插入图片描述

4. 总结

到此,使用 Braindecode库系列(1):在BCIC IV 2a数据集上进行试验 已经介绍完毕了!!! 如果有什么疑问欢迎在评论区提出,对于共性问题可能会后续添加到文章介绍中。

如果觉得这篇文章对你有用,记得点赞、收藏并分享给你的小伙伴们哦😄。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/743827.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

在线培训系统的保障措施带来安全、可靠的学习环境

在今天的数字时代,越来越多的人选择在线培训系统作为学习的方式。然而,随着在线教育市场的不断增长,安全和可靠性成为消费者普遍关心的问题。因此,在线培训系统需要采取一系列保护措施以确保学生的数据和隐私得到保护,…

Python 运算符(二)

文章目录 Python逻辑运算符Python成员运算符Python身份运算符Python运算符优先级后记 Python逻辑运算符 Python语言支持逻辑运算符,以下假设变量 a 为 10, b为 20: 运算符逻辑表达式描述实例andx and y 布尔"与" - 如果 x 为 False,x and y …

php周练

前言:博主个人小练(纯小白)。 目录 1.[SWPUCTF 2021 新生赛]gift_F12已解决2.[SWPUCTF 2021 新生赛]jicao3.[ZJCTF 2019]NiZhuanSiWei4.[SWPUCTF 2021 新生赛]no_wakeup5.[SWPUCTF 2021 新生赛]ez_unserialize 1.[SWPUCTF 2021 新生赛]gift_…

Ae 效果:CC RepeTile

风格化/CC RepeTile Stylize/CC RepeTile CC RepeTile(CC 重复拼贴)效果可对整个图层进行复制并扩展,通过重复拼贴来创建平铺效果。 ◆ ◆ ◆ 效果属性说明 Expand Right 向右扩展 设置图层向右扩展的距离。 Expand Left 向左扩展 设置图层…

VMware vCenter Server 7.0 Update 3n 下载 - 集中管理 vSphere 环境

VMware vCenter Server 7.0 Update 3n 下载 - 集中管理 vSphere 环境 请访问原文链接:https://sysin.org/blog/vmware-vcenter-7-u3/,查看最新版。原创作品,转载请保留出处。 作者主页:sysin.org VMware vCenter Server 是一款高…

【菜菜丸的菜鸟教程】制作带闹铃和振动功能的仿真闹钟

一、准备闹钟模型 (一)下载模型 从Unity资源商店和其他模型网站可以下载到各种各样的闹钟模型。为了帮助大家了解机械钟表的设置原理,建议使用带有时针、分针和秒针的钟表,如下图。 注意:时针、分针和秒针最好是挂在闹钟父物体下的三个独立的…

【数据结构】--二叉树

注:本文树和二叉树的概念及结构部分有部分参考了别的文章,其他的二叉树的实现和性质习题等等都是自己一点点写的,创作不易,希望支持! ————————————————————— 目录 一. 树概念及结构 1、树概念…

springboot家具商城系统

开发语言:Java 框架:springboot JDK版本:JDK1.8 服务器:tomcat7 数据库:mysql 5.7(一定要5.7版本) 数据库工具:Navicat11 开发软件:eclipse/myeclipse/idea Maven…

Spring Boot 系列2 -- 配置文件

目录 1. 配置文件的作用 2. 配置文件的格式 3. properties 配置文件说明 3.1 properties 基本语法 3.2 读取配置文件 3.3 properties 缺点 4.yml 配置文件说明 4.1 yml 基本语法 4.2 yml 使用进阶 4.2.1 yml 配置不同数据类型及 null 4.2.2 yml 配置读取 4.2.3 注意…

FPGA学习——点亮流水灯

文章目录 一、前言二、源码三、ModelSim仿真3.1 tb文件源码:3.2 创建项目3.3 ModelSim仿真 一、前言 在FPGA开发板中,一般板载LED默认高电平点亮,因此我们只需要将想要亮起的LED赋值为1即可。 本入门实验要求为每隔1s开发板上的LED轮流亮起&…

STM32 Proteus仿真医用仓库环境控制系统紫外线消毒RS232上传CO2 -0066

STM32 Proteus仿真医用仓库环境控制系统紫外线消毒RS232上传CO2 -0066 Proteus仿真小实验: STM32 Proteus仿真医用仓库环境控制系统紫外线消毒RS232上传CO2 -0066 功能: 硬件组成:STM32F103R6单片机 LCD1602显示器DHT11温度湿度电位器模拟…

高分卫星影像及GIS技术在甘南泥石流灾害中的应用

本文使用的甘南夏河县泥石流灾情专题数据如下(来源于高分甘肃中心): (1)灾前遥感影像 (2)灾害位置 (3)基础地理数据:行政区划、交通路网、河流水系、湖泊水库…

前端Vue自定义精美steps步骤条进度条插件 物流信息跟踪展示组件 流程审批跟进组件

随着技术的发展,开发的复杂度也越来越高,传统开发方式将一个系统做成了整块应用,经常出现的情况就是一个小小的改动或者一个小功能的增加可能会引起整体逻辑的修改,造成牵一发而动全身。 通过组件化开发,可以有效实现…

怎么用PDF派工具将Word转成PDF

Word是我们最常用的一种格式文件,它易于编辑,但是安全性和稳定性较差,有时候我们发送给别人的Word文件,接收到打开内容已经乱码。遇到这种情况,我们可以优先将Word文件转换成稳定性好的PDF文件。那么如何进行文件格式转…

如何使用伪元素 ::before 实现 Antd 表单一模一样的 required 红色 * 号

如何使用伪元素 ::before 实现 Antd 表单一模一样的 required 红色 * 号 背景 以一个简单的 Form.Item 包裹 Select 为例 我们去实现它的 * 号 操作 F12 打开控制台选中这个元素上面查看 CSS 属性 仿照这个写在 .less 文件里 // .less .ruleTable::before {display: inlin…

RiProV2主题一级分类显示包含子分类的数量Ritheme主题美化WordPress美化类似的步骤

美化-RiProV2主题一级分类显示包含子分类的数量 WordPress主题一级分类页面显示包含子分类的数量 一级分类显示子分类相加的数量 原主题配置项 原来的RiProV2主题,虽然有个配置用来显示分类下的数量。 但是该数量有个问题,就是一级分类的数量显示不包含该一级分类下二级…

操作系统14:缓冲区和磁盘调度算法

目录 1、缓冲区管理 (1)单缓冲区和双缓冲区 1.1 - 单缓冲区 1.2 - 双缓冲区 (2)环形缓冲区/多缓冲区 (3)缓冲池(Buffer Pool) 3.1 - 缓冲池的组成 3.2 - 缓冲池的工作方式 2、磁盘存储器的性能和调…

面向对象进阶一(static,继承,多态)

面向对象进阶一 一、static二、继承2.1 继承的定义和特点2.2 继承内容、成员变量和成员方法的访问特点2.2.1继承内容2.2.2 成员变量的访问特点2.2.3 成员方法的访问方法、方法的重写 2.3 继承中构造方法的访问特点 三、this、super使用总结四、多态4.1 多态的基本概念4.2 多态调…

MySQL 学习笔记 2:触发器

MySQL 学习笔记 2:触发器 图源:ubiq.co 触发器,就像字面意思那样,它会在数据库某些事件发生时执行一些操作。 具体来说,触发器会在特定表的INSERT、UPDATE、DELETE这些类型的 SQL 语句执行时被“触发”,并…