PyTorch搭建RNN联合嵌入模型(LSTM GRU)实现视觉问答(VQA)实战(超详细 附数据集和源码)

news2024/12/23 23:24:53

需要源码和数据集请点赞关注收藏后评论区留言私信~~~

一、视觉问题简介

视觉问答(VQA)是一种同时设计计算机视觉和自然语言处理的学习任务。简单来说,VQA就是对给定的图片进行问答,一个VQA系统以一张图片和一个关于这张图片形式自由,开放式的自然语言问题作为输入,生成一条自然语言答案作为输出,视觉问题系统综合运用到了目前的计算机视觉和自然语言处理的技术,并设计模型设计,实验,以及可视化。

VQA问题的一种典型模型是联合嵌入模型,这种方法首先学习视觉与自然语言的两个不同模态特征在一个共同的特征空间的嵌入表示,然后根据这种嵌入表示产生回答。

二、数据集的准备

1:下载数据

这里使用VQA2.0数据集进行训练和验证,VQA2.0是一个公认有难度,并且语言验证得到了有效控制的数据集

本次使用到的图片为MSCOCO数据集中train2014子集和val2014子集,图片可以在官网下载

数据集网址

本次用到的图像特征是由目标检测网络Faster-RCNN检测并生成的,可评论区留言私信博主要

2:安装依赖

确保安装好PyTorch,然后在程序目录下运行pip install -r requirements.txt安装其他依赖项

三、关键模块简介

1:FCnet模块

FCnet即一系列的全连接层,各个层的输入输出大小在模块构建时给出,这个模块默认使其中的全连接层具有bias,并以ReLU作为激活函数 并使用weight normalization

2:SimpleClassifier模块

它的作用是:在视觉问答系统的末端,根据融合的特征得到最终答案

3:问题嵌入模块

在联合嵌入模型中,需要使用RNN将输入的问题编码成向量,LSTM和GRU使两种代表性的RNN,由于实践中GRU与LSTM表现相近且占用显存较少,所以这里选用GRU

4:词嵌入

要获得问题句子的嵌入表示,首先应该获得词嵌入表示,每一个词需要用一个唯一的数字表示

baseline代码如下

import torch
import torch.nn as nn
from lib.module import topdown_attention
from lib.module.language_model import WordEmbedding, QuestionEmbedding
from lib.module.classifier import SimpleClassifier
from lib.module.fc import FCNet

class Baseline(nn.Module):

    def __init__(self, w_emb, q_emb, v_att, q_net, v_net, classifer, need_internals=False):
        super(Baseline, self).__init__()

        self.need_internals = need_internals

        self.w_emb = w_emb
        self.q_emb = q_emb
        self.v_att = v_att
        self.q_net = q_net
        self.v_net = v_net
        self.classifier = classifer

    def forward(self, q_tokens, ent_features):

        w_emb = self.w_emb(q_tokens)
        q_emb = self.q_emb(w_emb)

        att = self.v_att(q_emb, ent_features) # [ B, n_ent, 1 ]
        v_emb = (att * ent_features).sum(1)  # [ B, hid_dim ]

        internals = [att.squeeze()] if self.need_internals else None

        q_repr = self.q_net(q_emb)
        v_repr = self.v_net(v_emb)
        joint_repr = q_repr * v_repr
        logits = self.classifier(joint_repr)

        return logits, internals

    @classmethod
    def build_from_config(cls, cfg, dataset, need_internals):
        w_emb = WordEmbedding(dataset.word_dict.n_tokens, cfg.lm.word_emb_dim, 0.0)
        q_emb = QuestionEmbedding(cfg.lm.word_emb_dim, cfg.hid_dim, cfg.lm.n_layers, cfg.lm.bidirectional, cfg.lm.dropout, cfg.lm.rnn_type)
        q_dim = cfg.hid_dim
        att_cls = topdown_attention.classes[cfg.topdown_att.type]
        v_att = att_cls(1, q_dim, cfg.ent_dim, cfg.topdown_att.hid_dim, cfg.topdown_att.dropout)
        q_net = FCNet([q_dim, cfg.hid_dim])
        v_net = FCNet([cfg.ent_dim, cfg.hid_dim])
        classifier = SimpleClassifier(cfg.hid_dim, cfg.mlp.hid_dim, dataset.ans_dict.n_tokens, cfg.mlp.dropout)
        return cls(w_emb, q_emb, v_att, q_net, v_net, classifier, need_internals)

