基于昇思MindSpore,实现使用胶囊网络的图像描述生成算法

news2024/11/16 18:03:21

项目链接

https://github.com/Liu-Yuanqiu/acn_mindspore

项目描述

图像描述生成算法

人类可以轻易的使用语言来描述所看到的场景,但是计算机却很难做到,图像描述生成任务的目的就是教会计算机如何描述所看到的内容,其中涉及到了对视觉信息的处理以及如何生成符合人类语言习惯的语句,这两方面也分别对应的人工智能的两大领域——计算机视觉和自然语言处理。图像描述生成任务不仅仅在算法研究上具有重要的意义,同时在盲人助理、图文转换等实际应用场景也有广泛的应用。

胶囊网络

因为卷积是局部连接和参数共享的,并没有考虑到特征之间的相互关联和相互位置关系,即没有学习到特征之间的位姿信息;胶囊网络同时将空间信息和物体存在概率编码到胶囊向量中,通过动态路由决定当前胶囊输入到哪个更加高级的胶囊,并通过非线性激活函数在归一化的同时保持方向不变,从而使胶囊网络学习到有用的特征以及它们之间的关系。

网络结构

具体来说,图像描述生成算法通常使用编码器-解码器结构,如图1所示,其中编码器部分提取图片的视觉特征,通过多种注意力机制捕获视觉特征之间的关系并生成输出,其中包括了双线性池化模块和注意力胶囊模块,双线性池化模块通过对特征进行挤压-奖励操作获取特征之间的二阶交互,注意力胶囊模块将每一个视觉特征看作一个胶囊,从而捕获特征之间的相对位置关系;解码器部分根据视觉特征通过循环神经网络生成对应单词,解码器使用自回归方式解码,以上一时刻单词作为输入,逐个生成单词组成最终的描述语句。

图1 使用胶囊网络的图像描述生成算法框架

算法使用交叉熵损失监督训练,损失函数表示为:

其中表示前i-1步生成的单词,表示图像特征,表示在两者基础上生成本时刻单词的概率。

开发流程

在具体实现过程中,我们先进行数据集适配,然后对网络模型进行开发,之后编写了训练代码。MindSpore拥有良好的基础算子支持和简洁的API接口,不管是底层的矩阵运算还是高层的网络模型封装都可以完美的支持,极大的方便了开发和调试的过程。

数据集适配

因为图像描述生成任务涉及到对图片和文本的处理,因此训练数据的格式也具有多种类型,其中图像特征是使用Faster RCNN预处理好的矩阵,词表以txt文件存储,描述语句使用数组进行存储,因此我们使用自定义数据集来生成训练批次,首先定义一个数据类,类中实现__getitem__作为可迭代对象返回数据,使用GeneratorDataset组装数据类,从而进行混洗和抽取批次。

coco_train_set = CocoDataset(image_ids_path = os.path.join(args.dataset_path, 'txt', 'coco_train_image_id.txt'),
input_seq = os.path.join(args.dataset_path, 'sent', 'coco_train_input.pkl'),
target_seq = os.path.join(args.dataset_path, 'sent', 'coco_train_target.pkl'),
att_feats_folder = os.path.join(args.dataset_path, 'feature', 'up_down_36'),
seq_per_img = args.seq_per_img,
max_feat_num = -1)
dataset_train = ds.GeneratorDataset(coco_train_set,
column_names=["indices", "input_seq", "target_seq", "att_feats"],
shuffle=True,
python_multiprocessing=True,
num_parallel_workers=args.works)
dataset_train = dataset_train.batch(args.batch_size, drop_remainder=True)

网络模型开发

网络模型均使用nn.Cell作为基类进行开发,根据网络中模型的不同功能实现了模型整体类CapsuleXlan、编码器Encoder和解码器Decoder等等,其中涉及到线性层、激活层、矩阵操作、张量操作等等,具体网络结构定义如下。

