人工智能(pytorch)搭建模型9-pytorch搭建一个ELMo模型,实现训练过程

news2024/11/24 11:38:33

大家好,我是微学AI,今天给大家介绍一下人工智能(pytorch)搭建模型9-pytorch搭建一个ELMo模型,实现训练过程,本文将介绍如何使用PyTorch搭建ELMo模型,包括ELMo模型的原理、数据样例、模型训练、损失值和准确率的打印以及预测。文章将提供完整的代码实现。

目录

  1. ELMo模型简介
  2. 数据准备
  3. 搭建ELMo模型
  4. 训练模型
  5. 预测
  6. 总结

1. ELMo模型简介

ELMo(Embeddings from Language Models)是一种基于深度双向LSTM(Long Short-Term Memory)的预训练语言模型。ELMo的主要特点是能够生成上下文相关的词向量,这意味着同一个词在不同的上下文中可以有不同的词向量表示。这种表示能够捕捉到词汇的多义性,从而提高自然语言处理任务的性能。

ELMo的数学原理:

ELMo模型由两个组件组成:一个双向语言模型和一个线性组合层,如下所示:

ELMo ( t ) = E ( x t θ L M ) = γ ( ∑ k = 0 K − 1 s k ⋅ h t , k ) \text{ELMo} (t) = E \left( x_t\theta^{LM} \right) = \gamma \left( \sum{k=0}^{K-1} s_k \cdot h_{t,k} \right) ELMo(t)=E(xtθLM)=γ(k=0K1skht,k)

其中, x t x_t xt 是输入的词向量, t t t 表示词汇表中第 t t t 个词汇;
h t , k h_{t,k} ht,k 是 BiLM 的第 k k k 层的输出,它是一个大小为 2 d 2d 2d 的向量,其中 d d d 是隐藏层的维度,由于 BiLM 是双向的,因此每个词汇的表示将由其左侧和右侧的隐藏层状态组成;
s k s_k sk 是一个可训练的标量权重,用于加权 BiLM 的不同层的表示, K K K 是 BiLM 的层数;
γ \gamma γ 是一个可训练的标量参数,用于调整线性组合的规模。
对于每个词汇 t t t,ELMo 模型将词汇的表示 ELMo t \text{ELMo}_t ELMot 定义为 BiLM 的不同层的加权和。这种方法使得每个词汇的表示都是上下文相关的,而不是固定的。

在训练过程中,ELMo 模型使用了两个损失函数:一种是正向语言模型的损失函数,另一种是反向语言模型的损失函数。这些损失函数的目标是最小化模型在单个词汇和上下文中预测下一个词汇的错误率。在训练完成后,ELMo 模型中的参数被用来计算每个词汇的上下文相关表示。

ELMo 模型的数学原理包括双向语言模型和线性组合层,其中双向语言模型使用了两个损失函数来学习上下文相关的词向量表示。
在这里插入图片描述

2. 数据准备

我们将使用一个简单的文本数据集来演示ELMo模型的训练和预测。数据集包含以下句子:

I have a cat.
She likes to play with her toys.
My cat is very cute.

首先,我们需要对数据进行预处理,包括分词、构建词汇表和生成训练数据。

import torch
from torch.utils.data import Dataset, DataLoader
from collections import Counter
import numpy as np

# 分词
def tokenize(text):
    return text.lower().split()

# 构建词汇表
def build_vocab(tokenized_text):
    word_counts = Counter(tokenized_text)
    vocab = {word: idx for idx, (word, _) in enumerate(word_counts.most_common())}
    return vocab

# 生成训练数据
class TextDataset(Dataset):
    def __init__(self, text, vocab):
        self.text = text
        self.vocab = vocab

    def __len__(self):
        return len(self.text)

    def __getitem__(self, idx):
        return self.text[idx], self.vocab[self.text[idx]]

text = "I have a cat. She likes to play with her toys. My cat is very cute."
tokenized_text = tokenize(text)
vocab = build_vocab(tokenized_text)
dataset = TextDataset(tokenized_text, vocab)
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)

3. 搭建ELMo模型

接下来,我们将使用PyTorch搭建ELMo模型。模型包括一个词嵌入层、一个双向LSTM层和一个线性输出层。

import torch.nn as nn

class ELMo(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, num_layers):
        super(ELMo, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers, bidirectional=True)
        self.linear = nn.Linear(hidden_dim * 2, vocab_size)

    def forward(self, x):
        x = self.embedding(x)
        x, _ = self.lstm(x)
        x = self.linear(x)
        return x

