GT_BERT文本分类

news2024/11/23 15:03:51

目录

  • GT-BERT
  • 结束语
  • 代码实现
  • 整个项目源码(数据集模型)

GT-BERT

在为了使 BERT 模型能够得到广泛的应用,在保证模型分类准确率不降低的情况下,减少模型参数规模并降低时间复杂度,提出一种基于半监督生成对抗网络与 BERT 的文本分类模型 GT-BERT。模型的整体框架如图3所示。
在这里插入图片描述

首先,对BERT进行压缩,通过实验验证选择使用BERT-of-theseus方法进行压缩得到BERT-theseus模型。损失函数设定为文本分类常用的交叉熵损失:
在这里插入图片描述

其中,为训练集的第j个样本,是的标签,C和c表示标签集合和一个类标签。接着,在压缩之后,从SS-GANs角度扩展BERT-theseus模型进行微调。在预训练过的BERT-theseus模型中添加两个组件:(1)添加特定任务层;(2)添加SS-GANs层来实现半监督学习。本研究假定K类句子分类任务,给定输入句子s=(, ,…,),其中开头的为分类特殊标记“[CLS]”,结尾的为句子分隔特殊标记“[SEP]”,其余部分对输入句子进行切分后标记序列输入BERT模型后得到编码向量序列为=(,…,)。
将生成器G生成的假样本向量与真实无标注数据输入BERT-theseus中所提取的特征向量,分别输入至判别器D中,利用对抗训练来不断强化判别器D。与此同时,利用少量标注数据对判别器D进行分类训练,从而进一步提高模型整体质量。
其中,生成器G输出服从正态分布的“噪声”,采用CNN网络,将输出空间映射到样本空间,记作∈。 判别器D也为CNN网络,它在输入中接收向量∈,其中可以为真实标注或者未标注样本 ,也可以为生成器生成的假样本数据。在前向传播阶段,当样本为真实样本时,即=,判别器D会将样本分类在K类之中。当样本为假样本时,即=,判别器D会把样本相对应的分类于K+1类别中。在此阶段生成器G和判别器D的损失分别被记作和,训练过程中G和D通过相互博弈而优化损失。
在反向传播中,未标注样本只增加。标注的真实样本只会影响,在最后和都会受到G的影响,即当D找不出生成样本时,将会受到惩罚,反亦然。在更新D时,改变BERT-theseus的权重来进行微调。训练完成后,生成器G会被舍弃,同时保留完整的BERT-theseus模型与判别器D进行分类任务的预测。

结束语

该文提出了一种用于文本分类任务的GT-BERT模型。首先,使用 theseus方法对BERT进行压缩,在不降低分类性能的前提下,有效降低了BERT 的参数规模和时间复杂度。然后,引人SS-GAN框架改进模型的训练方式,使 BERT-theseus模型能有效利用无标注数据,并实验了多组生成器与判别器的组合方式,获取了最优的生成器判别器组合配置,进一步提升了模型的分类性能。

代码实现

import torch
from transformers import BertTokenizer, BertModel
from torch.utils.data import DataLoader, Dataset
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
import torch.nn as nn
import torch.optim as optim
import os
from glob import glob

torch.autograd.set_detect_anomaly(True)


# 定义数据集类
class TextDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_len):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        label = self.labels[idx]
        encoding = self.tokenizer.encode_plus(
            text,
            add_special_tokens=True,
            max_length=self.max_len,
            return_token_type_ids=False,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt',
        )
        return {
            'text': text,
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'label': torch.tensor(label, dtype=torch.long)
        }


# 加载数据集函数
def load_data(dataset_name):
    if dataset_name == '20ng':
        dirs = glob("E:/python_project/GT_BERT/dateset/20_newsgroups/20_newsgroups/*")

        texts = []
        labels = []

        for i, d in enumerate(dirs):
            for j in glob(d + "/*")[:10]:
                try:
                    with open(j, "r", encoding="utf-8") as f:
                        one = f.read()
                except:
                    continue
                texts.append(one)
                labels.append(i)


    elif dataset_name == 'sst5':
        data_dir = 'path/to/sst/data'

        def load_sst_data(data_dir, split):
            sentences = []
            labels = []
            with open(os.path.join(data_dir, f'{split}.txt')) as f:
                for line in f:
                    label, sentence = line.strip().split(' ', 1)
                    sentences.append(sentence)
                    labels.append(int(label))
            return sentences, labels

        texts, labels = load_sst_data(data_dir, 'train')
    elif dataset_name == 'mr':
        file_path = 'path/to/mr/data'

        def load_mr_data(file_path):
            sentences = []
            labels = []
            with open(file_path) as f:
                for line in f:
                    label, sentence = line.strip().split(' ', 1)
                    sentences.append(sentence)
                    labels.append(int(label))
            return sentences, labels

        texts, labels = load_mr_data(file_path)
    elif dataset_name == 'trec':
        file_path = 'path/to/trec/data'

        def load_trec_data(file_path):
            sentences = []
            labels = []
            with open(file_path) as f:
                for line in f:
                    label, sentence = line.strip().split(' ', 1)
                    sentences.append(sentence)
                    labels.append(label)
            return sentences, labels

        texts, labels = load_trec_data(file_path)
    else:
        raise ValueError("Unsupported dataset")
    return texts, labels


