基于text2vec进行文本向量化、聚类

news2025/1/11 20:56:50

基于text2vec进行文本向量化、聚类

  • 基于text2vec进行文本向量化、聚类
    • 介绍
    • 安装
      • 安装text2vec库
      • 安装transformers库
    • 模型下载
    • 文本向量化
      • 使用text2vec
      • 使用transformers
    • 文本聚类
      • 训练流程:
      • 训练代码
      • 推理流程
      • 推理代码

基于text2vec进行文本向量化、聚类

介绍

文本向量表征工具,把文本转化为向量矩阵,是文本进行计算机处理的第一步。

text2vec实现了Word2Vec、RankBM25、BERT、Sentence-BERT、CoSENT等多种文本表征、文本相似度计算模型,并在文本语义匹配(相似度计算)任务上比较了各模型的效果。

安装

安装text2vec库

pip install  text2vec

安装transformers库

pip install transformers

模型下载

默认情况下模型会下载到cache的目录下,不方便直接调用

需要手动下载以下三个文件,新建bert_chinese文件夹,把这三个文件放进去。

https://huggingface.co/bert-base-chinese/tree/main

在这里插入图片描述

文本向量化

使用text2vec

from text2vec import SentenceModel
sentences = ['如何更换花呗绑定银行卡', '花呗更改绑定银行卡']
model_path = "bert_chinese"
model = SentenceModel(model_path)
embeddings = model.encode(sentences)
print(embeddings)

使用transformers

from transformers import BertTokenizer, BertModel
import torch

# Mean Pooling - Take attention mask into account for correct averaging
def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0]  # First element of model_output contains all token embeddings
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

# Load model from local
model_path = "bert_chinese"
tokenizer = BertTokenizer.from_pretrained(model_path)
model = BertModel.from_pretrained(model_path)
sentences = ['如何更换花呗绑定银行卡', '花呗更改绑定银行卡']
# Tokenize sentences
encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')

# Compute token embeddings
with torch.no_grad():
    model_output = model(**encoded_input)
# Perform pooling. In this case, max pooling.
sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
print("Sentence embeddings:")
print(sentence_embeddings)

文本聚类

训练流程:

  • 加载新闻数据
  • 基于text2vec利用bert模型进行文本向量化
  • 基于KMeans对向量化的模型进行聚类
  • 基于三种评估指标查看模型好坏
  • 利用joblib保存模型

训练代码

from transformers import BertTokenizer, BertModel
import torch
from sklearn.cluster import KMeans
from sklearn import metrics
from sklearn.metrics import silhouette_score
from sklearn.metrics import  davies_bouldin_score
import joblib
import os

#get txt file
file_path = "data\THUCNews"
files = os.listdir(file_path)
contents = []
for file in files:
    file_p = os.path.join(file_path,file)
    with open(file_p, 'r',encoding='utf-8') as f:
        a = f.read()[:200]
        contents.append(a)


# Mean Pooling - Take attention mask into account for correct averaging
def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0]  # First element of model_output contains all token embeddings
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

model_path = "bert_chinese"
# Load model from HuggingFace Hub
tokenizer = BertTokenizer.from_pretrained(model_path)
model = BertModel.from_pretrained(model_path)
# sentences = ['如何更换花呗绑定银行卡', '花呗更改绑定银行卡','明天下午会下雨','周二下午可能是阴天','星期六不是晴天']
# Tokenize sentences
encoded_input = tokenizer(contents, padding=True, truncation=True, return_tensors='pt')

# Compute token embeddings
with torch.no_grad():
    model_output = model(**encoded_input)
# Perform pooling. In this case, max pooling.
sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
print("Sentence embeddings:")
print(sentence_embeddings.shape)

X = sentence_embeddings
kmeans = KMeans(n_clusters=3)
kmeans.fit(X)
joblib.dump(kmeans, 'kmeans.joblib')
#kmeans = joblib.load('kmeans.joblib')

labels = kmeans.predict(X)
print(labels)
score = silhouette_score(X, labels)
ch_score = metrics.calinski_harabasz_score(X, kmeans.labels_)
davies_bouldin_score = davies_bouldin_score(X, kmeans.labels_)

print("Calinski-Harabasz指数:", ch_score)
print("轮廓系数评分为:", score)
print("Davies-Bouldin指数评分:", davies_bouldin_score)

