深度学习实战26-(Pytorch)搭建TextCNN实现多标签文本分类的任务

news2024/11/20 7:12:16

大家好,我是微学AI,今天给大家介绍一下深度学习实战26-(Pytorch)搭建TextCNN实现多标签文本分类的任务,TextCNN是一种用于文本分类的深度学习模型,它基于卷积神经网络(Convolutional Neural Networks, CNN)实现。TextCNN的主要思想是使用卷积操作从文本中提取有用的特征,并使用这些特征来预测文本的类别。

TextCNN将文本看作是一个一维的时序数据,将每个单词嵌入到一个向量空间中,形成一个词向量序列。然后,TextCNN通过堆叠一些卷积层和池化层来提取关键特征,并将其转换成一个固定大小的向量。最后,该向量将被送到一个全连接层进行分类。TextCNN的优点在于它可以非常有效地捕捉文本中的局部和全局特征,从而提高分类精度。此外,TextCNN的训练速度相对较快,具有较好的可扩展性.

TextCNN做多标签分类

1.库包导入

import os
import re
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score, precision_score, recall_score
from collections import Counter

 2.定义参数

max_length = 20
batch_size = 32
embedding_dim = 100
num_filters = 100
filter_sizes = [2, 3, 4]
num_classes = 4
learning_rate = 0.001
num_epochs = 2000
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

3. 数据集处理函数


def load_data(file_path):
    df = pd.read_csv(file_path,encoding='gbk')
    texts = df['text'].tolist()
    labels = df['label'].apply(lambda x: x.split("-")).tolist()
    return texts, labels

def preprocess_text(text):
    text = re.sub(r'[^\w\s]', '', text)
    return text.strip().lower().split()

def build_vocab(texts, max_size=10000):
    word_counts = Counter()
    for text in texts:
        word_counts.update(preprocess_text(text))
    vocab = {"<PAD>": 0, "<UNK>": 1}
    for i, (word, count) in enumerate(word_counts.most_common(max_size - 2)):
        vocab[word] = i + 2
    return vocab

def encode_text(text, vocab):
    tokens = preprocess_text(text)
    return [vocab.get(token, vocab["<UNK>"]) for token in tokens]

def pad_text(encoded_text, max_length):
    return encoded_text[:max_length] + [0] * max(0, max_length - len(encoded_text))

def encode_label(labels, label_set):
    encoded_labels = []
    for label in labels:
        encoded_label = [0] * len(label_set)
        for l in label:
            if l in label_set:
                encoded_label[label_set.index(l)] = 1
        encoded_labels.append(encoded_label)
    return encoded_labels

class TextDataset(Dataset):
    def __init__(self, texts, labels):
        self.texts = texts
        self.labels = labels

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, index):
        return torch.tensor(self.texts[index], dtype=torch.long), torch.tensor(self.labels[index], dtype=torch.float32)

texts, labels = load_data("data_qa.csv")
vocab = build_vocab(texts)
label_set = ["人工智能", "卷积神经网络", "大数据",'ChatGPT']

encoded_texts = [pad_text(encode_text(text, vocab), max_length) for text in texts]
encoded_labels = encode_label(labels, label_set)

X_train, X_test, y_train, y_test = train_test_split(encoded_texts, encoded_labels, test_size=0.2, random_state=42)
#print(X_train,y_train)

train_dataset = TextDataset(X_train, y_train)
test_dataset = TextDataset(X_test, y_test)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

数据集样例:

textlabel
人工智能如何影响进出口贸易——基于国家层面数据的实证检验人工智能
生成式人工智能——ChatGPT的变革影响、风险挑战及应对策略人工智能-ChatGPT
人工智能与人的自由全面发展关系探究——基于马克思劳动解放思想人工智能
中学生人工智能技术使用持续性行为意向影响因素研究 人工智能
人工智能技术在航天装备领域应用探讨 人工智能
人工智能赋能教育的伦理省思 人工智能
人工智能的神话:ChatGPT与超越的数字劳动“主体”之辨 人工智能-ChatGPT
人工智能(ChatGPT)对社科类研究生教育的挑战与机遇 人工智能-ChatGPT
人工智能助推教育变革的现实图景——教师对ChatGPT的应对策略分析 人工智能-ChatGPT
智能入场与民主之殇:人工智能时代民主政治的风险与挑战 人工智能
国内人工智能写作的研究现状分析及启示 人工智能
人工智能监管:理论、模式与趋势 人工智能
“新一代人工智能技术ChatGPT的应用与规制”笔谈 人工智能-ChatGPT
ChatGPT新一代人工智能技术发展的经济和社会影响 人工智能-ChatGPT
ChatGPT赋能劳动教育的图景展现及其实践策略 人工智能-ChatGPT
人工智能聊天机器人—基于ChatGPT、Microsoft Bing视角分析 人工智能-ChatGPT
拜登政府对华人工智能产业的打压与中国因应 人工智能
人工智能技术在现代农业机械中的应用研究人工智能
人工智能对中国制造业创新的影响研究—来自工业机器人应用的证据 人工智能
人工智能技术在电子产品设计中的应用人工智能
ChatGPT等智能内容生成与新闻出版业面临的智能变革人工智能-ChatGPT
基于卷积神经网络的农作物智能图像识别分类研究人工智能-卷积神经网络
基于卷积神经网络的图像分类改进方法研究人工智能-卷积神经网络

 

