KL 散度(python+nlp)

news2024/11/15 16:20:27

python demo

KL 散度(Kullback-Leibler divergence),也称为相对熵,是衡量两个概率分布之间差异的一种方式。KL 散度是非对称的,也就是说,P 相对于 Q 的 KL 散度通常不等于 Q 相对于 P 的 KL 散度。
在这里插入图片描述
一个简单的 Python 类来计算两个离散概率分布之间的 KL 散度:

import numpy as np

class KLDivergence:
    def __init__(self, eps=1e-10):
        self.eps = eps  # 防止出现 log(0)

    def kl_divergence(self, p, q):
        """
        计算两个离散概率分布 P 和 Q 之间的 KL 散度。
        
        参数:
        p (np.array): 分布 P 的概率值。
        q (np.array): 分布 Q 的概率值。
        
        返回:
        float: P 相对于 Q 的 KL 散度。
        """
        p = np.asarray(p, dtype=np.float)
        q = np.asarray(q, dtype=np.float)
        
        # 防止分母为零
        q = np.clip(q, self.eps, 1 - self.eps)
        
        # 计算 KL 散度
        kl_div = np.sum(np.where(p != 0, p * np.log(p / q), 0))
        return kl_div

# 示例代码
if __name__ == "__main__":
    kld = KLDivergence()
    
    # 定义两个概率分布
    p = np.array([0.1, 0.4, 0.5])
    q = np.array([0.2, 0.3, 0.5])
    
    # 计算 KL 散度
    kl_div = kld.kl_divergence(p, q)
    print("KL Divergence:", kl_div)

在这里插入图片描述

nlp demo 1

在自然语言处理(NLP)中,KL 散度可以用于多种场景,比如评估文档的主题分布一致性、语料库中的词频分布比较等。这里提供一个使用 KL 散度来比较两个文档中词频分布的例子。我们将使用 scikit-learn 库来提取文档的词频,并计算它们之间的 KL 散度。

示例说明

在这个例子中,我们将使用 TF-IDF(Term Frequency-Inverse Document Frequency)来表示文档中的词频,并计算两个文档之间的 KL 散度。TF-IDF 是一种常用的文本特征表示方法,它可以衡量一个词对文档的重要程度。

import numpy as np
from sklearn.feature_extraction.text import TfidfVectorizer
from scipy.stats import entropy

class DocumentComparator:
    def __init__(self):
        self.vectorizer = TfidfVectorizer()
        self.eps = 1e-10

    def fit_transform(self, documents):
        """
        使用 TF-IDF 向量化器将文档转换为 TF-IDF 特征向量。
        
        参数:
        documents (list of str): 文档列表。
        
        返回:
        np.array: TF-IDF 特征矩阵。
        """
        tfidf_matrix = self.vectorizer.fit_transform(documents)
        return tfidf_matrix.toarray()

    def calculate_kl_divergence(self, doc1_tfidf, doc2_tfidf):
        """
        计算两个文档 TF-IDF 特征向量之间的 KL 散度。
        
        参数:
        doc1_tfidf (np.array): 第一个文档的 TF-IDF 特征向量。
        doc2_tfidf (np.array): 第二个文档的 TF-IDF 特征向量。
        
        返回:
        float: 两个文档之间的 KL 散度。
        """
        # 将 TF-IDF 向量归一化为概率分布
        doc1_prob = doc1_tfidf / np.sum(doc1_tfidf)
        doc2_prob = doc2_tfidf / np.sum(doc2_tfidf)
        
        # 防止分母为零
        doc2_prob = np.clip(doc2_prob, self.eps, 1 - self.eps)
        
        # 计算 KL 散度
        kl_div = entropy(doc1_prob, doc2_prob)
        return kl_div

