基于深度学习的面部表情分类识别系统

news2024/9/26 3:25:35

:温馨提示:文末有 CSDN 平台官方提供的学长 QQ 名片 :) 

1. 项目简介

        面部表情识别是计算机视觉领域的一个重要研究方向, 它在人机交互、心理健康评估、安全监控等领域具有广泛的应用。近年来,随着深度学习技术的快速发展, 面部表情识别的准确性和实时性得到了显著提升。本项目以 MobileNetV2 为基础模型构建面向面部表情识别的卷积神经网络, 完成模型的训练、验证和测试,面部表情识别准确率达到 85%以上。并利用 Flask + Bootstrap 框架搭建交互式分析平台,方便用户进行表情的识别。

        B站视频详情及代码下载:基于深度学习的面部表情分类识别系统_哔哩哔哩_bilibili

基于深度学习的面部表情分类识别系统

2. 面部表情数据集读取与预处理

        利用 opencv 读取面部表情图像数据,并转换为 numpy 数组,图片为灰度图:

def prepare_data(ori_data):
    """
    像素数组转成 numpy array
    """
    image_array = np.zeros(shape=(ori_data.shape[0], img_size, img_size))
    image_label = np.array(list(map(int, ori_data['emotion'])))

    for i, row in ori_data.iterrows():
        image = np.fromstring(row['pixels'], dtype=int, sep=' ')
        image = np.reshape(image, (img_size, img_size))
        image_array[i] = image
        
    return image_array, image_label

         读取的数据集,可视化部分样例数据:

3. 数据集制作与样本均衡处理

        数据集共包含:生气(Angry)、厌恶(Disgust)、恐惧(Fear)、开心(Happy)、悲伤(Sad)、惊讶(Surprise)和中性(Neutral)七种类型,其样本数量分布如下:

        可以看出,样本类别极具不均衡,如不处理样本均衡问题,将影响模型的训练,样本少的类别会得不到充分的学习。

         通过对样本进行采样,并切分出训练集、验证集和测试集:

class_data = [data[data['emotion'] == i] for i in range(7)]

# 对每个类别进行过采样,以匹配平均类别数量
oversampled_data = [resample(class_df, replace=True, n_samples=int(average_class_count), random_state=42) for class_df in class_data]
# 将过采样后的数据合并为一个平衡的数据集
balanced_data = pd.concat(oversampled_data)
# 重置索引并直接在原数据框上进行修改
balanced_data.reset_index(drop=True, inplace=True)

# 准备数据,将特征和标签分离
all_x, all_y = prepare_data(balanced_data)
# 重塑特征数据,以匹配输入形状
all_x = all_x.reshape((all_x.shape[0], img_size, img_size, 1))
# 将标签转换为独热编码格式
all_y = to_categorical(all_y)

# 将数据分为训练集和临时集,其中20%用于测试
x_train, x_temp, y_train, y_temp = train_test_split(all_x, all_y, test_size=0.2)
# 将临时集进一步分为验证集和测试集,各占50%
x_val, x_test, y_val, y_test = train_test_split(x_temp, y_temp, test_size=0.5)

print('训练集:{},验证集:{},测试集:{}'.format(x_train.shape[0], x_val.shape[0], x_test.shape[0]))

        可以看出,七种类型的样本数量已基本均衡,有利于神经网络的训练。

4. 面部表情识别卷积神经网络构建

4.1 MobileNetV2 基础模型

        MobileNetV2 是一种轻量级的深度神经网络模型,由Google在2018年发布,旨在用于移动和边缘设备上的高效图像识别任务。它是MobileNetV1的改进版,继承了其轻量级和高效的特点,并在多个方面进行了优化。以下是MobileNetV2模型的主要特点和结构:

        主要特点:

  1. 深度可分离卷积(Depthwise Separable Convolution): MobileNetV2依然采用了深度可分离卷积来减少模型参数和计算量。这种卷积将标准的卷积分解为两个步骤:深度卷积(depthwise convolution)和逐点卷积(pointwise convolution)。

  2. 线性瓶颈(Linear Bottlenecks): 在MobileNetV2中,作者引入了线性瓶颈的概念,即在网络的最后几层使用了线性激活函数(ReLU6)而不是传统的ReLU,这有助于减少信息的损失。

  3. 倒残差结构(Inverted Residuals): MobileNetV2采用了倒残差结构,即先通过一个逐点卷积扩展维度,然后进行深度卷积,最后再用逐点卷积减少维度。这种结构有助于提高网络的表达能力。

  4. 轻量级: 由于采用了上述结构,MobileNetV2在保持精度的同时大大减少了模型的参数数量和计算量,使其非常适合在资源受限的设备上运行。

