NNLM - 神经网络语言模型 | 高效的单词预测工具

news2025/1/21 6:30:25

本系列将持续更新NLP相关模型与方法,欢迎关注!

简介

神经网络语言模型(NNLM)是一种人工智能模型,用于学习预测词序列中下一个词的概率分布。它是自然语言处理(NLP)中的一个强大工具,在机器翻译、语音识别和文本生成等领域都有广泛的应用。

Paper - A Neural Probabilistic Language Model(2003)[1]

原理

NNLM 首先学习词的分布式表示,也称为词嵌入,它捕捉了词之间的语义相似性。然后将这些嵌入输入到神经网络模型中,通常是一个前馈神经网络或循环神经网络(RNN),该模型根据前面的词提供的上下文来学习预测序列中的下一个词。

alt

例如,给定句子“猫在坐在”,NNLM 可能会高概率地预测下一个词为“地板”,因为这是给定上下文的常见补充。

示例

假设我们有一个大型的文本语料库,比如一系列新闻文章。我们可以对这些数据进行 NNLM 训练,以学习单词和它们上下文之间的关系。训练完成后,模型可以生成连贯和与上下文相关的句子。

例如,如果我们提供初始短语“人工智能是”,NNLM 可能生成以下完成句子:“人工智能正在改变行业,重塑未来的工作。”

应用

  1. 机器翻译: NNLM 在机器翻译系统中发挥作用,通过预测源语言上下文的下一个词来生成流畅且准确的翻译。
  2. 语音识别: NNLM 在语音识别系统中起着至关重要的作用,通过从口语表达中预测最可能的词序列。
  3. 文本生成: NNLM 在各种文本生成任务中使用,包括对话生成、故事生成和内容摘要,在这些任务中,它们基于给定的输入生成连贯且与上下文相关的文本。
  4. 语言建模: NNLM 作为语言建模任务的基础,用于估计在给定上下文中序列单词发生的概率。这在拼写检查、自动完成和语法错误检测等任务中特别有用。

Code

# code by Tae Hwan Jung @graykode
import torch
import torch.nn as nn
import torch.optim as optim

def make_batch():
    input_batch = []
    target_batch = []

    for sen in sentences:
        word = sen.split() # space tokenizer
        input = [word_dict[n] for n in word[:-1]] # create (1~n-1) as input
        target = word_dict[word[-1]] # create (n) as target, We usually call this 'casual language model'

        input_batch.append(input)
        target_batch.append(target)

    return input_batch, target_batch

# Model
class NNLM(nn.Module):
    def __init__(self):
        super(NNLM, self).__init__()
        self.C = nn.Embedding(n_class, m)
        self.H = nn.Linear(n_step * m, n_hidden, bias=False)
        self.d = nn.Parameter(torch.ones(n_hidden))
        self.U = nn.Linear(n_hidden, n_class, bias=False)
        self.W = nn.Linear(n_step * m, n_class, bias=False)
        self.b = nn.Parameter(torch.ones(n_class))

    def forward(self, X):
        X = self.C(X) # X : [batch_size, n_step, m]
        X = X.view(-1, n_step * m) # [batch_size, n_step * m]
        tanh = torch.tanh(self.d + self.H(X)) # [batch_size, n_hidden]
        output = self.b + self.W(X) + self.U(tanh) # [batch_size, n_class]
        return output

