LLM - LLaMA-2 获取文本向量并计算 Cos 相似度

news2024/9/24 9:18:20
 

目录

一.引言

二.获取文本向量

1.hidden_states 与 last_hidden_states

◆ hidden_states

◆ last_hidden_states 

2.LLaMA-2 获取 hidden_states

◆ model config 

◆ get Embedding

三.获取向量 Cos 相似度

1.向量选择

2.Cos 相似度

3.BERT-whitening 特征白化

四.总结


一.引言

前面提到了两种基于统计的机器翻译评估方法: Rouge 与 BLEU,二者通过统计概率计算 N-Gram 的准确率与召回率,在机器翻译这种回答相对固定的场景该方法可以作为一定参考,但在当前大模型更加多样性的场景以及发散的回答的情况下,Rouge 与 BLEU 有时候并不能更好的描述文本之间的相似度,下面我们尝试从 LLM 大模型提取文本的 Embedding 并进行向量相似度计算。

二.获取文本向量

1.hidden_states 与 last_hidden_states

根据 LLM 模型类型的不同,有的 Model 提供 hidden_states 方法,例如 LLaMA-2-13B,有的模型提供 last_hidden_states 方法,例如 GPT-2。查找模型对应方法 API 可以在 Transformer 官网。

 hidden_states

hidden_states 类型为 typing.Optional[typing.Tuple[torch.FloatTensor]],其提供一个 Tuple[Tensor] 分别记录了每层的输出,完整的解释在参数下方: 

模型在每一层输出处的隐藏状态加上可选的初始嵌入输出。这里我们可以通过打印模型 Layer 和索引从而获取 hidden_states 中隐层的输出。

◆ last_hidden_states 

一些传统的模型例如 GPT-2,还有当下一些的新模型例如 ChatGLM2 都有 last_hidden_states 的 API,可以直接获取最后一层的 Embedding 输出,而如果使用 hidden_states 则只需要通过 [-1] 索引即可获得 last_hidden_states,相比来如前者更全面后者更方便。

2.LLaMA-2 获取 hidden_states

model config 

    config_kwargs = {
        "trust_remote_code": True,
        "cache_dir": None,
        "revision": 'main',
        "use_auth_token": None,
        "output_hidden_states": True
    }

    config = AutoConfig.from_pretrained(ori_model_path, **config_kwargs)

    llama_model = AutoModelForCausalLM.from_pretrained(
        ori_model_path,
        config=config,
        torch_dtype=torch.float16,
        low_cpu_mem_usage=True,
        trust_remote_code=True,
        revision='main'
    )

 根据 CausalLMOutputWithPast hidden_states 参数的提示,我们只需要在模型 config 中添加:

"output_hidden_states": True

get Embedding

def get_embeddings(result, llm_tokenizer, model, args):
    fw = open(args.output, 'w', encoding='utf-8')
    for qa in result:
        q = qa[0]
        a = qa[1]
        # 对输出文本进行 tokenize 和编码
        tokens = llm_tokenizer.encode_plus(a, add_special_tokens=True, padding='max_length', truncation=True,
                                           max_length=128, return_tensors='pt')
        input_ids = tokens["input_ids"]
        attention_mask = tokens['attention_mask']

        # 获取文本 Embedding
        with torch.no_grad():
            outputs = model(input_ids=input_ids.cuda(), attention_mask=attention_mask)
            embedding = list(outputs.hidden_states)
            last_hidden_states = embedding[-1].cpu().numpy()
            first_hidden_states = embedding[0].cpu().numpy()
            last_hidden_states = np.squeeze(last_hidden_states)
            first_hidden_states = np.squeeze(first_hidden_states)
            fisrt_larst_avg_status = np.mean(first_hidden_states + last_hidden_states, axis=0)

        log = "%s\t%s\t%s\n" % (q, a, toString(fisrt_larst_avg_status))
        fw.write(log)
    fw.close()

predict  预测       ➔  将 model 基于 Question generate 得到的 Answer 存入 result

encode 编码       ➔  对 Answer 进行编码获取对应 Token 与 input_ids、attention_mask

output 模型输出  ➔  直接调用 model 进行输出,有的也可以调用 model.transform 方法进行输出

hidden_states     ➔  outputs.hidden_states 获取各隐层输出

最后获取的向量需要先 cpu 然后再转为 numpy 数组,一般的做法是采用 mean 获得句子的平均表征。

三.获取向量 Cos 相似度

1.向量选择

在 BERT-flow 的论文中,如果不加任何后处理手段,那么基于 BERT 抽取句向量的最好 Pooling 方法是 BERT 的第一层与最后一层的所有 token 向量的平均,即 fisrt-larst-avg,对应 hidden_state 的 0 和 -1 索引,所以后面的相似度计算我们都以 fisrt-larst-avg 为基准来评估 Embedding 相似度。

