【深度学习】- 作业6: 图像自然语言描述生成

news2025/1/17 14:11:25

课程链接: 清华大学驭风计划

代码仓库:Victor94-king/MachineLearning: MachineLearning basic introduction (github.com)


驭风计划是由清华大学老师教授的,其分为四门课,包括: 机器学习(张敏教授) , 深度学习(胡晓林教授), 计算机语言(刘知远教授) 以及数据结构与算法(邓俊辉教授)。本人是综合成绩第一名,除了数据结构与算法其他单科均为第一名。<font color=Blue>代码和报告均为本人自己实现,由于篇幅限制,只展示任务布置以及关键代码,如果需要报告或者代码可以私聊博主 </font>



机器学习部分授课老师为胡晓林教授,主要主要通过介绍回归模型,多层感知机,CNN,优化器,图像分割,RNN & LSTM 以及生成式模型入门深度学习


有任何疑问或者问题,也欢迎私信博主,大家可以相互讨论交流哟~~



任务介绍

本次案例将使用深度学习技术来完成图像自然语言描述生成任务,输入一张图片,模型会给出关于图片内容的语言描述。本案例使用coco2014数据集 [1] ,包含82,783张训练图片,40,504张验证图片,40,775张测试图片。案例使用Andrej Karpathy[2]提供的数据集划分方式和图片标注信息,案例已提供数据处理的脚本,只需下载数据集和划分方式即可…

1 任务和数据简介

本次案例将使用深度学习技术来完成图像自然语言描述生成任务,输入一张图片,模型会给出关于图片内容的语言描述。本案例使用 coco2014 数据集[1],包含 82,783 张训练图片,40,504 张验证图片,40,775 张测试图片。案例使用 AndrejKarpathy[2]提供的数据集划分方式和图片标注信息,案例已提供数据处理的脚本,只需下载数据集和划分方式即可。

图像自然语言描述生成任务一般采用 Encoder-Decoder 的网络结构,Encoder采用 CNN 结构,对输入图片进行编码,Decoder 采用 RNN 结构,利用 Encoder编码信息,逐个单词的解码文字描述输出。模型评估指标采用 BLEU 分数[3],用来衡量预测和标签两句话的一致程度,具体计算方法可自行学习,案例已提供计算代码。



2 方法描述

模型输入图像统一到 256×256 大小,并且归一化到[−1,1]后还要对图像进行 RGB 三通道均值和标准差的标准化。语言描述标签信息既要作为目标标签,也要作为Decoder 的输入,以 <start>开始,<end>结束并且需要拓展到统一长度,例如:

< 𝑠𝑡𝑎𝑟𝑡 > 𝑎 𝑡𝑎𝑏𝑙𝑒 𝑡𝑜𝑝𝑝𝑒𝑑 𝑤𝑖𝑡ℎ 𝑝𝑙𝑎𝑡𝑒𝑠 𝑜𝑓 𝑓𝑜𝑜𝑑 𝑎𝑛𝑑 𝑑𝑟𝑖𝑛𝑘𝑠 < 𝑒𝑛𝑑 > < 𝑝𝑎𝑑 > < 𝑝𝑎𝑑 >< 𝑝𝑎𝑑 > ⋯

每个 token 按照词汇表转为相应的整数。同时还需要输入描述语言的长度,具体为单词数加2 (<start><end>),目的是为了节省在 <pad>上的计算时间。Encoder案例使用 ResNet101 网络作为编码器,去除最后 Pooling 和 Fc 两层,并添加了 AdaptiveAvgPool2d()层来得到固定大小的编码结果。编码器已在 ImageNet 上预训练好,在本案例中可以选择对其进行微调以得到更好的结果。



Decoder

Decoder 是本案例中着重要求的内容。案例要求实现两种 Decoder 方式,分别对应这两篇文章[4][5]。在此简要阐述两种 Decoder 方法,进一步学习可参考原文章。


第一种 Decoder 是用 RNN 结构来进行解码,解码单元可选择 RNN、LSTM、GRU 中的一种,初始的隐藏状态和单元状态可以由编码结果经过一层全连接层并做批归一化 (Batch Normalization) 后作为解码单元输入得到,后续的每个解码单元的输入为单词经过 word embedding 后的编码结果、上一层的隐藏状态和单元状态,解码输出经过全连接层和 Softmax 后得到一个在所有词汇上的概率分布,并由此得到下一个单词。Decoder 解码使用到了 teacher forcing 机制,每一时间步解码时的输入单词为标签单词,而非上一步解码出来的预测单词。训练时,经过与输入相同步长的解码之后,计算预测和标签之间的交叉熵损失,进行 BP反传更新参数即可。测试时由于不提供标签信息,解码单元每一时间步输入单词

