基于BERT的文本分类项目的实现
一、项目背景
该文本分类项目主要是情感分析,二分类问题,以下是大致流程及部分代码示例:
二、数据集介绍
2.1 数据集基本信息
数据集 自定义 类型 二分类(正面/负面) 样本量 训练集 + 验证集 + 测试集 文本长度 平均x字(最大x字) 领域 商品评论、影视评论
dataset = pd. read_csv( 'data/train.txt' , sep= '\t' )
print ( dataset[ 'train' ] [ 0 ] )
2.2 数据分析
2.2.1 句子长度分布
import matplotlib. pyplot as plt
def analyze_length ( texts) :
lengths = [ len ( t) for t in texts]
plt. figure( figsize= ( 12 , 5 ) )
plt. hist( lengths, bins= 30 , range = ( 0 , 256 ) , color= 'blue' , alpha= 0.7 )
plt. title( "文本长度分布" , fontsize= 14 )
plt. xlabel( "字符数" )
plt. ylabel( "样本量" )
plt. show( )
analyze_length( dataset[ 'train' ] [ 'text' ] )
2.2.2 标签分布
import pandas as pd
pd. Series( dataset[ 'train' ] [ 'label' ] ) . value_counts( ) . plot(
kind= 'pie' ,
autopct= '%1.1f%%' ,
title= '类别分布(0-负面 1-正面)'
)
plt. show( )
2.2.3 类别平衡处理
from torch. utils. data import WeightedRandomSampler
labels = dataset[ 'train' ] [ 'label' ]
class_weights = 1 / torch. Tensor( [ len ( labels) - sum ( labels) , sum ( labels) ] )
sampler = WeightedRandomSampler(
weights= [ class_weights[ label] for label in labels] ,
num_samples= len ( labels) ,
replacement= True
)
三、数据处理
3.1 BERT分词器
from transformers import BertTokenizer
tokenizer = BertTokenizer. from_pretrained( 'bert-base-chinese' )
def collate_fn ( batch) :
texts = [ item[ 'text' ] for item in batch]
labels = [ item[ 'label' ] for item in batch]
inputs = tokenizer(
texts,
padding= True ,
truncation= True ,
max_length= 256 ,
return_tensors= 'pt'
)
return {
'input_ids' : inputs[ 'input_ids' ] ,
'attention_mask' : inputs[ 'attention_mask' ] ,
'labels' : torch. LongTensor( labels)
}
3.2 数据加载器
from torch. utils. data import DataLoader
train_loader = DataLoader(
dataset[ 'train' ] ,
batch_size= 32 ,
collate_fn= collate_fn,
sampler= sampler
)
val_loader = DataLoader(
dataset[ 'validation' ] ,
batch_size= 32 ,
collate_fn= collate_fn
)
四、模型构建
4.1 BERT分类模型
import torch. nn as nn
from transformers import BertModel
class BertClassifier ( nn. Module) :
def __init__ ( self) :
super ( ) . __init__( )
self. bert = BertModel. from_pretrained( 'bert-base-chinese' )
self. dropout = nn. Dropout( 0.1 )
self. fc = nn. Linear( 768 , 2 )
def forward ( self, input_ids, attention_mask) :
outputs = self. bert( input_ids, attention_mask)
pooled = self. dropout( outputs. pooler_output)
return self. fc( pooled)
4.2 模型配置
import torch
device = torch. device( "cuda" if torch. cuda. is_available( ) else "cpu" )
model = BertClassifier( ) . to( device)
optimizer = torch. optim. AdamW( model. parameters( ) , lr= 2e - 5 )
criterion = nn. CrossEntropyLoss( )
五、模型训练与验证
5.1 训练流程
from tqdm import tqdm
def train_epoch ( model, loader) :
model. train( )
total_loss = 0
for batch in tqdm( loader) :
optimizer. zero_grad( )
input_ids = batch[ 'input_ids' ] . to( device)
attention_mask = batch[ 'attention_mask' ] . to( device)
labels = batch[ 'labels' ] . to( device)
outputs = model( input_ids, attention_mask)
loss = criterion( outputs, labels)
loss. backward( )
optimizer. step( )
total_loss += loss. item( )
return total_loss / len ( loader)
5.2 验证流程
def evaluate ( model, loader) :
model. eval ( )
correct = 0
total = 0
with torch. no_grad( ) :
for batch in loader:
input_ids = batch[ 'input_ids' ] . to( device)
attention_mask = batch[ 'attention_mask' ] . to( device)
labels = batch[ 'labels' ] . to( device)
outputs = model( input_ids, attention_mask)
preds = torch. argmax( outputs, dim= 1 )
correct += ( preds == labels) . sum ( ) . item( )
total += len ( labels)
return correct / total
六、实验结果
6.1 评估指标
from sklearn. metrics import confusion_matrix
import seaborn as sns
def plot_confusion_matrix ( loader) :
y_true = [ ]
y_pred = [ ]
model. eval ( )
with torch. no_grad( ) :
for batch in loader:
input_ids = batch[ 'input_ids' ] . to( device)
attention_mask = batch[ 'attention_mask' ] . to( device)
labels = batch[ 'labels' ] . to( device)
outputs = model( input_ids, attention_mask)
preds = torch. argmax( outputs, dim= 1 )
y_true. extend( labels. cpu( ) . numpy( ) )
y_pred. extend( preds. cpu( ) . numpy( ) )
cm = confusion_matrix( y_true, y_pred)
sns. heatmap( cm, annot= True , fmt= 'd' , cmap= 'Blues' )
plt. title( '混淆矩阵' )
plt. xlabel( '预测标签' )
plt. ylabel( '真实标签' )
plt. show( )
plot_confusion_matrix( test_loader)
6.2 学习曲线
from torch. utils. tensorboard import SummaryWriter
writer = SummaryWriter( )
for epoch in range ( 3 ) :
train_loss = train_epoch( model, train_loader)
val_acc = evaluate( model, val_loader)
writer. add_scalar( 'Loss/Train' , train_loss, epoch)
writer. add_scalar( 'Accuracy/Validation' , val_acc, epoch)
七、流程架构图
原始文本
分词编码
BERT特征提取
全连接分类
损失计算
反向传播
模型评估