pytorch实战---IMDB情感分析

news2024/12/24 7:03:16

💗💗💗欢迎来到我的博客,你将找到有关如何使用技术解决问题的文章,也会找到某个技术的学习路线。无论你是何种职业,我都希望我的博客对你有所帮助。最后不要忘记订阅我的博客以获取最新文章,也欢迎在文章下方留下你的评论和反馈。我期待着与你分享知识、互相学习和建立一个积极的社区。谢谢你的光临,让我们一起踏上这个知识之旅!
请添加图片描述

文章目录

  • 🥦引言
  • 🥦完整代码
  • 🥦代码分析
    • 🥦导库
    • 🥦设置日志
    • 🥦模型定义
      • 🥦GCNN
      • 🥦TextClassificationModel
    • 🥦准备IMDb数据集
    • 🥦整理函数
    • 🥦训练函数
    • 🥦模型初始化和优化器
    • 🥦加载用于训练和评估的数据
    • 🥦恢复训练
    • 🥦调用训练
  • 🥦保存文件的读取
  • 🥦扩展 LSTM、GRU
  • 🥦总结

🥦引言

本文使用IMDB数据集,结合pytorch进行情感分析

🥦完整代码

import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score

from torch import utils

import torchtext
from tqdm import tqdm
from torchtext.datasets import IMDB

from torchtext.datasets.imdb import NUM_LINES
from torchtext.data import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from torchtext.data.functional import to_map_style_dataset

import os
import sys
import logging
import logging

logging.basicConfig(
    level=logging.WARN, stream=sys.stdout, format = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")

VOCAB_SIZE = 15000


# step1 编写GCNN模型代码,门(Gate)卷积网络
class GCNN(nn.Module):
    def __init__(self, vocab_size=VOCAB_SIZE, embedding_dim=64, num_class=2):
        super(GCNN, self).__init__()

        self.embedding_table = nn.Embedding(vocab_size, embedding_dim)
        nn.init.xavier_uniform_(self.embedding_table.weight)

        # 都是1维卷积
        self.conv_A_1 = nn.Conv1d(embedding_dim, 64, 15, stride=7)
        self.conv_B_1 = nn.Conv1d(embedding_dim, 64, 15, stride=7)

        self.conv_A_2 = nn.Conv1d(64, 64, 15, stride=7)
        self.conv_B_2 = nn.Conv1d(64, 64, 15, stride=7)

        self.output_linear1 = nn.Linear(64, 128)
        self.output_linear2 = nn.Linear(128, num_class)

    def forward(self, word_index):
        """
        定义GCN网络的算子操作流程,基于句子单词ID输入得到分类logits输出
        """
        # 1. 通过word_index得到word_embedding
        # word_index shape: [bs, max_seq_len]
        word_embedding = self.embedding_table(word_index)  # [bs, max_seq_len, embedding_dim]

        # 2. 编写第一层1D门卷积模块,通道数在第2维
        word_embedding = word_embedding.transpose(1, 2)  # [bs, embedding_dim, max_seq_len]
        A = self.conv_A_1(word_embedding)
        B = self.conv_B_1(word_embedding)
        H = A * torch.sigmoid(B)  # [bs, 64, max_seq_len]

        A = self.conv_A_2(H)
        B = self.conv_B_2(H)
        H = A * torch.sigmoid(B)  # [bs, 64, max_seq_len]

        # 3. 池化并经过全连接层
        pool_output = torch.mean(H, dim=-1)  # 平均池化,得到[bs, 4096]
        linear1_output = self.output_linear1(pool_output)

        # 最后一层需要设置为隐含层数目
        logits = self.output_linear2(linear1_output)  # [bs, 2]

        return logits


# PyTorch官网的简单模型
class TextClassificationModel(nn.Module):
    """
    简单版embedding.DNN模型
    """

    def __init__(self, vocab_size=VOCAB_SIZE, embed_dim=64, num_class=2):
        super(TextClassificationModel, self).__init__()
        self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, sparse=False)
        self.fc = nn.Linear(embed_dim, num_class)

    def forward(self, token_index):
        # 词袋
        embedded = self.embedding(token_index)  # shape: [bs, embedding_dim]
        return self.fc(embedded)


# step2 构建IMDB Dataloader
BATCH_SIZE = 64


def yeild_tokens(train_data_iter, tokenizer):
    for i, sample in enumerate(train_data_iter):
        label, comment = sample
        yield tokenizer(comment)  # 字符串转换为token索引的列表


train_data_iter = IMDB(root="./data", split="train")  # Dataset类型的对象
tokenizer = get_tokenizer("basic_english")
# 只使用出现次数大约20的token
vocab = build_vocab_from_iterator(yeild_tokens(train_data_iter, tokenizer), min_freq=20, specials=["<unk>"])
vocab.set_default_index(0)  # 特殊索引设置为0
print(f'单词表大小: len(vocab)')