# 获取文本 Embedding
with torch.no_grad():
    outputs = model(input_ids=input_ids.cuda(), attention_mask=attention_mask)
    embedding = list(outputs.hidden_states)
    last_hidden_states = embedding[-1].cpu().numpy()
    first_hidden_states = embedding[0].cpu().numpy()
    last_hidden_states = np.squeeze(last_hidden_states)
    first_hidden_states = np.squeeze(first_hidden_states)
    fisrt_larst_avg_status = np.mean(first_hidden_states + last_hidden_states, axis=0)

2.Cos 相似度

# 计算 Cos 相似度
def compute_cosine(a_vec, b_vec):
    norms1 = np.linalg.norm(a_vec, axis=1)
    norms2 = np.linalg.norm(b_vec, axis=1)
    dot_products = np.sum(a_vec * b_vec, axis=1)
    cos_similarities = dot_products / (norms1 * norms2)
    return cos_similarities

a_vec 为预测文本转化得到的 Embedding,b_vec 为人工标注正样本文本转化得到的 Embedding,通过计算二者相似度,评估预测文本与人工文本的相似程度。

3.BERT-whitening 特征白化

苏神在 BERT-whitening 一文中提出了一种基于 PCA 降维的无监督 Embedding 评估方式,Bert-whitening 又叫特征白化,其思路与 PCA 降维类似,意在对 SVD 分解后的主成分矩阵取前 λ 个特征向量构造特征值矩阵,提取向量中的关键信息,使输出向量矩阵每个维度均值为零,协方差矩阵为单位阵,λ 个特征值也对应前 λ 个主成分。其算法逻辑如下:

 下面我们调用 Sklearn 的 PCA 库简单实现下:

from sklearn.decomposition import PCA
from sklearn.preprocessing import normalize

    # 取出句子的平均表示 -> 使用 PCA 降维 -> 白化处理
    concatenate = np.concatenate((answer_vector, predict_vector))
    pca = PCA(n_components=2048)
    pca.fit(concatenate)
    ans_white_vec = pca.transform(answer_vector)
    ans_norm_vec = normalize(ans_white_vec)
    pre_white_vec = pca.transform(predict_vector)
    pre_norm_vec = normalize(pre_white_vec)

    pca_cos_similarities = compute_cosine(ans_norm_vec, pre_norm_vec)

answec_vector 和 predict_vector 均通过 first_and_last 方法从 hidden_states 中获取,n_components 即 top_k 的选择,以 LLaMA-2 为例,原始得到的向量维度为 5120,原文中也有使用 n_components = 256 实验。

四.总结

博主采用 1500+ 样本分别使用 cos、pca 和 self_pca [自己实现 SVD 与特征矩阵] 三种方法对向量相似度进行评估,n_components 设为 1024:

可以看到 SVD 处理后得到的 W 和 mu 的 shape,通过下述操作可完成向量的降维:

vecs = (vecs + bias).dot(kernel)

最终得到的结果 Cosine 与 PCA 降维的相似度差距较大,由于自然语言生成的样本没有严格意义的正样本,上面计算采用的参考文本也是人工标注,有一定的不确定性,所以基于不同的度量,我们也可以统计分析,定一个 threshold,认为大于该 threshold 的输入样本为可用。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/958449.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

用户体验地图是什么?UX设计心得分享

大家好,我是设计师l1m0身。本篇文章是关于UX设计中的用户体验地图。 对于新手设计师来说,建立用户体验地图会有一些难度。本篇文章中,我会以简单、易懂的语言分享UX设计师如何制作用户体验地图,希望对你的日常项目体验提升有所帮…

私有化部署无忧企业文档2.1.7版本功能清单已更新!

无忧企业文档是软开企服研发的一款基于云端的在线文档管理工具,适用于团队内部协作、知识管理、项目管理等多个领域。与其他在线文档工具相比,无忧企业文档更注重团队协作和安全性,适合企业和团队使用。并且它还提供了丰富的API接口&#xff…

使用Dbeaver连接GaussDB

1.下载DBeaver,官网地址 2.安装软件,打开软件,点击数据库->驱动管理器,具体操作如下图: 3、选择新建后进行参数设置,如下图: 具体参数如下图 驱动名称: GS #随便定义 驱动类型&#…

Pytorch 的基本概念和使用场景介绍

文章目录 一、基本概念1. 张量(Tensor)2. 自动微分(Autograd)3. 计算图(Computation Graph)4. 动态计算图(Dynamic Computation Graph)5. 变量(Variable) 二、…

Nodejs入门 token校验

Nodejs入门token校验之jsonwebtoken的使用 前言 token校验作为项目里的必要项,其重要性不言而喻,今天介绍一个在Node.js中备受推崇的神奇工具——jsonwebtoken 一、token是什么jsonwebtoken是什么? 在互联网世界中,Token是一种用于…

路由技术介绍

路由技术介绍 一、路由概述1.1、为什么需要路由1.2、路由的定义1.3、直接路由数据通信分析1.4、间接路由数据通信分析1.5、认识路由设备1.6、路由的下一跳1.7、路由表的构成与维护1.8、路由表的构成1.9、路由表的度量值1.10、路由表的内容1.11、管理距离1.12、路由加表原则1.13…

