最优传输问题与Sinkhorn算法

news2025/1/11 0:03:49

目录

  • 引言
  • 例子:分甜点
  • 最优传输问题
  • Sinkhorn算法
    • Sinkhorn距离
    • 算法流程
    • 代码实验

引言

最近看到一篇特征匹配相关的论文,思想是将特征匹配问题转化为最优传输问题求解,于是我去学习了一下最优传输问题。
本文主要是对博文 Notes on Optimal Transport 的学习做一个记录总结,该博文写的不错,推荐阅读。

例子:分甜点

文章作者以一个简单的甜点分配例子引入了最优传输问题。
向量 c = [ 4 , 2 , 6 , 4 , 4 ] ⊤ \mathbf{c}=[4, 2, 6, 4, 4]^{\top} c=[4,2,6,4,4] 表示5种甜点的数量:
5种甜点的数量分布
向量 r = [ 3 , 3 , 3 , 4 , 2 , 2 , 2 , 1 ] ⊤ \mathbf{r}=[3, 3, 3, 4, 2, 2, 2, 1]^{\top} r=[3,3,3,4,2,2,2,1] 表示8个人需要的甜点数:
在这里插入图片描述
矩阵 M ∈ R 5 × 8 \mathbf{M}\in \mathbb{R}^{5\times 8} MR5×8 表示每个人对各种甜点的偏好,尺度区间 [ − 2 , 2 ] [-2, 2] [2,2],-2表示非常不喜欢,2表示非常喜欢:
在这里插入图片描述

我们的目标,就是要根据甜点的数量,同时考虑每个人的需求和偏好,将所有甜点合理地分配到每个人手中。

最优传输问题

最优运输问题的目标就是以最小的成本将一个概率分布转换为另一个概率分布。上面的分甜点的目标,用最优传输问题的定义来说,就是将概率分布 c \mathbf{c} c 以最小的成本转换到概率分布 r \mathbf{r} r
这就需要我们求得一个分配方案,由矩阵 P ∈ R 5 × 8 P\in \mathbb{R}^{5\times 8} PR5×8 表示,存储每个人分得的每个甜点的情况。

根据现实条件,这个分配矩阵 P P P 显然具有以下约束:

  1. 分配的甜点数量不能为负数;
  2. 每个人的需求都要满足,即 P P P 的行和服从分布 r \mathbf{r} r
  3. 每种甜点要全部分完,即 P P P 的列和服从分布 c \mathbf{c} c

于是在分布 r \mathbf{r} r c \mathbf{c} c 约束下, P P P 的解空间可以做如下定义:
U ( r , c ) = { P ∈ R > 0 n × m ∣ P 1 m = r , P ⊤ 1 n = c } (1) U(\mathbf{r}, \mathbf{c})=\left\{P \in \mathbb{R}_{>0}^{n \times m} \mid P \mathbf{1}_{m}=\mathbf{r}, P^{\top} \mathbf{1}_{n}=\mathbf{c}\right\} \tag 1 U(r,c)={PR>0n×mP1m=r,P1n=c}(1)
PS:这是博文的原公式,这里我有个疑问,为什么 P P P 的元素要求严格大于0,而不是大于等于0?希望有同学能够解答我的疑惑(感谢)

如前面所述,我们希望最小化转换成本,可以简单地反转偏好矩阵 M \mathbf{M} M 的符号,就可以得到成本矩阵(cost matrix)。于是就有了最优传输问题的公式化表示:
d M ( r , c ) = min ⁡ P ∈ U ( r , c ) ∑ i , j P i j M i j (2) d_{M}(\mathbf{r}, \mathbf{c})=\min _{P \in U(\mathbf{r}, \mathbf{c})} \sum_{i, j} P_{i j} M_{i j} \tag 2 dM(r,c)=PU(r,c)mini,jPijMij(2)

标量 d M d_{M} dM 也被称为推土机距离(earth mover distance),因为它可以解释为至少移动多少“泥土”(成本)才能将一个土堆(分布)变成另一个土堆(分布)。

Sinkhorn算法

Sinkhorn距离

