Braindecode系列:在BCIC IV 2a数据集上进行试验
- 0. 引言
- 1. 环境介绍
- 1.1 环境配置
- 1.2 运行环境
- 2. Python实现
- 2.1 加载和预处理数据集
- 2.2 创建模型
- 2.3 模型训练
- 2.4 结果输出图像
- 3. 结果展示
- 4. 总结
0. 引言
最近在看运动想象相关的论文时,找到了一个很好的关于脑电信号处理的深度学习库,名为:Braindecode
。在该库包中,集成了众多模型,包括:EEGNet
、Shallow_fbcsp
、Atcnet
、Tcn
等。这里就如何使用Braindecode
进行简单的介绍,本节内容主要介绍一个小项目:在BCI IV 2a数据集上进行试验
。该项目主要分为四个系列,分别为:
- 在
BCIC IV 2a
数据集上进行试验(trialwise decoding
) BCIC IV 2a
数据集上的裁剪解码 (cropped decoding
)BCIC IV 2a
数据集的数据增强- 使用自定义数据集
Braindecode
项目地址:Braindecode:一个解决脑电信号处理的深度学习模型的开源Python库
1. 环境介绍
1.1 环境配置
首先,介绍下如何配置项目所需的环境。项目所需环境配置
主要分为三个步骤:
Pytorch
的配置。项目里面所用的诸多深度学习模型均基于Pytorch
框架,因此需要安装Pytorch
框架,具体安装需根据自己电脑显卡
及cuda
版本来进行;moabb
的配置。直接调用命令即可:pip install moabb
braindecode
的配置。直接调用命令即可:pip install braindecode
到这里,项目的环境就配置完成
了!!!
1.2 运行环境
虽然你已经根据上述内容配置好了自己的环境,但是如果环境版本不合适
往往会出现这样或者那样的问题。这里给出了我配置过程中的一个问题
,感兴趣的读者可以观看:Bug小能手系列(python)10: 使用Braindecode库报错 ‘EEGClassifier‘ object has no attribute ‘classes_inferred‘。另外,下面我给出了自己可以运行代码的对应环境
,如果你出现环境问题可以按照下面的版本进行配置环境。
Package Version
------------------------ ------------
aiofiles 22.1.0
aiosqlite 0.18.0
albumentations 1.2.1
anyio 3.5.0
appdirs 1.4.4
argon2-cffi 21.3.0
argon2-cffi-bindings 21.2.0
asttokens 2.0.5
attrs 22.1.0
Babel 2.11.0
backcall 0.2.0
beautifulsoup4 4.12.2
bleach 4.1.0
Braindecode 0.7
brotlipy 0.7.0
certifi 2023.5.7
cffi 1.15.1
charset-normalizer 3.2.0
chinese-calendar 1.8.0
colorama 0.4.6
comm 0.1.2
contourpy 1.1.0
coverage 7.2.7
cryptography 39.0.1
cycler 0.11.0
debugpy 1.5.1
decorator 5.1.1
defusedxml 0.7.1
entrypoints 0.4
executing 0.8.3
fastjsonschema 2.16.2
fonttools 4.40.0
h5py 3.9.0
idna 3.4
importlib-metadata 6.0.0
importlib-resources 5.12.0
ipykernel 6.19.2
ipython 8.12.0
ipython-genutils 0.2.0
ipywidgets 8.0.4
jedi 0.18.1
Jinja2 3.1.2
joblib 1.2.0
json5 0.9.6
jsonschema 4.17.3
jupyter 1.0.0
jupyter_client 8.1.0
jupyter-console 6.6.3
jupyter_core 5.3.0
jupyter-events 0.6.3
jupyter_server 2.5.0
jupyter_server_fileid 0.9.0
jupyter_server_terminals 0.4.4
jupyter_server_ydoc 0.8.0
jupyter-ydoc 0.2.4
jupyterlab 3.6.3
jupyterlab-pygments 0.1.2
jupyterlab_server 2.22.0
jupyterlab-widgets 3.0.5
kiwisolver 1.4.4
lxml 4.9.2
MarkupSafe 2.1.1
matplotlib 3.7.1
matplotlib-inline 0.1.6
memory-profiler 0.61.0
mistune 0.8.4
mkl-fft 1.3.6
mkl-random 1.2.2
mkl-service 2.4.0
mne 1.4.2
moabb 0.5.0
nbclassic 0.5.5
nbclient 0.5.13
nbconvert 6.5.4
nbformat 5.7.0
nest-asyncio 1.5.6
notebook 6.5.4
notebook_shim 0.2.2
numpy 1.25.0
packaging 23.0
pandas 1.5.3
pandocfilters 1.5.0
parso 0.8.3
pickleshare 0.7.5
Pillow 10.0.0
pip 23.1.2
platformdirs 2.5.2
ply 3.11
pooch 1.7.0
prometheus-client 0.14.1
prompt-toolkit 3.0.36
psutil 5.9.0
pure-eval 0.2.2
pycparser 2.21
Pygments 2.15.1
pyOpenSSL 23.0.0
pyparsing 3.1.0
pypiwin32 223
PyQt5 5.15.7
PyQt5-sip 12.11.0
pyriemann 0.3
pyrsistent 0.18.0
PySocks 1.7.1
python-dateutil 2.8.2
python-json-logger 2.0.7
pyttsx3 2.90
pytz 2022.7
PyWavelets 1.4.1
pywin32 305.1
pywinpty 2.0.10
PyYAML 6.0
pyzmq 25.1.0
qtconsole 5.4.2
QtPy 2.2.0
qudida 0.0.4
requests 2.29.0
rfc3339-validator 0.1.4
rfc3986-validator 0.1.1
scikit-learn 1.2.2
scipy 1.10.1
seaborn 0.12.2
Send2Trash 1.8.0
setuptools 67.8.0
sip 6.6.2
six 1.16.0
skorch 0.14.0
sniffio 1.2.0
soupsieve 2.4
stack-data 0.2.0
tabulate 0.9.0
terminado 0.17.1
threadpoolctl 2.2.0
tinycss2 1.2.1
toml 0.10.2
tomli 2.0.1
torch 1.12.1+cu116
torchaudio 0.12.1+cu116
torchvision 0.13.1+cu116
tornado 6.2
tqdm 4.65.0
traitlets 5.7.1
typing_extensions 4.6.3
tzdata 2023.3
urllib3 1.26.16
wcwidth 0.2.5
webencodings 0.5.1
websocket-client 0.58.0
wheel 0.38.4
widgetsnbextension 4.0.5
win-inet-pton 1.1.0
y-py 0.5.9
ypy-websocket 0.8.2
zipp 3.11.0
2. Python实现
主要目的:介绍如何在经典的EEG
设置中使用Braindecode
训练和测试深度学习模型:您有带标签的数据试验(例如,右手、左手等)。
2.1 加载和预处理数据集
首先,我们加载数据
。在本教程中,我们使用 braindecodes
通过MOABB
加载数据集以加载BCIC IV 2a
数据。具体代码如下:
from braindecode.datasets import MOABBDataset
subject_id = 3
# BNCI2014001 表示 BCIC IV 2a 数据集 subject_ids表示试验者编号
dataset = MOABBDataset(dataset_name="BNCI2014001", subject_ids=[subject_id])
然后,我们对数据进行预处理
操作。我们将带通滤波
等预处理应用于数据集。您可以应用mne.Raw
或mne.Epochs
提供的函数,也可以将自己的函数应用于mne
对象或底层numpy
数组。具体代码如下:
from braindecode.preprocessing import (
exponential_moving_standardize, preprocess, Preprocessor)
from numpy import multiply
low_cut_hz = 4. # low cut frequency for filtering
high_cut_hz = 38. # high cut frequency for filtering
# Parameters for exponential moving standardization
factor_new = 1e-3
init_block_size = 1000
# Factor to convert from V to uV
factor = 1e6
preprocessors = [
Preprocessor('pick_types', eeg=True, meg=False, stim=False), # Keep EEG sensors
Preprocessor(lambda data: multiply(data, factor)), # Convert from V to uV
Preprocessor('filter', l_freq=low_cut_hz, h_freq=high_cut_hz), # Bandpass filter
Preprocessor(exponential_moving_standardize, # Exponential moving standardization
factor_new=factor_new, init_block_size=init_block_size)
]
# Transform the data
preprocess(dataset, preprocessors)
剪切计算窗口也是数据预处理的重要一步
。计算窗口
,即训练期间深度网络的输入
。在trialwise decoding
的情况下,我们只需要决定是否要在试验之前和/或之后
剪切一些部分。对于这个数据集,在我们的工作中,在试验前截取500毫秒
通常是有益
的。具体代码如下:
from braindecode.preprocessing import create_windows_from_events
trial_start_offset_seconds = -0.5
# Extract sampling frequency, check that they are same in all datasets
sfreq = dataset.datasets[0].raw.info['sfreq']
assert all([ds.raw.info['sfreq'] == sfreq for ds in dataset.datasets])
# Calculate the trial start offset in samples.
trial_start_offset_samples = int(trial_start_offset_seconds * sfreq)
# Create windows using braindecode function for this. It needs parameters to define how
# trials should be used.
windows_dataset = create_windows_from_events(
dataset,
trial_start_offset_samples=trial_start_offset_samples,
trial_stop_offset_samples=0,
preload=True,
)
最后,同传统网络结构一样,对数据集进行切分
,将数据集分为训练集和验证集。由于在数据集存储过程中已经完成了切分,这里直接使用存储在描述属性中的附加信息轻松地拆分数据集。具体代码如下:
splitted = windows_dataset.split('session')
train_set = splitted['session_T']
valid_set = splitted['session_E']
2.2 创建模型
现在我们创建深度学习模型!Braindecode
为原始时域EEG
提供了一些预定义的卷积神经网络架构。在这里,我们使用深度学习的浅层ConvNet
模型和卷积神经网络进行EEG解码
和可视化。这些模型都是纯PyTorch
深度学习模型,因此要使用自己的模型,它只需要是一个普通的PyTorch nn.Module
。具体代码如下:
import torch
from braindecode.util import set_random_seeds
from braindecode.models import ShallowFBCSPNet
cuda = torch.cuda.is_available() # check if GPU is available, if True chooses to use it
device = 'cuda' if cuda else 'cpu'
if cuda:
torch.backends.cudnn.benchmark = True
# Set random seed to be able to roughly reproduce results
# Note that with cudnn benchmark set to True, GPU indeterminism
# may still make results substantially different between runs.
# To obtain more consistent results at the cost of increased computation time,
# you can set `cudnn_benchmark=False` in `set_random_seeds`
# or remove `torch.backends.cudnn.benchmark = True`
seed = 20200220
set_random_seeds(seed=seed, cuda=cuda)
n_classes = 4
# Extract number of chans and time steps from dataset
n_chans = train_set[0][0].shape[0]
input_window_samples = train_set[0][0].shape[1]
model = ShallowFBCSPNet(
n_chans,
n_classes,
input_window_samples=input_window_samples,
final_conv_length='auto',
)
# Send model to GPU
if cuda:
model.cuda()
2.3 模型训练
现在我们训练网络模型
!EEGClassifier
是一个Braindecode
对象,负责管理神经网络的训练
。它继承自skorch.NeuralNetClassifier
,因此训练逻辑与Skorch
中的相同。具体代码如下:
注意:在本教程中,我们使用了一些默认参数,这些参数在运动解码中效果良好,但我们强烈建议您对训练数据使用交叉验证来执行自己的超参数优化。
from skorch.callbacks import LRScheduler
from skorch.helper import predefined_split
from braindecode import EEGClassifier
# These values we found good for shallow network:
lr = 0.0625 * 0.01
weight_decay = 0
# For deep4 they should be:
# lr = 1 * 0.01
# weight_decay = 0.5 * 0.001
batch_size = 64
n_epochs = 4
clf = EEGClassifier(
model,
criterion=torch.nn.NLLLoss,
optimizer=torch.optim.AdamW,
train_split=predefined_split(valid_set), # using valid_set for validation
optimizer__lr=lr,
optimizer__weight_decay=weight_decay,
batch_size=batch_size,
callbacks=[
"accuracy", ("lr_scheduler", LRScheduler('CosineAnnealingLR', T_max=n_epochs - 1)),
],
device=device,
)
# Model training for a specified number of epochs. `y` is None as it is already supplied
# in the dataset.
clf.fit(train_set, y=None, epochs=n_epochs)
2.4 结果输出图像
最后,我们使用Skorch
在整个训练过程中存储的历史来绘制精度和损失曲线
。具体代码如下:
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
import pandas as pd
# Extract loss and accuracy values for plotting from history object
results_columns = ['train_loss', 'valid_loss', 'train_accuracy', 'valid_accuracy']
df = pd.DataFrame(clf.history[:, results_columns], columns=results_columns,
index=clf.history[:, 'epoch'])
# get percent of misclass for better visual comparison to loss
df = df.assign(train_misclass=100 - 100 * df.train_accuracy,
valid_misclass=100 - 100 * df.valid_accuracy)
plt.style.use('seaborn')
fig, ax1 = plt.subplots(figsize=(8, 3))
df.loc[:, ['train_loss', 'valid_loss']].plot(
ax=ax1, style=['-', ':'], marker='o', color='tab:blue', legend=False, fontsize=14)
ax1.tick_params(axis='y', labelcolor='tab:blue', labelsize=14)
ax1.set_ylabel("Loss", color='tab:blue', fontsize=14)
ax2 = ax1.twinx() # instantiate a second axes that shares the same x-axis
df.loc[:, ['train_misclass', 'valid_misclass']].plot(
ax=ax2, style=['-', ':'], marker='o', color='tab:red', legend=False)
ax2.tick_params(axis='y', labelcolor='tab:red', labelsize=14)
ax2.set_ylabel("Misclassification Rate [%]", color='tab:red', fontsize=14)
ax2.set_ylim(ax2.get_ylim()[0], 85) # make some room for legend
ax1.set_xlabel("Epoch", fontsize=14)
# where some data has already been plotted to ax
handles = []
handles.append(Line2D([0], [0], color='black', linewidth=1, linestyle='-', label='Train'))
handles.append(Line2D([0], [0], color='black', linewidth=1, linestyle=':', label='Valid'))
plt.legend(handles, [h.get_label() for h in handles], fontsize=14)
plt.tight_layout()
为了更好地展现结果,这里给出了绘制混淆矩阵
的代码:
from sklearn.metrics import confusion_matrix
from braindecode.visualization import plot_confusion_matrix
# generate confusion matrices
# get the targets
y_true = valid_set.get_metadata().target
y_pred = clf.predict(valid_set)
# generating confusion matrix
confusion_mat = confusion_matrix(y_true, y_pred)
# add class labels
# label_dict is class_name : str -> i_class : int
label_dict = valid_set.datasets[0].windows.event_id.items()
# sort the labels by values (values are integer class labels)
labels = list(dict(sorted(list(label_dict), key=lambda kv: kv[1])).keys())
# plot the basic conf. matrix
plot_confusion_matrix(confusion_mat, class_names=labels)
3. 结果展示
这里展示了模型精度和损失变化曲线
以及最后的混淆矩阵
图像!!
4. 总结
到此,使用 Braindecode
库系列(1):在BCIC IV 2a数据集上进行试验 已经介绍完毕了!!! 如果有什么疑问欢迎在评论区提出,对于共性问题可能会后续添加到文章介绍中。
如果觉得这篇文章对你有用,记得点赞、收藏并分享给你的小伙伴们哦😄。