人工智能应用-实验7-胶囊网络分类minst手写数据集

news2024/10/7 3:22:25

文章目录

    • 🧡🧡实验内容🧡🧡
    • 🧡🧡代码🧡🧡
    • 🧡🧡分析结果🧡🧡
    • 🧡🧡实验总结🧡🧡

🧡🧡实验内容🧡🧡

编写胶囊网络分类软件(编程语言不限,如 Python 等),实现对 MNIST 数据集分类的操作。


🧡🧡代码🧡🧡

import torch
from torch import nn
from torch.optim import Adam
import torch.nn.functional as F
from torchvision import transforms, datasets
import time
import matplotlib.pyplot as plt

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_set = datasets.MNIST('data', train=True, download=True, transform=transform)
test_set = datasets.MNIST('data', train=False, download=True, transform=transform)

train_loader = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=1000, shuffle=True)


# !nvidia - smi
# 检查GPU是否可用
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)


# @title 胶囊 net
def squash(inputs, axis=-1):
    inputs = inputs.to(device)
    norm = torch.norm(inputs, p=2, dim=axis, keepdim=True)
    scale = norm ** 2 / (1 + norm ** 2) / (norm + 1e-8)
    return scale * inputs


# 预胶囊层
class PrimaryCapsule(nn.Module):
    def __init__(self, in_channels, out_channels, dim_caps, kernel_size, stride=1, padding=0):
        super(PrimaryCapsule, self).__init__()
        self.dim_caps = dim_caps
        self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)

    def forward(self, x):
        outputs = self.conv2d(x)
        outputs = outputs.view(x.size(0), -1, self.dim_caps)
        return squash(outputs)


# 胶囊层(重构层)
class DenseCapsule(nn.Module):
    def __init__(self, in_num_caps, in_dim_caps, out_num_caps, out_dim_caps, routings=3):
        super(DenseCapsule, self).__init__()
        self.in_num_caps = in_num_caps
        self.in_dim_caps = in_dim_caps
        self.out_num_caps = out_num_caps
        self.out_dim_caps = out_dim_caps
        self.routings = routings
        self.weight = nn.Parameter(0.01 * torch.randn(out_num_caps, in_num_caps, out_dim_caps, in_dim_caps))

    def forward(self, x):
        x_hat = torch.squeeze(torch.matmul(self.weight, x[:, None, :, :, None]), dim=-1)
        x_hat_detached = x_hat.detach()  # 通过detach()方法获得的张量不再具有梯度,即它是一个常量张量,不会对反向传播产生影响。
        # The prior for coupling coefficient, initialized as zeros.
        # b.size = [batch, out_num_caps, in_num_caps]
        b = torch.zeros(x.size(0), self.out_num_caps, self.in_num_caps)
        b = b.to(device)
        # print(f"b shape={b.shape}")
        assert self.routings > 0, 'The \'routings\' should be > 0.'
        for i in range(self.routings):
            # c.size = [batch, out_num_caps, in_num_caps]
            c = F.softmax(b, dim=1)
            c = c.to(device)
            # print(f"c shape={c.shape}")
            if i == self.routings - 1:
                # dim=-2:倒数第二个维度求和, keepdim=True:保持计算前后维度不变
                outputs = squash(torch.sum(c[:, :, :, None] * x_hat, dim=-2, keepdim=True))
            else:  # Otherwise, use `x_hat_detached` to update `b`. No gradients flow on this path.
                outputs = squash(torch.sum(c[:, :, :, None] * x_hat_detached, dim=-2, keepdim=True))
                b = b + torch.sum(outputs * x_hat_detached, dim=-1)
        # print(f"outputs shape={torch.squeeze(outputs, dim=-2).shape}")
        return torch.squeeze(outputs, dim=-2)  # 将倒数第二维去掉


