基于word2vec 和 fast-pytorch-kmeans 的文本聚类实现,利用GPU加速提高聚类速度

news2024/11/16 11:26:39

文章目录

    • 简介
      • GPU加速
    • 代码实现
    • kmeans
    • 聚类结果
    • kmeans 绘图函数
    • 相关资料参考

简介

本文使用text2vec模型,把文本转成向量。使用text2vec提供的训练好的模型权重进行文本编码,不重新训练word2vec模型。

直接用训练好的模型权重,方便又快捷

完整可运行代码如下:
https://github.com/JieShenAI/csdn/blob/main/machine_learning/kmeans_pytorch.ipynb

GPU加速

传统sklearn的TF-IDF文本转向量,在CPU上计算速度较慢。使用text2vec通过cuda加速,加快文本转向量的速度。
传统使用sklearn的kmeans聚类算法在CPU上计算,如遇到大批量的数据,计算耗时太长。
故本文使用kmeans_pytorch包,基于pytorch在GPU上计算,提高聚类速度。

代码实现

装包

pip install fast-pytorch-kmeans text2vec
import torch
import numpy as np

from text2vec import SentenceModel

不使用SentenceModel模型也可以,在 text2vec 中,还有很多其他的向量编码模型供选择。

文本编码模型

embedder = SentenceModel()

异常情况说明,该模型需要从huggingface下载模型权重,目前被墙了。(请想办法解决,或者尝试其他的编码模型)
在这里插入图片描述

语料库如下:

# Corpus with example sentences
corpus = [
    '花呗更改绑定银行卡',
    '我什么时候开通了花呗',
    'A man is eating food.',
    'A man is eating a piece of bread.',
    'The girl is carrying a baby.',
    'A man is riding a horse.',
    'A woman is playing violin.',
    'Two men pushed carts through the woods.',
    'A man is riding a white horse on an enclosed ground.',
]
corpus_embeddings = embedder.encode(corpus)
# numpy 转成 pytorch, 并转移到GPU显存中
corpus_embeddings = torch.from_numpy(corpus_embeddings).to('cuda')

如下图所示,编码的向量是768纬;

type(corpus_embeddings), corpus_embeddings.shape

在这里插入图片描述

kmeans

kmeans_pytorch vs fast-pytorch-kmeans:
在实验过程中,利用kmeans_pytorch 针对30万个词进行聚类的时候,发现显存炸了,程序崩溃退出。30万个词的词向量,占用显存还不到2G,但是运行kmeans_pytorch后,显存就炸了。

fast-pytorch-kmeans不存在上述显存崩溃的问题。本以为词向量很多会跑很长时间,但fast-pytorch-kmeans在非常短的时间内就完成了kmeans聚类。

# kmeans
# from kmeans_pytorch import kmeans
from fast_pytorch_kmeans import KMeans

num_class = 3 # 分类类别数
kmeans = KMeans(n_clusters=num_class, mode='euclidean', verbose=1)

# 模型预测结果
labels = kmeans.fit_predict(corpus_embeddings)

聚类程序运行如下:

used 2 iterations (0.3682s) to cluster 9 items into 3 clusters

模型中心点坐标:

kmeans.centroids

在这里插入图片描述

聚类结果

class_data = {
    i:[]
    for i in range(3)
}

for text,cls in zip(corpus, labels):
    class_data[cls.item()].append(text)

class_data

文本聚类结果如下:
0: 女
1:男
2: 花呗
在这里插入图片描述

kmeans 绘图函数

封装了KMeansPlot 绘图类,方便聚类结果可视化

from sklearn.decomposition import PCA
import matplotlib.pyplot as plt

class KMeansPlot:

    def __init__(self, numClass=4, func_type='PCA'):
        if func_type == 'PCA':
            self.func_plot = PCA(n_components=2)
        elif func_type == 'TSNE':
            from sklearn.manifold import TSNE
            self.func_plot = TSNE(2)
        self.numClass = numClass

    def plot_cluster(self, result, pos, cluster_centers=None):
        plt.figure(2)
        Lab = [[] for i in range(self.numClass)]
        index = 0
        for labi in result:
            Lab[labi].append(index)
            index += 1
        color = ['oy', 'ob', 'og', 'cs', 'ms', 'bs', 'ks', 'ys', 'yv', 'mv', 'bv', 'kv', 'gv', 'y^', 'm^', 'b^', 'k^',
                    'g^'] * 3

        for i in range(self.numClass):
            x1 = []
            y1 = []
            for ind1 in pos[Lab[i]]:
                # print ind1
                try:
                    y1.append(ind1[1])
                    x1.append(ind1[0])
                except:
                    pass
            plt.plot(x1, y1, color[i])

        if cluster_centers is not None:
            #绘制初始中心点
            x1 = []
            y1 = []

            for ind1 in cluster_centers:
                try:
                    y1.append(ind1[1])
                    x1.append(ind1[0])
                except:
                    pass

            plt.plot(x1, y1, "rv") #绘制中心

        plt.show()

    def plot(self, weight, label, cluster_centers=None):
        pos = self.func_plot.fit_transform(weight)
        # 高纬的中心点坐标,也经过降纬处理
        cluster_centers = self.func_plot.fit_transform(cluster_centers)
        self.plot_cluster(list(label), pos, cluster_centers)

