N6 word2vec文本分类

news2024/9/22 17:24:14
  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊# 前言

前言

上周学习了训练word2vec模型,这周进行相关实战

1. 导入所需库和设备配置
import torch
import torch.nn as nn
import torchvision
from torchvision import transforms, datasets
import os, PIL, pathlib, warnings

warnings.filterwarnings("ignore")  # 忽略警告信息

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

import pandas as pd
2. 加载数据
train_data = pd.read_csv('./train.csv', sep='\t', header=None)
print(train_data)
3. 数据预处理
def coustom_data_iter(texts, labels):
    for x, y in zip(texts, labels):
        yield x, y

x = train_data[0].values[:]
y = train_data[1].values[:]

from gensim.models.word2vec import Word2Vec
import numpy as np

w2v = Word2Vec(vector_size=100, min_count=3)
w2v.build_vocab(x)
w2v.train(x, total_examples=w2v.corpus_count, epochs=20)
  • 定义自定义数据迭代器coustom_data_iter
  • 提取文本和标签数据。
  • 使用Word2Vec训练词向量模型,设置词向量维度为100,最小词频为3。
def average_vec(text):
    vec = np.zeros(100).reshape((1, 100))
    for word in text:
        try:
            vec += w2v.wv[word].reshape((1, 100))
        except KeyError:
            continue
    return vec

x_vec = np.concatenate([average_vec(z) for z in x])
w2v.save('w2v_model.pkl')

train_iter = coustom_data_iter(x_vec, y)
print(len(x), len(x_vec))
label_name = list(set(train_data[1].values[:]))
print(label_name)

text_pipeline = lambda x: average_vec(x)
label_pipeline = lambda x: label_name.index(x)

print(text_pipeline("你在干嘛"))
print(label_pipeline("Travel-Query"))
  • 定义函数average_vec,将文本转换为词向量的平均值。
  • 将所有文本转换为词向量并保存Word2Vec模型。
  • 打印文本和向量的数量,以及所有标签的名称。
  • 定义文本和标签的预处理函数text_pipelinelabel_pipeline
4. 数据加载器
from torch.utils.data import DataLoader

def collate_batch(batch):
    label_list, text_list = [], []

    for (_text, _label) in batch:
        label_list.append(label_pipeline(_label))
        processed_text = torch.tensor(text_pipeline(_text), dtype=torch.float32)
        text_list.append(processed_text)

    label_list = torch.tensor(label_list, dtype=torch.int64)
    text_list = torch.cat(text_list)

    return text_list.to(device), label_list.to(device)

dataloader = DataLoader(train_iter, batch_size=8, shuffle=False, collate_fn=collate_batch)
  • 定义函数collate_batch,将批次中的文本和标签转换为张量。
  • 创建数据加载器dataloader
5. 定义模型
class TextClassificationModel(nn.Module):

    def __init__(self, num_class):
        super(TextClassificationModel, self).__init__()
        self.fc = nn.Linear(100, num_class)

    def forward(self, text):
        return self.fc(text)

num_class = len(label_name)
model = TextClassificationModel(num_class).to(device)
  • 定义文本分类模型TextClassificationModel,包含一个全连接层。
  • 初始化模型,设置输出类别数。
6. 训练和评估函数
import time

def train(dataloader):
    model.train()
    total_acc, train_loss, total_count = 0, 0, 0
    log_interval = 50
    start_time = time.time()

    for idx, (text, label) in enumerate(dataloader):
        predicted_label = model(text)

        optimizer.zero_grad()
        loss = criterion(predicted_label, label)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
        optimizer.step()

        total_acc += (predicted_label.argmax(1) == label).sum().item()
        train_loss += loss.item()
        total_count += label.size(0)

        if idx % log_interval == 0 and idx > 0:
            elapsed = time.time() - start_time
            print('| epoch {:1d} | {:4d}/{:4d} batches | train_acc {:4.3f} train_loss {:4.5f}'.format(
                epoch, idx, len(dataloader), total_acc / total_count, train_loss / total_count))
            total_acc, train_loss, total_count = 0, 0, 0
            start_time = time.time()

def evaluate(dataloader):
    model.eval()
    total_acc, train_loss, total_count = 0, 0, 0

    with torch.no_grad():
        for idx, (text, label) in enumerate(dataloader):
            predicted_label = model(text)
            loss = criterion(predicted_label, label)
            total_acc += (predicted_label.argmax(1) == label).sum().item()
            train_loss += loss.item()
            total_count += label.size(0)

    return total_acc / total_count, train_loss / total_count
  • 定义训练函数train和评估函数evaluate
7. 训练模型
from torch.utils.data.dataset import random_split
from torchtext.data.functional import to_map_style_dataset