# 组合以上
class CapsuleNet(nn.Module):
    def __init__(self, input_size, classes, routings):
        super(CapsuleNet, self).__init__()
        self.input_size = input_size
        self.classes = classes
        self.routings = routings

        # Layer 1: Just a conventional Conv2D layer
        self.conv1 = nn.Conv2d(
            input_size[0], 256, kernel_size=9, stride=1, padding=0)

        # Layer 2: Conv2D layer with `squash` activation, then reshape to [None, num_caps, dim_caps]
        self.primarycaps = PrimaryCapsule(
            256, 256, 8, kernel_size=9, stride=2, padding=0)

        # Layer 3: Capsule layer. Routing algorithm works here.
        self.digitcaps = DenseCapsule(in_num_caps=32 * 6 * 6, in_dim_caps=8,
                                      out_num_caps=classes, out_dim_caps=16, routings=routings)

        self.relu = nn.ReLU()

    def forward(self, x, y=None):
        x = self.relu(self.conv1(x))
        x = self.primarycaps(x)
        x = self.digitcaps(x)
        # 对(64,10,16),计算最后一维向量的长度,即对应10个分类类别的存在概率,16蕴含特征的空间、局部位置等信息
        length = x.norm(dim=-1)
        return length


# @title Train and Test
def test(data_iter, net):
    acc_sum, n = 0.0, 0
    with torch.no_grad():
        for X, y in data_iter:
            X = X.to(device)
            y = y.to(device)
            if isinstance(net, torch.nn.Module):
                net.eval()  # 评估模式, 这会关闭dropout
                acc_sum += (net(X).argmax(dim=1) == y).float().sum().item()
                net.train()  # 改回训练模式
            else:
                if ('is_training' in net.__code__.co_varnames):  # 如果有is_training这个参数
                    # 将is_training设置成False
                    acc_sum += (net(X, is_training=False).argmax(dim=1) == y).float().sum().item()
                else:
                    acc_sum += (net(X).argmax(dim=1) == y).float().sum().item()
            n += y.shape[0]
    return acc_sum / n


def train(train_iter, test_iter, net, loss, optimizer, epochs):
    batch_count = 0
    loss_list = []
    test_acc_list = []

    for epoch in range(epochs):
        train_loss_sum, train_acc_sum, n, start = 0.0, 0.0, 0, time.time()
        for X, y in train_iter:
            X = X.to(device)
            y = y.to(device)
            y_hat = net(X)
            l = loss(y_hat, y)
            optimizer.zero_grad()
            l.backward()
            optimizer.step()
            train_loss_sum += l.item()
            train_acc_sum += (y_hat.argmax(dim=1) == y).sum().item()
            n += y.shape[0]
            batch_count += 1
            # print(f"batch_count: {batch_count}")
        test_acc = test(test_iter, net)
        loss_list.append(train_loss_sum / batch_count)
        test_acc_list.append(test_acc)
        print('==========Epoch=%d========== \nloss=%.4f, train_acc=%.5f, test_acc=%.5f, time %f min'
              % (epoch + 1, train_loss_sum / batch_count, train_acc_sum / n,
                 test_acc, (time.time() - start) / 60))
    # 绘制损失函数随训练轮数的变化图
    plt.plot(range(1, epochs + 1), loss_list)
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.title('Training Loss')
    plt.show()
    # 绘制准确率随训练轮数的变化图
    plt.plot(range(1, epochs + 1), test_acc_list)
    plt.xlabel('Epochs')
    plt.ylabel('acc')
    plt.title('Test acc')
    plt.show()


batch_size, lr, epochs = 100, 0.001, 4
net = CapsuleNet(input_size=[1, 28, 28], classes=10, routings=8)
net = net.to(device)
optimizer = torch.optim.Adam(net.parameters(), lr=lr)
loss = nn.CrossEntropyLoss()
loss = loss.to(device)
train(train_loader, test_loader, net, loss, optimizer, epochs)

🧡🧡分析结果🧡🧡

数据预处理:
加载数据集:
加载torch库中自带的minst数据集
转换数据:
转为tensor变量(相当于直接除255归一化到值域为(0,1)),再设置mean和st
d标准化到(-1,1)区间。

设置基本参数:
在这里插入图片描述

