CENet及多模态情感计算实战(论文复现)

news2024/9/20 20:26:31

CENet及多模态情感计算实战(论文复现)

本文所涉及所有资源均在传知代码平台可获取

文章目录

    • CENet及多模态情感计算实战(论文复现)
      • 概述
      • 研究背景
      • 主要贡献
      • 论文思路
      • 主要内容和网络架构
      • 数据集介绍
      • 性能对比
      • 复现过程(重要)
      • 演示结果

概述

本文对 “Cross-Modal Enhancement Network for Multimodal Sentiment Analysis” 论文进行讲解和手把手复现教学,解决当下热门的多模态情感计算问题,并展示在MOSI和MOSEI两个数据集上的效果

研究背景

情感分析在人工智能向情感智能发展中起着重要作用。早期的情感分析研究主要集中在分析单模态数据上,包括文本情感分析、图像情感分析、音频情感分析等。然而,人类的情感是通过人体的多种感官来传达的。因此,单模态情感分析忽略了人类情感的多维性。相比之下,多模态情感分析通过结合文本、视觉和音频等多模态数据来推断一个人的情感状态。与单模态情感分析相比,多模态数据包含多样化的情感信息,具有更高的预测精度。目前,多模态情感分析已被广泛应用于视频理解、人机交互、政治活动等领域。近年来,随着互联网和各种多媒体平台的快速发展,通过互联网表达情感的载体和方式也变得越来越多样化。这导致了多媒体数据的快速增长,为多模态情感分析提供了大量的数据源。下图展示了多模态在情感计算任务中的优势。

在这里插入图片描述

主要贡献

  • 提出了一种跨模态增强网络,通过融入长范围非文本情感语境来增强预训练语言模型中的文本表示;
  • 提出一种特征转换策略,通过减小文本模态和非文本模态的初始表示之间的分布差异,促进了不同模态的融合;
  • 融合了新的预训练语言模型SentiLARE来提高模型对情感词的提取效率,从而提升对情感计算的准确性。

论文思路

作者提出的跨模态增强网络(CENet)模型通过将视觉和声学信息集成到语言模型中来增强文本表示。在基于transformer的预训练语言模型中嵌入跨模态增强(CE)模块,根据非对齐非文本数据中隐含的长程情感线索增强每个单词的表示。此外,针对声学模态和视觉模态,提出了一种特征转换策略,以减少语言模态和非语言模态的初始表示之间的分布差异,从而促进不同模态的融合。

主要内容和网络架构

在这里插入图片描述

通过该图,我们可以看出该模型主要有以下几部分组成:

  • 非文本模态特征转化
  • 跨模态增强
  • 预训练语言模型输出
  • 非文本模态转换

针对预训练语言模型,初始文本表示是基于词汇表的单词索引序列,而视觉和听觉的表示则是实值向量序列。为了缩小这些异质模态之间的初始分布差异,进而减少在融合过程中非文本特征和文本特征之间的分布差距,本文提出了一种将非文本向量转换为索引的特征转换策略。这种策略有助于促进文本表征与非文本情感语境的有机融合。

具体而言,特征转换策略利用无监督聚类算法分别构建了“声学词汇表”和“视觉词汇表”。通过查询这些非语言词汇表,可以将原始的非语言特征序列转换为索引序列。下图展示了特征转换过程的具体步骤。考虑到k-means方法具有计算复杂度低和实现简单等优点,作者选择使用k-means算法来学习非语言模态的词汇。

在这里插入图片描述

  • 跨模态增强模块

本文提出的CE模块旨在将长程视觉和声学信息集成到预训练语言模型中,以增强文本的表示能力。CE模块的核心组件是跨模态嵌入单元,其结构如下图所示。该单元利用跨模态注意力机制捕捉长程非文本情感信息,并生成基于文本的非语言嵌入。嵌入层的参数可学习,用于将经过特征转换策略处理后得到的非文本索引向量映射到高维空间,然后生成文本模态对非文本模态的注意力权重矩阵。

在这里插入图片描述

