机器学习-基于attention机制来实现对Image Caption图像描述实验

news2024/12/22 4:44:35

机器学习-基于attention机制来实现对Image Caption图像描述实验

实验目的

基于attention机制来实现对Image Caption图像描述

实验内容

1.了解一下RNN的Encoder-Decoder结构

在最原始的RNN结构中,输入序列和输出序列必须是严格等长的。但在机器翻译等任务中,源语言句子的长度和目标语言句子的长度往往不同,因此我们需要将原始序列映射为一个不同长度的序列。Encoder-Decoder模型就解决了这样一个长度不一致的映射问题。

1

2.模型架构训练

在Image Caption输入的图像代替了之前机器翻译中的输入的单词序列,图像是一系列的像素值,我们需要从使用图像特征提取常用的CNN从图像中提取出相应的视觉特征,然后使用Decoder将该特征解码成输出序列,下图是论文的网络结构,特征提取采用的是CNN,Decoder部分,将RNN换成了性能更好的LSTM,输入还是word embedding,每步的输出是单词表中所有单词的概率。

实验数据和程序清单

import json
 
# 加载数据集标注
with open("annotations/captions_train2014.json", "r") as f:
    annotations = json.load(f)
 
# 提取图像文件名和描述
image_path_to_caption = {}
for val in annotations["annotations"]:
    caption = f"<start> {val['caption']} <end>"
    image_path = "train2014/" + "COCO_train2014_" + "%012d.jpg" % (val["image_id"])
    if image_path in image_path_to_caption:
        image_path_to_caption[image_path].append(caption)
    else:
        image_path_to_caption[image_path] = [caption]
 
image_paths = list(image_path_to_caption.keys())
归一化处理
import tensorflow as tf
 
def load_image(image_path):
    img = tf.io.read_file(image_path)
    img = tf.image.decode_jpeg(img, channels=3)
    img = tf.image.resize(img, (299, 299))
    img = tf.keras.applications.inception_v3.preprocess_input(img)
return img, image_path
模型构建
from tensorflow.keras.applications import InceptionV3
 
encoder = InceptionV3(weights="imagenet", include_top=False)
from tensorflow.keras.layers import Input, Embedding, LSTM, Dense
from tensorflow.keras.models import Model
 
embedding_dim = 256
vocab_size = 10000  # 您可以根据需要调整词汇表大小
max_length = 40  # 您可以根据需要调整最大描述长度
 
# 解码器输入
input_caption = Input(shape=(max_length,))
embedding = Embedding(vocab_size, embedding_dim)(input_caption)
lstm_output = LSTM(256)(embedding)
output_caption = Dense(vocab_size, activation="softmax")(lstm_output)
 
# 定义解码器模型
decoder = Model(inputs=input_caption, outputs=output_caption)
模型训练
optimizer = tf.keras.optimizers.Adam()
loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction="none")
 
def loss_function(real, pred):
    mask = tf.math.logical_not(tf.math.equal(real, 0))
    loss_ = loss_object(real, pred)
 
    mask = tf.cast(mask, dtype=loss_.dtype)
    loss_ *= mask
 
return tf.reduce_mean(loss_)
@tf.function
def train_step(img_tensor, target):
    loss = 0
    hidden = decoder.reset_state(batch_size=target.shape[0])
 
    dec_input = tf.expand_dims([tokenizer.word_index["<start>"]] * target.shape[0], 1)
 
    with tf.GradientTape() as tape:
        features = encoder(img_tensor)
        for i in range(1, target.shape[1]):
            predictions = decoder([features, hidden, dec_input])
            loss += loss_function(target[:, i], predictions)
 
            dec_input = tf.expand_dims(target[:, i], 1)
 
    total_loss = loss / int(target.shape[1])
 
    trainable_variables = encoder.trainable_variables + decoder.trainable_variables
    gradients = tape.gradient(loss, trainable_variables)
    optimizer.apply_gradients(zip(gradients, trainable_variables))
 
return loss, total_loss
import time
 
epochs = 10
batch_size = 64
buffer_size = 1000
 
dataset = tf.data.Dataset.from_tensor_slices((image_paths, captions))
dataset = dataset.map(load_image, num_parallel_calls=tf.data.experimental.AUTOTUNE)
dataset = dataset.shuffle(buffer_size).batch(batch_size)
dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
 