# 默认加载 20 News Group 数据集
dataset_name = '20ng'
texts, labels = load_data(dataset_name)

label_encoder = LabelEncoder()
labels = label_encoder.fit_transform(labels)

# 使用BERT的tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
max_len = 128

# 将数据集划分为训练集和验证集
train_texts, val_texts, train_labels, val_labels = train_test_split(texts, labels, test_size=0.2)
train_dataset = TextDataset(train_texts, train_labels, tokenizer, max_len)
val_dataset = TextDataset(val_texts, val_labels, tokenizer, max_len)
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=16, shuffle=False)


# 定义BERT编码器
class BERTTextEncoder(nn.Module):
    def __init__(self):
        super(BERTTextEncoder, self).__init__()
        self.bert = BertModel.from_pretrained('bert-base-uncased')

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs[1]
        return pooled_output


# 定义生成器
class Generator(nn.Module):
    def __init__(self, noise_dim, output_dim):
        super(Generator, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(noise_dim, 128),
            nn.ReLU(),
            nn.Linear(128, output_dim),
            nn.Tanh()
        )

    def forward(self, noise):
        return self.fc(noise)


# 定义判别器
class Discriminator(nn.Module):
    def __init__(self, input_dim):
        super(Discriminator, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )

    def forward(self, features):
        return self.fc(features)


# 定义完整的GT-BERT模型
class GTBERTModel(nn.Module):
    def __init__(self, bert_encoder, noise_dim, output_dim, num_classes):
        super(GTBERTModel, self).__init__()
        self.bert_encoder = bert_encoder
        self.generator = Generator(noise_dim, output_dim)
        self.discriminator = Discriminator(output_dim)
        self.classifier = nn.Linear(output_dim, num_classes)

    def forward(self, input_ids, attention_mask, noise):
        real_features = self.bert_encoder(input_ids, attention_mask)
        fake_features = self.generator(noise)
        disc_real = self.discriminator(real_features)
        disc_fake = self.discriminator(fake_features)
        class_output = self.classifier(real_features)
        return class_output, disc_real, disc_fake


# 初始化模型和超参数
noise_dim = 100
output_dim = 768
num_classes = len(set(labels))
bert_encoder = BERTTextEncoder()
model = GTBERTModel(bert_encoder, noise_dim, output_dim, num_classes)

# 定义损失函数和优化器
criterion_class = nn.CrossEntropyLoss()
criterion_disc = nn.BCELoss()
optimizer_G = optim.Adam(model.generator.parameters(), lr=0.0002)
optimizer_D = optim.Adam(model.discriminator.parameters(), lr=0.0002)
optimizer_BERT = optim.Adam(model.bert_encoder.parameters(), lr=2e-5)
optimizer_classifier = optim.Adam(model.classifier.parameters(), lr=2e-5)

num_epochs = 10