其中CapsuleXlan中包含了Encoder和Decoder,通过Encoder编码视觉信息,通过Decoder生成描述语句;Encoder中包含多层CapsuleLowRankLayer,每一层CapsuleLowRankLayer对特征进行处理后输入Capsule进行计算,结果处理后返回上一层;对于Decoder来说同理。

对于具体的运算,以Decoder为例,首先我们定义子层和对应的线性层,提前需要使用的算子操作,然后在construct中使用预先定义的各个层对输入进行处理,计算结果经过线性层和层归一化后返回。

class Decoder(nn.Cell):
    def __init__(self, layer_num, embed_dim, att_heads, att_mid_dim, att_mid_drop):
        super(Decoder, self).__init__()
        self.decoder = nn.CellList([])
        for _ in range(layer_num):
        sublayer = LowRankLayer(embed_dim=embed_dim, att_heads=8,
        att_mid_dim=[128, 64, 128], att_mid_drop=0.9)
        self.decoder.append(sublayer)
        self.proj = nn.Dense(embed_dim * (layer_num + 1), embed_dim)
        self.layer_norm = nn.LayerNorm([embed_dim])
        self.concat_last = ops.Concat(-1)

def construct(self, gv_feat, att_feats, att_mask):
    batch_size =  att_feats.shape[0]
    feat_arr = [gv_feat]
    for i, decoder_layer in enumerate(self.decoder):
        gv_feat = decoder_layer(gv_feat, att_feats, att_mask, gv_feat, att_feats)
        feat_arr.append(gv_feat)
        gv_feat = self.concat_last(feat_arr)
        gv_feat = self.proj(gv_feat)
        gv_feat = self.layer_norm(gv_feat)
    return gv_feat, att_feats

同时,对交叉熵损失进行实现并将其和网络模型组装到一个类中:
class CapsuleXlanWithLoss(nn.Cell):
    def __init__(self, model):
        super(CapsuleXlanWithLoss, self).__init__()
        self.model = model
        self.ce = nn.SoftmaxCrossEntropyWithLogits(sparse=True)

def construct(self, indices, input_seq, target_seq, att_feats):
    logit = self.model(input_seq, att_feats)
    logit = logit.view((-1, logit.shape[-1]))
    target_seq = target_seq.view((-1))
    mask = (target_seq > -1).astype("float32")
    loss = self.ce(logit, target_seq)
    loss = ops.ReduceSum(False)(loss * mask) / mask.sum()
    return loss

训练

class CapsuleXlan(nn.Cell):
    def __init__():
        self.encoder = Encoder()
        self.decoder = Decoder()
class Encoder(nn.Cell):
    def __init__():
        self.encoder = nn.CellList([])
        for _ in range(layer_num):
        sublayer = CapsuleLowRankLayer()
        self.encoder.append(sublayer)
class Decoder(nn.Cell):
    def __init__():
        self.decoder = nn.CellList([])
        for _ in range(layer_num):
        sublayer = LowRankLayer()
        self.decoder.append(sublayer)
class CapsuleLowRankLayer(nn.Cell):
    def __init__():
        self.attn_net = Capsule()
class LowRankLayer(nn.Cell):
    def __init__():
        self.attn_net = SCAtt()
class SCAtt(nn.Cell):
    def __init__():
class Capsule(nn.Cell):
    def __init__():

数据集和网络模型都已经准备完毕,在训练过程中只需要将模型、优化器、数据集、回调函数装配起来即可,在这里我们首先定义优化器,使用Adam优化器进行优化,然后定义回调函数,使用LossMonitor、TimeMonitor、ModelCheckpoint和SummaryCollector分别监听损失函数、计算每一步所用时间、保存模型以及保存可视化数据。最终使用nn.Model将四部分装配在一起进行训练,并可以通过MindInsight观察训练损失和参数的变化。

net = CapsuleXlan()
net = CapsuleXlanWithLoss(net)
warmup_lr = nn.WarmUpLR(args.lr, args.warmup)   
optim = nn.Adam(params=net.trainable_params(), learning_rate=warmup_lr, beta1=0.9, beta2=0.98, eps=1.0e-9)
model = ms.Model(network=net, optimizer=optim)