为上一步解码预测的单词,直到解码出 <end>信息。测试时可以采用 beam search解码方法来得到更准确的语言描述,具体方法可自行学习。



第二种 Decoder 是用 RNN 加上 Attention 机制来进行解码,Attention 机制做的是生成一组权重,对需要关注的部分给予较高的权重,对不需要关注的部分给予较低的权重。当生成某个特定的单词时,Attention 给出的权重较高的部分会在图像中该单词对应的特定区域即该单词主要是由这片区域对应的特征生成的。

Attention 权重的计算方法为:

𝛼 = 𝑠𝑜𝑓𝑡𝑚𝑎𝑥 (𝑓𝑐 (𝑟𝑒𝑙𝑢(𝑓𝑐(𝑒𝑛𝑐𝑜𝑑𝑒𝑟_𝑜𝑢𝑡𝑝𝑢𝑡) + 𝑓𝑐(ℎ))))



其中 softmax()表示 Softmax 函数,fc()表示全连接层,relu()表示 ReLU 激活函encoder_output 是编码器的编码结果,h 是上一步的隐藏状态。初始的隐藏状态和单元状态由编码结果分别经过两个全连接层得到。每一时间步解码单元的输入除了上一步的隐藏状态和单元状态外,还有一个向量,该向量由单词经过word embedding 后的结果和编码器编码结果乘上注意力权重再经过一层全连接层后的结果拼接而成。解码器同样使用 teacher forcing 机制,训练和测试时的流程与第一种 Decoder 描述的一致。样例输出第一种 Decoder 得到的结果仅包含图像的文字描述,如下图:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-AUwARm4K-1684512879541)(image/hw6/1684512483493.png)]

第二种 Decoder 由于有 Attention 机制的存在,可以得到每个单词对应的图片

区域,如下图:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-wmN4KWJ9-1684512879542)(image/hw6/1684512490538.png)]



3 参考程序及使用说明

本次案例提供了完整、可供运行的参考程序,各程序简介如下:create_input_files.py : 下载好数据集和划分方式后需要运行该脚本文件,会生成案例需要的 json 和 hdf5 文件,注意指定输入和输出数据存放的位置。datasets.py : 定义符合 pytorch 标准的 Dataset 类,供数据按 Batch 读入。models.py : 定义 Encoder 和 Decoder 网络结构,其中 Encoder 已提前定义好,无需自己实现。两种 Decoder 方法需要自行实现,已提供部分代码,只需将 #To Do 部分补充完全即可。

solver.py : 定义了训练和验证函数,供模型训练使用。

train.ipynb : 用于训练的 jupyter 文件,其中超参数需要自行调节,训练过程中可以看到模型准确率和损失的变化,并可以得到每个 epoch 后模型在验证集上的 BLEU 分数,保存最优的验证结果对应的模型用于测试。

test.ipynb : 用于测试的 jupyter 文件,加载指定的模型,解码时不使用 teacher forcing,并使用 beam search 的解码方法,最终会得到模型在测试集上的 BLEU分数。

caption.ipynb : 加载指定模型,对单张输入图片进行语言描述,第一种 Decoder 方法只能得到用于描述的语句,第二种 Decoder 方法同时可以获取每个单词对应的注意力权重,最后对结果进行可视化。

utils.py : 定义一些可能需要用到的函数,如计算准确率、图像可视化等。

环境要求:python 包 pytorch, torchvision, numpy, nltk, tqdm, h5py, json, PIL,

matplotlib, scikit-image, scipy=1.1.0 等。



4 要求与建议

 完成 models.py 文件中的 #To Do 部分,可参考第 2 部分中的介绍或原论文;

 调节超参数,运行 train.ipynb,其中 attention 参数指示使用哪种 Decoder,分别训练使用两种不同 Decoder 的模型,可以分两个 jupyter 文件保存最佳参数和训练记录,如 train1.ipynb, train2.ipynb;

 运行 test.ipynb 得到两个模型在测试集上的 BLEU 分数,分别保留结果;

 选择一张图片,可以是测试集中的,也可以是自行挑选的,对图片进行语言描述自动生成,分别保留可视化结果;

 在参考程序的基础上,综合使用深度学习各项技术,尝试提升该模型在图像自然语言描述生成任务上的效果,如使用更好的预训练模型作为 Encoder,或者提出更好的 Decoder 结构,如 Adaptive Attention 等;

 完成一个实验报告,内容包括基础两个模型的实现原理说明、两个模型的最佳参数和对应测试集 BLEU 分数、两个模型在单个图片上的表现效果、自己所做的改进、对比分析两个基础模型结果的不同优劣。

 禁止任何形式的抄袭,借鉴开源程序务必加以说明。