# 校对函数, batch是dataset返回值,主要是处理batch一组数据
def collate_fn(batch):
    """
    对DataLoader所生成的mini-batch进行后处理
    """
    target = []
    token_index = []
    max_length = 0  # 最大的token长度
    for i, (label, comment) in enumerate(batch):
        tokens = tokenizer(comment)
        token_index.append(vocab(tokens))  # 字符列表转换为索引列表

        # 确定最大的句子长度
        if len(tokens) > max_length:
            max_length = len(tokens)

        if label == "pos":
            target.append(0)
        else:
            target.append(1)

    token_index = [index + [0] * (max_length - len(index)) for index in token_index]
    # one-hot接收长整形的数据,所以要转换为int64
    return (torch.tensor(target).to(torch.int64), torch.tensor(token_index).to(torch.int32))


# step3 编写训练代码
def train(train_data_loader, eval_data_loader, model, optimizer, num_epoch, log_step_interval, save_step_interval,  eval_step_interval, save_path, resume=""):
    """
    此处data_loader是map-style dataset
    """
    start_epoch = 0
    start_step = 0
    if resume != "":
        # 加载之前训练过的模型的参数文件
        logging.warning(f"loading from resume")
        checkpoint = torch.load(resume)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint['epoch']
        start_step = checkpoint['step']

    for epoch_index in tqdm(range(start_epoch, num_epoch), desc="epoch"):
        ema_loss = 0
        total_acc_account = 0
        total_account = 0
        true_labels = []
        predicted_labels = []
        num_batches = len(train_data_loader)
        for batch_index, (target, token_index) in enumerate(train_data_loader):
            optimizer.zero_grad()
            step = num_batches * (epoch_index) + batch_index + 1
            logits = model(token_index)
            # one-hot需要转换float32才可以训练
            bce_loss = F.binary_cross_entropy(torch.sigmoid(logits), F.one_hot(target, num_classes=2).to(torch.float32))
            ema_loss = 0.9 * ema_loss + 0.1 * bce_loss  # 指数平均loss
            bce_loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), 0.1)  # 梯度的正则进行截断,保证训练稳定
            optimizer.step()  # 更新参数

            true_labels.extend(target.tolist())
            predicted_labels.extend(torch.argmax(logits, dim=-1).tolist())

            if step % log_step_interval == 0:
                logging.warning(f"epoch_index: {epoch_index}, batch_index: {batch_index}, ema_loss: {ema_loss}")

            if step % save_step_interval == 0:
                os.makedirs(save_path, exist_ok=True)
                save_file = os.path.join(save_path, f"step_{step}.pt")
                torch.save({
                    "epoch": epoch_index,
                    "step": step,
                    "model_state_dict": model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'loss': bce_loss
                }, save_file)


                logging.warning(f"checkpoint has been saved in {save_file}")
            if step % save_step_interval == 0:
                os.makedirs(save_path, exist_ok=True)
                save_file = os.path.join(save_path, f"step_{step}.pt")
                torch.save({
                    "epoch": epoch_index,
                    "step": step,
                    "model_state_dict": model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'loss': bce_loss,
                    'accuracy': accuracy,
                    'precision': precision,
                    'recall': recall,
                    'f1': f1
                }, save_file)

                logging.warning(f"checkpoint has been saved in {save_file}")

            if step % eval_step_interval == 0:
                logging.warning("start to do evaluation...")
                model.eval()
                ema_eval_loss = 0
                total_acc_account = 0
                total_account = 0
                true_labels = []
                predicted_labels = []

                for eval_batch_index, (eval_target, eval_token_index) in enumerate(eval_data_loader):
                    total_account += eval_target.shape[0]
                    eval_logits = model(eval_token_index)
                    total_acc_account += (torch.argmax(eval_logits, dim=-1) == eval_target).sum().item()
                    eval_bce_loss = F.binary_cross_entropy(torch.sigmoid(eval_logits),
                                                           F.one_hot(eval_target, num_classes=2).to(torch.float32))
                    ema_eval_loss = 0.9 * ema_eval_loss + 0.1 * eval_bce_loss

                    true_labels.extend(eval_target.tolist())
                    predicted_labels.extend(torch.argmax(eval_logits, dim=-1).tolist())

                accuracy = accuracy_score(true_labels, predicted_labels)
                precision = precision_score(true_labels, predicted_labels)
                recall = recall_score(true_labels, predicted_labels)
                f1 = f1_score(true_labels, predicted_labels)

                logging.warning(f"ema_eval_loss: {ema_eval_loss}, eval_acc: {total_acc_account / total_account}")
                logging.warning(f"Precision: {precision}, Recall: {recall}, F1: {f1}, Accuracy: {accuracy}")
                model.train()


