【池化方法】多示例学习池化(MIL pooling)公式与代码

news2025/1/26 15:27:07

  一般的池化方法包括最大池化、平均池化、自适应池化与随机池化,这几天意外看到了多示例学习池化,感觉挺有意思的,记录一下。
  论文
  代码

1. 多示例学习(Multiple instance learning,MIL)

  经典深度学习的数据是一张图一个类别,而多示例学习的数据是一个数据包(bag),一个bag标记一个类别,bag中的每一张图称为一个示例(instance)。形象一点的例子就是,一位患者扫了一次CT,产生了很多张CT切片图像,此时,一张CT切片为一个instance,所有CT切片为一个bag。如果所有的CT切片都检测为没病,那么这位患者正常,否则,这名患者患病。
  其基本模式如下图所示:
在这里插入图片描述

2. MIL pooling

  最大池化和平均池化都是不可训练的,设计灵活且自适应的MIL池化可以通过针对任务和数据进行调整,以实现更好的结果。

2.1 注意机制(Attention mechanism)

  该方法使用每一个instance低维嵌入的加权平均值,其权重系数通过神经网络学习得到,权重系数之和为1。设 H = { h 1 , … , h K } H = \left\{ {{h_1}, \ldots ,{h_K}} \right\} H={h1,,hK}为一个bag中的K个嵌入,则:
z = ∑ k = 1 K a k h k {z = \sum\limits_{k = 1}^K {{a_k}{h_k}}} z=k=1Kakhk a k = exp ⁡ { w ⊤ tanh ⁡ ( V h k ⊤ ) } ∑ j = 1 K exp ⁡ { w ⊤ tanh ⁡ ( V h j ⊤ ) } {{a_k} = \frac{{\exp \left\{ {{w^ \top }\tanh (Vh_k^ \top )} \right\}}}{{\sum\limits_{j = 1}^K {\exp \left\{ {{w^ \top }\tanh (Vh_j^ \top )} \right\}} }}} ak=j=1Kexp{wtanh(Vhj)}exp{wtanh(Vhk)}  其中 w ∈ R L × 1 {w \in R{^{L \times 1}}} wRL×1, V ∈ R L × M {V \in R{^{L \times M}}} VRL×M为参数,可由全连接层实现。 L {L} L为低维嵌入大小, M {M} M为中间维度。

2.2 门控注意机制(Gated attention mechanism)

  由于 tanh ⁡ ( x ) {\tanh (x)} tanh(x) x ∈ [ − 1 , 1 ] {x \in [ - 1,1]} x[1,1]时近似线性,这可能会限制instance之间学习关系的最终表达。作者设计了一种门控机制,即:
a k = exp ⁡ { w ⊤ ( tanh ⁡ ( V h k ⊤ ) ⊙ s i g m o i d ( U h k ⊤ ) ) } ∑ j = 1 K exp ⁡ { w ⊤ ( tanh ⁡ ( V h j ⊤ ) ⊙ s i g m o i d ( U h j ⊤ ) ) } {{a_k} = \frac{{\exp \left\{ {{w^ \top }(\tanh (Vh_k^ \top ) \odot sigmoid(Uh_k^ \top ))} \right\}}}{{\sum\limits_{j = 1}^K {\exp \left\{ {{w^ \top }(\tanh (Vh_j^ \top ) \odot sigmoid(Uh_j^ \top ))} \right\}} }}} ak=j=1Kexp{w(tanh(Vhj)sigmoid(Uhj))}exp{w(tanh(Vhk)sigmoid(Uhk))}  其中, U ∈ R L × M {U \in R{^{L \times M}}} URL×M为参数, ⊙ { \odot } 为元素级相乘,门控机制引入了可学习的非线性,潜在地消除了 tanh ⁡ ( x ) {\tanh (x)} tanh(x)中麻烦的线性。

3. MIL pooling的PyTorch代码