在初始训练阶段,由于语言表征和非语言表征处于不同的特征空间,它们之间的相关性通常较低。因此,注意力权重矩阵中的元素可能较小。为了更有效地学习模型参数,研究者在应用softmax之前使用超参数来缩放这些注意力权重矩阵。

基于注意力权重矩阵,可以生成基于文本的非语言向量。将基于文本的声学嵌入和基于文本的视觉嵌入结合起来,形成非语言增强嵌入。最后,通过整合非语言增强嵌入来更新文本的表示。因此,CE模块的提出旨在为文本提供非语言上下文信息,通过增加非语言增强嵌入来调整文本表示,从而使其在语义上更加准确和丰富。

  • 预训练语言模型输出

作者采用SentiLARE作为语言模型,其利用包括词性和单词情感极性在内的语言知识来学习情感感知的语言表示。CE模块被集成到预训练语言模型的第i层中。值得注意的是,任何基于Transformer的预训练语言模型都可以与CE模块集成。下面是作者根据SentiLARE的设置进行的步骤:

  1. 给定一个单词序列,首先通过Stanford Log-Linear词性(POS)标记器学习其词性序列,并通过SentiwordNet学习单词级情感极性序列。
  2. 然后,使用预训练语言模型的分词器获取词标索引序列。这个序列作为输入,产生一个初步的增强语言知识表示。
  3. 更新后的文本表示将作为第(i+1)层的输入,并通过SentiLARE中的剩余层进行处理。
  4. 每一层的输出将是具有视觉和听觉信息的文本主导的高级情感表示。
  5. 最后,将这些文本表示输入到分类头中,以获取情感强度。

因此,CE模块通过将非语言增强嵌入集成到预训练语言模型中,有助于生成更富有情感感知的语言表示。这种方法能够在文本表示中有效地整合视觉和听觉信息,从而提升情感分析等任务的性能。

数据集介绍

1. CMU-MOSI: 它是一个多模态数据集,包括文本、视觉和声学模态。它来自Youtube上的93个电影评论视频。这些视频被剪辑成2199个片段。每个片段都标注了[-3,3]范围内的情感强度。该数据集分为三个部分,训练集(1,284段)、验证集(229段)和测试集(686段)。

2. CMU-MOSEI: 它类似于CMU-MOSI,但规模更大。它包含了来自在线视频网站的23,453个注释视频片段,涵盖了250个不同的主题和1000个不同的演讲者。CMU-MOSEI中的样本被标记为[-3,3]范围内的情感强度和6种基本情绪。因此,CMU-MOSEI可用于情感分析和情感识别任务。

性能对比

有下图可以观察到,该论文提出的CENet与其他SOTA模型对比性能有明显提升:

在这里插入图片描述

复现过程(重要)

1. 数据集准备
下载MOSI和MOSEI数据集已提取好的特征文件(.pkl)。把它放在"./dataset”目录。

2. 下载预训练语言模型
下载SentiLARE语言模型文件,然后将它们放入"/pretrained-model / sentilare_model”目录。

3. 下载需要的包

pip install -r requirements.txt

4. 搭建CENet模块
利用pytorch框架对CENet模块进行搭建:

class CE(nn.Module):
    def __init__(self, beta_shift_a=0.5, beta_shift_v=0.5, dropout_prob=0.2):
        super(CE, self).__init__()
        self.visual_embedding = nn.Embedding(label_size + 1, TEXT_DIM, padding_idx=label_size)
        self.acoustic_embedding = nn.Embedding(label_size + 1, TEXT_DIM, padding_idx=label_size)
        self.hv = SelfAttention(TEXT_DIM)
        self.ha = SelfAttention(TEXT_DIM)
        self.cat_connect = nn.Linear(2 * TEXT_DIM, TEXT_DIM)
        

    def forward(self, text_embedding, visual=None, acoustic=None, visual_ids=None, acoustic_ids=None):
        visual_ = self.visual_embedding(visual_ids)
        acoustic_ = self.acoustic_embedding(acoustic_ids)
        visual_ = self.hv(text_embedding, visual_)
        acoustic_ = self.ha(text_embedding, acoustic_) 
        visual_acoustic = torch.cat((visual_, acoustic_), dim=-1)
        shift = self.cat_connect(visual_acoustic)
        embedding_shift = shift + text_embedding
    
        return embedding_shift

