pytorch实现胶囊网络(capsulenet)

news2025/1/23 2:00:58

胶囊网络在hinton刚提出来的时候小热过一段时间,之后热度并没有维持多久。vision transformer之后基本少有人问津了。不过这个模型思路挺独特的,值得研究一下。

这个模型的提出是为了解决CNN模型学习到的特征之间没有空间上的关系,从而对于各种变换不鲁棒的缺点。

模型的整体思路如下:

1,胶囊:

抛开论文里花哨的描述,胶囊其实就是特征图上比点更大的单元,本质上我觉得类似transformer的patch。当然也有一定的差别,因为后续要用动态路由更新胶囊,所以胶囊必须要是向量,而不是标量。

2,动态路由:

由于pooling会导致信息丢失,作者使用动态路由来连接两个胶囊层,并更新胶囊。

同时,动态路由也能建立不同层胶囊(特征)在空间上的相对关系。

由于胶囊其实是向量,动态路由算法会根据这些向量的相似性(点积)和一致性(加权)来决定信息传递的路径。

3,整体结构:

1)卷积层

2)PrimaryCaps层:这层的作用就是把卷积特征转变成胶囊的形式

3)DigitCaps层:用动态路由迭代生成高层的胶囊。

4)解码器

4,loss

胶囊网络的损失函数主要由两部分组成:间隔损失(Margin Loss)和重构损失。

在计算间隔损失时,会使用一个阈值(通常设置为0.9和0.1)来区分正样本和负样本。如果某一类的胶囊输出向量的模长大于阈值m+(正样本阈值,例如0.9),则认为该类存在,并将其视为正样本;反之,如果输出向量的模长小于阈值m-(负样本阈值,例如0.1),则认为该类不存在,将其视为负样本。

重构损失的计算通常基于原始输入数据与重构数据之间的差异,例如使用均方误差(MSE)来衡量这种差异。

如果站在2024年的如今再来看当初的设计,其实胶囊的思路还是很像后来的transformer的,有点殊途同归的感觉。


pytorch实现:

1,实现初始胶囊

首先是会用到的压缩函数,压缩函数的作用是将向量的长度压缩到0和1之间,同时保留向量的方向不变。

公式:

def squash(inputs, axis=-1):
    norm = torch.norm(inputs, p=2, dim=axis, keepdim=True)
    scale = norm**2 / (1 + norm**2 + 1e-8) / (norm + 1e-8)
    return scale * inputs

初始胶囊,这一层的作用是将卷积特征转换为胶囊的形式。

class PrimaryCapsule(nn.Module):
    def __init__(self, in_channels, out_channels, dim_caps, kernel_size, stride=1, padding=0):
        super(PrimaryCapsule, self).__init__()
        self.dim_caps = dim_caps
        self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)

    def forward(self, x):
        outputs = self.conv2d(x)
        outputs = outputs.reshape(x.size(0), -1, self.dim_caps)
        return squash(outputs)

2,实现胶囊层

路由算法

这个伪代码初看起来挺乱的,我翻译成人话如下:

首先,每一次迭代由两层胶囊层做点积后再通过softmax计算出耦合系数c。

耦合系数和下层胶囊的预测计算加权和,这是个投票的过程。

再通过压缩函数,就得到了本层的胶囊v。

因为这是个迭代的过程,需要不断更新耦合系数C。

新的耦合系数由两层胶囊之间的相似度决定。


具体实现中,会对低层胶囊先做一个变换,也就是下面代码里的weight。这个权重矩阵代表的是对下层胶囊的变化,变换之后的结果Ui|j用论文里的话说叫做“prediction vectors”。

胶囊层代码:

class DenseCapsule(nn.Module):
    def __init__(self, in_num_caps, in_dim_caps, out_num_caps, out_dim_caps, routings=3):
        super(DenseCapsule, self).__init__()
        self.in_num_caps = in_num_caps
        self.in_dim_caps = in_dim_caps
        self.out_num_caps = out_num_caps
        self.out_dim_caps = out_dim_caps
        self.routings = routings #路由的迭代次数
        #初始化
        self.weight = nn.Parameter(0.01 * torch.randn(out_num_caps, in_num_caps, out_dim_caps, in_dim_caps))

    def forward(self, x):
        u_hat = torch.squeeze(torch.matmul(self.weight, x[:, None, :, :, None]), dim=-1)
        #从当前计算图中分离出x_hat,这样在后续的反向传播中不会计算其梯度 
        u_hat_detached = u_hat.detach()
        b = torch.zeros(x.size(0), self.out_num_caps, self.in_num_caps).cuda()
        #路由算法
        for i in range(self.routings):
            c = F.softmax(b, dim=1)
            if i == self.routings - 1:
                v = squash(torch.sum(c[:, :, :, None] * u_hat, dim=-2, keepdim=True))
            else:
                v = squash(torch.sum(c[:, :, :, None] * u_hat_detached, dim=-2, keepdim=True))
                b = b + torch.sum(v * u_hat_detached, dim=-1)

        return torch.squeeze(v, dim=-2)

需要将的是u_hat_detached = u_hat.detach()这一步。将u_hat从计算图中分离出来的目的,是为了防止迭代过程中梯度不断累积,导致梯度过大。所以我们可以在后续的路由算法中看出,只有在最后一次计算路由时使用了u_hat,之前的迭代中都是使用的u_hat_detached。从而让整个路由过程中梯度只更新一次。

3,损失函数

def caps_loss(y_true, y_pred, x, x_recon, lambd=0.5):
    L = y_true * torch.clamp(0.9 - y_pred, min=0.) ** 2 + 0.5 * (1 - y_true) * torch.clamp(y_pred - 0.1, min=0.) ** 2
    L_margin = L.sum(dim=1).mean()

    L_recon = nn.MSELoss()(x_recon, x)

    return L_margin + lambd * L_recon

4,整体模型

模型返回两个值,一个是预测的概率,一个是重建的图像。这两个值会分别用来计算间隔损失和重构损失。

class CapsuleNet(nn.Module):
    def __init__(self, input_size, classes, routings):
        super(CapsuleNet, self).__init__()
        self.input_size = input_size
        self.classes = classes
        self.routings = routings
        self.conv1 = nn.Conv2d(input_size[0], 256, kernel_size=9, stride=1, padding=0)
        self.primarycaps = PrimaryCapsule(256, 256, 8, kernel_size=9, stride=2, padding=0)

        self.digitcaps = DenseCapsule(in_num_caps=32*6*6, in_dim_caps=8,
                                      out_num_caps=classes, out_dim_caps=16, routings=routings)

        self.decoder = nn.Sequential(
            nn.Linear(16*classes, 512),
            nn.ReLU(inplace=True),
            nn.Linear(512, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, input_size[0] * input_size[1] * input_size[2]),
            nn.Sigmoid()
        )

        self.relu = nn.ReLU()

    def forward(self, x, y=None):
        x = self.relu(self.conv1(x))
        x = self.primarycaps(x)
        x = self.digitcaps(x)
        length = x.norm(dim=-1)
        if y is None:
            index = length.max(dim=1)[1]
            y = torch.zeros(length.size()).scatter_(1, index.view(-1, 1), 1.)
        reconstruction = self.decoder((x * y[:, :, None]).view(x.size(0), -1))
        return length, reconstruction.view(-1, *self.input_size)

5,注意事项:

1)one-hot

在重建过程中使用的标签y是one-hot形式的,因此在训练和测试时需要加上这行代码,转换一下

targets = F.one_hot(targets, num_classes=classes).to(device)

2) loss

训练和测试时的loss设置如下

loss = caps_loss(y_true=targets,y_pred=y_pred,x=imgs,x_recon=x_recon,lambd=0.5)
        loss = loss.to(device)

