【神经网络与深度学习】Long short-term memory网络(LSTM)

news2024/12/24 21:40:36

简单介绍

在这里插入图片描述
API介绍:

nn.LSTM(input_size=100, hidden_size=10, num_layers=1,batch_first=True, bidirectional=True)

inuput_size: embedding_dim
hidden_size: 每一层LSTM单元的数量
num_layers: RNN中LSTM的层数
batch_first: True对应[batch_size, seq_len, embedding_dim]
bidiectional: True对应使用双向LSTM

在这里插入图片描述
实例化LSTM对象后,不仅要传入数据,还有传入前一次的h_0和c_0
lstm(input, (h_0, c_0))
LSTM默认输出(output, (h_n, c_n))
output: [ seq_len, batch, hidden_size*num_directions ] (若batch_first=false)
h_n: [num_directions, batch, hidden_size]
c_n : [num_directions, batch, hidden_size]

import torch.nn as nn
import torch.nn.functional as F
import torch

batch_size = 10
seq_len =20 #句子长度
vocab_size = 100 # 词典数量
embedding_dim = 30 # 用embedding_dim长度的向量表示一个词语
hidden_size = 18

input = torch.randint(0, 100, [batch_size, seq_len])
print(input.size())
print("*"*100)
# 经过embedding
embed = nn.Embedding(vocab_size, embedding_dim)

input_embed = embed(input)  # [bs, seq_len, embedding_dim]
print(input_embed.size())
print("*"*100)
lstm = nn.LSTM(embedding_dim, hidden_size=hidden_size, num_layers=1, batch_first=True)
output,(h_n, c_n) = lstm(input_embed)
print(output.size())
print("*"*100)
print(h_n.size())
print("*"*100)
print(c_n.size())

通常由最后一个输出代替整个句子

使用双向LSTM实现