5. 将CE模块与预训练语言模型融合

class BertEncoder(nn.Module):
    def __init__(self, config):
        super(BertEncoder, self).__init__()
        self.output_attentions = config.output_attentions
        self.output_hidden_states = config.output_hidden_states
        self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
        self.CE = CE()

    def forward(self, hidden_states, visual=None, acoustic=None, visual_ids=None, acoustic_ids=None, attention_mask=None, head_mask=None):
        all_hidden_states = ()
        all_attentions = ()
        for i, layer_module in enumerate(self.layer):
            if self.output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states,)

            if i == ROBERTA_INJECTION_INDEX:
                hidden_states = self.CE(hidden_states, visual=visual, acoustic=acoustic, visual_ids=visual_ids, acoustic_ids=acoustic_ids)

6. 训练代码编写

定义一个整体的训练过程 train()函数,它负责训练模型多个 epoch,并在每个 epoch 结束后评估模型在验证集和测试集上的性能,并记录相关的指标和损失;并在训练最后一轮输出所有测试集id,true label 和 predicted label。

def train(
    args,
    model,
    train_dataloader,
    validation_dataloader,
    test_data_loader,
    optimizer,
    scheduler,
):
    valid_losses = []
    test_accuracies = []
    for epoch_i in range(int(args.n_epochs)):
        train_loss, train_pre, train_label = train_epoch(args, model, train_dataloader, optimizer, scheduler)
        valid_loss, valid_pre, valid_label = evaluate_epoch(args, model, validation_dataloader)

        test_loss, test_pre, test_label = evaluate_epoch(args, model, test_data_loader)
        train_acc, train_mae, train_corr, train_f_score = score_model(train_pre, train_label)
        test_acc, test_mae, test_corr, test_f_score = score_model(test_pre, test_label)
        non0_test_acc, _, _, non0_test_f_score = score_model(test_pre, test_label, use_zero=True)
        valid_acc, valid_mae, valid_corr, valid_f_score = score_model(valid_pre, valid_label)

        print(
            "epoch:{}, train_loss:{}, train_acc:{}, valid_loss:{}, valid_acc:{}, test_loss:{}, test_acc:{}".format(
                epoch_i, train_loss, train_acc, valid_loss, valid_acc, test_loss, test_acc
            )
        )
        valid_losses.append(valid_loss)
        test_accuracies.append(test_acc)
        wandb.log(
            (
                {
                    "train_loss": train_loss,
                    "valid_loss": valid_loss,
                    "train_acc": train_acc,
                    "train_corr": train_corr,
                    "valid_acc": valid_acc,
                    "valid_corr": valid_corr,
                    "test_loss": test_loss,
                    "test_acc": test_acc,
                    "test_mae": test_mae,
                    "test_corr": test_corr,
                    "test_f_score": test_f_score,
                    "non0_test_acc": non0_test_acc,
                    "non0_test_f_score": non0_test_f_score,
                    "best_valid_loss": min(valid_losses),
                    "best_test_acc": max(test_accuracies),
                }
            )
        )

    # 输出测试集的 id、真实标签和预测标签
    with torch.no_grad():
        for step, batch in enumerate(test_data_loader):
            batch = tuple(t.to(DEVICE) for t in batch)
            input_ids, visual_ids, acoustic_ids, pos_ids, senti_ids, polarity_ids, visual, acoustic, input_mask, segment_ids, label_ids = batch
            visual = torch.squeeze(visual, 1)
            outputs = model(
                input_ids,
                visual,
                acoustic,
                visual_ids,
                acoustic_ids,
                pos_ids, senti_ids, polarity_ids,
                token_type_ids=segment_ids,
                attention_mask=input_mask,
                labels=None,
            )
            logits = outputs[0]
            logits = logits.detach().cpu().numpy()
            label_ids = label_ids.detach().cpu().numpy()
            logits = np.squeeze(logits).tolist()
            label_ids = np.squeeze(label_ids).tolist()
            
            # 假设您从 label_ids 中获取 ids
            ids = [f"sample_{idx}" for idx in range(len(label_ids))]  # 这里是示例,您可以根据实际情况生成合适的 ids
            
            # 输出所有测试样本的 id、真实标签和预测标签值
            for i in range(len(ids)):
                print(f"id: {ids[i]}, true label: {label_ids[i]}, predicted label: {logits[i]}")

