对话模型Demo解读(使用代码解读原理)

news2024/12/29 17:15:44

文章目录

  • 前言
  • 一、数据加工
  • 二、模型搭建
  • 三、模型训练
    • 1、构建模型
    • 2、优化器与损失函数定义
    • 3、模型训练
  • 四、模型推理
  • 五、所有Demo源码


前言

对话模型是一种人工智能技术,旨在使计算机能够像人类一样进行对话和交流。这种模型通常基于深度学习和自然语言处理技术,能够理解自然语言并做出相应的回应。然而现有博客很少介绍对话模型内容,也很少用一个简单代码带领大家理解其原理。因此,我创建一个简单的对话模型,在不适用Hugging Face或LSTM结构,旨在使用一个简单的全连接神经网络来实现这个模型,且代码基于PyTorch框架搭建,意在帮助读者构建对话模型知识。当然,模型仅是一个简单模型,旨在帮助理解原理,不具备很好效果能力。


一、数据加工

文本数据最终都是转为对应字典索引代表其文本内容,输入模型加工,实现nlp任务,对话模型也不列外。因此,我们需要构建一个字典映射(可参考:点击这里)与文本数据,并按照字典映射转换为对应索引id,其代码如下:

# 定义一个简单的对话数据集
data = [
    ("hi", "hello"),
    ("how are you?", "I'm fine, thank you."),
    ("what's your name?", "I'm a chatbot.")
]

# 构建词汇表
vocab = list(set(" ".join([x[0] + " " + x[1] for x in data])))
vocab.append("<SOS>")
vocab.append("<EOS>")

word_to_idx = {word: i for i, word in enumerate(vocab)}
idx_to_word = {i: word for i, word in enumerate(vocab)}

# 将对话数据集转换为索引序列
def to_idx_seq(sentence):
    return [word_to_idx[word] for word in sentence]

data_x = [to_idx_seq(x[0]) for x in data]
data_y = [to_idx_seq(x[1]) for x in data]

data_y =[[word_to_idx["<SOS>"]]+list(x)+[word_to_idx["<EOS>"]] for x in data_y]


其字典内容如下:
在这里插入图片描述

二、模型搭建

这里,也是最重要内容,如何搭建对话模型,我是使用transformer结构搭建(之前模型使用LSTM模型搭建),创建一个简单的对话生成模型,也使用一个基于全连接层的神经网络来实现这个模型,其代码如下:


# 定义一个简单的Transformer生成式对话模型
class ChatbotTransformer(nn.Module):
    def __init__(self, input_dim, output_dim, nhead, num_encoder_layers, num_decoder_layers):
        super(ChatbotTransformer, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.embedding = nn.Embedding(len(vocab), input_dim)
        self.transformer = nn.Transformer(d_model=input_dim, nhead=nhead, num_encoder_layers=num_encoder_layers, num_decoder_layers=num_decoder_layers)
        self.fc = nn.Linear(input_dim, output_dim)

    def forward(self, src, tgt):
        src = self.embedding(src).permute(1, 0, 2)  # 调整输入张量维度
        tgt = self.embedding(tgt).permute(1, 0, 2)  # 调整输入张量维度
        output = self.transformer(src, tgt)
        output = self.fc(output)
        return output

从代码上看,我们需要输入src为前面提问数据,而答案生成输入,是每个tgt字输入,按照顺序输出预测。

三、模型训练

1、构建模型

我大概试了一下,使用更多层对于简单数据反而效果不佳,我的数据又比较简单,我构建了较少的层来预测模型。其模型构建代码如下:

# 创建模型实例
# model = ChatbotTransformer(input_dim=256, output_dim=len(vocab), nhead=8, num_encoder_layers=6, num_decoder_layers=6)
model = ChatbotTransformer(input_dim=16, output_dim=len(vocab), nhead=8, num_encoder_layers=1, num_decoder_layers=1)

2、优化器与损失函数定义

优化器定义我将不考虑介绍,只说下文本预测是交叉熵方式,实际文本基本都采用该方法作为loss计算。其代码如下:

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)

3、模型训练

接下来,我们解读如何训练对话模型,我们获得对应输入数据与生成预测数据,我们开始训练模型,其代码如下:

# 模型训练
epochs = 800
for epoch in range(epochs):
    total_loss = 0
    for i in range(len(data_x)):
        optimizer.zero_grad()
        input_seq = torch.tensor(data_x[i])
        target_seq = torch.tensor(data_y[i])

        for j in range(len(target_seq)-1):
            if random.random() < 0.5:
                j = random.randint(0, len(target_seq)-2)
            output = model(input_seq.unsqueeze(0), target_seq[:j+1].unsqueeze(0))  # 使用目标序列的前n-1个词预测后n-1个词
            loss = criterion(output.view(-1, len(vocab)), target_seq[1:j+2].view(-1))  # 计算损失,需要将输出形状转换成二维
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

    if (epoch+1) % 10 == 0:
        print('Epoch [{}/{}], Loss: {:.4f}'.format(epoch+1, epochs, total_loss / len(data_x)))

