传知代码-图神经网络长对话理解(论文复现)

news2024/12/28 5:31:36

代码以及视频讲解

本文所涉及所有资源均在传知代码平台可获取

概述

情感识别是人类对话理解的关键任务。随着多模态数据的概念,如语言、声音和面部表情,任务变得更加具有挑战性。作为典型解决方案,利用全局和局部上下文信息来预测对话中每个单个句子(即话语)的情感标签。具体来说,全局表示可以通过对话级别的跨模态交互建模来捕获。局部表示通常是通过发言者的时间信息或情感转变来推断的,这忽略了话语级别的重要因素。此外,大多数现有方法在统一输入中使用多模态的融合特征,而不利用模态特定的表示。针对这些问题,我们提出了一种名为“关系时序图神经网络与辅助跨模态交互(CORECT)”的新型神经网络框架,它以模态特定的方式有效捕获了对话级别的跨模态交互和话语级别的时序依赖,用于对话理解。大量实验证明了CORECT的有效性,通过在IEMOCAP和CMUMOSEI数据集上取得了多模态ERC任务的最新成果。

模型整体架构

在这里插入图片描述

特征提取

文本采用transformerde方式进行编码
在这里插入图片描述

音频,视频都采用全连接的方式进行编码
在这里插入图片描述

通过添加相应的讲话者嵌入来增强技术增强
在这里插入图片描述

关系时序图卷积网络(RT-GCN)

解读:RT-GCN旨在通过利用话语之间以及话语与其模态之间的多模态图来捕获对话中每个话语的局部上下文信息,关系时序图在一个模块中同时实现了上下文信息,与模态之间的信息的传递。对话中情感识别需要跨模态学习到信息,同时也需要学习上下文的信息,整合成一个模块的作用将两部分并行处理,降低模型的复杂程度,降低训练成本,降低训练难度。

建图方式,模态与模态之间有边相连,对话之间有边相连:

在这里插入图片描述

建图之后,用图transformer融合不同模态,以及不同语句的信息,得到处理之后特征向量:
在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

两两交叉模态特征交互

跨模态的异质性经常提高了分析人类语言的难度。利用跨模态交互可能有助于揭示跨模态之间的“不对齐”特性和长期依赖关系。受到这一思想的启发(Tsai等人,2019),我们将配对的跨模态特征交互(P-CM)方法设计到我们提出的用于对话理解的框架中。

在这里插入图片描述
在这里插入图片描述

在这里插入图片描述

线性分类器

最后就是根据提取出来的特征进行情感分类了:

在这里插入图片描述

代码修改

这是对话中多模态情感识别(视觉,音频,文本)在数据集IEMOCAP目前为止的SOTA。在离线系统已经取得了相当不错的表现。(离线系统的意思是,是一段已经录制好的视频,而不是事实录制如线上开会)

但是却存在一个问题,输入的数据是已经给定的一个视频,分析某一句话的情感状态的时候,论文的方法使用了过去的信息,也使用了未来的信息,这样会在工业界实时应用场景存在一定的问题。

比如在开线上会议,需要检测开会双方的情绪,不可能用未来将要说的话预测现在的情绪。因为未来的话都还没被说话者说出来,此时,就不能参考到未来的语句来预测现在语句的情感信息。但是原文的方法在数据结构图的构建的时候,连接上了未来语句和现在语句的边,用图神经网络学习了之间的关联。

因此,修改建图方式,不考虑未来的情感信息,重新训练网络,得到了还可以接受的效果,精度大概在82%左右,原文的精度在84%左右,2%精度的牺牲解决了是否能实时的问题其实是值得的。

演示效果

在这里插入图片描述
在这里插入图片描述

核心逻辑

在这里可以粘贴您的核心代码逻辑:

# start

#模型核心部分

import torch
import torch.nn as nn
import torch.nn.functional as F

from .Classifier import Classifier
from .UnimodalEncoder import UnimodalEncoder
from .CrossmodalNet import CrossmodalNet
from .GraphModel import GraphModel
from .functions import multi_concat, feature_packing
import corect

log = corect.utils.get_logger()

