【从零开始实现意图识别】中文对话意图识别详解

news2024/9/24 5:26:52

前言

意图识别(Intent Recognition)是自然语言处理(NLP)中的一个重要任务,它旨在确定用户输入的语句中所表达的意图或目的。简单来说,意图识别就是对用户的话语进行语义理解,以便更好地回答用户的问题或提供相关的服务。

在NLP中,意图识别通常被视为一个分类问题,即通过将输入语句分类到预定义的意图类别中来识别其意图。这些类别可以是各种不同的任务、查询、请求等,例如搜索、购买、咨询、命令等。

下面是一个简单的例子来说明意图识别的概念:

用户输入: "我想订一张从北京到上海的机票。

意图识别:预订机票。

在这个例子中,通过将用户输入的语句分类到“预订机票”这个意图类别中,系统可以理解用户的意图并为其提供相关的服务。

意图识别是NLP中的一项重要任务,它可以帮助我们更好地理解用户的需求和意图,从而为用户提供更加智能和高效的服务。

在智能对话任务中,意图识别是一种非常重要的技术,它可以帮助系统理解用户的输入,从而提供更加准确和个性化的回答和服务。

模型

意图识别和槽位填充是对话系统中的基础任务。本仓库实现了一个基于BERT的意图(intent)和槽位(slots)联合预测模块。想法上实际与JoinBERT类似(GitHub:BERT for Joint Intent Classification and Slot Filling),利用 [CLS] token对应的last hidden state去预测整句话的intent,并利用句子tokens的last hidden states做序列标注,找出包含slot values的tokens。你可以自定义自己的意图和槽位标签,并提供自己的数据,通过下述流程训练自己的模型,并在JointIntentSlotDetector类中加载训练好的模型直接进行意图和槽值预测。

源GitHub:https://github.com/Linear95/bert-intent-slot-detector

在本文使用的模型中对数据进行了扩充、对代码进行注释、对部分代码进行了修改

Bert模型下载

Bert模型下载地址:https://huggingface.co/bert-base-chinese/tree/main

下载下方红框内的模型即可。

数据集介绍

训练数据以json格式给出,每条数据包括三个关键词:text表示待检测的文本,intent代表文本的类别标签,slots是文本中包括的所有槽位以及对应的槽值,以字典形式给出。

{

"text": "搜索西红柿的做法。",

"domain": "cookbook",

"intent": "QUERY",

"slots": {"ingredient": "西红柿"}

}

原始数据集:https://conference.cipsc.org.cn/smp2019/

本项目中在原始数据集中新增了部分数据,用来平衡数据。

模型训练

python train.py

# -----------training-------------
max_acc = 0
for epoch in range(args.train_epochs):
    total_loss = 0
    model.train()
    for step, batch in enumerate(train_dataloader):
        input_ids, intent_labels, slot_labels = batch

        outputs = model(
            input_ids=torch.tensor(input_ids).long().to(device),
            intent_labels=torch.tensor(intent_labels).long().to(device),
            slot_labels=torch.tensor(slot_labels).long().to(device)
        )

        loss = outputs['loss']
        total_loss += loss.item()

        if args.gradient_accumulation_steps > 1:
            loss = loss / args.gradient_accumulation_steps

        loss.backward()

        if step % args.gradient_accumulation_steps == 0:
            # 用于对梯度进行裁剪,以防止在神经网络训练过程中出现梯度爆炸的问题。
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)

            optimizer.step()
            scheduler.step()
            model.zero_grad()

    train_loss = total_loss / len(train_dataloader)

    dev_acc, intent_avg, slot_avg = dev(model, val_dataloader, device, slot_dict)

    flag = False
    if max_acc < dev_acc:
        max_acc = dev_acc
        flag = True
        save_module(model, model_save_dir)
    print(f"[{epoch}/{args.train_epochs}] train loss: {train_loss}  dev intent_avg: {intent_avg} "
          f"def slot_avg: {slot_avg} save best model: {'*' if flag else ''}")