EPOCHS = 10
LR = 5
BATCH_SIZE = 64

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=LR)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.1)
total_accu = None

train_iter = coustom_data_iter(train_data[0].values[:], train_data[1].values[:])
train_dataset = to_map_style_dataset(train_iter)

split_train_, split_valid_ = random_split(train_dataset, [int(len(train_dataset) * 0.8), int(len(train_dataset) * 0.2)])

train_dataloader = DataLoader(split_train_, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)
valid_dataloader = DataLoader(split_valid_, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)

for epoch in range(1, EPOCHS + 1):
    epoch_start_time = time.time()
    train(train_dataloader)
    val_acc, val_loss = evaluate(valid_dataloader)

    lr = optimizer.state_dict()['param_groups'][0]['lr']

    if total_accu is not None and total_accu > val_acc:
        scheduler.step()
    else:
        total_accu = val_acc
    print('-' * 69)
    print('| epoch {:1d} | time: {:4.2f}s | valid_acc {:4.3f} valid_loss {:4.3f} | lr {:4.6f}'.format(
        epoch, time.time() - epoch_start_time, val_acc, val_loss, lr))
    print('-' * 69)

test_acc, test_loss = evaluate(valid_dataloader)
print('模型准确率为:{:5.4f}'.format(test_acc))
  • 定义超参数并初始化损失函数、优化器和学习率调度器。
  • 创建数据集并进行训练集和验证集的划分。
  • 训练模型并在每个epoch后进行验证。
8. 预测函数
def predict(text, text_pipeline):
    with torch.no_grad():
        text = torch.tensor(text_pipeline(text), dtype=torch.float32)
        print(text.shape)
        output = model(text)
        return output.argmax(1).item()

ex_text_str = "还有双鸭山到淮阴的汽车票吗13号的"
model = model.to("cpu")
print("该文本的类别是:%s" % label_name[predict(ex_text_str, text_pipeline)])
  • 定义预测函数predict,将文本转换为张量并使用模型进行预测。
  • 使用示例文本进行预测并输出结果。

结果

在这里插入图片描述

总结

这周学习了通过word2vec文本分类,包括数据加载、预处理、模型训练、评估和预测。进一步加深了对word2vec的理解。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1915868.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

10x Visium HD数据分析

–https://satijalab.org/seurat/articles/visiumhd_analysis_vignette 留意更多内容,欢迎关注微信公众号:组学之心 1.数据准备-Seurat Visium HD 数据是由特定空间排列分布的寡核苷酸序列在 2um x 2um 的网格(bin)中生成的。然…

15. Revit API: Transaction(事务)与 Failures(故障处理)

前言 UI讲完,回到DB这块儿。在Document那篇,提到增删改查操作都是在Document上,是对Documet进行操作。 看到“增删改查”这四个,想到什么了没有? 数据库(DB)嘛~话说那本经典的红皮数据库的书叫…

Python学习笔记34:进阶篇(二十三)pygame的使用之颜色与字体

前言 基础模块的知识通过这么长时间的学习已经有所了解,更加深入的话需要通过完成各种项目,在这个过程中逐渐学习,成长。 我们的下一步目标是完成python crash course中的外星人入侵项目,这是一个2D游戏项目。在这之前&#xff…

算法训练营day28--134. 加油站 +135. 分发糖果+860.柠檬水找零+406.根据身高重建队列

一、 134. 加油站 题目链接:https://leetcode.cn/problems/gas-station/ 文章讲解:https://programmercarl.com/0134.%E5%8A%A0%E6%B2%B9%E7%AB%99.html 视频讲解:https://www.bilibili.com/video/BV1jA411r7WX 1.1 初见思路 得模拟分析出…

【Python实战因果推断】19_线性回归的不合理效果9

目录 De-Meaning and Fixed Effects Omitted Variable Bias: Confounding Through the Lens of Regression De-Meaning and Fixed Effects 您刚刚看到了如何在模型中加入虚拟变量来解释不同组间的不同干预分配。但是,FWL 定理真正的亮点在于虚拟变量。如果您有大量…

鸿蒙架构之AOP

零、主要内容 AOP 简介ArkTs AOP 实现原理 JS 原型链AOP实现原理 AOP的应用场景 统计类: 方法调用次数统计、方法时长统计防御式编程:参数校验代理模式实现 AOP的注意事项 一、AOP简介 对于Android、Java Web 开发者来说, AOP编程思想并不…

【前端】包管理器:npm、Yarn 和 pnpm 的全面比较

前端开发中的包管理器:npm、Yarn 和 pnpm 的全面比较 在现代前端开发中,包管理器是开发者必不可少的工具。它们不仅能帮我们管理项目的依赖,还能极大地提高开发效率。本文将详细介绍三种主流的前端包管理器:npm、Yarn 和 pnpm&am…