model = GCNN()
# model = TextClassificationModel()
print("模型总参数:", sum(p.numel() for p in model.parameters()))
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

train_data_iter = IMDB(root="data", split="train")  # Dataset类型的对象
train_data_loader = torch.utils.data.DataLoader(
    to_map_style_dataset(train_data_iter), batch_size=BATCH_SIZE, collate_fn=collate_fn, shuffle=True)

eval_data_iter = IMDB(root="data", split="test")  # Dataset类型的对象
# collate校对
eval_data_loader = utils.data.DataLoader(
    to_map_style_dataset(eval_data_iter), batch_size=8, collate_fn=collate_fn)

# resume = "./data/step_500.pt"
resume = ""

train(train_data_loader, eval_data_loader, model, optimizer, num_epoch=10, log_step_interval=20, save_step_interval = 500, eval_step_interval = 300, save_path = "./log_imdb_text_classification2", resume = resume)

🥦代码分析

🥦导库

首先导入需要的库

import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
from torch import utils
import torchtext
from tqdm import tqdm
from torchtext.datasets import IMDB
  • torch (PyTorch):
    PyTorch 是一个用于机器学习和深度学习的开源深度学习框架。它提供了张量计算、自动微分、神经网络层和优化器等功能,使用户能够构建和训练深度学习模型。

  • torch.nn:
    torch.nn 模块包含了PyTorch中用于构建神经网络模型的类和函数。它包括各种神经网络层、损失函数和优化器等。

  • torch.nn.functional:
    torch.nn.functional 模块提供了一组函数,用于构建神经网络的非参数化操作,如激活函数、池化和卷积等。这些函数通常与torch.nn一起使用。

  • sklearn.metrics (scikit-learn):
    scikit-learn是一个用于机器学习的Python库,其中包含了一系列用于评估模型性能的度量工具。导入的precision_score、recall_score、f1_score 和 accuracy_score 用于计算分类模型的精确度、召回率、F1分数和准确性。

  • torch.utils:
    torch.utils 包含了一些实用工具和数据加载相关的函数。在这段代码中,它用于构建数据加载器。

  • torchtext:
    torchtext 是一个PyTorch的自然语言处理库,用于文本数据的处理和加载。它提供了用于文本数据预处理和构建数据集的功能。

  • tqdm:
    tqdm 是一个Python库,用于创建进度条,可用于监视循环迭代的进度。在代码中,它用于显示训练和评估的进度。

  • torchtext.datasets.IMDB:
    torchtext.datasets.IMDB 是TorchText库中的一个数据集,包含了IMDb电影评论的数据。这些评论用于情感分析任务,其中评论被标记为积极或消极。

🥦设置日志

logging.basicConfig(
    level=logging.WARN, stream=sys.stdout, format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
)

在代码中设置日志的作用是记录程序的运行状态、调试信息和重要事件,以便在开发和生产环境中更轻松地诊断问题和了解程序的行为。设置日志有以下作用:

  • 问题诊断:当程序出现错误或异常时,日志记录可以提供有关错误发生的位置、原因和上下文的信息。这有助于开发人员快速定位和修复问题。

  • 性能分析:通过记录程序的运行时间和关键操作的时间戳,日志可以用于性能分析,帮助开发人员识别潜在的性能瓶颈。

  • 跟踪进度:在长时间运行的任务中,例如训练深度学习模型,日志记录可以帮助跟踪任务的进度,以便了解训练状态、完成的步骤和剩余时间。

  • 监控和警报:日志可以与监控系统集成,以便在发生关键事件或异常情况时触发警报。这对于及时响应问题非常重要。

  • 审计和合规:在某些应用中,日志记录是合规性的一部分,用于追踪系统的操作和用户的活动。日志可以用于审计和调查。

在上述代码中,设置日志的目的是跟踪训练进度、记录训练损失以及保存检查点。它允许开发人员监视模型训练的进展并在需要时查看详细信息,例如损失值和评估指标。此外,日志还可以用于调试和查看模型性能。

🥦模型定义

代码定义了两个模型:

GCNN:用于文本分类的门控卷积神经网络。
TextClassificationModel:使用嵌入和线性层的简单文本分类模型。

🥦GCNN

