基于昇思MindSpore实现使用胶囊网络的图像描述生成算法

news2025/2/27 7:28:52

基于昇思MindSpore实现使用胶囊网络的图像描述生成算法

项目链接

https://github.com/Liu-Yuanqiu/acn_mindspore

01

项目描述

1.1 图像描述生成算法

人类可以轻易的使用语言来描述所看到的场景,但是计算机却很难做到,图像描述生成任务的目的就是教会计算机如何描述所看到的内容,其中涉及到了对视觉信息的处理以及如何生成符合人类语言习惯的语句,这两方面也分别对应人工智能的两大领域——计算机视觉和自然语言处理。图像描述生成任务不仅仅在算法研究上具有重要的意义,同时在盲人助理、图文转换等实际应用场景也有广泛的应用。

1.2 胶囊网络

因为卷积是局部连接和参数共享的,并没有考虑到特征之间的相互关联和相互位置关系,即没有学习到特征之间的位置信息;胶囊网络同时将空间信息和物体存在概率编码到胶囊向量中,通过动态路由决定当前胶囊输入到哪个更加高级的胶囊,并通过非线性激活函数在归一化的同时保持方向不变,从而使胶囊网络学习到有用的特征以及它们之间的关系。

02

网络结构

具体来说,图像描述生成算法通常使用编码器-解码器结构,如图1所示,其中编码器部分提取图片的视觉特征,通过多种注意力机制捕获视觉特征之间的关系并生成输出,其中包括了双线性池化模块和注意力胶囊模块,双线性池化模块通过对特征进行挤压-奖励操作获取特征之间的二阶交互,注意力胶囊模块将每一个视觉特征看作一个胶囊,从而捕获特征之间的相对位置关系;解码器部分根据视觉特征通过循环神经网络生成对应单词,解码器使用自回归方式解码,以上一时刻单词作为输入,逐个生成单词组成最终的描述语句。

图1 使用胶囊网络的图像描述生成算法框架

算法使用交叉熵损失监督训练,损失函数表示为:

其中表示前i-1步生成的单词,表示图像特征,表示在两者基础上生成本时刻单词的概率。

03

开发流程

在具体实现过程中,我们先进行数据集适配,然后对网络模型进行开发,之后编写了训练代码。昇思MindSpore拥有良好的基础算子支持和简洁的API接口,不管是底层的矩阵运算还是高层的网络模型封装都可以完美的支持,极大的方便了开发和调试的过程。

3.1 数据集适配

因为图像描述生成任务涉及到对图片和文本的处理,因此训练数据的格式也具有多种类型,其中图像特征是使用Faster RCNN预处理好的矩阵,词表以txt文件存储,描述语句使用数组进行存储,因此我们使用自定义数据集来生成训练批次,首先定义一个数据类,类中实现__getitem__作为可迭代对象返回数据,使用GeneratorDataset组装数据类,从而进行混洗和抽取批次。

coco_train_set = CocoDataset(
  image_ids_path = os.path.join(args.dataset_path, 'txt', 'coco_train_image_id.txt'),
  input_seq = os.path.join(args.dataset_path, 'sent', 'coco_train_input.pkl'),
  target_seq = os.path.join(args.dataset_path, 'sent', 'coco_train_target.pkl'),
  att_feats_folder = os.path.join(args.dataset_path, 'feature', 'up_down_36'),
  seq_per_img = args.seq_per_img,
  max_feat_num = -1)
dataset_train = ds.GeneratorDataset(coco_train_set, 
  column_names=["indices", "input_seq", "target_seq", "att_feats"],
  shuffle=True,
  python_multiprocessing=True,
  num_parallel_workers=args.works)
dataset_train = dataset_train.batch(args.batch_size, drop_remainder=True)

3.2 网络模型开发

网络模型均使用nn.Cell作为基类进行开发,根据网络中模型的不同功能实现了模型整体类CapsuleXlan、编码器Encoder和解码器Decoder等等,其中涉及到线性层、激活层、矩阵操作、张量操作等等,具体网络结构定义如下。

class CapsuleXlan(nn.Cell):
    def __init__():
      self.encoder = Encoder()
      self.decoder = Decoder()