Sinkhorn距离是对推土机距离的一种改进,在其基础上引入了熵正则化项:
d M λ ( r , c ) = min ⁡ P ∈ U ( r , c ) ∑ i , j P i j M i j − 1 λ h ( P ) (3) d_{M}^{\lambda}(\mathbf{r}, \mathbf{c})=\min _{P \in U(\mathbf{r}, \mathbf{c})} \sum_{i, j} P_{i j} M_{i j}-\frac{1}{\lambda} h(P) \tag 3 dMλ(r,c)=PU(r,c)mini,jPijMijλ1h(P)(3)
其中 h ( P ) = − ∑ P i j log ⁡ P i j h(P)=-\sum{P_{ij}\log{P_{ij}}} h(P)=PijlogPij 称作 P P P 的信息熵(information entropy), P P P 分布越均匀,信息熵越大。

熵正则化参数 λ \lambda λ 负责调整信息熵的影响程度, λ \lambda λ 越大,信息熵的影响越小,最终结果受成本矩阵的影响更大,即更多地考虑每个人的喜好;反之,最终结果则更倾向于均匀分配,每种甜点将平均分配给每个人。

算法流程

新增的熵正则化项似乎让问题更加难以优化,但Sinkhorn算法提供了一种简单且有效的方法应对这一问题,Sinkhorn算法认为,最优分配矩阵 P λ ∗ P^*_\lambda Pλ 的元素应该具有如下形式:
( P λ ∗ ) i j = α i β j e − λ M i j (4) (P^*_\lambda)_{ij}=\alpha_i \beta_j e^{-\lambda M_{ij}} \tag 4 (Pλ)ij=αiβjeλMij(4)
其中正是 α 1 , . . . , α n \alpha_1,...,\alpha_n α1,...,αn β 1 , . . . , β n \beta_1,...,\beta_n β1,...,βn 使得 P ∗ P^* P 满足分配矩阵的三个约束。

具体流程如下:

给定: 代价矩阵 M M M, 分布 r \mathbf{r} r, 分布 c \mathbf{c} c, 熵正则化参数 λ \lambda λ
初始化: 分配矩阵 P λ = e − λ M P_\lambda=e^{-\lambda M} Pλ=eλM
重复:

  1. 缩放行,使得 P P P 的行和逼近分布 r \mathbf{r} r
  2. 缩放列,使得 P P P 的列和逼近分布 c \mathbf{c} c

直到: 收敛

代码实验

以下是Sinkhorn代码实现:

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt


r = np.array([3, 3, 3, 4, 2, 2, 2, 1])
c = np.array([4, 2, 6, 4, 4])
M = np.array(
    [[2, 2, 1, 0, 0], 
    [0, -2, -2, -2, -2], 
    [1, 2, 2, 2, -1], 
    [2, 1, 0, 1, -1],
    [0.5, 2, 2, 1, 0], 
    [0, 1, 1, 1, -1], 
    [-2, 2, 2, 1, 1], 
    [2, 1, 2, 1, -1]],
    dtype=float) 
M = -M # 将M变号,从偏好转为代价

def compute_optimal_transport(M, r, c, lam, eplison=1e-8):
    """
    Computes the optimal transport matrix and Slinkhorn distance using the
    Sinkhorn-Knopp algorithm

    Inputs:
        - M : cost matrix (n x m)
        - r : vector of marginals (n, )
        - c : vector of marginals (m, )
        - lam : strength of the entropic regularization
        - epsilon : convergence parameter

    Outputs:
        - P : optimal transport matrix (n x m)
        - dist : Sinkhorn distance
    """
    n, m = M.shape  # 8, 5
    P = np.exp(-lam * M) # (8, 5)
    P /= P.sum()  # 归一化
    u = np.zeros(n) # (8, )
    # normalize this matrix
    while np.max(np.abs(u - P.sum(1))) > eplison: # 这里是用行和判断收敛
        # 对行和列进行缩放,使用到了numpy的广播机制,不了解广播机制的同学可以去百度一下
        u = P.sum(1) # 行和 (8, )
        P *= (r / u).reshape((-1, 1)) # 缩放行元素,使行和逼近r
        v = P.sum(0) # 列和 (5, )
        P *= (c / v).reshape((1, -1)) # 缩放列元素,使列和逼近c
    return P, np.sum(P * M) # 返回分配矩阵和Sinkhorn距离

我们来看看在不同 λ \lambda λ 下,得到的分配矩阵有什么特点:

lam = 0.1

P, d = compute_optimal_transport(M,
        r,
        c, lam=lam)