for epoch in range(epochs):
    start = time.time()
    total_loss = 0
 
    for (batch, (img_tensor, target)) in enumerate(dataset):
        batch_loss, t_loss = train_step(img_tensor, target)
        total_loss += t_loss
 
        if batch % 100 == 0:
            print(f"Epoch {epoch+1} Batch {batch} Loss {batch_loss.numpy() / int(target.shape[1]):.4f}")
 
    print(f"Epoch {epoch+1} Loss {total_loss/len(image_paths):.6f}")
    print(f"Time taken for 1 epoch: {time.time() - start:.2f} sec\n")

可视化:


import matplotlib.pyplot as plt
import numpy as np

def plot_attention(image_path, result, attention_plot):
   img = plt.imread(image_path)

   fig = plt.figure(figsize=(10, 10))
   len_result = len(result)
   for i in range(len_result):
       temp_att = np.resize(attention_plot[i], (8, 8))
       grid_size = max(np.ceil(len_result / 2), 2)
       ax = fig.add_subplot(grid_size, grid_size, i + 1)
       ax.set_title(result[i])
       imgplot = ax.imshow(img)
       ax.imshow(temp_att, cmap="gray", alpha=0.6, extent=imgplot.get_extent())

   plt.tight_layout()
   plt.show()

plot_attention(image_path, result, attention_plot)


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1351108.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

你真的会用Pycharm?这本耗时2年编写的《Pycharm中文指南》,解决你的困惑!

很多读者应该非常了解 JetBrains 开发的 PyCharm 了&#xff0c;它差不多是 Python 最常用的 IDE之一。PyCharm的优势在于可以为我们节省大量时间、管理代码&#xff0c;并完成大量其他任务&#xff0c;如 debug 和可视化等。 需要最新专业版PyCharm永久使用权限的扫码获取 那…

华为高级Java面试真题

今年IT寒冬&#xff0c;大厂都裁员或者准备裁员&#xff0c;作为开猿节流主要目标之一&#xff0c;我们更应该时刻保持竞争力。为了抱团取暖&#xff0c;林老师开通了《知识星球》&#xff0c;并邀请我阿里、快手、腾讯等的朋友加入&#xff0c;分享八股文、项目经验、管理经验…

解决jenkins的Exec command命令不生效,或者执行停不下来的问题

Jenkins构建完后将war包通过 Publish Over SSH 的插件发布到服务器上&#xff0c;在服务器上执行脚本时&#xff0c;脚本中的 nohup 命令无法执行&#xff0c;并不生效&#xff0c;我配置的Exec command命令是后台启动一个war包&#xff0c;并输出日志文件。 nohup java -jar /…

第二十三章 反射

第二十三章 反射 1.反射机制问题2.反射快速入门3.发射原理图4.反射相关类5.发射调用优化6.Class类分析7.Class常用方法8.获取Class对象的6种方式9.哪些类型有Class对象10.动态和静态加载11.类加载流程图12.类加载五个阶段&#xff08;1&#xff09;13.类加载五个阶段&#xff0…

OpenGL FXAA抗锯齿算法(Qt,Consloe版本)

文章目录 一、简介二、实现代码三、实现效果参考资料一、简介 之前已经提供了使用VCG读取Mesh的方式,接下来就需要针对读取的网格数据进行一些渲染操作了。在绘制Mesh数据时总会遇到图形的抗锯齿问题,OpenGL本身已经为我们提供了一种MSAA技术,但该技术对于一些实时渲染性能有…

【代数学作业1完整版-python实现GNFS一般数域筛】构造特定的整系数不可约多项式:涉及素数、模运算和优化问题

代数学作业1-完整版&#xff1a;python实现GNFS一般数域筛 写在最前面背景在GNFS算法中选择互质多项式时&#xff0c;需要考虑哪些关键因素&#xff0c;它们对算法的整体运行时间有何影响? 练习1题目题目分析Kleinjung方法简介通用数域筛法&#xff08;GNFS&#xff09;中的多…

论文阅读--EFFICIENT OFFLINE POLICY OPTIMIZATION WITH A LEARNED MODEL

作者&#xff1a;Zichen Liu, Siyi Li, Wee Sun Lee, Shuicheng YAN, Zhongwen Xu 论文链接&#xff1a;Efficient Offline Policy Optimization with a Learned Model | OpenReview 发表时间&#xff1a; ICLR 2023年1月21日 代码链接&#xff1a;https://github.com/s…

Nginx 代理静态资源,解决跨域问题

&#x1f602; 背景&#xff1a;移动端 H5 项目&#xff0c;依赖了一个外部的 JS 文件。访问时&#xff0c;出现跨域&#xff0c;导致请求被 block。 当前域名&#xff1a;https://tmcopss.test.com要访问的 JS 文件&#xff1a;https://tm.test.com/public/scripts/y-jssdk.j…

下载与安装Python解释器