假设input_seq =tensor([17, 6, 0, 15, 1, 8, 9, 15, 7, 6, 11, 4]),target_seq=tensor([20, 12, 18, 13, 15, 5, 16, 10, 9, 2, 15, 3, 17, 1, 10, 19, 15, 7, 6, 11, 14, 21]),vocab为字典映射,共22个映射。我将试图模型连续2轮解释一下,训练时候相关变化。
第二轮输出结果:
在这里插入图片描述
第三轮输出结果:
在这里插入图片描述
后面以此类推迭代。

很明显,提问每次都是全部输入,而输出则是第一个20开始与输入共同进模型,分别是模型src与tgt,不断重复与预测完成训练。而loss计算都是往后取一个target文本,也发现并不会计算20索引""文本。

四、模型推理

最后,让我们使用训练好的模型进行推理,实际和上面训练讲到方法类似,我们开始""文本开始,不断给出生成对应文本,也就是对话内容。

# 进行推理
def generate_response(input_sentence):
    model.eval()
    input_seq = torch.tensor(to_idx_seq(input_sentence))
    target_seq = torch.tensor([word_to_idx["<SOS>"]])  # 在开始时使用特殊的起始标记


    with torch.no_grad():
        for i in range(20):  # 限制生成的句子长度为20个词

            output = model(input_seq.unsqueeze(0), target_seq.unsqueeze(0))
            output_token = output.argmax(2)[-1].item()

            print(idx_to_word[output_token], end=" ")
            target_seq = torch.cat((target_seq, torch.tensor([output_token])), dim=0)

        res = [idx_to_word[int(k)] for k in target_seq]
    print(res[1:-1])
    return res[1:-1]

# 进行对话生成
generate_response("hi")


其结果如下:
在这里插入图片描述


五、所有Demo源码

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
# 定义一个简单的对话数据集
data = [
    # ("hi", "hello"),
    ("how are you?", "I'm fine, thank you."),
    # ("what's your name?", "I'm a chatbot.")
]

# 构建词汇表
vocab = list(set(" ".join([x[0] + " " + x[1] for x in data])))
vocab.append("<SOS>")
vocab.append("<EOS>")

word_to_idx = {word: i for i, word in enumerate(vocab)}
idx_to_word = {i: word for i, word in enumerate(vocab)}

# 将对话数据集转换为索引序列
def to_idx_seq(sentence):
    return [word_to_idx[word] for word in sentence]

data_x = [to_idx_seq(x[0]) for x in data]
data_y = [to_idx_seq(x[1]) for x in data]

data_y =[[word_to_idx["<SOS>"]]+list(x)+[word_to_idx["<EOS>"]] for x in data_y]


# 定义一个简单的Transformer生成式对话模型
class ChatbotTransformer(nn.Module):
    def __init__(self, input_dim, output_dim, nhead, num_encoder_layers, num_decoder_layers):
        super(ChatbotTransformer, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.embedding = nn.Embedding(len(vocab), input_dim)
        self.transformer = nn.Transformer(d_model=input_dim, nhead=nhead, num_encoder_layers=num_encoder_layers, num_decoder_layers=num_decoder_layers)
        self.fc = nn.Linear(input_dim, output_dim)

    def forward(self, src, tgt):
        src = self.embedding(src).permute(1, 0, 2)  # 调整输入张量维度
        tgt = self.embedding(tgt).permute(1, 0, 2)  # 调整输入张量维度
        output = self.transformer(src, tgt)
        output = self.fc(output)
        return output

# 创建模型实例
# model = ChatbotTransformer(input_dim=256, output_dim=len(vocab), nhead=8, num_encoder_layers=6, num_decoder_layers=6)
model = ChatbotTransformer(input_dim=16, output_dim=len(vocab), nhead=8, num_encoder_layers=1, num_decoder_layers=1)

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)