class Encoder(nn.Cell):
    def __init__():
      self.encoder = nn.CellList([])
      for _ in range(layer_num):
        sublayer = CapsuleLowRankLayer()
        self.encoder.append(sublayer)
class Decoder(nn.Cell):
    def __init__():
      self.decoder = nn.CellList([])
      for _ in range(layer_num):
        sublayer = LowRankLayer()
        self.decoder.append(sublayer)
class CapsuleLowRankLayer(nn.Cell):
    def __init__():
      self.attn_net = Capsule()
class LowRankLayer(nn.Cell):
    def __init__():
      self.attn_net = SCAtt()
class SCAtt(nn.Cell):
    def __init__():
class Capsule(nn.Cell):
    def __init__():

其中CapsuleXlan中包含了Encoder和Decoder,通过Encoder编码视觉信息,通过Decoder生成描述语句;Encoder中包含多层CapsuleLowRankLayer,每一层CapsuleLowRankLayer对特征进行处理后输入Capsule进行计算,结果处理后返回上一层;对于Decoder来说同理。

对于具体的运算,以Decoder为例,首先我们定义子层和对应的线性层,提前需要使用的算子操作,然后在construct中使用预先定义的各个层对输入进行处理,计算结果经过线性层和层归一化后返回。

class Decoder(nn.Cell):
    def __init__(self, layer_num, embed_dim, att_heads, att_mid_dim, att_mid_drop):
      super(Decoder, self).__init__()
      self.decoder = nn.CellList([])
      for _ in range(layer_num):
          sublayer = LowRankLayer(embed_dim=embed_dim, att_heads=8, 
                       att_mid_dim=[128, 64, 128], att_mid_drop=0.9)
          self.decoder.append(sublayer)
      self.proj = nn.Dense(embed_dim * (layer_num + 1), embed_dim)
      self.layer_norm = nn.LayerNorm([embed_dim])
      self.concat_last = ops.Concat(-1)

    def construct(self, gv_feat, att_feats, att_mask):
      batch_size =  att_feats.shape[0]
      feat_arr = [gv_feat]
      for i, decoder_layer in enumerate(self.decoder):
          gv_feat = decoder_layer(gv_feat, att_feats, att_mask, 
          gv_feat, att_feats)
      feat_arr.append(gv_feat)
      gv_feat = self.concat_last(feat_arr)
      gv_feat = self.proj(gv_feat)
      gv_feat = self.layer_norm(gv_feat)
      return gv_feat, att_feats

同时,对交叉熵损失进行实现并将其和网络模型组装到一个类中:

class CapsuleXlanWithLoss(nn.Cell):
    def __init__(self, model):
      super(CapsuleXlanWithLoss, self).__init__()
      self.model = model
      self.ce = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
    def construct(self, indices, input_seq, target_seq, att_feats):
      logit = self.model(input_seq, att_feats)
      logit = logit.view((-1, logit.shape[-1]))
      target_seq = target_seq.view((-1))
      mask = (target_seq > -1).astype("float32")
      loss = self.ce(logit, target_seq)
      loss = ops.ReduceSum(False)(loss * mask) / mask.sum()
      return loss

3.3 训练

数据集和网络模型都已经准备完毕,在训练过程中只需要将模型、优化器、数据集、回调函数装配起来即可,在这里我们首先定义优化器,使用Adam优化器进行优化,然后定义回调函数,使用LossMonitor、TimeMonitor、ModelCheckpoint和SummaryCollector分别监听损失函数、计算每一步所用时间、保存模型以及保存可视化数据。最终使用nn.Model将四部分装配在一起进行训练,并可以通过MindInsight观察训练损失和参数的变化。

net = CapsuleXlan()
net = CapsuleXlanWithLoss(net)
warmup_lr = nn.WarmUpLR(args.lr, args.warmup)   
optim = nn.Adam(params=net.trainable_params(), learning_rate=warmup_lr, beta1=0.9, beta2=0.98, eps=1.0e-9)
model = ms.Model(network=net, optimizer=optim)