# 训练循环
e_id = 1
for epoch in range(num_epochs):
    model.train()
    for batch in train_dataloader:
        e_id += 1
        input_ids = batch['input_ids']
        attention_mask = batch['attention_mask']
        labels = batch['label']

        # 生成噪声
        noise = torch.randn(input_ids.size(0), noise_dim)

        # 获取模型输出
        class_output, disc_real, disc_fake = model(input_ids, attention_mask, noise)

        # 计算损失
        real_labels = torch.ones(input_ids.size(0), 1)
        fake_labels = torch.zeros(input_ids.size(0), 1)

        loss_real = criterion_disc(disc_real, real_labels)
        loss_fake = criterion_disc(disc_fake, fake_labels)
        loss_class = criterion_class(class_output, labels)

        if e_id % 5 == 0:

            # 优化判别器
            optimizer_D.zero_grad()
            loss_D = (loss_real + loss_fake) / 2
            loss_D.backward(retain_graph=True)
            optimizer_D.step()

        elif e_id % 2 == 0:
            # 优化生成器
            loss_G = criterion_disc(disc_fake, real_labels)
            optimizer_G.zero_grad()
            loss_G.backward(retain_graph=True)
            optimizer_G.step()
        else:
            # 优化BERT和分类器

            optimizer_BERT.zero_grad()
            optimizer_classifier.zero_grad()
            loss_class.backward()
            optimizer_BERT.step()
            optimizer_classifier.step()

    print(
        f'Epoch [{epoch + 1}/{num_epochs}], Loss D: {loss_D.item()}, Loss G: {loss_G.item()}, Loss Class: {loss_class.item()}')

# 验证模型
model.eval()
val_loss = 0
correct = 0
with torch.no_grad():
    for batch in val_dataloader:
        input_ids = batch['input_ids']
        attention_mask = batch['attention_mask']
        labels = batch['label']

        noise = torch.randn(input_ids.size(0), noise_dim)
        class_output, disc_real, disc_fake = model(input_ids, attention_mask, noise)

        loss = criterion_class(class_output, labels)
        val_loss += loss.item()
        pred = class_output.argmax(dim=1, keepdim=True)
        correct += pred.eq(labels.view_as(pred)).sum().item()

val_loss /= len(val_dataloader.dataset)
accuracy = correct / len(val_dataloader.dataset)
print(f'Validation Loss: {val_loss}, Accuracy: {accuracy}')

整个项目源码(数据集模型)

项目

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1841121.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

DNS污染是什么?防止和清洗DNS污染的解决方案

在运营互联网业务中,通常会遇到各种各样的问题。其实DNS污染就是其中一个很严重的问题,它甚至会导致我们的业务中断,无法进行。今天就来了解一下DNS污染是什么?以及如何防止和清洗DNS污染。 什么是DNS? 首先我们要了解…

企业微信,机器人定时提醒