构建胶囊网络:

  • 第一层:(64,28,28,1) => (64,20,20,256)
    - 卷积核大小为9,步长为1,故图片由28×28缩小为20×20,输出特征数设为256
  • 第二层:(64, 20, 20, 256) => (64,6,6,256) => (64, 1152, 8)
    - 先进行一层卷积操作,卷积核大小为9,步长为2,故图片由20×20缩小为6×6,输出特征数仍然为256。
    - 将输出通道数256分割成32×8的向量,合并中间维度(64, 6×6×32, 8) =>(64, 1152, 8),即有1152个胶囊,每个胶囊维度为8。
    进行squash激活。
  • 第三层:(64, 1152, 8) => (64, 10, 16)
    - 输入的胶囊个数为1152,维度为8,目标是转为胶囊个数为10,维度为16。
    - 对输入向量扩维:(64, 1152, 8) =>(64, 1, 1152, 8, 1),然后使用维度为 (1, 10, 1152, 16, 8)的权重w与其进行矩阵乘法,得到x_hat ,其维度为(64, 10, 1152, 16, 1),去除最后一维,维度变成 (64, 10, 1152, 16)。
    - 设置初始权重b,其维度为(64, 10, 1152, 1),使用动态路由算法将归一化后的b与x_hat多次相乘并累加,得到维度为(64, 10, 1, 16)的向量,去除倒数第二个维度,最后变为(64, 10, 16)的输出向量。
  • 输出结果:(64, 10, 16) => (64, 10)
    - 64代表batch_size,10代表10个图片分类标签,16代表这10个向量所具有的空间、相对位置等相关信息(CNN则没有),通过计算这10个向量的长度,即可得到分类概率。
    在这里插入图片描述

模型训练与评估:
可以看到,设置了4个epoch,准确率即达到了98.7%, 高于CNN和BP。
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

分析胶囊参数对分类准确率的影响
更改胶囊参数原则:
在卷积核参数不变的情况下,主要第二层预胶囊层中的输出胶囊维度,然后自己手算第二层预胶囊层的输出个数;之后第三层胶囊层的输入胶囊个数和维度同第二层的胶囊输出维度,第三层胶囊层个数为10固定不变(因为有10个类别),胶囊层维度可以更改。
在这里插入图片描述
被对照实验:
输入胶囊个数=1152、输入胶囊维度=8、输出胶囊个数=10、输出胶囊维度=16
参照问题1,准确率为98.70%,运行时间为4.931 minutes
对照实验1:
输入胶囊个数=512、输入胶囊维度=16、输出胶囊个数=10、输出胶囊维度=16
在这里插入图片描述
在这里插入图片描述
对照实验2:
输入胶囊个数=512、输入胶囊维度=16、输出胶囊个数=10、输出胶囊维度=32
在这里插入图片描述
在这里插入图片描述
对照实验3:
输入胶囊个数=512、输入胶囊维度=16、输出胶囊个数=10、输出胶囊维度=8
在这里插入图片描述
在这里插入图片描述
总结以上结果如下表(epoch=4):
(表格中的“输入输出”指的是第三层胶囊层的输入输出胶囊)
在这里插入图片描述
由此分析可知:

  • 输入胶囊个数和维度:
    在第一行和第二行中,更改第三层的输入胶囊个数和维度,一方面增加了输入胶囊的维度,使得每个胶囊能够表示更多的特征信息,从而提高了分类准确率;另一方面胶囊个数减少,从而减少了运行时间。
  • 输出胶囊个数和维度:
    观察第二、三行、四行,可以看到随着输出胶囊维度减小,分类准确率得到提高,且模型运行时间相应减少。据此判断胶囊维度小,浓缩了对判断分类最为有利的特征信息,减少了冗余无用的干扰信息,因此使得准确率得到提高。

🧡🧡实验总结🧡🧡

理论理解方面:
胶囊网络与CNN卷积网络的主要区别:

  • CNN擅长提取图像中的局部特征,其中,池化层可以降低采样数,将一个邻域的重要特征提取出来以代表区域特征,但是,也正因为此,池化层的计算过程中会损失空间特征,即特征组合的相对信息位置。例如对于一张人脸图像,CNN很容易提取出眼睛、嘴巴、鼻子这些局部的重要的特征,但是眼睛、嘴巴、鼻子之间的相对位置关系却无法获得。
  • 基于以上CNN存在的问题,而提出了胶囊网络的概念,胶囊网络中的胶囊单元(包含许多神经元)能够捕捉图像中的层次结构和空间关系,并且设计了动态路由算法来实现胶囊之间的信息传递和姿态估计。

