Pytorch intermediate(三) BiLSTM

news2025/1/8 21:38:57

Bi-directional Long Short-Term Memory,双向LSTM网络。

       有些时候预测可能需要由前面若干输入和后面若干输入共同决定,这样会更加准确。因此提出了双向循环神经网络,网络结构如上图。


       构建LSTM模型时,在参数中添加bidirectional=True,这样就构建了一个双向的LSTM模型。初始化参数时,全连接层的隐藏层特征数量x2,h0和c0参数也要相应改变。

import torch 
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

sequence_length = 28
input_size = 28
hidden_size = 128
num_layers = 2
num_classes = 10
batch_size = 100
num_epochs = 2
learning_rate = 0.003

# MNIST dataset
train_dataset = torchvision.datasets.MNIST(root='./data/',
                                           train=True, 
                                           transform=transforms.ToTensor(),
                                           download=True)

test_dataset = torchvision.datasets.MNIST(root='./data/',
                                          train=False, 
                                          transform=transforms.ToTensor())

# Data loader
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=batch_size, 
                                           shuffle=True)

test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                          batch_size=batch_size, 
                                          shuffle=False)

class BiRNN(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, num_classes):
        super(BiRNN, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, bidirectional=True)
        self.fc = nn.Linear(hidden_size*2, num_classes)  # 2 for bidirection
    
    def forward(self, x):
        # Set initial states
        h0 = torch.zeros(self.num_layers*2, x.size(0), self.hidden_size).to(device) # 2 for bidirection 
        c0 = torch.zeros(self.num_layers*2, x.size(0), self.hidden_size).to(device)
        
        # Forward propagate LSTM
        out, _ = self.lstm(x, (h0, c0))  # out: tensor of shape (batch_size, seq_length, hidden_size*2)
        
        # Decode the hidden state of the last time step
        out = self.fc(out[:, -1, :])
        return out

model = BiRNN(input_size, hidden_size, num_layers, num_classes).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

total_step = len(train_loader)
for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):
        images = images.reshape(-1, sequence_length, input_size).to(device)
        labels = labels.to(device)
        
        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if (i+1) % 100 == 0:
            print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}' 
                   .format(epoch+1, num_epochs, i+1, total_step, loss.item()))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1001786.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

面试半个月后的一些想法

源于半个月面试经历后的一些想法,刚开始想的是随便写写,没想到居然写了这么多。 找不到目标找不到意义亦或是烦躁的时候,就写写文章吧,把那些困扰你很久的问题铺开来 花时间仔细想想,其实真正让我们生气懊恼&#xff0…

轻松搭建本地知识库的ChatGLM2-6B

近期发现了一个项目,它的前身是ChatGLM,在我之前的博客中有关于ChatGLM的部署过程,本项目在前者基础上进行了优化,可以基于当前主流的LLM模型和庞大的知识库,实现本地部署自己的ChatGPT,并可结合自己的知识…

华为OD机考算法题:分奖金

题目部分 题目分奖金难度难题目说明公司老板做了一笔大生意,想要给每位员工分配一些奖金,想通过游戏的方式来决定每个人分多少钱。按照员工的工号顺序,每个人随机抽取一个数字。按照工号的顺序往后排列,遇到第一个数字比自己数字…

我们这一代人的机会是什么?

大家好,我是苍何,今天作为专业嘉宾参观了 2023 年中国国际智能产业博览会(智博会),是一场以「智汇八方,博采众长」为主题的汇聚全球智能技术和产业创新的盛会,感触颇深,随着中国商业…

在Linux(Centos7)上编译whisper.cpp的详细教程

whisper.cpp的简单介绍: Whisper 是 OpenAI 推出的一个自动语音识别(ASR)系统,whisper.cpp 则是 Whisper 模型的 C/C 移植。whisper.cpp 具有无依赖项、内存使用量低等特点,支持 Mac、Windows、Linux、iOS 和 Android …

免费开源音乐聚合软件-洛雪音乐助手

一、软件介绍 洛雪音乐助手(LX Music),一个基于 Electron Vue 开发的免费开源音乐聚合软件,软件聚合多个音乐平台搜索接口,提供音乐在线播放和下载,而且免费无广告,支持在Windows、Mac OS、L…

CSP 202109-1 数值推导

答题 一眼看上去好像有点复杂&#xff0c;稍微想一下就知道&#xff0c;最大值就是把所有B加起来&#xff0c;最小值就是把不重复的B加起来&#xff0c;用set搞定不重复的元素 #include<iostream> #include<set> using namespace std; int main(){int n,max0,min0…

ARM接口编程—Interrupt(exynos 4412平台)

CPU与硬件的交互方式 轮询 CPU执行程序时不断地询问硬件是否需要其服务&#xff0c;若需要则给予其服务&#xff0c;若不需要一段时间后再次询问&#xff0c;周而复始中断 CPU执行程序时若硬件需要其服务&#xff0c;对应的硬件给CPU发送中断信号&#xff0c;CPU接收到中断信号…