loss_cb = LossMonitor(per_print_times=1)
time_cb = TimeMonitor(data_size=step_per_epoch)
ckpoint_cb = ModelCheckpoint(prefix='ACN',
directory=os.path.join(args.result_folder, 'checkpoints'))
summary_cb = SummaryCollector(summary_dir=os.path.join(args.result_folder, 'summarys'))
cbs = [loss_cb, time_cb, ckpoint_cb, summary_cb]
model.train(epoch=args.epochs, train_dataset=dataset_train, callbacks=cbs)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/83128.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

JavaOOP面试题(108道)

✅作者简介:热爱国学的Java后端开发者,修心和技术同步精进。 🍎个人主页:Java Fans的博客 🍊个人信条:不迁怒,不贰过。小知识,大智慧。 💞当前专栏:Java面试题…

mybatis以及mybatisplus批量插入问题

1. 思路分析: 批量插入是我们日常开放经常会使用到的场景,一般情况下我们也会有两种方案进行实施,如下所示。 方案一 就是用 for 循环循环插入: 优点:JDBC 中的 PreparedStatement 有预编译功能,预编译之…

vue3较vue不同的地方

自定义指令的区别: vue2的写法: Vue.directive(scroll, {}) //scroll是指令名称 vue3的写法: 定义全局的:在main.js文件中定义: createApp(App).directive("hello",{}).use(store).use(router).mount(#…

小程序import及include引用的简单理解

场景:在小程序中,WXML 提供两种文件引用方式import和include 我自己记录下自己的一些简单理解 官方文档:引用 | 微信开放文档 第一:import import,就是可以引入自定义指定的template模板 比如:我在import页…

stm32f767之ADC