代码实操方面:

  • 首先是运行时间问题,前两次CNN和BP实验中,用CPU跑,分别设计epoch=5和epoch=15,运行时间在3-4min左右,而本次实验,我首先用的CPU跑,但是跑了半个小时仍然没有完成一个epoch,于是感叹它的复杂性。好在Google colab自带免费的GPU,才解决了我的燃眉之急。本实验中,一开始设置为epoch=5时,超过了Google colab自带的GPU算力,运行不下去,只好统一设计为epoch=4。
  • 其次是复杂的矩阵运算。二维矩阵与二维矩阵相乘我还会,但是在重构胶囊信息时,要进行五维矩阵与五维矩阵的乘法,例如维度(64, 10, 1152, 16, 8)的矩阵与维度(64, 1, 1152, 8, 1)进行矩阵乘法,就比较懵了。通过网上学习,了解到:只要看最后两维度相乘,前面的维度通过“维度广播”一般取最大值即可,所以相乘之后的矩阵的维度为(64, 10, 1152, 16, 1)。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1688889.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

vue3+ts实战

目录 一、ts语法练习 1.1、安装 1.2、语法 二、vue3ts 2.1、项目创建 2.1.1、项目创建(建议node版本在16.及以上) 2.1.2、下载路由、axios 2.1.3、引入element-plus 2.1.4、报错解决 (1)文件路径下有红色波浪 (2)组件名称下有红色波浪 (3)引入模块下有红色波浪 2.…

快速幂算法6

eg: n10&#xff0c;10%20, 10/25, 5%21,4* 5/22, 2%20,4*256 0/20, 1024 递归算法 #include<iostream> using namespace std; long long quick_pow(int b,int e) {if(b0)return 0;if(e0)return 1;if(e%20){int tempquick_pow(b,e/2);return temp*temp;}if(e%2!0)…

大数据学习之安装并配置maven环境

什么是Maven Maven字面意&#xff1a;专家、内行Maven是一款自动化构建工具&#xff0c;专注服务于Java平台的项目构建和依赖管理。依赖管理&#xff1a;jar之间的依赖关系&#xff0c;jar包管理问题统称为依赖管理项目构建&#xff1a;项目构建不等同于项目创建 项目构建是一…

【SQL国际标准】ISO/IEC 9075:2023 系列SQL的国际标准详情

目录 &#x1f30a;1. 前言 &#x1f30a;2. ISO/IEC 9075:2023 系列SQL的国际标准详情 &#x1f30a;1. 前言 ISO&#xff08;国际标准化组织&#xff0c;International Organization for Standardization&#xff09;是一个独立的、非政府间的国际组织&#xff0c;其宗旨是…

C++语言学习(五)—— 类与对象(一)

目录 一、类类型的定义 二、类成员的访问控制 2.1 什么是"类内"和"类外" 2.2 对于访问控制属性的说明 三、类类型的使用 3.1 进行抽象 3.2 声明类 3.3 实现类 3.4 使用类 四、构造函数的引入 五、析构函数的引入 六、重载构造函数的引入 6.1 …

# 分布式链路追踪_skywalking_学习(2)

分布式链路追踪_skywalking_学习&#xff08;2&#xff09; 一、分布式链路追踪_skywalking &#xff1a;Rpc 调用监控 1、Skywalking(6.5.0) 支持的 Rpc 框架有以下几种&#xff1a; Dubbo 2.5.4 -> 2.6.0Dubbox 2.8.4Apache Dubbo 2.7.0Motan 0.2.x -> 1.1.0gRPC 1.…

Live800:客户为王,企业竞争的新趋势与核心要素!

在企业经营管理中&#xff0c;客户始终是最重要的资源和战略。从企业经营的角度来说&#xff0c;企业管理的核心是客户管理&#xff0c;客户管理的核心是价值创造和价值分配&#xff0c;这是企业经营的基础。这里主要讨论了企业竞争的新趋势与核心要素&#xff0c;认为客户为王…

嵌入式岗位,你有能力,你同样可以拿到高薪资

在开始前刚好我有一些资料&#xff0c;是我根据网友给的问题精心整理了一份「嵌入式的资料从专业入门到高级教程」&#xff0c; 点个关注在评论区回复“888”之后私信回复“888”&#xff0c;全部无偿共享给大家&#xff01;&#xff01;&#xff01; 就算你进去了&#xff0…

景源畅信:小白做抖音运营难吗?