partition = pd.DataFrame(P, index=np.arange(1, 9), columns=np.arange(1, 6))
ax = partition.plot(kind='bar', stacked=True)
print('Sinkhorn distance: {}'.format(d))
ax.set_ylabel('portions')
ax.set_title('Optimal distribution ($\lambda={}$)'.format(lam))

在这里插入图片描述

可以看到每个人分配得到的甜点基本上都符合初始甜点的分布比例 c = [ 4 , 2 , 6 , 4 , 4 ] ⊤ \mathbf{c}=[4, 2, 6, 4, 4]^{\top} c=[4,2,6,4,4]

试着调大 λ \lambda λ
在这里插入图片描述
可以看到最终的分配向每个人的偏好靠拢了。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/192395.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

4.6 QR分解二:Householder变换

1 Householder reflector Householder反射是这样子的(图片来自瑞典皇家理工学院):   图中u是长度为1的向量。x是任意向量,H是u的Householder reflector。可见无论x是什么向量,HxHxHx始终除于和u正交的平面上。H和u的关系是: HI…

【z-library平替】Clibrary中文图书馆,电子书大全

目录1、z-library和Clibrary简介2、Clibrary网址3、具体操作界面1、z-library和Clibrary简介 喜欢阅读的盆友多多少少可能都听过z-library,书籍库非常全,而且是免费的,但是在z-library国内下线后,就一直没有找到合适的平替书库。 …

【vue2】vuex超超超级详解!(核心五大配置项)

🥳博 主:初映CY的前说(前端领域) 🌞个人信条:想要变成得到,中间还有做到! 🤘本文核心:vuex基础认识、state、getters、mutations actions、modules使用 目录(文末原素材) 一、…

新年找工作?python带你批量采集招聘数据

前言 大家早好、午好、晚好吖 ❤ ~ 必备素材: stealth.min.js 谷歌浏览器谷歌驱动selenium3.141.0 不知道怎么弄嘚同学可以私我获取哦~ 开发环境: python 3.8 pycharm 专业版 操作步骤 selenium 模块: 操作浏览器 打开一个浏览器 打开一个网址 获取数据 保存数据 …

性能测试工具-nmon

nmon 文章目录nmon介绍下载Linux系统服务器在服务器上新建nmon文件夹将下载文件上传到服务器新建的文件夹内修改文件名启动nmon启动nmon命令行使用nomn_analyser对监控结果进行分析图表分析nmon 主要用来做性能测试时对服务器的监控 捕捉各类系统资源的使用情况,并…

Maven实战-2.pom.xml标签说明

前言 持续更新中… pom.xml文件 1.<project> 这是pom.xml的根元素&#xff0c;所有的标签都包含在<project>…</project>之间。 2.<modelVersion> 指定当前POM模型的版本&#xff0c;对于maven2和maven3来说&#xff0c;它只能是4.0.0 <mode…

【linux】剖析底层——带你详细了解Linux内核源码的构成及其作用(1)

目录 一、arch文件 1.作用 2.arch文件下的子文件示意图 3.各个子文件的作用 &#xff08;1&#xff09;alpha &#xff08;2&#xff09;arc &#xff08;3&#xff09;arm &#xff08;4&#xff09;arm64 &#xff08;5&#xff09;cshy &#xff08;6&#xff09;…

8 加载数据集

文章目录前提知识了解数据集Mini-Batch常用术语DataLoader核心参数核心功能小tips课程代码实例课程来源&#xff1a; 链接课程文本部分来源&#xff08;参考&#xff09;&#xff1a; 链接以及&#xff08;强烈推荐&#xff09; Birandaの前提知识了解 enumerate函数 数据集 …

局域网中UTP连接,如何实现防止芯片损坏,防止信号产生各种误码,及实现CHIP之间的阻抗匹配

Hqst盈盛电子导读&#xff1a;局域网中UTP连接&#xff0c;如何实现防止芯片损坏&#xff0c;防止信号产生各种误码&#xff0c;及实现CHIP之间的阻抗匹配&#xff0c;浅谈网络滤波器作用一&#xff0c;在有线局域网中&#xff0c;计算机与服务器之间&#xff0c;计算机与路由器…

10、条件语句

目录 一、if语句的基本形式 1. if语句形式 2. if…else语句形式 3. else if语句形式 二、if的嵌套形式 三、条件运算符 四、switch语句 1. switch语句的基本形式 2. 多路开关模式的switch语句 一、if语句的基本形式 在if语句中&#xff0c;首先判断表达式的值&#x…