4.2 基于迁移学习的卷积神经网络构建

        以 MobileNetV2 为 base 模型,加载利用 ImageNet 大规模数据集预训练的 MobileNetV2 模型权重,构建

base_model = tf.keras.applications.MobileNetV2(
    weights='./pretrained_models/mobilenet_v2_weights_tf_dim_ordering_tf_kernels_1.0_96_no_top.h5', 
    include_top=False,
    input_shape=(96,96,3)
)

# Inputlayer
input = tf.keras.layers.Input(name='0_Input',
                            shape=(img_size, img_size, 1))

# Preprocessing stage
x = tf.keras.layers.Resizing(name='1_Preprocessing_1',height = 96, width = 96)(input)
x = tf.keras.layers.Rescaling(name='1_Preprocessing_2',
                            scale = 1/127.0, offset=-1)(x)
x = tf.keras.layers.RandomRotation(name='1_Preprocessing_3',
                                 factor=0.20,
                                 seed=100)(x)
x = tf.keras.layers.RandomFlip(name='1_Preprocessing_4',
                             mode="horizontal",
                             seed=100)(x)

......

# Feature extracting stage
x = base_model(x)
x = tf.keras.layers.Flatten(name='3_Classification_1')(x)

# Classification stage
x = tf.keras.layers.Dense(name='3_Classification_2',
                        units=256,
                        kernel_regularizer=tf.keras.regularizers.l2(l2=regularization_rate),
                        kernel_initializer = 'he_uniform',
                        activation='relu')(x)

......

# Prediction stage
predictions = tf.keras.layers.Dense(name='4_Prediction',
                                  units = 7,
                                  kernel_initializer = 'zeros',
                                  activation=tf.nn.softmax)(x)

model = Model(inputs=input, outputs=predictions)

model.compile(optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate),
            loss='categorical_crossentropy',
            metrics=['acc'])
model.summary()

5. 模型训练与验证

5.1 模型训练 

        利用切分的训练集进行模型的训练,验证集进行模型的验证评估,并保存 val_acc 最高的模型权重:

checkpoint = ModelCheckpoint('save_models/best_model.h5', monitor='val_acc', verbose=1, mode='max',save_best_only=True)
early = EarlyStopping(monitor="val_acc", mode="max",restore_best_weights=True, patience=5)
lrp_reducer = ReduceLROnPlateau(monitor='val_loss', factor=lrp_factor, patience=lrp_patience, verbose=1)

callbacks_list = [checkpoint, early, lrp_reducer]

history = model.fit(
    x_train, y_train, 
    batch_size=batch_size,
    epochs=epochs,
    steps_per_epoch=x_train.shape[0] // batch_size,
    verbose=1,
    callbacks=callbacks_list,
    validation_data=(x_val, y_val),
    validation_steps=x_val.shape[0]//batch_size
)

5.2 测试集预测结果的 AUC 得分与 ROC score 分布

        模型预测测试集的 AUC 得分,并绘制 ROC 曲线:

# 获取疾病标签名称列表
labels = list(emotions.values())

# 创建一个范围,表示 x 轴上每个标签的位置
x = np.arange(len(labels))
# 设置柱状图的宽度
width = 0.80
fig, ax = plt.subplots(figsize=(20, 8), dpi=120)
rects = ax.bar(x, cate_auc, width, color='#EEC900')
ax.set_ylabel('AUC Score', fontsize=20)
ax.set_xlabel('标签', fontsize=20)
ax.set_title('不同类别模型预测 AUC Score 分布', fontsize=30)
ax.set_xticks(x, labels, fontsize=20)
ax.bar_label(rects, padding=3, fontsize=16)
fig.tight_layout()
plt.show()

5.3 困惑矩阵 Confusionmatrix

from matplotlib.colors import LogNorm
import seaborn as sns

true_labels = np.argmax(y_test, axis=1)
predictions = np.argmax(pred_test, axis=1)
conf_matrix = confusion_matrix(true_labels, predictions)