一,基本介绍 1,ADC时钟。 ADC时钟一般常用来自于经可编程预分频器分频的APB2 时钟,该预分频器允许ADC 在fPCLK2/2、 /4、/6 或/8 下工作。ADCCLK 的最大值限制。2,ADC通道。 有16 条复用通道。我的理解是每个ADC(1&…

气泡水位计安装示意图 气泡水位计工作原理

气泡式水位计测量精度高,免气瓶,免测井,免维护,抗振动,寿命长,特别适用于流动水体、大中小河流等水深比较大的场合。具有安装简单,操作、组网灵活,尤其是无井水位测量最理想的水位监…

城市燃气系统安全解决方案

汽车制造业 MES系统 DNC系统 生产 安全域1 管理层 工控安全隔离装置 交换机 安全配置核查系统 HMI 历史数据库 运行监控系统 实时数据库 打印机过程 安全域2 监控层 工控漏洞扫描系统 安全交换机 工控安全审计系统 工控入侵检测系统工程师站 A 操作员站 A 实时数据库A 操作员站…

[附源码]Python计算机毕业设计SSM基于的冠状病毒疫情防控资讯交流推荐网站(程序+LW)

项目运行 环境配置: Jdk1.8 Tomcat7.0 Mysql HBuilderX(Webstorm也行) Eclispe(IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持)。 项目技术: SSM mybatis Maven Vue 等等组成,B/S模式 M…

stream_open函数分析

在讲 stream_open() 函数之前,需要先了解 stream_open() 里面使用到的一些基本的数据结构。如下: 第一个数据结构是 struct VideoState ,VideoState 可以说是播放器的全局管理器。字段非常多,时钟,队列,解…

Android 11 的状态栏的隐藏

概述 Android 11 的状态栏与导航栏较之前的版本有较大的差异, 在Android 7.0 SystemUI 状态/导航栏的隐藏与显示中所描述的部分内容已不再适用. 比如, 自动隐藏的时间, 隐藏的动画, 较之前的版本已面目全非, 本文将对隐藏状态栏部分的内容进行一些补充. APP如何隐藏状态栏 参…

Yield Guild Games 成功举办首届 SubDAO 峰会

Yield Guild Games(YGG)于 2022 年 11 月 18 日在菲律宾马尼拉举行了第一届 SubDAO 峰会。 SubDAO 峰会与菲律宾 Web3 狂欢节两个活动同时举行,为 YGG 的区域 SubDAO 提供了在 Web3 应用中心——菲律宾进行面对面交流的机会。此次活动旨在传达…

运维开发实践 - helm

1. helm介绍 helm 是一个用于管理部署在kubernetes上的应用的工具 使用要求:一个Kubernetes集群 2.下载安装 Helm Github Download Helm Huawei Source 按照自己的操作系统版本下载相应的helm压缩包 并将helm添加到环境变量中; # 检查是否安装成功 helm version…

read_thread解复用线程分析

read_thread() 线程的主要作用从 MP4 里面读取 AVPacket,然后丢进去 PacketQueue 队列。所以需要先学习一下 strcut PacketQueue 跟 struct MyAVPacketList 数据结构。如下: typedef struct MyAVPacketList {AVPacket *pkt;int serial; } MyAVPacketLis…

html文件里怎么引用vue组件?

这里我们使用 http-vue-loader 来实现&#xff1a;https://www.npmjs.com/package/http-vue-loader Load .vue files directly from your html/js. No node.js environment, no build step. 我做了个demo如下&#xff1a; html文件里面写下面的代码 <!DOCTYPE html> &l…

计算机研究生就业方向之当老师(中小学)

我一直跟学生们说你考计算机的研究生之前一定要想好你想干什么&#xff0c;如果你只是转码&#xff0c;那么你不一定要考研&#xff0c;至少以下几个职位研究生是没有啥优势的&#xff1a; 1&#xff0c;软件测试工程师&#xff08;培训一下就行&#xff09; 2&#xff0c;前…

股票购买接口系统怎么使用vn.py进行量化策略?

一般情况下&#xff0c;股票购买接口系统主要是可以运用在股票量化交易系统开发的一个大方向&#xff0c;也就是说&#xff0c;股票购买接口系统是根据这些量化的特点来开发的&#xff0c;就比如使用vn.py进行量化策略&#xff0c;在这方面&#xff0c;对交易者进行量化分析也起…

Web前端105天-day-41-JSCORE

JSCORE01 目录 前言 一、声明提升 二、宿主 window 三、断点功能 四、匿名函数解决全局污染 五、作用域链 六、闭包 七、私有 八、arguments 九、函数重载 十、方括号属性语法 十一、重载练习 十二、this 总结 前言 JSCORE01学习开始 一、声明提升 报错方案: 让…

走进SpringCloud微服务

微服务概述一、注册中心&#xff1a;Eureka ⭐⭐⭐1.1 原理1.2 代码二、负载均衡&#xff1a;Ribbon ⭐三、远程调用&#xff1a;Feigh ⭐⭐⭐3.1 原理3.2 代码四、熔断限流&#xff1a;Hystrix ⭐⭐⭐4.1线程池策略4.2 信号量隔离策略4.3 方法降级4.4 断路器、熔断器五、网关&…

MongoDB和MongoTemplate对于嵌套数据的判空查询

前言&#xff1a; 不知道有没有和小名一样&#xff0c;接触MongDB时间不长的小伙伴。由于MongoDB是以文档形式存储数据的&#xff0c;所以其中的数据类型相对MySql或者Oracle关系型数据库丰富一些&#xff08;MongoDB是NoSQL数据库这里比较不是很准确&#xff09; 我们在关系…

Dropout方法原理和使用方法

来源&#xff1a;投稿 作者&#xff1a;梦飞翔 编辑&#xff1a;学姐 为什么提出这种方法&#xff1f; 神经网络在训练过程中&#xff0c;由于正负样本不均衡、样本特征显著性不足、训练参数不合适等原因会导致发生过拟合现象&#xff0c;即模型参数陷入局部最优&#xff0c;仅…