【BetterBench】2023年美赛辅导

通知 2023年美赛快开始啦&#xff0c;提醒大家比赛信息&#xff0c;比赛期间我会全称提供辅导&#xff0c;包括建模方案、实现代码&#xff01; 可以参考往年所有建模比赛&#xff0c;本人开源的建模方案及实现代码 2020-2023年所有数学建模竞赛专栏 报名信息 1.辅助报名截止…

【异常】前端Babel提示 Support for the experimental syntax ‘jsx‘ isn‘t currently enabled

一、报错内容 17:33:41 - Building for production... 17:34:13 ERROR Failed to compile with 5 errors5:34:09 PM 17:34:13 17:34:13 error in ./src/layout/components/Sidebar/Item.vue?vue&typescript&langjs& 17:34:13 17:34:13 Syntax Error…

《流浪地球2》看不懂?根服务器、权威解析,专业科普来了

随着《流浪地球2》的上映&#xff0c;关于国产硬科幻电影的话题也火爆起来&#xff0c;片中各种脑洞大开&#xff0c;科技设定可圈可点&#xff0c;例如量子计算机、脑机接口、太空电梯等。从专业角度来看&#xff0c;作为国产科幻大片之光的《流浪地球2》为了保证真实性确实狠…

二叉平衡树 之 红黑树 (手动模拟实现)

目录 1、红黑树的概念 2、红黑树的性质 3、红黑树节点的定义 4、红黑树的插入 5、红黑树验证 代码汇总 6、红黑树的删除&#xff08;了解&#xff09; 7、红黑树的应用 8、红黑树 VS AVL树 1、红黑树的概念 红黑树&#xff0c;就是一种特殊的二叉搜索树&#xff0c;每个…

MySQL详解(四)——高级 2.0

性能分析 Explain 使用EXPLAIN关键字可以模拟优化器&#xff08;不改变查询结果前提下&#xff0c;调整查询顺序&#xff0c;生成执行计划&#xff09;执行SQL查询语句&#xff0c;从而知道MySQL是如何处理你的SQL语句的。分析你的查询语句或是表结构的性能瓶颈 功能&#x…

ECharts线性渐变色示例演示(2种渐变方式)

第003个点击查看专栏目录Echarts的渐变色采用了echarts.graphic.LinearGradient的方法&#xff0c;可以根据代码中的内容来看如何使用。线性渐变&#xff0c;多用于折线柱形图&#xff0c;前四个参数分别是 x0, y0, x2, y2, 范围从 0 - 1&#xff0c;相当于在图形包围盒中的百分…

PTA L1-025 正整数A+B(详解)

前言&#xff1a;本期是关于正整数AB的详解&#xff0c;内容包括四大模块&#xff1a;题目&#xff0c;代码实现&#xff0c;大致思路&#xff0c;代码解读&#xff0c;今天你c了吗&#xff1f; 题目&#xff1a; 题的目标很简单&#xff0c;就是求两个正整数A和B的和&#xf…

用户使用苹果AirTag来追踪宠物存在风险,苹果Find My功能用处广

苹果的 AirTag 不失为追踪宠物的一种便捷方式&#xff0c;这样宠物即便挣脱宠物圈或者其它方式丢失&#xff0c;都可以通过“Find My”方式追踪定位。正如《华尔街日报》所指出的&#xff0c;这种方式也存在 AirTag 被宠物吞食的风险。 AirTag 的直径为 1.26 英寸&#xff0c…

【Faster R-CNN】之 Resize_and_Padding 代码精读

【Faster R-CNN】之 Resize_and_Padding1、前言&#xff1a;2、resize_image_and_bbox1&#xff09;先对图像做resize处理2&#xff09;再对 bounding box 做resize处理3、padding_images代码1、前言&#xff1a; 在上一篇文章 【Faster R-CNN】之 Dataset and Dataloader 代码…

Linux网络:传输层之UDPTCP协议

文章目录一、端口号1.端口号范围划分2.常用命令二、UDP 协议1.格式2.特点3. UDP 的缓冲区4. UDP 使用注意事项5.基于 UDP 的应用层协议三、TCP 协议1.格式2.确认应答机制3.超时重传机制4.连接管理机制三次握手四次挥手5.滑动窗口6.流量控制7.拥塞控制8.延迟应答9.捎带应答10.面…