循环神经网络RNN用于分类任务

news2025/1/15 6:29:05

RNN是一类拥有隐藏状态,允许以前的输出可用于当前输入的神经网络,

 输入一个序列,对于序列中的每个元素与前一个元素的隐藏状态一起作为RNN的输入,通过计算当前的输出和隐藏状态。当前的影藏状态作为下一个单元的输入...

 

 RNN的种类

上图中的红色方块代表输入,蓝色方块代表输出,绿色方块代表

一对一(one to one)

一对一的结构一般指经典的卷积神经网络和全连接神经网络,被誉为香草神经网络,它们处理固定大小的输入,并输出固定大小的输出,主要用于图像分类任务。

下面介绍循环神经网络RNN的几种结构和其应用:

多对一(many to one)

多对一结构用于分类,如情感分类(输入一个句子,判断是积极还是消极的情感)

 一对多(one to many)

一对多可用于生成,如音乐生成,还可以用于图像字幕(输入一张图片,输出一个文本句子)

多对多(many to many)

多对多结构可用于机器翻译(输入一个英文句子,输出一个中文句子)和视频分类(视频中的多个帧作为输入,对视频中的每一帧进行分类)

RNN多对一用于图像分类

下面介绍使用RNN进行分类(图像分类)的方法。RNN用于分类任务的结构是多对一结构。假设图像的尺寸是28x28的黑白图片,将每一个张图片作为RNN的输入,每一行作为一个输入特征,所有的行构成一个序列。

RNN的输出只选择最后一行的输出(使用out[:, -1, :]),这就是多对一的精髓。最后再套一个全连接层,使用交叉熵作为判断准则。

RNN图像分类的示例代码

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader

import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Hyper-parameters 
# input_size = 784  # 28x28
hidden_size = 128
num_classes = 10
num_epochs = 2
batch_size = 100
learning_rate = 0.001

input_size = 28
sequence_length = 28
num_layers = 2

# MNIST dataset 
train_dataset = torchvision.datasets.MNIST(root='./data', 
                                           train=True, 
                                           transform=transforms.ToTensor(),  
                                           download=True)

test_dataset = torchvision.datasets.MNIST(root='./data', 
                                          train=False, 
                                          transform=transforms.ToTensor())

# Data loader
train_loader = DataLoader(dataset=train_dataset,
                                           batch_size=batch_size, 
                                           shuffle=True)

test_loader = DataLoader(dataset=test_dataset,
                                          batch_size=batch_size, 
                                          shuffle=False)

examples = iter(test_loader)
example_data, example_targets = next(examples)

# for i in range(6):
#     plt.subplot(2,3,i+1)
#     plt.imshow(example_data[i][0], cmap='gray')
# plt.show()

# Fully connected neural network with one hidden layer
class RNN(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, num_classes):
        super(RNN, self).__init__()
        self.num_layers = num_layers
        self.hidden_size = hidden_size
        self.rnn = nn.RNN(input_size, hidden_size, num_layers, batch_first=True)
        #  x: (batch_size, sequence_len, input_size)
        self.fc = nn.Linear(hidden_size, num_classes)

    def forward(self, x):
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)

        out, _ = self.rnn(x, h0)
        # out: batch_size, seq_length, hidden_size
        # out: (N, 28, 128)
        out = out[:, -1, :]
        # out (N, 128)
        out = self.fc(out)
        return out

model = RNN(input_size, hidden_size, num_layers,num_classes).to(device)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)  

# Train the model
n_total_steps = len(train_loader)
for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):  
        # origin shape: [100, 1, 28, 28]
        # resized: [100, 28, 28]
        images = images.reshape(-1, sequence_length, input_size).to(device)
        labels = labels.to(device)
        
        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if (i+1) % 100 == 0:
            print (f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{n_total_steps}], Loss: {loss.item():.4f}')