# 模型训练
epochs = 800
for epoch in range(epochs):
    total_loss = 0
    for i in range(len(data_x)):
        optimizer.zero_grad()
        input_seq = torch.tensor(data_x[i])
        target_seq = torch.tensor(data_y[i])

        for j in range(len(target_seq)-1):
            if random.random() < 0.5:
                j = random.randint(0, len(target_seq)-2)

            output = model(input_seq.unsqueeze(0), target_seq[:j+1].unsqueeze(0))  # 使用目标序列的前n-1个词预测后n-1个词
            loss = criterion(output.view(-1, len(vocab)), target_seq[1:j+2].view(-1))  # 计算损失,需要将输出形状转换成二维
            loss.backward()
            optimizer.step()
            total_loss += loss.item()


    if (epoch+1) % 10 == 0:
        print('Epoch [{}/{}], Loss: {:.4f}'.format(epoch+1, epochs, total_loss / len(data_x)))

# 进行推理
def generate_response(input_sentence):
    model.eval()
    input_seq = torch.tensor(to_idx_seq(input_sentence))
    target_seq = torch.tensor([word_to_idx["<SOS>"]])  # 在开始时使用特殊的起始标记


    with torch.no_grad():
        for i in range(20):  # 限制生成的句子长度为20个词

            output = model(input_seq.unsqueeze(0), target_seq.unsqueeze(0))
            output_token = output.argmax(2)[-1].item()

            print(idx_to_word[output_token], end=" ")
            target_seq = torch.cat((target_seq, torch.tensor([output_token])), dim=0)

        res = [idx_to_word[int(k)] for k in target_seq]
    print(res[1:-1])
    return res[1:-1]

# 进行对话生成
generate_response("how are you?")


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1443868.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

MongoDB从入门到实战之.NET Core使用MongoDB开发ToDoList系统(1)-后端项目框架搭建

前言&#xff1a; 前面的四个章节我们主要讲解了MongoDB的相关基础知识&#xff0c;接下来我们就开始进入使用.NET7操作MongoDB开发一个ToDoList系统实战教程。本章节主要介绍的是如何快熟搭建一个简单明了的后端项目框架。 MongoDB从入门到实战的相关教程 MongoDB从入门到实战…

从信息隐藏到功能隐藏

本文主要记录复旦大学张新鹏教授于2022年12月在第三届CSIG中国媒体取证与安全大会上的汇报

蓝桥杯Web应用开发-CSS3 新特性【练习一:属性有效性验证】

练习一&#xff1a;属性有效性验证 页面上有一个邮箱输入框&#xff0c;当你的输入满足邮箱格式时&#xff0c;输入框的背景颜色为绿色&#xff1b;当你的输入不满足要求&#xff0c;背景颜色为红色。 新建一个 index2.html 文件&#xff0c;在其中写入以下内容。 <!DOCTYP…

Stata实证命令代码汇总

Stata代码命令汇总 数据内容&#xff1a;包括数据导入和管理、数据的处理、描述性统计、相关性分析、实证模型、内生性解决、检验分析、结果导出 具体如下&#xff1a; 一、数据导入和管理&#xff1a;数据导入、数据导出 二、数据的处理&#xff1a;生成新变量、格式转换、…

计算机二级C语言备考学习记录

一、C语言程序的结构 1.程序的构成&#xff0c;main函数和其他函数。 程序是由main函数和其他函数构成main作为主函数&#xff0c;一个C程序里只有一个main函数其他函数可以分为系统函数和用户函数&#xff0c;系统函数为编译系统提供&#xff0c;用户函数由用户自行编写 2.…

北斗卫星在物联网时代的应用探索

北斗卫星在物联网时代的应用探索 在当今数字化时代&#xff0c;物联网的应用已经深入到人们的生活中的方方面面&#xff0c;让我们的生活更加智能便捷。而北斗卫星系统作为我国自主研发的卫星导航系统&#xff0c;正为物联网的发展提供了强有力的支撑和保障。本文将全面介绍北…

爬虫练习——动态网页的爬取(股票和百度翻译)

动态网页也是字面意思&#xff1a;实时更新的那种 还有就是你在股票这个网站上&#xff0c;翻页。他的地址是不变的 是动态的加载&#xff0c;真正我不太清楚&#xff0c;只知道他是不变的。如果用静态网页的方法就不可行了。 静态网页的翻页&#xff0c;是网址是有规律的。 …

【Linux】信号概念与信号产生

信号概念与信号产生 一、初识信号1. 信号概念2. 前台进程和后台进程3. 认识信号4. 技术应用角度的信号 二、信号的产生1. 键盘组合键2. kill 命令3. 系统调用4. 异常&#xff08;1&#xff09;观察现象&#xff08;2&#xff09;理解本质 5. 软件条件闹钟 一、初识信号 1. 信号…

【网络】:序列化和反序列化

序列化和反序列化 一.json库 二.简单使用json库 前面已经讲过TCP和UDP&#xff0c;也写过代码能够进行双方的通信了&#xff0c;那么有没有可能这种通信是不安全的呢&#xff1f;如果直接通信&#xff0c;可能会被底层捕捉&#xff1b;可能由于网络问题&#xff0c;一方只接收到…

