Bilinear CNN:细粒度图像分类网络,对Bilinear CNN中矩阵外积的解释。

news2025/1/12 10:50:10

文章目录

    • 一、Bilinear CNN 的网络结构
    • 二、矩阵外积(outer product)
      • 2.1 外积的计算方式
      • 2.2 外积的作用
    • 三、PyTorch 网络代码实现

细粒度图像分类(fine-grained image recognition)的目的是区分类别的子类,如判别一只狗子是哈士奇还是柴犬。细粒度图像分类可以分为基于强监督信息(图像类别、物体标注框、部位标注点等)和基于弱监督信息(只有图像类别),具体可以参考 细粒度图像分类

Bilinear CNN 是2015在论文 《Bilinear CNN Models for Fine-grained Visual Recognition》中提出来的,是一种基于弱监督信息的细粒度图像分类模型。

一、Bilinear CNN 的网络结构

Bilinear CNN 的网络结构如下:
在这里插入图片描述
Bilinear CNN 由两个 CNN 特征提取网络组成,它们的输出做外积(outer product)获得双线性向量(可称为图像描述符 image descriptor),再进行分类。

需要注意的是两个 CNN 其实是完全相同的,代码中用的就是一个网络(一般用预训练的 vgg16 或 ResNet18 网络),只是对网络输出值 x 计算了 x 和 xT 的矩阵乘积实现特征交互。
当然也可以使用两个不同的 CNN 网络。

双线性网络用于模拟图像的双因素变化。有一种说法是:网络A的作用是对图像中对象的特征部位进行定位,网络B则是用来对网络A检测到的特征区域进行特征提取。两个网络相互协调作用,实现细粒度图像分类。但如果用一个网络来实现,这种说法也太荒谬了。

由于模型对两个 CNN 的输出的操作是线性的(矩阵相乘是线性运算,因为只有加法和乘法操作),所以网络称为 bilinear CNNs。

二、矩阵外积(outer product)

2.1 外积的计算方式

网上很多博客说 矩阵外积就是克罗内克积,但是Bilinear CNN代码实现中的外积其实就是普通的矩阵相乘(就是线性代数中最常规的矩阵相乘),并非克罗内克积。

代码可见本文第三部分“PyTorch 网络代码实现”。

计算外积的代码为:

x = torch.bmm(x, torch.transpose(x, 1, 2)) / (28 * 28)

这里的 torch.bmm(a,b) 就是普通的矩阵相乘,举个例子证明:

import torch

a = torch.randint(low=0, high=5, size=(1, 2, 2))
b = torch.randint(low=0, high=5, size=(1, 2, 2))
c = torch.bmm(a, b)
print(f"a = {a}")
print(f"b = {b}")
print(f"c = {c}")


"""
a = tensor([[[4, 0],
             [4, 1]]])
b = tensor([[[1, 4],
             [2, 4]]])
c = tensor([[[ 4, 16],         4 = 4 * 1 + 0 * 2, 16 = 4 * 4 + 0 * 4
             [ 6, 20]]])       6 = 4 * 1 + 1 * 2, 20 = 4 * 4 + 1 * 4
"""

如果这个版本的 PyTorch 代码没有错误的话,这里的外积就是普通的矩阵相乘。当然我没有看 Bilinear CNN 的 Matlab 源码,源码地址为 Bilinear CNNs for Fine-grained Visual Recognition,欢迎大家批评指正(对于内积外积我也没分清楚)。

2.2 外积的作用

外积其实只是一种特征融合的方式,其他常用的特征融合方法还有:最大值融合、平均值融合、相加、concat 等。

但外积可以通过矩阵运算捕捉不同通道之间的特征相关性。由于描述向量的不同维度对应卷积特征的不同通道,而不同通道提取了不同的语义特征,因此,通过双线性操作,可以同时捕获输入图像的不同语义特征之间的关系。

三、PyTorch 网络代码实现

基于 vgg16:

import torch
import torch.nn as nn
import torchvision


