【深度学习】实验6答案:图像自然语言描述生成(让计算机“看图说话”)

news2024/12/24 9:30:13

DL_class

学堂在线《深度学习》实验课代码+报告(其中实验1和实验6有配套PPT),授课老师为胡晓林老师。课程链接:https://www.xuetangx.com/training/DP080910033751/619488?channel=i.area.manual_search。

持续更新中。
所有代码为作者所写,并非最后的“标准答案”,只有实验6被扣了1分,其余皆是满分。仓库链接:https://github.com/W-caner/DL_classs。 此外,欢迎关注我的CSDN:https://blog.csdn.net/Can__er?type=blog。
部分数据集由于过大无法上传,我会在博客中给出下载链接。如果对代码有疑问,有更好的思路等,也非常欢迎在评论区与我交流~

实验6:图像自然语言描述生成(让计算机“看图说话”)

实现原理

Encoder

使用 ResNet101 网络作为编码器,去除最后 Pooling 和 Fc 两层,并添加了 AdaptiveAvgPool2d()层来得到固定大小的编码结果。编码器已在 ImageNet 上预训练好,在本案例中可以选择对其进行微调以得到更好的结果。

Decoder-RNN

实现过程中参考了开源代码:https://github.com/sgrvinod/a-PyTorch-Tutorial-to-Image-Captioning

第一种 Decoder 是用 RNN 结构来进行解码,解码单元可选择 RNN、LSTM、GRU 中的一种,这里选择了LSTM。

在每一个batch执行forward的过程中,首先进行了如下几个操作:

  • 按照caption_lengths降序排列输入数据,重构了encoder_out和encoded_captions。
  • 将encoded_captions经过embedding层进行嵌入表示,同时,caption长度减去1作为decode长度。
  • 创建tensor占位,以记录每一时刻预测的scores。

LSTM初始的隐藏状态和单元状态由encoder_out经过一层全连接层并做批归一化 (Batch Normalization) 后作为解码单元输入。

对后续的每个解码单元,按照t的顺序输入,每次筛选长度大于等于当前t的样本,否则没有输入。输入为单词经过 word embedding 后的编码结果,上一层的隐藏状态和单元状态。得到的解码输出,经过全连接层和 Softmax(存疑?代码中没有明显实现) 后得到一个在所有词汇上的概率分布,并由此得到下一个单词。

其中,训练过程使用了dropout和clip_gradient来防止梯度爆炸,Decoder 解码过程使用到了 teacher forcing 机制。训练时,经过与输入相同步长的解码之后,计算预测和标签之间的交叉熵损失,进行 BP反传更新参数即可。测试时由于不提供标签信息,解码单元每一时间步输入单词为上一步解码预测的单词,直到解码出信息。

核心代码如下:

# To Do: Implement the main decode step for forward pass 
# Hint: Decode words one by one
# Teacher forcing is used.
# At each time-step, generate a new word in the decoder with the previous word embedding
# Your Code Here!
for t in range(max(decode_lengths)):
    idx = sum([l > t for l in decode_lengths])
    preds, h, c = self.one_step(
        embeddings[:idx, t, :], h[:idx], c[:idx])
    predictions[:idx, t, :] = preds

Decoder-AttentionRNN

第二种 Decoder 是用 RNN 加上 Attention 机制来进行解码,Attention 机制做的是生成一组权重,对需要关注的部分给予较高的权重,对不需要关注的部分给予较低的权重。当生成某个特定的单词时,Attention 给出的权重较高的部分会在图像中该单词对应的特定区域,即该单词主要是由这片区域对应的特征生成的。

此处Attention 权重的计算方法(f_att)为:

𝛼 = 𝑠𝑜𝑓𝑡𝑚𝑎𝑥 (𝑓𝑐 (𝑟𝑒𝑙𝑢(𝑓𝑐(𝑒𝑛𝑐𝑜𝑑𝑒𝑟_𝑜𝑢𝑡𝑝𝑢𝑡) + 𝑓𝑐(ℎ))))

