【信号处理】基于DGGAN的单通道脑电信号增强和情绪检测(tensorflow)

news2025/1/9 22:50:30

关于

情绪检测,是脑科学研究中的一个常见和热门的方向。在进行情绪检测的分类中,真实数据不足,经常导致情绪检测模型的性能不佳。因此,对数据进行增强,成为了一个提升下游任务的重要的手段。本项目通过DCGAN模型实现脑电信号的扩充。

 图片来源:https://www.medicalnewstoday.com/articles/seizure-eeg

工具

数据

方法实现

DCGAN速递:https://arxiv.org/abs/1511.06434

数据加载和预处理
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
from sklearn.preprocessing import StandardScaler
from tensorflow.keras.utils import to_categorical
from sklearn.model_selection import train_test_split
import tensorflow as tf
from tensorflow.keras import Sequential
from tensorflow.keras.layers import Dense, Dropout
from tensorflow.keras.layers import Embedding
from tensorflow.keras.layers import LSTM
from tensorflow.keras.optimizers import SGD
from sklearn.metrics import accuracy_score
from model_DCGAN import DCGAN
from tensorflow.keras.optimizers import Adam, SGD, RMSprop
from sklearn.utils import shuffle
from sklearn.ensemble import GradientBoostingClassifier

use_feature_reduction = True

tf.keras.backend.clear_session()

df=pd.read_csv('dataset/emotions.csv')

encode = ({'NEUTRAL': 0, 'POSITIVE': 1, 'NEGATIVE': 2} )
#new dataset with replaced values
df_encoded = df.replace(encode)

print(df_encoded.head())
print(df_encoded['label'].value_counts()),

x=df_encoded.drop(["label"]  ,axis=1)
y = df_encoded.loc[:,'label'].values

scaler = StandardScaler()
scaler.fit(x)
x = scaler.transform(x)
y = to_categorical(y)

x_train, x_test, y_train, y_test = train_test_split(x, y, test_size = 0.2, random_state = 4)

if use_feature_reduction:
    # Feature reduction part
    est = GradientBoostingClassifier(n_estimators=10, learning_rate=0.1, random_state=0).fit(x_train,
                                                                                             y_train.argmax(-1))

    # Obtain feature importance results from Gradient Boosting Regressor
    feature_importance = est.feature_importances_
    epsilon_feature = 1e-2
    x_train = x_train[:, feature_importance > epsilon_feature]
    x_test = x_test[:, feature_importance > epsilon_feature]
设置DCGAN优化器

# setup optimzers
gen_optim = Adam(1e-4, beta_1=0.5)
disc_optim = RMSprop(5e-4)
 训练GAN生成类别0脑电数据
# generate samples for class 0
generator_class = 0
dcgan = DCGAN(gen_optim, disc_optim, noise_dim=100, dropout=0.3, input_dim=x_train.shape[2])
x_train_class_0 = x_train[y_train[:,generator_class]==1,:]
loss_history_class_0, acc_history_class_0, grads_history_class_0 = dcgan.train(x_train_class_0, epochs=100)
print("Class 0 fake samples are generating")
generator_class_0 = dcgan.generator
generated_samples_class_0, _ = dcgan.generate_fake_data(N=len(x_train_class_0))
  训练GAN生成类别1脑电数据
# generate samples for class 1
generator_class = 1
dcgan = DCGAN(gen_optim, disc_optim, noise_dim=100, dropout=0.3, input_dim=x_train.shape[2])
x_train_class_1 = x_train[y_train[:,generator_class]==1,:]
loss_history_class_1, acc_history_class_1, grads_history_class_1 = dcgan.train(x_train_class_1, epochs=100)
print("Class 1 fake samples are generating")
generator_class_1 = dcgan.generator
generated_samples_class_1, _ = dcgan.generate_fake_data(N=len(x_train_class_1))
 训练GAN生成类别2脑电数据
# generate samples for class 2
generator_class = 2
dcgan = DCGAN(gen_optim, disc_optim, noise_dim=100, dropout=0.3, input_dim=x_train.shape[2])
x_train_class_2 = x_train[y_train[:,generator_class]==1,:]
loss_history_class_2, acc_history_class_2, grads_history_class_2 = dcgan.train(x_train_class_2,epochs=100)
print("Class 2 fake samples are generating")
generator_class_2 = dcgan.generator
generated_samples_class_2, _ = dcgan.generate_fake_data(N=len(x_train_class_2))
合成数据融入真实训练数据集
generated_samples = np.concatenate((generated_samples_class_0,
                                    generated_samples_class_1,
                                    generated_samples_class_2),axis=0)
