【NLP练习】Transformer实战-单词预测

news2024/12/23 9:54:41
  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊

任务:自定义输入一段英文文本进行预测

一、定义模型

from tempfile import TemporaryDirectory
from typing import Tuple
from torch import nn,Tensor
from torch.nn import TransformerEncoder, TransformerEncoderLayer
import math, os, torch

class TransformerModel(nn.Module):
    def __init__(self, ntoken: int, d_model: int, nhead: int, d_hid: int, nlayers: int, dropout: float = 0.5):
        super().__init__()
        self.pos_encoder = PositionalEncoding(d_model, dropout)

        #定义编码器层
        encoder_layers = TransformerEncoderLayer(d_model, nhead, d_hid, dropout)
        #定义编码器,pytorch将Transformer编码器进行了打包,这里直接调用即可
        self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
        self.embedding = nn.Embedding(ntoken,d_model)
        self.d_model = d_model
        self.linear = nn.Linear(d_model, ntoken)
        self.init_weights()
        #初始化权重
    def init_weights(self) -> None:
        initrange = 0.1
        self.embedding.weight.data.uniform_(-initrange, initrange)
        self.linear.bias.data.zeros_()
        self.linear.weight.data.uniform_(-initrange, initrange)

    def forward(self, src:Tensor, src_mask: Tensor = None) -> Tensor:
        """
        Arguments:
            src:      Tensor, 形状为[seq_len, batch_size]
            src_mask: Tensor, 形状为[seq_len, seq_len]

        Returns:
            输出的Tensor,形状为[seq_len, batch_size, ntoken]
        """
        src = self.embedding(src) * math.sqrt(self.d_model)
        src = self.pos_encoder(src)
        output = self.transformer_encoder(src, src_mask)
        output = self.linear(output)
        return output

class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p = dropout)
        #生成位置编码的位置张量
        position = torch.arange(max_len).unsqueeze(1)
        #计算位置编码的除数项
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        #创建位置编码张量
        pe = torch.zeros(max_len, 1, d_model)
        #使用正弦函数计算位置编码中的基数维度部分
        pe[:, 0, 1::2] = torch.sin(position * div_term)
        #使用余弦函数计算位置编码中的偶数维度部分
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x: Tensor) -> Tensor:
        """
        Arguments:
            x:      Tensor, 形状为[seq_len, batch_size, embedding_dim]
        """
        #将位置编码添加到输入张量
        x = x + self.pe[:x.size(0)]
        #应用dropout
        return self.dropout(x)

二、加载数据集

本实验使用torchtext生成Wikitext-2数据集。在此之前,你需要安装下面的包:

  • pip install portalocker
  • pip install torchdata
import torchtext
from torchtext.datasets.wikitext2 import WikiText2
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator

#从torchtext库中导入WikiTetx2数据集
train_iter = WikiText2(split = 'train')

#获取基本的英语分词器
tokenizer = get_tokenizer('basic_english')
#通过迭代器构建词汇表
vocab = build_vocab_from_iterator(map(tokenizer, train_iter), specials=['<unk>'])
#将默认索引设置为'<unk>'
vocab.set_default_index(vocab['<unk>'])

def data_process(raw_text_iter: dataset.IterableDataset) -> Tensor:
    """将原始文本转换为扁平的张量"""
    data = [torch.tensor(vocab(tokenizer(item)),
                        dtype = torch.long) for item in raw_text_iter]
    return torch.cat(tuple(filer(lambda t: t.numel() > 0, data)))

#由于构建词汇表时"train_iter"被使用了,所以需要重新创建
train_iter, val_iter, test_iter = WikiText2()

#队训练、验证和测试数据进行处理
train_data = data_process(train_iter)
val_data = data_process(val_iter)
test_data = data_process(test_iter)

#检查是否有可用的CUDA设备,将设备设置为GPU或者CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def batchify(data: Tensor, bsz: int) -> Tensor:
    """将数据划分为bsz个单独的序列,去除不能完全容纳的额外元素。
    参数:
        data: Tensor,形状为``[N]``
        bsz:int,批大小
    返回:
        形状为[N // bsz, bsz]的张量
    """
    seq_len = data.size(0) // bsz
    data = data[:seq_len * bsz]
    data = data.view(bsz, seq_len).t().contiguous()
    return data.to(device)