6. 开始训练+测试

python train.py

7. 输出结果
此外,为了方便直观地查看模型性能,我在最后一层训练结束后将所有测试集视频的clip id、真实标签和预测标签依次进行输出;并且结合wandb库自动保存结果可视化;结果在后续章节展示。

演示结果

训练过程

在这里插入图片描述

模型性能结果

在这里插入图片描述

接下来是我们自己补充的每个测试集的真实标签和预测标签

在这里插入图片描述

可视化

在这里插入图片描述

文章代码资源点击附件获取

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2118440.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

数据集成在搭建“智慧校园”中的使用

智慧校园是一种新型校园数字化建设方式,目前被全国高校使用。 智慧校园利用现代信息技术,如物联网、云计算、大数据等,搭建集成硬件设施和数据平台,实现校园管理高度信息化,提高校园管理的效率。 如何搭建“智慧校园…

SPI驱动学习五(如何编写SPI设备驱动程序)

目录 一、SPI驱动程序框架二、怎么编写SPI设备驱动程序1. 编写设备树2. 注册spi_driver3. 怎么发起SPI传输3.1 接口函数3.2 函数解析 三、示例1:编写SPI_DAC模块驱动程序1. 要做什么事情2. 硬件2.1 原理图2.2 连接 3. 编写设备树4. 编写驱动程序5. 编写app层操作程序…

OpenHarmony鸿蒙开发( Beta5.0)无感配网详解

1、简介 无感配网是指在设备联网过程中无需输入热点相关账号信息,即可快速实现设备配网,是一种兼顾高效性、可靠性和安全性的配网方式。2、配网原理 2.1 通信原理 手机和智能设备之间的信息传递,利用特有的NAN协议实现。利用手机和智能设备…

Linux 网络编程 --- 应用层

一、自定义协议和序列化反序列化 代码: 序列化反序列化实现网络版本计算器 二、HTTP协议 1、谈两个简单的预备知识 https://www.baidu.com/ --- 域名 --- 域名解析 --- IP地址 http的端口号为80端口,https的端口号为443 url为统一资源定位符。CSDN…

Java+Swing图书管理系统

JavaSwing图书管理系统 一、系统介绍二、功能展示1.管理员登陆2.图书查询3.图书入库4.借书5.还书6.图书证管理 三、系统实现1.BookManageMainFrame.java 四、其它1.其他系统实现 一、系统介绍 该系统实现了用户端实现书籍查询,借书,还书功能。用户能够查…

QLib学习

开源地址:GitHub - microsoft/qlib: Qlib is an AI-oriented quantitative investment platform that aims to realize the potential, empower research, and create value using AI technologies in quantitative investment, from exploring ideas to implementi…

为何在106短信群发前需完成实名认证?

为了严格遵循国家法律法规及电信运营商的规范要求,确保您及贵公司的合法权益得到充分保障,使用云通信服务进行106短信群发前,实名认证成为必要步骤。 根据《中华人民共和国网络安全法》第二十四条的明确规定: 法律强制性要求&…

vulhub think PHP 2-rce远程命令执行漏洞

1.开启环境 2。访问对应网站端口 3.这里我们直接构造payload,访问phpinfo() http://192.168.159.149:8080/?s/Index/index/L/${phpinfo()} 4.可以访问到我们的phpinfo, 所以写入一句话木马,也可使用蚁剑进行连接,获得其shell进…

GenAI 用于客户支持 — 第 4 部分:调整 RAG 搜索的相关性

作者:来自 Elastic Antonio Schnmann 欢迎阅读我们关于将生成式 AI 集成到 Elastic 客户支持的博客系列的第 4 部分。本期深入探讨了检索增强生成 (RAG) 在增强我们 AI 驱动的技术支持助理方面的作用。在这里,我们解决了改进搜索效果的挑战、解决方案和结…