vocab_size = len(vocab)
embedding_dim = 100
hidden_dim = 128
num_layers = 2
model = ELMo(vocab_size, embedding_dim, hidden_dim, num_layers)

4. 训练模型

现在我们可以开始训练模型。我们将使用交叉熵损失函数和Adam优化器。

import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

num_epochs = 20
for epoch in range(num_epochs):
    for batch in dataloader:
        _, inputs = batch
        inputs = torch.tensor(inputs).long()  # 将输入数据转换为张量
        targets = torch.tensor(inputs).long()  # 将目标数据转换为张量
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs.view(-1, vocab_size), targets.view(-1))
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {loss.item()}")

5. 预测

训练完成后,我们可以使用模型进行预测。这里我们将预测一个简单的句子:“My cat likes to play.”

def predict(model, sentence, vocab):
    tokenized_sentence = tokenize(sentence)
    input_ids = [vocab[word] for word in tokenized_sentence]
    inputs = torch.tensor(input_ids).unsqueeze(1)
    outputs = model(inputs)
    predictions = torch.argmax(outputs, dim=-1)
    pred = [tokenized_text[x] for x in list(predictions.numpy().reshape(-1))]

    return [word for word, _ in vocab.items() if word in pred ]

sentence = "My cat likes to play"
predictions = predict(model, sentence, vocab)
print("Predictions:", predictions)

6. 总结

这篇文章主要介绍了如何使用PyTorch搭建ELMo模型,包括模型的原理、数据准备、模型搭建、训练和预测。我们提供了完整的代码实现,确保代码可运行且无错误。希望本文能帮助您理解ELMo模型并在自己的项目中应用,更多模型的运用技巧请持续关注。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/605891.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

labelimg闪退解决方法(之前使用过labelimg,但新一次使用,打开文件夹无反应,再次打开闪退的问题)及标注经验

问题描述: 之前使用过labelimg进行好多次的标注,但新一次运行使用,发现打开目录无反应,再次打开闪退的问题,重启电脑并且从新运行labelimg仍然无效。 解决方法: 关闭labelimg,然后删除文件C…

一文纵览Umi‘s Friends生态,GameFi浪潮的变革者

以“P2E”为特性的 GameFi,代表着游戏时代的新盈利模式,它将 NFT 或其他形式的代币化资产作为游戏内容,游戏内资产的寿命会,则随着这些资产继续存在于玩家的钱包中而延长(即便游戏关闭),资产的互…

class文件中,常量池、方法表、属性表,异常表等等相关数据解析!小白就跟我一起对照学【class字节码文件分析】

前言:前段时间读《深入java虚拟机》介绍到class文件的时候,由于理论知识较多,人总感觉疲惫不堪,就泛泛阅读了一下。在工作中使用起来知识点知道,但是总是需要查阅各种资料。今天有时间,继续整理常量池后面的…

session与cookie

session是一种会话机制。当客户端发送登录请求时,服务端会生成一个sessionId存储在cookie中返回给客户端,客户端通过响应数据中的set-cookie字段来获取cookie并保存。如果客户端再向同一网站发送请求时,会自动携带cookie,相当于一…

离散数学_十章-图 ( 5 ):连通性 - 下

📷10.5 图的连通性 4. 有向图的连通性4.1 强连通4.2 弱连通4.3 (有向图的)强连通分支 5. 通路与同构6. 顶点间通路个数的计算 4. 有向图的连通性 根据是否考虑边的方向,在有向图中有两种连通性概念: 4.1 强连通 强连…

C/C++线程绑核详解

在一些大型的工程或者特殊场景中,我们会听到绑核,绑核分为进程绑核和线程绑核。绑核的最终目的都是为了提高程序和性能或者可靠性。 一:为什么需要绑核 操作系统发展至今,已经能很好的平衡运行在操作系统上层的应用,兼…

16.3:岛屿数量问题2

岛屿数量问题2 https://leetcode.cn/problems/number-of-islands-ii/ 给你一个大小为 m x n 的二进制网格 grid 。网格表示一个地图,其中,0 表示水,1 表示陆地。最初,grid 中的所有单元格都是水单元格(即&#xff0c…

Dubbo源码解析一网络通信原理

Dubbo 网络通信原理 1. Dubbo高可用集群1.1 服务集群的概述1.1.1 服务集群的概述1.1.2 调用过程1.1.3 组件介绍 1.2 集群容错机制1.2.1 内置集群容错策略1.2.1.1 Failover(失败自动切换)1.2.1.2 Failsafe(失败安全)1.2.1.3 Failfast(快速失败)1.2.1.4 Failback(失败自动恢复)1.…