这里设置多标签,用“-”符号隔开多个标签。

4.构建模型

class TextCNN(nn.Module):
    def __init__(self, vocab_size, embedding_dim, num_filters, filter_sizes, num_classes, dropout=0.5):
        super(TextCNN, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.convs = nn.ModuleList([nn.Conv2d(1, num_filters, (fs, embedding_dim)) for fs in filter_sizes])
        self.dropout = nn.Dropout(dropout)
        self.fc = nn.Linear(num_filters * len(filter_sizes), num_classes)

    def forward(self, x):
        x = self.embedding(x)
        x= x.unsqueeze(1)
        x = [torch.relu(conv(x)).squeeze(3) for conv in self.convs]
        x = [torch.max_pool1d(i, i.size(2)).squeeze(2) for i in x]
        x = torch.cat(x, 1)
        x = self.dropout(x)
        logits = self.fc(x)
        return torch.sigmoid(logits)

5.模型训练

def train_epoch(model, dataloader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct_preds = 0  # 记录正确预测的数量
    total_preds = 0  # 记录总的预测数量
    for inputs, targets in dataloader:
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

        # 计算正确预测的数量
        predicted_labels = torch.argmax(outputs, dim=1)
        targets = torch.argmax(targets, dim=1)

        correct_preds += (predicted_labels == targets).sum().item()
        total_preds += len(targets)

    accuracy = correct_preds / total_preds  # 计算准确率
    return running_loss / len(dataloader), accuracy  # 返回平均损失和准确率

def evaluate(model, dataloader, device):
    model.eval()
    preds = []
    targets = []
    with torch.no_grad():
        for inputs, target in dataloader:
            inputs = inputs.to(device)
            outputs = model(inputs)
            preds.extend(outputs.cpu().numpy())
            targets.extend(target.numpy())
    return np.array(preds), np.array(targets)

def calculate_metrics(preds, targets, threshold=0.5):
    preds = (preds > threshold).astype(int)
    f1 = f1_score(targets, preds, average="micro")
    precision = precision_score(targets, preds, average="micro")
    recall = recall_score(targets, preds, average="micro")
    return {"f1": f1, "precision": precision, "recall": recall}

model = TextCNN(len(vocab), embedding_dim, num_filters, filter_sizes, num_classes).to(device)
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

for epoch in range(num_epochs):
    if epoch % 20==0:
        train_loss,accuracy = train_epoch(model, train_loader, criterion, optimizer, device)
        print(f"Epoch: {epoch + 1}, Train Loss: {train_loss:.4f}, Train Accuracy: {accuracy:.4f}")

        preds, targets = evaluate(model, test_loader, device)
        metrics = calculate_metrics(preds, targets)
        print(f"Epoch: {epoch + 1}, F1: {metrics['f1']:.4f}, Precision: {metrics['precision']:.4f}, Recall: {metrics['recall']:.4f}")
...
Epoch: 1821, Train Loss: 0.0055, Train Accuracy: 0.8837
Epoch: 1821, F1: 0.9429, Precision: 0.9429, Recall: 0.9429
Epoch: 1841, Train Loss: 0.0064, Train Accuracy: 0.9070
Epoch: 1841, F1: 0.9429, Precision: 0.9429, Recall: 0.9429
Epoch: 1861, Train Loss: 0.0047, Train Accuracy: 0.8837
Epoch: 1861, F1: 0.9429, Precision: 0.9429, Recall: 0.9429
Epoch: 1881, Train Loss: 0.0058, Train Accuracy: 0.8605
Epoch: 1881, F1: 0.9429, Precision: 0.9429, Recall: 0.9429
Epoch: 1901, Train Loss: 0.0064, Train Accuracy: 0.8488
Epoch: 1901, F1: 0.9429, Precision: 0.9429, Recall: 0.9429
Epoch: 1921, Train Loss: 0.0062, Train Accuracy: 0.8140
Epoch: 1921, F1: 0.9429, Precision: 0.9429, Recall: 0.9429
Epoch: 1941, Train Loss: 0.0059, Train Accuracy: 0.8953
Epoch: 1941, F1: 0.9429, Precision: 0.9429, Recall: 0.9429
Epoch: 1961, Train Loss: 0.0053, Train Accuracy: 0.8488
Epoch: 1961, F1: 0.9429, Precision: 0.9429, Recall: 0.9429
Epoch: 1981, Train Loss: 0.0055, Train Accuracy: 0.8488
Epoch: 1981, F1: 0.9429, Precision: 0.9429, Recall: 0.9429

大家可以利用自己的数据集进行训练,按照格式修改即可

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/444473.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

C语言实现链表--数据结构

魔王的介绍&#xff1a;&#x1f636;‍&#x1f32b;️一名双非本科大一小白。魔王的目标&#xff1a;&#x1f92f;努力赶上周围卷王的脚步。魔王的主页&#xff1a;&#x1f525;&#x1f525;&#x1f525;大魔王.&#x1f525;&#x1f525;&#x1f525; ❤️‍&#x1…

gateway整合knife4j(微服务在线文档)

文章目录 knife4j 微服务整合一、微服务与单体项目文档整合的区别二、开始整合1. 搭建一个父子maven模块的微服务,并引入gateway2.开始整合文档 总结 knife4j 微服务整合 由于单个服务的knife4j 整合之前已经写过了,那么由于效果比较好,然后微服务的项目中也想引入,所以开始微…

【Linux】多线程的互斥与同步

目录 一、线程冲突 二、重入与线程安全 1、线程不安全的情况 2、线程安全的情况 3、不可重入的情况 4、可重入的情况 5、可重入和线程安全的联系 6、STL是否线程安全 7、智能指针是否线程安全 三、互斥锁 1、互斥锁的使用 2、基于RAII风格的互斥锁的封装 2.1Mutex…

ChatGPT-4回答电子电路相关问题,感觉它有思想,有灵魂,一起看看聊天记录

前几天发了一篇文章&#xff0c;讲了我们平常摸电脑或者其它电器设备的时候&#xff0c;会有酥酥麻麻的感觉&#xff0c;这个并不是静电&#xff0c;而是Y电容通过金属壳泄放高频扰动&#xff0c;我们手摸金属壳的时候&#xff0c;就给Y电容提供了一个泄放回路&#xff0c;所以…

全网抓包天花板教程,CSDN讲的最详细的Fiddler抓包教程。2小时包你学会

目录 前言 一、安装 Fiddler 二、打开 Fiddler 并设置代理 三、抓取 HTTP/HTTPS 流量 四、流量分析和调试 五、应用场景 六、注意事项 七、实际案例 八、拓展阅读 九、结论 前言 Fiddler 是一款功能强大的网络调试工具&#xff0c;可以用于捕获和分析 HTTP 和 HTTPS …

生物信息学有哪些SCI期刊推荐? - 易智编译EaseEditing

以下是几个生物信息学领域的SCI期刊推荐&#xff1a; Bioinformatics&#xff1a; 该期刊是生物信息学领域最具影响力的SCI期刊之一&#xff0c;涵盖了生物信息学、计算生物学、系统生物学、生物医学工程等多个研究方向。 BMC Bioinformatics&#xff1a; 该期刊是生物信息学…

数据结构入门(C语言版)二叉树链式结构的实现

二叉树链式结构的实现 二叉树的概念及结构创建1、概念2、结构创建2、创建结点函数3、建树函数 二叉树的遍历1、前序遍历2、中序遍历3、后序遍历4、层序遍历 二叉树的销毁结语 二叉树的概念及结构创建 1、概念 简单回顾一下二叉树的概念&#xff1a; ★ 空树 ★非空&#xff1…

intellij 从2020升级到2023 踩坑实录

1.下载新版本intellij 工作机器上的intellij版本为2020社区版&#xff0c;版本比较老旧&#xff0c;需要进行升级。IDE这种提高生产力的工具&#xff0c;还是蛮重要的&#xff0c;也是值得稍微多花点时间研究一下的。升级之前就预计到了不会是那么简单&#xff0c;后面事实也证…

大型体检管理系统源码,Vs2012,C/S架构

体检管理系统源码&#xff0c;PEIS源码 一套专业的体检管理系统源码&#xff0c;核心功能有体检档案的录入、体检报告的输出、体检档案的统计查询和对比分析。该系统的使用&#xff0c;可以大大提高体检档案管理人员的工作效率&#xff0c;使体检档案的管理更加准确、全面、完…

以人为本的重点是有效网络安全计划的关键

安全和风险管理 (SRM) 领导者在根据九大行业趋势创建和实施网络安全计划时&#xff0c;必须重新考虑他们在技术和以人为本的元素之间的投资平衡。 以人为本的网络安全方法对于减少安全故障至关重要。 在控制设计和实施以及通过业务沟通和网络安全人才管理中关注人&#xff…

Python中的异常——概述和基本语法

Python中的异常——概述和基本语法 摘要&#xff1a;Python中的异常是指在程序运行时发生的错误情况&#xff0c;包括但不限于除数为0、访问未定义变量、数据类型错误等。异常处理机制是Python提供的一种解决这些错误的方法&#xff0c;我们可以使用try/except语句来捕获异常并…

基于linux:MySql-5.7二进制安装部署

基于linux&#xff1a;MySql-5.7二进制安装 1&#xff09;检查当前系统是否安装过Mysql [ ~]$ rpm -qa|grep mariadb mariadb-libs-5.5.56-2.el7.x86_64 //如果存在通过如下命令卸载 [ ~]$ sudo rpm -e --nodeps mariadb-libs //用此命令卸载mariadb2&#xff09;解压MySQ…

限流算法浅析

前言 在前文接口请求安全措施中&#xff0c;简单提到过接口限流&#xff0c;那里是通过Guava工具类的RateLimiter实现的&#xff0c;它实际上是令牌桶限流的具体实现&#xff0c;那么下面分别介绍几种限流算法&#xff0c;做一个更详细的了解。 固定窗口限流 1、核心思想 在…

基于 Flink CDC 的现代数据栈实践

摘要&#xff1a;本文整理自阿里云技术专家&#xff0c;Apache Flink PMC Member & Committer, Flink CDC Maintainer 徐榜江和阿里云高级研发工程师&#xff0c;Apache Flink Contributor & Flink CDC Maintainer 阮航&#xff0c;在 Flink Forward Asia 2022 数据集成…

初识C语言————4

文章目录 常见关键字 1、 关键字 typedef 2、关键字static define 定义常量和宏 指针 结构体 前言 这是博主初识C语言系列的最后一篇&#xff0c;之后博主会更新更详细的关于C语言学习的知识。希望各位老铁多多支持。 一、常见关键字 1、 关键字 typedef typedef 顾名思义是…

海康威视发布2022年ESG报告:科技为善, 助力可持续的美好未来

近日&#xff0c;海康威视正式发布《2022环境、社会及管治报告》&#xff08;以下简称“海康威视ESG报告”)&#xff0c;连续5年呈现在环境、社会发展、企业治理等领域的思考和创新成果。此外&#xff0c;报告中首次披露了碳中和业务蓝图&#xff0c;积极布局绿色生产、绿色运营…

HTTP特性

1 HTTP/1.1 的优点有哪些&#xff1f; 2 HTTP/1.1 的缺点有哪些&#xff1f; 3 HTTP/1.1 的性能如何&#xff1f; HTTP 协议是基于 TCP/IP&#xff0c;并且使用了「请求 - 应答」的通信模式&#xff0c;所以性能的关键就在这两点里。 3.1 长连接 早期 HTTP/1.0 性能上的一…

分布式Id生成之雪花算法(SnowFlake)

目录 前言 回顾二进制 二进制概念 运算法则 位&#xff08;Bit&#xff09; 字节&#xff08;Byte&#xff09; 字符 字符集 二进制原码、反码、补码 有符号数和无符号数 疑问&#xff1a;为什么不是-127 &#xff5e; 127 &#xff1f; 为什么需要分布式全局唯一ID…

sql中 join 的简单用法总结(带例子)

join 常见的用法有&#xff1a; 目录 left join&#xff08;left outer join&#xff09;right join&#xff08;right outer join&#xff09;join&#xff08;inner join&#xff09;full join&#xff08;full outer join 、outer join&#xff09;cross join 说明&#xf…

docker自定义镜像

文章目录 一、自定义镜像1.1 镜像结构1.2 Dockerfile1.3 dockerCompose1.3.1 dockerCompose的作用1.3.2 dockerCompose的常用命令 1.4 镜像仓库 一、自定义镜像 1.1 镜像结构 自定义镜像通常包含三个基本部分&#xff1a;基础镜像、应用程序代码和配置文件。 基础镜像&#…