loss_cb = LossMonitor(per_print_times=1)
time_cb = TimeMonitor(data_size=step_per_epoch)
ckpoint_cb = ModelCheckpoint(prefix='ACN', 
    directory=os.path.join(args.result_folder, 'checkpoints'))
summary_cb = SummaryCollector(summary_dir=os.path.join(args.result_folder, 'summarys'))
cbs = [loss_cb, time_cb, ckpoint_cb, summary_cb]
model.train(epoch=args.epochs, train_dataset=dataset_train, callbacks=cbs)

04

模型效果

模型基于昇腾软硬件平台进行训练和测试,在硬件层面,使用Ascend 910 NPU作为训练设备,在软件层面,以CANN作为驱动,使用昇思MindSpore框架实现,以MindInsight可视化调试调优工具实时查看损失变化,使用几十万对数据进行训练,达到了良好的训练和推理效果,能够为图片生成完整准确的描述语句。

同时,方案获得了华为技术认证书,根据此方案申请的专利也进入到公布实审阶段。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/105295.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

昇思MindSpore动静结合中list和dict方法实现

01 概述 静态图和动态图是神经学习框架中的重要概念,昇思MindSpore同时支持动态图和静态图两种模式,在动态图与静态图的结合方面做了很多工作。本文以昇思MindSpore框架中图模式下list和dict的实现方式为例,介绍昇思MindSpore框架中的动静结…

C与C++如何互相调用

个人主页:董哥聊技术我是董哥,嵌入式领域新星创作者创作理念:专注分享高质量嵌入式文章,让大家读有所得!文章目录1、为什么会有差异?2、extern "C"3、C调用C正确方式4、C调用C5、总结在项目开发过…

[第十二届蓝桥杯/java/算法]C——卡片

🧑‍🎓个人介绍:大二软件生,现学JAVA、Linux、MySQL、算法 💻博客主页:渡过晚枫渡过晚枫 👓系列专栏:[编程神域 C语言],[java/初学者],[蓝桥杯] &#x1f4d6…

中外法律文献查找下载常用数据库大盘点

中外法律文献查找下载常用数据库有: 一、Westlaw(法律全文数据库) 是法律出版集团Thomson Legal and Regulator’s于1975年开发的,为国际法律专业人员提供的互联网的搜索工具。 Westlaw International其丰富的资源来自法律、法规…

图(Graph)详解 - 数据结构

文章目录:图的基本概念图的存储结构邻接矩阵邻接矩阵的实现邻接表邻接表实现图的遍历图的广度优先搜索(BFS)图的深度优先搜索(DFS)最小生成树Kruskal算法Prim算法最短路径单源最短路径 - Dijkstra算法单源最短路径 - B…

Linux学习-91-Discuz论坛安装

17.22 Discuz论坛安装 通过 Discuz! 搭建社区论坛、知识付费网站、视频直播点播站、企业网站、同城社区、小程序、APP、图片素材站,游戏交流站,电商购物站、小说阅读、博客、拼车系统、房产信息、求职招聘、婚恋交友等等绝大多数类型的网站。Discuz!自2…

《教养的迷思》

在读《穷查理宝典》时,查理芒格在有一讲,专门谈及《教养的迷思》一书,说到作者朱迪斯哈里斯。查理芒格认为哈里斯在探求真理的道路上走得很顺利,取得成功的因素之一就是她热衷于摧毁自己的观念。 朱迪斯在书的开端首先严肃地纠正了…

【案例教程】无人机生态环境监测、图像处理与GIS数据分析综合实践

【查看原文】无人机生态环境监测、图像处理与GIS数据分析综合实践技术应用 构建“天空地”一体化监测体系是新形势下生态、环境、水文、农业、林业、气象等资源环境领域的重大需求,无人机生态环境监测在一体化监测体系中扮演着极其重要的角色。通过无人机航空遥感技…

Fabric系列 - 多通道技术(Muti-channel)

可在节点,通道和联盟级别上配置。 一个Fabric网络中能够运行多个账本,每个通道间的逻辑相互隔离不受影响,如下图所示,每种颜色的线条代表一个逻辑上的通道,每个Peer节点可以加入不同的通道,每个通道都拥有…

