Show, Attend, and Tell | a PyTorch Tutorial to Image Captioning代码调试(跑通)

news2024/12/22 18:35:37

Show, Attend, and Tell | a PyTorch Tutorial to Image Captioning代码调试(跑通)

文章目录

  • Show, Attend, and Tell | a PyTorch Tutorial to Image Captioning代码调试(跑通)
  • 前言
  • 1. 创建、安装所用的包
    • 1.1 创建环境,安装pytorch包
    • 1.2 安装其他必要的包
  • 2. 准备数据
  • 3. create_input_files.py调试
  • 4. train.py调试
  • 5. caption.py调试
  • 6. eval.py调试
  • 总结

前言

Show, Attend, and Tell是一个使用图像生成描述性字幕的模型。该模型通过注意力机制,学习如何在生成字幕时,关注与当前要生成的单词最相关的图像部分。在生成字幕过程中,我们可以看到模型的视线在图像上移动。

代码:sgrvinod/a-PyTorch-Tutorial-to-Image-Captioning: Show, Attend, and Tell | a PyTorch Tutorial to Image Captioning (github.com)
论文:Show, Attend and Tell: Neural Image CaptionGeneration with Visual Attention (arxiv.org)

p9CzIrq.png

1. 创建、安装所用的包

在本节中,我们主要介绍如何创建用于运行Show, Attend, and Tell代码的环境。

1.1 创建环境,安装pytorch包

  1. 创建show_attend_tell 环境
conda create -n show_attend_tell python=3.6
  1. 进入所创建环境show_attend_tell
conda activate show_attend_tell
  1. 为了在具有 CUDA 9.0 的服务器上安装 PyTorch 包,您可以按照官网的指南进行操作。安装命令如下(注意,后来发现我的服务器CUDA与pytorch不匹配,我又切换为CUDA=9.2版本,重新安装了Pytorch=1.5):
conda install pytorch==1.1.0 torchvision==0.3.0 cudatoolkit=9.0 -c pytorch
  1. 下图可以说明GPU是可以正常运行的。

image

1.2 安装其他必要的包

  1. scipy是一个Python科学计算库,包含了许多常用的数学、科学和工程计算功能,例如信号处理、优化器、图像处理等。详细描述,请参照 第三章(2):深入理解NTLK库基本使用方法_安静到无声的博客-CSDN博客

  2. nltk是自然语言处理(NLP)领域的Python库,提供了许多实用的工具和数据集,例如分词、词性标注、语言模型等。详细描述,请参照 第三章(3):深入理解Spacy库基本使用方法_安静到无声的博客-CSDN博客

  3. h5py是一种用于读写HDF5文件的Python库,HDF5文件是一种用于存储和交换科学数据的文件格式。在深度学习中,很多模型都使用HDF5格式来保存和加载权重参数。

  4. tqdm是一个Python进度条库,可以在循环过程中显示进度条和估计的剩余时间,方便用户追踪长时间运行的程序的进度。

conda install scipy
conda install nltk
conda install h5py
conda install tqdm

2. 准备数据

本实验是在Flickr8k数据上进行的,可以前往以下链接:[ Flickr 8k Dataset]下载官方数据。

数据集中文本文件类型如下:

  1. Flickr_8k.trainImages.txtFlickr_8k.testImages.txtFlickr_8k.devImages.txt:包含了训练、测试和验证集中图片的文件名。

  2. Flickr8k.token.txtFlickr8k.lemma.token.txt:包含了每张图片的标题。

  3. ExpertAnnotations.txtCrowdFlowerAnnotations.txt:包含了每张图片的人工评注数据。

p9PSFiD.png

之后需要对Flickr8k数据集的文件进行预处理,生成符合COCO JSON格式的输入数据,以用于后续的图像标题生成实验,具体教程请参照成功实现:将Flickr8k.token.txt转换为JSON格式(其他数据集可仿照迁移)_安静到无声的博客-CSDN博客。读者也可以访问该此网站自行下载。

3. create_input_files.py调试

在准备好数据之后, 我们运行create_input_files.py

入口函数参数值的含义如下:

dataset:数据集名称,可选值为 cocoflickr8kflickr30k
karpathy_json_path:Karpathy JSON 文件的路径,其中包含了数据集的划分和图像描述。
image_folder:包含下载的图像的文件夹路径。 captions_per_image:每张图像抽样的图像描述数。
min_word_freq:单词频率的阈值,小于此阈值的单词将被替换成 标记。
output_folder:保存文件的文件夹路径。 max_len:允许抽样的图像描述最大长度,超过此长度的描述将被过滤掉。