# 示例代码
if __name__ == "__main__":
    comparator = DocumentComparator()
    
    # 定义两个文档
    doc1 = "Python is a widely used high-level programming language."
    doc2 = "Python is a popular scripting language for data science."
    
    # 将文档转换为 TF-IDF 特征向量
    docs = [doc1, doc2]
    tfidf_matrix = comparator.fit_transform(docs)
    
    # 计算两个文档之间的 KL 散度
    kl_div = comparator.calculate_kl_divergence(tfidf_matrix[0], tfidf_matrix[1])
    print("KL Divergence between documents:", kl_div)

在这里插入图片描述

nlp demo 2

结合 Transformer 模型(如 BERT、GPT-2、T5 或 Llama2)与 KL 散度,我们可以构建一个更复杂的系统,用于评估不同文档或句子之间的相似性。这可以通过以下几种方式实现:

  • 使用预训练模型生成句子嵌入:使用 Transformer 模型来生成句子或文档级别的嵌入向量。
  • 计算句子嵌入的概率分布:将句子嵌入转换为概率分布。
  • 计算 KL 散度:使用 KL 散度来比较两个句子的概率分布。

这里我们使用 Hugging Face 的 Transformers 库来加载预训练的模型,并计算句子之间的 KL 散度。我们将使用 BERT 作为示例模型,但这种方法同样适用于 GPT-2、T5 或 Llama2 等其他 Transformer 模型。

import torch
from transformers import AutoTokenizer, AutoModel
from scipy.special import rel_entr
import numpy as np

class SentenceEmbeddingComparator:
    def __init__(self, model_name="bert-base-uncased"):
        """
        初始化 SentenceEmbeddingComparator 类。
        
        参数:
        model_name (str): 预训练模型的名字,默认为 'bert-base-uncased'。
        """
        # 设置设备为 GPU 如果可用,否则使用 CPU
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        # 加载预训练的分词器和模型
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModel.from_pretrained(model_name).to(self.device)
        
        # 设置模型为评估模式
        self.model.eval()

    def _mean_pooling(self, model_output, attention_mask):
        """
        对模型输出执行平均池化以获取句子级别的嵌入。
        
        参数:
        model_output (torch.Tensor): 模型的输出。
        attention_mask (torch.Tensor): 注意力掩码,指示哪些位置是填充的。
        
        返回:
        torch.Tensor: 句子级别的嵌入。
        """
        # 获取最后一层的隐藏状态
        token_embeddings = model_output.last_hidden_state
        
        # 扩展注意力掩码以匹配嵌入的维度
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        
        # 计算每句话的嵌入的加权和
        sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
        
        # 计算每个位置的有效掩码数量
        sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
        
        # 计算平均嵌入
        return sum_embeddings / sum_mask

    def generate_sentence_embedding(self, sentence):
        """
        生成给定句子的嵌入。
        
        参数:
        sentence (str): 输入的句子。
        
        返回:
        np.array: 句子的嵌入向量。
        """
        # 对输入句子进行编码并添加到设备上
        encoded_input = self.tokenizer(sentence, padding=True, truncation=True, max_length=128, return_tensors='pt').to(self.device)
        
        # 使用模型生成输出
        with torch.no_grad():
            model_output = self.model(**encoded_input)
        
        # 使用平均池化获取句子嵌入
        sentence_embedding = self._mean_pooling(model_output, encoded_input['attention_mask'])
        
        # 将嵌入从张量转换为 NumPy 数组
        return sentence_embedding.cpu().numpy()[0]

    def calculate_kl_divergence(self, emb1, emb2):
        """
        计算两个句子嵌入之间的 KL 散度。
        
        参数:
        emb1 (np.array): 第一个句子的嵌入。
        emb2 (np.array): 第二个句子的嵌入。
        
        返回:
        float: 两个句子之间的 KL 散度。
        """
        # 将嵌入向量转换为概率分布
        # 归一化确保概率分布的总和为 1
        prob1 = emb1 / np.linalg.norm(emb1, ord=1)
        prob2 = emb2 / np.linalg.norm(emb2, ord=1)
        
        # 计算 KL 散度
        kl_div = np.sum(rel_entr(prob1, prob2))
        return kl_div