其中lambd这个系数决定的是重构损失所占的比例 loss=margin_loss+lambd*recon_loss

总结:

胶囊网络分类结果不算差,在我的一些任务中train from scratch的胶囊网络就超越了imagenet1k上预训练过再finetune的vit。也超过了无预训练的VGG和resnet。(但是不如预训练过的vgg和resnet)。

这样的表现放在2017年已经很能打了,没火的原因我感觉有3个:

首先,由于胶囊网络迭代过程需要多次完整的特征图点乘特征图,所以内存消耗和时间消耗都是巨大的。我跑256的图时,24g显存的4090也只能把batch设置成5,运行速度非常慢。放在2017年,只能用1080ti来跑这个模型,简直折磨。(我2018年时也试过这个模型,训练都是按周算的,这谁愿意用啊)

另外一个原因可能是它的改进潜力不大。例如vit的核心机制是自注意力,注意力大家都玩出花来了,各种改进思路都很好借鉴。虽然vit效果很一般,但是后续的改进模型一个比一个厉害。而胶囊网络的核心路由算法想要创新就比较难。

最后还有一点就是原作者没放出胶囊网络在imagenet上的预训练模型。这个对模型热度的影响其实挺大的

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1587087.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

python毕业设计django游泳馆管理系统-flask

游泳馆管理系统具有信息管理功能的选择。游泳馆管理系统采用python技术,基于mysql开发,实现了首页,教练信息,培训信息,交流版块,活动公告,个人中心,后台管理等内容进行管理&#xff…

怎样将PDF转成PPT,有免费的工具吗?

PDF转换为PPT的需求在现代办公和学习中越来越常见。很多人可能遇到过需要将PDF文件中的内容转移到PPT中以方便编辑和展示的情况。幸运的是,现在市面上有许多工具可以帮助我们实现这一目标,而且其中不乏一些免费的选项。本文将详细介绍如何使用这些免费工…

Python 批量检测ip地址连通性,以json格式显示(支持传参单IP或者网段)

代码 ########################################################################## File Name: check_ip_test.py# Author: eight# Mail: 18847097110163.com # Created Time: Thu 11 Apr 2024 08:52:45 AM CST################################################…

小程序中配置scss

找到:project.config.json 文件 setting 模块下添加: "useCompilerPlugins": ["sass","其他的样式类型"] 配置完成后,重启开发工具,并新建文件 结果:

使用hexo+gitee从零搭建个人博客

一、环境准备 1.Node.js:下载 | Node.js 中文网 (nodejs.cn) ,Hexo 是基于Node.js 的博客框架 教程:https://blog.csdn.net/weixin_52799373/article/details/123840137 node -v npm -v 安装 Node.js 淘宝镜像加速器 (cnpm&am…

JS-28-AJAX

一、AJAX的定义 AJAX不是JavaScript的规范,它只是一个哥们“发明”的缩写:Asynchronous JavaScript and XML,意思就是用JavaScript执行异步网络请求。 如果仔细观察一个Form的提交,你就会发现,一旦用户点击“Submit”…

【JavaEE初阶系列】——网络编程 UDP客户端/服务器 程序实现

目录 🚩UDP和TCP之间的区别 🎈TCP是有连接的 UDP是无连接的 🎈TCP是可靠传输 UDP是不可靠传输 🎈TCP是面向字节流 UDP是面向数据报 🎈TCP和UDP是全双工 👩🏻‍💻UDP的socket ap…

【muzzik 分享】3D模型平面切割

# 前言 一年一度的征稿到了,倒腾点存货,3D平面切割通常用于一些解压游戏里,例如水果忍者,切菜这些,今天我就给大家讲讲怎么实现3D切割以及其原理,帮助大家更理解3D中的 Mesh(网格),以及UV贴图和…

ssm+vue的实验室课程管理系统(有报告)。Javaee项目,ssm vue前后端分离项目。