class GCNN(nn.Module):
    def __init__(self, vocab_size=VOCAB_SIZE, embedding_dim=64, num_class=2):
        super(GCNN, self).__init__()

        self.embedding_table = nn.Embedding(vocab_size, embedding_dim)
        nn.init.xavier_uniform_(self.embedding_table.weight)

        # 都是1维卷积
        self.conv_A_1 = nn.Conv1d(embedding_dim, 64, 15, stride=7)
        self.conv_B_1 = nn.Conv1d(embedding_dim, 64, 15, stride=7)

        self.conv_A_2 = nn.Conv1d(64, 64, 15, stride=7)
        self.conv_B_2 = nn.Conv1d(64, 64, 15, stride=7)

        self.output_linear1 = nn.Linear(64, 128)
        self.output_linear2 = nn.Linear(128, num_class)

    def forward(self, word_index):
        """
        定义GCN网络的算子操作流程,基于句子单词ID输入得到分类logits输出
        """
        # 1. 通过word_index得到word_embedding
        # word_index shape: [bs, max_seq_len]
        word_embedding = self.embedding_table(word_index)  # [bs, max_seq_len, embedding_dim]

        # 2. 编写第一层1D门卷积模块,通道数在第2维
        word_embedding = word_embedding.transpose(1, 2)  # [bs, embedding_dim, max_seq_len]
        A = self.conv_A_1(word_embedding)
        B = self.conv_B_1(word_embedding)
        H = A * torch.sigmoid(B)  # [bs, 64, max_seq_len]

        A = self.conv_A_2(H)
        B = self.conv_B_2(H)
        H = A * torch.sigmoid(B)  # [bs, 64, max_seq_len]

        # 3. 池化并经过全连接层
        pool_output = torch.mean(H, dim=-1)  # 平均池化,得到[bs, 4096]
        linear1_output = self.output_linear1(pool_output)

        # 最后一层需要设置为隐含层数目
        logits = self.output_linear2(linear1_output)  # [bs, 2]

        return logits

🥦TextClassificationModel

class TextClassificationModel(nn.Module):
    """
    简单版embedding.DNN模型
    """

    def __init__(self, vocab_size=VOCAB_SIZE, embed_dim=64, num_class=2):
        super(TextClassificationModel, self).__init__()
        self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, sparse=False)
        self.fc = nn.Linear(embed_dim, num_class)

    def forward(self, token_index):
        # 词袋
        embedded = self.embedding(token_index)  # shape: [bs, embedding_dim]
        return self.fc(embedded)

🥦准备IMDb数据集

这行代码使用TorchText的IMDB数据集对象,导入IMDb数据集的训练集部分。

# 数据集导入
train_data_iter = IMDB(root="./data", split="train")

这行代码创建了一个用于将文本分词为单词的分词器。

# 数据预处理
tokenizer = get_tokenizer("basic_english")

这里,build_vocab_from_iterator 函数根据文本数据创建了一个词汇表,只包括出现频率大于等于20次的单词。特殊标记用于处理未知单词。然后,set_default_index将特殊标记的索引设置为0。

# 构建词汇表
vocab = build_vocab_from_iterator(yeild_tokens(train_data_iter, tokenizer), min_freq=20, specials=["<unk>"])
vocab.set_default_index(0)

这是一个自定义的校对函数,用于处理DataLoader返回的批次数据,将文本转换为可以输入模型的张量形式。

def collate_fn(batch):
    """
    对DataLoader所生成的mini-batch进行后处理
    """
    target = []
    token_index = []
    max_length = 0  # 最大的token长度
    for i, (label, comment) in enumerate(batch):
        tokens = tokenizer(comment)
        token_index.append(vocab(tokens))  # 字符列表转换为索引列表

        # 确定最大的句子长度
        if len(tokens) > max_length:
            max_length = len(tokens)

        if label == "pos":
            target.append(0)
        else:
            target.append(1)

    token_index = [index + [0] * (max_length - len(index)) for index in token_index]
    # one-hot接收长整形的数据,所以要转换为int64
    return (torch.tensor(target).to(torch.int64), torch.tensor(token_index).to(torch.int32))

这行代码将IMDb训练数据集加载到DataLoader对象中,以便进行模型训练。collate_fn函数用于处理数据的批处理。

train_data_loader = torch.utils.data.DataLoader(
    to_map_style_dataset(train_data_iter), batch_size=BATCH_SIZE, collate_fn=collate_fn, shuffle=True)

上述代码块执行了IMDb数据集的准备工作,包括导入数据、分词、构建词汇表和设置数据加载器。这些步骤是为了使数据集可用于训练文本分类模型。

🥦整理函数

这个 collate_fn 函数用于对 DataLoader 批次中的数据进行处理,确保每个批次中的文本序列具有相同的长度,并将标签转换为适用于模型输入的张量形式。它的工作包括以下几个方面:

提取标签和评论文本。
使用分词器将评论文本分词为单词。
确定批次中最长评论的长度。
根据最长评论的长度,将所有评论的单词索引序列填充到相同的长度。
将标签转换为适当的张量形式(这里是将标签转换为长整数型)。
返回处理后的批次数据,其中包括标签和填充后的单词索引序列。

这个整理函数确保了模型在训练期间能够处理不同长度的文本序列,并将它们转换为模型可接受的张量输入。

🥦训练函数