#设置批大小和评估批大小
batch_size = 20
eval_batch_size = 10
#将训练、验证和测试数据进行批处理
train_data = batchify(train_data, batch_size)   #形状为[seq_len, batch_size]
val_data = batchify(val_data, eval_batch_size)
test_data = batchify(test_data, eval_batch_size)
bptt = 35

#获取批次数据
def get_batch(source:Tensor, i: int) -> Tuple[Tensor, Tensor]:
    """
    参数:
        source: Tensor,形状为``[full_seq_len, batch_size]``
        i : int, 当前批次索引
    返回:
        tuple(data, target),
        -data形状为[seq_len, batch_size],
        -target形状为[seq_len * batch_size]
    """
    #计算当前批次的序列长度,最大为bptt,确保不超过source的长度
    seq_len = min(bptt, len(source) - 1 - i)
    #获取data,从i开始,长度为seq_len
    data = source[i:i+seq_len]
    #获取target,从i+1开始,长度为seq_len,并将其形状转换为一维张量
    target = source[i+1:i+1+seq_len].reshape(-1)
    return data, target

三、初始化实例

ntokend = len(vocab)
emsize = 200
d_hid = 200
nlayers = 2
nhead = 2
dropout = 0.2
#创建transformer模型
model = TransformerModel(ntokend,
                        emsize,
                        nhead,
                        d_hid,
                        nlayers,
                        dropout).to(device)

四、训练模型

结合使用CrossEntropyLoss与SGD(随机梯度下降优化器)。训练期间,使用torch.nn.utils.clip_grad_norm_来防止梯度爆炸

import time
criterion = nn.CrossEntropyLoss() #定义交叉熵损失函数
lr = 5.0
optimizer = torch.optim.SGD(model.parameters(),lr = lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gama = 0.95)

def train(model: nn.Module) -> None:
    model.train() #开启训练模式
    total_loss = 0.
    log_interval = 200 #
    start_time = time.time()

    num_batches = len(train_data) // bptt
    for batch, i in enumerate(range(0, train_data.size(0) - 1, bptt)):
        data, targets = get_batch(train_data, i)
        output = model(data)
        output_flat = output.view(-1, ntokens)
        loss = criterion(output_flat, targets) #计算损失

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        optimizer.step()

        total_loss += loss.item()
        if batch % log_interval == 0 and batch > 0:
            lr = scheduler.get_last_lr()[0]
            ms_per_batch = (time.time() - start_time) * 1000 / log_interval
            cur_loss = total_loss / log_interval
            ppl = math.exp(cur_loss)

            print(f'| epoch{epoch:3d} | {batch:5d} / {num_batches:5d} batches |'
                 f'lr{lr:02.2f} | ms/batch {ms_per_batch:5.2f} |'
                f'loss {cur_loss:5.2f}|ppl{ppl:8.2f}')
            total_loss = 0
            start_time = time.time()

def evaluate(model:nn.Module, eval_data:Tensor) -> float:
    model.eval()
    total_loss = 0.
    with torch.no_grad():
        for i in range(0,eval_data.size(0) - 1, bptt):
            data, targets = get_batch(eval_data,i)
            seq_len = data.size(0)
            output = model(data)
            output_flat = output.view(-1,ntokens)
            total_loss += seq_len * criterion(output_flat, targets).item()

    return total_loss / (len(eval_data) - 1)
    
        
best_val_loss = float('inf')
epochs = 1

with TemporaryDirectory() as tempdir:
    best_model_params_path = os.path.join(tempdir, "best_model_params.pt")

    for epoch in range(1, epochs + 1):
        epoch_start_time = time.time()
        train(model)
        val_loss = evaluate(model, val_data)
        val_ppl = math_exp(val_loss)
        elapsed = time.time() - epoch_start_time

        #打印当前epoch的信息,包括耗时、验证损失和困惑度
        print('-' * 89)
        print(f'|end of epoch {epoch:3d} | time:{elapsed: 5.2f}s |'
             f'valid loss {val_loss:5.2f} | valid ppl {val_ppl: 8.2f}')
        print('-' * 89)

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), best_model_params_path)
        scheduler.step()    #更新学习率
    model.load_state_dict(torch.load(best_model_params_path))