报告

核心代码

attention的实现

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-8XDNxizD-1684512879542)(image/hw6/1684512544937.png)]

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-FhRum8L5-1684512879543)(image/hw6/1684512570061.png)]

Decoder 实现

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-jmVQTZUZ-1684512879543)(image/hw6/1684512588512.png)]

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-VGraIch4-1684512879544)(image/hw6/1684512593190.png)]



结果

下面是普通LSTM与加入attention的对比

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-gg5udPYl-1684512879544)(image/hw6/1684512616319.png)]

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/557001.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

2023 华为 Datacom-HCIE 题库 04--含解析

单选题 1.[试题编号&#xff1a;189921] &#xff08;单选题&#xff09;防火墙双机热备场景下&#xff0c;当VGMP工作在负载分担模式时&#xff0c;为了避免在来回路径不一致的场景下回程流量因没有匹配到会话表项而丢弃的现象&#xff0c;防火墙需要启开一下那些功能&#x…

【vim】从入门到放弃(“四种”模式、常用命令、正则表达式、文件属性、插件安装)

文章目录 一、vim简介二、vim操作2.1 三种模式及其切换2.2 常用命令2.21 命令模式下常用命令2.22 底行模式下常用命令 三、vim进阶3.1 进阶操作3.11 可视化模式3.12 正则表达式3.13 结合其他文本处理命令3.14 修改文件属性&#xff08;编码、格式、权限&#xff09; 3.2 进阶配…

cpu压力测试、平均负载、切换上下文(linux)

和windows下有很多图形化测试工具不同&#xff0c;linux下的压力测试通常需要命令行 一、平均负载 1.查看命令 uptime会给出类似如下的信息 2.说明 三个数值代表1分钟&#xff0c;5分钟&#xff0c;15分钟的平均进程数。 换成更容易理解但不准确的说法就是几个核满载 比如…

目前前端流行的框架总结

框架 前端框架 前端框架一般指用于简化网页设计的框架&#xff0c;使用广泛的前端开发套件&#xff0c;比如&#xff0c;jquery&#xff0c;extjs&#xff0c;bootstrap等等&#xff0c;这些框架封装了一些功能&#xff0c;比如html文档操作&#xff0c;漂亮的各种控件&#x…

取余,取模

目录 一&#xff1a;取整方式 1&#xff1a;向0取整 --- trunc取整函数 2.向-∞取整 --- floor&#xff08;地板&#xff09;函数 3.向∞取整 --- ceil函数 4.四舍五入取整 --- round 函数 5.四种取整方式的对比 二&#xff1a;取模 1.引入 2.取模与取余等价&#xff1f; 一&a…

JavaScript实现通过表格方式显示三角形的代码

以下为实现通过表格方式显示三角形的程序代码和运行截图 目录 前言 一、通过表格方式显示三角形 1.1 运行流程及思想 1.2 代码段 1.3 JavaScript语句代码 1.4 运行截图 前言 1.若有选择&#xff0c;您可以在目录里进行快速查找&#xff1b; 2.本博文代码可以根据题目要…

5.python列表

文章目录 一、什么是列表二、列表的表示方法三 、列表元素的索引四、访问列表元素五、修改列表元素直接赋值 六、添加列表元素6.1 方法append()6.2 方法insert() 七、删除列表元素7.1 语句del7.2方法pop()7.3方法remove() 八、组织列表8.1倒着打印列表8.2确定列表长度8.3 列表排…

【机器学习】 - 作业5: 基于Kmeans算法的AAAI会议论文聚类分析

课程链接: 清华大学驭风计划 代码仓库&#xff1a;Victor94-king/MachineLearning: MachineLearning basic introduction (github.com) 驭风计划是由清华大学老师教授的&#xff0c;其分为四门课&#xff0c;包括: 机器学习(张敏教授) &#xff0c; 深度学习(胡晓林教授), 计算…

HC-05蓝牙模块的使用

我最近刚刚开始学习嵌入式&#xff0c;在第一次使用蓝牙模块HC-05的时候遇到了很多问题&#xff0c; 甚至连接线都不会&#xff0c;因此下面我会十分详细地介绍我一步一步探索的步骤&#xff0c;直到完成使用手机APP和51单片机收发数据。 调试步骤 首先&#xff0c;我们需要明…

2023开放原子全球开源峰会分论坛即将来袭,Pick你最关注的峰会话题!

2023开放原子全球开源峰会即将开启 二十余场分论坛主题重磅首发 聚焦全球开源发展最新动向 前沿技术、行业实践、开源项目与治理等 多场知识盛宴等您来享 为更好地了解大家的参与意向 分论坛投票今天正式启动&#xff01; 投票时间&#xff1a;5月19-26日 长按识别二维码 …

MFC 给对话框添加图片背景

在windows开发当中做界面的主要技术之一就是使用MFC&#xff0c;通常我们看到的QQ,360,暴风影音这些漂亮的界面都可以用MFC来实现。今天我们来说一下如何用MFC美化对话框&#xff0c;默认情况下&#xff0c;对话框的背景如下&#xff1a; 那么&#xff0c;我们如何将它的背景变…

【Servlet 基础】

&#x1f389;&#x1f389;&#x1f389;点进来你就是我的人了博主主页&#xff1a;&#x1f648;&#x1f648;&#x1f648;戳一戳,欢迎大佬指点! 欢迎志同道合的朋友一起加油喔&#x1f93a;&#x1f93a;&#x1f93a; 目录 1. 什么是Servlet&#xff1f; 2. 第一个Serv…

微软 LoRA| 使用万分之一的参数微调你的GPT3模型

一、概述 title&#xff1a;LORA: LOW-RANK ADAPTATION OF LARGE LAN- GUAGE MODELS 论文地址&#xff1a;https://arxiv.org/abs/2106.09685 代码&#xff1a;GitHub - microsoft/LoRA: Code for loralib, an implementation of "LoRA: Low-Rank Adaptation of Large …

课时6—死锁(二)

一、死锁的避免 避免死锁同样属于事先预防策略&#xff0c;是在资源动态分配过程中&#xff0c;防止系统进入不安全状态&#xff0c;以避免发生死锁。 1、系统安全状态 在避免死锁方法中&#xff0c;把系统的状态分为安全状态和不安全状态。当系统处于安全状态时可避免发生死…

Android UI开发之多样式富文本的简洁实现

多样式富文本的简洁实现 原文链接&#xff1a;Android UI开发之多样式富文本的简洁实现 AppendableStyleString 允许你快速构建多种样式文字。 特性 支持对于同一个字符串设置多种样式。支持文字和图片。提供默认样式。采用 DSL 确保更清晰的样式作用范围 快速开始 下面的…

【事务失效】十种常见场景

前提 大多数Spring Boot项目只需要在方法上标记Transactional注解&#xff0c;即可一键开启方法的事务性配置。 但是&#xff0c;事务如果没有被正确出&#xff0c;很有可能会导致事务的失效&#xff0c;避免因为事务处理不当导致业务逻辑产生大量偶发性BUG 事务的传播类型 …

JDK8-17的特性发生了哪些变化

JDK8-17的特性发生了哪些变化 垃圾回收器Java交互式编程接口定义扩展String底层结构变更of 创建不可变序列HTTP 2 协议接口引入 var 关键字字符串增强lambda 表达式类型推导switch 增强支持文本块定义instanceof 模式匹配引入record 关键字新增密封类的定义switch二度加强模块…

栈及其实现

目录 一&#xff1a;栈 1.栈的概念和结构 2.栈的实现 <1>.初始化栈 <2>.入栈 <3>.出栈 <4>:获取栈顶元素 <5>.获取栈中有效元素个数 <6>.销毁栈 <7>.示例 二&#xff1a;栈的完整代码 一&#xff1a;栈 1.栈的概念和结构 …

Origin中log2的计算,设置以2为底的log坐标

使用高中的换底公式即可&#xff0c;把2的底换成10的底计算 ![在这里插入图片描述](https://img-blog.csdnimg.cn/5747fdbd2b5c43f095d716092fd17124.png

模式介绍和基本管理

模式介绍&#xff1a; 用户的模式(SCHEMA&#xff09;指的是用户账号拥有的对象集&#xff0c;在概念上可将其看作是包含表、 视图、索引和权限定义的对象。在 DM 中&#xff0c;一个用户可以创建多个模式&#xff0c;一个模式中的对象 &#xff08;表、视图等&#xff09;可以…