k8s-资源限制与监控 15

资源限制 上传实验所需镜像 Kubernetes采用request和limit两种限制类型来对资源进行分配。 request(资源需求)&#xff1a;即运行Pod的节点必须满足运行Pod的最基本需求才能 运行Pod。 limit(资源限额)&#xff1a;即运行Pod期间&#xff0c;可能内存使用量会增加&#xff0…

区间dp 笔记

区间dp一般是先枚举区间长度&#xff0c;再枚举左端点&#xff0c;再枚举分界点&#xff0c;时间复杂度为 环形石子合并 将 n 堆石子绕圆形操场排放&#xff0c;现要将石子有序地合并成一堆。 规定每次只能选相邻的两堆合并成新的一堆&#xff0c;并将新的一堆的石子数记做该…

分布式搜索引擎 elasticsearch

分布式搜索引擎 elasticsearch 第一部分 1.初识elasticsearch 1.1.了解ES 1.1.1.elasticsearch的作用 elasticsearch是一款非常强大的开源搜索引擎&#xff0c;具备非常多强大功能&#xff0c;可以帮助我们从海量数据中快速找到需要的内容 例如&#xff1a; 在GitHub搜索…

159基于matlab的基于密度的噪声应用空间聚类(DBSCAN)算法对点进行聚类

基于matlab的基于密度的噪声应用空间聚类(DBSCAN)算法对点进行聚类&#xff0c;聚类结果效果好&#xff0c;DBSCAN不要求我们指定集群的数量&#xff0c;避免了异常值&#xff0c;并且在任意形状和大小的集群中工作得非常好。它没有质心&#xff0c;聚类簇是通过将相邻的点连接…

[论文总结] 深度学习在农业领域应用论文笔记12

文章目录 1. 3D-ZeF: A 3D Zebrafish Tracking Benchmark Dataset (CVPR, 2020)摘要背景相关研究所提出的数据集方法和结果个人总结 2. Automated flower classification over a large number of classes (Computer Vision, Graphics & Image Processing, 2008)摘要背景分割…

猜猜谁是凶手?

目录 一、题目二、思路三、完整代码 一、题目 日本某地发生了一件谋杀案&#xff0c;警察通过排查确定杀人凶手必为4个嫌疑犯的一个。 以下为4个嫌疑犯的供词: A说&#xff1a;不是我。 B说&#xff1a;是C。 C说&#xff1a;是D。 D说&#xff1a;C在胡说 已知3个人说了…

hexo 博客搭建以及踩雷总结

搭建时的坑 文章置顶 安装一下这个依赖 npm install hexo-generator-topindex --save然后再文章的上面设置 top: number&#xff0c;数字越大&#xff0c;权重越大&#xff0c;也就是越靠顶部 hexo 每次推送 nginx 都访问不到 宝塔自带的 nginx 的 config 里默认的角色是 …

RabbitMQ高级篇

消息队列在使用过程中&#xff0c;面临着很多实际问题需要思考&#xff1a; 一、消息可靠性 消息从发送&#xff0c;到消费者接收&#xff0c;会经历多个过程&#xff1a; 其中的每一步都可能导致消息丢失&#xff0c;常见的丢失原因包括&#xff1a; 发送时丢失&#xff1a;…

【sentinel流量卫兵配置持久化到Nacos】

sentinel流量卫兵配置持久化到Nacos 概述&#xff1a; 一、添加配置二、配置说明限流规则配置&#xff1a;降级规则配置&#xff1a;热点规则配置&#xff1a;授权规则配置&#xff1a;系统规则配置&#xff1a; 三、服务整合 概述&#xff1a; 控制台配置的参数&#xff0c;默…

政安晨:示例演绎TensorFlow的官方指南(一){基础知识}

为什么要示例演绎&#xff1f; 既然有了官方指南&#xff0c;咱们在官方指南上看看就可以了&#xff0c;为什么还要写示例演绎的文章呢&#xff1f; 其实对于初步了解TensorFlow的小伙伴们而言&#xff0c;示例演绎才是最重要的。 官方文档已经假定了您已经具备了相当合适的…

卫星通讯领域FPGA关注技术:算法和图像方面(4)

最近关注的公众号提到了从事移动通信、卫星通讯等领域的FPGA、ASIC、信号处理算法等工程师可能需要关注的技术&#xff0c;有5G NTN、多址技术、低轨通信卫星LEO&#xff0c;以下做了一些基础的调研&#xff1a; 1 5G NTN 来自《5G NTN技术白皮书&#xff1a;天地一体、手机直…