# 示例代码
if __name__ == "__main__":
    # 创建 SentenceEmbeddingComparator 实例
    comparator = SentenceEmbeddingComparator("bert-base-uncased")
    
    # 定义两个句子
    sentence1 = "Python is a widely used high-level programming language."
    sentence2 = "Python is a popular scripting language for data science."
    
    # 生成句子嵌入
    emb1 = comparator.generate_sentence_embedding(sentence1)
    emb2 = comparator.generate_sentence_embedding(sentence2)
    
    # 计算两个句子之间的 KL 散度
    kl_div = comparator.calculate_kl_divergence(emb1, emb2)
    
    # 输出 KL 散度
    print("KL Divergence between sentences:", kl_div)

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1993899.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

zabbix7.0TLS-05-快速入门-触发器

文章目录 1 概述2 查看触发器3 添加触发器4 验证触发器5 查看问题6 问题恢复 1 概述 监控项用于收集数据,但是我们并不能时刻观测每个监控项的数据,看看哪个监控项的数据超过了正常可接受的数值或状态,比如 CPU 负载高于 90%、磁盘使用率低于…

不平衡数据:Handling Imbalanced Dataset with SMOTE导致ValueError ⚖️

不平衡数据:Handling Imbalanced Dataset with SMOTE导致ValueError ⚖️📈 不平衡数据:Handling Imbalanced Dataset with SMOTE导致ValueError ⚖️📈摘要引言详细介绍什么是不平衡数据集?⚖️SMOTE简介&#x1f4c8…

加密案例分享:电子设备制造行业

企业核心诉求选择 1.某企业规模庞大,分支众多,数据安全管理方面极为复杂; 2.企业结构复杂,包括研发、销售、财务、总部、分部、办事处、销售等单位连结成为一个庞大的企业组织,数据产生、存储、流转、使用、销毁变化…

Selenium + Python 自动化测试08(截图)

我们的目标是:按照这一套资料学习下来,大家可以独立完成自动化测试的任务。 上一篇我们讨论了滑块的操作方法,本篇文章我们讲述一下截图的操作方法。希望能够帮到爱学的小伙伴。 在实际的测试项目组中我们经常要截屏保存报错信息&#xff0c…

做个一套C#面试题

1.int long float double 分别是几个字节 左到右范围从小到大:byte->short->int->long->float->double 各自所占字节大小:1字节、2字节、4字节、8字节、4字节、8字节 2.System.Object四个公共方法的申明 namespace System {//// 摘要…

C#如何解决引用类型的“深度”克隆问题

前言 在C#中我们new一个引用类型的对象称为对象1,如果我们再次new一个引用类型的对象称为对象2,如果直接将第一个对象直接赋值给第二个对象,然后如果我们这时候改变对象2的值,你会发现对象1的值也会被更改,这就是引用…

在ubuntu系统上安装nginx以及php的部署

1、安装依赖包 apt-get install gcc apt-get install libpcre3 libpcre3-dev apt-get install zlib1g zlib1g-dev sudo apt-get install openssl sudo apt-get install libssl-dev 2、到nginx官方下载 官方地址:nginx: download 图中下载的nginx1.22版本&#…

Python | Leetcode Python题解之第322题重新安排行程

题目: 题解: class Solution:def findItinerary(self, tickets: List[List[str]]) -> List[str]:def dfs(curr: str):while vec[curr]:tmp heapq.heappop(vec[curr])dfs(tmp)stack.append(curr)vec collections.defaultdict(list)for depart, arri…

element-ui周选择器,如何获取年、周、起止日期?

说明 版本:vue2、element-ui2.15.14 element-ui的日期选择器可以设为周,即typeweek,官方示例如下: 如果你什么都不操作,那么获取的周的值为: value1: Tue Aug 06 2024 00:00:00 GMT0800 (中国标准时间)如…

分布式存储ceph知识点整理

一、Ceph概述 如何选择存储 底层协议兼容性产品要有定位,功能有所取舍针对特定市场的应用存储被市场认可的存储系统 稳定性是第一位的性能第二数据功能要够用 一)存储分类 1、本地存储 本地的文件系统,不能在网络上用。 如:ext3、…

