双塔模型dssm实践

news2025/1/13 13:58:39

最近在学习向量召回,向量召回不得不用到dssm双塔模型,双塔模型的原理非常简单,就是用两个任务塔,一个是query侧的query任务塔,另一个是doc侧的doc任务塔,任务塔向上抽象形成verctor隐向量后,用cosin相似度度计算两个向量之间的相似程度,进而让query和doc的向量尽量的相近。
模型整体的结构如下所示,线上工程使用的时候,是将doc的隐向量进行离线保存,用faiss类似工具形成线上的向量检索索引,然后线上用户query查询的时候,就可以通过线上部署的query侧任务塔,实时的生成query侧的隐向量,然后通过query隐向量去查询faiss中相似的doc向量结果,达到query语义召回的目的。
在这里插入图片描述

代码:模型中的encoder是可以用任何nlp编码模型实现,比如说可以用预训练的bert实现。

class DSSM(nn.Module):
    """
    DSSM(Deep Structured Semantic Model) 模型实现, 采用cos值计算向量相似度, 精度稍低, 但计算速度快。
    Paper Reference: https://posenhuang.github.io/papers/cikm2013_DSSM_fullversion.pdf

    Args:
        nn (_type_): _description_
    """

    def __init__(self, encoder, dropout=None):
        """
        init func.

        Args:
            encoder (transformers.AutoModel): backbone, 默认使用 ernie 3.0
            dropout (float): dropout.
        """
        super().__init__()
        self.encoder = encoder
        hidden_size = 768
        self.dropout = nn.Dropout(dropout if dropout is not None else 0.1)

    def forward(
        self,
        input_ids: torch.tensor,
        token_type_ids: torch.tensor,
        attention_mask: torch.tensor
    ) -> torch.tensor:
        """
        forward 函数,输入单句子,获得单句子的embedding。

        Args:
            input_ids (torch.LongTensor): (batch, seq_len)
            token_type_ids (torch.LongTensor): (batch, seq_len)
            attention_mask (torch.LongTensor): (batch, seq_len)

        Returns:
            torch.tensor: embedding -> (batch, hidden_size)
        """
        embedding = self.encoder(
                input_ids=input_ids,
                token_type_ids=token_type_ids,
                attention_mask=attention_mask
            )["pooler_output"]                                  # (batch, hidden_size)
        return embedding

    def get_similarity(
        self,
        query_input_ids: torch.tensor,
        query_token_type_ids: torch.tensor,
        query_attention_mask: torch.tensor,
        doc_input_ids: torch.tensor,
        doc_token_type_ids: torch.tensor,
        doc_attention_mask: torch.tensor
    ) -> torch.tensor:
        """
        输入query和doc的向量,返回query和doc两个向量的余弦相似度。

        Args:
            query_input_ids (torch.LongTensor): (batch, seq_len)
            query_token_type_ids (torch.LongTensor): (batch, seq_len)
            query_attention_mask (torch.LongTensor): (batch, seq_len)
            doc_input_ids (torch.LongTensor): (batch, seq_len)
            doc_token_type_ids (torch.LongTensor): (batch, seq_len)
            doc_attention_mask (torch.LongTensor): (batch, seq_len)

        Returns:
            torch.tensor: (batch, 1)
        """
        print('query_input_ids=', query_input_ids, query_input_ids.shape)
        # print('query_token_type_ids=', query_token_type_ids)
        # print('query_attention_mask=', query_attention_mask)
        query_embedding = self.encoder(
            input_ids=query_input_ids,
            token_type_ids=query_token_type_ids,
            attention_mask=query_attention_mask
        )["pooler_output"]                                 # (batch, hidden_size)
        print('query_embedding=', query_embedding, query_embedding.shape)
        query_embedding = self.dropout(query_embedding)

        doc_embedding = self.encoder(
            input_ids=doc_input_ids,
            token_type_ids=doc_token_type_ids,
            attention_mask=doc_attention_mask
        )["pooler_output"]                                  # (batch, hidden_size)
        doc_embedding = self.dropout(doc_embedding)

        similarity = nn.functional.cosine_similarity(query_embedding, doc_embedding)
        return similarity