推理流程

  • 输入文本
  • 基于text2vec利用bert模型进行文本向量化
  • 加载训练好的聚类模型
  • 对向量化的文本进行预测类别
  • 类别映射

推理代码

import joblib
from transformers import BertTokenizer, BertModel
import torch

map_labels = ["娱乐","星座",'体育']
contents = '双鱼综合症患者的自述(图)新浪网友:比雅   星座真心话征稿启事双鱼座是眼泪泡大的星座,双鱼座是多愁善感的星座,双鱼座是多情的星座,双鱼座是爱幻想的星座。'
kmeans = joblib.load('kmeans.joblib')

def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0]  # First element of model_output contains all token embeddings
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

model_path = "bert_chinese"
# Load model from HuggingFace Hub
tokenizer = BertTokenizer.from_pretrained(model_path)
model = BertModel.from_pretrained(model_path)
# Tokenize sentences
encoded_input = tokenizer(contents, padding=True, truncation=True, return_tensors='pt')

# Compute token embeddings
with torch.no_grad():
    model_output = model(**encoded_input)
# Perform pooling. In this case, max pooling.
sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
print("Sentence embeddings:")
print(sentence_embeddings.shape)

X = sentence_embeddings
labels = kmeans.predict(X)
print(map_labels[labels[0]])

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/193871.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

21_ncwireshark

nc&wireshark 一、nc介绍 网上百度就有一堆介绍,平常主要用于监听和连接 二、nc判断端口是否开放 实验环境: win10虚拟机和kali虚拟机 win10虚拟机ip: 192.168.11.142 kali虚拟机ip: 192.168.11.131 此时win10虚拟机,开放了80端口,21端口,3306端口 在kali虚拟机使用…

【书】只会功能测试的我,核心竞争力在哪里?

在现实工作中,测试的工作在很多人眼里就是“点点点”,特别是在推进开发自测或者向上级回报的时候,测试的工作显得那么平白无奇。 不仅是在公司内部,在进行面试的时候也会被问到,你觉得你的优势在哪里? 以上…

BlockingQueue(阻塞队列)详解

一. 前言 在新增的Concurrent包中,BlockingQueue很好的解决了多线程中,如何高效安全“传输”数据的问题。通过这些高效并且线程安全的队列类,为我们快速搭建高质量的多线程程序带来极大的便利。本文详细介绍了BlockingQueue家庭中的所有成员&…

上海亚商投顾:A股三大指数震荡涨跌各异 大消费全天活跃

上海亚商投顾前言:无惧大盘涨跌,解密龙虎榜资金,跟踪一线游资和机构资金动向,识别短期热点和强势个股。市场情绪沪指今日窄幅震荡,创业板指小幅冲高后回落,科创50指数盘中涨近1.5%,随后涨幅明显…

gdb与gdbserver的使用

GDB调试示例以调试可执行程序gdbDebug为例,gdbDebug.cpp内容如下:使用gdb 启动gdbDebug程序左侧为gdb调试,右侧为gdbDebug.cpp内容GDB与GDBServer调试示例以调试可执行程序gdbDebug为例,gdbDebug.cpp内容如下:使用gdbs…

StarRocks荣获2022年度最具潜力数据库奖

近日,墨天轮颁布了《2022年度数据库获奖名单》,通过墨天轮排行榜排名及年度得分、生态建设、市场活动、市场份额、专利数等38个综合指标进行遴选,评选出2022年的数据库重要奖项,期望能够通过多维度评选,呈现出数据库产…

客户在国内,挑选海外服务器供应商有什么技巧?

​  一直以来,基于互联网管理的严格要求,在使用中国大陆服务器放置网站时,是需要备案手续的,这个手续时长快则7天,慢则也有接近1个月的情况,复杂耗时,当然,这也是对建站成本的增加…

程序员想兼职赚钱?这几个渠道你一定要知道?

某一天当一个程序员,一拍脑门想要兼职,赚点小钱,于是他打开了知乎,打开了百度搜索兼职。结果弹出了一大部分有兼职要视频剪辑的,写文稿的等等等等。逛了一圈,发现根本没有自己合适的兼职。 我想说&#xff…

0201 设置/修改元素内容和属性

document.write()方法文本内容追加到</body>前面的位置文本中标签会被解析<script>document.write(hello world)document.write(<h3>你好世界</h3>)</script>innerText属性将文本内容添加更新到任意标签位置文本包含的标签不会被解析<style&g…

JavaScript 类与类型判断