WPF学习(11)-ToolTip控件(提示工具)+Popup弹出窗口

ToolTip控件 ToolTip控件继承于ContentControl,它不能有逻辑或视觉父级,意思是说它不能以控件的形式实例化,它必须依附于某个控件。因为它的功能被设计成提示信息,当鼠标移动到某个控件上方时,悬停一会儿,…

【React】实现输入框切换

需求 类似designable-antd平台的这个切换功能: 点击右边按钮,可以切换不同的输入框样式。 实现 维护一个type-component的类型数组遍历数组,找到当前组件类型并渲染当切换输入框样式的时候,获取下一个组件类型并渲染。如果为最…

UE5——如何在UI界面中显示鼠标并可以点击按钮

首先进入UI蓝图的图标界面,在Event Construct节点 后连接一个Set Input Model UI Only去设置用户的输入模式 同时使用Get Player Controller获取玩家控制器并连接到Set Input Model UI Only的Player Controller 连接好后是这个样子。 此时整个UI界面只能获取到鼠标的…

Apple 智能基础语言模型

Introducing Apple’s On-Device and Server Foundation Models technical details June 10, 2024 在2024年的全球开发者大会上,苹果推出了Apple Intelligence,这是一个深度集成到iOS 18、iPadOS 18和macOS Sequoia中的个人智能系统。Apple Intelligen…

【系统响应慢排查所需命令】ps -ef、grep、jstat、pmap 、sort 、head 、jmap 、dump.hprof

列出所有进程,找到需要的进程id【ps -ef】 UID: 进程所属的用户 ID。 PID: 进程 ID。 PPID: 父进程 ID。 C: CPU 使用率。 STIME: 进程启动的时间。 TTY: 与进程关联的终端。 TIME: 进程占用的 CPU 时间。 CMD: 启动进程的命令。 假如是搜索功能缓慢&#x…

算法板子:分解质因数

目录 1. 质因数的概念 2. 代码 1. 质因数的概念 这道题的目的是找到x这个数的质因数的底数和指数。例如280这个数&#xff0c;可以看成2^3 * 5^1 * 7^1&#xff0c;其中2、5和7分别是三个质因数的底数&#xff0c;3、1、1分别是三个质因数的指数。 2. 代码 #include <io…

Java | Leetcode Java题解之第332题重新安排行程

题目&#xff1a; 题解&#xff1a; class Solution {Map<String, PriorityQueue<String>> map new HashMap<String, PriorityQueue<String>>();List<String> itinerary new LinkedList<String>();public List<String> findItine…

onnxruntime和tensorrt动态输入推理

onnxruntime动态输入推理 lenet的onnxruntime动态输入推理 导出下面的onnx模型&#xff1a; 可以看到&#xff0c;该模型的输入batch是动态的。 onnx动态输入推理&#xff08;python&#xff09;&#xff1a; import cv2 import numpy as np import onnxruntime from path…

AI 手机的技术展望

某某领导问到我&#xff0c;AI手机这个产业发展如何&#xff1f;对于&#xff0c;地方科技园区&#xff0c;应该如何发展相关产业&#xff1f;我一时还真说不上来&#xff0c;于是&#xff0c;查了一下资料&#xff0c;大概应对了一下。 一&#xff1a;AI手机的定义 首先&…

《车辆路径规划问题》专栏_安全提示3——关于抄袭并通过其本人有偿获取内容的安全提示

近期经粉丝反馈&#xff0c;咸鱼用户《白芷归露》 未经允许&#xff0c;盗用本人原创代码 &#xff1a; 【自适应大邻域算法(ALNS)求解MDHFVRPTW『Py』】 本人在此声明&#xff0c;此咸鱼号 非本博主运营&#xff0c;其行为与本人无关&#xff0c;如有在处上当受骗者&#xf…