train代码:

# !/usr/bin/env python3

import os
import time
import argparse
from functools import partial

import torch
from torch.utils.data import DataLoader
import evaluate
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModel, default_data_collator, get_scheduler

from model import DSSM
from utils import convert_dssm_example
from iTrainingLogger import iSummaryWriter


parser = argparse.ArgumentParser()
parser.add_argument("--model", default='bert-base-chinese', type=str, help="backbone of encoder.")
parser.add_argument("--train_path", default='./data/dssm_data/train.csv', type=str, help="The path of train set.")
parser.add_argument("--dev_path", default='./data/dssm_data/train.csv', type=str, help="The path of dev set.")
parser.add_argument("--save_dir", default="./checkpoints", type=str, required=False, help="The output directory where the model predictions and checkpoints will be written.")
parser.add_argument("--max_seq_len", default=512, type=int,help="The maximum total input sequence length after tokenization. Sequences longer "
    "than this will be truncated, sequences shorter will be padded.", )
parser.add_argument("--batch_size", default=16, type=int, help="Batch size per GPU/CPU for training.", )
parser.add_argument("--learning_rate", default=1e-5, type=float, help="The initial learning rate for Adam.")
parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.")
parser.add_argument("--num_train_epochs", default=10, type=int, help="Total number of training epochs to perform.")
parser.add_argument("--warmup_ratio", default=0.06, type=float, help="Linear warmup over warmup_ratio * total_steps.")
parser.add_argument("--valid_steps", default=200, type=int, required=False, help="evaluate frequecny.")
parser.add_argument("--logging_steps", default=10, type=int, help="log interval.")
parser.add_argument('--device', default="cpu", help="Select which device to train model, defaults to gpu.")
parser.add_argument("--img_log_dir", default='logs', type=str, help="Logging image path.")
parser.add_argument("--img_log_name", default='Model Performance', type=str, help="Logging image file name.")
args = parser.parse_args()

writer = iSummaryWriter(log_path=args.img_log_dir, log_name=args.img_log_name)


def evaluate_model(model, metric, data_loader, global_step):
    """
    在测试集上评估当前模型的训练效果。

    Args:
        model: 当前模型
        metric: 评估指标类(metric)
        data_loader: 测试集的dataloader
        global_step: 当前训练步数
    """
    model.eval()
    with torch.no_grad():
        for step, batch in enumerate(data_loader):
            logits = model.get_similarity(query_input_ids=batch['query_input_ids'].to(args.device),
                            query_token_type_ids=batch['query_token_type_ids'].to(args.device),
                            query_attention_mask=batch['query_attention_mask'].to(args.device),
                            doc_input_ids=batch['doc_input_ids'].to(args.device),
                            doc_token_type_ids=batch['doc_token_type_ids'].to(args.device),
                            doc_attention_mask=batch['doc_attention_mask'].to(args.device))
            logits[logits>=0.5] = 1
            logits[logits<0.5] = 0
            metric.add_batch(predictions=logits.to(torch.int), references=batch["labels"])
    eval_metric = metric.compute()
    model.train()
    return eval_metric['accuracy'], eval_metric['precision'], eval_metric['recall'], eval_metric['f1']