class CORECT(nn.Module):
    def __init__(self, args):
        super(CORECT, self).__init__()

        self.args = args
        self.wp = args.wp
        self.wf = args.wf
        self.modalities = args.modalities
        self.n_modals = len(self.modalities)
        self.use_speaker = args.use_speaker
        g_dim = args.hidden_size
        h_dim = args.hidden_size

        ic_dim = 0
        if not args.no_gnn:
            ic_dim = h_dim * self.n_modals

            if not args.use_graph_transformer and (args.gcn_conv == "gat_gcn" or args.gcn_conv == "gcn_gat"):
                ic_dim = ic_dim * 2

            if args.use_graph_transformer:
                ic_dim *= args.graph_transformer_nheads
        
        if args.use_crossmodal and self.n_modals > 1:
            ic_dim += h_dim * self.n_modals * (self.n_modals - 1)

        if self.args.no_gnn and (not self.args.use_crossmodal or self.n_modals == 1):
            ic_dim = h_dim * self.n_modals

        
        a_dim = args.dataset_embedding_dims[args.dataset]['a']
        t_dim = args.dataset_embedding_dims[args.dataset]['t']
        v_dim = args.dataset_embedding_dims[args.dataset]['v']
        
        dataset_label_dict = {
            "iemocap": {"hap": 0, "sad": 1, "neu": 2, "ang": 3, "exc": 4, "fru": 5},
            "iemocap_4": {"hap": 0, "sad": 1, "neu": 2, "ang": 3},
            "mosei": {"Negative": 0, "Positive": 1},
        }

        dataset_speaker_dict = {
            "iemocap": 2,
            "iemocap_4": 2,
            "mosei":1,
        }
        
        
        tag_size = len(dataset_label_dict[args.dataset])
        self.n_speakers = dataset_speaker_dict[args.dataset]

        self.wp = args.wp
        self.wf = args.wf
        self.device = args.device


        self.encoder = UnimodalEncoder(a_dim, t_dim, v_dim, g_dim, args)
        self.speaker_embedding = nn.Embedding(self.n_speakers, g_dim)

        print(f"{args.dataset} speakers: {self.n_speakers}")
        if not args.no_gnn:
            self.graph_model = GraphModel(g_dim, h_dim, h_dim, self.device, args)
            print('CORECT --> Use GNN')

        if args.use_crossmodal and self.n_modals > 1:
            self.crossmodal = CrossmodalNet(g_dim, args)
            print('CORECT --> Use Crossmodal')
        elif self.n_modals == 1:
            print('CORECT --> Crossmodal not available when number of modalitiy is 1')

        self.clf = Classifier(ic_dim, h_dim, tag_size, args)

        self.rlog = {}


    def represent(self, data):

        # Encoding multimodal feature
        a = data['audio_tensor'] if 'a' in self.modalities else None
        t = data['text_tensor'] if 't' in self.modalities else None
        v = data['visual_tensor'] if 'v' in self.modalities else None

        a, t, v = self.encoder(a, t, v, data['text_len_tensor'])


        # Speaker embedding
        if self.use_speaker:
            emb = self.speaker_embedding(data['speaker_tensor'])
            a = a + emb if a != None else None
            t = t + emb if t != None else None
            v = v + emb if v != None else None

        # Graph construct
        multimodal_features = []

        if a != None:
            multimodal_features.append(a)
        if t != None:
            multimodal_features.append(t)
        if v != None:
            multimodal_features.append(v)

        out_encode = feature_packing(multimodal_features, data['text_len_tensor'])
        out_encode = multi_concat(out_encode, data['text_len_tensor'], self.n_modals)

        out = []

        if not self.args.no_gnn:
            out_graph = self.graph_model(multimodal_features, data['text_len_tensor'])
            out.append(out_graph)


        if self.args.use_crossmodal and self.n_modals > 1:
            out_cr = self.crossmodal(multimodal_features)

            out_cr = out_cr.permute(1, 0, 2)
            lengths = data['text_len_tensor']
            batch_size = lengths.size(0)
            cr_feat = []
            for j in range(batch_size):
                cur_len = lengths[j].item()
                cr_feat.append(out_cr[j,:cur_len])

            cr_feat = torch.cat(cr_feat, dim=0).to(self.device)
            out.append(cr_feat)
        
        if self.args.no_gnn and (not self.args.use_crossmodal or self.n_modals == 1):
            out = out_encode
        else:
            out = torch.cat(out, dim=-1)

        return out

    def forward(self, data):
        graph_out = self.represent(data)
        out = self.clf(graph_out, data["text_len_tensor"])

        return out
    
    def get_loss(self, data):
        graph_out = self.represent(data)
        loss = self.clf.get_loss(
                graph_out, data["label_tensor"], data["text_len_tensor"])
        
        return loss

    def get_log(self):
        return self.rlog


        

#图神经网络
import torch
import torch.nn as nn
from torch_geometric.nn import RGCNConv, TransformerConv

import corect