其中fc()表示全连接层,用于统一不同维度至decoder_dim,然后经过MLP得到attention权重。

此时,每一时间步解码单元的输入除了embedding,上一步的隐藏状态和单元状态外,还有一个向量,该向量为经过门控(上一刻的隐藏状态经过单层神经元)后的Attention 权重。

核心代码如下:

# To Do: Implement the forward pass for attention module
# Hint: follow the equation 
# "e = f_att(encoder_out, decoder_hidden)"
# "alpha = softmax(e)"
# "z = alpha * encoder_out"
# Your Code Here!
encoder_att = self.encoder_trans(encoder_out)
decoder_att = self.decoder_trans(decoder_hidden)
# att: (batch_size, num_pixels, attention_dim) + (batch_size, attention_dim).unsqueeze(1)
# e: (batch_size, num_pixels, attention_dim) dot (attention_dim, 1) -> (batch_size, num_pixels, 1)
e = self.full_trans(self.relu(encoder_att + decoder_att.unsqueeze(1)))
# alpha: (batch_size, num_pixels, 1)
alpha = self.softmax(e)
# z: (batch_size, encoder_dim)
z = (alpha * encoder_out).sum(dim = 1)

所做改进

因为飞桨平台总是断线,此处使用命令行+输出至文件的方式进行训练。首先对于两种Decoder模型原始参数进行3个周期的训练,分别保存训练过程于train1.logtrain2.log。可以发现,无论是准确率,收敛速度,还是在验证集上的BLUE,带Attention的Decoder都有着较好的表现:

  • Decoder-RNN(3周期)

在这里插入图片描述

  • Decoder-AttentionRNN(3周期)

在这里插入图片描述

我没有算力了,没有尝试使用其它改进。简单的学习了一下Adaptive Attention和Beam search ,发现原始代码后面使用的MLP较为简单,仅输入了h作为全连接层,进行预测,而论文中还需要拼接向量C_t作为输入,也就是self.fc 输入维度需要扩展decoder_dim + encoder_dim,同时,如下图所示,两者的最主要区别在于C_t的生成。

在这里插入图片描述

论文中提出的改进的 spatial attention 模型,在每个step的过程,是先经过的LSTM,将当前的隐藏单元(而不是上一时刻)作为Attention函数的输入,核心实现代码如下:

def one_step(self, embeddings, encoder_out, h, c):
    ############################################################################
    # To Do: Implement the one time decode step for forward pass
    # this function can be used for test decode with beam search
    # return predicted scores over vocabs: preds
    # return attention wAeight: alpha
    # return hidden state and cell state: h, c
    # Your Code Here!
    h, c = self.decode_step(torch.cat([embeddings, z], dim=1), (h, c))
    z, alpha = self.adpattention(encoder_out, h)
    gate = self.sigmoid(self.beta(h))
    z = gate * z
    preds = self.fc(self.dropout(torch.cat([h,z],dim=1)))
    ############################################################################
    return preds, alpha, h, c

同时针对Attention策略,作者认为对于非视觉词,它们的生成应该取决于历史信息而不是视觉信息,因此在这种情况下应该对视觉信息加以控制。这一部分的代码我没有运行成功(如果直接导入需要注释掉AdpAttention相关内容),我没有明白这里的m_t的含义:

在这里插入图片描述

最佳参数,测试集BLUE-4

对于实现的模型中某些参数进行调整,如fine_tune_encoder 设为True,grad_clip 稍微放宽至8,重新进行训练,第三个周期得到结果如下:

在这里插入图片描述

可以发现,允许预训练模型的微调能够带来更好的效果。经过5个周期的训练,最终得到三种模型每一周期的验证集BLUE如下图所示:

在这里插入图片描述

分别采取每种模型最好表现的checkpoint,测试集BlUE-4最高能达到29.4,存在明显的欠拟合情况,如果增加训练周期将会有更好的表现:

在这里插入图片描述

表现效果

随机选取了一些图片进行带Attention和仅Rnn的示例展示,样例如下:

在这里插入图片描述

在这里插入图片描述

可以看到,相比单纯的Rnn可解释性强,效果更好。但同时也存在缺点,即Attention机制不是一个"distance-aware"的,无法捕捉语序顺序,存在可改进空间。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/66997.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

ElasticSearch实战

一、es集群的搭建 1.集群相关概念 单节点故障问题 单台服务器,往往都有最大的负载能力,超过这个阈值,服务器性能就会大大降低甚至不可用。单点的elasticsearch也是一样那单点的es服务器存在哪些可能出现的问题呢? 单台机器存储…

[附源码]计算机毕业设计基于springboot在线影院系统

项目运行 环境配置: Jdk1.8 Tomcat7.0 Mysql HBuilderX(Webstorm也行) Eclispe(IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持)。 项目技术: SSM mybatis Maven Vue 等等组成,B/S模式 M…

ROS MoveIT1(Noetic)安装总结

前言 由于MoveIT2的Humble的教程好多用的还是moveit1的环境,所以又装了Ubutun20.04和ROS1(Noetic)。【2022年12月6日】 环境 系统:Ubutun20.04LTS Ros:Noetic 虚拟机:VMware 安装 ROS Noetic 安装教程…

【微信小程序】canvasToTempFilePath遇到的问题

在微信小程序开发中,经常需要将绘制好的canvas保存到本地,这就需要调用canvasToTempFilePath将canvas画布转为本地临时文件。遇到过的问题如下: 1.create bitmap failed 2.fail canvas is empty 这个问题就是canvas还没画为空拿不到转化的临…

Eclipse+Maven+Tomcat 集成开发环境配置

在Eclipse中创建的Dynamic Web Project 类型的Web 项目, 通过Run As -> Run on Server 可以添加本地安装的Tomcat,在Eclipse 中启动Tomcat 进行整合开发。 但是如果创建的是Maven类型的项目,如果没有导入额外的包还正常, 但是…

Nginx入门到搭建

前言 上一篇文章我们分享了Linux的软件安装以及项目后端的部署,本篇文章将要分享的内容是,Nginx的入门安装、反向代理、负载均衡等。 一、Nginx简介 Nginx是一款轻量级的Web 服务器/反向代理服务器及电子邮件(IMAP/POP3) 代理服务…

如何让彩色网页变灰

如何让彩色网页变灰 在特殊的日子&#xff08;如清明节特殊纪念日等&#xff09;&#xff0c;需要让彩色网页变成灰色&#xff08;黑白色&#xff09;如下图所示&#xff0c;怎么做到呢&#xff1f; 下面先给出彩色正常的网页示例源码&#xff1a; <!DOCTYPE html> <…

消除数据库表中的重复组

重复组是在整个数据库表中重复的一系列字段/属性。大型和小型组织都面临着一个普遍的问题&#xff0c;这个问题可能会带来多种后果。例如&#xff0c;在不同区域中存在的同一组信息会导致数据冗余和数据不一致。而且&#xff0c;所有这些重复的数据可能会占用大量宝贵的磁盘空间…

【计算机图形学入门】笔记2:向量与线性代数(图形学中用到的线性代数)

02向量与线性代数&#xff08;图形学中用到的线性代数&#xff09;1.A Swift and Brutal Introduction to Linear Algebra!简单粗暴入门线性代数1.Graphics’ Dependencies 图形学依赖的一些知识2.Vectors 向量1.Dot product向量的点乘2.向量的叉乘Cross product3.矩阵Matrix4.…

Linux——进程并发控制(系统中的POSIX信息量机制、进程间通信)

目录 一、Linux系统中POSIX信号量机制 1、POSIX有名信号量 &#xff08;1&#xff09;常用函数 &#xff08;2&#xff09;有名信号量应用于多线程的例子 &#xff08;3&#xff09;有名信号量应用于多进程 2、POSIX无名信号量 &#xff08;1&#xff09;常用函数 &…

基于jsp+mysql+ssm大学生社交平台-计算机毕业设计