演示视频: ssmvue的实验室课程管理系统(有报告)。Javaee项目,ssm vue前后端分离项目。 项目介绍: 采用M(model)V(view)C(controller)三层体系结构…

【NLP笔记】大模型微调方法概述

对于一些生成式场景而言,没有固定的回答结果,采用AI Agent的增强范式,可以极大地提升模型生成的效果。但是对于有固定格式、输出目标的场景而言,仅从prompt优化的角度出发很难突破瓶颈,需要通过微调来提升效果&#xf…

背 单 词 (考研词汇闪过)

单词: 买考研词汇闪过 研究艾宾浩斯遗忘曲线 https://www.bilibili.com/video/BV18Y4y1h7YR/?spm_id_from333.337.search-card.all.click&vd_source5cbefe6dd70d6d84830a5891ceab2bf9 单词方法 闪记背两排(5min)重复一遍(2mi…

vue中预览docx、xlsx、pptx、pdf

前言:其实本来是要做全类型文件预览的,但是一直找不到合适的doc,xlx,ppt预览插件。要是有可以使用的,可以评论推荐给我 我使用的node版本:v18.19.1 参考官网:preview 文件预览 | ran 引入方式: //安装组…

Flask快速搭建文件上传服务与接口

说明:仅供学习使用,请勿用于非法用途,若有侵权,请联系博主删除 作者:zhu6201976 一、需求背景 前端通过浏览器,访问后端服务器地址,将目标文件进行上传。 访问地址:http://127.0.0…

✔ ★Java项目——设计一个消息队列(二)

Java项目——设计一个消息队列 四. 项⽬创建五. 创建核⼼类创建 Exchange(名字、类型、持久化)创建 MSGQueue(名字、持久化、独占标识)创建 Binding(交换机名字、队列名字、bindingKey用于与routingKey匹配&#xff09…

前端docker jenkins nginx CI/CD持续集成持续部署-实战

最近用go react ts开发了一个todolist后端基本开发完了,前端采用CI/CD方式去部署。 步骤总结 先安装docker 和 docker-compose。安装jenkins镜像,跑容器的时候要配好数据卷。配置gitee或github(我这里使用gitee)在服务器上一定要创建好dokcer的数据卷,以便持久保存jenkin…

【MySQL】锁篇

SueWakeup 个人主页:SueWakeup 系列专栏:学习技术栈 个性签名:保留赤子之心也许是种幸运吧 本文封面由 凯楠📸友情提供 目录 本系列专栏 1. MySQ 中的锁 2. 表锁和行锁 表锁 行锁 3. InnoDB 存储引擎的三种行级锁 4. 悲观锁…

怎么开发一个预约小程序_一键预约新体验

预约小程序,让生活更便捷——轻松掌握未来,一键预约新体验 在快节奏的现代生活中,我们总是在不断地奔波,为了工作、为了生活,不停地忙碌着。然而,在这繁忙的生活中,我们是否曾想过如何更加高效…

探探各个微前端框架

本文作者为 360 奇舞团前端开发工程师 微前端架构是为了在解决单体应用在一个相对长的时间跨度下,由于参与的人员、团队的增多、变迁,从一个普通应用演变成一个巨石应用(Frontend Monolith)后,随之而来的应用不可维护的问题。这类问题在企业级…

点击按钮(文字)调起elementUI大图预览

时隔一年,我又回来了 ~ 最近在做后台,遇到一个需求,就是点击“查看详情”按钮,调起elementUI的大图预览功能,预览多张图片,如下图: 首先想到的是使用element-ui的el-image组件,但它是…

Towards Geolocation of Millions of IP Addresses(2012年)

下载地址: Towards geolocation of millions of IP addresses | Proceedings of the 2012 Internet Measurement Conference 被引用次数:70 Hu Z, Heidemann J, Pradkin Y. Towards geolocation of millions of IP addresses[C]//Proceedings of the 2012 Internet Measure…