def train(train_data_loader, eval_data_loader, model, optimizer, num_epoch, log_step_interval, save_step_interval,  eval_step_interval, save_path, resume=""):
    """
    此处data_loader是map-style dataset
    """
    start_epoch = 0
    start_step = 0
    if resume != "":
        # 加载之前训练过的模型的参数文件
        logging.warning(f"loading from resume")
        checkpoint = torch.load(resume)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint['epoch']
        start_step = checkpoint['step']

    for epoch_index in tqdm(range(start_epoch, num_epoch), desc="epoch"):
        ema_loss = 0
        total_acc_account = 0
        total_account = 0
        true_labels = []
        predicted_labels = []
        num_batches = len(train_data_loader)
        for batch_index, (target, token_index) in enumerate(train_data_loader):
            optimizer.zero_grad()
            step = num_batches * (epoch_index) + batch_index + 1
            logits = model(token_index)
            # one-hot需要转换float32才可以训练
            bce_loss = F.binary_cross_entropy(torch.sigmoid(logits), F.one_hot(target, num_classes=2).to(torch.float32))
            ema_loss = 0.9 * ema_loss + 0.1 * bce_loss  # 指数平均loss
            bce_loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), 0.1)  # 梯度的正则进行截断,保证训练稳定
            optimizer.step()  # 更新参数

            true_labels.extend(target.tolist())
            predicted_labels.extend(torch.argmax(logits, dim=-1).tolist())

            if step % log_step_interval == 0:
                logging.warning(f"epoch_index: {epoch_index}, batch_index: {batch_index}, ema_loss: {ema_loss}")

            if step % save_step_interval == 0:
                os.makedirs(save_path, exist_ok=True)
                save_file = os.path.join(save_path, f"step_{step}.pt")
                torch.save({
                    "epoch": epoch_index,
                    "step": step,
                    "model_state_dict": model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'loss': bce_loss
                }, save_file)


                logging.warning(f"checkpoint has been saved in {save_file}")
            if step % save_step_interval == 0:
                os.makedirs(save_path, exist_ok=True)
                save_file = os.path.join(save_path, f"step_{step}.pt")
                torch.save({
                    "epoch": epoch_index,
                    "step": step,
                    "model_state_dict": model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'loss': bce_loss,
                    'accuracy': accuracy,
                    'precision': precision,
                    'recall': recall,
                    'f1': f1
                }, save_file)

                logging.warning(f"checkpoint has been saved in {save_file}")

            if step % eval_step_interval == 0:
                logging.warning("start to do evaluation...")
                model.eval()
                ema_eval_loss = 0
                total_acc_account = 0
                total_account = 0
                true_labels = []
                predicted_labels = []

                for eval_batch_index, (eval_target, eval_token_index) in enumerate(eval_data_loader):
                    total_account += eval_target.shape[0]
                    eval_logits = model(eval_token_index)
                    total_acc_account += (torch.argmax(eval_logits, dim=-1) == eval_target).sum().item()
                    eval_bce_loss = F.binary_cross_entropy(torch.sigmoid(eval_logits),
                                                           F.one_hot(eval_target, num_classes=2).to(torch.float32))
                    ema_eval_loss = 0.9 * ema_eval_loss + 0.1 * eval_bce_loss

                    true_labels.extend(eval_target.tolist())
                    predicted_labels.extend(torch.argmax(eval_logits, dim=-1).tolist())

                accuracy = accuracy_score(true_labels, predicted_labels)
                precision = precision_score(true_labels, predicted_labels)
                recall = recall_score(true_labels, predicted_labels)
                f1 = f1_score(true_labels, predicted_labels)

                logging.warning(f"ema_eval_loss: {ema_eval_loss}, eval_acc: {total_acc_account / total_account}")
                logging.warning(f"Precision: {precision}, Recall: {recall}, F1: {f1}, Accuracy: {accuracy}")
                model.train()

这段代码定义了一个名为 train 的函数,用于执行训练过程。下面是该函数的详细说明:

train 函数接受以下参数:
    train_data_loader: 训练数据的 DataLoader,用于迭代训练数据。
    eval_data_loader: 用于评估的 DataLoader,用于评估模型性能。
    model: 要训练的神经网络模型。
    optimizer: 用于更新模型参数的优化器。
    num_epoch: 训练的总周期数。
    log_step_interval: 记录日志的间隔步数。
    save_step_interval: 保存模型检查点的间隔步数。
    eval_step_interval: 执行评估的间隔步数。
    save_path: 保存模型检查点的目录。
    resume: 可选的,用于恢复训练的检查点文件路径。

