【Image captioning】ruotianluo/self-critical.pytorch之1—数据集的加载与使用
作者:安静到无声 个人主页
数据加载程序示意图
使用方法
示例代码
#%%
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
import numpy as np
import time
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0' ##from six.moves import cPickle
import traceback
from collections import defaultdict
import captioning.utils.opts as opts
import captioning.models as models
from captioning.data.dataloader import *
import skimage.io
import captioning.utils.eval_utils as eval_utils
import captioning.utils.misc as utils
from captioning.utils.rewards import init_scorer, get_self_critical_reward
from captioning.modules.loss_wrapper import LossWrapper
import sys
sys.path.append("..")
import time
#%%
opt = opts.parse_opt()
opt.input_json = '/home/lihuanyu/code/07ImageCaptioning/data/cocotalk.json'
opt.input_label_h5 = '/home/lihuanyu/code/07ImageCaptioning/data/cocotalk_label.h5'
opt.input_fc_dir = '/home/lihuanyu/code/07ImageCaptioning/data/cocotalk_fc'
opt.input_att_dir = '/home/lihuanyu/code/07ImageCaptioning/data/cocotalk_att'
opt.batch_size = 1
opt.train_only = 1
opt.use_att = True
opt.use_att = True
opt.use_box = 0
#%%
print(opt.input_json)
print(opt.batch_size) #批量化为16
loader = DataLoader(opt) # 数据加载
#打印字内容
#print(loader.get_vocab()) #返回字典
for i in range(2):
data = loader.get_batch('train')
print('———————————————————※输入的信息特征※——————————————————') #[1,2048] 全连接特征
print('全连接特征【fc_feats】的形状:', data['fc_feats'].shape) #[1,2048] 全连接特征
print('全连接特征【att_feats】的形状:', data['att_feats'].shape) #[1,2048] 注意力特征
print('att_masks', data['att_masks'])
print('含有的信息infos:', data['infos']) #infos [{'ix': 117986, 'id': 495956, 'file_path': 'train2014/COCO_train2014_000000495956.jpg'}]
print('———————————————————※标签信息※——————————————————') #[1,2048] 全连接特征
print('labels', data['labels']) #添加了一些0
print('gts:', data['gts']) #没有添加的原始句子
print('masks', data['masks'])
print('———————————————————※记录遍历的位置※——————————————————') #[1,2048] 全连接特征
print('bounds', data['bounds'])
time.sleep(1)
print(data.keys())
输出结果:
Hugginface transformers not installed; please visit https://github.com/huggingface/transformers
meshed-memory-transformer not installed; please run `pip install git+https://github.com/ruotianluo/meshed-memory-transformer.git`
Warning: coco-caption not available
cider or coco-caption missing
/home/lihuanyu/code/07ImageCaptioning/data/cocotalk.json
1
是否使用【注意力特征[use_fc]】: True
是否使用【注意力特征[use_att]】: True
是否在注意力特征中使用【检测框特征[use_box]】: 0
DataLoader loading json file: /home/lihuanyu/code/07ImageCaptioning/data/cocotalk.json
vocab size is 9487
DataLoader loading h5 file: /home/lihuanyu/code/07ImageCaptioning/data/cocotalk_fc /home/lihuanyu/code/07ImageCaptioning/data/cocotalk_att data/cocotalk_box /home/lihuanyu/code/07ImageCaptioning/data/cocotalk_label.h5
max sequence length in data is 16
read 123287 image features
assigned 82783 images to split train(训练集有多少图片)
assigned 5000 images to split val(验证集有多少图片)
assigned 5000 images to split test(测试集有多少图片)
———————————————————※输入的信息特征※——————————————————
全连接特征【fc_feats】的形状: torch.Size([1, 2048])
全连接特征【att_feats】的形状: torch.Size([1, 196, 2048])
att_masks None
含有的信息infos: [{'ix': 60494, 'id': 46065, 'file_path': 'train2014/COCO_train2014_000000046065.jpg'}]
———————————————————※标签信息※——————————————————
labels tensor([[[ 0, 1, 271, 17, 7068, 35, 98, 6, 1, 102, 3,
912, 0, 0, 0, 0, 0, 0],
[ 0, 995, 2309, 2308, 609, 6, 1, 271, 119, 912, 0,
0, 0, 0, 0, 0, 0, 0],
[ 0, 2309, 9487, 179, 98, 6, 1, 46, 271, 0, 0,
0, 0, 0, 0, 0, 0, 0],
[ 0, 182, 35, 995, 7068, 6, 1, 271, 3, 60, 678,
32, 14, 29, 0, 0, 0, 0],
[ 0, 995, 915, 17, 2309, 3130, 6, 1, 46, 271, 0,
0, 0, 0, 0, 0, 0, 0]]])
gts: [array([[ 1, 271, 17, 7068, 35, 98, 6, 1, 102, 3, 912,
0, 0, 0, 0, 0],
[ 995, 2309, 2308, 609, 6, 1, 271, 119, 912, 0, 0,
0, 0, 0, 0, 0],
[2309, 9487, 179, 98, 6, 1, 46, 271, 0, 0, 0,
0, 0, 0, 0, 0],
[ 182, 35, 995, 7068, 6, 1, 271, 3, 60, 678, 32,
14, 29, 0, 0, 0],
[ 995, 915, 17, 2309, 3130, 6, 1, 46, 271, 0, 0,
0, 0, 0, 0, 0]], dtype=uint32)]
masks tensor([[[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0.,
0.],
[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0.,
0.],
[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0.,
0.],
[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0.,
0.],
[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0.,
0.]]])
———————————————————※记录遍历的位置※——————————————————
bounds {'it_pos_now': 1, 'it_max': 82783, 'wrapped': False}
dict_keys(['fc_feats', 'att_feats', 'att_masks', 'labels', 'masks', 'gts', 'bounds', 'infos'])
———————————————————※输入的信息特征※——————————————————
全连接特征【fc_feats】的形状: torch.Size([1, 2048])
全连接特征【att_feats】的形状: torch.Size([1, 196, 2048])
att_masks None
含有的信息infos: [{'ix': 106440, 'id': 151264, 'file_path': 'train2014/COCO_train2014_000000151264.jpg'}]
———————————————————※标签信息※——————————————————
labels tensor([[[ 0, 1, 230, 6, 14, 230, 237, 32, 1086, 627, 0,
0, 0, 0, 0, 0, 0, 0],
[ 0, 1, 6035, 230, 35, 274, 127, 225, 1598, 335, 1,
940, 0, 0, 0, 0, 0, 0],
[ 0, 1, 230, 35, 900, 32, 307, 756, 61, 607, 0,
0, 0, 0, 0, 0, 0, 0],
[ 0, 1, 230, 35, 98, 79, 1, 230, 224, 0, 0,
0, 0, 0, 0, 0, 0, 0],
[ 0, 1, 46, 1109, 230, 1596, 245, 1, 224, 0, 0,
0, 0, 0, 0, 0, 0, 0]]])
gts: [array([[ 1, 230, 6, 14, 230, 237, 32, 1086, 627, 0, 0,
0, 0, 0, 0, 0],
[ 1, 6035, 230, 35, 274, 127, 225, 1598, 335, 1, 940,
0, 0, 0, 0, 0],
[ 1, 230, 35, 900, 32, 307, 756, 61, 607, 0, 0,
0, 0, 0, 0, 0],
[ 1, 230, 35, 98, 79, 1, 230, 224, 0, 0, 0,
0, 0, 0, 0, 0],
[ 1, 46, 1109, 230, 1596, 245, 1, 224, 0, 0, 0,
0, 0, 0, 0, 0]], dtype=uint32)]
masks tensor([[[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0.,
0.],
[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0.,
0.],
[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0.,
0.],
[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0.,
0.],
[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0.,
0.]]])
———————————————————※记录遍历的位置※——————————————————
bounds {'it_pos_now': 2, 'it_max': 82783, 'wrapped': False}
dict_keys(['fc_feats', 'att_feats', 'att_masks', 'labels', 'masks', 'gts', 'bounds', 'infos'])
推荐专栏
🔥 手把手实现Image captioning
💯CNN模型压缩
💖模式识别与人工智能(程序与算法)
🔥FPGA—Verilog与Hls学习与实践
💯基于Pytorch的自然语言处理入门与实践