文章目录 一. 下载Python解释器二. 安装Python解释器总结 一. 下载Python解释器 下载地址&#xff1a;https://www.python.org/downloads/release/python-372/ 查找目标文件&#xff1a;Windows x86-64 executable installer – 单击即可下载。 这里贴出我下载好的网盘链接…

从0到1实战,快速搭建SpringBoot工程

目录 一、前言 二、准备工作 2.1 安装JDK 2.2 安装Maven 2.3 下载IDEA 三、从0到1搭建 3.1 创建SpringBoot工程 3.2 运行SpringBoot工程 四、总结 一、前言 SpringBoot是一个在Spring框架基础上构建的开源框架&#xff0c;不仅继承了Spring框架原有的优秀特性&#x…

软件测试/测试开发丨Python 模块与包

python 模块与包 python 模块 项目目录结构 组成 package包module模块function方法 模块定义 定义 包含python定义和语句的文件.py文件作为脚本运行 导入模块 import 模块名from <模块名> import <方法 | 变量 | 类>from <模块名> import * 注意&a…

JAVA反序列化之URLDNS链分析

简单介绍下urldns链 在此之前最好有如下知识&#xff0c;请自行bing or google学习。 什么是序列化 反序列化 &#xff1f;特点&#xff01; java对象反射调用&#xff1f; hashmap在java中是一种怎样的数据类型&#xff1f; dns解析记录有那…

canvas绘制网格线示例

查看专栏目录 canvas示例教程100专栏&#xff0c;提供canvas的基础知识&#xff0c;高级动画&#xff0c;相关应用扩展等信息。canvas作为html的一部分&#xff0c;是图像图标地图可视化的一个重要的基础&#xff0c;学好了canvas&#xff0c;在其他的一些应用上将会起到非常重…

ArkTS - @Builder自定义构建函数

这个Builder作用就是可以把组件样式抽离出来&#xff0c;写成公共组件&#xff0c;下边记录下全局自定义构建函数用法及注意的地方。 官方文档&#xff1a;开发者可以将重复使用的UI元素抽象成一个方法&#xff0c;在build方法里调用。 一、用法 下边代码&#xff0c;我在Co…

LangChain(0.0.340)官方文档十一:Agents之Agent Types

LangChain官网、LangChain官方文档 、langchain Github、langchain API文档、llm-universe《Agent Types》、《Examples using AgentType》 文章目录 一、快速入门1.1 概念1.2 基本示例1.2.1 配置LangSmith1.2.2 使用LCEL语法创建Agents1.2.3 使用自定义runtime执行1.2.4 使用A…

uniapp中用户登录数据的存储方法探究

Hello大家好&#xff01;我是咕噜铁蛋&#xff01;作为一个博主&#xff0c;我们经常需要在应用程序中实现用户登录功能&#xff0c;并且需要将用户的登录数据进行存储&#xff0c;以便在多次使用应用程序时能够方便地获取用户信息。铁蛋通过科技手段帮大家收集整理了些知识&am…

八大算法排序@选择排序

目录 选择排序概念算法思想示例步骤1步骤2步骤...n最后一步 代码实现时间复杂度空间复杂度特性总结 选择排序 概念 选择排序&#xff08;Selection Sort&#xff09;是一种简单直观的排序算法。基本思想是在未排序的序列中找到最小&#xff08;或最大&#xff09;元素&#xf…

BIO和NIO编程(待完善)

目录 IO模型 BIO NIO 常见问题 IO模型 Java共支持3种网络编程IO模式&#xff1a;BIO&#xff0c;NIO&#xff0c;AIO BIO 同步阻塞模型&#xff0c;一个客户端连接对应一个处理线程 代码示例&#xff1a; Server端&#xff1a; public class BioServer {private static …

K8S集群部署MySql

挂载MySQL数据卷 在k8s集群中挂载MySQL数据卷 需要安装一个NFS。 在主节点安装NFS yum install -y nfs-utils rpcbind 在主节点创建目录 mkdir -p /nfs chmod 777 /nfs 更改归属组与用户 chown -R nfsnobody:nfsnobody /nfs 配置共享目录 echo "/nfs *(insecure,rw,s…

工业相机如何实现实时和本地Raw格式图像和Bitmap格式图像的保存和相互转换(C#代码,UI界面版)

工业相机如何实现实时和本地Raw图像和Bitmap图像的保存和相互转换&#xff08;C#代码&#xff0c;UI界面版&#xff09; 工业相机图像格式工业相机实现Raw图像和Bitmap图像的保存和转换的技术背景在相机SDK中获取图像转换图像的代码分析工业相机回调函数里保存Bitmap图像数据工…