kmeans.centroids :是一个高纬空间的中心点坐标,故在plot函数中,将其降纬到2D平面上;

k_plot = KMeansPlot(num_class)
k_plot.plot(
    corpus_embeddings.to('cpu'),
    labels.to('cpu'),
    kmeans.centroids.to('cpu')
)

在这里插入图片描述

完整可运行代码如下:
https://github.com/JieShenAI/csdn/blob/main/machine_learning/kmeans_pytorch.ipynb

相关资料参考

  • 动手实战基于 ML 的中文短文本聚类
  • tfidf和word2vec构建文本词向量并做文本聚类
    提到训练word2vec模型,silhouette_score_show(word2vec, 'word2vec') 轮廓系数,判断分几个类别最好。
  • 机器学习:Kmeans聚类算法总结及GPU配置加速demo
    PyTorch kmeans 加速。from scratch 实现;
  • KMeans算法全面解析与应用案例 通俗易懂的原理讲解
  • pytorch K-means算法的实现 底层代码实现
  • 【pytorch】Kmeans_pytorch用于一般聚类任务的代码模板 使用pytorch封装的kmeans包实现,包括训练和预测;
  • text2vec 包

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1517391.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

软件无线电系列——模拟无线电、数字无线电、软件无线电

本节目录 一、模拟无线电 二、数字无线电 1、窄带数字无线电 2、宽带数字无线电 三、软件无线电本节内容 一、模拟无线电 20世纪80年代的模拟体制(美国的AMPS/欧洲的TACS)被称为第一代移动通信,简称1G,主要目标是为在大范围内有限的用户提供移动电话服务。最主要的…

uniapp遇到的问题

【uniapp】小程序中input输入框的placeholder-class不生效解决办法 解决:写在scope外面 uniapp设置底部导航 引用:https://www.jianshu.com/p/738dd51a0162 【微信小程序】moveable-view / moveable-area的使用 https://blog.csdn.net/qq_36901092/…

Figure与OpenAI 联手推出新机器人;荣耀首款「AI PC」即将发布

▶ Figure 与 OpenAI 联手推出新机器人 AI 机器人公司 Figure 发布了他们与 OpenAI 的合作成果,将 OpenAI 的大模型运用在其机器人 Figure 01 上。 据介绍,OpenAI 大模型加持的 Figure 01 机器人现在可以与人全面对话。 OpenAI 模型为机器人提供了高级…

微信小程序(五十九)使用鉴权组件时原页面js自动加载解决方法(24/3/14)