最后基于flickr8k数据集,我们修改了参数配置,具体如下:

create_input_files(dataset='flickr8k',  
                   karpathy_json_path='/home/lihuanyu/Data/flickr8k/dataset_flickr8k.json',  
                   image_folder='/home/lihuanyu/Data/flickr8k/Flickr8k_Dataset/Flicker8k_Dataset/',  
                   captions_per_image=5,  
                   min_word_freq=5,  
                   output_folder='/home/lihuanyu/code/09show_attend_tell/result/',  
                   max_len=50)

不出意外当然还是报错呢~,具体错误如下:

p9PSmLt.png

我们安装pip install imageio,再将from scipy.misc import imread, imresize改为

from imageio import imread  
from scipy.misc import imresize

但是仍然会报出如下错误:

p9PSueP.png

参考解决方法cannot import name ‘imresize‘ from ‘scipy.misc‘ - 腾讯云开发者社区-腾讯云 (tencent.com)

我们将img = imresize(img, (256, 256))改为

img = np.array(Image.fromarray(img).resize((256, 256)))

程序最终可以正常运行:

p9PSKdf.png

由图可知,一共6000张用于训练、1000张用于测试、1000张用于验证。

最终在result文件夹中生成了如下文件:

p9PS1Jg.png

至此create_input_files.py调试完成,数据的准备阶段也已经完成。

4. train.py调试

打开train.py我们可以看到需要配置以下参数:

p9PS8zj.png

初始阶段,我们只修改与数据路径有关的参数配置,不改变其余参数,能够保证可以正常训练即可。

路径的修改如下图所示:

data_folder = '/home/lihuanyu/code/09show_attend_tell/result'  # folder with data files saved by create_input_files.py 由create_input_files.py创建的包含数据文件的文件夹”  
data_name = 'flickr8k_5_cap_per_img_5_min_word_freq'  # base name shared by data files 数据文件共享的基础名称

然后开始运行train.py程序,可以看到成功下载卷积神经网络权重数据。

p9PSJQs.png

但是还是报出了如下错误~

p9PSYyn.png

这个错误比较简单,修改为如下程序即可:

p9PSUe0.png

继续运行train.py函数,我们可以发现程序可以直接运行:

p9PSawV.png

但是在执行完第一个epoch之后,模型报错,如下所示:

在这里插入图片描述

这是因为只更改了train函数函数中的错误,没有更新validate函数错误。

在这里插入图片描述

修改后的代码,终于可以正常运行。

在这里插入图片描述

5. caption.py调试

  1. 直接运行caption.py程序,会显示如下错误:

p9CYeFf.png

我们安装matplotlib包即可,如果想了解matplotlib的基本使用方法,请访问:第二章(1):Python入门:语法基础、面向对象编程和常用库介绍_安静到无声的博客-CSDN博客

  1. 运行后,又报出如下错误:

p9CYI0I.png

我们接着安装skimage即可,参考解决方案如下:ModuleNotFoundError: No module named ‘skimage‘modulenotfounderror: no module named 'skimage==cjw==的博客-CSDN博客

  1. 继续运行,出现如下错误。

p9CYHtf.png

这与调试create_input_files.py时出现的错误一样,我们做相同的修改即可。

  1. 继续运行,出现如下错误。

p9CtmH1.png

这是由于没有加载模型,图片和字典,我们修改如下:

parser.add_argument('--img', '-i', default= '/home/lihuanyu/code/09show_attend_tell/img/1007129816_e794419615.jpg', help='path to image')  # 图片的路径  
parser.add_argument('--model', '-m', default= '/home/lihuanyu/code/09show_attend_tell/BEST_checkpoint_flickr8k_5_cap_per_img_5_min_word_freq.pth.tar', help='path to model')  # 模型的路径  
parser.add_argument('--word_map', '-wm',default='/home/lihuanyu/code/09show_attend_tell/result/WORDMAP_flickr8k_5_cap_per_img_5_min_word_freq.json', help='path to word map JSON') # json的路径
  1. 继续运行,报出如下错误。

p9CtLUx.png

错误修改方法,同第3点。

p9CURpT.png

  1. 继续运行,报出如下错误。

p9CwIBQ.png

这是由于版本不一致造成的,我们又在CUDA=9.2版本,重新安装了Pytorch=1.5的版本。

命令如下:

pip install torch==1.5.1+cu92 torchvision==0.6.1+cu92 -f https://download.pytorch.org/whl/torch_stable.html
  1. 继续运行