卡尔曼滤波 | Matlab实现利用卡尔曼滤波器估计电池充电状态(Kalman Filtering)

文章目录 效果一览文章概述研究内容程序设计参考资料效果一览 文章概述 卡尔曼滤波 | Matlab实现利用卡尔曼滤波器估计电池充电状态(Kalman Filtering) 研究内容

gyp verb `which` failed Error: not found: python2

安装node-sass居然需要python2,7环境,不能python3 我只能重新降版本: python2.7:https://www.python.org/ftp/python/2.7/python-2.7.amd64.msi npm ERR! code 1 npm ERR! path F:\idea2021work\music01 初始化\music-client\node_modules\node-sass np…

自然语言处理从入门到应用——自然语言处理的基础任务:词性标注(POS Tagging)和句法分析(Syntactic Parsing)

分类目录:《自然语言处理从入门到应用》总目录 词性标注 词性是词语在句子中扮演的语法角色,也被称为词类(Part-Of-Speech,POS)。例如,表示抽象或具体事物名字(如“计算机”)的词被…

【遗传算法简介】

遗传算法:原理与实战 简介 遗传算法是一种模拟达尔文生物进化论的自然选择以及遗传学机制的搜索算法,由 John Holland 在20世纪70年代提出。它们在各种搜索、优化和机器学习任务中已被广泛应用。 遗传算法原理 1. 编码 遗传算法的第一步是将问题的可…

Andriod开发 Room 数据库处理框架

1.Room框架 Room是Android Jetpack组件库中的一部分,它是一个SQLite数据库的抽象层,提供了更简单的API和更好的性能,适合于中大型应用程序。 2.Room的使用 使用Room和之前使用SQLite搭建数据库的过程类似,但是更加简单了。 1&…

JAVA网络编程(一)

一、什么是网络编程 定义:在网络通信协议下,不同计算机上运行的程序,进行的数据传输。 应用场景:即时通信,网游,邮件等 不管什么场景,都是计算机与计算机之间通过网络在进行数据传输 java提供一…

软件测试必会:cookie、session和token的区别

今天就来说说session、cookie、token这三者之间的关系!最近这仨玩意搞得头有点大🤣 01、为什么会有它们三个 我们都知道 HTTP 协议是无状态的,所谓的无状态就是客户端每次想要与服务端通信,都必须重新与服务端链接,意…

穿针引线之 AsyncLocalStorage

在 Node.js 中,如何更优雅地获取请求上下文一直是一个问题,看一下下面的例子。 背景 const http require(http); function handler1(req, res) {console.log(req.url); }function handler2(req, res) {console.log(req.url); }http.createServer((req…

【react全家桶】react-Hook (下)

本人大二学生一枚&#xff0c;热爱前端&#xff0c;欢迎来交流学习哦&#xff0c;一起来学习吧。 <专栏推荐> &#x1f525;&#xff1a;js专栏 &#x1f525;&#xff1a;vue专栏 &#x1f525;&#xff1a;react专栏 文章目录 15【react-Hook &#xff08;下&#x…

进程控制(Linux)

进程控制 fork 在Linux中&#xff0c;fork函数是非常重要的函数&#xff0c;它从已存在进程中创建一个新进程。新进程为子进程&#xff0c;而原进程为父进程。 返回值&#xff1a; 在子进程中返回0&#xff0c;父进程中返回子进程的PID&#xff0c;子进程创建失败返回-1。 …

Spring - BeanFactory与ApplicationContext介绍

文章目录 Spring Bean一、BeanFactory 快速入门1.1 BeanFactory 开发步骤1.2 DI依赖注入 二、ApplicationContext快速入门2.1 入门2.2 BeanFactory 与 ApplicationContext关系2.3 BeanFactory 继承体系2.4 ApplicationContext 继承体系 Spring Bean 之前也了解过Spring Bean&a…

高斯过程回归 | Matlab实现高斯过程回归多输入单输出预测(Gaussian Process Regression)

文章目录 效果一览文章概述研究内容程序设计参考资料效果一览 文章概述 高斯过程回归 | Matlab实现高斯过程回归多输入单输出预测(Gaussian Process Regression) 研究内容 高斯过程回归(Gaussian Process Regression,GPR)是一种基于概率模型的非参数回归方法,可以用于