generated_y =np.concatenate((np.zeros((len(x_train_class_0),),dtype=np.int32),
                             np.ones((len(x_train_class_1),),dtype=np.int32),
                             2 * np.ones((len(x_train_class_2),),dtype=np.int32)),axis=0)

generated_y = to_categorical(generated_y)

x_train_all = np.concatenate((x_train,generated_samples),axis=0)
y_train_all = np.concatenate((y_train,generated_y), axis=0)


#shuffle training data
x_train_all, y_train_all = shuffle(x_train_all,y_train_all)
 基于数据增强的LSTM模型情绪检测
model = Sequential()
model.add(LSTM(64, input_shape=(1,x_train_all.shape[2]),activation="relu",return_sequences=True))
model.add(Dropout(0.2))
model.add(LSTM(32,activation="sigmoid"))
model.add(Dropout(0.2))
model.add(Dense(3, activation='sigmoid'))
model.compile(loss = 'categorical_crossentropy', optimizer = "adam", metrics = ['accuracy'])
model.summary()


history = model.fit(x_train_all, y_train_all, epochs = 250, validation_data= (x_test, y_test))
score, acc = model.evaluate(x_test, y_test)

pred = model.predict(x_test)
predict_classes = np.argmax(pred,axis=1)
expected_classes = np.argmax(y_test,axis=1)
print(expected_classes.shape)
print(predict_classes.shape)
correct = accuracy_score(expected_classes,predict_classes)
print(f"Test Accuracy: {correct}")

已附DCGAN模型

相关项目和代码问题,欢迎沟通交流。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1550644.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

全局UI方法-弹窗三-文本滑动选择器弹窗(TextPickDialog)

1、描述 根据指定的选择范围创建文本选择器,展示在弹窗上。 2、接口 TextPickDialog(options?: TextPickDialogOptions) 3、TextPickDialogOptions 参数名称 参数类型 必填 参数描述 rang string[] | Resource 是 设置文本选择器的选择范围。 selected nu…

企业选购USB服务器,有哪些注意事项?

一、前言 随着信息技术的飞速发展,USB服务器在企业日常运营中的作用日益凸显。USB服务器不仅实现了远程USB设备的共享和管理,还提升了企业数据的安全性和管理效率。 然而,面对市场上琳琅满目的USB服务器产品,企业在选购时需要注…

2024年天津财经大学珠江学院退役大学生士兵专升本专业课报名须知

天津财经大学珠江学院2024年高职升本科(面向退役大学生士兵)职业技能综合考查报考须知 一、报名条件 报考天津财经大学珠江学院2024年高职升本科职业技能综合考查的退役大学生士兵应符合天津市及我院规定的报考资格。考生须完成天津市高职升本科文化考…

【数据结构】Java中Map和Set详解(含二叉搜索树和哈希表)

目录 Map和Set详解 1.二叉搜索树 2.Map常见方法 3.Set常见方法 4.哈希表 Map和Set详解 Map:一种键值对结构,hashMap中键和值均可以为空,hashTable中则不可以存放null值 Set:一种集合,不能存放重复元素&#xff0c…

Docker进阶:Docker-compose 实现服务弹性伸缩

Docker进阶:Docker-compose 实现服务弹性伸缩 一、Docker Compose基础概念1.1 Docker Compose简介1.2 Docker Compose文件结构 二、弹性伸缩的原理和实现步骤2.1 弹性伸缩原理2.2 实现步骤 三、技术实践案例3.1 场景描述3.2 配置Docker Compose文件3.3 使用 docker-…

[Semi-笔记]Switching Temporary Teachers for Semi-Supervised Semantic Segmentation

目录 概要创新一:Dual Temporary Teacher挑战:解决: 创新二:Implicit Consistency Learning(隐式一致性学习)挑战:解决: 实验结果小结论文地址代码地址 分享一篇2023年NeurIPS的文章…

碳素光线疗法与阳光猪舍

碳素光线疗法与阳光猪舍 生息在地球上的所有动物、在自然太阳光奇妙的作用下、生长发育。太阳光的能量使它们不断进化、繁衍种族。现在、生物能够生存、全仰仗于太阳的光线。太阳光线中、包含有动物健康所需要的极为重要的波长。因此、和户外饲养的动物相比、在室内喂养的观赏动…

动态规划相关题目