plt.figure(figsize=(10, 8))
sns.heatmap(conf_matrix, annot=True, cmap='GnBu', fmt='g', xticklabels=[emotions[i] for i in range(len(conf_matrix))], yticklabels=[emotions[i] for i in range(len(conf_matrix))], norm=LogNorm())

plt.title('Confusion Matrix')
plt.xlabel('Predicted Label')
plt.ylabel('True Label')
plt.show()

6. 基于深度学习的面部表情分类识别系统

        利用 Flask + Bootstrap 框架搭建响应式布局的交互分析 web 系统,利用 keras load_model 加载训练好的性能最佳的模型,提供标准化 rest api,提供面部表情的在线识别功能。

6.1 系统首页

6.2 面部表情在线识别

        通过上传待测试面部表情图片,提交预测后,后端调用模型进行表情预测,预测结果返回给前端进行渲染可视化,展示预测的标签类别,及各类标签预测的概率分布。

7. 结论

        本项目以 MobileNetV2 为基础模型构建面向面部表情识别的卷积神经网络, 完成模型的训练、验证和测试,面部表情识别准确率达到 85%以上。并利用 Flask + Bootstrap 框架搭建交互式分析平台,方便用户进行表情的识别。

欢迎大家点赞、收藏、关注、评论啦 ,由于篇幅有限,只展示了部分核心代码。技术交流、源码获取认准下方 CSDN 官方提供的学长 QQ 名片 :)

精彩专栏推荐订阅:

1. Python数据挖掘精品实战案例

2. 计算机视觉 CV 精品实战案例

3. 自然语言处理 NLP 精品实战案例

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1975608.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

C++篇:入门(2)

引用 引用的概念以及定义: 在C中,引用(Reference)是一个非常重要的概念又可以称之为取别名,它允许我们创建一个已存在对象的别名。引用提供了一种机制,通过它可以直接访问另一个变量、对象或函数的值&#…

Nginx进阶-常见配置(一)

一、nginx Proxy 反向代理 1、代理原理 反向代理产生的背景: 在计算机世界里,由于单个服务器的处理客户端(用户)请求能力有一个极限,当用户的接入请求蜂拥而入时,会造成服务器忙不过来的局面&#xff0c…

【实现100个unity特效之12】Unity中的冲击波 ——如何使用ShaderGraph制作一个冲击波着色器

最终效果 文章目录 最终效果新增LitShaderGraph圆环扭曲效果优化冲击波效果屏幕全屏冲击波圆形冲击波最终连线图代码控制补充源码完结 新增LitShaderGraph 圆环扭曲效果 让我们从一个UV节点开始 创建一个Vector2变量RingSpawnPosition表示冲击波生成位置,在X和Y上将其默认值…

springboot实现前后端调用axios异步请求(后端单体服务器static部分)

目的:让页面调用controller,将数据加载到页面中(只不过这个前端页面我们直接就是放到了static里面了)。 第一步:导入文件 所需要的文件见本文最后“文件获取”: (1)文件如下&…

汇昌联信拼多多运营怎么样?

汇昌联信拼多多运营怎么样?在探讨汇昌联信在拼多多平台的运营情况时,首先需要明确的回答是:汇昌联信在拼多多的运营表现是积极的,并取得了一定的成效。接下来,我们将从五个不同的角度深入分析其运营策略及效果。 一、产品多样性与…

Centos7挂载数据盘

查看当前服务器有哪些磁盘 fdisk -l 2.格式化 mke2fs -t ext4 /dev/vdc 3.挂载数据盘 mkdir /sdxinfang mount /dev/vdc /sdxinfang/ 为了避免每次开机都要重新挂载,直接设置系统挂载信息,这样开机会自动挂载 vim /etc/fstab 在文件末尾增加以下内容&…

Axure八大优质Web端系统框架模版

在当今数字化转型的浪潮中,Axure作为一款强大的原型设计工具,以其快速、直观和易用的特点,成为了众多设计师和产品经理的首选。本文将详细介绍六套基于Axure制作的智慧系统原型框架模版,包括智慧园区、智慧社区、智慧乡村、智慧驾…

4个好用的 CSS 伪类 :not()、:has()、 :is()、:where()

文章目录 (1):not()(2):has()(3):is()(4):where()(5):where()与:is() 的区别 (1):not() :not 伪类:用于选择不满足给定条…

微信小程序开发费用一览表,不同开发方式的费用对比