dev_acc, intent_avg, slot_avg = dev(model, val_dataloader, device, slot_dict)
print("last model dev intent_avg: {} def slot_avg: {}".format(intent_avg, slot_avg))

运行过程:

模型推理

python predict.py
 
def detect(self, text, str_lower_case=True):
    """
    text : list of string, each string is a utterance from user
    """
    list_input = True

    if isinstance(text, str):
        text = [text]
        list_input = False

    if str_lower_case:
        text = [t.lower() for t in text]

    batch_size = len(text)

    inputs = self.tokenizer(text, padding=True)

    with torch.no_grad():
        outputs = self.model(input_ids=torch.tensor(inputs['input_ids']).long().to(self.device))

    intent_logits = outputs['intent_logits']
    slot_logits = outputs['slot_logits']

    intent_probs = torch.softmax(intent_logits, dim=-1).detach().cpu().numpy()
    slot_probs = torch.softmax(slot_logits, dim=-1).detach().cpu().numpy()

    slot_labels = self._predict_slot_labels(slot_probs)
    intent_labels = self._predict_intent_labels(intent_probs)

    slot_values = self._extract_slots_from_labels(inputs['input_ids'], slot_labels, inputs['attention_mask'])

    outputs = [{'text': text[i], 'intent': intent_labels[i], 'slots': slot_values[i]}
               for i in range(batch_size)]

    if not list_input:
        return outputs[0]

    return outputs

推理结果:

模型检测相关代码

将概率值转换为实际标注值

def _predict_slot_labels(self, slot_probs):
    """
    slot_probs : probability of a batch of tokens into slot labels, [batch, seq_len, slot_label_num], numpy array
    """
    slot_ids = np.argmax(slot_probs, axis=-1)
    return self.slot_dict[slot_ids.tolist()]

def _predict_intent_labels(self, intent_probs):
    """
    intent_labels : probability of a batch of intent ids into intent labels, [batch, intent_label_num], numpy array
    """
    intent_ids = np.argmax(intent_probs, axis=-1)
    return self.intent_dict[intent_ids.tolist()]

槽位验证(确保检测结果的正确性)

def _extract_slots_from_labels_for_one_seq(self, input_ids, slot_labels, mask=None):
    results = {}
    unfinished_slots = {}  # dict of {slot_name: slot_value} pairs
    if mask is None:
        mask = [1 for _ in range(len(input_ids))]

    def add_new_slot_value(results, slot_name, slot_value):
        if slot_name == "" or slot_value == "":
            return results
        if slot_name in results:
            results[slot_name].append(slot_value)
        else:
            results[slot_name] = [slot_value]
        return results

    for i, slot_label in enumerate(slot_labels):
        if mask[i] == 0:
            continue
        # 检测槽位的第一字符(B_)开头
        if slot_label[:2] == 'B_':
            slot_name = slot_label[2:]  # 槽位名称 (B_ 后面)
            if slot_name in unfinished_slots:
                results = add_new_slot_value(results, slot_name, unfinished_slots[slot_name])
            unfinished_slots[slot_name] = self.tokenizer.decode(input_ids[i])
        # 检测槽位的后面字符(I_)开头
        elif slot_label[:2] == 'I_':
            slot_name = slot_label[2:]
            if slot_name in unfinished_slots and len(unfinished_slots[slot_name]) > 0:
                unfinished_slots[slot_name] += self.tokenizer.decode(input_ids[i])

    for slot_name, slot_value in unfinished_slots.items():
        if len(slot_value) > 0:
            results = add_new_slot_value(results, slot_name, slot_value)

    return results

源码获取

NLP/bert-intent-slot at main · mzc421/NLP (github.com)icon-default.png?t=N7T8https://github.com/mzc421/NLP/tree/main/bert-intent-slot

链接作者

欢迎关注我的公众号:@AI算法与电子竞赛

硬性的标准其实限制不了无限可能的我们,所以啊!少年们加油吧!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1244176.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

2020年12月 Scratch(三级)真题解析#中国电子学会#全国青少年软件编程等级考试