def train():
    encoder = AutoModel.from_pretrained(args.model)
    model = DSSM(encoder)
    tokenizer = AutoTokenizer.from_pretrained(args.model)
    dataset = load_dataset('text', data_files={'train': args.train_path,
                                                'dev': args.dev_path})    
    print(dataset)
    convert_func = partial(convert_dssm_example, tokenizer=tokenizer, max_seq_len=args.max_seq_len)
    dataset = dataset.map(convert_func, batched=True)
    
    train_dataset = dataset["train"]
    eval_dataset = dataset["dev"]
    train_dataloader = DataLoader(train_dataset, shuffle=True, collate_fn=default_data_collator, batch_size=args.batch_size)
    eval_dataloader = DataLoader(eval_dataset, collate_fn=default_data_collator, batch_size=args.batch_size)

    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
            "weight_decay": args.weight_decay,
        },
        {
            "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
            "weight_decay": 0.0,
        },
    ]
    optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=5e-5)
    model.to(args.device)

    # 根据训练轮数计算最大训练步数,以便于scheduler动态调整lr
    num_update_steps_per_epoch = len(train_dataloader)
    max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
    warm_steps = int(args.warmup_ratio * max_train_steps)
    lr_scheduler = get_scheduler(
        name='linear',
        optimizer=optimizer,
        num_warmup_steps=warm_steps,
        num_training_steps=max_train_steps,
    )

    loss_list = []
    metric = evaluate.combine(["accuracy", "f1", "precision", "recall"])
    criterion = torch.nn.MSELoss()
    tic_train = time.time()
    global_step, best_f1 = 0, 0
    for epoch in range(1, args.num_train_epochs+1):
        for batch in train_dataloader:
            logits = model.get_similarity(query_input_ids=batch['query_input_ids'].to(args.device),
                            query_token_type_ids=batch['query_token_type_ids'].to(args.device),
                            query_attention_mask=batch['query_attention_mask'].to(args.device),
                            doc_input_ids=batch['doc_input_ids'].to(args.device),
                            doc_token_type_ids=batch['doc_token_type_ids'].to(args.device),
                            doc_attention_mask=batch['doc_attention_mask'].to(args.device))
            labels = batch['labels'].to(torch.float).to(args.device)
            loss = criterion(logits, labels)
            loss.backward()
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()
            loss_list.append(float(loss.cpu().detach()))
            
            if global_step % args.logging_steps == 0:
                time_diff = time.time() - tic_train
                loss_avg = sum(loss_list) / len(loss_list)
                writer.add_scalar('train/train_loss', loss_avg, global_step)
                print("global step %d, epoch: %d, loss: %.5f, speed: %.2f step/s"
                        % (global_step, epoch, loss_avg, args.logging_steps / time_diff))
                tic_train = time.time()

            if global_step % args.valid_steps == 0:
                cur_save_dir = os.path.join(args.save_dir, "model_%d" % global_step)
                if not os.path.exists(cur_save_dir):
                    os.makedirs(cur_save_dir)
                torch.save(model, os.path.join(cur_save_dir, 'model.pt'))
                tokenizer.save_pretrained(cur_save_dir)

                acc, precision, recall, f1 = evaluate_model(model, metric, eval_dataloader, global_step)
                writer.add_scalar('eval/accuracy', acc, global_step)
                writer.add_scalar('eval/precision', precision, global_step)
                writer.add_scalar('eval/recall', recall, global_step)
                writer.add_scalar('eval/f1', f1, global_step)
                writer.record()

                print("Evaluation precision: %.5f, recall: %.5f, F1: %.5f" % (precision, recall, f1))
                if f1 > best_f1:
                    print(
                        f"best F1 performence has been updated: {best_f1:.5f} --> {f1:.5f}"
                    )
                    best_f1 = f1
                    cur_save_dir = os.path.join(args.save_dir, "model_best")
                    if not os.path.exists(cur_save_dir):
                        os.makedirs(cur_save_dir)
                    torch.save(model, os.path.join(cur_save_dir, 'model.pt'))
                    tokenizer.save_pretrained(cur_save_dir)
                tic_train = time.time()
            
            global_step += 1


if __name__ == '__main__':
    train()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/627054.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【多同步挤压变换】基于多同步挤压变换处理时变信号和噪声信号研究(Matlab代码实现)

&#x1f4a5;&#x1f4a5;&#x1f49e;&#x1f49e;欢迎来到本博客❤️❤️&#x1f4a5;&#x1f4a5; &#x1f3c6;博主优势&#xff1a;&#x1f31e;&#x1f31e;&#x1f31e;博客内容尽量做到思维缜密&#xff0c;逻辑清晰&#xff0c;为了方便读者。 ⛳️座右铭&a…