注释很详细,直接上代码 上一篇 新增内容: 1.使用覆盖函数的方法阻止原页面的自动执行方法 2.使用判断实现只有当未登录时才进行方法覆盖 源码: app.json {"pages": ["pages/index/index","pages/logs/logs"],…

mac删除带锁标识的app

一 、我们这里要删除FortiClient.app 带锁 常规方式删除不掉带锁的 app【如下图】 二、删除命令,依次执行即可。 /bin/ls -dleO /Applications/FortiClient.app sudo /usr/bin/chflags -R noschg /Applications/FortiClient.app /bin/ls -dleO /Applications/Forti…

2024计算机二级3

1. 2. 3. 4. 5. 6. append每次只能添加一个元素,两个元素都在同一个列表内相当于是一个整体 7. d.get后边括号内会存在一个默认值,如果题目给出的选项内没有已经存在的键值名,则会返回后边的默认值 8. 字典是映射数据类型,不属于…

【QT】TCP简易聊天框

我们首先复习一下TCP通信的流程 基于linuxTCP客户端和服务器 QT下的TCP处理流程 服务器先启动(处于监听状态) 各函数的意义和使用 QTcpServer Class *QTcpServer*类提供了一个基于TCP的服务器。这个类可以接受传入的TCP连接。您可以指定端口或让QTcpS…

碳储量监测的新技术:遥感在草原碳汇评估中的潜力

在全球环境问题日益严重的今天,以全球变暖为主要特征的气候变化成为了人类面临的巨大挑战。它威胁着地球的生态平衡,对全球可持续发展构成了严峻的挑战。为了应对这一挑战,各国纷纷采取行动,致力于实现碳达峰和碳中和的目标。 在…

Zabbix 监控 tomcat

zabbix-java-gateway服务组件 zabbix监控tomcat需要用到zabbix-java-gateway组件,它充当zabbix服务和java应用程序之间的网关。它允许zabbix服务器用过java网关与java应用程序进行通信,从而监控和收集java应用程序的性能数据。 zabbix-agent服务&#xf…

gradio部署视频输入输出示例,gradio网页输出视频nan,输出视频无法播放解决方法

gradio部署视频输入输出示例,gradio网页输出视频nan,输出视频无法播放 Opencv不能采用h64格式进行编码解决方案moviepy介绍浏览器接受的视频编码格式:采用h264编码合成视频: gradio部署视频输入输出示例Gradio视频组件使用详解简介…

小程序学习3 goods-card

pages/home/home home.wxml <goods-listwr-class"goods-list-container"goodsList"{{goodsList}}"bind:click"goodListClickHandle"bind:addcart"goodListAddCartHandle"/> <goods-list>是一个自定义组件&#xff0c;它具…

【MIT 6.S081】2020, 实验记录(8),Lab: locks

目录 Task 1&#xff1a;Memory allocator (moderate)</font>Task 2&#xff1a;Buffer cache (hard)</font> Task 1&#xff1a;Memory allocator (moderate) 这个任务就是练习将一把大锁拆分为多个小锁&#xff0c;同时可以更加深入地理解 memory allocator 运行…

PY32离线烧录器功能介绍,可批量烧录,支持PY32系列多款单片机

PY32离线烧录器可以对PY系列单片机进行批量烧录&#xff0c;现支持PY32F002A/002B/002/003/030/071/072/040/403/303芯片各封装和XL2409&#xff0c;XL32F001/003等芯片。PY32离线烧录器需要搭配上位机软件才能使用&#xff0c;上位机软件在我们官网&#xff08;www.xinlinggo.…

JVM基础篇

什么是JVM java虚拟机 JVM的功能 1.解释和运行 对字节码文件中的指令&#xff0c;实时的解释成机器码&#xff0c;让计算机执行 2.内存管理 自动为对象、方法等分配内存 自动的垃圾回收机制&#xff0c;回收不再使用的对象&#xff08;c不会自动回收&#xff0c;相当于降…

QT 如何防止 QTextEdit 自动滚动到最下方

在往QTextEdit里面append字符串时&#xff0c;如果超出其高度&#xff0c;默认会自动滚动到QTextEdit最下方。但是有些场景可能想从文本最开始的地方展示&#xff0c;那么就需要禁止自动滚动。 我们可以在append之后&#xff0c;添加如下代码&#xff1a; //设置编辑框的光标位…

指针的函数传参的详细讲解(超详细)

如果对指针基础知识已经有可以直接跳到 函数的指针传参与解引用&#xff0c;哪里不明白可以评论&#xff0c;随时解答。 目录 所以就有了一句话&#xff1a;指针就是地址&#xff0c;地址就是指针 对于指针在C语言中&#xff0c;指针类型就是数据类型&#xff0c;是给编译器…

PHP极简网盘系统源码 轻量级文件管理与共享系统网站源码

PHP极简网盘系统源码 轻量级文件管理与共享系统网站源码 极简网盘是一个轻量级文件管理与共享系统&#xff0c;支持多用户&#xff0c;可充当网盘程序&#xff0c;程序无需数据库 安装步骤&#xff1a; 1.建议安装在apache环境下&#xff0c;并确保.htaccess可用 2.解压文件…

论文阅读——RingMo

RingMo: A Remote Sensing Foundation Model With Masked Image Modeling 与自然场景相比&#xff0c;RS图像存在以下困难。 1&#xff09;分辨率和方位范围大&#xff1a;受遥感传感器的影响&#xff0c;图像具有多种空间分辨率。此外&#xff0c;与自然图像的实例通常由于重…

【Python】Leetcode 240. 搜索二维矩阵 II - 削减矩阵+递归,击败88%

描述 搜索二维矩阵 II 编写一个高效的算法来搜索 m x n 矩阵 matrix 中的一个目标值 target 。 该矩阵具有以下特性&#xff1a; 每行的元素从左到右升序排列。 每列的元素从上到下升序排列。 思路 确定左右及上下限&#xff0c;削减矩阵&#xff0c;递归。 注意判断四个端…

15届蓝桥杯第一期模拟赛所有题目解析

文章目录 &#x1f9e1;&#x1f9e1;t1_字母数&#x1f9e1;&#x1f9e1;问题描述思路代码 &#x1f9e1;&#x1f9e1;t2_大乘积&#x1f9e1;&#x1f9e1;问题描述思路代码 &#x1f9e1;&#x1f9e1;t3_星期几&#x1f9e1;&#x1f9e1;问题描述思路代码 &#x1f9e1;…