微信小程序作为当前移动互联网领域的重要入口之一,其开发费用因开发方式、功能需求、设计复杂度及开发团队的不同而有所差异。本文将详细梳理微信小程序开发的几种主要方式,并对比各方式的费用情况,以便企业和个人在选择时能够有更清晰的了解…

PHP 打印 V 和倒 V 图案的程序(Program to print V and inverted-V pattern)

倒 V 型模式:给定 n 的值,打印倒 V 型模式。示例: 输入:n 5 输出 : E D D C C B B A A 输入:n 7 输出 : G F F E E D D C C B B A…

pycharm中安装、使用扩展工具,以QT Designer为例

pycharm中安装、使用扩展工具,以QT Designer为例 第一步,下载QT Designer安装包。找到QT Designer.exe所在位置,复制路径 第二步,打开Pycharm,选择Setting,找到扩展工具(External Tools&#xf…

git回退未commit、回退已commit、回退已push、合并某一次commit到另一个分支

文章目录 1、git回退未commit2、git回退已commit3、git回退已push的代码3.1 直接丢弃某一次的push3.2 撤销push后,不丢弃改动,重新修改后要再次push 4、合并某一次commit到另一个分支 整理几个工作上遇到的git问题。 1、git回退未commit git回退未comm…

【C++】STL-哈希表封装unorder_set和unordered_map

目录 1、实现哈希表的泛型 2、unordered_set和unordered_map的插入 3、迭代器 3.1 operator 3.2 const迭代器 4、find 5、unordered_map的operator[] 6、对于无法取模的类型 7、介绍unordered_set的几个函数 7.1 bucket_count 7.2 bucket_size 7.3 bucket 7.4 rese…

Gcc/G++编译C/C++文件(主要以C++语言为主,C语言就做阐述 用法一样 就是将G++换成GCC)

首先,我们在Linux中创建一个helloc.cc文件(C文件) vim helloc.cc 直接用g裸编译 g helloc.cc 生成的a.out就是二进制可执行文件 如果要产生 自定义可执行文件 就需要下面的编译步骤 繁琐操作 g -c helloc.cc 会生成目标文件 g -o hello helloc.o 此时hell…

仿SOUL社交友附近人婚恋约仿陌陌APP系统源码

专门为单身男女打造的恋爱交友社区,就是一个由千千万万单身男女组建的大家庭。他们来自全国各地,或许有着不同的人生经历,却有着共同的对恋爱交友的渴望。他们可以通过文字、语音、视频聊天的方式,和镜头前的彼此诉说自己工作中发…

95页PPT丨IBM-IT应用规划

一、IBM针对IT应用规划项目核心内容IBM在IT应用规划项目中的核心内容,旨在帮助企业实现数字化转型,优化IT资源配置,并确保IT战略与业务目标的一致性。以下是IBM IT应用规划项目的详细核心内容: 资料下载方式,请看每张…

LabVIEW与CANopen实现自动化生产线的设备控制与数据采集

在某工厂的自动化生产线上,多个设备通过CANopen网络进行通信和控制。这些设备包括传感器、执行器和PLC,它们共同负责监测和控制生产过程中的关键参数,如温度、压力、速度等。为了实现对整个生产线的集中监控和管理,工厂决定使用La…

深入理解同城代驾系统源码:技术架构与实现细节

今天,小编将深入讲解同城代驾系统的技术架构与实现细节。 一、同城代驾系统的基本功能模块 一个完整的同城代驾系统通常包括以下核心功能模块: 1.用户端应用 2.司机端应用 3.后台管理系统 4.消息推送与通知 二、技术架构设计 同城代驾系统的技术架…

程序设计基础(c语言)_补充_1

1、编程应用双层循环输出九九乘法表 #include <stdio.h> #include <stdlib.h> int main() {int i,j;for(i1;i<9;i){for(j1;j<i;j)if(ji)printf("%d*%d%d",j,i,j*i);elseprintf("%d*%d%-2d ",j,i,j*i);printf("\n");}return 0…

DNS处理模块 dnspython

DNS处理模块 dnspython 标题介绍安装dnspython 模块常用方法介绍实践&#xff1a;DNS域名轮询业务监控 标题介绍 Dnspython 是 Python 的 DNS 工具包。它可用于查询、区域传输、动态更新、名称服务器测试和许多其他事情。 dnspython 模块提供了大量的 DNS 处理方法&#xff0c…