ESP32在线仿真器

1. Wokwi是一个电子在线仿真平台,支持的芯片有ESP32,STM32,树莓派,Arduino 网址: https://wokwi.com ; 而且支持在vscode内置插件仿真 2. ESP32可以使用microPython开发,相关sdk说明MicroPython(ESP32)快…

IndexDB

1 新建一个数据库 (1)在utils文件中新建一个indexedDB.ts export default class DB {private dbName: string // 数据库名称constructor(dbName: string){this.dbName dbName}// 打开数据库public openStore() {const request window.indexedDB.open(this.dbName,1)request…

关于linux openssl的自签证书认证与nginx配置

自签文档链接 重点注意这块,不能写一样的,要是一样的话登录界面锁会报不安全 域名这块跟最后发布的一致 nginx配置的话 server {listen 443 ssl; //ssl 说明为https 默认端口为443server_name www.skyys.com; //跟openssl设置的域名保持一致s…

GD32F4_CAN1无法进入接收中断

Q、GD32F450/470的外设CAN1按正常的初始化顺序配置后在正常和回环模式下都无法进入接收中断。 A、注意以下两点 【1】要使用CAN1的接收中断必须要开启CAN0的时钟 【2】CAN1的接收过滤序号应设置为15

肖sir__linux详解__001

linux详解: 1、ifconfig 查看ip地址 2、6版本:防火墙的命令: service iptables status 查看防火墙状态 service iptables statrt 开启防火墙 service iptables stop 关闭防火墙 service iptables restart 重启防火墙状态 7版本: systemctl s…

【Flutter】Flutter 使用 Shimmer 实现闪光效果的加载动画占位符

【Flutter】Flutter 使用 Shimmer 实现闪光效果的加载动画占位符 文章目录 一、前言二、为什么选择 shimmer 以及其安装和基本使用1. 闪光效果在 UI 设计中的价值2. shimmer 与其他类似工具的比较3. 如何在 Flutter 项目中安装 shimmer4. 基本使用方法和代码示例 三、深入了解 …

Activity基础之开发环境

工欲善其事必先利其器。 一、Android开发工具AndroidStudio安装以及环境搭建。 AS下载路径:https://developer.android.google.cn/studio AS历史版本下载路径:https://developer.android.google.cn/studio/archive?hlzh-cn 安装过程省略。。。 Jav…

C语言每日一练----Day(12)

本专栏为c语言练习专栏,适合刚刚学完c语言的初学者。本专栏每天会不定时更新,通过每天练习,进一步对c语言的重难点知识进行更深入的学习。 今日练习题关键字:最大连续1的个数 完全数计算 💓博主csdn个人主页&#xff1…

Python数据分析案例30——中国高票房电影分析(爬虫获取数据及分析可视化全流程)

案例背景 最近总看到《消失的她》票房多少多少,《孤注一掷》票房又破了多少多少..... 于是我就想自己爬虫一下获取中国高票房的电影数据,然后分析一下。 数据来源于淘票票:影片总票房排行榜 (maoyan.com) 爬它就行。 代码实现 首先爬虫获…

【css】z-index与层叠上下文

z-index属性用来设置元素的堆叠顺序,使用z-index有一个大的前提:z-index所作用元素的样式列表中必须有position属性并且属性值为absolute、relative或fixed中的一个,否则z-index无效。 层叠上下文 MDN讲解 我们给元素设置的z-index都是有一…

selenium中定位shadow-root,以及获取shadow-root内部的数据

通过shadow-root的父级定位到shadow-root,再通过语句进行操作 两种方法: 第一种,Python种JS实现 第二种,selenium实现 1.0 案例网站 参考某橘色网站 2.0 js语句定位 可在控制台进行测试 测试语句 document.querySelector("ali-ba…

猫头虎博主解析:Spring中的“Unknown return value type: java.lang.Boolean“问题

🌷🍁 博主猫头虎(🐅🐾)带您 Go to New World✨🍁 🦄 博客首页——🐅🐾猫头虎的博客🎐 🐳 《面试题大全专栏》 🦕 文章图文…

纽扣电池做CE认证 需要的介绍 EMC检测、RED检测和LVD检测

认证需要测试什么项目 1、EMC检测项目:传导骚扰、辐射骚扰、静电放电抗扰度、射频电磁场辐射抗扰度、电快速瞬变脉)冲群抗扰度、浪涌(冲击)抗扰度、射频场感应的传导抗扰度、工频磁场抗扰度、电压暂降、短时中断和电压变化抗扰度、谐波电流、电压波动和闪烁。 2、LVD检测项目:…

Laf 中大猫谱:让每一只流浪猫都有家

猫谱简介 中大猫谱是一款辅助校园流浪猫救助的开源小程序项目,服务端使用 Laf 云开发。 猫谱主要功能包括:猫咪信息登记、照片分享、拍照识猫、公告和留言等。项目创立的初衷,是解决校园猫猫交流群里的一个常见问题:问猫猫是谁。…