AI编译器XLA调研

文章目录一、XLA简介二、XLA在TensorFlow中的应用2.1 XLA是什么?(tensorflow\compiler\xla)2.2 TensorFlow怎样转化为XLA (tensorflow\compiler\tf2xla)2.3 JIT(just in time) 即时编译 (tensorflow\compil…

【大数据技术Hadoop+Spark】Flume、Kafka的简介及安装(图文解释 超详细)

Flume简介 Flume是Cloudera提供的一个高可用、高可靠、分布式的海量日志采集、聚合和传输的系统,Flume支持在日志系统中定制各类数据发送方,用于收集数据;同时,Flume提供对数据进行简单处理,并写到各种数据接受方&…

NLP学习笔记(四) Seq2Seq基本介绍

大家好,我是半虹,这篇文章来讲序列到序列模型 (Sequence To Sequence, Seq2Seq) 本文写作思路如下: 从循环神经网络的应用场景引入,介绍循环神经网络作为编码器和解码器使用,最后是序列到序列模型 在之前的文章中&am…

微信消息收发与微信内部emoji表情转义

微信消息收发与微信内部emoji表情转义 目录 微信内部emoji表情转义与消息收发 一、概述 二、常用标准emoji表情字符、微信内部转义符、unicode对照表 1、比如 2、微信聊天窗口emoji表情字符 2.1、PC端表情选择,01~03排: 2.2、PC端表情选择&#…

华为IMC培训——通信基础

目录 一、华为设备图标 二、数据的传递 三、专业术语 四、网络设备及相关知识 五、OSI七层模型 六、TCP和UDP数据报格式 七、TCP的三次握手 八、 TCP窗口滑动机制 一、华为设备图标 AP:相当于家用路由器一般配和AC使用。 AC和AP的区别_wangzhibo_csdn的博客…

创意被盗用,这3个加水印方法,让照片刻上我们专属印记

一般我们为了保护自己的图片不被别人盗用,都会选择在图片上刻上专属印记。那么便是加水印方法,它包含两种:文字水印和图片水印。想知道怎么给图片添加水印吗?其实有很多种法子可以做到,下面就由我来分享这3个简单好用的…

代码随想录刷题记录 day48 两个字符串的删除操作+编辑距离

代码随想录刷题记录 day48 两个字符串的删除操作编辑距离 583. 两个字符串的删除操作 思想 两个元素都能删除了,还是考虑第i-1个字符和第j-1个字符是不是相同的,不相同的话考虑三种情况,删除i-1;删除j-1,同时删除 1…

css实现鼠标禁用(鼠标滑过显示红色禁止符号)

css实现鼠标禁用(鼠标滑过显示红色禁止符号)创作背景css鼠标禁用创作背景 从本文开始,将会用三篇文章来一步一步实现vueantdts实战后台管理系统中table表格的不可控操作。中间会补充两篇css知识文章,方便后续功能的实现。实现表格…

非零基础自学Golang 第14章 反射 14.2 基本用法 14.2.2 获取类型的值 14.2.3 使用反射调用函数

非零基础自学Golang 文章目录非零基础自学Golang第14章 反射14.2 基本用法14.2.2 获取类型的值14.2.3 使用反射调用函数第14章 反射 14.2 基本用法 14.2.2 获取类型的值 Go语言使用reflect.TypeOf来获取类型信息,使用reflect.ValueOf来获取变量值的信息。 refle…

云原生|kubernetes|CKA真题解析-------(6-10题)

第六题: service配置 解析: 考察两个知识点: deployment控制器内的port命名 暴露一个pod内的端口到新建的服务内的 这里有一个需要注意的地方,没有告诉你deployment控制器在哪个namespace。假设这个front-end这个pod是在A这个…

前端CSS Flex布局8大重难点知识,收藏起来吧

2009年,W3C提出了一种新的方案—-Flex布局,可以简便、完整、响应式地实现各种页面布局。目前,它已经得到了所有浏览器的支持,这意味着,现在就能很安全地使用这项功能。 Flex布局将成为未来布局的首选方案。这也是学习前…