场景: 每天定时发送文字,提醒群成员事情,可以用机器人代替 人工提醒。 1)在企业微信,创建机器人 2)在腾讯轻联,创建流程,选择定时任务,执行操作(企业微信机…

Qt利用Coin3D(OpenInventor)进行3d绘图

文章目录 1.安装1.1.下载coin3d1.2.下载quarter1.3.解压并合并 2.在Qt中使用3.画个网格4.加载wrl模型 1.安装 1.1.下载coin3d 首先,到官网下载[coin3d/coin] 我是Qt5.15.2vs2019的,因此我选择这个coin-4.0.2-msvc17-x64.zip 1.2.下载quarter 到官网…

milvus元数据解析工具milvusmetagui介绍使用

简介 milvusmetagui是一款用来对milvus的元数据进行解析的工具,milvus的元数据存储在etcd上,而且经过了序列化,通过etcd-manager这样的工具来查看是一堆二进制乱码,因此开发了这个工具对value进行反序列化解析。 在这里为了方便交…

arm-linux-strip 指令的作用

指令作用 arm-linux-strip 是一个用于从目标文件(如可执行文件或对象文件)中移除符号信息的工具。这些符号信息(如函数名、变量名等)在开发过程中很有用,因为它们允许调试器(如 GDB)确定内存地址…

安装cuda、cudnn、Pytorch(用cuda和cudnn加速计算)

写在前面 最近几个月都在忙着毕业的事,好一阵子没写代码了。今天准备跑个demo,发现报错 AssertionError: Torch not compiled with CUDA enabled 不知道啥情况,因为之前有cuda环境,能用gpu加速,看这个报错信息应该是P…

Elasticsearch搜索引擎(初级篇)

1.1 初识ElasticSearch | 《ElasticSearch入门到实战》电子书 (chaosopen.cn) 目录 第一章 入门 1.1 ElasticSearch需求背景 1.2 ElasticSearch 和关系型数据库的对比 1.3 基础概念 文档和字段 索引和映射 第二章 索引操作 2.0 Mapping映射属性 2.1 创建索引 DS…

Java宝藏实验资源库(1)文件

一、实验目的 掌握文件、目录管理以及文件操作的基本方法。掌握输入输出流的基本概念和流处理类的基本结构。掌握使用文件流进行文件输入输出的基本方法。 二、实验内容、过程及结果 1.显示指定目录下的每一级文件夹中的.java文件 运行代码如下 : import java.io.…

智慧校园综合管理系统:打造高效智慧的学校管理平台

智慧校园综合管理系统,作为提升教育管理与教学效率的数字化解决方案,它将信息技术深度融合于校园的每一个角落,构建了一个集信息共享、教学资源优化、智能管理、安全保障于一体的综合平台。该系统不仅提供了统一的信息门户,确保学…

MYSQL 四、mysql进阶 3(存储引擎)

mysql中表使用了不同的存储引擎也就决定了我们底层文件系统中文件的相关物理结构。 为了管理方便,人们把连接管理、语法解析、查询优化这些并不涉及真实数据存储的功能划分为 Mysql Server的功能,把真实存取数据的功能划分为存储引擎的功能&…

CVE-2023-50563(sql延时注入)

简介 SEMCMS是一套支持多种语言的外贸网站内容管理系统(CMS)。SEMCMS v4.8版本存在SQLI,该漏洞源于SEMCMS_Function.php 中的 AID 参数包含 SQL 注入 过程 打开靶场 目录扫描,发现安装install目录,进入,…

ruoyi登录功能源码分析

Ruoyi登录功能源码分析 上一篇文章我们分析了一下若依登录验证码生成的代码,今天我们来分析一下登录功能的代码 1、发送登录请求 前端通过http://localhost/dev-api/login向后端发送登录请求并携带用户的登录表单 在后端中的com.ruoyi.web.controller.system包下…

Artalk-CORS,跨域拦截问题

今天重新部署Artalk之后,遇到了CORS——跨域拦截的问题,卡了好一会记录一下。 起因 重新部署之后,浏览器一直提示CORS,之前在其他项目也遇到过类似的问题,原因就在于跨域问题。

[Linux] 系统的基本架构特点

Linux系统的基本结构 Linux is also a subversion of UNIX,it follows the basic structure of UNIX 内核(kernel): 操作系统的基本部分 管理与硬件相关的功能,分模块进行 常驻模块:进程控制IO操作文件\磁盘访问 用户不能直接访问内核 外壳(s…

C语言 | Leetcode C语言题解之第167题两数之和II-输入有序数组

题目&#xff1a; 题解&#xff1a; int* twoSum(int* numbers, int numbersSize, int target, int* returnSize) {int* ret (int*)malloc(sizeof(int) * 2);*returnSize 2;int low 0, high numbersSize - 1;while (low < high) {int sum numbers[low] numbers[high]…

Centos 配置安装Mysql

linux安装配置mysql的方法主要有yum安装和配置安装两种&#xff0c;由于yum安装比较简单&#xff0c;但是会将文件分散到不同的目录结构下面&#xff0c;配置起来比较麻烦&#xff0c;这里主要研究一下配置安装mysql的方法 1、环境说明 centos 7.9 mysql 5.7.372、环境检查 …

【Kubernetes】概念学习

Kubernetes介绍 Kubernetes 是谷歌开源的容器集群管理系统 是用于自动部署&#xff0c;扩展和管理 Docker 应用程序的开源系统&#xff0c;简称 K8S。 Kubernetes是一个可以移植、可扩展的开源平台&#xff0c;使用 声明式的配置 并依据配置信息自动地执行容器化应用程序的管…

27 map和set封装

map和set可以采用两套红黑树实现&#xff0c;也可以用同一个红黑树&#xff0c;就需要对前面的结构进行修改 迭代器的好处是可以方便遍历&#xff0c;是数据结构的底层实现与用户透明。如果想要给红黑树增加迭代器&#xff0c;需要考虑以前问题&#xff1a; begin()和end() s…

ChatGPT Plus GPT-4o Claude 3 Opus合租拼车全新方式

无需自己搭建&#xff0c;登录即可用&#xff0c;国内直连访问&#xff0c;聚合多家最强大模型&#xff0c;随意选择使用。立即体验 datapipe.top 支持 OpenAI 最新 GPT-4o &#xff0c;获得快速高质量的对话&#xff0c;保证可用配额。支持多种大模型&#xff0c;GPT-4o &…

课程设计---哈夫曼树的编码与解码(Java详解)

目录 一.设计任务&&要求&#xff1a; 二.方案设计报告&#xff1a; 2.1 哈夫曼树编码&译码的设计原理&#xff1a; 2.3设计目的&#xff1a; 2.3设计的主要过程&#xff1a; 2.4程序方法清单&#xff1a; 三.整体实现源码&#xff1a; 四.运行结果展示&…