【C++】windwos下vscode多文件项目创建、编译、运行

目录 🌕vscode多文件项目创建方法🌙具体案例⭐命令行创建项目名,并在vscode中打开项目⭐创建include目录和头文件⭐创建src目录和cpp文件⭐根目录下创建main.cpp 🌕运行项目失败(找不到include目录下的头文件和src目录…

实习生进了公司不会开发?一个真实业务教会你如何分析业务!

一、权限业务介绍 本业务为个人试用期的时候的真实企业级业务,进行了简单的替换词和业务内容模糊,但是业务逻辑没有变化,值得学习,代码部分可能因为替换词有部分误解,发现问题请评论区提醒我。 这个业务教会了我&…

最新前端开发VSCode高效实用插件推荐清单

在此进行总结归类工作中用到的比较实用的、有助于提升开发效率的VSCode插件。大家有其他的好插件推荐的也欢迎留言评论区哦😄 基础增强 Chinese (Simplified) Language Pack: 提供中文界面。 Code Spell Checker: 检查代码中的拼写错误。 ESLint: 集成 ESLint&…

AIoTedge边缘计算+边缘物联网平台

在数字化转型的浪潮中,AIoTedge边缘计算平台以其边云协同的架构和强大的分布式AIoT处理能力,正成为推动智能技术发展的关键力量。AIoTedge通过在数据源附近处理信息,实现低延迟、快速响应,增强了应用的实时性。同时,它…

CCF推荐C类会议和期刊总结:(计算机网络领域)

CCF推荐C类会议和期刊总结(计算机网络领域) 在计算机网络领域,中国计算机学会(CCF)推荐的C类会议和期刊为研究者提供了广泛的学术交流平台。以下是对所有C类会议和期刊的总结,包括全称、出版社、dblp文献网…

DMDPC单副本集群安装

1. 环境描述 2. 部署步骤 2.1. 安装DM数据库软件启动DMAP [dmdbalei1 ~]$ DmAPService status DmAPService (pid 1269) is running.2.2. 初始化数据库实例 [dmdbalei1 data]$ dminit path/dmdba/data/sp1 instance_nameSP1 port_num5236 ap_port_num6000 dpc_modeSP initdb …

野火霸天虎V2学习记录

文章目录 嵌入式开发常识汇总1、嵌入式Linux和stm32之间的区别和联系2、stm32程序下载方式3、Keil5安装芯片包4、芯片封装种类5、STM32命名6、数据手册和参考手册7、什么是寄存器、寄存器映射和内存映射8、芯片引脚顺序9、stm32芯片里有什么10、存储器空间的划分11、如何理解寄…

【前端学习】AntV G6-08 深入图形与图形分组、自定义节点、节点动画(下)

【课程链接】 AntV G6:深入图形与图形分组、自定义节点、节点动画(下)_哔哩哔哩_bilibili 本章十吾老师讲解了一个复杂的自定义节点中,应该怎样去计算和绘制图形,如何给一个图形制作不间断的动画,以及在鼠…

长芯微国产LS0104电平转换器/电平移位器P2P替代TXS0104

LS0104这个4位非反相转换器是一个双向电压电平转换器,可用于建立数字混合电压系统之间的切换兼容性。它使用两个单独的可配置电源轨支持1.65V至5.5V工作电压的端口同时它跟踪VCCA电源和支持的B端口工作电压从2.3V到5.5V,同时跟踪VCCB供应。这允许更低和更…

【C++】stackqueuedeque

目录 一、stack: 1、栈的介绍: 2、栈的使用: 3、栈的成员函数的模拟实现: 二、queue: 1、队列的介绍: 2、队列的使用: 3、队列的模拟实现: 三、deque: deque的底层结构&am…

电脑丢失msvcp120.dll问题原因及详细介绍msvcp120.dll丢失的解决方法

在使用 Windows 操作系统时,你可能会遇到一个常见的错误提示:“程序无法启动,因为计算机中丢失 msvcp120.dll”。这个错误通常意味着你的系统中缺少了一个重要的动态链接库文件,即msvcp120.dll。这个文件是 Microsoft Visual C 20…