class GNN(nn.Module):
    def __init__(self, g_dim, h1_dim, h2_dim, num_relations, num_modals, args):
        super(GNN, self).__init__()
        self.args = args

        self.num_modals = num_modals
        
        if args.gcn_conv == "rgcn":
            print("GNN --> Use RGCN")
            self.conv1 = RGCNConv(g_dim, h1_dim, num_relations)

        if args.use_graph_transformer:
            print("GNN --> Use Graph Transformer")
           
            in_dim = h1_dim
                
            self.conv2 = TransformerConv(in_dim, h2_dim, heads=args.graph_transformer_nheads, concat=True)
            self.bn = nn.BatchNorm1d(h2_dim * args.graph_transformer_nheads)
            

    def forward(self, node_features, node_type, edge_index, edge_type):

        if self.args.gcn_conv == "rgcn":
            x = self.conv1(node_features, edge_index, edge_type)
        
        if self.args.use_graph_transformer:
            x = nn.functional.leaky_relu(self.bn(self.conv2(x, edge_index)))
        
        return x

使用方式&部署方式

首先建议安装conda,因为想要复现深度学习的代码,github上不同项目的环境差别太大,同时处理多个项目的时候很麻烦,在这里就不做conda安装的教程了,请自行学习。

安装pytorch:
请到pytorch官网找安装命令,尽量不要直接pip install
https://pytorch.org/get-started/previous-versions/

给大家直接对着我安装版本来下载,因为图神经网络的包版本要求很苛刻,版本对应不上很容易报错:
在这里插入图片描述
在这里插入图片描述

只要环境配置好了,找到这个文件,里面的代码粘贴到终端运行即可
在这里插入图片描述

温馨提示

1.数据集和已训练好的模型都在.md文件中有百度网盘链接,直接下载放到指定文件夹即可
2.注意,训练出来的模型是有硬件要求的,我是用cpu进行训练的,模型只能在cpu跑,如果想在gpu上跑,请进行重新训练
3.如果有朋友希望用苹果的gpu进行训练,虽然现在pytorch框架已经支持mps(mac版本的cuda可以这么理解)训练,但是很遗憾,图神经网络的包还不支持,不过不用担心,这个模型的训练量很小,我全程都是苹果笔记本完成训练的。

源码下载

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1912762.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

2024世界人工智能大会:AI产品技术与未来趋势的深度解析

随着2024年世界人工智能大会(WAIC 2024)在上海的圆满落幕,我们见证了人工智能技术的又一次飞跃。本次大会以“以共商促共享,以善治促善智”为主题,汇聚了全球顶尖的智慧,共同探讨了AI技术的未来趋势和应用前…

妙笔生词智能写歌词软件:创新助力还是艺术之殇?

在音乐创作日益普及和多样化的当下,各种辅助工具层出不穷,妙笔生词智能写歌词软件便是其中之一。那么,它到底表现如何呢? 妙笔生词智能写歌词软件(veve522)的突出优点在于其便捷性和高效性。对于那些灵感稍…

JVM内存泄露的ThreadLocal详解

目录 一、为什么要有ThreadLocal 二、ThreadLocal的使用 三、实现解析 实现分析 具体实现 Hash冲突的解决 开放定址法 链地址法 再哈希法 建立公共溢出区 四、引发的内存泄漏分析 内存泄漏的现象 分析 总结 错误使用ThreadLocal导致线程不安全 一、为什么要有Thr…

Test-Time Adaptation via Conjugate Pseudo-labels--论文笔记

论文笔记 资料 1.代码地址 https://github.com/locuslab/tta_conjugate 2.论文地址 https://arxiv.org/abs/2207.09640 3.数据集地址 论文摘要的翻译 测试时间适应(TTA)指的是使神经网络适应分布变化,在测试时间仅访问来自新领域的未标记测试样本。以前的TT…

【pytorch24】Visdom可视化

TensorboardX pytorch有一个工具借鉴了tensorboard pip install tensorboardX 有查看变量的数值、监听曲线等功能 如何使用 新建SummaryWriter()实例 要把监听的数据,比如说要监听dummy_s1[0](y 坐标)存放到data/scalar1中,…

普中51单片机:中断系统与寄存器解析(六)

文章目录 引言中断流程图中断优先级下降沿中断结构图中断相关寄存器IE中断允许寄存器(可位寻址)XICON辅助中断控制寄存器(可位寻址)TCON标志控制寄存器SCON串行口控制寄存器 中断号中断响应条件中断函数代码模板电路图开发板IO连接…

洁净车间的压缩空气质量如何检测(露点、水油、粒子、浮游菌)

通常一个空压机站的设备即为一个狭义的压缩空气系统,下图为一个典型的压缩空气系统流程图: 气源设备(空气压缩机)吸入大气,将自然状态下的空气压缩成为具有较高压力的压缩空气,经过净化设备除去压缩空气中的…