数据集目录如下

四、结果可视化

读取了之前训练好的模型之后,使用数据为配置文件中的val,程序运行完成后结果可视化如下

机器对于给出的图片会输出对于的问答结果

 

 五、代码

部分代码如下

训练类

import os
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import LambdaLR
from torch.nn.utils import clip_grad_norm_
from bisect import bisect
from tqdm import tqdm

def bce_with_logits(logits, labels):
    assert logits.dim() == 2
    loss = F.binary_cross_entropy_with_logits(logits, labels)
    loss *= labels.size(1) # multiply by number of QAs
    return loss

def sce_with_logits(logits, labels):
    assert logits.dim() == 2
    loss = F.cross_entropy(logits, labels.nonzero()[:, 1])
    loss *= labels.size(1)
    return loss

def compute_score_with_logits(logits, labels):
    with torch.no_grad():
        logits = torch.max(logits, 1)[1] # argmax
        one_hots = torch.zeros(*labels.size()).cuda()
        one_hots.scatter_(1, logits.view(-1, 1), 1)
        scores = (one_hots * labels)
        return scores

def lr_schedule_func_builder(cfg):
    def func(step_idx):
        if step_idx <= cfg.train.warmup_steps:
            alpha = float(step_idx) / float(cfg.train.warmup_steps)
            return cfg.train.warmup_factor * (1. - alpha) + alpha
        else:
            idx = bisect(cfg.train.lr_steps, step_idx)
            return pow(cfg.train.lr_ratio, idx)
    return func

def train(model, cfg, train_loader, val_loader, n_epochs, val_freq, out_dir):

    os.makedirs(out_dir, exist_ok=True)
    optim = torch.optim.Adamax(model.parameters(), **cfg.train.optim)

    n_train_batches = len(train_loader)

    train_score = 0.0
    loss_fn = bce_with_logits if cfg.model.loss == "logistic" else sce_with_logits

    for epoch in range(n_epochs):

        epoch_loss = 0.0

        tic_0 = time.time()

        for i, data in enumerate(train_loader):

            tic_1 = time.time()

            q_tokens = data[2].cuda()
            a_targets = data[3].cuda()
            v_features = [_.cuda() for _ in data[4:]]

            tic_2 = time.time()

            optim.zero_grad()
            logits, _ = model(q_tokens, *v_features)
            loss = loss_fn(logits, a_targets)

            tic_3 = time.time()

            loss.backward()
            if cfg.train.clip_grad: clip_grad_norm_(model.parameters(), cfg.train.max_grad_norm)
            optim.step()

            tic_4 = time.time()

            batch_score = compute_score_with_logits(logits, a_targets).sum()
            epoch_loss += float(loss.data.item() * logits.size(0))
            train_score += float(batch_score)

            del loss

            logstr = "epoch %2d batch %4d/%4d | ^ %4dms | => %4dms | <= %4dms" % \
                  (epoch + 1, i + 1, n_train_batches, 1000*(tic_2-tic_0), 1000*(tic_3-tic_2), 1000*(tic_4-tic_3))
            print("%-80s" % logstr, end="\r")

            tic_0 = time.time()

        epoch_loss /= len(train_loader.dataset)
        train_score = 100 * train_score / len(train_loader.dataset)

        logstr = "epoch %2d | train_loss: %5.2f train_score: %5.2f" % (epoch + 1, epoch_loss, train_score)
        if (epoch + 1) % val_freq == 0:
            model.eval()
            val_score, upper_bound = validate(model, val_loader)
            model.train()
            logstr += " | val_score: %5.2f (%5.2f)" % (100 * val_score, 100 * upper_bound)
        print("%-80s" % logstr)

        model_path = os.path.join(out_dir, 'model_%d.pth' % (epoch + 1))
        torch.save(model.state_dict(), model_path)