训练函数的主要工作如下:
    它首先检查是否有恢复训练的检查点文件。如果有,它会加载之前训练的模型参数和优化器状态,以便继续训练。
    然后,它开始进行一系列的训练周期(epochs),每个周期内包含多个训练步(batches)。
    在每个训练步中,它执行以下操作:
        零化梯度,以准备更新模型参数。
        计算模型的预测输出(logits)。
        计算二进制交叉熵损失(binary cross-entropy loss)。
        使用反向传播(backpropagation)计算梯度并更新模型参数。
        记录损失、真实标签和预测标签。
        如果步数达到了 log_step_interval,则记录损失。
        如果步数达到了 save_step_interval,则保存模型检查点。
        如果步数达到了 eval_step_interval,则执行评估:
    将模型切换到评估模式(model.eval())。
    对评估数据集中的每个批次执行以下操作:
        计算模型的预测输出。
        计算二进制交叉熵损失。
        计算准确性、精确度、召回率和F1分数。
        记录评估损失和评估指标。
    将模型切换回训练模式(model.train())。

最后,训练函数返回经过训练的模型。

这个训练函数执行了完整的训练过程,包括了模型的前向传播、损失计算、梯度更新、日志记录、模型检查点的保存和评估。通过调用这个函数,你可以训练模型并监视其性能。

🥦模型初始化和优化器

model = GCNN()
# model = TextClassificationModel()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

🥦加载用于训练和评估的数据

在提供的代码中,加载用于训练和评估的数据的部分如下:

train_data_iter = IMDB(root="data", split="train")

这一行代码使用 TorchText 的 IMDB 数据集对象,导入 IMDB 数据集的训练集部分。这部分数据将用于模型的训练。

eval_data_iter = IMDB(root="data", split="test")

这一行代码使用 TorchText 的 IMDB 数据集对象,导入 IMDB 数据集的测试集部分。这部分数据将用于评估模型的性能。


之后,这些数据集通过以下代码转化为 DataLoader 对象,以便用于模型训练和评估:

# 训练数据 DataLoader
train_data_loader = torch.utils.data.DataLoader(
    to_map_style_dataset(train_data_iter), batch_size=BATCH_SIZE, collate_fn=collate_fn, shuffle=True)
# 评估数据 DataLoader
eval_data_loader = utils.data.DataLoader(
     to_map_style_dataset(eval_data_iter), batch_size=8, collate_fn=collate_fn)

这些 DataLoader 对象将数据加载到内存中,以便训练和评估使用。collate_fn 函数用于处理数据的批次,确保它们具有适当的格式,以便输入到模型中。

这些部分负责加载和准备用于训练和评估的数据,是机器学习模型训练和评估的重要准备步骤。训练数据用于训练模型,而评估数据用于评估模型的性能。

🥦恢复训练

start_epoch = 0
start_step = 0
if resume != "":
    # 加载之前训练过的模型的参数文件
    logging.warning(f"loading from resume")
    checkpoint = torch.load(resume)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    start_epoch = checkpoint['epoch']
    start_step = checkpoint['step']

上述代码段位于训练函数中的开头部分,主要用于检查是否有已经训练过的模型的检查点文件,以便继续训练。具体解释如下:

如果 resume 变量不为空(即存在要恢复的检查点文件路径),则执行以下操作:
通过 torch.load 加载之前训练过的模型的检查点文件。
使用 load_state_dict 方法将已保存的模型参数加载到当前的模型中,以便继续训练。
同样,使用 load_state_dict 方法将已保存的优化器状态加载到当前的优化器中,以确保继续从之前的状态开始训练。
获取之前训练的轮数和步数,以便从恢复的状态继续训练。

这部分代码的目的是允许从之前保存的模型检查点继续训练,而不是从头开始。这对于长时间运行的训练任务非常有用,可以在中途中断训练并在之后恢复,而不会丢失之前的训练进度。

🥦调用训练

train(train_data_loader, eval_data_loader, model, optimizer, num_epoch=10, log_step_interval=20, save_step_interval=500, eval_step_interval=300, save_path="./log_imdb_text_classification2", resume=resume)

🥦保存文件的读取

import torch

# 指定已存在的 .pt 文件路径
file_path = "./log_imdb_text_classification/step_3500.pt"  # 替换为实际的文件路径

# 使用 torch.load() 加载文件
checkpoint = torch.load(file_path)

# 查看准确率、精确率、召回率和F1分数
accuracy = checkpoint["accuracy"]
precision = checkpoint["precision"]
recall = checkpoint["recall"]
f1 = checkpoint["f1"]

print("Accuracy:", accuracy)
print("Precision:", precision)
print("Recall:", recall)
print("F1 Score:", f1)

在这里插入图片描述

🥦扩展 LSTM、GRU

本文原作者使用的是卷积神经网络,但是卷积神经网络的优化模型GCNN,但是这个模型对于图更好,由此我接下来引入两个循环神经网络LSTM和GRU