class BCNN_fc(nn.Module):
    def __init__(self):
        super(BCNN_fc, self).__init__()
        # VGG16的卷积层和池化层
        self.features = torchvision.models.vgg16(pretrained=True).features

        # 去除最后一个 pooling 层
        self.features = nn.Sequential(*list(self.features.children())[:-1])

        # 线性分类层
        self.fc = nn.Linear(512 * 512, 200)

        # 冻结以前的所有层
        for param in self.features.parameters():
            param.requres_grad = False

        # 初始化fc层
        nn.init.kaiming_normal_(self.fc.weight.data)
        if self.fc.bias is not None:
            nn.init.constant_(self.fc.bias.data, val=0)

    def forward(self, x):
        N = x.size()[0]
        assert x.size() == (N, 3, 448, 448)

        # 特征提取
        x = self.features(x)
        assert x.size() == (N, 512, 28, 28)
        x = x.view(N, 512, 28 * 28)

        # 双线性矩阵相乘
        # 对于 c=torch.bmm(a,b),其中 a.shape=[b,m,n], b.shape=[b,n,p], 则 c.shape=[b,m,p]
        # 这里其实是对 x 和 x^T 进行了相乘
        # 除以 28 * 28 是为了防止最后 softmax 的梯度过小
        x = torch.bmm(x, torch.transpose(x, 1, 2)) / (28 * 28)
        assert x.size() == (N, 512, 512)

        # 有符号平方根,y = sign(x) * sqrt(|x|)
        x = torch.sign(x) * torch.sqrt(torch.abs(x) + 1e-10)
        x = x.view(N, 512 * 512)
        assert x.size() == (N, 512 * 512)

        # L2归一化
        x = torch.nn.functional.normalize(x)
        assert x.size() == (N, 512 * 512)

        # 全连接分类层
        x = self.fc(x)
        assert x.size() == (N, 200)
        return x

if __name__ == '__main__':
    input = torch.randn(16, 3, 448, 448)
    model = BCNN_fc()
    output = model(input)
    print(output.shape)   # torch.Size([16, 200])

基于 ResNet18:

import torch
import torch.nn as nn
from torchvision.models import resnet18

class Net(nn.Module):
    def __init__(self):
        super(Net,self).__init__()
        self.features = nn.Sequential(resnet18().conv1,
                                     resnet18().bn1,
                                     resnet18().relu,
                                     resnet18().maxpool,
                                     resnet18().layer1,
                                     resnet18().layer2,
                                     resnet18().layer3,
                                     resnet18().layer4)
        self.classifiers = nn.Sequential(nn.Linear(512**2,14))
        
    def forward(self,x):
        x=self.features(x)
        batch_size = x.size(0)
        feature_size = x.size(2)*x.size(3)
        x = x.view(batch_size , 512, feature_size)
        x = (torch.bmm(x, torch.transpose(x, 1, 2)) / feature_size).view(batch_size, -1)
        x = torch.nn.functional.normalize(torch.sign(x)*torch.sqrt(torch.abs(x)+1e-10))
        x = self.classifiers(x)
        return x

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/620729.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【web自动化测试】Web网页测试针对性的流程解析

前言 测试行业现在70%是以手工测试为主,那么只有20%是自动化测试,剩下的10%是性能测试。 有人可能会说,我现在做手工,我为什么要学自动化呢?我去学性能更好性能的人更少? 其实,性能的要求比自动…

蓝桥杯2022年第十三届决赛真题-齿轮

题目描述 这天,小明在组装齿轮。 他一共有 n 个齿轮,第 i 个齿轮的半径为 ri,他需要把这 n 个齿轮按一定顺序从左到右组装起来,这样最左边的齿轮转起来之后,可以传递到最右边的齿轮,并且这些齿轮能够起到提…

小程序容器与PWA是一回事吗?

PWA代表“渐进式网络应用”(Progressive Web Application)。它是一种结合了网页和移动应用程序功能的技术概念。PWA旨在提供类似于原生应用程序的用户体验,包括离线访问、推送通知、后台同步等功能,同时又具有网页的优势&#xff…

软件验收测试该怎么进行?权威的软件检测机构应该具备哪些资质?

软件测试是软件开发周期中非常重要的一个环节。软件测试的目的是发现软件在不同环境下的各种问题,保证软件在发布前能够达到用户的要求。软件验收测试是软件测试的最后一个环节,该环节主要验证软件是否满足用户需求。那么对于软件验收测试,该…

分布式事务二 Seata使用及其原理剖析

一 Seata 是什么 Seata 介绍 Seata 是一款开源的分布式事务解决方案,致力于提供高性能和简单易用的分布式事务服务。Seata 将为用户提供了 AT、TCC、SAGA 和 XA 事务模式,为用户打造一站式的分布式解决方案。AT模式是阿里首推的模式,阿里云上有商用版本…

【Spring源码】Spring源码导入Idea

1.基础环境准备 相关软件、依赖的版本号 Spring源码版本 5.3.x软件 ideaIU-2021.1.2.exeGradle gradle-7.2-bin.zip https://services.gradle.org/distributions/gradle-7.2-bin.zip - 网上说要单独下载gradle并配置环境变量,亲测当前5.3.X版本通过gradlew的方式进…

虚函数详解及应用场景