文章目录 1.动态规划理论基础2.斐波那契数3.爬楼梯4.使用最小花费爬楼梯5.不同路径6.不同路径 II7. 整数拆分8. 不同的二叉搜索树 1.动态规划理论基础 1.1 什么是动态规划? 动态规划,英文:Dynamic Programming,简称DP,如果某一…

ClickUp、clickupython

文章目录 关于 ClickUp关于 clickupython安装Authentication方法 1: API Key (最快)当前 ClickUpClient 功能TaskListFolderAttachmentsCommentsTeamsChecklistsGoalsMembersTagsSpacesTime Tracking 教程 关于 ClickUp 官网:https://clickup.com/University Cours…

Acer宏碁暗影骑士擎AN515-58笔记本电脑工厂模式原厂Win11系统ISO镜像安装包下载

宏基AN515-58原装出厂OEM预装Windows11系统工厂包,恢复出厂时开箱状态一模一样,带恢复还原功能 链接:https://pan.baidu.com/s/1iCVSYtList-hPqbyTyaRqQ?pwdt2gw 提取码:t2gw 宏基原装系统自带所有驱动、NITROSENSE风扇键盘灯…

【动手学深度学习】深入浅出深度学习之线性神经网络

目录 🌞一、实验目的 🌞二、实验准备 🌞三、实验内容 🌼1. 线性回归 🌻1.1 矢量化加速 🌻1.2 正态分布与平方损失 🌼2. 线性回归的从零开始实现 🌻2.1. 生成数据集 &#x…

H4020 40V输入EN智能控制 带线补 同步降压控制器芯片

40V输入EN智能控制带线补的同步降压控制器芯片的工作原理涉及多个方面,以下是其关键部分的解释: 输入电压管理:芯片能够接收高达40V的输入电压。这通常需要一个前端电路来处理高电压,比如使用电阻分压、电容器滤波或者其他保护电…

C++:变量和常量(3)

变量 什么是变量:变量就是一个装东西的盒子 通俗:变量是用于存放数据的容器。我们通过变量名获取数据,甚至数据可以修改 变量的作用:给指定的内存空间起名,后期通过起的名字就可以调用整个内存空间 定义变量的格式 &a…

LVS负载均衡-DR模式配置

LVS:Linux virtual server ,即Linux虚拟服务器 LVS自身是一个负载均衡器(Director),不直接处理请求,而是将请求转发至位于它后端的真实服务器real server上。 LVS是四层(传输层 tcp/udp)负载均衡…

基于javaweb(springboot)汽车配件管理系统设计和实现以及文档报告

基于javaweb(springboot)汽车配件管理系统设计和实现以及文档报告 博主介绍:多年java开发经验,专注Java开发、定制、远程、文档编写指导等,csdn特邀作者、专注于Java技术领域 作者主页 央顺技术团队 Java毕设项目精品实战案例《1000套》 欢迎点赞 收藏 ⭐…

python--冒泡排序和main函数

1.判断是不是回文数: x int(input("请输入一个正整数:")) x str(x) if x x[::-1]:print("是回文数。") else:print("不是回文数。") 2.冒泡排序 # 冒泡排序: # [30,8,-10, 50&am…

【漏洞复现】网络验证系统getInfo接口处存在SQL注入漏洞

免责声明:文章来源互联网收集整理,请勿利用文章内的相关技术从事非法测试,由于传播、利用此文所提供的信息或者工具而造成的任何直接或者间接的后果及损失,均由使用者本人负责,所产生的一切不良后果与文章作者无关。该…

【算法】KMP-快速文本匹配

文章目录 一、KMP算法说明二、详细实现1. next数组定义2. 使用next加速匹配3. next数组如何快速生成4. 时间复杂度O(mn)的证明a) next生成的时间复杂度b) 匹配过程时间复杂度 三、例题1. [leetcode#572](https://leetcode.cn/problems/subtree-of-another-tree/description/)2.…

Python 从0开始 一步步基于Django创建项目(13)将数据关联到用户

在city_infos应用程序中,每个城市信息条目是关联到城市的,所以只需要将城市条目关联到用户即可。 将数据关联到用户,就是把‘顶层’数据关联到用户。 设计思路: 1、修改顶层数据模型,向其中添加‘用户’属性 2、根…

虹科Pico汽车示波器 | 免拆诊断案例 | 2018款东风风神AX7车发动机怠速抖动、加速无力

一、故障现象 一辆2018款东风风神AX7车,搭载10UF01发动机,累计行驶里程约为5.3万km。该车因发动机怠速抖动、加速无力及发动机故障灯异常点亮而进厂维修,维修人员用故障检测仪检测,提示气缸3失火;与其他气缸对调点火线…