import torch
import torch.nn as nn
import torch.nn.functional as F


class Attention(nn.Module):
    def __init__(self):
        super(Attention, self).__init__()
        self.L = 500
        self.D = 128
        self.K = 1

        self.feature_extractor_part1 = nn.Sequential(
            nn.Conv2d(1, 20, kernel_size=5),
            nn.ReLU(),
            nn.MaxPool2d(2, stride=2),
            nn.Conv2d(20, 50, kernel_size=5),
            nn.ReLU(),
            nn.MaxPool2d(2, stride=2)
        )

        self.feature_extractor_part2 = nn.Sequential(
            nn.Linear(50 * 4 * 4, self.L),
            nn.ReLU(),
        )
        # w 和 V 由两个线性层实现
        self.attention = nn.Sequential(
            nn.Linear(self.L, self.D),
            nn.Tanh(),
            nn.Linear(self.D, self.K)
        )

        self.classifier = nn.Sequential(
            nn.Linear(self.L*self.K, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        # 设输入张量大小为[20, 1, 30, 30],即有20个instance
        x = x.squeeze(0)   # [20, 1, 30, 30]
       
        H = self.feature_extractor_part1(x)  # [20, 50, 4, 4] 特征提取下采样
        
        H = H.view(-1, 50 * 4 * 4)   # [20, 800] 通道合并
        
        H = self.feature_extractor_part2(H)  # NxL  [20, 500] 低维嵌入
        
        A = self.attention(H)  # NxK  [20, 1] 计算ak
       
        A = torch.transpose(A, 1, 0)  # KxN  [1, 20] 每个instance一个权重
        
        A = F.softmax(A, dim=1)  # softmax over N  [1, 20] softmax使权重之和为1

        M = torch.mm(A, H)  # KxL  [1, 500] 计算ak乘以hk

        Y_prob = self.classifier(M)  # [1, 1] 分类器输出概率
        
        Y_hat = torch.ge(Y_prob, 0.5).float()  # [1, 1] 大于0.5为1

        return Y_prob, Y_hat, A


class GatedAttention(nn.Module):
    def __init__(self):
        super(GatedAttention, self).__init__()
        self.L = 500
        self.D = 128
        self.K = 1

        self.feature_extractor_part1 = nn.Sequential(
            nn.Conv2d(1, 20, kernel_size=5),
            nn.ReLU(),
            nn.MaxPool2d(2, stride=2),
            nn.Conv2d(20, 50, kernel_size=5),
            nn.ReLU(),
            nn.MaxPool2d(2, stride=2)
        )

        self.feature_extractor_part2 = nn.Sequential(
            nn.Linear(50 * 4 * 4, self.L),
            nn.ReLU(),
        )

        self.attention_V = nn.Sequential(
            nn.Linear(self.L, self.D),
            nn.Tanh()
        )

        self.attention_U = nn.Sequential(
            nn.Linear(self.L, self.D),
            nn.Sigmoid()
        )
       
        self.attention_weights = nn.Linear(self.D, self.K)   # w

        self.classifier = nn.Sequential(
            nn.Linear(self.L*self.K, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = x.squeeze(0)

        H = self.feature_extractor_part1(x)
        H = H.view(-1, 50 * 4 * 4)
        H = self.feature_extractor_part2(H)  # NxL

        A_V = self.attention_V(H)  # NxD tanh
        A_U = self.attention_U(H)  # NxD Sigmoid
        A = self.attention_weights(A_V * A_U) # element wise multiplication # NxK
        A = torch.transpose(A, 1, 0)  # KxN
        A = F.softmax(A, dim=1)  # softmax over N

        M = torch.mm(A, H)  # KxL

        Y_prob = self.classifier(M)
        Y_hat = torch.ge(Y_prob, 0.5).float()

        return Y_prob, Y_hat, A

   MIL pooling也不一定限制在多示例学习中使用,如对三维数据采用不同的二维降采样方法,得到的数据经特征提取后进行融合,也可以采用这种池化方法。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/433669.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

梯度下降算法原理详解及MATLAB程序代码(最简单)

模型就是线性规划及线性规划的对偶理论,单纯形法以及它的实际应用:整数规划及其解法(分支定界法、割平面法匈牙利算法Q),目标规划,非线性规划动态规划、决策分析等等。 其它的一些优化算法。比如说一维搜索里面的黄金分割法、加步…

PostMan笔记(二)发送请求

1. 发送请求功能介绍 Postman是一款流行的API开发工具,它可以让开发人员更方便地测试、调试和使用API。其中,发送请求功能是Postman最为重要和基础的功能之一。 在Postman中,发送请求功能主要包括以下几个步骤: 选择请求方法&am…

数据分析时,进行数据建模该如何筛选关键特征?

1.为什么要做关键特征筛选? 在数据量与日俱增的时代,我们收集到的数据越来越多,能运用到数据分析挖掘的数据也逐渐丰富起来,但同时,我们也面临着如何从庞大的数据中筛选出与我们业务息息相关的数据。(大背景…

Java的对象克隆

本节我们会讨论 Cloneable 接口,这个接口指示一个类提供了一个安全的 clone() 方法。 Object 类提供的 clone() 方法是 “浅拷贝”,并没有克隆对象中引用的其他对象,原对象和克隆的对象仍然会共享一些信息。深拷贝指的是:在对象中…

微服务---一篇学完SpringCloud

SpringCloud 1.认识微服务 随着互联网行业的发展,对服务的要求也越来越高,服务架构也从单体架构逐渐演变为现在流行的微服务架构。这些架构之间有怎样的差别呢? 1.0.学习目标 了解微服务架构的优缺点 1.1.单体架构 单体架构&#xff1a…

java企业级信息系统开发学习笔记06基于xml配置方式使用Spring MVC

文章目录 一、学习目标二、Spring MVC概述1、MVC架构2、Spirng MVC3、使用Spring MVC的两种方式 三、基于xml配置与注解的方式使用Spring MVC(一)创建Maven项目(二)添加相关依赖(三)给项目添加Web功能&…

SpringMVC表格提交中文乱码和配置logback

最佳解决方案还是使用Spring提供的过滤器&#xff0c;将其配置到WEB.XML文件中&#xff1a; <filter><filter-name>characterEncodingFilter</filter-name><filter-class>org.springframework.web.filter.CharacterEncodingFilter</filter-class&g…

nginx部署VUE项目

前言 目前公司的前端代码基本都是部署在nginx下&#xff0c;特此来记录一下 开发环境&#xff1a;window10 nginx环境搭建&#xff08;参考下方文章&#xff09; window环境安装 mac环境安装 本地我将nginx放置于F盘 前端项目打包 一个nginx服务下可能会放置多个前端包&…

echarts 折线图

Echarts 常用各类图表模板配置 注意&#xff1a; 这里主要就是基于各类图表&#xff0c;更多的使用 Echarts 的各类配置项&#xff1b; 以下代码都可以复制到 Echarts 官网&#xff0c;直接预览&#xff1b; 图标模板目录Echarts 常用各类图表模板配置一、简洁折线图二、环形图…

结构体的存储

由于要想知道结构体的大小&#xff0c;了解结构体是如何存储在内存中的 我们需要先了解一个知识点&#xff1a; 结构体内存对齐 1. 第一个成员在与结构体变量偏移量为0的地址处 (偏移量是某个字节相较于起始存储空间的相差字节数 例如第一个字节的偏移量是0&#xff0c;第二个…

一套专业的C#医院体检管理系统源码 PEIS体检报告管理系统源码 C/S医院PEIS系统源码

医院PEIS体检管理系统源码&#xff0c;有源码&#xff0c;有演示&#xff0c;自主研发&#xff0c;官方正版授权&#xff01; 开发语言&#xff1a;C# 开发工具&#xff1a;VS2013版本起 后端框架&#xff1a;winform 数 据 库&#xff1a;oracle 12c 医院体检系统主要特点…

人大金仓亮相2023CHITEC,五大看点不容错过

近日&#xff0c;由中国卫生信息与健康医疗大数据学会和《中国卫生信息管理杂志》社联合举办的2023&#xff08;17th&#xff09;中国卫生信息技术/健康医疗大数据应用交流大会暨软硬件与健康医疗产品展览会&#xff08;2023 CHITEC&#xff09;在安徽合肥顺利召开。 作为数据库…

【DAY38】BOM/VUE初步学习

pageXOffset 设置或返回当前页面相对于窗口显示区左上角的 X 位置。 pageYOffset 设置或返回当前页面相对于窗口显示区左上角的 Y 位置。 screenLeft&#xff0c;screenTop&#xff0c;screenX&#xff0c;screenY 声明了窗口的左上角在屏幕上的的 x 坐标和 y 坐标。IE、Safari…

JavaScript历史

JavaScript历史 参考视频1 1990年&#xff0c;第一个终端显示网页被蒂姆博士创造出来&#xff0c;表现为超链接跳转、无图的特点。文本格式定义、文本传输协议即应用层协议&#xff0c;解析显示引擎是关键。1993年&#xff0c;随着人们对视觉效果的要求逐渐变高&#xff0c;马…

Https详解

文章目录 一. 什么是 Https1. "加密"是什么?2. 对称加密3. 非对称加密4. "中间人攻击" 二. 引入证书理解签名黑客能否伪造证书?黑客能否替换公钥?黑客能否篡改签名?如何查看证书? 一. 什么是 Https https 就是 http 安全层(SSL)–> 用来加密的协…

黑马在线教育数仓实战6

6. 意向用户主题看板_增量流程 6.1 数据采集(拉链表) 7. hive的索引 ​ 索引的作用: 加快查询的效率 为什么索引可以提升查询效率呢? hive索引是在 分区 分桶优化基础上, 又提供一种新的优化手段, 如果分区 和分桶受限, 可以尝试使用索引的方式来优化处理 hive提供了三种索…

VMware ESXi 8.0U1 macOS Unlocker OEM BIOS (标准版和厂商定制版)

ESXi 8.0U1 标准版&#xff0c;Dell HPE 联想 浪潮 定制版 请访问原文链接&#xff1a; https://sysin.org/blog/vmware-esxi-8-u1-oem/&#xff0c;查看最新版。原创作品&#xff0c;转载请保留出处。 作者主页&#xff1a;sysin.org 2023-04-18, VMware vSphere 8.0U1 发布…

家用洗地机实用吗?家用洗地机款式推荐

要说现在家居清洁用什么单品更省心&#xff0c;洗地机必须要算一项。虽然这在国际上也不是什么新鲜的概念了&#xff0c;但是在国内兴起也只是这几年的事&#xff0c;关于家用洗地机什么牌子最好之类的问题也是很多人都比较关心的问题。我个人也是不喜欢做家务的&#xff0c;家…

C++算法:排序、查找

排序 排序是一个非常经典的问题&#xff0c;它以一定的顺序对一个数组&#xff08;或一个列表&#xff09;中的项进行重新排序 有许多不同的排序算法&#xff0c;每个都有其自身的优点和局限性。 时间复杂度&#xff1a;对排序数据的总的操作次数。反映当n变化时&#xff0c;操…

SQL之SQL优化

文章目录 一、插入数据优化insert优化大批量插入数据 二、主键优化数据组织方式页分裂页合并主键设计原则三、order by优化 四、Group By 优化五、limit优化六、count优化count的几种用法 七、update优化总结 一、插入数据优化 insert优化 insert into tb_test values(1, tom…