目录 概述1. 虚函数概述2. 虚函数的声明与重写3. 析构函数与虚函数的关系4. 虚函数的应用场景4.1. 多态性4.2. 接口定义与实现分离4.3. 运行时类型识别4.4. 多级继承与虚函数覆盖 结论 概述 虚函数是C中一种实现多态性的重要机制,它允许在基类中声明一个函数为虚函…

PDCCH monitoring capability

欢迎关注同名微信公众号“modem协议笔记”。 前段时间看search space set group (SSSG) switching相关内容时,注意到R17和R16的描述由于PDCCH monitoring capability的变化,内容有些不一样。于是就顺带看了下R16 R17PDCCH monitoring capability的内容。…

Domino 14.0早期测试版本

大家好,才是真的好。 本篇是超级图片篇,图片多,内容丰富,流量党请勿手残。 前天我们说到Engageug2023正在如火如荼进行,主题是“The Future is Now”。 因为时差的关系,实际上在写这篇公众号时&#xff…

设计模式(七):结构型之适配器模式

设计模式系列文章 设计模式(一):创建型之单例模式 设计模式(二、三):创建型之工厂方法和抽象工厂模式 设计模式(四):创建型之原型模式 设计模式(五):创建型之建造者模式 设计模式(六):结构型之代理模式 设计模式…

Java --- springboot3之web内容协商原理

一、内容协商原理 HttpMessageConverter 定制 HttpMessageConverter 来实现多端内容协商 编写WebMvcConfigurer提供的configureMessageConverters底层,修改底层的MessageConverter ResponseBody由HttpMessageConverter处理 标注了ResponseBody的返回值 将会由支持它…

蹭个高考热度,中国人民大学与加拿大女王大学金融硕士项目给你更多的选择

今日各大平台热搜都被“高考”霸屏,朋友圈里到处都是高考的祝福。期待莘莘学子都将交上满意的答卷,考出理想的未来。针对职场上的我们而言高考已是过去时,但知识的力量却是无穷的,在职的我们依然可以向上生长,中国人民…

FreeRTOS_任务相关API函数

目录 1. 任务创建和删除 API 函数 1.1 函数 xTaskCreate() 1.2 函数 xTaskCreateStatic() 1.3 函数 xTaskCreateRestricted() 1.4 函数 vTaskDelete() 2. 任务创建和删除实验(动态方法) 2.1 实验程序与分析 3. 任务创建和删除实验(静…

ZC-CLS381RGB颜色识别——配置寄存器组(上)

文章目录 前言一、ZC-CLS381RGB简介二、配置寄存器组1.主控寄存器2.检测速率寄存器2.增益寄存器2.颜色数据寄存器 三、状态转移图和信号波形图绘制总结 前言 在现代工业生产中,颜色识别技术已经成为了一个非常重要的技术。颜色识别可以用于产品质量检测、物料分类、…

特瑞仕|常见电子元器件的故障现象及原因详解

​电子元器件是现代电子设备中不可或缺的组成部分,但在长时间的使用过程中,它们也可能会出现各种故障现象。本文将详细介绍一些常见电子元器件的故障现象及原因,以帮助读者更好地理解和处理这些问题。 一、电阻器 故障现象:电阻值…

湖南人的商业策略:用“副产品免费”的模式,推动主产品消费

湖南人的商业策略:用“副产品免费”的模式,推动主产品消费 什么是副产品免费模式?(主产品要钱,副产品不要钱) 免费商业模型设计的核心就是通过延长产业链,以此来达到利润链条的延伸,在这个过程中衍生和挖掘…

1.8 掌握Scala函数

一、声明函数 (一)显式声明函数 案例演示 (1)加法函数 package net.huawei.day08import scala.io.StdIn/*** 功能:显式声明函数* 作者:* 日期:2023年03月20日*/ object Example01 {def add1…

测试用例设计方法之因果图详解

一、因果图概述 因果图是从需求中找出因(输入条件)和果(输出或程序状态的改变),通过分析输入条件之间的关系(组合关系、约束关系等)及输入和输出之间的关系绘制出因果图,再转化成判…

composer-创建自己的依赖库

1.环境 码云账号(或者GitHub)码云地址composer 官方仓库账号 Packagist composer官方仓库安装composer 2.步骤 2.1 发行composer的依赖包是需要从git 或者svn里拉取的,所以得先在码云里创建一个仓库 2.2 依赖包中必须有composer.json配置标明名字依赖等信息,配置大概如下,配…

Vue基础第五篇

一、动态组件 1.基本使用 <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><title>动态组件</title><script src"https://cdn.bootcdn.net/ajax/libs/vue/2.6.12/vue.min.js"></sc…