Scratch等级考试(1~4级)全部真题・点这里 一、单选题(共25题,每题2分,共50分) 第1题 关于广播消息,以下说法正确的是? A:只有角色,可以通过“广播消息”积木,向其他角色或是背景发送消息 B:只有背景,可以通过“广播消息”积木,向其他角色或是背景发送消息 C:背…

远端WWW服务支持TRACE请求

安全扫描的时候&#xff0c;扫出来的问题&#xff0c;这里不分享如何处理&#xff0c;就只分享下&#xff0c;如何找到有问题的端口。 通过命令 curl -v -X TRACE -I ip:port&#xff0c;这里的ip和端口就是扫描出有问题的服务器地址ip以及开放的服务端口。 观察返回值&#x…

建设数字工厂管理系统对企业来说有哪些优势

随着科技的飞速发展&#xff0c;数字化转型已成为企业持续发展的必由之路。在这一背景下&#xff0c;建设数字工厂管理系统显得尤为重要。本文将详细分析数字工厂管理系统给企业带来的优势&#xff0c;以及企业如何选择合适的管理系统和成功实施数字化转型。 一、数字工厂管理系…

【华为数通HCIP | 网络工程师】821-IGP高频题、易错题之OSPF(1)

个人名片&#xff1a; &#x1f43c;作者简介&#xff1a;一名大三在校生&#xff0c;喜欢AI编程&#x1f38b; &#x1f43b;‍❄️个人主页&#x1f947;&#xff1a;落798. &#x1f43c;个人WeChat&#xff1a;hmmwx53 &#x1f54a;️系列专栏&#xff1a;&#x1f5bc;️…

【AIGC重塑教育】AI大爆发的时代,未来的年轻人怎样获得机会和竞争力?

目录 AI浪潮来袭 AI与教育 AI的优势 延伸阅读 推荐语 ​作者&#xff1a;刘文勇 来源&#xff1a;IT阅读排行榜 本文摘编自《AIGC重塑教育&#xff1a;AI大模型驱动的教育变革与实践》&#xff0c;机械工业出版社出版 AI浪潮来袭 这次&#xff0c;狼真的来了。 AI正迅猛地…

【c语言】重温一下动态内存,int数组过大会造成栈错误

项目场景&#xff1a; 项目场景&#xff1a;互助群同学在刷题的过程中&#xff0c;遇到的一个题目&#xff0c;需要申请一个很大数组&#xff0c;于是这个同学就写了int[1000000],其实这样写也没有错&#xff0c;可是运行后却显示栈错误。于是就找到我来请教&#xff0c;我想就…

mapTR环境配置和代码复现

MAPTR: STRUCTURED MODELING AND LEARNING FOR ONLINE VECTORIZED HD MAP CONSTRUCTION 论文 :https://arxiv.org/pdf/2208.14437.pdf 代码:https://github.com/hustvl/MapTR MapTR,是一个结构化的端到端框架,用于高效的在线矢量化高精地图构建。我们提出了一种基于统一…

Linux开发工具(含gdb调试教程)

文章目录 Linux开发工具&#xff08;含gdb调试教程&#xff09;1、Linux 软件包管理器 yum2、Linux开发工具2.1、Linux编辑器 -- vim的使用2.1.1、vim的基本概念2.1.2、vim的基本操作2.1.3、vim正常模式命令集2.1.4、vim末行模式命令集 2.2、vim简单配置 3、Linux编译器 -- gcc…

数仓成本下降近一半,StarRocks 存算分离助力云览科技业务出海

成都云览科技有限公司倾力打造了凤凰浏览器&#xff0c;专注于为海外用户提供服务&#xff0c;公司致力于构建一个全球性的数字内容连接入口&#xff0c;为用户带来更为优质、高效、个性化的浏览体验。 作为数据驱动的高科技公司&#xff0c;从数据中挖掘价值一直是公司核心任务…

佳易王羽毛球馆计时计费软件灯控系统安装教程