可以输出如下结果:

p9CBvTJ.png

6. eval.py调试

该程序的配置参数不多,具体如下所示:

p9C7rkR.png

我们按照train.py的配置方式,对eval.py代码进行修改。

# Parameters  
data_folder = '/home/lihuanyu/code/09show_attend_tell/result'  # folder with data files saved by create_input_files.py  
data_name =  'flickr8k_5_cap_per_img_5_min_word_freq'   # base name shared by data files  
checkpoint = '/home/lihuanyu/code/09show_attend_tell/BEST_checkpoint_flickr8k_5_cap_per_img_5_min_word_freq.pth.tar'  # model checkpoint  
word_map_file = '/home/lihuanyu/code/09show_attend_tell/result/WORDMAP_flickr8k_5_cap_per_img_5_min_word_freq.json'  # word map, ensure it's the same the data was encoded with and the model was trained with

运行程序,出现了以下错误:

p9CHeE9.png

我们参照了ValueError: max() arg is an empty sequence · Issue #191 · sgrvinod/a-PyTorch-Tutorial-to-Image-Captioning (github.com)的方法进行解决。

p9Czf2j.png

再次运行,可以得到如下结果。

p9Czhxs.png

总结

至此我们完成了Show, Attend and Tell: Neural Image CaptionGeneration with Visual Attention的代码复现,后续将对每个py文件进行详细注解,感谢关注。


参考
Previous PyTorch Versions | PyTorch
show attend and tell代码实现(绝对详细)_show attend and tell pytorch代码_饿了就干饭的博客-CSDN博客

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/440264.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【深度学习】OCR文本识别

OCR文字识别定义 OCR(optical character recognition)文字识别是指电子设备(例如扫描仪或数码相机)检查纸上打印的字符,然后用字符识别方法将形状翻译成计算机文字的过程;即,对文本资料进行扫描…

【数据结构】二叉树经典oj题

🚀write in front🚀 📜所属专栏:初阶数据结构 🛰️博客主页:睿睿的博客主页 🛰️代码仓库:🎉VS2022_C语言仓库 🎡您的点赞、关注、收藏、评论,是对…

B. Make Them Equal(Codeforces Round 673 (Div. 1))

传送门 题意: 思路: 首先判断是否能够操作达到目的:即所有的数都相等。 不能达到有两种情况: 1:所有数之和对n取余不等于0 2: 每个ai都是小于i的,例如n5, a[]{0,1,2,3,4}。因为每个数都是小于 i 的&am…

idea中的 debug 中小功能按钮都代表的意思

1.step over 步过----->一行一行的往下走,如果这一行中有方法那么不会进入该方法,直接一行一行往下走,除非你在该方法中打入断点 2.step into 步入—>可以进入方法内部,但是只能进入自己写的方法内部,而不会进入方法的类库中 3.Force step into 强制步入---->可以步…

编译livox ros driver2(ROS2、livox、rviz、ubuntu22.04)

1. 编译Livox-SDK2 官方地址:https://github.com/Livox-SDK/Livox-SDK2 执行一下命令: git clone https://github.com/Livox-SDK/Livox-SDK2.git cd ./Livox-SDK2/ mkdir build cd build cmake .. && make sudo make install 如上就安装完成了…

嵌入式【CPLD】5M570ZT100C5N、5M1270ZF256C5N、5M2210ZF256C5N采用独特的非易失性架构,低成本应用设计。

英特尔MAX V CPLD 采用独特的非易失性架构,提供低功耗片上功能,适用于以边缘为中心的应用。MAX V CPLD系列能够在单位空间中提供大量 I/O 和逻辑。这些设备还使用了低成本绿色封装技术,封装大小只有 20 毫米。 MAX V系列的关键应用包括&…

PCL点云库(1) - 简介与数据类型