在数字化时代&#xff0c;社交媒体已成为人们生活的一部分&#xff0c;而抖音作为其中的翘楚&#xff0c;吸引了众多希望通过平台实现自我价值和商业目标的用户。对于刚入门的小白来说&#xff0c;运营抖音账号可能会遇到不少挑战。接下来&#xff0c;我们将详细探讨这一话题&a…

交换机部分综合实验

实验要求 1.内网IP地址使用172.16.0.0/16 2.sw1和sW2之间互为备份; 3.VRRP/mstp/vlan/eth-trunk均使用; 4.所有pc均通过DHcP获取Ip地址; 5.ISP只配置IP地址; 6.所有电脑可以正常访问IsP路由器环回 实验拓扑 实验思路 1.给交换机创建vlan&#xff0c;并将接口划入vlan 2.在SW1和…

pytorch-13_2 模型结构选择策略:层数、激活函数、神经元个数

一、拟合度概念 在所有的模型优化问题中&#xff0c;最基础的也是最核心的问题&#xff0c;就是关于模型拟合程度的探讨与优化。根据此前的讨论&#xff0c;模型如果能很好的捕捉总体规律&#xff0c;就能够有较好的未知数据的预测效果。但限制模型捕捉总体规律的原因主要有两点…

Qt for android 添加自己的java包

java 包 目录 将目录放在项目的android目录中 .pro 中添加 或(可以在Qt Creator中显示) DISTFILES android/src/ScytheStudio/*.java \android/src/Serial/*.java \

火山引擎边缘云亮相 Force 原动力大会,探索 AI 应用新范式

5月15日&#xff0c;2024 春季火山引擎 FORCE 原动力大会在北京正式举办。大会聚焦 AI 主题&#xff0c;以大模型应用为核心、以 AI 落地为导向&#xff0c;展示了火山引擎在大模型、云计算领域的实践应用&#xff0c;携手汽车、手机终端、金融、消费、互联网等领域的专家和企业…

CSS3 新增背景属性 + 新增边框属性(如果想知道CSS3新增背景属性和新增边框属性的知识点,那么只看这一篇就够了!)

前言&#xff1a;CSS3在CSS2的基础上&#xff0c;新增了很多强大的新功能&#xff0c;从而解决一些实际面临的问题&#xff0c;本篇文章主要讲解的为CSS3新增背景属性和新增边框属性。 ✨✨✨这里是秋刀鱼不做梦的BLOG ✨✨✨想要了解更多内容可以访问我的主页秋刀鱼不做梦-CSD…

关于廉洁的短视频:四川京之华锦信息技术公司

关于廉洁的短视频&#xff1a;传递清廉之风 在信息爆炸的时代&#xff0c;短视频以其短小精悍、直观生动的特点&#xff0c;成为了人们获取信息、传播价值观念的重要渠道。四川京之华锦信息技术公司在众多主题中&#xff0c;关于廉洁的短视频尤为引人注目&#xff0c;它们以独…

B站自动回复插件_无需千粉,轻松适配引流拉新资源分享

项目介绍 B站关键词自动回复插件&#xff0c;无需千粉&#xff0c; 很适合做流量做引流做私欲的朋友&#xff0c; 前期没有千粉是无法开启官方自动回复的&#xff0c; 适当的情况下可以用这个插件顶一下&#xff0c; 三联好评领取资源的打法真的超级涨粉&#xff0c; 感谢插件…

第十二节 SpringBoot Starter 系列结束语

感谢阅读&#xff0c;到这里&#xff0c;本系列课程就结束了。 一、为什么选择 SpringBoot Starter SpringBoot 近年来已经成为 Java 应用的必备框架&#xff1b; 而 SpringBoot starter 模式已经成为各大中间件集成到 SpringBoot 应用的首选方式&#xff0c;通过引入 xxx-st…

【MATLAB源码-第215期】基于matlab的8PSK调制CMA均衡和RLS-CMA均衡对比仿真,对比星座图和ISI。

操作环境&#xff1a; MATLAB 2022a 1、算法描述 CMA算法&#xff08;恒模算法&#xff09; CMA&#xff08;Constant Modulus Algorithm&#xff0c;恒模算法&#xff09;是一种自适应盲均衡算法&#xff0c;主要用于消除信道对信号的码间干扰&#xff08;ISI&#xff09;…