# Test the model
# In test phase, we don't need to compute gradients (for memory efficiency)
with torch.no_grad():
    n_correct = 0
    n_samples = 0
    for images, labels in test_loader:
        images = images.reshape(-1, sequence_length, input_size).to(device)
        labels = labels.to(device)
        outputs = model(images)
        # max returns (value ,index)
        _, predicted = torch.max(outputs.data, 1)
        n_samples += labels.size(0)
        n_correct += (predicted == labels).sum().item()

    acc = 100.0 * n_correct / n_samples
    print(f'Accuracy of the network on the 10000 test images: {acc} %')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/654664.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

AMEYA:如何设计好DC-DC电源,注意事项有哪些

DC-DC变换器(DC-DC converter)是指在直流电路中将一个电压值的电能变为另一个电压值的电能的装置。DC-DC的layout非常重要,会直接影响到产品的稳定性与EMI效果。 DC-DC电源几点经验以及规则 1、处理好反馈环,反馈线不要走肖特基下…

基于JavaWeb的体育赛事平台的设计与实现

摘要 体育是随着社会生产力的发展而产生和发展的,在其漫长的历史中,由于社会、政治和经济发展的影响,其内容、形式、功能和操作方法不断变化。奥运会和世界杯等大型体育赛事代表着体育发展的顶峰,因为它们不仅给组织者带来了巨大…

【考研复习】李春葆新编C语言习题与解析(错误答案订正)持续更新

新编C语言习题与解析 做习题时发现有些错误答案,写篇博客进行改正记录。不对地方欢迎指正~ 第二章 C. 其中b的表达形式错误,若加上0x1e2b则正确。所以C错误。 D. e后为整数。指数命名规则:e前有数,后有整数。所以D错…

实验篇(7.2) 15. 站对站安全隧道 - 多条隧道聚合(FortiGate-IPsec) ❀ 远程访问

【简介】虽然隧道冗余可以解决连接问题,但是当大量数据访问或要求访问不能中断时,隧道冗余就力不从心了。这种情况就要用到隧道聚合。但是对宽带的要求也高了,双端都至少需要二条宽带。 实验要求与环境 OldMei集团深圳总部部署了域服务器和ER…

C语言复合类型之结构(struct)篇(结构指针)

结构相关知识总结 什么是结构?结构的声明与简单使用结构的初始化结构中成员变量的访问结构的初始化器结构数组结构数组的声明结构数组的成员标识 结构的嵌套结构指针结构作为参数在函数中传递将结构成员作为参数进行传递将结构地址(指向结构的指针)作为参数进行传递…

AI数字人之语音驱动人脸模型Wav2Lip

1 Wav2Lip模型介绍 2020年,来自印度海德拉巴大学和英国巴斯大学的团队,在ACM MM2020发表了的一篇论文《A Lip Sync Expert Is All You Need for Speech to Lip Generation In The Wild 》,在文章中,他们提出一个叫做Wav2Lip的AI模…

面试题:完败的面试,被虐得体无完肤

经过上一轮的面试,我信心一下子就建立起来了,说巧不巧,前几周正好看到美团校招,想着试一下也不会怎样,就找了学长要了内推码,试着投递了一下,然后就通知周六参加笔试,结果惨不忍睹。…

flv 报错 Unsupported codec in video frame: 12

视频播放器播放 flv 报错 [TransmuxingController] > DemuxException: type CodecUnsupported, info Flv: Unsupported codec in video frame: 12 原因 主要是因为我们的播放器不支持 H.265 视频编码; 解决办法 方法一:将设备端的视频编码改为 …

FPGA实现USB3.0 UVC 相机HDMI视频输出 基于FT602驱动 提供工程源码和QT上位机源码

目录 1、前言2、UVC简介3、FT602芯片解读4、我这儿的 FT601 USB3.0通信方案5、详细设计方案基于FT602的UVC模块详解 6、vivado工程详解7、上板调试验证8、福利:工程代码的获取 1、前言 目前USB3.0的实现方案很多,但就简单好用的角度而言,FT6…