代码输出:
在这里插入图片描述

五、总结

加载数据集时,注意包的版本关联关系。另外,注意结合使用优化器提升优化性能。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1847570.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

DVWA 靶场 Open HTTP Redirect 通关解析

前言 DVWA代表Damn Vulnerable Web Application&#xff0c;是一个用于学习和练习Web应用程序漏洞的开源漏洞应用程序。它被设计成一个易于安装和配置的漏洞应用程序&#xff0c;旨在帮助安全专业人员和爱好者了解和熟悉不同类型的Web应用程序漏洞。 DVWA提供了一系列的漏洞场…

Nuxt快速学习开发---Nuxt3视图Views

Views Nuxt提供了几个组件层来实现应用程序的用户界面 默认情况下&#xff0c;Nuxt 会将app.vue文件视为入口点并为应用程序的每个路由呈现其内容 应用程序.vue <template> <div> <h1>Welcome to the homepage</h1> </div> </template> …

湖南科技大学24计算机考研情况,软工学硕考数二,分数线290分,录取均分321分!

湖南科技大学&#xff08;Hunan University of Science and Technology&#xff09;坐落在伟人故里、人文圣地湘潭&#xff0c;处于长株潭核心区域&#xff0c;比邻湘潭九华经济技术开发区&#xff08;国家级&#xff09;&#xff0c;是应急管理部、国家国防科技工业局与湖南省…

复分析——第6章—— Γ 函数和 ζ 函数(E.M. Stein R. Shakarchi)

第6章 Γ函数和Ζ函数(The Gamma and Zeta Functions) 毫不夸张地说&#xff0c;Γ函数和Ζ函数是数学中最重要的非初等函数之一。Γ函数在自然界中无处不在。它出现在大量计算中&#xff0c;并以分析中出现的大量恒等式为特征。对此的部分解释可能在于Γ函数的基本结构特性&…

Nginx HTTPS(证书) 部署实战

一、申请证书与认证 要搭建https服务首先需有SSL证书&#xff0c;证书通常是在第三方申请&#xff0c;在阿里云的安全服务中有SSL证书这一项&#xff0c;可以在里面申请免费的证书。也可以在自己电脑中生成&#xff0c;虽然也能完成加密&#xff0c;但是浏览器是不认可的&…

编译 CanMV 固件

前言 上一章节中已经搭建好了基于 CanMV 的 C 开发环境&#xff0c;这么一来便可以进行基于 C 语言和 FreeRTOS 的应用开发或者编译基于 MicroPython 语法的应用开发方式所需的 CanMV 固件&#xff0c;本 章就将带领读者体验一下 CanMV 固件的编译流程。 本章分为如下几个小节&…

Java面试题:mysql执行速度慢的原因和优化

Sql语句执行速度慢 原因 聚合查询 多表查询 表数据量过大查询 深度分页查询 分析 sql的执行计划 可以使用EXPLAIN或者DESC获取Mysql如何执行SELECT语句的信息 直接在select语句前加关键字explain/desc 得到一个执行信息表 信息字段分析 possible_keys:可能使用到的索…

云计算【第一阶段(18)】磁盘管理与文件系统

一、磁盘基础 磁盘&#xff08;disk&#xff09;是指利用磁记录技术存储数据的存储器。 磁盘是计算机主要的存储介质&#xff0c;可以存储大量的二进制数据&#xff0c;并且断电后也能保持数据不丢失。 早期计算机使用的磁盘是软磁盘&#xff08;Floppy Disk&#xff0c;简称…

海外社媒网站抓取经验总结:如何更高效实现网页抓取?

有效的网络抓取需要采取战略方法来克服挑战并确保最佳数据提取。让我们深入研究一些关键实践&#xff0c;这些实践将使您能够掌握复杂的网络抓取。 一、了解 Web 抓取检测 在深入探讨最佳实践之前&#xff0c;让我们先了解一下网站如何识别和抵御网络爬虫。了解您在这一过程中…

深度神经网络一

文章目录 深度神经网络 (DNN)1. 概述2. 基本概念3. 网络结构 深度神经网络的层次结构详细讲解1. 输入层&#xff08;Input Layer&#xff09;2. 隐藏层&#xff08;Hidden Layers&#xff09;3. 输出层&#xff08;Output Layer&#xff09;整体流程深度神经网络的优点深度神经…