2022年国赛高教杯数学建模B题无人机遂行编队飞行中的纯方位无源定位解题全过程文档及程序

2022年国赛高教杯数学建模 B题 无人机遂行编队飞行中的纯方位无源定位 原题再现 无人机集群在遂行编队飞行时&#xff0c;为避免外界干扰&#xff0c;应尽可能保持电磁静默&#xff0c;少向外发射电磁波信号。为保持编队队形&#xff0c;拟采用纯方位无源定位的方法调整无人机…

win10录屏软件哪个好用?强烈推荐这3款!

案例&#xff1a;想要录制我的电脑屏幕&#xff0c;但是不知道如何选择合适的录屏工具&#xff0c;有没有好用的win10录屏软件&#xff1f; 【我想找一款好用的win10录屏工具&#xff0c;录制我的电脑屏幕&#xff0c;但是找了很久还没有找到&#xff0c;大家有好用的录屏软件…

Kibana:使用 Kibana 自带数据进行可视化(二)

在今天的练习中&#xff0c;我们将使用 Kibana 自带的数据来进行一些可视化的展示。希望对刚开始使用 Kibana 的用户有所帮助。这个是继上一篇文章 “Kibana&#xff1a;使用 Kibana 自带数据进行可视化&#xff08;一&#xff09;” 的续篇。 前提条件 如果你还没有安装好自己…

占据80%中国企业出海市场,亚马逊云科技如何为出海客户提供更多资源和附加值

亚马逊云科技就可以做到&#xff0c;作为占据80%中国企业出海市场的亚马逊云科技&#xff0c;其覆盖全球的业务体系&#xff0c;从亚马逊海外购、亚马逊全球开店、亚马逊智能硬件与服务&#xff0c;Amazon Alexa&#xff0c;Amazon Music都是属于亚马逊云科技“梦之队”的一员。…

【Android】WMS(二)Window的添加

软件盘相关模式 在 Android 应用开发中&#xff0c;软键盘的显示与隐藏是一个经常出现的问题&#xff0c;而 WindowManager 的 LayoutParams 中定义的软键盘相关模式则为开发者提供了一些解决方案。 其中&#xff0c;SoftInputMode 就是用于描述软键盘的显示方式和窗口的调整…

【LeetCode】HOT 100(6)

题单介绍&#xff1a; 精选 100 道力扣&#xff08;LeetCode&#xff09;上最热门的题目&#xff0c;适合初识算法与数据结构的新手和想要在短时间内高效提升的人&#xff0c;熟练掌握这 100 道题&#xff0c;你就已经具备了在代码世界通行的基本能力。 目录 题单介绍&#…

python包装与授权

欢迎关注博主 Mindtechnist 或加入【Linux C/C/Python社区】一起学习和分享Linux、C、C、Python、Matlab&#xff0c;机器人运动控制、多机器人协作&#xff0c;智能优化算法&#xff0c;滤波估计、多传感器信息融合&#xff0c;机器学习&#xff0c;人工智能等相关领域的知识和…

ai聊天对话工具哪种好用?这些ai对话聊天工具不要错过

在如今信息爆炸的时代&#xff0c;人工智能技术正在逐渐渗透到我们的生活和工作中。ai对话聊天技术作为其中的一项重要应用&#xff0c;吸引了越来越多的关注。但是&#xff0c;ai对话聊天技术并不是万能的&#xff0c;它需要一定的技巧和策略才能真正发挥其价值。那么&#xf…

CAN总线转串口

一、CAN总线在工程机械中的广泛应用 随着科技的进步和现代施工项目大型化的要求,新一代工程机械需要实现集成化操作和智能控制。CAN总线是国际上应用最广泛的现场总线之一。CAN总线以其高可靠性、实时性、无破坏仲裁、多主等特性&#xff0c;已广泛应用于工程机械中&#xff0c…

这里推荐几个前端动画效果网站

1. AnimistaAnimista 是一个 CSS 动画/转场库和在线工具。它有许多现成的 CSS 动画片段可以直接使用,也可以在线定制动画。 网站地址:Animista - On-Demand CSS Animations Library 2. Animate.cssAnimate.css 是一个免费的 CSS 动画库,里面有 Attention Seekers 、 Bouncing E…