def validate(model, loader):

    score = 0
    upper_bound = 0
    n_qas = 0

    with torch.no_grad():

        for i, data in enumerate(loader):

            q_tokens = data[2].cuda()
            a_targets = data[3].cuda()
            v_features = [_.cuda() for _ in data[4:]]

            logits, _ = model(q_tokens, *v_features)
            batch_score = compute_score_with_logits(logits, a_targets)
            score += batch_score.sum()

            upper_bound += (a_targets.max(1)[0]).sum()
            n_qas += logits.size(0)

            logstr = "val batch %5d/%5d" % (i + 1, len(loader))
            print("%-80s" % logstr, end='\r')

    score = score / n_qas
    upper_bound = upper_bound / n_qas
    return score, upper_bound

infer类

import os
import time
import json
import torch
import cv2
import shutil
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm

colors = [ (175, 84, 65), (68, 194, 246), (136, 147, 65), (92, 192, 151) ]

def attention_map(im, boxes, atts, p=0.8, bgc=1.0, compress=0.85, box_color=(65, 81, 226)):

    height, width, channel = im.shape
    im = im / 255.0

    att_map = np.zeros([height, width])
    boxes = boxes.astype(np.int)
    for box, att in zip(boxes, atts):
        x1, y1, x2, y2 = box
        roi = att_map[y1:y2, x1:x2]
        roi[roi < att] = att

    att_map /= att_map.max()
    att_map = att_map ** p
    att_map = att_map * compress + (1-compress)

    att_map = cv2.resize(att_map, (int(width/16), int(height/16)))
    att_map = cv2.resize(att_map, (width, height))

    att_map = np.expand_dims(att_map, axis=2)

    bg = np.ones_like(att_map) * bgc
    att_im = im * att_map + bg * (1-att_map)

    att_im = (att_im * 255).astype(np.uint8)

    center = np.argmax(atts)
    x1, y1, x2, y2 = boxes[center]
    cv2.rectangle(att_im, (x1, y1), (x2, y2), box_color, 5)

    return att_im

def infer_visualize(model, args, cfg, ans_dict, loader):

    _, ckpt = os.path.split(args.checkpoint)
    ckpt, _ = os.path.splitext(ckpt)
    out_dir = os.path.join(args.out_dir, "%s_%s_%s_visualization" % (args.cfg_name, ckpt, args.data))
    os.makedirs(out_dir, exist_ok=True)
    model.eval()

    questions_path = cfg.data[args.data].composition[0].q_jsons[0]
    questions = json.load(open(questions_path))

    pbar = tqdm(total=args.n_batches * loader.batch_size)

    with torch.no_grad():

        for i, data in enumerate(loader):

            if i == args.n_batches: break

            question_ids = data[0]
            image_ids = data[1]
            q_tokens = data[2].cuda()
            obj_featuers = data[4].cuda()
            batch_boxes = data[5].numpy()

            logits, internals = model(q_tokens, obj_featuers)
            topdown_atts = internals[0]
            topdown_atts = topdown_atts.data.cpu().numpy()
            _, predictions = logits.max(dim=1)

            for idx in range(len(question_ids)):

                question_id = question_ids[idx]
                image_id = image_ids[idx]
                boxes = batch_boxes[idx]
                answer = ans_dict.idx2ans[predictions[idx]]
                q_entry = questions[question_id]
                topdown_att = topdown_atts[idx]
                question = q_entry["question"]
                gts = list(q_entry["answers"].items())
                gts = sorted(gts, reverse=True, key=lambda x: x[1])
                gt = gts[0][0]

                q_out_dir = os.path.join(out_dir, question_id)
                os.makedirs(q_out_dir, exist_ok=True)

                q_str = question + "\n" + "gt: %s\n" % gt + "answer: %s\n" % answer
                with open(os.path.join(q_out_dir, "qa.txt"), "w") as f: f.write(q_str)

                image_path = os.path.join(args.images_dir, "%s.jpg" % image_id)
                shutil.copy(image_path, os.path.join(q_out_dir, "original.jpg"))

                im = cv2.imread(image_path)
                att_map = attention_map(im.copy(), boxes, topdown_att)
                cv2.imwrite(os.path.join(q_out_dir, "topdown_att.jpg"), att_map)

                pbar.update(1)

创作不易 觉得有帮助请点赞关注收藏~~~

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/32275.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

用HTML+CSS+JS写的切水果小游戏它来了

