copy is all you need前向绘图 和疑惑标记

news2025/1/4 19:32:06

疑惑的起因
在这里插入图片描述
简化前向图
在这里插入图片描述
GPT4解释

这段代码实现了一个神经网络模型,包含了BERT、GPT-2和MLP等模块。主要功能是给定一个文本序列和一个查询序列,预测查询序列中的起始和结束位置,使其对应文本序列中的一个短语。具体实现细节如下:

首先,使用BERT将文本序列编码成一个张量作为输入,并从GPT-2模型中提取查询序列的表示。然后,将查询序列的表示拆分为起始和结束两部分,并使用MLP将短语的起始和结束点表示投影到另一个空间中。
接下来,利用生成式模型计算每个查询起始和结束位置之间的所有令牌的概率分布,并根据此生成预测位置。使用交叉熵损失函数来计算这些预测和真实值之间的差距。
最后,计算起始和结束位置的准确性和整个短语的准确性,并返回损失和准确性的值来优化模型的性能。

import torch
from transformers import AutoModel, AutoTokenizer, GPT2LMHeadModel
from torch import nn
from torch.nn import functional as F


class Copyisallyouneed(nn.Module):

    def __init__(self, **args):
        super(Copyisallyouneed, self).__init__()
        self.args = args

        # bert-encoder model
        self.phrase_encoder = AutoModel.from_pretrained(
            self.args['phrase_encoder_model'][self.args['lang']]
        )
        self.bert_tokenizer = AutoTokenizer.from_pretrained(
            self.args['phrase_encoder_tokenizer'][self.args['lang']]
        )
        self.bert_tokenizer.add_tokens(['<|endoftext|>', '[PREFIX]'])
        self.prefix_token_id = self.bert_tokenizer.convert_tokens_to_ids('[PREFIX]')
        self.phrase_encoder.resize_token_embeddings(self.phrase_encoder.config.vocab_size + 2)

        # model and tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(self.args['prefix_encoder_tokenizer'][self.args['lang']])
        self.vocab_size = len(self.tokenizer)
        self.pad = self.tokenizer.pad_token_id if self.args['lang'] == 'zh' else self.tokenizer.bos_token_id

        self.model = GPT2LMHeadModel.from_pretrained(self.args['prefix_encoder_model'][self.args['lang']])
        self.token_embeddings = nn.Parameter(list(self.model.lm_head.parameters())[0])
        # MLP: mapping bert phrase start representations
        self.s_proj = nn.Sequential(
            nn.Dropout(p=args['dropout']),
            nn.Tanh(),
            nn.Linear(self.model.config.hidden_size, self.model.config.hidden_size // 2)
        )
        # MLP: mapping bert phrase end representations
        self.e_proj = nn.Sequential(
            nn.Dropout(p=args['dropout']),
            nn.Tanh(),
            nn.Linear(self.model.config.hidden_size, self.model.config.hidden_size // 2)
        )
        self.gen_loss_fct = nn.CrossEntropyLoss(ignore_index=self.pad)

    @torch.no_grad()
    def get_query_rep(self, ids):
        self.eval()
        output = self.model(input_ids=ids, output_hidden_states=True)['hidden_states'][-1][:, -1, :]
        return output

    def get_token_loss(self, ids, hs, ids_mask):
        # no pad token
        label = ids[:, 1:]
        logits = torch.matmul(
            hs[:, :-1, :],
            self.token_embeddings.t()
        )
        # TODO: inner loss function remove the temperature factor
        logits /= self.args['temp']
        loss = self.gen_loss_fct(logits.view(-1, logits.size(-1)), label.reshape(-1))
        chosen_tokens = torch.max(logits, dim=-1)[1]
        gen_acc = (chosen_tokens.reshape(-1) == label.reshape(-1)).to(torch.long)
        valid_mask = (label != self.pad).reshape(-1)
        valid_tokens = gen_acc & valid_mask
        gen_acc = valid_tokens.sum().item() / valid_mask.sum().item()
        return loss, gen_acc

    def forward(self, batch):
        ## gpt2 query encoder
        ids, ids_mask = batch['gpt2_ids'], batch['gpt2_mask']
        last_hidden_states = \
        self.model(input_ids=ids, attention_mask=ids_mask, output_hidden_states=True).hidden_states[-1]
        # get token loss
        loss_0, acc_0 = self.get_token_loss(ids, last_hidden_states, ids_mask)

        ## encode the document with the BERT encoder model
        dids, dids_mask = batch['bert_ids'], batch['bert_mask']
        output = self.phrase_encoder(dids, dids_mask, output_hidden_states=True)['hidden_states'][-1]  # [B, S, E]
        # collect the phrase start representations and phrase end representations
        s_rep = self.s_proj(output)
        e_rep = self.e_proj(output)
        s_rep = s_rep.reshape(-1, s_rep.size(-1))
        e_rep = e_rep.reshape(-1, e_rep.size(-1))  # [B_doc*S_doc, 768//2]

        # collect the query representations
        query = last_hidden_states[:, :-1].reshape(-1, last_hidden_states.size(-1))
        query_start = query[:, :self.model.config.hidden_size // 2]
        query_end = query[:, self.model.config.hidden_size // 2:]

        # training the representations of the start tokens
        candidate_reps = torch.cat([
            self.token_embeddings[:, :self.model.config.hidden_size // 2],
            s_rep], dim=0)
        logits = torch.matmul(query_start, candidate_reps.t())
        logits /= self.args['temp']

        # build the padding mask for query side
        query_padding_mask = ids_mask[:, :-1].reshape(-1).to(torch.bool)

        # build the padding mask: 1 for valid and 0 for mask
        attention_mask = (dids_mask.reshape(1, -1).to(torch.bool)).to(torch.long)
        padding_mask = torch.ones_like(logits).to(torch.long)
        # Santiy check over
        padding_mask[:, self.vocab_size:] = attention_mask

        # build the position mask: 1 for valid and 0 for mask
        pos_mask = batch['pos_mask']
        start_labels, end_labels = batch['start_labels'][:, 1:].reshape(-1), batch['end_labels'][:, 1:].reshape(-1)
        position_mask = torch.ones_like(logits).to(torch.long)
        query_pos = start_labels > self.vocab_size
        # ignore the padding mask
        position_mask[query_pos, self.vocab_size:] = pos_mask
        assert padding_mask.shape == position_mask.shape
        # overall mask
        overall_mask = padding_mask * position_mask
        ## remove the position mask
        # overall_mask = padding_mask

        new_logits = torch.where(overall_mask.to(torch.bool), logits, torch.tensor(-1e4).to(torch.half).cuda())
        mask = torch.zeros_like(new_logits)
        mask[range(len(new_logits)), start_labels] = 1.
        loss_ = F.log_softmax(new_logits[query_padding_mask], dim=-1) * mask[query_padding_mask]
        loss_1 = (-loss_.sum(dim=-1)).mean()

        ## split the token accuaracy and phrase accuracy
        phrase_indexes = start_labels > self.vocab_size
        phrase_indexes_ = phrase_indexes & query_padding_mask
        phrase_start_acc = new_logits[phrase_indexes_].max(dim=-1)[1] == start_labels[phrase_indexes_]
        phrase_start_acc = phrase_start_acc.to(torch.float).mean().item()
        phrase_indexes_ = ~phrase_indexes & query_padding_mask
        token_start_acc = new_logits[phrase_indexes_].max(dim=-1)[1] == start_labels[phrase_indexes_]
        token_start_acc = token_start_acc.to(torch.float).mean().item()

        # training the representations of the end tokens
        candidate_reps = torch.cat([
            self.token_embeddings[:, self.model.config.hidden_size // 2:],
            e_rep], dim=0
        )
        logits = torch.matmul(query_end, candidate_reps.t())  # [Q, B*]  
        logits /= self.args['temp']
        new_logits = torch.where(overall_mask.to(torch.bool), logits, torch.tensor(-1e4).to(torch.half).cuda())
        mask = torch.zeros_like(new_logits)
        mask[range(len(new_logits)), end_labels] = 1.
        loss_ = F.log_softmax(new_logits[query_padding_mask], dim=-1) * mask[query_padding_mask]
        loss_2 = (-loss_.sum(dim=-1)).mean()
        # split the phrase and token accuracy
        phrase_indexes = end_labels > self.vocab_size
        phrase_indexes_ = phrase_indexes & query_padding_mask
        phrase_end_acc = new_logits[phrase_indexes_].max(dim=-1)[1] == end_labels[phrase_indexes_]
        phrase_end_acc = phrase_end_acc.to(torch.float).mean().item()
        phrase_indexes_ = ~phrase_indexes & query_padding_mask
        token_end_acc = new_logits[phrase_indexes_].max(dim=-1)[1] == end_labels[phrase_indexes_]
        token_end_acc = token_end_acc.to(torch.float).mean().item()
        return (
            loss_0,  # token loss
            loss_1,  # token-head loss
            loss_2,  # token-tail loss
            acc_0,  # token accuracy
            phrase_start_acc,
            phrase_end_acc,
            token_start_acc,
            token_end_acc
        )

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/929278.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

面向对象的理解

想要对象了&#xff1f;没问题&#xff0c;new一个就好了。 但是&#xff0c;new太多对象&#xff0c;对象也会生气的哦。 你瞧&#xff0c;她来了 从两段代码发现端倪 我们来计算一个矩形的面积&#xff0c;看看这两段代码有什么区别呢&#xff1f; 第一段&#xff1a; con…

并发-线程池

阻塞队列 笔记地址 点击进入 队列&#xff1a;先进先出 限定在一端进行插入&#xff0c;一端进行删除 出队为队头&#xff0c;入队为队尾 阻塞队列 BlockingQueue Queue接口继承Collection接口添加元素&#xff1a;add()&#xff0c;队列满了对抛出异常offer()&#xff0c;队…

【二分】搜索旋转数组

文章目录 不重复数组找最小值&#xff0c;返回下标重复数组找最小值&#xff0c;返回下标不重复数组找target&#xff0c;返回下标重复数组找target&#xff0c;返回bool重复数组找target&#xff0c;返回下标 不重复数组找最小值&#xff0c;返回下标 class Solution {public …

js中作用域的理解?

1.作用域 作用域&#xff0c;即变量(变量作用域又称上下文)和函数生效(能被访问)的区域或集合 换句话说&#xff0c;作用域决定了代码区块中变量和其他资源的可见性 举个例子 function myFunction() {let inVariable "函数内部变量"; } myFunction();//要先执行这…

SOLIDWORKS提高装配效率的方法:随配合复制

SOLIDWORKS用户在工程装配设计中会经常遇到这样的问题&#xff0c;在同一个装配体中某个零件需要调入多次&#xff0c;虽然装配位置可能不同&#xff0c;但是配合方式基本一致&#xff0c;传统的方法是多次插入该零件&#xff0c;然后添加配合关系&#xff0c;过程比较繁琐&…

多传感器时间同步

文章目录 存在的问题可能的解决方法 存在的问题 不同的传感器记录的时间存在差异&#xff0c;若是不进行同步操作&#xff0c;会导致融合结果异常&#xff0c;这是我可视化后的轨迹&#xff0c;存在大幅度抖动。 未同步的imu和gps融合 gps轨迹 可能的解决方法 在ROS中&am…

vite+vue3配置环境变量

首先在项目根目录添加环境文件&#xff0c;比如有测试环境、开发环境、生产环境&#xff1a; 环境说明 .env&#xff1a;所有环境都会加载.env.development&#xff1a;开发环境加载.env.production&#xff1a;生产环境加载.env.test&#xff1a;测试环境加载 语法说明 变量…

《存储IO路径》专题:数据魔法师DMA

初识DMA 大家好,今天我要给大家介绍一位在计算机世界中不可或缺的魔法师——DMA(Direct Memory Access)。让我们一起揭开这位魔法师的神秘面纱,看看它是如何让数据在内存之间自由穿梭的。 DMA这位魔法师可是大有来头。在现代计算机系统中,CPU、内存和各种设备之间需要进…

【微服务】05-网关与BFF(Backend For Frontend)

文章目录 1.打造网关1.1 简介1.2 连接模式1.3 打造网关 2.身份认证与授权2.1 身份认证方案2.1.1 JWT是什么2.1.2 启用JwtBearer身份认证2.1.3 配置身份认证2.1.4 JWT注意事项 1.打造网关 1.1 简介 BFF(Backend For Frontend)负责认证授权&#xff0c;服务聚合&#xff0c;目标…

TMS FlexCel Studio for VCL and FireMonkey Crack

TMS FlexCel Studio for VCL and FireMonkey Crack FlexCel for VCL/FireMonkey是一套允许操作Excel文件的Delphi组件。它包括一个广泛的API&#xff0c;允许本机读取/写入Excel文件。如果您需要在没有安装Excel的Windows或macOS机器上阅读或创建复杂的电子表格&#xff0c;Fle…

skywalking服务部署

一、前言 Apache SkyWalking 是一个开源的分布式跟踪、监控和诊断系统&#xff0c;旨在帮助用户监控和诊断分布式应用程序、微服务架构和云原生应用的性能和健康状况。它提供了可视化的分析工具&#xff0c;帮助开发人员和运维团队深入了解应用程序的性能、调用链和异常情况 …

spring整合MybatisAOP整合PageHelper插件

一&#xff0c;spring集成Mybatis的概念 Spring 整合 MyBatis 是将 MyBatis 数据访问框架与 Spring 框架进行集成&#xff0c;以实现更便捷的开发和管理。在集成过程中&#xff0c;Spring 提供了许多特性和功能&#xff0c;如依赖注入、声明式事务管理、AOP 等 它所带来给我们的…

C++笔记之静态成员函数可以在类外部访问私有构造函数吗?

C笔记之静态成员函数可以在类外部访问私有构造函数吗&#xff1f; code review! 静态成员函数可以在类外部访问私有构造函数。在C中&#xff0c;访问控制是在编译时执行的&#xff0c;而不是在运行时执行的。这意味着静态成员函数在编译时是与类本身相关联的&#xff0c;而不…

Linux(基础篇一)

Linux基础篇 Linux基础篇一1. Linux文件系统与目录结构1.1 Linux文件系统1.2 Linux目录结构 2. VI/VIM编辑器2.1 vi/vim是什么2.2 模式间的转换2.3 一般模式2.4 插入模式2.4.1 进入编辑模式2.4.2 退出编辑模式 2.5 命令模式 3. 网络配置3.1 网络连接模式3.2 修改静态ip3.3 配置…

ubuntu下自启动设置,为了开机自启动launch文件

1、书写sh脚本文件 每隔5秒钟启动一个launch文件&#xff0c;也可以直接在一个launch文件中启动多个&#xff0c;这里为了确保启动顺利&#xff0c;添加了一些延时 #! /bin/bash ### BEGIN INIT sleep 5 gnome-terminal -- bash -c "source /opt/ros/melodic/setup.bash…

4.12 TCP 连接,一端断电和进程崩溃有什么区别?

目录 TCP keepalive TCP 的保活机制 主机崩溃 进程崩溃 有数据传输的场景 客户端主机宕机&#xff0c;又迅速重启 客户端主机宕机&#xff0c;一直没有重启 TCP连接服务器宕机和进程退出情况总结 TCP keepalive TCP 的保活机制 TCP 保活机制需要通过 socket 接口设置 S…

算法学习——递归

引言 从这个专栏开始&#xff0c;我们将会一起来学习算法知识。首先我们要一起来学习的算法便是递归。为什么呢&#xff1f;因为这个算法是我很难理解的算法。我希望通过写这些算法博客&#xff1b;来加深自己对于递归算法的理解和运用。当然&#xff0c;学习算法最快的方式便是…

一文解析block io生命历程

作为存储业务的一个重要组成部分&#xff0c;block IO是非易失存储的唯一路径&#xff0c;它的生命历程每个阶段都直接关乎我们手机的性能、功耗、甚至寿命。本文试图通过block IO的产生、调度、下发、返回的4个阶段&#xff0c;阐述一个block IO的生命历程。 一、什么是块设备…

Python“牵手”lazada商品列表数据,关键词搜索lazadaAPI接口数据,lazadaAPI接口申请指南

lazada平台API接口是为开发电商类应用程序而设计的一套完整的、跨浏览器、跨平台的接口规范&#xff0c;lazadaAPI接口是指通过编程的方式&#xff0c;让开发者能够通过HTTP协议直接访问lazada平台的数据&#xff0c;包括商品信息、店铺信息、物流信息等&#xff0c;从而实现la…

安装启动yolo5教程

目录 一、下载yolo5项目 二、安装miniconda&#xff08;建议不要安装在C盘&#xff09; 三、安装CUDA 四、安装pytorch 五、修改配置参数 六、修改电脑参数 七、启动项目 博主硬件&#xff1a; Windows 10 家庭中文版 一、下载yolo5项目 GitHub - ultralytics/yolov5:…