新手如何正确学习Python?分享我是如何2个月熟练掌握Python的!学习大纲+学习方式+学习资料 汇总!

前言 一直以来都有很多想学习Python的朋友们问我,学Python怎么学?爬虫和数据分析怎么学?web开发的学习路线能教教我吗? 我先告诉大家一个点,不管你是报了什么培训班,还是自己在通过各种渠道自学&#xff…

[C++][ProtoBuf][Proto3语法][三]详细讲解

目录 1.默认值2.更新消息1.更新规则2.保留字段reserved 3.未知字段1.是什么?2.未知字段从哪获取 4.前后兼容性5.选项option1.选项分类2.常用选项列举3.设置自定义选项 1.默认值 反序列化消息时,如果被反序列化的⼆进制序列中不包含某个字段,…

elasticsearch集群模式部署

系统版本:CentOS Linux release 7.9.2009 (Core) es版本: elasticsearch-7.6.2 本次搭建es集群为三个节点 添加启动用户 确保elasticsearch的启动用户为普通用户,这里我创建了es用户用于启动elasticsearch 执行命令为es用户添加sudo权限 v…

数学建模及国赛

认识数学建模及国赛 认识数学建模 环境类:预测一下明天的气温 实证类: 评价一下政策的优缺点 农业类: 预测一下小麦的产量 财经类: 分析一下理财产品的最优组合 规划类: 土地利用情况进行 合理的划分 力学类&#xf…

如何在 CentOS 中配置 Linux 命名空间(ip netns)

引言 Linux 命名空间是一项强大的技术,允许在同一系统上创建多个独立的虚拟化实例,每个实例可以拥有自己的网络栈、路由表、IP 地址等网络资源,实现资源的隔离和管理。本文将深入探讨如何在 CentOS 中配置和使用 ip netns 命名空间&#xff0…

网络安全合规建设

网络安全合规建设 一、法律安全需求基本合规(1)《网络安全法》重要节点等级保护政策核心变化 二、安全需求 业务刚需(1)内忧(2)外患 三、解决方法(1)总安全战略目标图(2&…

PaddleVideo:Squeeze Time算法移植

参考PaddleVideo/docs/zh-CN/contribute/add_new_algorithm.md at develop PaddlePaddle/PaddleVideo GitHubAwesome video understanding toolkits based on PaddlePaddle. It supports video data annotation tools, lightweight RGB and skeleton based action recognitio…

Xilinx FPGA UltraScale SelectIO 接口逻辑资源

目录 1. 简介 2. Bank Overview 2.1 Diagram 2.2 IOB 2.3 Slice 2.4 Byte Group 2.5 I/O bank 示例 2.6 Pin Definition 2.7 数字控制阻抗(DCI) 2.8 SelectIO 管脚供电电压 2.8.1 VCCO 2.8.2 VREF 2.8.3 VCCAUX 2.8.4 VCCAUX_IO 2.8.5 VCCINT_IO 3. 总结 1. 简介…

基于信号量的生产者消费者模型

文章目录 信号量认识概念基于线程分析信号量信号量操作 循环队列下的生产者消费者模型理论认识代码部分 信号量 认识概念 信号量本质: 计数器 它也叫做公共资源 为了线程之间,进程间通信------>多个执行流看到的同一份资源---->多个资源都会并发访问这个资源(此时易出现…

【Qt课设】基于Qt实现的中国象棋

一、摘 要 本报告讨论了中国象棋程序设计的关键技术和方法。首先介绍了中国象棋的棋盘制作,利用Qt中的一些绘画类的函数来进行绘制。在创作中国象棋棋子方面,首先,我们先定义一下棋子类,将棋子中相同的部分进行打包,使…

Python:安装/Mac

之前一直陆陆续续有学python!今天开始!正式开肝!!! 进入网站:可能会有点慢,多开几个网页 https://www.python.org 点击下载,然后进入新的页面,往下滑 来到File&#xff0…

PHP验证日本免费电话号码格式

首先,您需要了解免费电话号码的格式。 日本免费电话也就那么几个号段:0120、0990、0180、0570、0800等开头的,0800稍微特殊点,在手机号里面有080 开头,但是后面不一样了。 关于免费电话号码的划分,全部写…

忘记Apple ID密码怎么退出苹果ID账号?

忘记Apple ID密码怎么退出账号?Apple ID对每个苹果用户来说都是必不可少的,没有它,用户就不能享受iCloud、App Store、iTunes等服务。苹果手机软件下载、丢失解锁、恢复出厂设置等都需要使用Apple ID。如果忘记Apple ID 密码,这会…