【Java|多线程与高并发】线程安全问题以及synchronized使用实例

文章目录 1. 前言2. 线程安全问题演示3.线程安全问题的原因4.synchronized关键字5. 总结 1. 前言 Java多线程环境下&#xff0c;多个线程同时访问共享资源时可能出现的数据竞争和不一致的情况。 线程安全一直都是一个令人头疼的问题.为了解决这个问题,Java为我们提供了很多方式…

MySQL为什么有了redolog还需要double write buffer?

MySQL为什么有了redolog还需要double write buffer&#xff1f; 问题 我们知道MySQL InnoDB引擎使用redolog作为异常容灾恢复的机制&#xff0c;当MySQL进程发生异常退出、机器断电等&#xff0c;在重新启动时&#xff0c;使用redolog恢复。 OK&#xff0c;redolog是被MySQL…

进程同步与进程通信(#include <windows.h>)

目录 实验二 进程同步与进程通信 一、实验目的 二、实验内容 任务一、进程同步与互斥 任务二、进程通信 实验二 进程同步与进程通信 备注&#xff1a;大二&#xff08;下&#xff09;操作系统实验二 一、实验目的 掌握基本的同步与互斥算法&#xff0c;理解P&#xff…

移植蓝牙芯片后,PCM 无声音问题记录

背景:投影仪项目上的蓝牙模组本地已经验证ok,送到客户那里发现HFP打电话没声音。 1. 客户平台是3566,android 11的环境, 该环境下其他的模组是可以的 2. 在3566上安装QQ, 波通VOIP电话后, 无阴影, 3. 通过示波器接收pcm 无波形输出, 问题分析查证 1.查看HCI log ,…

【LeetCode热题100】打卡第17天:接雨水全排列旋转图像

文章目录 【LeetCode热题100】打卡第17天&#xff1a;接雨水&全排列&旋转图像⛅前言 接雨水&#x1f512;题目&#x1f511;题解 全排列&#x1f512;题目&#x1f511;题解 旋转图像&#x1f512;题目&#x1f511;题解 【LeetCode热题100】打卡第17天&#xff1a;接雨…

Elasticsearch 中文分词器

IK 分词器 我们在ES中最常用的中文分词器就是IK分词器&#xff0c;其项目地址为&#xff1a;https://github.com/medcl/elasticsearch-analysis-ik 下载安装 下载地址&#xff1a; https://github.com/medcl/elasticsearch-analysis-ik/releases 下载时注意和es的版本对应&a…

Network 之十二 iPXE 源码、编译过程、Linker tables 机制、移植新驱动、固件使用

最近&#xff0c;正在学习 iPXE 源码&#xff0c;于是开始各种 Google 查找 iPXE 的资料进行学习。以下就是学习过程中一些感觉比较重要的点&#xff0c;特此记录&#xff0c;以备后续查阅。 起源 上世纪 90 年代初&#xff0c;网卡开始在其扩展卡上包含启动 ROM&#xff0c;每…

2023-6-9-一天一种设计模式

&#x1f37f;*★,*:.☆(&#xffe3;▽&#xffe3;)/$:*.★* &#x1f37f; &#x1f4a5;&#x1f4a5;&#x1f4a5;欢迎来到&#x1f91e;汤姆&#x1f91e;的csdn博文&#x1f4a5;&#x1f4a5;&#x1f4a5; &#x1f49f;&#x1f49f;喜欢的朋友可以关注一下&#xf…

当在浏览器截屏过曝时,应该采取的措施

一、问题来源 屏幕打开了HDR模式后&#xff0c;浏览器在截图的一瞬间出现色彩错误 正常情况如下&#xff1a; HDR截图过曝后如下&#xff1a; 二、解决方法 1. 关闭屏幕HDR模式 桌面右键显示设置关闭HDR选项 2. 修改浏览器选项 地址栏输入 edge://flags&#xff08;Edg…