佳易王羽毛球馆计时计费软件灯控系统安装教程 佳易王羽毛球馆计时计费软件&#xff0c;点击开始计时的时候&#xff0c;自动打开灯&#xff0c;结账后自动关闭灯。 因为场馆每一场地的灯功率都很大&#xff0c;需要加装交流接触器。这个由专业电工施工。 1、计时计费功能 &…

玻色量子“揭秘”之多项式回归问题与QUBO建模

摘要&#xff1a;多项式回归&#xff08;Polynomial Regression&#xff09;是一种回归分析方法&#xff0c;通过拟合一个多项式方程来模拟自变量与因变量之间的非线性关系。多项式回归的目标是找到一组多项式系数&#xff0c;使得拟合曲线尽可能地接近数据点。这种方法可以用于…

Arm64版本的centos编译muduo库遇到的问题的归纳

环境&#xff1a;Mac m2 pro下的VMware虚拟机中Arm64 centos ./build.sh 执行后提示如下 cmake -DCMAKE_BUILD_TYPErelease -DCMAKE_INSTALL_PREFIX…/release-install-cpp11 -DCMAKE_EXPORT_COMPILE_COMMANDSON /root/package/muduo-master – Boost version: 1.69.0 – Co…

12.docker的网络-host模式

1.docker的host网络模式简介 host模式下&#xff0c;容器将不会虚拟出自己的网卡、配置IP等&#xff0c;而是使用宿主机的IP和端口&#xff1b;也就说&#xff0c;宿主机的就是我的。 2. 以host网络模式创建容器 2.1 创建容器 我们仍然以tomcat这个镜像来说明一下。我们以h…

Java之《ATM自动取款机》(面向对象)

《JAVA编程基础》项目说明 一、项目名称&#xff1a; 基于JAVA控制台版本银行自动取款机 项目要求&#xff1a; 实现银行自动取款机的以下基本操作功能&#xff1a;读卡、取款、查询。&#xff08;自动取款机中转账、修改密码不作要求&#xff09; 具体要求&#xff1a; 读卡…

CSS实现三角形

CSS实现三角形 前言第一种:bordertransparent第二种borderrgb使用unicode字符 前言 本文讲解三种实现三角形的方式&#xff0c;并且配有图文以及代码解说。那么好&#xff0c;本文正式开始。 第一种:bordertransparent border是边框&#xff0c;而transparent是透明的颜色&a…

计数排序+桶排序+基数排序 详讲(思路+图解+代码详解)

文章目录 计数排序桶排序基数排序一、计数排序概念&#xff1a;写法一&#xff1a;写法二&#xff1a; 二、桶排序概念代码 三、基数排序概念1.LSD排序法&#xff08;最低位优先法&#xff09;2.MSD排序法&#xff08;最高位优先法&#xff09; 基数排序VS基数排序VS桶排序 计数…

C++语法知识点-vector+子数组

C语法知识点-vector子数组 一维数组定义无参数有参数迭代器扩容操作reserve 二维数组 vector 定义创建m*n的二维vectorvector< vector<int> > v(m, vector<int>(n) ) 初始化定义vector常用函数的实例分析访问操作resize 函数push _back ( )pop_back()函数siz…

宣传技能培训1——《新闻摄影技巧》光影魔法:理解不同光线、角度、构图的摄影效果,以及相机实战操作 + 新闻摄影实例讲解

新闻摄影技巧 写在最前面摘要 构图与拍摄角度景别人物表情与叙事远景与特写 构图与拍摄角度案例 主体、陪体、前景、背景强调主体利用前景和背景层次感的创造 探索新闻摄影中的构图技巧基本构图技巧构图技巧的应用实例实例分析1. 黄金分割和九宫格2. 三角型构图3. 引导线构图4.…

2 使用React构造前端应用

文章目录 简单了解React和Node搭建开发环境React框架JavaScript客户端ChallengeComponent组件的主要结构渲染与应用程序集成 第一次运行前端调试将CORS配置添加到Spring Boot应用使用应用程序部署React应用程序小结 前端代码可从这里下载&#xff1a; 前端示例 后端使用这里介…