项目介绍 本系统需要满足校园网上社交方面的基本需要。需要实现用户所要求的功能&#xff0c;方便他们进行交流。在界面上力求做到美观、操作方面尽量避免由于会员操作不当带来系统的出错现象。对数据库操作的性能需要做到优化&#xff0c;数据库过大将会影响运行速度。编程过…

(四) Docker镜像

Docker镜像一、概述二、镜像加载原理三、镜像注意点四、Docker镜像commit操作五、总结一、概述 书面解释 是一种轻量级、可执行的独立软件包&#xff0c;它包含运行某个软件所需的所有内容&#xff0c;我们把应用程序和配置依赖打包好形成一个可交付的运行环境(包括代码、运行时…

开关电源环路稳定性分析(04)-电压控制模式

大家好&#xff0c;这里是大话硬件。 在前3节分析了一个开环电源是如何工作的&#xff0c;开环电源的弊端也很明显&#xff0c;无法维持输出的稳定&#xff0c;不能抗扰动&#xff0c;无法得到我们想要的电压等等。因此&#xff0c;开关电源的闭环环路对稳定性来说非常重要。 …

LeetCode简单题之统计共同度过的日子数

题目 Alice 和 Bob 计划分别去罗马开会。 给你四个字符串 arriveAlice &#xff0c;leaveAlice &#xff0c;arriveBob 和 leaveBob 。Alice 会在日期 arriveAlice 到 leaveAlice 之间在城市里&#xff08;日期为闭区间&#xff09;&#xff0c;而 Bob 在日期 arriveBob 到 l…

大数据:Storm和流处理简介

一、Storm 1.1 简介 Storm 是一个开源的分布式实时计算框架&#xff0c;可以以简单、可靠的方式进行大数据流的处理。通常用于实时分析&#xff0c;在线机器学习、持续计算、分布式 RPC、ETL 等场景。Storm 具有以下特点&#xff1a; 支持水平横向扩展&#xff1b;具有高容错…

信息安全技术 信息安全风险评估方法 汇总

概述 风险评估应贯穿于评估对象生命周期 各阶段中。评估对象生命周期各阶段中涉及的风险评估原则和方法昆一致的&#xff0c;但由干各阶段实施内容对象、安全需求不同.使得风险评估的对象、目的、要求等各方面也有所不同。在规划设计阶段&#xff0c;通过风险评估以确定评估对…

(推荐阅读)H264, H265硬件编解码基础及码流分析

需求 在移动端做音视频开发不同于基本的UI业务逻辑工作,音视频开发需要你懂得音视频中一些基本概念,针对编解码而言,我们必须提前懂得编解码器的一些特性,码流的结构,码流中一些重要信息如sps,pps,vps,start code以及基本的工作原理,而大多同学都只是一知半解,所以导致代码中的…

JAVA-元注解和注解

故事背景&#xff1a;罗芭是一名正在学习java的妹子&#xff0c;最近看甲骨文的官方文档&#xff0c;学到了注解Annotation这里&#xff0c;发现注解我可以自定义&#xff0c;但罗芭不会诶。但是布洛特 亨德尔已经学习过了java注解。 罗芭&#xff0c;help me~ 唰唰唰&#xff…

Redis05:Redis高级部分

Redis高级部分SpringBoot整合Redis整合测试序列化配置解决乱码问题redis自定义RedisTemplateSpringBoot整合Redis 说明&#xff1a;在SpringBoot2.x之后&#xff0c;原来使用jedis被替换成了letttuce! jedis:采用的时直连&#xff0c;多个线程操作的话&#xff0c;是不安全的&a…

MySQL下载和安装(Windows)

前言&#xff1a;刚换了一台电脑&#xff0c;里面所有东西都需要重新配置&#xff0c;习惯了所有东西都配好的装配&#xff0c;突然需要自己从头来配才发现不知道如何下手&#xff0c;所以决定将这些步骤都做个记录&#xff0c;以便后续查看。仅限没有安装过的人使用&#xff0…