前言 切水果游戏曾经是一款风靡手机的休闲游戏&#xff0c;今天要分享的就是一款网页版的切水果游戏&#xff0c; 由HTMLCSSJS实现&#xff0c;虽然功能和原版的相差太大&#xff0c;但基本的功能具备&#xff0c;效果逼真。感兴趣的小伙伴可收藏学习&#xff08;完整源码在文…

Heterogeneous Parallel Programming 异构并行编程 - UIUC伊利诺伊大学(持续更新)

Lecture 11.2 Introduction to Heterogeneous异构1.3 Portability and Scalability1.4 Introduction to CUDA 数据并行化 and 执行模型1.5 Introduction to CUDA 内存模型 and 基本函数API1.6 Introduction to CUDA Kernel-based SPMD1.7 更高维的Grid的Kernel-based SPMD例子1…

Linux的基本协议与他的堂兄堂弟

14天学习训练营导师课程&#xff1a; 互联网老辛《 符合学习规律的超详细linux实战快速入门》 努力是为了不平庸~ 学习有些时候是枯燥的&#xff0c;但收获的快乐是加倍的&#xff0c;欢迎记录下你的那些努力时刻&#xff08;学习知识点/题解/项目实操/遇到的bug/等等&#xf…

教程一 在Go使用JavaScript、HTML和CSS构建Windows、Linux、MacOSX跨平台的桌面应用

Energy是Go语言使用JavaScript、HTML和CSS构建跨平台的桌面应用程序可用于构建跨平台的桌面应用内嵌 Chromium CEF 二进制 环境安装 Energy 命令行工具 使用命令行工具自动安装Energy框架的所有依赖(CEF)&#xff0c;支持Window、Linux、MacOSX 安装过程从网络下载CEF和Energy…

二、vue基础入门

一、vue简介 1.1、什么是vue 官方给出的概念&#xff1a;Vue (读音 /vjuː/&#xff0c;类似于 view) 是一套用于构建用户界面的前端框架。 1.2、vue的特性 vue框架的特性&#xff0c;主要体现在如下两方面&#xff1a; 数据驱动视图双向数据绑定 1.2.1、数据驱动视图 在…

高灵敏度艾美捷小鼠肿瘤坏死因子α-ELISpot试剂盒

肿瘤坏死因子-a&#xff08;TNF-a&#xff09;由许多不同的细胞类型产生&#xff0c;例如单核细胞&#xff0c;巨噬细胞&#xff0c;T细胞和B细胞。在TNF-a的许多作用中&#xff0c;有针对细菌感染&#xff0c;细胞生长调节&#xff0c;免疫系统调节和参与败血症性休克的保护。…

现代气象仪器 | 太阳辐射测量

南京信息工程大学 实验&#xff08;实习&#xff09;报告 实验&#xff08;实习&#xff09;名称 现代气象仪器 实验&#xff08;实习&#xff09;日期 10.28 得分 指导老师 学院 电信院 专业 电子信息工程 年级 2020 班次 4 姓名 学号 20208327 实验…

万字博客带你全面剖析Spring的依赖注入

1.写在前面 前面的博客我们已经写了Spring的依赖查找&#xff0c;这篇博客我们来了解写Spring的依赖注入。 2.依赖注入的模式和类型 手动模式 - 配置或者编程的方式&#xff0c; 提前安排注入规则 XML 资源配置元信息Java 注解配置元信息API 配置元信息 自动模式 - 实现方…

华为机试 - 最大括号深度