"""
定义模型
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from lib import ws,max_len
from dataset import get_data
import lib
import os
import numpy as np
class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.embedding = nn.Embedding(len(ws), 100)
        self.lstm = nn.LSTM(input_size=100, hidden_size=lib.hidden_size, num_layers=lib.num_layers,batch_first=True, bidirectional=lib.bidirectional, dropout=lib.dropout)
        self.fc = nn.Linear(lib.hidden_size*2, 2)
        



    def forward(self, input):
        """

        :param input: [batch_size, max_len]
        :return:
        """
        x = self.embedding(input) # [batch_size, max_len, 100]
        x,(h_n,c_n)= self.lstm(x)
        output = torch.cat([h_n[-2,:,:],h_n[-1,:,:]],dim=-1)
        output = self.fc(output)
        return F.log_softmax(output,dim=-1)

model = MyModel().to(lib.device)
optimizer = torch.optim.Adam(model.parameters(),lr=0.001)
if os.path.exists("./model0/model.pkl"):
    model.load_state_dict(torch.load("./model0/model.pkl"))
    optimizer.load_state_dict(torch.load("./model0/optimizer.pkl"))

def train(epoch):
    for idx,(input,target) in enumerate(get_data(train=True)):
        input = input.to(lib.device)
        target = target.to(lib.device)
        # 梯度清零
        optimizer.zero_grad()
        output= model(input)
        loss = F.nll_loss(output,target)
        loss.backward()
        optimizer.step()
        print(epoch, idx, loss.item())

        if idx%100==0:
            torch.save(model.state_dict(),"./model0/model.pkl")
            torch.save(optimizer.state_dict(),"./model0/optimizer.pkl")

def eval():
    loss_list = []
    acc_list = []
    for idx,(input,target) in enumerate(get_data(train=False, batch_size=lib.test_batch_size)):
        input = input.to(lib.device)
        target = target.to(lib.device)
        with torch.no_grad():
            output= model(input)
            loss = F.nll_loss(output,target)
            loss_list.append(loss.cpu().item())
            pre = output.max(dim=-1)[-1]
            acc = pre.eq(target).float().mean()
            acc_list.append(acc.cpu().item())

    print("total loss, acc:", np.mean(loss_list), np.mean(acc_list))




if __name__ == '__main__':
    for i in range(10):
        train(epoch=i)

    eval()






本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1597549.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

nginx-ingress详解

一、ingress概述 1、概述 Kubernetes是一个拥有强大故障恢复功能的集群,当pod挂掉时,集群会重新创建一个pod出来,但是pod的IP也会随之发生变化,为了应对这种情况,引入了service,通过service的标签匹配&am…

Python Flask Web 框架-API接口开发_4

一、1、安装 Falsk 当前用户安装 pip3 install --user Flask 确认安装成功: 进入python交互模式看下Flask的介绍和版本: $ python3>>> import flask >>> print(flask.__doc__)flask~~~~~A microframework based on Werkzeug. Its …

【Leetcode】代码随想录Day16|二叉树3.0

文章目录 104 二叉树的最大深度559 n叉树的最大深度111 二叉树的最小深度222 完全二叉树的节点个数 104 二叉树的最大深度 递归法:无论是哪一种顺序,标记最大深度 class Solution(object):def depthHelper(self, root, depth):if root:depth 1left_de…

GPT 交互式提示工程

简介:交互式提示工程 人工智能领域,尤其是 GPT(生成式预训练变压器)等工具,凸显了即时工程的关键作用。 这篇扩展文章深入探讨了如何设计有效的提示,以从 GPT 等 AI 模型中获得出色的响应。 了解即时工程即…

尚硅谷html5+css3(4)浮动

1.浮动的概念 <head><style>.box1 {width: 200px;height: 200px;background-color: orange;/*通过浮动可以使一个元素向其父元素的左侧或右侧移动使用float属性设置子资源的浮动可选值&#xff1a;none默认值&#xff0c;元素不浮动left向左浮动right向右浮动注意…

分布式监控平台---Zabbix

一、Zabbix概述 作为一个运维&#xff0c;需要会使用监控系统查看服务器状态以及网站流量指标&#xff0c;利用监控系统的数据去了解上线发布的结果&#xff0c;和网站的健康状态。 利用一个优秀的监控软件&#xff0c;我们可以&#xff1a; 通过一个友好的界面进行浏览整个…

SETR——Rethinking系列工作,展示使用纯transformer在语义分割任务上是可行的,但需要很强的训练技巧

题目:Rethinking Semantic Segmentation from a Sequence-to-Sequence Perspective with Transformers 作者: 开源:https://fudan-zvg.github.io/SETR 1.研究背景 1.1 为什么要研究这个问题? 自[ 36 ]的开创性工作以来,现有的语义分割模型主要是**基于全卷积网络( FCN )的…

【机器学习300问】69、为什么深层神经网络比浅层要好用?

要回答这个问题&#xff0c;首先得知道神经网络都在计算些什么东西&#xff1f;之前我在迁移学习的文章中稍有提到&#xff0c;跳转链接在下面&#xff1a; 为什么其他任务预训练的模型参数&#xff0c;可以在我这个任务上起作用&#xff1f;http://t.csdnimg.cn/FVAV8 …

go work模块与go mod包管理是的注意事项

如下图所示目录结构 cmd中是服务的包&#xff0c;显然auth,dbtables,pkg都是为cmd服务的。 首先需要需要将auth,dbtables,pkg定义到go.work中&#xff0c;如下&#xff1a; 在这样在各个单独的go mod管理的模块就可以互相调用了。一般情况下这些都是IDE自动进行的&#xff0c;…

js纯前端实现语音播报,朗读功能(2024-04-15)

实现语音播报要有两个原生API 分别是【window.speechSynthesis】【SpeechSynthesisUtterance】 项目代码 // 执行函数 initVoice({text: 项目介绍,vol: 1,rate: 1 })// 函数 export function initVoice(config) {window.speechSynthesis.cancel();//播报前建议调用取消的函数…

CSS使用自己的字体

在项目的根目录下的static文件夹中放置字体文件。在项目中使用这个字体&#xff0c;需要2个步骤。 一. 你需要在全局样式文件中引入它。 假设你的全局样式文件是App.vue或者App.vue中引入的App.scss文件&#xff0c;你可以像这样引入字体文件&#xff1a; font-face {font-fa…

自然语言控制机械臂:ChatGPT与机器人技术的融合创新(下)

引言 在我们的上一篇文章中&#xff0c;我们探索了如何将ChatGPT集成到myCobot 280机械臂中&#xff0c;实现了一个通过自然语言控制机械臂的系统。我们详细介绍了项目的动机、使用的关键技术如ChatGPT和Google的Speech-to-text服务&#xff0c;以及我们是如何通过pymycobot模块…

C++面向对象程序设计 - 类和对象进一步讨论

在C中&#xff0c;关于面向对象程序设计已经讲了很大篇幅&#xff0c;也例举很多案例&#xff0c;此篇将通过一些习题来进一步了解对象、静态成员、指针、引用、友元、类模板等等相关知识。 一、习题一&#xff08;构造函数默认参数&#xff09; 示例代码&#xff1a; #includ…

主流App UI设计,7个在线平台供您选择!

在数字时代&#xff0c;用户界面&#xff08;UI&#xff09;设计变得非常重要&#xff0c;因为良好的UI设计可以改善用户体验&#xff0c;增强产品吸引力。随着技术的发展&#xff0c;越来越多的应用程序 ui在线设计网站的出现为设计师和团队提供了一种全新的创作方式。本文将盘…

关于Wordpress的操作问题1:如何点击菜单跳转新窗口

1.如果打开&#xff0c;外观-菜单-菜单结构内&#xff0c;没有打开新窗口属性&#xff0c;如图&#xff1a; 2.在页面的最上部&#xff0c;点开【显示选项】&#xff0c;没有这一步&#xff0c;不会出现新跳转窗口属性 3.回到菜单结构部分&#xff0c;就出现了

Mysql嵌套查询太简单了

1、子查询的分类 不相关查询&#xff1a; 子查询能独立执行 相关查询&#xff1a; 子查询不能独立运行 相关查询的执行顺序&#xff1a; 首先取外层查询中表的第一个元组,根据它与内层查询相关的属性值处理内层查询, 若WHERE子句返回值为真&#xff0c;则取此元组放入结果…

CSS基础:margin属性4种值类型,4个写法规则详解

你好&#xff0c;我是云桃桃。 一个希望帮助更多朋友快速入门 WEB 前端的程序媛。大专生&#xff0c;一枚程序媛&#xff0c;感谢关注。回复 “前端基础题”&#xff0c;可免费获得前端基础 100 题汇总&#xff0c;回复 “前端工具”&#xff0c;可获取 Web 开发工具合集 268篇…

PaddleOCR训练自己模型(1)----数据准备

一、下载地址&#xff1a; PaddleOCR开源代码&#xff08;下载的是2.6RC版本的&#xff0c;可以根据自己需求下载&#xff09; 具体环境安装就不详细介绍了&#xff0c; 挺简单的&#xff0c;也挺多教程的。 二、数据集准备及制作 &#xff08;1&#xff09;下载完代码及配置…

Navicat for MySQL 使用基础与 SQL 语言的DDL

一、目的&#xff1a; Navicat for MySQL 是一套专为 MySQL 设计的高性能数据库管理及开发 工具。它可以用于任何版本 3.21 或以上的 MySQL 数据库服务器&#xff0c;并支持大 部份 MySQL 最新版本的功能&#xff0c;包括触发器、存储过程、函数、事件、视图、 管理用户等。…

软件工程及开发模型

根据希赛相关视频课程汇总整理而成&#xff0c;个人笔记&#xff0c;仅供参考。 软件工程的基本要素包括方法、工具和&#xff08;过程&#xff09; 方法&#xff1a;完成软件开发的各项任务的技术方法&#xff1b; 工具&#xff1a;运用方法而提供的软件工程支撑环境&#xff…