class LSTMModel(nn.Module):
    def __init__(self, vocab_size=VOCAB_SIZE, embedding_dim=64, hidden_dim=64, num_class=2):
        super(LSTMModel, self).__init__()
        self.embedding_table = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers=1, batch_first=True)
        self.output_linear = nn.Linear(hidden_dim, num_class)

    def forward(self, word_index):
        word_embedding = self.embedding_table(word_index)
        lstm_out, _ = self.lstm(word_embedding)
        lstm_out = lstm_out[:, -1, :]  # 取最后一个时间步的输出
        logits = self.output_linear(lstm_out)
        return logits

class GRUModel(nn.Module):
    def __init__(self, vocab_size=VOCAB_SIZE, embedding_dim=64, hidden_dim=64, num_class=2):
        super(GRUModel, self).__init__()
        self.embedding_table = nn.Embedding(vocab_size, embedding_dim)
        self.gru = nn.GRU(embedding_dim, hidden_dim, num_layers=1, batch_first=True)
        self.output_linear = nn.Linear(hidden_dim, num_class)

    def forward(self, word_index):
        word_embedding = self.embedding_table(word_index)
        gru_out, _ = self.gru(word_embedding)
        gru_out = gru_out[:, -1, :]  # 取最后一个时间步的输出
        logits = self.output_linear(gru_out)
        return logits
# 创建LSTM模型
lstm_model = LSTMModel()
print("模型总参数:", sum(p.numel() for p in lstm_model.parameters()))
lstm_optimizer = torch.optim.Adam(lstm_model.parameters(), lr=0.001)

# 创建GRU模型
# gru_model = GRUModel()
# print("模型总参数:", sum(p.numel() for p in gru_model.parameters()))
# gru_optimizer = torch.optim.Adam(gru_model.parameters(), lr=0.001)
# 训练LSTM模型
train(train_data_loader, eval_data_loader, lstm_model, lstm_optimizer, num_epoch=10, log_step_interval=20, save_step_interval=500, eval_step_interval=300, save_path="./log_imdb_lstm", resume="")

# 训练GRU模型
# train(train_data_loader, eval_data_loader, gru_model, gru_optimizer, num_epoch=10, log_step_interval=20, save_step_interval=500, eval_step_interval=300, save_path="./log_imdb_gru", resume="")

感兴趣的小伙伴可以试试,对比一下

🥦总结

本文代码来自网络仅供学习,原文地址

请添加图片描述

挑战与创造都是很痛苦的,但是很充实。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1123720.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【基于形态学的权重自适应去噪】

【基于形态学的权重自适应去噪】 1 引言2 数学形态学原理3 权重自适应的多结构形态学去噪4 实现代码4.1 主函数代码4.2 串、并联去噪4.3 图像权值计算4.4 计算 PSNR 值 5 实验结果 参考书籍&#xff1a;计算机视觉与深度学习实战:以MATLAB、Python为工具&#xff0c; 主编&…

【Java】多态中调用成员的特点