SOC 2.0安全运营中心

SOC&#xff0c;安全运营中心&#xff0c;为取得其最佳效果&#xff0c;以及真正最小化网络风险&#xff0c;需要全员就位&#xff0c;让安全成为每个人的责任。 早在几年前&#xff0c;企业就开始创建SOC来集中化威胁与漏洞的监视和响应。第一代SOC的目标&#xff0c;是集中管…

【使用Cpolar和Qchan搭建自己的个人图床】

文章目录 前言1. Qchan网站搭建1.1 Qchan下载和安装1.2 Qchan网页测试1.3 cpolar的安装和注册 2. 本地网页发布2.1 Cpolar云端设置2.2 Cpolar本地设置 3. 公网访问测试总结 前言 图床作为云存储的一项重要应用场景&#xff0c;在大量开发人员的努力下&#xff0c;已经开发出大…

TSMC逻辑制程技术命名和发展

参考来源&#xff1a; 逻辑制程 - 台湾积体电路制造股份有限公司 (tsmc.com) 1、 TSMC逻辑制程发展路线图 2、从65nm到3nm制程技术发展的简介 2.1 65nm 制程 Y2005&#xff1a;成功试产65nm晶片&#xff1b; Y2006&#xff1a;成功通过65nm制程技术的产品验证&#xff0c;推…

asio中的锁

asio到底有没有锁 asio是有锁的&#xff0c;所以规避锁的写法还是值得研究的 windows中的锁 先来张截图&#xff1a; dispatch_mutex_主要是为了保护定时器队列和完成端口回调的队列。 保护定时器队列 保护完成端口回调的队列 在PostQueuedCompletionStatus失败时&#x…

React总结1

3 React技术 React是Facebook于2013年开源的框架。React解决的是前端MVC框架中的View视图层的问题。 3.1 Virtual DOM* DOM&#xff08;文档对象模型Document Object Model&#xff09; 将网页内所有内容映射到一棵树型结构的层级对象模型上&#xff0c;浏览器提供对DOM的支…

构建全面 AI Agent 解决方案:Chocolate Factory 框架的文本到 UI、图表和测试用例生成...

长太不看版&#xff1a;基于领域驱动设计思考的 AI Agent 框架 Chocolate Factory&#xff0c;框架现在还在 PoC 阶段&#xff0c;欢迎加入开发。&#xff08;当前主要关注于 SDLC AIGC 的场景&#xff09;。 GitHub&#xff1a;https://github.com/unit-mesh/chocolate-facto…

MySQL的架构和性能优化

一、架构 MySQL逻辑架构整体分为三层&#xff0c;最上层为客户端&#xff0c;并非MySQL独有&#xff0c;诸如&#xff1a;连接处理&#xff0c;授权认证&#xff0c;安全等功能均在这一层处理 MySQL大多数核心服务均在中间这一层&#xff0c;包括查询解析&#xff0c;分析优化…

【Mysql】数据库第三讲(表的约束、基本查询语句)

表的约束和基本查询 1.表的约束1.1 空属性1.2默认值1.3列描述1.4 zerofill1.5主键1.6 自增长1.7 唯一键1.8外键 1.表的约束 真正约束字段的是数据类型&#xff0c;但是数据类型约束很单一&#xff0c; 需要有一些额外的约束&#xff0c; 更好的保证数据的合法性&#xff0c;从…

【Flowable】FlowableUI使用以及在IDEA使用flowable插件(二)

前言 之前有需要使用到Flowable&#xff0c;鉴于网上的资料不是很多也不是很全也是捣鼓了半天&#xff0c;因此争取能在这里简单分享一下经验&#xff0c;帮助有需要的朋友&#xff0c;也非常欢迎大家指出不足的地方。 一、部署FlowableUI 1.准备war包 在这里提供了&#xf…

Java之Hashset的原理及解析

4.数据结构 4.1二叉树【理解】 二叉树的特点 二叉树中,任意一个节点的度要小于等于2 节点: 在树结构中,每一个元素称之为节点 度: 每一个节点的子节点数量称之为度 二叉树结构图 4.2二叉查找树【理解】 二叉查找树的特点 二叉查找树,又称二叉排序树或者二叉搜索树 每一…

【算法专题突破】双指针 - 最大连续1的个数 III(11)

目录 1. 题目解析 2. 算法原理 3. 代码编写 写在最后&#xff1a; 1. 题目解析 题目链接&#xff1a;1004. 最大连续1的个数 III - 力扣&#xff08;Leetcode&#xff09; 这道题不难理解&#xff0c;其实就是求出最长的连续是1的子数组&#xff0c; 但是&#xff0c;他支…