目录 题目描述 输入描述 输出描述 用例 题目解析 算法源码 题目描述 现有一字符串仅由 ‘(‘&#xff0c;’)’&#xff0c;{‘&#xff0c;’}’&#xff0c;[‘&#xff0c;’]’六种括号组成。 若字符串满足以下条件之一&#xff0c;则为无效字符串&#xff1a; ①…

【MySQL】拿来即用 —— MySQL中的数据类型

个人简介&#xff1a;Java领域新星创作者&#xff1b;阿里云技术博主、星级博主、专家博主&#xff1b;正在Java学习的路上摸爬滚打&#xff0c;记录学习的过程~ 个人主页&#xff1a;.29.的博客 学习社区&#xff1a;进去逛一逛~ MySQL数据类型⚪熟悉SQL一、MySQL数据类型总结…

设备树和设备树语法

设备树 驱动代码只负责处理驱动的逻辑&#xff0c;而关于设备的具体信息存放到设备树文件中。许多硬件设备信息可以直 接通过它传递给 Linux&#xff0c;而不需要在内核中堆积大量的冗余代码。 设备树&#xff0c;将这个词分开就是“设备”和“树”&#xff0c;描述设备树的文…

【计算机毕业设计】22.毕业设计选题系统ssm源码

一、系统截图&#xff08;需要演示视频可以私聊&#xff09; 引言 近年来&#xff0c;电子商务发展的愈趋成熟使得人们的消费方式以及消费观念发生巨大改变&#xff0c;网上竞拍的拍卖模式随之发展起来。大学拍卖网旨在为湘大学生提供一个线上拍卖的交易平台。平台展示的商品大…

【American English】美式发音,英语发音,美国音音标列表及发音

首先声明&#xff0c;网上各种英式发音和美式发音的教程&#xff0c;而我的目的是寻找美式发音。但是自己现在也是在不断地找寻中&#xff0c;所以资料找错了请莫怪。另外&#xff0c;资料顺序采用部分倒叙&#xff0c;不喜请勿吐槽。 文章目录发音示意图49. [](https://www.bi…

百度地图有感

以前总认为坚持会让我们变强大&#xff0c;但是长大后发现&#xff0c;让我们强大的&#xff0c;是放下。 生活也许就是这样&#xff0c;多一分经验便少一分幻想&#xff0c;以实际的愉快平衡现实的痛苦。 百度地图开放平台 百度地图入门指南 百度地图开发指南 百度地图API文…

性早熟和微生物群:性激素-肠道菌群轴的作用

谷禾健康 肠道菌群 & 性激素 青春期是生命的一个关键阶段&#xff0c;与性成熟相关的生理变化有关&#xff0c;是一个受多种内分泌和遗传控制调控的复杂过程。 青春期发育可以在适当的时候&#xff0c;早熟或延迟。 未经治疗的性早熟的孩子通常不会达到成年身高的全部潜力。…

Activity的最佳实践

文章目录Activity的最佳实践知晓当前是在哪一个Activiy随时随地退出程序启动Activity的最佳写法Activity的最佳实践 知晓当前是在哪一个Activiy 创建一个BaseActivity类,继承AppCompatActivity类.重写onCreate方法 open class BaseActivity : AppCompatActivity() {override…

xilinx PL测 DP 点屏 /接收(二)--RX

环境&#xff1a; a)硬件&#xff1a;官方ZCU106开发板 , tb-fmch-vfmc-dp子卡。 b)软件&#xff1a;vivado2021.1&#xff0c;vitis2021.1&#xff0c;裸机程序。 1、官方例程&#xff1a; 2、DP RX IP &#xff1a; 3、DP RX寄存器&#xff1a; 4、时钟&#xff1a; 5、像素&…

CentOS 6.6系统怎么安装?CentOS Linux系统安装配置图解教程

服务器相关设置如下&#xff1a; 操作系统&#xff1a;CentOS 6.6 64位 IP地址&#xff1a;192.168.21.129 网关&#xff1a;192.168.21.2 DNS&#xff1a;8.8.8.8 8.8.4.4 备注&#xff1a; CentOS 6.6系统镜像有32位和64位两个版本&#xff0c;并且还有专门针对服务器优化过的…

【端到端存储解决方案】Weka,让企业【文件存储】速度飞起来!

一、HK-Weka概述 虹科WekaIO&#xff08;简称HK-Weka&#xff09;是一个可共享、可扩展的文件存储系统解决方案&#xff0c;其并行文件系统WekaFS支持NVMeoF的flash-native并行文件系统、比传统的NAS存储及本地存储更快。 HK-Weka后端主机被配置为集群&#xff0c;它与安装在应…

在Mysql中新建序列Sequence

在Oracle数据库中想要一个连续的自增数据类型的值&#xff0c;可以通过创建一个sequence来实现。而在Mysql数据库中并没有sequence&#xff0c;如想要在Mysql中像Oracle那样使用序列&#xff0c;该如何操作呢&#xff1f;&#xff08;可以使用mysql中的自增主键&#xff09; 1、…