if __name__ == '__main__':
    n_step = 2 # number of steps, n-1 in paper
    n_hidden = 2 # number of hidden size, h in paper
    m = 2 # embedding size, m in paper

    sentences = ["i like dog""i love coffee""i hate milk"]

    word_list = " ".join(sentences).split()
    word_list = list(set(word_list))
    word_dict = {w: i for i, w in enumerate(word_list)}
    number_dict = {i: w for i, w in enumerate(word_list)}
    n_class = len(word_dict)  # number of Vocabulary

    model = NNLM()

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    input_batch, target_batch = make_batch()
    input_batch = torch.LongTensor(input_batch)
    target_batch = torch.LongTensor(target_batch)

    # Training
    for epoch in range(5000):
        optimizer.zero_grad()
        output = model(input_batch)

        # output : [batch_size, n_class], target_batch : [batch_size]
        loss = criterion(output, target_batch)
        if (epoch + 1) % 1000 == 0:
            print('Epoch:''%04d' % (epoch + 1), 'cost =''{:.6f}'.format(loss))

        loss.backward()
        optimizer.step()

    # Predict
    predict = model(input_batch).data.max(1, keepdim=True)[1]

    # Test
    print([sen.split()[:2for sen in sentences], '->', [number_dict[n.item()] for n in predict.squeeze()])

总的来说,神经网络语言模型(NNLM)是自然语言处理中的强大工具,利用神经网络架构来预测文本序列中的下一个词。从机器翻译到文本生成,NNLM 继续推动人工智能在理解和生成人类语言方面的能力。

Reference
[1]

paper: http://www.jmlr.org/papers/volume3/bengio03a/bengio03a.pdf

本文由 mdnice 多平台发布

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1455018.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

从kafka如何保证数据一致性看通常数据一致性设计

一、前言 在数据库系统中有个概念叫事务,事务的作用是为了保证数据的一致性,意思是要么数据成功,要么数据失败,不存在数据操作了一半的情况,这就是数据的一致性。在很多系统或者组件中,很多场景都需要保证…

IO进程线程作业day1

1> 使用fgets统计给定文件的行数 #include <stdio.h> #include <stdlib.h> #include <string.h> #include <math.h> #include <unistd.h> int main(int argc, const char *argv[]) {//判断外部输入文件名是否规范if(argc!2){printf("in…

K8S之运用污点、容忍度设置Pod的调度约束

污点、容忍度 污点容忍度 taints 是键值数据&#xff0c;用在节点上&#xff0c;定义污点&#xff1b; tolerations 是键值数据&#xff0c;用在pod上&#xff0c;定义容忍度&#xff0c;能容忍哪些污点。 污点 污点是定义在k8s集群的节点上的键值属性数据&#xff0c;可以决…

ARMv8-AArch64 的异常处理模型详解之异常处理详解(进入异常以及异常路由)

在上篇文章 ARMv8-AArch64 的异常处理模型详解之异常处理概述Handling exceptions中&#xff0c;作者对异常处理整体流程以及相关概念做了梳理。接下来&#xff0c;本文将详细介绍处理器在获取异常、异常处理以及异常返回等过程中都做了哪些工作。 ARMv8-AArch64 的异常处理模型…

批量追踪中通快递

在物流信息的管理中&#xff0c;批量追踪中通快递单号一直是个让人头疼的问题。但有了固乔快递查询助手&#xff0c;这一切都变得轻而易举。 固乔快递查询助手&#xff0c;作为市场上备受好评的快递查询软件&#xff0c;专门针对批量查询需求进行了优化。用户只需将中通快递单号…

鸿蒙生态来了 ,60k 高薪向你招手

最近&#xff0c;各大平台都被华为鸿蒙不断刷屏。原因是在华为秋季发布会上&#xff0c;华为宣布启动鸿蒙原生应用&#xff0c;不再兼容安卓应用。一石激起千层浪&#xff0c;这无疑是IT界的一颗核弹&#xff0c;各大企业和开发者都纷纷开始加入“鸿蒙朋友圈”。 鸿蒙原生应用…

数据分析 — Matplotlib 、Pandas、Seaborn 绘图

目录 一、Matplotlib1、折线图2、柱状图3、水平条形图4、直方图5、散点图6、饼图 二、pandas1、折线图2、柱状图 三、seaborn1、散点图2、箱线图3、直方核密度图4、成对图 一、Matplotlib Matplotlib 是一个用于绘制数据可视化图形的 Python 库。它提供了丰富的绘图工具&#…

Eliminating Domain Bias for Federated Learning in Representation Space【文笔可参考】

文章及作者信息&#xff1a; NIPS2023 Jianqing Zhang 上海交通大学 之前中的NeurIPS23论文刚今天传到arxiv上&#xff0c;这次我把federated learning的每一轮看成是一次bi-directional knowledge transfer过程&#xff0c;提出了一种促进server和client之间bi-direction…

浅析Linux设备驱动:IO端口和IO内存

文章目录 概述IO端口和IO内存的区别 IO资源管理IO资源类型IO端口资源IO内存资源 IO资源分配 IO端口访问IO端口操作函数 IO内存访问IO内存操作函数 相关参考 概述 在计算机系统中&#xff0c;外部设备通常会提供一组寄存器或内存用于处理器配置和访问设备功能。这些寄存器或内存…

MCU电源控制(PWR)与低功耗

目录 一、STM32 的内核和外设电源系统管理&#xff1a; 二、MCU电源监控&#xff1a; 三、三种低功耗模式&#xff1a; 1、睡眠模式&#xff1a; 2、停止模式&#xff1a; 3、待机模式&#xff1a; 一、STM32 的内核和外设电源系统管理&#xff1a; ① 电池备份区域&#…

思迈特再获国家权威认证:代码自主率98.78%

日前&#xff0c;思迈特软件自主研发的商业智能与数据分析软件&#xff08;Smartbi Insight&#xff09;通过中国赛宝实验室&#xff08;工业和信息化部电子第五研究所&#xff09;代码扫描测试&#xff0c;Smartbi Insight V11版本扫描测得代码自主率为98.78%的好成绩&#xf…

原创详解OpenAI Sora是什么?技术先进在哪里?能够带来什么影响?附中英文技术文档

一&#xff1a;Sora是什么 Sora是一个文本到视频的模型&#xff0c;由美国的人工智能研究机构OpenAI开发。Sora可以根据描述性的文本提示&#xff0c;生成高质量的视频&#xff0c;也可以根据已有的视频&#xff0c;向前或向后延伸&#xff0c;生成更长的视频。 Sora的主要功…

【完全二叉树节点数!】【深度优先】【广度优先】Leetcode 222 完全二叉树的节点个数

【完全二叉树】【深度优先】【广度优先】Leetcode 222 完全二叉树的节点个数 :star:解法1 按照完全二叉树解法2 按照普通二叉树&#xff1a;深度优先遍历 后序 左右中解法3 按照普通二叉树&#xff1a;广度优先遍历 层序遍历 ---------------&#x1f388;&#x1f388;题目链接…

【漏洞复现-通达OA】通达OA swfupload_new存在前台SQL注入漏洞

一、漏洞简介 通达OA(Office Anywhere网络智能办公系统)是由北京通达信科科技有限公司自主研发的协同办公自动化软件,是与中国企业管理实践相结合形成的综合管理办公平台。通达OA为各行业不同规模的众多用户提供信息化管理能力,包括流程审批、行政办公、日常事务、数据统计…

~汉诺塔~(C语言)~

引言 汉诺塔&#xff08;Hanoi Tower&#xff09;&#xff0c;又称河内塔&#xff0c;源于印度一个古老传说。大梵天创造世界的时候做了三根金刚石柱子&#xff0c;在一根柱子上从下往上按照大小顺序摞着64片黄金圆盘。大梵天命令婆罗门把圆盘从上面开始按大小顺序重新摆放在…

[计算机网络]深度学习传输层TCP协议

&#x1f493; 博客主页&#xff1a;从零开始的-CodeNinja之路 ⏩ 收录专栏&#xff1a;深度学习传输层TCP协议 &#x1f389;欢迎大家点赞&#x1f44d;评论&#x1f4dd;收藏⭐文章 [计算机网络]深度学习传输层TCP协议 前提概括一: TCP协议段格式二:确认应答三:超时重传四:…

IgG1 (mouse), ELISA kit——ENZO热销产品

90分钟内可得结果的高特异性定量ELISA试剂盒 免疫球蛋白G&#xff08;IgG&#xff09;是一种免疫球蛋白单体&#xff0c;由两条&#xff08;γ&#xff09;重链和两条轻链组成。每个IgG分子包含两个抗原结合域和一个效应&#xff08;Fc&#xff09;域。Enzo Life Sciences可提供…

【hcie-cloud】【29】华为云Stack数据安全服务

文章目录 前言数据安全概述数据产业发展和敏感数据上云趋势下对数据安全的需求重大隐私数据泄露事件云端数据安全问题成为业务上云的主要障碍数据安全相关法律法规密集出台数据安全法 - 欧盟的GDPR中国的数据安全法端到端考虑数据安全数据安全生命周期华为云Stack全生命周期数据…

七、Mybatis缓存

缓存就是内存中的数据&#xff0c;常常来自对数据库查询结果的保存&#xff0c;使用缓存、可以避免频繁的与数据库进行交互&#xff0c;进而提高响应速度一级缓存是sqlSession级别的缓存&#xff0c;在操作数据库时需要构造sqlsession对象&#xff0c;在对象中有一个数据结构&a…

WEB APIs(2)

应用定时器可以写一个定时轮播图&#xff0c;如下 <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8"><meta http-equiv"X-UA-Compatible" content"IEedge"><meta name"viewport&qu…