错位情缘悬疑升级

✨🔥【错位情缘,悬疑升级!关芝芝与黄牡丹的惊世婚约】🔥✨在这个迷雾重重的剧场,一场前所未有的错位大戏正悄然上演!👀 你没看错,昔日兄弟的前女友关芝芝,竟摇身一变成了…

axios使用sm2加密数据后请求参数多了双引号解决方法

axios使用sm2加密数据后请求参数多了双引号解决 背景问题描述解决过程 背景 因项目安全要求,需对传给后端的入参加密,将请求参数加密后再传给后端 前期将axios降低到1.6.7后解决了问题,但最近axios有漏洞,安全要求对版本升级&…

通过电压差判定无源晶振是否起振正确吗?

在电子工程中,无源晶振作为许多数字电路的基础组件,其是否成功起振对于系统的正常运行至关重要。然而,通过简单检测晶振两端的电压差来判断晶振是否工作,这一方法存在一定的误区,晶发电子将深入探讨这一话题&#xff0…

【AIGC】一、本地docker启动私有大模型

本地docker启动私有大模型 一、最终效果中英文对话生成代码 二、资源配置三、搭建步骤启动docker容器登录页面首次登录请注册登录后的效果 配置模型尝试使用选择模型选项下载模型选择适合的模型开始下载 试用效果返回首页选择模型中英文对话生成代码 四、附录资源监控 五、参考…

浮点类型使用陷阱

引言 当我们进行条件判断时,经常会遇到两个数是否相等的情况,但如果在程序中进行判断一个可以除尽的小数和数学上除出来所得的数是否相等时,就会神奇的发型居然不相等??! 遇到问题 看如下代码 double num5 2.7;//2.7double num6 8.1 / 3;//接近2.7System.out.println(n…

NAS免费用,鲁大师 AiNAS正式发布,「专业版」年卡仅需264元

7月10日,鲁大师召开新品发布会,正式发布旗下以“提供本地Ai部署和使用能力以及在线NAS功能”并行的复合软件产品:鲁大师 AiNAS。 全新的鲁大师 AiNAS将持续满足现如今大众对于数字化生活的全新需求,将“云存储”的便捷与NAS的大容…

学圣学最终的目的是:达到思无邪的状态( 纯粹、思想纯正、积极向上 )

学圣学最终的目的是:达到思无邪的状态( 纯粹、思想纯正、积极向上 ) 中华民族,一直以来,教学都是以追随圣学为目标,所以中华文化也叫圣学文化,是最高深的上等学问; 圣人那颗心根本…

如何配置yolov10环境?

本文介绍如何快速搭建起yolov10环境,用于后续项目推理、模型训练。教程适用win、linux系统 yolo10是基于yolo8(ultralytics)的改进,环境配置跟yolo8几乎一模一样。 目录 第1章节:创建虚拟环境 第2章节:…

tesla p100显卡显示资源不足,api调用失败

🏆本文收录于《CSDN问答解惑-专业版》专栏,主要记录项目实战过程中的Bug之前因后果及提供真实有效的解决方案,希望能够助你一臂之力,帮你早日登顶实现财富自由🚀;同时,欢迎大家关注&&收…

数据结构JAVA

1.数据结构之栈和队列 栈结构 先进后出 队列结构 先进先出 队列 2.数据结构之数组和链表 数组结构 查询快、增删慢 队列结构 查询慢、增删快 链表的每一个元素我们叫结点 每一个结点都是独立的对象

浅谈“不要卷模型,要卷应用”

目录 1.概述 2.AI技术应用场景探索 3.避免超级应用陷阱的策略 3.1.追求DAU的弊端 3.2.平衡用户活跃度与应用实用性的策略 4.个性化智能体开发 4.1. 用户需求分析与数据收集 4.2. 技术选择与开发 4.3. 个性化算法设计 4.4. 安全性与隐私保护 4.5. 多渠道集成与响应机…

《昇思25天学习打卡营第14天|计算机视觉-ShuffleNet图像分类》

FCN图像语义分割&ResNet50迁移学习&ResNet50图像分类 当前案例不支持在GPU设备上静态图模式运行,其他模式运行皆支持。 ShuffleNet网络介绍 ShuffleNetV1是旷视科技提出的一种计算高效的CNN模型,和MobileNet, SqueezeNet等一样主要应用在移动端…

GraphGNSSLib Series[2]:在CLion中不同Node间进行debug

CLion实现Node debug 步骤: 我了解到的node,大多是通过终端运行,但是使用clion不断debug断点进行调试一直使我很苦恼,所以此次记录一下如何通过clion实现node节点之间通过publisher以及subscriber进行节点话题间的发布与通信&…