基于多层感知机MLP的数据预测与误差分析的完整matlab代码分享

多层感知机(MLP,Multilayer Perceptron)也叫人工神经网络(ANN,Artificial Neural Network),除了输入输出层,它中间可以有多个隐层,最简单的MLP只含一个隐层,即三层的结构。多层感知器(multilayer Perceptron,MLP)是指可以是感知器的人工神经元组成的多个层次。MPL的…

在Windows和Linux系统上,用C语言实现命令行下输入密码回显星号和完全隐藏密码

本篇目录 引子在Windows 上实现在Linux上实现回显星号代码解读运行 完全隐藏运行 引子 在Windows系统上,当我们使用命令行和MySQL进行交互时,第一步就是要输入密码: -p后面的参数紧跟着的就是相应用户的密码。然而这种方式并不安全&#xff…

【数学建模】2019 年全国大学生数学建模竞赛C题全国一等奖获奖论文

2021 年高教社杯全国大学生数学建模竞赛题目 机场的出粗车问题 大多数乘客下飞机后要去市区(或周边)的目的地,出租车是主要的交通工具之一。国内多数机场都是将送客(出发)与接客(到达)通道分开…

2. windows系统下在QT中配置OPenCV开发环境

1. 说明: 在Windows系统中配置相对简单,不需要对下载的源码进行编译,在官网上下载的OPenCV可以直接使用,本文系统版本为win10,opencv是最新版本4.7.0。 效果展示: 2. 配置步骤: 2.1 下载OPenCV压缩包 打开opencv的官网OPenCV下载地址,可以在其页面内下载到最新的压…

iPhone手机UDID获取方法

UDID:iOS设备的唯一识别码,每台iOS设备都有一个独一无二的编码,这个编码,就称为识别码,也叫做UDID(Unique Device Identifier) 一、通过Xcode查看 手机连接电脑打开Xcode,选择wind…

入职2个月,那个高薪挖来的自动化软件测试被劝退了....

其实,在很多小伙伴的想法中,是希望通过跳槽实现薪酬涨幅,可是跳槽不是冲动后决定,应该谨慎啊~ 01 我的学弟,最近向我吐槽,2020 年上半年入职一家公司,当时是高薪挖走的他,所谓钱到…

阿里云无影云电脑使用教程全流程(5分钟上手)

阿里云无影云电脑即无影云桌面,云桌面如何使用?云桌面购买后没有用户名和密码,先创建用户设置密码,才可以登录连接到云桌面。云桌面想要访问公网还需要开通互联网访问功能。阿里云百科来详细说下阿里云无影云电脑从购买、创建用户…

h264结构与码流

h264基本概念结构图 H264视频压缩后会成为一个序列帧,帧里包含图像,图像分为很多片,每个片可以分为宏块,每个宏块由许多子块组成 H264结构中,一个视频图像编码后的数据叫做一帧,一帧由一个片(sl…

Redis系列--布隆过滤器(Bloom Filter)

一、前言 在实际开发中,会遇到很多要判断一个元素是否在某个集合中的业务场景,类似于垃圾邮件的识别,恶意ip地址的访问,缓存穿透等情况。类似于缓存穿透这种情况,有许多的解决方法,如:redis存储…

Python 自动化测试五种自动化测试模型实战详解

目录 前言: 自动化测试模型都有哪些? 线性模型 模块化驱动模型 数据驱动模型 关键字驱动模型 行为驱动模型 扩展知识 前言: Python是一种流行的编程语言,广泛应用于自动化测试领域。自动化测试可以帮助测试人员更快、更准确地发…

人脸识别4:Android InsightFace实现人脸识别Face Recognition(含源码)

人脸识别4:Android InsightFace实现人脸识别Face Recognition(含源码) 目录 人脸识别4:Android InsightFace实现人脸识别Face Recognition(含源码) 1. 前言 2. 项目说明 (1)开发版本 (2)依赖库说明(O…