[行业原型] 线上药房管理系统

​行业背景 据中国网上药店理事会调查报告显示&#xff1a;2011年&#xff0c;医药B2C的规模达到4亿元&#xff0c;仅出现5家销售额达5000万元的网上药店。而2011年医药行业的市场规模达到3718亿&#xff0c;线上药品的销售额还不到网下药店的一个零头&#xff0c;还有很大的发…

C++类基本常识

文章目录 一、类的默认方法二、类的成员变量初始化1 类的成员变量有三种初始化方法&#xff1a;2 成员变量初始化顺序3 const和static的初始化 三、C内存区域四、const和static 一、类的默认方法 C的类都会有8个默认方法 默认构造函数默认拷贝构造函数默认析构函数默认重载赋…

C语言基础关键字的含义和使用方法

​关键字在C语言中扮演着非常重要的角色&#xff0c;它们定义了语言的基本构造和语法规则&#xff0c;通过使用关键字&#xff0c;开发者可以创建变量、定义数据类型、控制程序流程&#xff08;如循环和条件判断&#xff09;、声明函数等。由于这些字是保留的&#xff0c;所以编…

springSecurity(二):实现登入获取token与解析token

登入生成token 主要思想 springSecurity使用UsernamePasswordAuthenticationToken类来封装用户名和密码的认证信息 代码实现 发起登入请求后&#xff0c;进入到login()方法 /*** 在接口中我们通过AuthenticationManager的authenticate方法来进行用户认证,* 所以需要在Secur…

Polyp-DDPM: Diffusion-Based Semantic Polyp Synthesis for Enhanced Segmentation

Polyp- ddpm:基于扩散的语义Polyp合成增强分割 摘要&#xff1a; 本研究介绍了一种基于扩散的方法Polyp-DDPM&#xff0c;该方法用于生成假面条件下息肉的逼真图像&#xff0c;旨在增强胃肠道息肉的分割。我们的方法解决了与医学图像相关的数据限制、高注释成本和隐私问题的挑…

网络编程(四)wireshark基本使用 TCP的三次握手和四次回挥手 TCP和UDP的比较

一、使用wireshark抓包分析协议头 &#xff08;一&#xff09;wireshark常用的过滤语句 tcp.port <想要查看的端口号> ip.src <想要查看的源IP地址> ip.dest <想要查看的目的IP地址> ip.addr <想要查看的IP地址>&#xff08;二&#xff09;抓包分…

DevEco鸿蒙开发请求网络交互设置

首先&#xff0c;在鸿蒙项目下config.json中找到module项&#xff0c;在里面填写"reqPermissions": [{"name": "ohos.permission.INTERNET"}] 在页面对应js文件内&#xff0c;填写import fetch from system.fetch;。 GET和POST区别 GET将表单数…

界面控件DevExpress v24.1全新发布 - 跨平台性进一步增强

DevExpress拥有.NET开发需要的所有平台控件&#xff0c;包含600多个UI控件、报表平台、DevExpress Dashboard eXpressApp 框架、适用于 Visual Studio的CodeRush等一系列辅助工具。屡获大奖的软件开发平台DevExpress 今年第一个重要版本v23.1正式发布&#xff0c;该版本拥有众多…

1. 基础设计流程(以时钟分频器的设计为例)

1. 准备工作 1. 写有vcs编译命令的run_vcs.csh的shell脚本 2. 装有timescale&#xff0c;设计文件以及仿真文件的flish.f&#xff08;filelist文件&#xff0c;用于VCS直接读取&#xff09; vcs -R -full64 -fsdb -f flist.f -l test.log 2. 写代码&#xff08;重点了解代码…

【Mac】DMG Canvas for mac(DMG镜像制作工具)软件介绍

软件介绍 DMG Canvas 是一款专门用于创建 macOS 磁盘映像文件&#xff08;DMG&#xff09;的软件。它的主要功能是让用户可以轻松地设计、定制和生成 macOS 上的安装器和磁盘映像文件&#xff0c;以下是它的一些主要特点和功能。 主要特点和功能 1. 用户界面设计 DMG Canva…