示例代码 public class Test {public static void main(String[] args) {//创建对象&#xff08;多态方式&#xff09;//父类 f new 子类();Animal a new Dog();//调用成员变量&#xff1a;编译看左边&#xff0c;运行也看左边//编译看左边&#xff1a;javac编译代码的时候&…

Python深度学习进阶与应用丨注意力(Attention)机制、Transformer模型、生成式模型、目标检测算法、图神经网络、强化学习详解等

目录 第一章 注意力&#xff08;Attention&#xff09;机制详解 第二章 Transformer模型详解 第三章 生成式模型详解 第四章 目标检测算法详解 第五章 图神经网络详解 第六章 强化学习详解 第七章 深度学习模型可解释性与可视化方法详解 更多应用 近年来&#xff0c;伴…

【Java】JDK 21中的虚拟线程以及其他新特性

目录 一、字符串模板&#xff08;String Templates&#xff09; 二、序列化集合&#xff08;Sequenced Collections&#xff09; 三、分代ZGC&#xff08;Generational ZGC&#xff09; 四、记录模式&#xff08;Record Patterns&#xff09; 五、Fibers&#xff08;纤程&…

实战SRC

附言&#xff1a;从补天的公益src公司中选中了幸运儿。 1. 通过hunter鹰图平台搜索公司的相关资产&#xff0c;发现其采用了华途应用安全网关。 2.访问相关地址&#xff0c;尝试使用弱口令登录&#xff0c;发现直接利用admin/admin就登录了&#xff0c;可以看到后台的相关日志…

汉语言语的声学特点是什么

汉语言语的声学特点是什兰明 医学硕士&#xff0c;听力学博士&#xff0c;听觉健康门诊主任 虽然互联网已经将英语作为最常用的&#xff08;第二&#xff09;语言的地位&#xff0c;但中文&#xff08;普通话&#xff09;仍然是最常用的母语。2010年&#xff0c;以中文为…

新成果展示:AlGaN/GaN基紫外光电晶体管的设计与制备

紫外光电探测器被广泛应用于导弹预警、火灾探测、非可见光通信、环境监测等民事和军事领域&#xff0c;这些应用场景的实现需要器件具有高信噪比和高灵敏度。因此&#xff0c;光电探测器需要具备响应度高、响应速度快和暗电流低的特性。近期&#xff0c;天津赛米卡尔科技有限公…

C++ 读取数量不定的输入数据

在C中&#xff0c;有时我们会遇到&#xff0c;在事先没有知道&#xff0c;要对多少个数进行求和的情况下&#xff0c;这就需要不断的读取数据直至没有新的输入为止&#xff1a; demo&#xff1a; #include <iostream> using namespace std;int main() {int sum 0;in…

如何打造小红书产品差异化,打造产品优势?

其实在当今的时代&#xff0c;我们实质上已经进入到了一个产能过剩的时代&#xff0c;这意味着大量的同质化产品出现在市场上&#xff0c;选择更多了但是选择也更少了。今天为大家分享下如何打造小红书产品差异化&#xff0c;打造产品优势&#xff1f; 下面是一些产品差异化策略…

Redis数据结构完全解析:底层实现细节揭秘

文章目录 &#x1f34a; 简单字符串&#x1f389; 问题1&#xff1a;SDS结构体的三个属性分别表示什么意思&#xff1f;&#x1f389; 问题2&#xff1a;SDS字符串的内存分配方式是怎么样的&#xff1f;&#x1f389; 问题3&#xff1a;SDS字符串的拼接操作是怎么样的&#xff…

Pyside6 QFile

Pyside6 QFile QFile使用QFile常用函数文件编辑类函数判断文件是否存在重命名文件删除文件函数复制文件 文件内容操作类函数文件打开函数文件关闭函数文件读取函数read函数使用readLine函数使用readAll函数使用 文件写入函数追加方式写文件重写方式写文件 程序界面程序主程序 P…

数据结构和算法——图

图 有向图 带权图 邻接矩阵 邻接表相较于邻接矩阵&#xff0c;减少了存储空间&#xff1b; 邻接表 参考视频&#xff1a;【尚硅谷】数据结构与算法&#xff08;Java数据结构与算法&#xff09;_哔哩哔哩_bilibili

高精度数字压力表丨铭控传感多款数字压力表在多场景中的应用

时代日新月异、变化万千&#xff0c;压力表应用需求始终在不断变化&#xff0c;但铭控传感对压力测量的应用一如既往的了如指掌。铭控传感总是能够为您提供最合适符合您要求的成本和功能都极佳产品解决方案&#xff0c;通过您的需求定制MEOKON产品&#xff0c;铭控传感始终为用…

【EP2C35F672C8 EDA试验箱下载】

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 文章目录 前言一、试验箱如何下载&#xff1f;1. 编译工程没问题后&#xff0c;配置引脚2.配置完引脚后&#xff0c;记得重新编译3.配置下载4.配置下载器&#xff0c;需要装驱…

如何使用Python进行自动化测试

目录 一、选择适合的测试框架 二、编写测试用例 三、运行和分析测试结果 四、重构测试用例 五、注意事项 总结 随着软件行业的快速发展&#xff0c;自动化测试已成为软件开发过程中不可或缺的一部分。使用Python进行自动化测试可以帮助我们快速、高效地测试应用程序&…

Explainable-ZSL

模型 体会 作者的实验做得很充足&#xff0c;但未提供可直接运行的代码

可变参数模板 - c++11

文章目录&#xff1a; 可变参数模板的认识参数包的展开递归函数方式展开参数包逗号表达式展开参数包 STL容器中的empalce相关接口函数 可变参数模板的认识 c11 引入了可变参数模板&#xff08;variadic templates&#xff09;的特性&#xff0c;使得编写支持任意数量参数的模板…

交易想简化分析并少失误,波浪原则anzo capital认为必不可少

要想在交易中简化分析并少失误&#xff0c;不管是交易新手还是交易高手&#xff0c;anzo capital认为其实很容易&#xff0c;只要了解艾略特波浪原则。 艾略特波浪原则&#xff0c;每一个趋势都由特定的基本元素(波浪)组成&#xff0c;这些元素具有重复的趋势。这些波浪可以根…

企业或人力资源公司可利用直播将职位以视频直播的方式展现

抖音直播招聘报白是一种通过直播方式展示职位信息并与求职者互动的招聘方式。抖音的短视频流量能够让岗位信息覆盖更广泛的人群&#xff0c;增加招聘信息的曝光度。通过抖音的短视频流量红利和精准推送&#xff0c;能够提高岗位信息的曝光度和求职者的留存率。如果你想做招聘报…

Windows系统安装node-red

Quick Start 1. Install Node.js 第一步下载node.js,超链接在后面 Download the latest LTS version of Node.js from the official Node.js home page. It will offer you the best version for your system. Run the downloaded MSI file. Installing Node.js requires l…