JavaScript中的数据类型 JavaScript的数据类型大致分为两种&#xff1a;原始类型、对象类型。 原始类型&#xff1a;数字Number、字符串String、布尔值boolean、以及两个特殊值(null、undefined). 对象类型&#xff1a;数组Array、函数Function、日期Date、正则RegExp、错误Err…

Small RTOS51 学习笔记(10)Small RTOS51 的移植

个人笔记 文章目录准备一个 51 单片机工程将 Small RTOS51 相关文件添加到工程一个简单的程序运行效果遇到的问题准备一个 51 单片机工程 我打算拿一个现成的 51 单片机工程来移植 Small RTOS51&#xff0c; 当然&#xff0c;也可以重新创建一个新的工程。 将 Small RTOS51 相…

记录一次sql group by 优化记录

最近有个手动任务&#xff0c;需要计算每天的数据量&#xff0c;然后再进行处理。根据这种情况计算&#xff0c;sql是这样的SELECT FROM_UNIXTIME(publish_time / 1000, %Y-%m-%d) date,COUNT(*) as countFROMinfo_article_mainWHEREpublish_time BETWEEN ?AND ?GROUP BY dat…

Windows实时运动控制软核(六):LOCAL高速接口测试之Matlab

今天&#xff0c;正运动小助手给大家分享一下MotionRT7的安装和使用&#xff0c;以及使用Matlab对MotionRT7开发的前期准备。 01 MotionRT7简介 MotionRT7是深圳市正运动技术推出的跨平台运动控制实时内核&#xff0c;也是国内首家完全自主自研&#xff0c;自主可控的Windows…

flutter 中stack 控件的 大小是如何确定的

stack 控件 stack 是我们在flutter中常用到的控件&#xff0c;然而stack的大小是如何确定的是一个值得探究的问题&#xff0c;自己在网上也进行了搜索&#xff0c;但是总是不能解释自己遇到的新情况&#xff0c;所以我这里就根据目前的经验对stack大小是如何确定的进行一下总结…

【Java基础】006 -- 运算符

目录 一、算数运算符 1、基本用法 2、高级用法 ①、数字相加 ②、字符串相加 ③、字符相加 二、自增自减运算符 1、基本用法 三、赋值运算符 四、关系运算符 五、逻辑运算符 1、四种逻辑运算符 2、短路逻辑运算符 六、三元运算符 1、什么是三元运算符 2、三元运算符格式 七、运…

规则引擎-drools-3.3-drl文件构成-rule部分-条件Condition

文章目录drl文件构成-rule部分条件部分 LHS模式&#xff08;Pattern&#xff09;、绑定变量属性约束DRL中支持的规则条件元素&#xff08;关键字&#xff09;运算符比较操作符条件元素继承条件元素do对应多then条件drl文件构成-rule部分 drl文件构成&#xff0c;位于官网的第5…

工程师是怎样对待开源

工程师如何对待开源 本文是笔者作为一个在知名科技企业内从事开源相关工作超过 20 年的工程师&#xff0c;亲身经历或者亲眼目睹很多工程师对待开源软件的优秀实践&#xff0c;也看到了很多 Bad Cases&#xff0c;所以想把自己的一些心得体会写在这里&#xff0c;供工程师进行…

递归、dfs、回溯、剪枝,一针见血的

一、框架&#xff1a;回溯搜索的遍历过程&#xff1a;回溯法⼀般是在集合中递归搜索&#xff0c;集合的⼤⼩构成了树的宽度&#xff0c;递归的深度构成的树的深度。for循环就是遍历集合区间&#xff0c;可以理解⼀个节点有多少个孩⼦&#xff0c;这个for循环就执⾏多少次。back…

那些提升工作效率的Windows常用快捷键

那些提升工作效率的Windows常用快捷键 前言 在我们日常工作中&#xff0c;掌握一些常用的电脑快捷键&#xff0c;可以让办公效率事半功倍&#xff0c;熟用快捷键可以极大增加我们的工作效率&#xff0c;更重要的是键盘操作看起来更让人赏心悦目&#xff01; 我们通常将快捷键…

【C++】作用域与函数重载

【C】作用域与函数重载 1、作用域 1.1 作用域的作用 作用域——scope 通常来说&#xff0c;一段程序代码中所用到的名字并不总是有效/可用的&#xff0c;而限定这个名字的可用性的代码范围就是这个名字的作用域。 简单来说&#xff0c;作用域的使用减少了代码中名字的重复冲…