目录 1.1 简介 1.2 PCL安装 1.2.1 安装方法 1.2.2 测试程序 1.3 PCL数据类型 1.4 PCL中自定义point类型 1.4.1 增加自定义point的步骤 1.4.2 完整代码 1.1 简介 来源:PCL(点云库)_百度百科 PCL(Point Cloud Library&…

个推打造消息推送专项运营提升方案,数据驱动APP触达效果升级

“数智化运营”能力已成为企业的核心竞争力之一。借助数据和算法,构建完善的数智化运营体系,企业可增强用户洞察和科学决策能力,提高日常运营效率和投入产出比。近半年,个推精准把握行业客户的切实需求,将“数智化运营…

分析型数据库:MPP 数据库的概念、技术架构与未来发展方向

随着企业数据量的增多,为了配合企业的业务分析、商业智能等应用场景,从而驱动数据化的商业决策,分析型数据库诞生了。由于数据分析一般涉及的数据量大,计算复杂,分析型数据库一般都是采用大规模并行计算或者分布式计算…

css的属性选择器

文章目录 属性选择器的原理简单的语法介绍子串值(Substring value)属性选择器 CSS 属性选择器的最基本用法复杂一点的用法层叠选择多条件复合选择伪正则写法配合 **:not()** 伪类重写行内样式 组合用法:搭配伪元素提升用户体验角标功能属性选…

基于51单片机的智能晾衣架的设计与实现(源码+论文)_kaic

【摘要】随着社会和市场的变化,我国经济的快速发展和房地产行业的快速扩张,使得装修家居行业飞速发展,在行业高速发展的同时,消费者家居智能化要求也在日益发展。随着科学技术的进步和电子技术的发展,单片机作为智能控…

Stable Diffusion一键安装器,只有2.3M

最近AI画图真的是太火了,但是Midjourney收费之后大家就不知道去哪里能用AI画图了, Stable Diffusion很多人听过,但是安装特别麻烦。所以为大家介绍一款软件,一键安装SD。 Stable Diffusion一键安装器_SD一键启动器-Stable Diffus…

LeetCode:459. 重复的子字符串 —【2、KMP算法】

🍎道阻且长,行则将至。🍓 🌻算法,不如说它是一种思考方式🍀 算法专栏: 👉🏻123 一、🌱459. 重复的子字符串 题目描述:给定一个非空的字符串 s &…

Docker数据管理与Docker镜像的创建

目录 1.管理数据的方式 1.数据卷 2.数据卷容器 3.容器互联(使用centos镜像) 2.Docker镜像的创建 1.基于现有镜像创建 2.基于本地模板创建 3.基于Dockerfile创建 4.Dockerfile案例 总结 1.管理数据的方式 容器中管理数据主要有两种方式&#xff1…

c++作业

自己定义mystring类实现string功能 #include <iostream> #include<cstring> using namespace std;class myString {private:char *str; //记录c风格的字符串int size; //记录字符串的实际长度public://无参构造myString():size(10){str new …

tomcat服务搭建

系列文章目录 文章目录 系列文章目录一、Tomcat1.核心功能 二、Tomcat服务搭建1.Tomcat服务2.Tomcat 虚拟主机配置1.创建 kgc 和 benet 项目目录和文件2.修改 Tomcat 主配置文件 server.xml3.客户端浏览器访问验证 三、Tomcat多实例部署 一、Tomcat 1.核心功能 1.connector&a…

Spring Bean生命周期源码之包扫描、创建BeanDefinition、合并BeanDefinition源码

文章目录 Bean生命周期源码生成BeanDefinitionSpring容器启动时创建单例Bean合并BeanDefinition Bean生命周期源码 我们创建一个ApplicationContext对象时&#xff0c;这其中主要会做两件时间&#xff1a;包扫描得到BeanDefinition的set集合&#xff0c;创建非懒加载的单例Bea…

体验ChatGPT在具体应用场景下的能力与表现——vuedraggable的move多次触发问题

当下人工智能模型在满天飞&#xff0c;今天拿一个具体的应用场景&#xff0c;来体验下ChatGPT的能力与表现&#xff0c;看看是否能解决实际问题。 顺便填一下之前遇到的一个具体的坑&#xff1a;vuedraggable的move多次触发问题。 背景 背景是这样的&#xff0c;实现低代码开…

Hadoop启动相关命令

Hadoop启动相关配置 文章目录 Hadoop启动相关配置格式化节点的情况什么情况下Hadoop需要进行格式化节点&#xff1f; Hadoop启动步骤Hadoop的启动步骤只是start-dfs.sh即可吗 *hdfs*的web管理页面参数说明参数的评价场景 格式化节点的情况 什么情况下Hadoop需要进行格式化节点…

赛效:怎么用改图鸭进行一键Logo设计?

改图鸭工具是一款在线图像处理工具&#xff0c;可以对图片进行大小调整、添加色彩、滤镜等&#xff0c;用户使用改图鸭可快速轻松地对多种图像进行处理操作&#xff0c;另外&#xff0c;改图鸭工具还支持一键进行Logo设计&#xff0c;很多人对改图鸭工具比较陌生&#xff0c;不…